In [1]:
import torch
from torch.utils.data import DataLoader

from src.dataset.en_vi_dataset import EN_VIDataset
from src.models.model import Transformer
from src.utils.utils import input_target_collate_fn

import matplotlib.pyplot as plt
import traceback

In [2]:
dev = 'cpu'

In [3]:
pretrained_path = 'weights/Transformer-En2Vi-CPE-WordPiece/best_bleu.pth'
config = torch.load(pretrained_path, map_location=dev)

In [4]:
token_type = config['config']['dataset']['train']
token_type = token_type.get(
    'config', {'token_type': 'bpe'}).get('token_type', 'bpe')
data_cfg = {
    'src_path': 'data/en-vi/raw-data/val/tst2012.en',
    'trg_path': 'data/en-vi/raw-data/val/tst2012.vi',
}
if token_type == 'bpe':
    data_cfg.update({
        'token_type': 'bpe',
        'src_vocab': ["vocab/english_bpe/en-bpe-minfreq5-vocab.json",
                      "vocab/english_bpe/en-bpe-minfreq5-merges.txt"],
        'trg_vocab': ["vocab/vietnamese_bpe/vi-bpe-minfreq5-vocab.json",
                      "vocab/vietnamese_bpe/vi-bpe-minfreq5-merges.txt"],
        # src_vocab: ["vocab/shared/shared-vocab.json", "vocab/shared/shared-merges.txt"]
        # trg_vocab: ["vocab/shared/shared-vocab.json", "vocab/shared/shared-merges.txt"]
    })
elif token_type == 'wordpiece':
    data_cfg.update({
        'token_type': 'wordpiece',
        'src_vocab': 'vocab/english_word/en-wordpiece-minfreq5-vocab.txt',
        'trg_vocab': 'vocab/vietnamese_word/vi-wordpiece-minfreq5-vocab.txt',
    })
ds = EN_VIDataset(**data_cfg)
dl = DataLoader(ds, batch_size=2, shuffle=True,
                collate_fn=input_target_collate_fn)

In [5]:
TRG_EOS_TOKEN = '</s>'
TRG_EOS_ID = ds.vi_tokenizer.token_to_id(TRG_EOS_TOKEN)

In [6]:
model = Transformer(
    n_src_vocab=ds.en_tokenizer.get_vocab_size(),
    n_trg_vocab=ds.vi_tokenizer.get_vocab_size(),
    src_pad_idx=ds.en_tokenizer.token_to_id('<pad>'),
    trg_pad_idx=ds.vi_tokenizer.token_to_id('<pad>'),
    **config['config']['model']
).to(dev)
model.load_state_dict(config['model_state_dict'])
model.eval()
print()




# **Visualize attention**

In [7]:
def visualize_attn(attn, dsts, ques):
    for j, (b, dst, que) in enumerate(zip(attn, dsts, ques)):
        print(f'===== Input {j} =====')
        fig, ax = plt.subplots(figsize=(10, 10), dpi=150)
        ax.imshow(b.mean(0).detach().cpu())
        
        ax.set_xticks(range(len(dst)))
        ax.set_xticklabels(dst)
        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

        ax.set_yticks(range(len(que)))
        ax.set_yticklabels(que)
        plt.show()

In [8]:
@torch.no_grad()
def make_prediction(model, src, trg):
    src = src.to(dev)
    trg = trg.to(dev)

    src_ = src.cpu().numpy()
    src_tokens = [[ds.en_tokenizer.id_to_token(ss) for ss in s] for s in src_]
    src_ = ds.en_tokenizer.decode_batch(src_)
    # print('Source', src_tokens)

    trg_ = trg.cpu().numpy()
    trg_tokens = [[ds.vi_tokenizer.id_to_token(ss) for ss in s] for s in trg_]
    # trg_ = ds.vi_tokenizer.decode_batch(trg_)
    # print('Target', trg_tokens)

    pred = model(src, trg[:, :-1])
    pred_ = pred.argmax(-1).cpu().numpy()
    pred_tokens = [[ds.vi_tokenizer.id_to_token(ss) for ss in s] for s in pred_]
    # pred_ = ds.vi_tokenizer.decode_batch(pred_)
    # print('Prediction', pred_tokens)

    return src_tokens, trg_tokens, pred_tokens

In [9]:
def visualize_attn_hook(m, i, o):
    global attn
    try:
        attn.append(o[1])
    except:
        traceback.print_exc()
        # handle.remove()
    # handle.remove()

In [10]:
for i, (src, trg) in enumerate(dl):
    break

## **Encoder self-attention**

In [11]:
nlayers = len(model.encoder.layer_stack)
handle = [model.encoder.layer_stack[l].self_attn.attention.register_forward_hook(visualize_attn_hook)
            for l in range(nlayers)]
    
attn = []
src_tokens, trg_tokens, pred_tokens = make_prediction(model, src, trg)
for l in range(nlayers):
    print(f'=== Layer {l} ===')
    handle[l].remove()
    visualize_attn(attn[l], src_tokens, src_tokens)

## **Decoder self-attention**

In [None]:
nlayers = len(model.encoder.layer_stack)
handle = [model.decoder.layer_stack[l].masked_self_attn.attention.register_forward_hook(visualize_attn_hook)
            for l in range(nlayers)]
    
attn = []
src_tokens, trg_tokens, pred_tokens = make_prediction(model, src, trg)
for l in range(nlayers):
    print(f'=== Layer {l} ===')
    handle[l].remove()
    visualize_attn(attn[l], trg_tokens, trg_tokens)

## **Encoder-Decoder attention**

In [None]:
nlayers = len(model.encoder.layer_stack)
handle = [model.decoder.layer_stack[l].enc_dec_attn.attention.register_forward_hook(visualize_attn_hook)
            for l in range(nlayers)]
    
attn = []
src_tokens, trg_tokens, pred_tokens = make_prediction(model, src, trg)
for l in range(nlayers):
    print(f'=== Layer {l} ===')
    handle[l].remove()
    visualize_attn(attn[l], src_tokens, trg_tokens)