# State Space Model Training Example

This notebook demonstrates how to use the State Space Model for sequence classification tasks.

State Space Models (SSMs) are a powerful class of models for sequence modeling that offer an alternative to Transformers and RNNs. They are based on continuous-time state space equations discretized for sequence processing.

## 1. Import Required Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import numpy as np

from models.state_space import StateSpaceModel, StateSpaceEncoder

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 2. Check Device Availability

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 3. Create a Synthetic Dataset

For this example, we'll create a simple binary classification task: classify sequences as having more 1s than 0s or vice versa.

In [None]:
def create_synthetic_dataset(num_samples=1000, seq_length=20, vocab_size=10, num_classes=2):
    """
    Create a synthetic dataset for sequence classification.
    Task: Classify sequences based on whether they contain more odd or even numbers.
    """
    # Generate random sequences
    sequences = torch.randint(0, vocab_size, (num_samples, seq_length))
    
    # Create labels: 1 if more odd numbers, 0 if more even numbers
    labels = torch.zeros(num_samples, dtype=torch.long)
    for i in range(num_samples):
        odd_count = (sequences[i] % 2).sum().item()
        labels[i] = 1 if odd_count > seq_length / 2 else 0
    
    return sequences, labels

# Create train and test datasets
train_sequences, train_labels = create_synthetic_dataset(num_samples=800, seq_length=20, vocab_size=10)
test_sequences, test_labels = create_synthetic_dataset(num_samples=200, seq_length=20, vocab_size=10)

print(f"Train dataset: {train_sequences.shape}, {train_labels.shape}")
print(f"Test dataset: {test_sequences.shape}, {test_labels.shape}")
print(f"\nExample sequence: {train_sequences[0]}")
print(f"Example label: {train_labels[0]}")

## 4. Create DataLoaders

In [None]:
batch_size = 32

train_dataset = TensorDataset(train_sequences, train_labels)
test_dataset = TensorDataset(test_sequences, test_labels)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## 5. Initialize the State Space Model

In [None]:
# Model hyperparameters
vocab_size = 10
num_classes = 2
embedding_dim = 64
d_state = 32
num_layers = 4
max_length = 20
dropout = 0.1
pooling = 'mean'  # Options: 'mean', 'max', 'last', 'cls'

# Create the model
model = StateSpaceModel(
    vocab_size=vocab_size,
    num_classes=num_classes,
    embedding_dim=embedding_dim,
    d_state=d_state,
    num_layers=num_layers,
    dropout=dropout,
    max_length=max_length,
    pooling=pooling,
    device=device
).to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model has {num_params:,} trainable parameters")
print(f"\nModel architecture:")
print(model)

## 6. Set Up Training

In [None]:
# Training hyperparameters
num_epochs = 10
learning_rate = 1e-3

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Learning rate scheduler (optional)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

## 7. Training Loop

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (sequences, labels) in enumerate(dataloader):
        sequences, labels = sequences.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(sequences)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    avg_loss = total_loss / len(dataloader)
    accuracy = 100. * correct / total
    return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
    """Evaluate the model."""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for sequences, labels in dataloader:
            sequences, labels = sequences.to(device), labels.to(device)
            
            outputs = model(sequences)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    avg_loss = total_loss / len(dataloader)
    accuracy = 100. * correct / total
    return avg_loss, accuracy

# Training history
train_losses = []
train_accs = []
test_losses = []
test_accs = []

# Training loop
print("Starting training...\n")
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    test_losses.append(test_loss)
    test_accs.append(test_acc)
    
    scheduler.step()
    
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"  Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
    print()

print("Training completed!")

## 8. Visualize Training Results

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot losses
ax1.plot(train_losses, label='Train Loss', marker='o')
ax1.plot(test_losses, label='Test Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Test Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot accuracies
ax2.plot(train_accs, label='Train Accuracy', marker='o')
ax2.plot(test_accs, label='Test Accuracy', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Test Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final Test Accuracy: {test_accs[-1]:.2f}%")

## 9. Test Model on Sample Predictions

In [None]:
# Get a few test samples
num_samples = 5
sample_sequences = test_sequences[:num_samples].to(device)
sample_labels = test_labels[:num_samples]

# Make predictions
model.eval()
with torch.no_grad():
    outputs = model(sample_sequences)
    _, predictions = outputs.max(1)

# Display results
print("Sample Predictions:\n")
for i in range(num_samples):
    seq = sample_sequences[i].cpu().numpy()
    true_label = sample_labels[i].item()
    pred_label = predictions[i].item()
    
    odd_count = sum(1 for x in seq if x % 2 == 1)
    
    print(f"Sample {i+1}:")
    print(f"  Sequence: {seq}")
    print(f"  Odd count: {odd_count}/{len(seq)}")
    print(f"  True label: {true_label} ({'Odd majority' if true_label == 1 else 'Even majority'})")
    print(f"  Predicted: {pred_label} ({'Odd majority' if pred_label == 1 else 'Even majority'})")
    print(f"  Correct: {'✓' if true_label == pred_label else '✗'}")
    print()

## 10. Model Configuration Options

The State Space Model supports various configuration options:

### Pooling Strategies
- `'mean'`: Average pooling over the sequence
- `'max'`: Max pooling over the sequence
- `'last'`: Use the last token's representation
- `'cls'`: Use the first token (CLS token style)

### Key Hyperparameters
- `embedding_dim`: Dimension of token embeddings and hidden states
- `d_state`: Dimension of the internal state space (larger = more expressive but slower)
- `num_layers`: Number of State Space blocks to stack
- `dropout`: Dropout probability for regularization

### Example: Try Different Pooling Methods

In [None]:
# Compare different pooling strategies
pooling_methods = ['mean', 'max', 'last', 'cls']
results = {}

for pooling in pooling_methods:
    # Create model with specific pooling
    test_model = StateSpaceModel(
        vocab_size=vocab_size,
        num_classes=num_classes,
        embedding_dim=32,  # Smaller for faster comparison
        d_state=16,
        num_layers=2,
        pooling=pooling,
        max_length=max_length,
        device=device
    ).to(device)
    
    # Quick evaluation on test set
    _, acc = evaluate(test_model, test_loader, criterion, device)
    results[pooling] = acc
    print(f"Pooling '{pooling}': {acc:.2f}% accuracy (untrained)")

print("\nNote: These are untrained models, so accuracies should be near random (50%).")
print("After training, different pooling methods may perform differently depending on the task.")

## Conclusion

This notebook demonstrated how to:
1. Create a State Space Model for sequence classification
2. Train it on a synthetic dataset
3. Evaluate its performance
4. Make predictions on new sequences

State Space Models offer an efficient alternative to Transformers and RNNs for sequence modeling tasks, with the following advantages:
- **Efficiency**: Linear-time complexity compared to quadratic for Transformers
- **Long sequences**: Better at handling long-range dependencies than RNNs
- **Simplicity**: Simpler architecture than attention mechanisms

You can adapt this notebook to work with real datasets by:
1. Loading your own data (e.g., from HuggingFace Datasets)
2. Adjusting the model hyperparameters
3. Using appropriate preprocessing and tokenization