# Mamba Model Training Example

This notebook demonstrates how to use the Mamba model for sequence classification tasks.

In [None]:
import numpy as np
import torch
import random
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import Dataset, DataLoader
from models.embedding import *
from models.mamba import Mamba, MambaEncoder

torch.autograd.set_detect_anomaly(True)

%load_ext autoreload
%autoreload 2

In [None]:
# Check for GPU availability
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print("Device:", device)

## Hyperparameters

In [None]:
n_epochs = 10
lr = 1e-4
num_layers = 4
embedding_dim = 64
batch_size = 32
max_length = 20
d_state = 16
d_conv = 4
expand_factor = 2
dropout = 0.1

## Create Synthetic Data

For testing purposes, we'll create synthetic sequence data.

In [None]:
# Create synthetic data
vocab_size_src = 100
vocab_size_trg = 50
num_samples = 1000

# Random sequences for testing
sequences_src = np.random.randint(0, vocab_size_src, size=(num_samples, max_length))
sequences_trg = np.random.randint(0, vocab_size_trg, size=(num_samples, max_length))

print(f"Source sequences shape: {sequences_src.shape}")
print(f"Target sequences shape: {sequences_trg.shape}")

## Create Dataset and DataLoader

In [None]:
class SyntheticDataset(Dataset):
    def __init__(self, src_seqs, trg_seqs):
        self.src_seqs = src_seqs
        self.trg_seqs = trg_seqs
    
    def __len__(self):
        return len(self.src_seqs)
    
    def __getitem__(self, idx):
        return (
            torch.tensor(self.src_seqs[idx], dtype=torch.long),
            torch.tensor(self.trg_seqs[idx], dtype=torch.long)
        )

dataset = SyntheticDataset(sequences_src, sequences_trg)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

print(f"Number of batches: {len(dataloader)}")

## Initialize Mamba Model

In [None]:
# Create Mamba model
model = MambaEncoder(
    embedding_type=EmbeddingType.POS_LEARNED,
    src_vocab_size=vocab_size_src,
    trg_vocab_size=vocab_size_trg,
    embedding_dim=embedding_dim,
    d_state=d_state,
    d_conv=d_conv,
    expand_factor=expand_factor,
    num_layers=num_layers,
    dropout=dropout,
    device=device,
    max_length=max_length
).to(device)

print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

## Test Forward Pass

In [None]:
# Test with a single batch
src_batch, trg_batch = next(iter(dataloader))
src_batch = src_batch.to(device)
trg_batch = trg_batch.to(device)

print(f"Input shape: {src_batch.shape}")
print(f"Target shape: {trg_batch.shape}")

# Forward pass
with torch.no_grad():
    output = model(src_batch)

print(f"Output shape: {output.shape}")
print(f"Expected shape: ({batch_size}, {max_length}, {vocab_size_trg})")
print("\nForward pass successful!")

## Training Loop (Optional)

In [None]:
# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.NLLLoss()

# Training loop
model.train()
losses = []

print("Starting training...")
for epoch in range(n_epochs):
    epoch_loss = 0
    for batch_idx, (src, trg) in enumerate(dataloader):
        src = src.to(device)
        trg = trg.to(device)
        
        # Forward pass
        output = model(src)
        
        # Reshape for loss calculation
        output = output.reshape(-1, vocab_size_trg)
        trg = trg.reshape(-1)
        
        # Calculate loss
        loss = criterion(output, trg)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(dataloader)
    losses.append(avg_loss)
    print(f"Epoch {epoch+1}/{n_epochs}, Loss: {avg_loss:.4f}")

print("\nTraining completed!")

## Plot Training Loss

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(losses, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Mamba Model Training Loss')
plt.grid(True)
plt.show()

## Model Summary

In [None]:
print("\nModel Architecture:")
print(model)
print(f"\nTotal Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")