In [1]:
import torch
import torch.nn as nn
import altair as alt
import pandas as pd
import numpy as np
import warnings
import re
import gc
warnings.filterwarnings("ignore")


from transformer_model import Transformer
from config import get_config, get_weights_file_path
from train import get_model, get_ds
from validation import greedy_valid_decode

In [2]:
# Define the config params
config = get_config()

In [3]:
# Define the device
device = torch.device(config['device'])
print(f'Using device: {device}')

Using device: cpu


In [4]:
# Train, val dataloader, src, tgt tokenizers, model: Transformer
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

# Load the pretrained weights
# model_filename = get_weights_file_path(config, f"29")
# if model_filename:
#     state = torch.load(model_filename)
#     model.load_state_dict(state['model_state_dict'])

Max length of source sentence: 229
Max length of target sentence: 195


In [5]:
def load_next_valid_batch():
    # Load a sample random batch from the validation set
    batch = next(iter(val_dataloader))
    encoder_input = batch["encoder_input"].to(device)
    encoder_mask = batch["encoder_mask"].to(device)
    decoder_input = batch["decoder_input"].to(device)

    # Check the batch size is 1
    assert encoder_input.size(0) == 1, 'Batch size must be 1 for validation'
    # In model Self attention blocks saved self attention scores
    model_out = greedy_valid_decode(model, encoder_input, encoder_mask, tokenizer_tgt, device, config['max_len'], config['temperature'])
    # Strs to screen translation in attention heatmap
    encoder_input_tokens = [tokenizer_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]   # Input : Encoder Str
    decoder_output_tokens = [tokenizer_tgt.id_to_token(idx) for idx in model_out.cpu().numpy()]          # Output: Decoder Generated Str
  
    return batch, encoder_input_tokens, decoder_output_tokens

In [6]:
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    return pd.DataFrame(
        [
            ( 
                r,
                c,
                float(m[r, c]),
                "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )


def get_attn_map(attn_type: str, layer: int, head: int):
    if attn_type == "encoder":
        attn = model.encoder.layers[layer].multi_head_attention_block.attention_score
    elif attn_type == "decoder":
        attn = model.decoder.layers[layer].masked_self_attention_block.attention_score
    elif attn_type == "encoder-decoder":
        attn = model.decoder.layers[layer].cross_attention_block.attention_score
    return attn[0, head].data


def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):
    df = mtx2df(
        get_attn_map(attn_type, layer, head),
        max_sentence_len,
        max_sentence_len,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        #.title(f"Layer {layer} Head {head}")
        .properties(height=400, width=400, title=f"Layer {layer} Head {head}")
        .interactive()
    )


def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int):
    charts = []
    for layer in layers:
        rowCharts = []
        for head in heads:
            rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))
        charts.append(alt.hconcat(*rowCharts))
    return alt.vconcat(*charts)

In [7]:
batch, encoder_input_tokens, decoder_output_tokens = load_next_valid_batch()
print(f'Source: {batch["src_text"][0]}')
print(f'Target: {batch["tgt_text"][0]}')
sentence_len = encoder_input_tokens.index("[PAD]")

Source: 'Just one advantage: I live in my own house, which is neither bought nor hired.
Target: -- Расчет один, что дома живу, непокупное, ненанятое.


In [8]:
layers = [0, 1, 2]
heads = [0, 1, 2]

# Encoder Self-Attention
get_all_attention_maps("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))

In [9]:
# Decoder self attention
get_all_attention_maps("decoder", layers, heads, decoder_output_tokens, decoder_output_tokens, min(20, sentence_len))

In [10]:
# Cross attention: Encoder - Decoder
get_all_attention_maps("encoder-decoder", layers, heads, encoder_input_tokens, decoder_output_tokens, min(20, sentence_len))