In [1]:
import torch
import torch.nn as nn
from model import Transformer
from train import get_model, get_ds, greedy_decode
import altair as alt
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")

In [2]:
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
from config import get_config, get_weights_file_path
config = get_config()
train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)
model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(device)

# Load the pretrained weights
model_filename = get_weights_file_path(config, f"01")
# model_filename = './weights/L1H6D36maxPos200_99.pt'
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])

Max length of source sentence: 100
Max length of target sentence: 100


<All keys matched successfully>

In [4]:
# plot image    
if False:
    import matplotlib.pyplot as plt
    %matplotlib inline

    ss = model.decoder.layers[0].cross_attention_block.attention_scores
    ss = ss[0][0].detach().cpu().numpy()
    # Plot the attention scores
    for h in range(1, config['num_heads']):
        s = model.decoder.layers[0].cross_attention_block.attention_scores
        s = s[0][h].detach().cpu().numpy()
        ss = ss + s
        ss = ss / config['num_heads']
        plt.figure(figsize=(4,4))
        plt.imshow(s, cmap='coolwarm', interpolation='nearest')
        plt.colorbar()
        plt.show()
    plt.close()


In [5]:
def load_next_batch():
    # Load a sample batch from the validation set
    batch = next(iter(val_dataloader))
    encoder_input = batch["encoder_input"].to(device)
    encoder_mask = batch["encoder_mask"].to(device)
    decoder_input = batch["decoder_input"].to(device)
    decoder_mask = batch["decoder_mask"].to(device)

    encoder_input_tokens = [vocab_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]
    decoder_input_tokens = [vocab_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()]

    # check that the batch size is 1
    assert encoder_input.size(
        0) == 1, "Batch size must be 1 for validation"

    model_out = greedy_decode(
        model, encoder_input, encoder_mask, vocab_src, vocab_tgt, config['seq_len'], device)
    
    return batch, encoder_input_tokens, decoder_input_tokens

In [6]:
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )

def get_attn_map(attn_type: str, layer: int, head: int):
    if attn_type == "encoder":
        attn = model.encoder.layers[layer].self_attention_block.attention_scores
    elif attn_type == "decoder":
        attn = model.decoder.layers[layer].self_attention_block.attention_scores
    elif attn_type == "encoder-decoder":
        attn = model.decoder.layers[layer].cross_attention_block.attention_scores
    return attn[0, head].data

def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):
    df = mtx2df(
        get_attn_map(attn_type, layer, head),
        max_sentence_len,
        max_sentence_len,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        #.title(f"Layer {layer} Head {head}")
        .properties(height=400, width=400, title=f"Layer {layer} Head {head}")
        .interactive()
    )

def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int):
    charts = []
    for layer in layers:
        rowCharts = []
        for head in heads:
            rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))
        charts.append(alt.hconcat(*rowCharts))
    return alt.vconcat(*charts)

In [7]:
batch, encoder_input_tokens, decoder_input_tokens = load_next_batch()
print(f'Source: {batch["src_text"][0]}')
print(f'Target: {batch["tgt_text"][0]}')
try:
    sentence_len = encoder_input_tokens.index("[PAD]")
except:
    sentence_len = config['seq_len']

Source: G45 G90 G15 G81 G6 G26 G52 G71 G73 G29 G11 G84 G51 G19 G82 G37 G30 G78 G25 G9 G1 G72 G4 G22 G2 G97 G33 G77 G60 G36 G56 G67 G100 G66 G3 G59 G91 G61 G65 G94 G55 G58 G62 G68 G57 G92 G89 G27 G44 G70 G21 G64 G99 G86 G63 G80 G12 G35 G32 G34 G10 G14 G93 G47 G95 G46 G83 G13 G85 G5 G96 G43 G79 G7 G98 G53 G54 G75 G23 G18 G48 G40 G31 G49 G42 G74 G50 G28 G24 G39 G20 G76 G88 G16 G8 G87 G41 G38 G17 G69
Target: G45 G44 G90 G15 G81 G6 G26 G2 G9 G71 G19 G11 G73 G37 G29 G84 G82 G25 G72 G36 G33 G1 G30 G66 G60 G67 G51 G100 G97 G52 G78 G56 G55 G91 G61 G4 G22 G94 G65 G59 G62 G57 G68 G58 G12 G64 G3 G70 G92 G35 G63 G83 G7 G39 G27 G41 G93 G40 G54 G48 G31 G69 G23 G21 G87 G89 G86 G49 G88 G99 G38 G5 G79 G34 G10 G18 G32 G76 G98 G74 G24 G14 G46 G28 G47 G20 G53 G17 G43 G85 G13 G8 G50 G75 G16 G77 G95 G42 G96 G80


In [21]:
# layers = [ i for i in range(config['num_layers']) ]
# heads = [ i for i in range(config['num_heads']) ]
layers = [0]
heads = [0]

# Encoder Self-Attention
get_all_attention_maps("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, min(40, sentence_len))


In [22]:
# Decoder Self-Attention
get_all_attention_maps("decoder", layers, heads, decoder_input_tokens, decoder_input_tokens, min(40, sentence_len))

In [23]:
# cross-attention
get_all_attention_maps("encoder-decoder", layers, heads, encoder_input_tokens, decoder_input_tokens, min(40, sentence_len))