In [1]:
import torch
from torch import nn
from model import  Transformer
from config import get_config, get_weights_file_path
from train import get_model, get_ds, greedy_decoder
import altair as alt
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
print("Using Device: {}".format(device))
config = get_config()
train_loader, val_loader, tokenize_src, tokenize_tgt = get_ds(config)
model = get_model(config, tokenize_src.get_vocab_size(), tokenize_tgt.get_vocab_size()).to(device)

#laoding weights
model_filename = get_weights_file_path(config, "08")
state = torch.load(model_filename)
model.load_state_dict(state["model_state_dict"])

Using Device: cuda
Max len of source sentence:  291
Max len of Target sentence:  349


<All keys matched successfully>

In [4]:
def load_next_batch():
    batch = next(iter(val_loader))
    encoder_input = batch["encoder_input"].to(device)
    encoder_mask = batch["encoder_mask"].to(device)
    decoder_input = batch["decoder_input"].to(device)
    decoder_mask = batch["decoder_mask"].to(device)

    encoder_input_token = [tokenize_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]
    decoder_input_token = [tokenize_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()]

    model_out = greedy_decoder(model, encoder_input, encoder_mask, tokenize_src, tokenize_tgt, config["seq_len"], device)

    return batch, encoder_input_token, decoder_input_token

In [12]:
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )

def get_attn_map(attn_type: str, layer: int, head: int):
    if attn_type == "encoder":
        attn = model.encoder.layers[layer].self_attention_block.attention_score
    elif attn_type == "decoder":
        attn = model.decoder.layers[layer].attention_block.attention_score
    elif attn_type == "encoder-decoder":
        print(model.decoder.layers[layer].cross_attention_block)
        attn = model.decoder.layers[layer].cross_attention_block.attention_score
    return attn[0, head].data

def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):
    df = mtx2df(
        get_attn_map(attn_type, layer, head),
        max_sentence_len,
        max_sentence_len,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        #.title(f"Layer {layer} Head {head}")
        .properties(height=400, width=400, title=f"Layer {layer} Head {head}")
        .interactive()
    )

def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int):
    charts = []
    for layer in layers:
        rowCharts = []
        for head in heads:
            rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))
        charts.append(alt.hconcat(*rowCharts))
    return alt.vconcat(*charts)

In [13]:
batch, encoder_input_token, decoder_input_token = load_next_batch()
print(f"Source: {batch['src_text'][0]}")
print(f"Target: {batch['tgt_text'][0]}")
sen_len = encoder_input_token.index("[PAD]")
print(sen_len)

Source: Compress display of addresses in TO/CC/BCC to the number specified in address_count.
Target: निर्दिष्ट संख्या के लिए TO/CC/BCC में पता के प्रदर्शन को सिकोड़े address_count में.
19


In [14]:
layers = [0,1,2]
head = [0,1,2,3,4,5,6,7]

get_all_attention_maps("encoder", layers, head, encoder_input_token, encoder_input_token, min(20, sen_len))

In [15]:
#decoder attention
get_all_attention_maps("decoder", layers, head, decoder_input_token, decoder_input_token, min(20, sen_len))

In [16]:
# for crross attention
get_all_attention_maps("encoder-decoder", layers, head, encoder_input_token, decoder_input_token, min(20, sen_len))

MultiHeadAttention(
  (dropout): Dropout(p=0.1, inplace=False)
  (w_q): Linear(in_features=512, out_features=512, bias=True)
  (w_k): Linear(in_features=512, out_features=512, bias=True)
  (w_v): Linear(in_features=512, out_features=512, bias=True)
  (w_o): Linear(in_features=512, out_features=512, bias=True)
)


AttributeError: 'MultiHeadAttention' object has no attribute 'attention_score'