# VishwamAI Experiment 2: Model Training and Evaluation

This notebook demonstrates training and evaluating a VishwamAI transformer model with distillation and TPU optimization.

## Setup and Dependencies

In [None]:
# Install required packages
!pip install jax[tpu] flax optax sentencepiece transformers tqdm safetensors duckdb dm-haiku matplotlib pandas numpy

In [None]:
import os
import sys
import json
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import jax
import jax.numpy as jnp
import flax
import optax
import haiku as hk
from tqdm.auto import tqdm

from vishwamai import (
    create_vishwamai_transformer,
    create_train_state,
    VishwamAITrainer,
    DuckDBLogger,
    compute_distillation_loss
)

# Enable TPU if available
if 'TPU_NAME' in os.environ:
    print(f"TPU devices: {jax.devices()}")
else:
    print(f"Using devices: {jax.devices()}")

## Model Configuration

In [None]:
# Load model config
config = {
    'model_config': {
        'vocab_size': 32000,
        'num_layers': 12,
        'num_heads': 12,
        'hidden_dim': 768,
        'mlp_dim': 3072,
        'max_seq_len': 512,
        'dropout_rate': 0.1,
        'attention_dropout_rate': 0.1,
        'use_flash_attn': True,
        'use_rotary': True
    },
    'training': {
        'batch_size': 32,
        'learning_rate': 1e-4,
        'warmup_steps': 1000,
        'max_steps': 10000,
        'eval_frequency': 500,
        'save_frequency': 1000
    },
    'distillation': {
        'temperature': 2.0,
        'alpha': 0.5,
        'label_smoothing': 0.1
    }
}

## Data Preparation

In [None]:
def create_dummy_data(num_samples=1000):
    """Create dummy data for demonstration"""
    rng = np.random.default_rng(42)
    
    # Create random input sequences
    input_ids = rng.integers(
        low=0,
        high=config['model_config']['vocab_size'],
        size=(num_samples, config['model_config']['max_seq_len'])
    )
    
    # Create random labels
    labels = rng.integers(
        low=0,
        high=config['model_config']['vocab_size'],
        size=(num_samples, config['model_config']['max_seq_len'])
    )
    
    return {
        'input_ids': jnp.array(input_ids),
        'labels': jnp.array(labels)
    }

# Create train and eval datasets
train_data = create_dummy_data(1000)
eval_data = create_dummy_data(200)

print("Training data shape:", train_data['input_ids'].shape)
print("Evaluation data shape:", eval_data['input_ids'].shape)

## Model Setup and Training

In [None]:
# Initialize model
rng = jax.random.PRNGKey(42)
model = create_vishwamai_transformer(config['model_config'])

# Create learning rate schedule
def create_learning_rate_schedule():
    base_lr = config['training']['learning_rate']
    warmup_steps = config['training']['warmup_steps']
    
    def schedule(step):
        warmup_factor = jnp.minimum(step / warmup_steps, 1.0)
        return base_lr * warmup_factor
    
    return schedule

# Create trainer
trainer = VishwamAITrainer(
    config=config,
    model=model,
    learning_rate_schedule=create_learning_rate_schedule(),
    logger=DuckDBLogger('experiment2.db', 'experiment2')
)

## Setup DuckDB Logging

In [None]:
# Create additional logging tables for transformer behavior
logger = trainer.logger
logger.conn.execute("""
    CREATE TABLE IF NOT EXISTS attention_stats (
        experiment_id VARCHAR,
        step INTEGER,
        layer INTEGER,
        head INTEGER,
        attention_entropy DOUBLE,
        attention_sparsity DOUBLE,
        FOREIGN KEY (experiment_id) REFERENCES experiments(experiment_id)
    )
""")

logger.conn.execute("""
    CREATE TABLE IF NOT EXISTS activation_stats (
        experiment_id VARCHAR,
        step INTEGER,
        layer INTEGER,
        activation_mean DOUBLE,
        activation_std DOUBLE,
        zero_rate DOUBLE,
        FOREIGN KEY (experiment_id) REFERENCES experiments(experiment_id)
    )
""")

print("Created additional logging tables")

## Attention Pattern Analysis

In [None]:
def analyze_attention_patterns(model_output, layer_idx):
    """Analyze attention patterns for a given layer"""
    attention_weights = model_output['attention_weights'][layer_idx]
    
    # Compute attention entropy
    entropy = -jnp.sum(attention_weights * jnp.log(attention_weights + 1e-10), axis=-1)
    
    # Compute sparsity (% of attention weights below threshold)
    sparsity = jnp.mean(attention_weights < 0.01)
    
    return {
        'entropy': entropy,
        'sparsity': sparsity
    }

def log_attention_metrics(logger, step, model_outputs):
    """Log attention metrics to DuckDB"""
    for layer_idx in range(config['model_config']['num_layers']):
        metrics = analyze_attention_patterns(model_outputs, layer_idx)
        
        for head_idx in range(config['model_config']['num_heads']):
            logger.conn.execute("""
                INSERT INTO attention_stats
                (experiment_id, step, layer, head, attention_entropy, attention_sparsity)
                VALUES (?, ?, ?, ?, ?, ?)
            """, [
                trainer.experiment_name,
                step,
                layer_idx,
                head_idx,
                float(metrics['entropy'][head_idx].mean()),
                float(metrics['sparsity'])
            ])

