# Training a 2B Parameter CRSM Model on Google Colab

This notebook provides a proof-of-concept implementation for training a ~2B parameter language model using the CRSM architecture. It includes:
- Base model pretraining
- Instruction tuning
- Model validation and evaluation
- Colab-specific optimizations

Make sure to:
1. Use a GPU runtime (preferably A100)
2. Mount Google Drive for model checkpoints
3. Set batch sizes according to available memory

# 1. Environment and GPU Check

First, let's verify we have the correct runtime and resources available:

In [None]:
import torch
import psutil
import os
import subprocess
from pynvml import *

def check_gpu():
    if not torch.cuda.is_available():
        raise RuntimeError("No GPU available. Please select GPU runtime in Colab.")
    
    # Get GPU info
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    gpu_model = nvmlDeviceGetName(handle).decode('utf-8')
    gpu_memory_gb = info.total / (1024**3)
    
    print(f"GPU: {gpu_model}")
    print(f"GPU Memory: {gpu_memory_gb:.1f}GB")
    print(f"CUDA Version: {torch.version.cuda}")
    
    # Get system memory
    ram_gb = psutil.virtual_memory().total / (1024**3)
    print(f"System RAM: {ram_gb:.1f}GB")
    
    # Determine safe batch size based on GPU
    if 'A100' in gpu_model:
        return {'batch_size': 32, 'gradient_accum': 4}
    elif gpu_memory_gb > 24:  # V100 or similar
        return {'batch_size': 16, 'gradient_accum': 8}
    else:  # T4 or similar
        return {'batch_size': 8, 'gradient_accum': 16}

training_params = check_gpu()

# 2. Install Required Packages

Let's install the necessary packages for training our CRSM model:

In [None]:
!pip install -q transformers datasets accelerate bitsandbytes evaluate huggingface-hub sentencepiece
!pip install -q pytorch-lightning wandb

# Install CRSM from the current directory
!git clone https://github.com/pomilon/CRSM.git
%cd CRSM
!pip install -e .
%cd ..

# 3. Mount Google Drive and Set Up Workspace

We'll mount Google Drive to store our model checkpoints and set up our workspace directories:

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Set up workspace directories
WORKSPACE_DIR = "/content/drive/MyDrive/crsm_2b"
CHECKPOINTS_DIR = f"{WORKSPACE_DIR}/checkpoints"
DATA_DIR = f"{WORKSPACE_DIR}/data"
LOGS_DIR = f"{WORKSPACE_DIR}/logs"

for d in [WORKSPACE_DIR, CHECKPOINTS_DIR, DATA_DIR, LOGS_DIR]:
    os.makedirs(d, exist_ok=True)

print("Workspace directories created at:", WORKSPACE_DIR)

# 4. Download and Prepare Datasets

We'll use a mix of:
- RedPajama for base model training (subset)
- Alpaca/Dolly for instruction tuning

In [None]:
from datasets import load_dataset
import random

# Load a subset of RedPajama for base training
base_dataset = load_dataset("cerebras/SlimPajama-627B", 
                          streaming=True,
                          split="train")

# Filter and prepare base dataset
def prepare_base_data():
    data = []
    for item in base_dataset.take(10000):  # Limit for PoC
        if len(item['text']) > 100:  # Filter very short texts
            data.append(item['text'])
    
    # Split train/val
    random.shuffle(data)
    split = int(len(data) * 0.9)
    return {
        'train': data[:split],
        'validation': data[split:]
    }

# Load instruction dataset
instruct_dataset = load_dataset("OpenAssistant/oasst1", split="train")
messages_by_id = {msg['message_id']: msg for msg in instruct_dataset}

