Visualize how well the class embeddings attend on words and
sentences. The expected result would be that the “married”
class embedding, for example, attends heavily on words and
sentences related to marriage like “married”, “husband”, “wife”, etc.

# Imports

In [None]:
%load_ext autoreload
%autoreload 2

from pprint import pprint

from data.power.split.split_dir import SplitDir

from util import plot_tensor
from pathlib import Path
from random import shuffle
from typing import List, Tuple

import pandas as pd
from torch import tensor, Tensor
from torch.utils.data import DataLoader

from dao.ower.ower_dir import Sample
from data.power.samples.samples_dir import SamplesDir
from data.power.texter_pkl import TexterPkl

pd.set_option('display.max_colwidth', None)
pd.set_option('display.precision', 2)

# Args

In [None]:
texter_pkl_path = '../data/power/texter-v2/context_attend_cde-irt-5-marked.pkl'

# Input data
samples_dir_path = '../data/power/samples-v5/cde-irt-5-marked/'
class_count = 100
sent_count = 5

split_dir_path = '../data/power/split-v2/cde-0/'

# Pre-processing
sent_len = 64

# Testing
batch_size = 4

# Check data

In [None]:
#
# Check that (input) POWER Texter PKL exists
#

texter_pkl = TexterPkl(Path(texter_pkl_path))
texter_pkl.check()

#
# Check that (input) POWER Samples Directory exists
#

samples_dir = SamplesDir(Path(samples_dir_path))
samples_dir.check()

#
# Check that (input) Power Split Directory exists
#

split_dir = SplitDir(Path(split_dir_path))
split_dir.check()

# Load Texter and test data

In [None]:
texter = texter_pkl.load().cpu()

test_set = samples_dir.test_samples_tsv.load(class_count, sent_count)
test_ent_to_sents = {sample.ent: sample.sents for sample in test_set}


def generate_batch(batch: List[Sample]) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    """
    :param    batch:            [Sample(ent, ent_lbl, [class], [sent])]

    :return:  ent_batch:        IntTensor[batch_size],
              tok_lists_batch:  IntTensor[batch_size, sent_count, sent_len],
              masks_batch:      IntTensor[batch_size, sent_count, sent_len],
              classes_batch:    IntTensor[batch_size, class_count]
    """

    ent_batch, _, classes_batch, sents_batch = zip(*batch)

    for sents in sents_batch:
        shuffle(sents)

    flat_sents_batch = [sent for sents in sents_batch for sent in sents]

    encoded = texter.tokenizer(flat_sents_batch, padding=True, truncation=True, max_length=sent_len,
                               return_tensors='pt')

    b_size = len(ent_batch)  # usually b_size == batch_size, except for last batch in samples
    tok_lists_batch = encoded.input_ids.reshape(b_size, sent_count, -1)
    masks_batch = encoded.attention_mask.reshape(b_size, sent_count, -1)

    return tensor(ent_batch), tok_lists_batch, masks_batch, tensor(classes_batch)


test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=generate_batch)

# Load Debug Info

In [None]:
# Load class info
rel_tail_freq_lbl_list = samples_dir.classes_tsv.load()

# load test ent labels
test_ent_to_lbl = split_dir.test_entities_tsv.load()

# Predict test entities

In [None]:
limit_classes = 4

texter.eval()

for i, (ent_batch, sents_batch, masks_batch, gt_batch) in enumerate(test_loader):
    if i == 1:
        break

    # print('ent_batch', ent_batch)
    # print('sent_batch', sents_batch)
    # print('masks_batch', masks_batch)
    # print('gt_batch', gt_batch)

    logits_batch, atts_batch, = texter(sents_batch, masks_batch)
    no_att_logits_batch = texter.forward_without_attention(sents_batch, masks_batch)

    # print('logits_batch', logits_batch)
    # print('atts_batch', atts_batch)

    for ent, sents, masks, gt, logits, atts, no_att_logits in \
            zip(ent_batch, sents_batch, masks_batch, gt_batch, logits_batch, atts_batch, no_att_logits_batch):
        print(test_ent_to_lbl[ent.item()])

        print('sents')
        pprint(test_ent_to_sents[ent.item()])

        # print('masks')
        # pprint(masks)

        print('gt')
        pprint(gt[:limit_classes])

        print('logits')
        pprint(logits[:limit_classes])

        print('atts')
        pprint(atts[:limit_classes])

        class_labels = [rel_tail_freq_lbl_list[c][3] for c in range(class_count)][:4]
        sent_labels = [f'sent {s}' for s in range(sent_count)]
        plot_tensor(atts[:limit_classes], 'atts', [class_labels, sent_labels])

        print('no_att_logits')
        pprint(no_att_logits[:,:limit_classes])