In [None]:
import sys
sys.path.append("..")
import os
import re
import json
import parsers
import random
import torch
import plotly.express as px
from typing import Dict
from model import TextMappingModel
from config import Config
from constants import SPECIAL_TOKENS
from transformers import AutoTokenizer
from data import LPMappingDataset
from torch.utils.data import DataLoader
from utils import generate_decoder_inputs_outputs
from plotly.subplots import make_subplots

random.seed(42)

In [None]:
RES_FILEPATH = 'results/best-checkpoint/test.out.json'
CKPT_PATH = 'best-checkpoint.mdl'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model(model_dir):
    print(f'Loading model from {model_dir}')
    saved_dict = torch.load(os.path.join(model_dir, CKPT_PATH), map_location = DEVICE)
    config = Config.from_dict(saved_dict['config'])
    
    tokenizer = AutoTokenizer.from_pretrained(config.bert_model_name)
    tokenizer.add_tokens(SPECIAL_TOKENS)

    model = TextMappingModel(config)
    model.load_bert(config.bert_model_name)
    model.bert.resize_token_embeddings(len(tokenizer))
    model.load_state_dict(saved_dict['model'])
    model.to(DEVICE)
    
    print(f'Loading dataset for {model_dir}')
    dataset = LPMappingDataset(
        path = '../data/test.jsonl',
        tokenizer = tokenizer,
        max_length = config.max_length,
        gpu = torch.cuda.is_available(),
        enrich_ner = config.enrich_ner,
    )
    dataset.numberize()
    dataloader = DataLoader(dataset, batch_size = 1, shuffle = False, collate_fn = dataset.collate_fn)

    return model, dataloader, config, tokenizer

In [None]:
@torch.no_grad()
def get_output(model, batch, tokenizer, config):
    decoder_inputs_outputs = generate_decoder_inputs_outputs(
        batch,
        tokenizer,
        model,
        torch.cuda.is_available(),
        config.max_position_embeddings,
        replace_pad_tokens = False,
    )
    decoder_input_ids = decoder_inputs_outputs['decoder_input_ids']
    decoder_labels = decoder_inputs_outputs['decoder_labels']
    
    model(batch, decoder_input_ids, decoder_labels, tokenizer = tokenizer)
    return decoder_labels, model.encode(batch, decoder_input_ids, decoder_labels)


def plot_cross_attention(output, input_ids, tokenizer, title, label_ids = None):
    last_cross_attention = output['cross_attentions'][-1]
    assert last_cross_attention.size(0) == 1, f'Expected one example but found {last_cross_attention.size(0)}'
    assert input_ids.size(0) == 1, f'Expected one example but found {input_ids.size(0)}'
    assert label_ids is None or label_ids.size(0) == 1, f'Expected one example but found {label_ids.size(0)}'
    
    input_ids = input_ids[0]
    if label_ids is None:
        output_logits = output['logits'][0]
        label_ids = output_logits.argmax(dim = -1)
    else:
        label_ids = label_ids[0]
    
    last_cross_attention = last_cross_attention[0]
    # mean across heads
    last_cross_attention = last_cross_attention.mean(dim = 0)
    assert last_cross_attention.size(0) == label_ids.size(0)
    assert last_cross_attention.size(1) == input_ids.size(0)
    
    last_cross_attention = last_cross_attention.cpu().numpy()
    inputs = [tokenizer.decode(i) for i in input_ids]
    labels = [tokenizer.decode(i) for i in label_ids]
    
    fig_width = len(inputs) * 10
    fig_height = len(labels) * 10
    fig = px.imshow(
        last_cross_attention,
        width = fig_width,
        height = fig_height,
        aspect = 'auto',
    )
    fig.update_layout(
        title = title,
        coloraxis_colorbar = dict(
            len = fig_height * 0.93,
            lenmode = 'pixels'
        ),
    )
    fig.update_xaxes(
        tickangle = 270,
        tickfont = dict(family = 'Rockwell', size = 8),
        tickmode = 'array',
        tickvals = list(range(len(inputs))),
        ticktext = inputs,
    )
    fig.update_yaxes(
        tickfont = dict(family = 'Rockwell', size = 8),
        tickmode = 'array',
        tickvals = list(range(len(labels))),
        ticktext = labels,
    )
    
    return fig


In [None]:
default_model_dir = '../output/default/20230514_151250415'
default_noner_model_dir = '../output/default_noner/20230515_223104883'

ner_model, ner_dataloader, ner_config, ner_tokenizer = load_model(default_model_dir)
noner_model, noner_dataloader, noner_config, noner_tokenizer = load_model(default_noner_model_dir)

assert ner_config.bert_model_name == noner_config.bert_model_name

ner_data_iter = iter(ner_dataloader)
noner_data_iter = iter(noner_dataloader)

In [None]:
def plot_next(save = False, output_dir = None):
    ner_batch = next(ner_data_iter)
    noner_batch = next(noner_data_iter)
    assert ner_batch.doc_ids[0] == noner_batch.doc_ids[0]

    ner_labels, ner_output = get_output(
        model = ner_model,
        batch = ner_batch,
        tokenizer = ner_tokenizer,
        config = ner_config,
    )
    noner_labels, noner_output = get_output(
        model = noner_model,
        batch = noner_batch,
        tokenizer = noner_tokenizer,
        config = noner_config,
    )

    assert (ner_labels == noner_labels).all()

    ner_fig = plot_cross_attention(
        output = ner_output,
        input_ids = ner_batch.input_ids,
        tokenizer = ner_tokenizer,
        title = 'NER Augmentation',
        label_ids = ner_labels,
    )
    noner_fig = plot_cross_attention(
        output = noner_output,
        input_ids = noner_batch.input_ids,
        tokenizer = noner_tokenizer,
        title = 'No NER Augmentation',
        label_ids = noner_labels,
    )
    
    ner_fig.show()
    noner_fig.show()
    
    if save:
        assert output_dir is not None, f'output_dir is required to save the plots'
        
        ner_fig.write_image(os.path.join(output_dir, 'ner.png'))
        noner_fig.write_image(os.path.join(output_dir, 'noner.png'))

In [None]:
plot_next(save = True, output_dir = '../output')