In [None]:
import torch
import torch.nn as nn
from model import Transformer
from config import get_config, get_weights_file_path
from train import get_model, get_dataset, greedy_decode
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
config = get_config()
train_dataloader, valid_dataloader, vocab_src, vocab_tar = get_dataset(config)
model = get_model(config, vocab_src.get_vocab_size(), vocab_tar.get_vocab_size()).to(device)

In [None]:
# load the pretrained weights
model_filename = get_weights_file_path(config, str(config['num_epochs']-1))
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])

In [None]:
def load_next_batch():
    
    batch = next(iter(valid_dataloader))
    encoder_input = batch["encoder_input"].to(device) # torch.tensor: (batch_size=1, seq_len)
    encoder_mask = batch["encoder_mask"].to(device) # torch.tensor: (batch_size=1, 1, 1, seq_len)
    decoder_input = batch["decoder_input"].to(device) # torch.tensor: (batch_size=1, seq_len)
    decoder_mask = batch["decoder_mask"].to(device) # torch.tensor: (batch_size=1, 1, seq_len, seq_len)
    # for validation, the batch_size = 1

    encoder_input_tokens = [vocab_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()] # (seq_len)
    decoder_input_tokens = [vocab_tar.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()] # (seq_len)

    # check if the batch_size for validation equals to 1
    assert encoder_input.size(0) == 1, "Batch size for validation must be 1 ..."

    model_output = greedy_decode(model, encoder_input, encoder_mask, vocab_src, vocab_tar, config["seq_len"], device)

    return batch, encoder_input_tokens, decoder_input_tokens, model_output

In [None]:
# convert a 2D matrix into a structured Pandas DataFrame, with added context from row_tokens and col_tokens
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )

In [None]:
def get_attention_map(attention_type: str, layer: int, head: int):
    
    """
        model:           Transformer
        encoder:         Encoder
        layers:          nn.ModuleList
        Layers[layer]:   EncoderBlock
        self_attention:  MultiHeadAttention
        cross_attention: MultiHeadAttention
        for class MultiHeadAttention, we have self.attention_scores
    """
    if attention_type == "encoder":
        attention = model.encoder.layers[layer].self_attention.attention_scores
    elif attention_type == "decoder":
        attention = model.decoder.layers[layer].self_attention.attention_scores
    elif attention_type == "encoder-decoder":
        attention = model.decoder.layers[layer].cross_attention.attention_scores # attention_scores -> (batch_size, num_head, seq_len, seq_len)
    
    return attention[0, head].detach() # return the attention map (a 2D matrix) with a shape of (seq_len, seq_len)

In [None]:
# visualize an attention map using matplotlib/seaborn
def plot_attention_map(attention_type, layer, head, row_tokens, col_tokens, max_sentence_len):

    """
    args:
        attention_type (str)  : 'encoder', 'decoder', or 'encoder-decoder'
        layer (int)           : layer index
        head (int)            : head index
        row_tokens (list[str]): list of tokens for y-axis (queries)
        col_tokens (list[str]): list of tokens for x-axis (keys)
        max_sentence_len (int): expected length of token sequences 
    """

    # get attention matrix
    attention_matrix = get_attention_map(attention_type, layer, head)
    # convert to numpy for plotting
    attention = attention_matrix.cpu().numpy()

    row_tokens = row_tokens[:max_sentence_len]
    col_tokens = col_tokens[:max_sentence_len]

    # plot attention map
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(attention, xticklabels=col_tokens, yticklabels=row_tokens, cmap="viridis",
                square=True, cbar=True, linewidths=0.5, linecolor='gray')
    
    ax.set_title(f"Layer {layer} - Head {head} Attention Map")
    ax.set_xlabel("col_tokens")
    ax.set_ylabel("row_tokens")
    plt.xticks(rotation=45, ha="right")
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

In [None]:
# display a grid of attention maps
def get_all_attention_maps(attention_type: str,
                           layers: list[int],
                           heads: list[int],
                           row_tokens: list[str],
                           col_tokens: list[str],
                           max_sentence_len: int,
                           figsize_per_plot=(4, 4),
                           cmap="viridis"):

    if not layers or not heads:
        raise ValueError("Both 'layers' and 'heads' lists must be non-empty.")
    
    num_rows = len(layers)
    num_cols = len(heads)
    figsize = (figsize_per_plot[0] * num_cols, figsize_per_plot[1] * num_rows)

    fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize, squeeze=False)

    for i, layer in enumerate(layers):
        for j, head in enumerate(heads):
            ax = axes[i, j]

            # get attention map
            attention = get_attention_map(attention_type, layer, head).detach().cpu().numpy()

            row_tokens = row_tokens[:max_sentence_len]
            col_tokens = col_tokens[:max_sentence_len]

            # plot heatmap
            sns.heatmap(attention, xticklabels=col_tokens, yticklabels=row_tokens, cmap=cmap,
                        square=True, cbar=(j == num_cols-1), ax=ax, linewidths=0.5, linecolor='gray') # only last column has colorbar
            ax.set_title(f"L{layer} - H{head}", fontsize=10)
            ax.set_xlabel("")
            ax.set_ylabel("")
            ax.tick_params(axis='x', labelrotation=45)
            ax.tick_params(axis='y', labelrotation=0)

    plt.suptitle(f"{attention_type.title()} Attention Maps", fontsize=14) # set global title
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show() 