# VishwamAI Colab Training

Training with custom VishwamAI Transformer architecture using advanced features.

In [None]:
# Setup Environment
!nvidia-smi

from google.colab import drive
drive.mount('/content/drive')

# Setup directories
import os
DRIVE_DIR = '/content/drive/MyDrive/VishwamAI'
CHECKPOINT_DIR = f'{DRIVE_DIR}/checkpoints'
!mkdir -p {CHECKPOINT_DIR}

In [None]:
# Install dependencies and clone repo
!git clone https://github.com/VishwamAI/VishwamAI.git
%cd VishwamAI

%pip install -q torch transformers datasets accelerate bitsandbytes wandb
%pip install -e .

# Secure Hugging Face authentication
from huggingface_hub import login
import os
import getpass

def get_huggingface_token():
    """Get Hugging Face token from environment or prompt"""
    token = os.getenv('HUGGINGFACE_TOKEN')
    if not token:
        print("HUGGINGFACE_TOKEN not found in environment")
        token = getpass.getpass('Enter your Hugging Face token (input will be hidden): ')
        # Store temporarily for this session
        os.environ['HUGGINGFACE_TOKEN'] = token
    return token

try:
    token = get_huggingface_token()
    login(token=token)
    print("Successfully logged in to Hugging Face")
except Exception as e:
    print(f"Error logging in to Hugging Face: {str(e)}")

In [1]:
# Import required modules
import torch
from transformers import AutoTokenizer
from datasets import load_dataset

# Import VishwamAI components with correct paths
from vishwamai.base_layers import Linear  # Import from base_layers instead of utils
from vishwamai.Transformer import Transformer
from vishwamai import (
    create_model,
    ModelArgs,
    VishwamAITokenizer,
    TokenizerConfig
)

# Import training components
from vishwamai.advanced_training import AdvancedTrainer
from vishwamai.fp8_cast_bf16 import main
from vishwamai.neural_memory import NeuralMemory
from vishwamai.tree_of_thoughts import TreeConfig, RewardConfig
from vishwamai.curriculum import CurriculumConfig

# Import visualization tools
import matplotlib.pyplot as plt
import pandas as pd

# Verify imports were successful
print("Imports completed successfully")

In [None]:
# Initialize visualization and analysis tools
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from datetime import datetime

# Configure plotting style
plt.style.use('seaborn')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = [12, 6]
plt.rcParams['figure.dpi'] = 100
plt.rcParams['axes.grid'] = True

# Initialize performance tracking
performance_history = {
    'steps': [],
    'loss': [],
    'learning_rate': [],
    'memory_usage': [],
    'curriculum_level': [],
    'expert_usage': [],
    'evaluation_scores': []
}

performance_df = pd.DataFrame(performance_history)

In [None]:
def update_performance_tracking(stats, step):
    """Update performance tracking with new statistics"""
    performance_df.loc[len(performance_df)] = {
        'steps': step,
        'loss': stats['loss'],
        'learning_rate': stats['lr'],
        'memory_usage': stats['memory_usage']['allocated'],
        'curriculum_level': stats['curriculum_stats']['current_difficulty'],
        'expert_usage': sum(stats.get('moe_metrics', {}).values()) / len(stats.get('moe_metrics', {})),
        'evaluation_scores': stats.get('eval_score', 0)
    }