# Prepare instruction data
def prepare_instruct_data():
    data = []
    for item in instruct_dataset:
        # Only process assistant responses that have a parent (human prompt)
        if item['role'] == 'assistant' and item['parent_id'] in messages_by_id:
            human_prompt = messages_by_id[item['parent_id']]['text']
            assistant_response = item['text']
            
            # Format as instruction tuple
            data.append({
                'instruction': human_prompt,
                'input': '',  # OpenAssistant format doesn't separate input
                'output': assistant_response
            })
    
    # Split train/val
    random.shuffle(data)
    split = int(len(data) * 0.9)
    return {
        'train': data[:split],
        'validation': data[split:]
    }

# Save processed datasets
base_data = prepare_base_data()
instruct_data = prepare_instruct_data()

print(f"Base dataset size: {len(base_data['train'])} train, {len(base_data['validation'])} validation")
print(f"Instruction dataset size: {len(instruct_data['train'])} train, {len(instruct_data['validation'])} validation")

# 5. Initialize CRSM Model and Tokenizer

Let's configure and initialize our CRSM model:

In [None]:
import os
import json
import torch
import crsm
from crsm.model import CRSMConfig, CRSMModel
from crsm.tokenizer import Tokenizer

# Model configuration for ~2B parameters
config = CRSMConfig(
    vocab_size=32000,  # Will be updated after tokenizer setup
    hidden_size=2048,
    intermediate_size=8192,
    num_hidden_layers=24,
    num_attention_heads=16,
    max_position_embeddings=2048,
    d_state=256,  # Increased for larger model
    dropout=0.1
)

# Initialize tokenizer (uses HF AutoTokenizer if a name is provided to Tokenizer, otherwise uses SimpleVocab fallback)
# For Colab PoC we use the simple fallback unless you provide a pretrained tokenizer name below.
HF_TOKENIZER_NAME = None  # set to e.g. 'gpt2' to use a standard HF tokenizer
if HF_TOKENIZER_NAME:
    tokenizer = Tokenizer(hf_name=HF_TOKENIZER_NAME)
else:
    tokenizer = Tokenizer()

# Save tokenizer in a compatible way
tokenizer_dir = f"{WORKSPACE_DIR}/tokenizer"
os.makedirs(tokenizer_dir, exist_ok=True)
if getattr(tokenizer, "_hf", None) is not None:
    tokenizer._hf.save_pretrained(tokenizer_dir)
else:
    # Save simple vocab mapping
    vocab = {"itos": tokenizer._simple.itos, "stoi": tokenizer._simple.stoi}
    with open(os.path.join(tokenizer_dir, "simple_vocab.json"), "w") as f:
        json.dump(vocab, f)

# Update config with actual vocab size
config.vocab_size = tokenizer.vocab_size

# Initialize model
model = CRSMModel(config)

# Move model to GPU if available (explicitly place weights on CUDA VRAM)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# When using mixed precision training Lightning will handle autocast/optimizer states,
# but explicitly set model to half if you want to pre-convert weights to fp16 for memory savings
# (only do this if you understand mixed-precision implications):
# if device.type == 'cuda':
#     model.half()

# Print model size
def count_parameters(model):
    return sum(p.numel() for p in model.parameters()) / 1e9

print(f"Model size: {count_parameters(model):.2f}B parameters")

# Note: Lightning Trainer will move the LightningModule to the correct device as well.
# If you wrap `model` into a LightningModule before Trainer.fit, Lightning will relocate as needed.

# 6. Set Up Training Pipeline

Now we'll set up the training pipeline with data loading, optimization, and logging:

In [None]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from crsm.dataset import StreamingTextDataset
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch
from torch.utils.data import Dataset

class CRSMLightningModule(pl.LightningModule):
    def __init__(self, model, tokenizer, learning_rate=3e-4):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.learning_rate = learning_rate
        
    def forward(self, input_ids, attention_mask=None):
        return self.model(input_ids, attention_mask=attention_mask)
    
    def training_step(self, batch, batch_idx):
        # Expect model to return (logits, states) - loss computation must be implemented in training loop or model
        logits, _ = self.model(batch['input_ids'].to(self.device))
        # Simple cross-entropy loss over next tokens
        shift_logits = logits[:, :-1, :]
        shift_labels = batch['labels'].to(self.device)
        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        logits, _ = self.model(batch['input_ids'].to(self.device))
        shift_logits = logits[:, :-1, :]
        shift_labels = batch['labels'].to(self.device)
        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
        self.log('val_loss', loss)
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = CosineAnnealingLR(optimizer, T_max=1000)
        return [optimizer], [scheduler]

