In [None]:
%pip install torch numpy matplotlib

import sys
import os
from pathlib import Path

# Add parent directory to path for imports
notebook_path = Path.cwd()
project_root = notebook_path.parent
sys.path.append(str(project_root))

import torch
import matplotlib.pyplot as plt
import seaborn as sns

from transformer.config import TransformerConfig
from transformer.model import Transformer


In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load checkpoint
checkpoint_path = project_root / 'checkpoints' / 'transformer_model.pt'
checkpoint = torch.load(checkpoint_path, map_location=device)

# Create model configuration
config = TransformerConfig(**checkpoint['config'])

# Create and load model
model = Transformer(config).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("\nModel configuration:")
for key, value in config.__dict__.items():
    print(f"{key}: {value}")


In [None]:
# Create random input sequence
batch_size = 3
src = torch.randint(4, config.vocab_size, (batch_size, config.max_seq_len), device=device)

# Generate sequences
with torch.no_grad():
    output_sequences, attention_weights = model.generate(
        src,
        max_length=config.max_seq_len,
        temperature=0.7
    )

print("Input sequences:")
print(src)
print("\nGenerated sequences:")
print(output_sequences)


In [None]:
def plot_attention(attention_weights, title):
    """Plot attention weights as a heatmap."""
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attention_weights.cpu().numpy(),
        cmap='viridis',
        xticklabels=range(attention_weights.size(-1)),
        yticklabels=range(attention_weights.size(-2))
    )
    plt.title(title)
    plt.xlabel('Key position')
    plt.ylabel('Query position')
    plt.show()

# Plot encoder self-attention (first layer, first head, first batch)
encoder_attention = attention_weights['encoder_attention'][0, 0, 0]
plot_attention(encoder_attention, 'Encoder Self-Attention (Layer 0, Head 0)')

# Plot decoder cross-attention (first layer, first head, first batch)
decoder_cross_attention = attention_weights['decoder_cross_attention'][0, 0, 0]
plot_attention(decoder_cross_attention, 'Decoder Cross-Attention (Layer 0, Head 0)')