## Activation Statistics

In [None]:
def analyze_activations(model_output, layer_idx):
    """Analyze activation statistics for a given layer"""
    activations = model_output['hidden_states'][layer_idx]
    
    # Compute basic statistics
    mean = jnp.mean(activations)
    std = jnp.std(activations)
    zero_rate = jnp.mean(jnp.abs(activations) < 1e-6)
    
    return {
        'mean': mean,
        'std': std,
        'zero_rate': zero_rate
    }

def log_activation_metrics(logger, step, model_outputs):
    """Log activation metrics to DuckDB"""
    for layer_idx in range(config['model_config']['num_layers']):
        metrics = analyze_activations(model_outputs, layer_idx)
        
        logger.conn.execute("""
            INSERT INTO activation_stats
            (experiment_id, step, layer, activation_mean, activation_std, zero_rate)
            VALUES (?, ?, ?, ?, ?, ?)
        """, [
            trainer.experiment_name,
            step,
            layer_idx,
            float(metrics['mean']),
            float(metrics['std']),
            float(metrics['zero_rate'])
        ])

## Input Validation and Error Handling

In [None]:
def validate_model_inputs(batch):
    """Validate model inputs before training"""
    required_keys = ['input_ids', 'labels']
    for key in required_keys:
        if key not in batch:
            raise ValueError(f"Missing required key '{key}' in batch")
            
        if not isinstance(batch[key], jnp.ndarray):
            raise TypeError(f"Batch['{key}'] must be a JAX array")
            
        if batch[key].shape[0] != config['training']['batch_size']:
            raise ValueError(f"Batch size mismatch in '{key}': "
                           f"expected {config['training']['batch_size']}, "
                           f"got {batch[key].shape[0]}")
            
        if batch[key].shape[1] != config['model_config']['max_seq_len']:
            raise ValueError(f"Sequence length mismatch in '{key}': "
                           f"expected {config['model_config']['max_seq_len']}, "
                           f"got {batch[key].shape[1]}")

def log_training_error(logger, step, error):
    """Log training errors to DuckDB"""
    logger.conn.execute("""
        INSERT INTO training_errors
        (experiment_id, step, error_type, error_message)
        VALUES (?, ?, ?, ?)
    """, [
        trainer.experiment_name,
        step,
        type(error).__name__,
        str(error)
    ])

class ExperimentMonitor:
    """Monitor training progress and detect issues"""
    def __init__(self, config):
        self.config = config
        self.loss_history = []
        self.nan_count = 0
        self.inf_count = 0
        
    def check_metrics(self, metrics):
        """Check training metrics for issues"""
        loss = metrics.get('loss')
        if loss is not None:
            self.loss_history.append(float(loss))
            
            # Check for NaN/Inf
            if jnp.isnan(loss):
                self.nan_count += 1
            if jnp.isinf(loss):
                self.inf_count += 1
                
            # Check for loss explosion
            if len(self.loss_history) > 10:
                recent_mean = np.mean(self.loss_history[-10:])
                if recent_mean > 1000 * np.mean(self.loss_history[:10]):
                    raise ValueError("Loss explosion detected")
                    
            # Check for loss stagnation
            if len(self.loss_history) > 100:
                recent_std = np.std(self.loss_history[-100:])
                if recent_std < 1e-6:
                    print("Warning: Loss may be stagnating")
                    
        # Check NaN/Inf counts
        if self.nan_count > 5:
            raise ValueError("Too many NaN values in loss")
        if self.inf_count > 5:
            raise ValueError("Too many Inf values in loss")

# Create monitor
monitor = ExperimentMonitor(config)

# Create error logging table
logger.conn.execute("""
    CREATE TABLE IF NOT EXISTS training_errors (
        experiment_id VARCHAR,
        step INTEGER,
        error_type VARCHAR,
        error_message TEXT,
        FOREIGN KEY (experiment_id) REFERENCES experiments(experiment_id)
    )
""")

## Update Training Loop with Validation