# In-memory dataset for lists of raw text
class InMemoryTextDataset(Dataset):
    def __init__(self, texts, tokenizer, seq_len=128):
        self.seq_len = seq_len
        self.tokenizer = tokenizer
        ids = []
        for t in texts:
            toks = tokenizer.encode(t)
            ids.extend(toks)
        self.ids = torch.tensor(ids, dtype=torch.long)

    def __len__(self):
        if len(self.ids) < self.seq_len:
            return 0
        return len(self.ids) - self.seq_len + 1

    def __getitem__(self, idx):
        seq = self.ids[idx: idx + self.seq_len]
        input_ids = seq[:-1].clone()
        labels = seq[1:].clone()
        # Ensure labels are in valid range: [0, vocab_size) or -100 for ignore
        # If any label is >= vocab_size, clamp it
        if hasattr(self.tokenizer, 'vocab_size'):
            vocab_sz = self.tokenizer.vocab_size
            labels = torch.clamp(labels, min=-100, max=vocab_sz - 1)
        return { 'input_ids': input_ids, 'labels': labels }

# Create data loaders - supports either a list of texts or a dataset name/file-based StreamingTextDataset
def create_dataloaders(texts_or_name, tokenizer, batch_size, seq_len=2048):
    # If a Python list (in-memory texts), use InMemoryTextDataset
    if isinstance(texts_or_name, list):
        dataset = InMemoryTextDataset(texts_or_name, tokenizer, seq_len=seq_len)
    else:
        # assume a dataset name string for StreamingTextDataset
        dataset = StreamingTextDataset(dataset_name=texts_or_name, seq_len=seq_len, tokenizer=tokenizer)

    return DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=2,
        pin_memory=False  # Disable pin_memory to avoid CUDA assert during memory pinning; we'll move to device manually in training_step
    )

# Initialize training
model_pl = CRSMLightningModule(model, tokenizer)
# Use seq_len from config
seq_len = getattr(config, 'max_position_embeddings', 2048)

# If base_data entries are lists of strings, pass that; else pass dataset name
train_loader = create_dataloaders(base_data['train'], tokenizer, training_params['batch_size'], seq_len=seq_len)
val_loader = create_dataloaders(base_data['validation'], tokenizer, training_params['batch_size'], seq_len=seq_len)

# Set up trainer
trainer = pl.Trainer(
    accelerator='cuda',
    devices=1,
    max_epochs=10,  # Adjust as needed
    precision='16-mixed',
    accumulate_grad_batches=training_params['gradient_accum'],
    default_root_dir=LOGS_DIR,
    enable_checkpointing=True,
    val_check_interval=0.25
)

# 7. Train Base Model

Let's train our base model and monitor the loss:

In [None]:
# Train the base model
trainer.fit(model_pl, train_loader, val_loader)

# Save the trained model
model_pl.model.save_pretrained(f"{CHECKPOINTS_DIR}/base_model")
tokenizer.save_pretrained(f"{CHECKPOINTS_DIR}/base_model")

print("Base model training completed and saved!")

In [None]:
# Debugging cell: single-batch smoke test without pin_memory issues
# Directly inspect dataset __getitem__ before the dataloader tries to pin it to GPU
import os, traceback, torch
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Running smoke test on device:', device)

# Check tokenizer vocab_size first
vocab_size = getattr(tokenizer, 'vocab_size', None)
print('tokenizer.vocab_size:', vocab_size)

# Get the underlying dataset (before DataLoader wrapping)
# train_loader.dataset gives us the InMemoryTextDataset or StreamingTextDataset
underlying_dataset = train_loader.dataset
print('underlying_dataset type:', type(underlying_dataset).__name__)

