In [2]:
# Install required packages

!pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
!pip install datasets==2.15.0 tokenizers==0.13.3
!pip install altair==5.1.1 pandas matplotlib ipywidgets tqdm




In [4]:
# Verify NumPy version
import numpy as np
print(f"NumPy version: {np.__version__}")


NumPy version: 1.24.3


In [5]:
import torch
import torch.nn as nn
from model import build_transformer
from config import get_config, get_weights_file_path
import pandas as pd
import altair as alt
import warnings
warnings.filterwarnings("ignore")
from pathlib import Path

In [6]:
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [7]:
# Define causal mask function
def causal_mask(size):
    """Create a causal mask for the decoder."""
    mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
    return mask == 0

# Load configuration and tokenizers
config = get_config()
from tokenizers import Tokenizer
tokenizer_src = Tokenizer.from_file(config['tokenizer_file'].format(config['lang_src']))
tokenizer_tgt = Tokenizer.from_file(config['tokenizer_file'].format(config['lang_tgt']))

# Load the model
model = build_transformer(
    tokenizer_src.get_vocab_size(),
    tokenizer_tgt.get_vocab_size(),
    config["seq_len"],
    config['seq_len'],
    d_model=config['d_model']
).to(device)

# Load pretrained weights
model_filename = get_weights_file_path(config, "30")  # Adjust epoch as needed
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state['model_state_dict'])
model.eval()
print(f"Successfully loaded model from {model_filename}")

Successfully loaded model from opus_books_weights/tmodel_30.pt


In [8]:
# After loading the model but before using it, add a forward method
def forward_hook(self, src, tgt, src_mask, tgt_mask):
    """Forward pass for the transformer model"""
    # Encode the source
    enc_output = self.encode(src, src_mask)
    # Decode using encoder output and target
    dec_output = self.decode(enc_output, src_mask, tgt, tgt_mask)
    # Project to vocabulary size
    return self.project(dec_output)

# Add the forward method to the model class
import types
model.forward = types.MethodType(forward_hook, model)

In [9]:

# Create batch with sample data for visualization
def create_sample_batch():
    # Create a simple sample for demonstration
    sample_text_src = "Hello, how are you?"
    sample_text_tgt = "Ciao, come stai?"
    
    # Tokenize
    encoder_tokens = tokenizer_src.encode(sample_text_src).ids
    decoder_tokens = tokenizer_tgt.encode(sample_text_tgt).ids
    
    # Add SOS and EOS tokens
    sos_token_id = tokenizer_src.token_to_id('[SOS]')
    eos_token_id = tokenizer_src.token_to_id('[EOS]')
    pad_token_id = tokenizer_src.token_to_id('[PAD]')
    
    # Create encoder input (add SOS, EOS, and padding)
    encoder_input = [sos_token_id] + encoder_tokens + [eos_token_id]
    max_len = config['seq_len']
    if len(encoder_input) < max_len:
        encoder_input += [pad_token_id] * (max_len - len(encoder_input))
    encoder_input = torch.tensor([encoder_input], dtype=torch.long).to(device)
    
    # Create decoder input (add SOS and padding)
    decoder_input = [sos_token_id] + decoder_tokens
    if len(decoder_input) < max_len:
        decoder_input += [pad_token_id] * (max_len - len(decoder_input))
    decoder_input = torch.tensor([decoder_input], dtype=torch.long).to(device)
    
    # Create masks
    encoder_mask = (encoder_input != pad_token_id).unsqueeze(1).unsqueeze(1).int().to(device)
    decoder_mask = (decoder_input != pad_token_id).unsqueeze(1).int().to(device) & causal_mask(decoder_input.size(1)).to(device)
    
    # Convert token IDs back to tokens for display
    encoder_input_tokens = [tokenizer_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]
    decoder_input_tokens = [tokenizer_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()]
    
    # Run the model to populate attention scores
    with torch.no_grad():
        model(encoder_input, decoder_input, encoder_mask, decoder_mask)
    
    batch = {
        "encoder_input": encoder_input,
        "decoder_input": decoder_input,
        "encoder_mask": encoder_mask,
        "decoder_mask": decoder_mask,
        "src_text": sample_text_src,
        "tgt_text": sample_text_tgt
    }
    
    return batch, encoder_input_tokens, decoder_input_tokens

In [10]:
# Functions for visualization
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    """Convert attention matrix to DataFrame for visualization."""
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )

def get_attn_map(attn_type: str, layer: int, head: int):
    """Get attention scores from the model."""
    if attn_type == "encoder":
        attn = model.encoder.layers[layer].self_attention_block.attention_scores
    elif attn_type == "decoder":
        attn = model.decoder.layers[layer].self_attention_block.attention_scores
    elif attn_type == "encoder-decoder":
        attn = model.decoder.layers[layer].cross_attention_block.attention_scores
    return attn[0, head].data

def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):
    """Create attention map visualization for a specific layer and head."""
    df = mtx2df(
        get_attn_map(attn_type, layer, head),
        max_sentence_len,
        max_sentence_len,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        .properties(height=400, width=400, title=f"Layer {layer} Head {head}")
        .interactive()
    )


In [11]:

def get_all_attention_maps(attn_type: str, layers: list, heads: list, row_tokens: list, col_tokens, max_sentence_len: int):
    """Create and display attention maps for multiple layers and heads."""
    charts = []
    for layer in layers:
        rowCharts = []
        for head in heads:
            rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))
        charts.append(alt.hconcat(*rowCharts))
    return alt.vconcat(*charts)

# Create sample and visualize
try:
    batch, encoder_input_tokens, decoder_input_tokens = create_sample_batch()
    print(f'Source: {batch["src_text"]}')
    print(f'Target: {batch["tgt_text"]}')
    
    # Find useful sentence length (excluding padding)
    sentence_len = encoder_input_tokens.index("[PAD]") if "[PAD]" in encoder_input_tokens else len(encoder_input_tokens)
    print(f"Sentence length: {sentence_len}")
    
    # Define which layers and heads to visualize
    layers = [0, 1, 2]  # Adjust based on your model
    heads = [0, 1, 2, 3]  # Using fewer heads for clearer display
    max_display_len = min(20, sentence_len)  # Limit display length
    
    # Visualize Encoder Self-Attention
    print("Generating Encoder Self-Attention Visualization...")
    encoder_attn = get_all_attention_maps("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, max_display_len)
    display(encoder_attn)
    
    # Visualize Decoder Self-Attention
    print("Generating Decoder Self-Attention Visualization...")
    decoder_attn = get_all_attention_maps("decoder", layers, heads, decoder_input_tokens, decoder_input_tokens, max_display_len)
    display(decoder_attn)
    
    # Visualize Encoder-Decoder Cross-Attention
    print("Generating Encoder-Decoder Cross-Attention Visualization...")
    cross_attn = get_all_attention_maps("encoder-decoder", layers, heads, decoder_input_tokens, encoder_input_tokens, max_display_len)
    display(cross_attn)
    
except Exception as e:
    print(f"Error during visualization: {e}")

Source: Hello, how are you?
Target: Ciao, come stai?
Sentence length: 8
Generating Encoder Self-Attention Visualization...


Generating Decoder Self-Attention Visualization...


Generating Encoder-Decoder Cross-Attention Visualization...