In [None]:
# Update training loop with validation
for step in tqdm(range(config['training']['max_steps'])):
    try:
        # Get and validate batch
        batch_idx = step % (len(train_data['input_ids']) // config['training']['batch_size'])
        start_idx = batch_idx * config['training']['batch_size']
        end_idx = start_idx + config['training']['batch_size']
        
        batch = {
            'input_ids': train_data['input_ids'][start_idx:end_idx],
            'labels': train_data['labels'][start_idx:end_idx]
        }
        
        validate_model_inputs(batch)
        
        # Training step
        metrics = trainer.train_step(batch)
        
        # Monitor training progress
        monitor.check_metrics(metrics)
        
        # Log metrics
        log_attention_metrics(trainer.logger, step, metrics['model_outputs'])
        log_activation_metrics(trainer.logger, step, metrics['model_outputs'])
        
        # Evaluation
        if step > 0 and step % config['training']['eval_frequency'] == 0:
            eval_metrics = trainer.evaluate(eval_data)
            print(f"\nStep {step} evaluation:")
            print(f"Loss: {eval_metrics['loss']:.4f}")
            print(f"Accuracy: {eval_metrics.get('accuracy', 0):.4f}")
            
    except Exception as e:
        log_training_error(trainer.logger, step, e)
        print(f"Error at step {step}: {str(e)}")
        if isinstance(e, (ValueError, RuntimeError)):
            # Critical error - stop training
            break
        # For other errors, continue training
        continue

print("Training completed!")

In [None]:
# Training loop
print("Starting training...")

for step in tqdm(range(config['training']['max_steps'])):
    # Get batch
    batch_idx = step % (len(train_data['input_ids']) // config['training']['batch_size'])
    start_idx = batch_idx * config['training']['batch_size']
    end_idx = start_idx + config['training']['batch_size']
    
    batch = {
        'input_ids': train_data['input_ids'][start_idx:end_idx],
        'labels': train_data['labels'][start_idx:end_idx]
    }
    
    # Training step
    metrics = trainer.train_step(batch)
    
    # Log attention and activation metrics
    log_attention_metrics(trainer.logger, step, metrics['model_outputs'])
    log_activation_metrics(trainer.logger, step, metrics['model_outputs'])
    
    # Evaluation
    if step > 0 and step % config['training']['eval_frequency'] == 0:
        eval_metrics = trainer.evaluate(eval_data)
        print(f"\nStep {step} evaluation:")
        print(f"Loss: {eval_metrics['loss']:.4f}")
        print(f"Accuracy: {eval_metrics.get('accuracy', 0):.4f}")

print("Training completed!")

## Results Analysis

In [None]:
# Get training history from logger
training_history = trainer.logger.get_experiment_summary()

# Plot training metrics
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.plot(training_history['metrics_summary']['loss']['history'])
plt.title('Training Loss')
plt.xlabel('Steps')
plt.ylabel('Loss')

plt.subplot(1, 2, 2)
if 'accuracy' in training_history['metrics_summary']:
    plt.plot(training_history['metrics_summary']['accuracy']['history'])
    plt.title('Training Accuracy')
    plt.xlabel('Steps')
    plt.ylabel('Accuracy')

plt.tight_layout()
plt.show()

## Advanced Analysis Queries

In [None]:
# Query attention behavior across layers
attention_df = logger.conn.execute("""
    SELECT 
        layer,
        AVG(attention_entropy) as avg_entropy,
        AVG(attention_sparsity) as avg_sparsity
    FROM attention_stats
    WHERE experiment_id = ?
    GROUP BY layer
    ORDER BY layer
""", [trainer.experiment_name]).fetchdf()

# Plot attention statistics
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(attention_df['layer'], attention_df['avg_entropy'])
plt.title('Average Attention Entropy by Layer')
plt.xlabel('Layer')
plt.ylabel('Entropy')

plt.subplot(1, 2, 2)
plt.plot(attention_df['layer'], attention_df['avg_sparsity'])
plt.title('Average Attention Sparsity by Layer')
plt.xlabel('Layer')
plt.ylabel('Sparsity')

plt.tight_layout()
plt.show()

## Additional Visualizations

In [None]:
def plot_layer_activations():
    """Plot activation statistics across layers"""
    activation_df = logger.conn.execute("""
        SELECT 
            layer,
            AVG(activation_mean) as avg_mean,
            AVG(activation_std) as avg_std,
            AVG(zero_rate) as avg_zero_rate
        FROM activation_stats
        WHERE experiment_id = ?
        GROUP BY layer
        ORDER BY layer
    """, [trainer.experiment_name]).fetchdf()
    
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.plot(activation_df['layer'], activation_df['avg_mean'])
    plt.title('Average Activation Mean')
    plt.xlabel('Layer')
    plt.ylabel('Mean')
    
    plt.subplot(1, 3, 2)
    plt.plot(activation_df['layer'], activation_df['avg_std'])
    plt.title('Average Activation Std')
    plt.xlabel('Layer')
    plt.ylabel('Standard Deviation')
    
    plt.subplot(1, 3, 3)
    plt.plot(activation_df['layer'], activation_df['avg_zero_rate'])
    plt.title('Average Zero Activation Rate')
    plt.xlabel('Layer')
    plt.ylabel('Zero Rate')
    
    plt.tight_layout()
    plt.show()

# Call plotting function after training
plot_layer_activations()

## Model Export

In [None]:
# Save the trained model
save_path = 'experiment2_model'
os.makedirs(save_path, exist_ok=True)
trainer.save_checkpoint(f"{save_path}/checkpoint")

# Save config
with open(f"{save_path}/config.json", 'w') as f:
    json.dump(config, f, indent=2)

print(f"Model and config saved to {save_path}")