def plot_training_progress():
    """Generate training progress visualization"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Training Progress Overview', fontsize=16)
    
    axes[0,0].plot(performance_df['steps'], performance_df['loss'])
    axes[0,0].set_title('Training Loss')
    axes[0,0].set_xlabel('Steps')
    axes[0,0].set_ylabel('Loss')
    
    axes[0,1].plot(performance_df['steps'], performance_df['learning_rate'])
    axes[0,1].set_title('Learning Rate')
    axes[0,1].set_xlabel('Steps')
    axes[0,1].set_ylabel('Learning Rate')
    
    axes[1,0].plot(performance_df['steps'], performance_df['curriculum_level'])
    axes[1,0].set_title('Curriculum Difficulty')
    axes[1,0].set_xlabel('Steps')
    axes[1,0].set_ylabel('Difficulty Level')
    
    axes[1,1].plot(performance_df['steps'], performance_df['expert_usage'])
    axes[1,1].set_title('Expert Usage')
    axes[1,1].set_xlabel('Steps')
    axes[1,1].set_ylabel('Average Usage')
    
    plt.tight_layout()
    plt.savefig(f"{DRIVE_DIR}/training_progress.png")
    plt.show()

print("Visualization and performance tracking initialized")

In [None]:
# Create neural memory and model arguments
model_args = ModelArgs(
    max_batch_size=4,
    max_seq_len=2048,
    dtype="fp8",
    vocab_size=32000,
    dim=1024,
    inter_dim=2816,
    moe_inter_dim=512,
    n_layers=12,
    n_dense_layers=1,
    n_heads=16,
    n_routed_experts=8,
    n_shared_experts=1,
    n_activated_experts=2,
    n_expert_groups=1,
    n_limited_groups=1,
    score_func="softmax",
    route_scale=1.0,
    q_lora_rank=0,
    kv_lora_rank=64,
    qk_nope_head_dim=64,
    qk_rope_head_dim=32,
    v_head_dim=64,
    original_seq_len=2048,
    rope_theta=10000.0,
    rope_factor=20,
    beta_fast=16,
    beta_slow=1,
    mscale=0.5,
    use_alibi=False,
    use_rope_scaling=True,
    gradient_checkpointing=True,
    parallel_attn=True,
    rope_condense_ratio=1.0
)

neural_memory = NeuralMemory(model_args)

In [None]:
# Initialize training components
tot_config = TreeConfig(
    max_branches=4,
    max_depth=3,
    beam_width=2,
    reward_gamma=0.95
)

reward_config = RewardConfig(
    reasoning_weight=0.4,
    accuracy_weight=0.4,
    consistency_weight=0.2
)

curriculum_config = CurriculumConfig(
    min_sequence_length=32,
    max_sequence_length=512,
    min_vocab_complexity=0.3,
    max_vocab_complexity=1.0,
    min_reasoning_steps=1,
    max_reasoning_steps=8,
    pacing_function='root',
    total_curriculum_steps=10000,
    performance_threshold=0.8,
    min_samples_before_advance=100,
    smoothing_factor=0.95
)

In [None]:
# Train tokenizer on dataset texts
print("Training tokenizer...")

import logging
import sentencepiece as spm
from pathlib import Path

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Ensure tokenizer directory exists
tokenizer_dir = Path("tokenizer")
tokenizer_dir.mkdir(exist_ok=True)

# Prepare training data
def prepare_training_data(datasets, max_samples=10000):
    logger.info("Preparing training data...")
    train_path = tokenizer_dir / "train.txt"
    
    with open(train_path, "w", encoding="utf-8") as f:
        for name, dataset in datasets.items():
            logger.info(f"Processing {name} dataset")
            if name in ["gsm8k", "mmlu"]:
                for i, item in enumerate(dataset):
                    if i >= max_samples:
                        break
                    if "question" in item and "answer" in item:
                        f.write(f"{item['question']}\n")
                        f.write(f"{item['answer']}\n")
            elif name == "code":
                for i, item in enumerate(dataset):
                    if i >= max_samples:
                        break
                    if "content" in item:
                        f.write(f"{item['content']}\n")
                        
    return train_path

try:
    # Initialize tokenizer with reduced vocabulary size
    tokenizer_config = TokenizerConfig(
        vocab_size=26519,  # Reduced vocab size
        model_prefix="vishwamai",
        character_coverage=0.9995,
        max_sentence_length=2048,
        pad_id=0,
        bos_id=1,
        eos_id=2,
        unk_id=3
    )
    
    # Create training data file
    train_path = prepare_training_data(datasets)
    logger.info(f"Training data saved to {train_path}")
    
    # Train SentencePiece model
    model_prefix = str(tokenizer_dir / tokenizer_config.model_prefix)
    spm.SentencePieceTrainer.train(
        input=str(train_path),
        model_prefix=model_prefix,
        vocab_size=tokenizer_config.vocab_size,
        character_coverage=tokenizer_config.character_coverage,
        model_type="bpe",
        max_sentence_length=tokenizer_config.max_sentence_length,
        pad_id=tokenizer_config.pad_id,
        bos_id=tokenizer_config.bos_id,
        eos_id=tokenizer_config.eos_id,
        unk_id=tokenizer_config.unk_id,
        input_sentence_size=10000000,
        shuffle_input_sentence=True,
        train_extremely_large_corpus=True
    )
    
    # Load the trained tokenizer
    tokenizer = VishwamAITokenizer(tokenizer_config)
    tokenizer.load(f"{model_prefix}.model")
    logger.info("Tokenizer training complete")
    
    # Verify tokenizer works
    test_text = "Hello world"
    encoded = tokenizer.encode(test_text)
    decoded = tokenizer.decode(encoded)
    logger.info(f"Tokenizer test - Encoded: {encoded}")
    logger.info(f"Tokenizer test - Decoded: {decoded}")

except Exception as e:
    logger.error(f"Error during tokenizer training: {str(e)}")
    raise

In [None]:
# Load and process datasets with trained tokenizer
print("Loading datasets...")
datasets = {
    "gsm8k": load_dataset("gsm8k", split="train"),
    "mmlu": load_dataset("cais/mmlu", split="train"),
    "code": load_dataset("codeparrot/github-code", split="train")
}

def process_dataset(examples, dataset_type):
    if dataset_type in ["gsm8k", "mmlu"]:
        text = [f"Question: {q}\nAnswer: {a}" 
                for q, a in zip(examples["question"], examples["answer"])]
    else:
        text = examples["content"]

    encoded = []
    attention_mask = []
    max_len = 2048

    for t in text:
        # Encode with trained tokenizer
        ids = tokenizer.encode(
            t,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        
        # Pad sequence
        padding_length = max_len - len(ids)
        if padding_length > 0:
            ids = ids + [tokenizer.config.pad_id] * padding_length
            mask = [1] * (max_len - padding_length) + [0] * padding_length
        else:
            ids = ids[:max_len]
            mask = [1] * max_len
            
        encoded.append(ids)
        attention_mask.append(mask)

    return {
        "input_ids": encoded,
        "attention_mask": attention_mask
    }

# Process datasets
processed_datasets = {}
for name, dataset in datasets.items():
    print(f"Processing {name}...")
    processed_datasets[name] = dataset.map(
        lambda x: process_dataset(x, name),
        batched=True,
        remove_columns=dataset.column_names
    )
    print(f"Processed {len(processed_datasets[name])} examples from {name}")

In [None]:
# Initialize model and trainer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, _ = create_model(config)
model = model.to(device)
main(model)

trainer = AdvancedTrainer(
    model=model,
    config=config,
    device=device,
    memory_size=512,
    cache_size=256,
    tot_config=tot_config,
    reward_config=reward_config,
    curriculum_config=curriculum_config
)

# Initialize wandb
import wandb
wandb.init(
    project="vishwamai-training",
    config={
        "model": config,
        "curriculum": curriculum_config.__dict__,
        "tot": tot_config.__dict__
    }
)

print(f"Model initialized on {device}")
print(f"Memory usage: {torch.cuda.memory_allocated(device)/1e9:.2f} GB")

In [None]:
# Training Loop with Performance Tracking
from tqdm.notebook import tqdm
import wandb

wandb.init(project="vishwamai-training")

performance_data = []

try:
    for step in tqdm(range(config["max_steps"])):
        stats = trainer.train_step()
        update_performance_tracking(stats, step)
        
        wandb.log({
            "loss": stats["loss"],
            "learning_rate": stats["lr"],
            "batch_size": stats["batch_size"],
            "curriculum_level": stats["curriculum_stats"]["current_difficulty"],
            "memory_usage": stats["memory_usage"]["allocated"],
            "moe_loss": stats.get("moe_loss", 0),
            "gradient_norm": stats["gradient_norm"],
            "expert_usage": stats.get("moe_metrics", {})
        })
        
        if step % 1000 == 0:
            plot_training_progress()
            checkpoint_path = f"{CHECKPOINT_DIR}/step_{step}.pt"
            trainer.save_checkpoint(checkpoint_path)
            
            trainer.push_to_hub(
                "VishwamAI/VishwamAI",
                commit_message=f"Training checkpoint at step {step}"
            )
            
        if step % 5000 == 0:
            print(f"\nEvaluating at step {step}...")
            eval_metrics = trainer.evaluate()
            wandb.log({"eval": eval_metrics})
            
except KeyboardInterrupt:
    print("\nTraining interrupted. Saving final visualization...")
    plot_training_progress()
    trainer.save_checkpoint(f"{CHECKPOINT_DIR}/interrupted.pt")

plot_training_progress()
performance_df.to_csv(f"{DRIVE_DIR}/training_metrics.csv", index=False)
print("Training complete with performance tracking")

In [None]:
# Generate Performance Graphs
performance_df = pd.DataFrame(performance_data)

plt.figure(figsize=(12, 6))
plt.plot(performance_df["step"], performance_df["loss"], label="Loss")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.title("Training Loss Over Time")
plt.legend()
plt.grid(True)
plt.savefig(f"{DRIVE_DIR}/training_loss.png")
plt.show()

plt.figure(figsize=(12, 6))
plt.plot(performance_df["step"], performance_df["memory_usage"], label="Memory Usage (GB)")
plt.xlabel("Step")
plt.ylabel("Memory Usage (GB)")
plt.title("Memory Usage Over Time")
plt.legend()
plt.grid(True)
plt.savefig(f"{DRIVE_DIR}/memory_usage.png")
plt.show()

In [None]:
# Final Evaluation
print("Running final evaluation...")

eval_datasets = [
    "gsm8k",
    "TIGER-Lab/MMLU-Pro",
    "MMMU/MMMU",
    "microsoft/SCBench",
    "camel-ai/math",
    "camel-ai/code"
]

results = {}
for dataset in eval_datasets:
    print(f"\nEvaluating on {dataset}...")
    try:
        eval_data = load_dataset(dataset, split="test")
        metrics = trainer.evaluate(eval_data)
        results[dataset] = metrics
        print(f"{dataset}: {metrics}")
    except Exception as e:
        print(f"Error evaluating {dataset}: {str(e)}")

# Save results
import json
with open(f"{DRIVE_DIR}/final_evaluation.json", "w") as f:
    json.dump(results, f, indent=2)

print("\nTraining complete!")