# HippoFormer Training on Colab

**Setup:** Runtime → Change runtime type → **T4 GPU** (or A100 with Pro+)

This notebook trains HippoFormer with Gemma-2B + LoRA on WikiText-2.

In [None]:
# Check GPU
!nvidia-smi --query-gpu=name,memory.total --format=csv

## 1. Install Dependencies

In [None]:
# Clone repository
!git clone https://github.com/YOUR_USERNAME/BrainLLM.git
%cd BrainLLM

# Install with all dependencies
!pip install -e ".[all]" -q

# Verify
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. HuggingFace Login (for Gemma access)

In [None]:
from huggingface_hub import login

# Get your token from: https://huggingface.co/settings/tokens
# Also accept Gemma license at: https://huggingface.co/google/gemma-2b
login()

## 3. Quick Test (5 minutes)

In [None]:
from hippoformer.config import HippoFormerConfig
from hippoformer.model import HippoFormer
from hippoformer.train import TrainingArgs, HippoFormerTrainer, create_dataloaders
from transformers import AutoTokenizer
import torch

# Configuration
config = HippoFormerConfig(
    base_model_name="google/gemma-2b",
    freeze_base=True,
    use_lora=True,
)

# Quick test settings
args = TrainingArgs(
    dataset_name="wikitext",
    dataset_config="wikitext-2-raw-v1",
    batch_size=4,
    gradient_accumulation_steps=4,
    num_epochs=1,
    learning_rate=1e-4,
    max_seq_length=512,
    output_dir="./outputs/quick_test",
    logging_steps=10,
    save_steps=100,
    device="cuda",
)

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Creating model...")
model = HippoFormer(config)
print(f"Total parameters: {model.get_num_total_params():,}")
print(f"Trainable parameters: {model.get_num_trainable_params():,}")

In [None]:
# Create dataloaders with limited samples for quick test
from datasets import load_dataset
from torch.utils.data import DataLoader

print("Loading dataset...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

# Limit to 500 samples for quick test
MAX_SAMPLES = 500

def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=args.max_seq_length,
        padding="max_length",
    )

train_data = dataset["train"].select(range(min(MAX_SAMPLES, len(dataset["train"]))))
val_data = dataset["validation"].select(range(min(100, len(dataset["validation"]))))

train_tokenized = train_data.map(tokenize_function, batched=True, remove_columns=["text"])
val_tokenized = val_data.map(tokenize_function, batched=True, remove_columns=["text"])

train_tokenized.set_format(type="torch", columns=["input_ids", "attention_mask"])
val_tokenized.set_format(type="torch", columns=["input_ids", "attention_mask"])

train_dataloader = DataLoader(train_tokenized, batch_size=args.batch_size, shuffle=True)
eval_dataloader = DataLoader(val_tokenized, batch_size=args.batch_size)

print(f"Train batches: {len(train_dataloader)}")
print(f"Eval batches: {len(eval_dataloader)}")

In [None]:
# Train!
trainer = HippoFormerTrainer(model, args, tokenizer)
history = trainer.train(train_dataloader, eval_dataloader)

print(f"\nFinal train loss: {history['train_loss'][-1]:.4f}")
if history['eval_loss']:
    print(f"Final eval loss: {history['eval_loss'][-1]:.4f}")

## 4. Full Training (2-3 hours on T4)

In [None]:
# Full training configuration
# Uncomment and run for full training

'''
from hippoformer.config import HippoFormerConfig
from hippoformer.model import HippoFormer
from hippoformer.train import TrainingArgs, HippoFormerTrainer, create_dataloaders
from transformers import AutoTokenizer

config = HippoFormerConfig(
    base_model_name="google/gemma-2b",
    freeze_base=True,
    use_lora=True,
)

args = TrainingArgs(
    dataset_name="wikitext",
    dataset_config="wikitext-2-raw-v1",
    batch_size=8,
    gradient_accumulation_steps=4,
    num_epochs=3,
    learning_rate=1e-4,
    max_seq_length=512,
    output_dir="./outputs/full_training",
    device="cuda",
)

tokenizer = AutoTokenizer.from_pretrained(config.base_model_name)
tokenizer.pad_token = tokenizer.eos_token

model = HippoFormer(config)
train_dataloader, eval_dataloader = create_dataloaders(tokenizer, args)

trainer = HippoFormerTrainer(model, args, tokenizer)
history = trainer.train(train_dataloader, eval_dataloader)
'''

## 5. Save to Google Drive

In [None]:
# Mount Google Drive to persist outputs
from google.colab import drive
drive.mount('/content/drive')

# Copy outputs to Drive
!cp -r ./outputs /content/drive/MyDrive/hippoformer_outputs
print("Outputs saved to Google Drive!")

## 6. Evaluation

In [None]:
# Run evaluation
from evaluation.metrics import compute_perplexity
from evaluation.datasets import create_eval_dataloader

# Load eval data
eval_loader = create_eval_dataloader(
    "wikitext-2",
    tokenizer=tokenizer,
    batch_size=4,
    max_samples=500,
)

# Compute perplexity
model.eval()
perplexity = compute_perplexity(model, eval_loader, device="cuda")
print(f"Perplexity: {perplexity:.2f}")

## 7. Ablation Study

In [None]:
# Run ablation (takes longer)
'''
from evaluation.ablation import AblationRunner
from hippoformer.config import HippoFormerConfig

base_config = HippoFormerConfig(
    base_model_name="google/gemma-2b",
    freeze_base=True,
    use_lora=True,
)

runner = AblationRunner(
    base_config=base_config,
    output_dir="./outputs/ablation",
    num_seeds=3,
)

# Run key ablations
variants = ["baseline", "no_salience", "no_memory", "no_drift"]
results = runner.run_ablation_suite(variants)
runner.generate_comparison_table()
'''