# Lab 2.4.2: Mamba Architecture Study - SOLUTIONS

Complete solutions for the Mamba architecture exercises.

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Exercise Solution: State Evolution Analysis

In [None]:
# Load model
MODEL_NAME = 'state-spaces/mamba-130m-hf'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16, device_map='auto')

# Test texts
texts = {
    'Repetitive': 'the the the the the the the the the the',
    'Structured (Code)': 'def foo(x):\n    return x * 2\n\ndef bar(y):\n    return y + 1',
    'Diverse': 'Apple banana computer dance elephant future galaxy horizon ice jazz',
}

fig, axes = plt.subplots(len(texts), 2, figsize=(14, 4*len(texts)))

for idx, (label, text) in enumerate(texts.items()):
    tokens = tokenizer.encode(text, add_special_tokens=False)[:30]
    input_ids = torch.tensor([tokens], device=device)
    
    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=True)
    
    hidden = outputs.hidden_states[1][0].cpu().numpy()[:, :32]
    
    # Heatmap
    axes[idx, 0].imshow(hidden.T, aspect='auto', cmap='RdBu_r')
    axes[idx, 0].set_title(f'{label}: State Evolution')
    axes[idx, 0].set_ylabel('State Dim')
    
    # State magnitude
    magnitudes = np.linalg.norm(outputs.hidden_states[1][0].cpu().numpy(), axis=1)
    axes[idx, 1].plot(magnitudes, 'b-', linewidth=2)
    axes[idx, 1].set_title(f'{label}: State Magnitude')
    axes[idx, 1].set_ylabel('L2 Norm')

axes[-1, 0].set_xlabel('Token Position')
axes[-1, 1].set_xlabel('Token Position')
plt.tight_layout()
plt.show()

print('\n Key Observations:')
print('1. Repetitive text: State converges to stable pattern')
print('2. Structured code: Regular patterns matching code structure')
print('3. Diverse text: High variance as model processes new concepts')

In [None]:
# Cleanup
del model, tokenizer
import gc
torch.cuda.empty_cache()
gc.collect()
print('Cleanup complete!')