# Get one raw sample from the dataset (bypasses DataLoader/pin_memory)
try:
    raw_sample = underlying_dataset[0]
    print('raw sample keys:', list(raw_sample.keys()) if isinstance(raw_sample, dict) else type(raw_sample))
    print('input_ids shape/dtype:', raw_sample['input_ids'].shape, raw_sample['input_ids'].dtype)
    print('labels shape/dtype:', raw_sample['labels'].shape, raw_sample['labels'].dtype)
    
    # Check label ranges
    lbl_min = int(raw_sample['labels'].min().item())
    lbl_max = int(raw_sample['labels'].max().item())
    print('labels min/max:', lbl_min, lbl_max)
    if vocab_size is not None and lbl_max >= vocab_size:
        print(f'>>> ERROR: label value {lbl_max} >= tokenizer.vocab_size {vocab_size}')
    if lbl_min < -100:
        print(f'>>> WARNING: label value {lbl_min} < -100 (unusual; typically -100 is ignore_index)')
except Exception as e:
    print('Failed to get raw sample from dataset:')
    traceback.print_exc()

# Manually collate a batch (simulating what DataLoader does, but without pin_memory)
try:
    print('\nCollating batch manually (no pin_memory)...')
    batch_samples = [underlying_dataset[i] for i in range(min(4, len(underlying_dataset)))]
    
    # Stack input_ids and labels
    batch = {
        'input_ids': torch.stack([s['input_ids'] for s in batch_samples]),
        'labels': torch.stack([s['labels'] for s in batch_samples])
    }
    
    print('Batch (on CPU):', {k: (v.dtype, v.device, v.shape) for k, v in batch.items()})
    
    # Move to device manually (avoids pin_memory CUDA assert)
    batch = {k: v.to(device) for k, v in batch.items()}
    print('Batch (on device):', {k: (v.dtype, v.device, v.shape) for k, v in batch.items()})
    
    input_ids = batch['input_ids']
    labels = batch['labels']
    
    # Run model forward
    model.eval()
    with torch.no_grad():
        out = model(input_ids)
        if isinstance(out, tuple) or isinstance(out, list):
            logits = out[0]
        else:
            logits = out
        print('logits shape/dtype/device:', logits.shape, logits.dtype, logits.device)
        if vocab_size is not None and logits.shape[-1] != vocab_size:
            print(f'>>> MISMATCH: logits last-dim={logits.shape[-1]} != tokenizer.vocab_size={vocab_size}')
    
    # Compute loss on CPU to avoid CUDA kernel asserts
    llogits = logits.cpu()
    llabels = labels.cpu()
    shift_logits = llogits[:, :-1, :].reshape(-1, llogits.size(-1))
    shift_labels = llabels.view(-1)
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
    loss = loss_fct(shift_logits, shift_labels)
    print('CrossEntropyLoss (computed on CPU):', float(loss.item()))
    
except Exception as e:
    print('Error during batch collation/forward/loss:')
    traceback.print_exc()

print('\nSmoke test complete.')

# 8. Instruction Fine-tuning

Now we'll fine-tune the base model on our instruction dataset:

In [None]:
from torch.utils.data import Dataset
import torch

class InstructionDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=2048):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        # Format: Instruction \n Input: {input} \n Output: {output}
        text = f"{item['instruction']}\nInput: {item['input']}\nOutput: {item['output']}"
        
        ids = self.tokenizer.encode(text)
        # truncate/pad
        if len(ids) > self.max_length:
            ids = ids[:self.max_length]
        else:
            ids = ids + [self.tokenizer._simple.stoi.get('<pad>', 0)] * (self.max_length - len(ids))
        input_ids = torch.tensor(ids, dtype=torch.long)
        # labels are next-token (shifted)
        labels = input_ids.clone()
        # For causal LM training, typically label padding tokens are set to -100 to ignore
        labels[labels == self.tokenizer._simple.stoi.get('<pad>', 0)] = -100
        
        return {
            "input_ids": input_ids,
            "attention_mask": (input_ids != self.tokenizer._simple.stoi.get('<pad>', 0)).long(),
            "labels": labels
        }

