# VishwamAI Pre-training Experiment

This notebook implements TPU v2 optimized pre-training using:
- Model: Custom Phi-1.6 TPU implementation (microsoft/phi-4)
- Dataset: GSM8K
- Hardware: TPU v2
- Training Types: Normal and Distillation training

In [None]:
import os
import json
import jax
import jax.numpy as jnp
import numpy as np
from datasets import load_dataset
import jax.tools.colab_tpu
from vishwamai.transformer import (
    create_vishwamai_transformer,
    create_train_state,
    train_step,
    evaluate_step,
    setup_distributed_training,
    data_loader
)

# Check TPU configuration
print("JAX devices:", jax.devices())
print("Number of devices:", jax.device_count())

## TPU Setup

In [None]:
# Initialize TPU system
jax.tools.colab_tpu.setup_tpu()

# Update model configuration
config['model_config'].update({
    'vocab_size': 32000,
    'num_layers': 12,
    'num_heads': 12,
    'head_dim': 64,
    'hidden_dim': 768,
    'mlp_dim': 3072,
    'max_seq_len': 2048,
    'use_flash_attn': True,
    'use_rotary': True,
    'use_rms_norm': True,
    'dtype': 'bfloat16',
    'compute_dtype': 'float32'
})

# Update training configuration
config['training'].update({
    'batch_size': 32 * jax.device_count(),  # Scale batch size by number of devices
    'learning_rate': 1e-4,
    'warmup_steps': 2000,
    'decay_steps': 50000,
    'weight_decay': 0.01,
    'gradient_checkpointing': True,
    'gradient_accumulation_steps': 4,
    'mixed_precision': True,
    'tpu_iterations_per_loop': 100
})

print("Model configuration:", config['model_config'])
print("\nTraining configuration:", config['training'])

## Load Configuration

In [None]:
# Load and update config for TPU v2
with open('vishwamai/configs/config_16b.json', 'r') as f:
    config = json.load(f)

# Update config for TPU v2 optimizations
config['tpu_config'].update({
    "num_devices": 8,
    "batch_partition_size": 8,
    "model_partition_size": 4,
    "block_size": 128,
    "use_f8_training": True,
    "rematerialize": True
})

config['model_config'].update({
    "use_enhanced": True,
    "use_rotary": True,
    "use_flash_attn": True,
    "use_rms_norm": True,
    "dtype": "bfloat16"
})

## Load Dataset

In [None]:
# Load GSM8K dataset
dataset = load_dataset("openai/gsm8k", split="train")

# Create tokenizer
tokenizer = create_tokenizer(config['model_config'])

def preprocess_function(examples):
    """Preprocess examples for TPU training"""
    texts = [f"Question: {q}\nAnswer: {a}" for q, a in zip(examples["question"], examples["answer"])]
    
    # Tokenize with padding
    max_length = config['model_config']['max_seq_len']
    tokenized = tokenizer.batch_encode(
        texts,
        max_length=max_length,
        padding=True,
        truncation=True
    )
    
    return {
        'input_ids': np.array(tokenized['input_ids']),
        'attention_mask': np.array(tokenized['attention_mask'])
    }

# Preprocess dataset
tokenized_dataset = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset.column_names
)

# Create data loaders
train_size = int(0.9 * len(tokenized_dataset))
train_dataset = tokenized_dataset.select(range(train_size))
eval_dataset = tokenized_dataset.select(range(train_size, len(tokenized_dataset)))

batch_size = config['training_config']['batch_size'] * config['tpu_config']['num_devices']
train_loader = data_loader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = data_loader(eval_dataset, batch_size=batch_size, shuffle=False)

## Normal Training Setup

In [None]:
# Create model and training state
print("Initializing model...")
rng = jax.random.PRNGKey(config['seed'])
model = create_vishwamai_transformer(config['model_config'])

# Create training state
state = create_train_state(
    rng=rng,
    config=config,
    learning_rate_schedule=lambda step: config['training_config']['learning_rate']
)

## Distillation Training Setup

In [None]:
# Create student model and training state
print("Initializing distillation training...")
distill_rng = jax.random.PRNGKey(config['seed'] + 1)

# Initialize distillation training state
distill_state = create_distillation_train_state(
    rng=distill_rng,
    config=config,
    learning_rate_schedule=lambda step: config['training_config']['learning_rate']
)

## Training Loops

In [None]:
def train_epoch(state, is_distill=False):
    """Train for one epoch"""
    train_metrics = []
    
    for batch in train_loader:
        # Get next PRNG key
        rng = jax.random.PRNGKey(int(time.time()))
        
        # Training step
        if is_distill:
            state, metrics = distillation_train_step(
                state=state,
                batch=batch,
                dropout_rng=rng,
                temperature=config['distillation_config']['temperature'],
                alpha=config['distillation_config']['alpha']
            )
        else:
            state, metrics = train_step(
                state=state,
                batch=batch,
                dropout_rng=rng
            )
            
        train_metrics.append(metrics)
    
    # Compute mean of metrics
    metrics_np = jax.device_get(train_metrics)
    metrics_mean = {
        k: np.mean([metrics[k] for metrics in metrics_np])
        for k in metrics_np[0]
    }
    
    return state, metrics_mean

def evaluate(state):
    """Run evaluation"""
    eval_metrics = []
    
    for batch in eval_loader:
        metrics = evaluate_step(state, batch)
        eval_metrics.append(metrics)
    
    # Compute mean of metrics
    metrics_np = jax.device_get(eval_metrics)
    metrics_mean = {
        k: np.mean([metrics[k] for metrics in metrics_np])
        for k in metrics_np[0]
    }
    
    return metrics_mean

## Run Training

In [None]:
# Normal training
print("Starting normal training...")
num_epochs = config['training_config']['num_epochs']

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    
    # Training
    state, train_metrics = train_epoch(state)
    print(f"Training metrics: {train_metrics}")
    
    # Evaluation
    eval_metrics = evaluate(state)
    print(f"Evaluation metrics: {eval_metrics}")

In [None]:
# Distillation training
print("\nStarting distillation training...")

save_dir = 'experiment2_final'
os.makedirs(save_dir, exist_ok=True)

# Save model
trainer.save_checkpoint(f"{save_dir}/model")

# Save configuration
with open(f"{save_dir}/config.json", 'w') as f:
    json.dump(config, f, indent=2)

print(f"Model and configuration saved to {save_dir}")