# VishwamAI Training with T5 Integration on Colab T4

This notebook integrates T5 architecture with VishwamAI and optimizes for Google Colab's T4 GPU.

In [None]:
# Install dependencies
!pip install -q torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q transformers==4.34.0 datasets accelerate huggingface_hub wandb bitsandbytes
!pip install -q sentencepiece protobuf

# Clone repository
!git clone https://github.com/VishwamAI/VishwamAI.git
%cd VishwamAI
!pip install -e . -q

In [None]:
import torch
import gc
import json
import time
from transformers import (
    AutoTokenizer,
    T5ForConditionalGeneration,
    T5Config,
    Trainer,
    TrainingArguments
)
from datasets import load_dataset
import bitsandbytes as bnb
from vishwamai.config import ModelArgs
from vishwamai.model import Transformer

# Verify GPU
!nvidia-smi
print(f"\nGPU: {torch.cuda.get_device_name(0)}")
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory/1024/1024/1024:.2f}GB")

In [None]:
def prepare_model():
    """Initialize T5-based model with VishwamAI enhancements"""
    # Load T5 base model and tokenizer
    t5_config = T5Config.from_pretrained('t5-base')
    t5_model = T5ForConditionalGeneration.from_pretrained(
        't5-base',
        device_map='auto',
        load_in_8bit=True  # Enable 8-bit quantization
    )
    tokenizer = AutoTokenizer.from_pretrained('t5-base')
    
    # Initialize VishwamAI config with T5 dimensions
    model_config = ModelArgs(
        max_batch_size=4,
        max_seq_len=512,
        vocab_size=t5_config.vocab_size,
        dim=t5_config.d_model,  # Match T5 hidden size
        n_heads=t5_config.num_heads,
        n_layers=t5_config.num_layers,
        inter_dim=t5_config.d_ff,
        dtype="fp8",
        gradient_checkpointing=True
    )
    
    # Initialize VishwamAI model
    model = Transformer(model_config)
    
    # Copy T5 weights where architectures match
    with torch.no_grad():
        # Copy embedding weights
        model.tok_embeddings.weight.copy_(t5_model.shared.weight)
        
        # Copy layer weights where possible
        for i in range(min(model_config.n_layers, t5_config.num_layers)):
            t5_layer = t5_model.encoder.block[i]
            vish_layer = model.layers[i]
            
            # Copy attention weights
            vish_layer.attn.wq.weight.copy_(t5_layer.layer[0].SelfAttention.q.weight)
            vish_layer.attn.wo.weight.copy_(t5_layer.layer[0].SelfAttention.o.weight)
            
            # Copy FFN weights
            if i < model_config.n_dense_layers:
                vish_layer.ffn.w1.weight.copy_(t5_layer.layer[1].DenseReluDense.wi.weight)
                vish_layer.ffn.w2.weight.copy_(t5_layer.layer[1].DenseReluDense.wo.weight)
    
    # Apply 8-bit quantization
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            model._modules[name] = bnb.nn.Linear8bitLt(
                module.in_features,
                module.out_features,
                module.bias is not None,
                has_fp16_weights=False,
                threshold=6.0
            )
    
    return model, tokenizer

model, tokenizer = prepare_model()

In [None]:
# Configure training arguments for T4
training_args = TrainingArguments(
    output_dir="./vishwamai_t5_output",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    warmup_steps=500,
    learning_rate=5e-5,
    fp16=True,
    logging_steps=10,
    save_strategy="steps",
    save_steps=200,
    evaluation_strategy="steps",
    eval_steps=200,
    gradient_checkpointing=True,
    report_to=["wandb"],
    load_best_model_at_end=True,
    push_to_hub=True,
    hub_model_id="kasinadhsarma/vishwamai-model"
)

In [None]:
def prepare_datasets():
    """Prepare datasets with T5-style formatting"""
    def format_gsm8k(example):
        return {
            "input_text": f"solve: {example['question']}",
            "target_text": example['answer']
        }
        
    def format_mmlu(example):
        return {
            "input_text": f"answer: {example['question']}\n\nOptions:\nA) {example['choices'][0]}\nB) {example['choices'][1]}\nC) {example['choices'][2]}\nD) {example['choices'][3]}",
            "target_text": f"The answer is {example['choices'][example['answer']]}"
        }
    
    # Load datasets
    train_gsm8k = load_dataset("gsm8k", "main", split="train")
    train_mmlu = load_dataset("cais/mmlu", "all", split="train")
    eval_dataset = load_dataset("cais/mmlu", "all", split="validation")
    
    # Format datasets
    train_gsm8k = train_gsm8k.map(format_gsm8k)
    train_mmlu = train_mmlu.map(format_mmlu)
    eval_dataset = eval_dataset.map(format_mmlu)
    
    # Combine training datasets
    train_dataset = concatenate_datasets([train_gsm8k, train_mmlu])
    
    # Tokenize
    def tokenize_function(examples):
        model_inputs = tokenizer(
            examples["input_text"],
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        labels = tokenizer(
            examples["target_text"],
            padding="max_length",
            truncation=True,
            max_length=128,
            return_tensors="pt"
        )
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
    
    train_dataset = train_dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=train_dataset.column_names
    )
    eval_dataset = eval_dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=eval_dataset.column_names
    )
    
    return train_dataset, eval_dataset

train_dataset, eval_dataset = prepare_datasets()

In [None]:
# Initialize Hugging Face and W&B
from huggingface_hub import login
import wandb

hf_token = input("Enter your Hugging Face access token: ")
login(token=hf_token)
wandb.login()

In [None]:
def train_model():
    """Train the model with T5 integration"""
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer
    )
    
    print("Starting training...")
    try:
        trainer.train()
        trainer.save_model("./final_model")
        trainer.push_to_hub(
            commit_message=f"VishwamAI-T5 training completed - {time.strftime('%Y-%m-%d %H:%M:%S')}"
        )
        print("Training completed successfully")
    except Exception as e:
        print(f"Training interrupted: {e}")
        gc.collect()
        torch.cuda.empty_cache()
        raise e

train_model()
print(f"Model available at: https://huggingface.co/kasinadhsarma/vishwamai-model")