# 9. Model Validation and Testing

Let's validate our models on some test examples:

In [None]:
from torch.nn.functional import softmax
import evaluate

def generate_text(model, tokenizer, prompt, max_length=100):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_length=max_length,
        num_return_sequences=1,
        temperature=0.7,
        do_sample=True
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test base model on text completion
base_model = CRSMModel.from_pretrained(f"{CHECKPOINTS_DIR}/base_model")
base_model.eval()
base_model.cuda()

print("Base model completion test:")
prompt = "The theory of relativity states that"
completion = generate_text(base_model, tokenizer, prompt)
print(f"Prompt: {prompt}")
print(f"Completion: {completion}\n")

# Test instruction model on various tasks
instruct_model = CRSMModel.from_pretrained(f"{CHECKPOINTS_DIR}/instruct_model")
instruct_model.eval()
instruct_model.cuda()

test_instructions = [
    {
        "instruction": "Explain what photosynthesis is in simple terms.",
        "input": "",
        "reference": "Photosynthesis is how plants make their food using sunlight, water, and carbon dioxide."
    },
    {
        "instruction": "Write a haiku about autumn.",
        "input": "",
        "reference": "Leaves drift to the ground\nCool breeze whispers through branches\nAutumn paints the world"
    }
]

# Load metrics
rouge = evaluate.load('rouge')
bleu = evaluate.load('bleu')

print("Instruction model tests:")
for test in test_instructions:
    prompt = f"{test['instruction']}\nInput: {test['input']}\nOutput:"
    response = generate_text(instruct_model, tokenizer, prompt)
    
    print(f"\nInstruction: {test['instruction']}")
    print(f"Model response: {response}")
    print(f"Reference: {test['reference']}")
    
    # Calculate metrics
    rouge_scores = rouge.compute(predictions=[response], references=[test['reference']])
    bleu_score = bleu.compute(predictions=[response.split()], references=[[test['reference'].split()]])
    
    print(f"ROUGE-L: {rouge_scores['rougeL']:.3f}")
    print(f"BLEU: {bleu_score['bleu']:.3f}")

print("\nValidation complete!")

# 10. Save and Export Models

Finally, let's save our models in a format ready for deployment:

In [None]:
# Save model card information
model_card = f"""
# CRSM 2B Model

This model was trained as a proof-of-concept implementation of the CRSM architecture.

## Model Details
- Parameters: ~2B
- Architecture: CRSM with {config.num_hidden_layers} layers
- Context Length: {config.max_position_embeddings} tokens
- Vocab Size: {config.vocab_size} tokens

## Training
- Base training: RedPajama dataset subset
- Instruction tuning: Dolly dataset
- Training platform: Google Colab
- Hardware: {nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(0)).decode('utf-8')}

## Usage
```python
from crsm.model import CRSMModel
from transformers import AutoTokenizer

model = CRSMModel.from_pretrained("path_to_model")
tokenizer = AutoTokenizer.from_pretrained("path_to_model")

# Generate text
text = generate_text(model, tokenizer, "Your prompt here")
```

## Limitations
- Proof-of-concept implementation
- Limited training data
- Colab resource constraints
"""

# Save model card
with open(f"{CHECKPOINTS_DIR}/base_model/README.md", "w") as f:
    f.write(model_card)

with open(f"{CHECKPOINTS_DIR}/instruct_model/README.md", "w") as f:
    f.write(model_card)

# Optional: Push to Hugging Face Hub
# from huggingface_hub import notebook_login
# notebook_login()
# 
# instruct_model.push_to_hub("your-username/crsm-2b-instruct")
# tokenizer.push_to_hub("your-username/crsm-2b-instruct")

print("Models saved and exported successfully!")
print(f"You can find the models in your Google Drive at: {CHECKPOINTS_DIR}")