# VishwamAI Advanced Pre-training on A100 GPUs

This notebook implements advanced pre-training for VishwamAI using:

- Mixed precision FP8/FP16 training
- Fully Sharded Data Parallel (FSDP)
- Gradient checkpointing
- Efficient memory management
- Multi-dataset curriculum learning

In [None]:
# Colab setup
try:
    import google.colab
    print("Setting up VishwamAI in Google Colab...")
    # Clone repository if not exists
    ![ ! -d "VishwamAI" ] && git clone https://github.com/VishwamAI/VishwamAI.git
    %cd VishwamAI
    
    # Install core dependencies first
    !pip install --quiet torch accelerate datasets transformers
    !pip install --quiet matplotlib seaborn pandas
    
    # Install git-lfs and initialize
    !apt-get -qq update && apt-get -qq install git-lfs
    !git lfs install
    
    # Install VishwamAI
    !pip install -e .
    !pip install -r requirements.txt
    
    # Create directories
    !mkdir -p checkpoints logs training_visualizations
    
    print("Setup complete!")
except ImportError:
    print("Not running in Google Colab")

# Core imports
import os
import gc
import torch
import numpy as np
import pandas as pd
from datetime import datetime
from typing import Dict

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Dataset handling
from datasets import load_dataset, concatenate_datasets
from torch.utils.data import DataLoader

# Distributed training
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from accelerate import Accelerator

# Import VishwamAI classes directly to avoid circular imports
from vishwamai.model import VishwamAI
from vishwamai.utils.config import ModelConfig, TrainingConfig
from vishwamai.training.advanced_training import AdvancedTrainer
from vishwamai.data.dataset import create_combined_dataset
from vishwamai.utils.logging import PretrainingLogger
from vishwamai.utils.checkpoint import CheckpointManager
from vishwamai.utils.hub_utils import HuggingFaceUploader

In [None]:
# Create necessary directories
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('logs', exist_ok=True)

# Configure model and training parameters
model_config = ModelConfig(
    vocab_size=64000,
    hidden_size=8192,
    num_layers=120,
    num_heads=64,
    intermediate_size=32768,
    max_position_embeddings=32768,
    use_moe=True,
    num_experts=8,
    use_memory=True,
    memory_size=4096,
    enable_emergent=True,
    tree_search_depth=3
)

training_config = TrainingConfig(
    learning_rate=1e-4,
    weight_decay=0.1,
    warmup_steps=2000,
    max_grad_norm=1.0,
    fp8_training=True,
    gradient_checkpointing=True,
    gradient_accumulation_steps=32
)

In [None]:
# Initialize logging and load config
logger = PretrainingLogger('configs/pretrain_config.yaml')

# Monitor GPU stats
def log_gpu_stats():
    stats = {
        'gpu_memory_used': torch.cuda.memory_allocated() / 1e9,  # GB
        'gpu_memory_cached': torch.cuda.memory_reserved() / 1e9,  # GB
        'gpu_utilization': torch.cuda.utilization()
    }
    logger.log_hardware_stats(stats)
    return stats

In [None]:
# Initialize distributed training
accelerator = Accelerator(
    mixed_precision='fp8',
    gradient_accumulation_steps=training_config.gradient_accumulation_steps
)

# Load datasets
datasets = [
    ('openai/gsm8k', 'main'),
    ('cais/mmlu', 'all'),
    ('TIGER-Lab/MMLU-Pro', 'main'),
    ('deepmind/math_dataset', 'algebra'),
    ('wikimedia/wikipedia', '20231101.en'),
    ('HuggingFace/c4', 'en'),
    ('sentence-transformers/codesearchnet', 'all')
]

train_dataset = create_combined_dataset(datasets)
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    num_workers=4,
    pin_memory=True
)

In [None]:
# Initialize model
model = VishwamAI(model_config)

# Wrap model in FSDP
model = FSDP(
    model,
    auto_wrap_policy=transformer_auto_wrap_policy,
    mixed_precision=True,
    device_id=torch.cuda.current_device()
)

In [None]:
# Initialize trainer
trainer = AdvancedTrainer(
    model=model,
    config=model_config,
    training_config=training_config,
    use_tree_search=True
)

# Initialize checkpoint manager
checkpoint_manager = CheckpointManager(
    compression=True,
    shard_size=1024*1024*1024  # 1GB shards
)

# Initialize HuggingFace uploader
hub_uploader = HuggingFaceUploader(
    repo_id="VishwamAI/VishwamAI",  # Updated organization/repo path
    token=os.environ.get("HF_TOKEN"),
    private=False
)

In [None]:
# Training loop
num_epochs = 10

try:
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Log GPU stats before training
        log_gpu_stats()
        
        # Train epoch
        metrics = trainer.train_epoch(
            dataloader=train_loader,
            epoch=epoch,
            use_curriculum=True,
            checkpoint_dir='checkpoints'
        )
        
        # Log metrics
        logger.log_metrics(metrics, step=epoch)
        
        # Save checkpoint and upload to Hub
        if (epoch + 1) % 2 == 0:
            checkpoint_path = f"checkpoints/model_epoch_{epoch+1}"
            
            # Save local checkpoint
            checkpoint_manager.save_checkpoint(
                model=model,
                optimizer=trainer.optimizer,
                filepath=checkpoint_path,
                extra_data={
                    'epoch': epoch,
                    'metrics': metrics
                },
                quantize=True
            )
            
            # Upload to HuggingFace Hub
            hub_uploader.upload_checkpoint(
                checkpoint_path=checkpoint_path,
                commit_message=f"Upload model checkpoint for epoch {epoch+1}",
                epoch=epoch+1,
                metrics=metrics
            )
            
            # Upload metrics separately
            hub_uploader.upload_metrics(metrics, epoch+1)
            
            # Log checkpoint
            logger.log_checkpoint(checkpoint_path, epoch)
            
            # Clean up local checkpoint to save space
            if epoch > 2:  # Keep only last 2 checkpoints locally
                old_checkpoint = f"checkpoints/model_epoch_{epoch-1}"
                if os.path.exists(old_checkpoint):
                    os.remove(old_checkpoint)
                    
except Exception as e:
    logger.log_error(e)
    raise e
finally:
    logger.finish()