Visualize how well the class embeddings attend on words and
sentences. The expected result would be that the “married”
class embedding, for example, attends heavily on words and
sentences related to marriage like “married”, “husband”, “wife”, etc.

# Imports

In [20]:
%load_ext autoreload
%autoreload 2

from dao.ryn.text.text_dir import TextDir
from pathlib import Path
from pprint import pprint
from random import shuffle
from typing import List, Tuple

import pandas as pd
import torch
from IPython.core.display import display, HTML
from jinja2 import Template
from matplotlib import pyplot
from sklearn.metrics import precision_score, recall_score, f1_score
from torch import tensor, Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

from dao.ower.ower_dir import Sample, OwerDir
from data.power.texter_pkl import TexterPkl
from util import plot_tensor

pd.set_option('display.max_colwidth', None)
pd.set_option('display.precision', 2)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Args

In [21]:
texter_pkl_path = '../data/power/texter-v2/static_attend_cde-irt-5-clean.pkl'

# Input data
ower_dir_path = '../data/ower-v4/cde-irt-100-5/'
class_count = 100
sent_count = 5

text_dir_path = '../data/ryn/text/cde-irt-5-clean/'

# Pre-processing
sent_len = 64

# Testing
batch_size = 4

# Check data

In [22]:
#
# Check that (input) OWER Texter PKL exists
#

texter_pkl = TexterPkl(Path(texter_pkl_path))
texter_pkl.check()

#
# Check that (input) OWER Directory exists
#

ower_dir = OwerDir(Path(ower_dir_path))
ower_dir.check()

#
# Check that (input) Ryn Text Directory exists
#

text_dir = TextDir(Path(text_dir_path))
text_dir.check()

# Load Texter and test data

In [23]:
texter = texter_pkl.load().cpu()

train_set, valid_set, test_set, vocab = ower_dir.read_datasets(class_count, sent_count)


def generate_batch(batch: List[Sample]) -> Tuple[Tensor, Tensor, Tensor]:

    ent_batch, gt_classes_batch, tok_lists_batch = zip(*batch)

    cropped_tok_lists_batch = [[tok_list[:sent_len]
                                for tok_list in tok_lists] for tok_lists in tok_lists_batch]

    padded_tok_lists_batch = [[tok_list + [0] * (sent_len - len(tok_list))
                               for tok_list in tok_lists] for tok_lists in cropped_tok_lists_batch]

    for padded_tok_lists in padded_tok_lists_batch:
        shuffle(padded_tok_lists)

    return tensor(ent_batch), tensor(padded_tok_lists_batch), tensor(gt_classes_batch)


test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=generate_batch)



# Load Debug Info

In [24]:
# Load class info
rel_tail_freq_lbl_list = ower_dir.classes_tsv.load()
ent_to_lbl = ower_dir.ent_labels_txt.load()

# load test ent labels
test_ent_to_sents = text_dir.ow_test_sents_txt.load()




# Predict test entities

In [31]:
limit_classes = 4

texter.eval()

for i, (ent_batch, sents_batch, gt_batch) in enumerate(test_loader):
    if i == 1:
        break

    logits_batch, atts_batch = texter(sents_batch)
    no_att_logits_batch = texter.forward_without_attention(sents_batch)

    for ent, sents, gt, logits, atts, no_att_logits in \
            zip(ent_batch, sents_batch, gt_batch, logits_batch, atts_batch, no_att_logits_batch):
        print(ent_to_lbl[ent.item()])

        print('sents')
        pprint(list(test_ent_to_sents[ent.item()]))

        print('gt')
        pprint(gt[:limit_classes])

        print('logits')
        pprint(logits[:limit_classes])

        print('atts')
        pprint(atts[:limit_classes])

        class_labels = [rel_tail_freq_lbl_list[c][3] for c in range(class_count)][:limit_classes]
        # class_labels = [f'class {c}' for c in range(class_count)][:limit_classes]
        sent_labels = [f'sent {s}' for s in range(sent_count)]
        plot_tensor(atts[:limit_classes], 'atts', [class_labels, sent_labels])

        print('no_att_logits')
        pprint(no_att_logits[:,:limit_classes])


KeyboardInterrupt: 

# Train

In [None]:
log_first_batch = False

for epoch in range(epoch_count):

    ## Train

    train_loss = 0.0

    # Valid gt/pred classes across all batches
    train_gt_classes_stack: List[List[int]] = []
    train_pred_classes_stack: List[List[int]] = []

    for batch_idx, (_, sents_batch, gt_classes_batch) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch}')):
        sents_batch = sents_batch.to(device)
        gt_classes_batch = gt_classes_batch.to(device)

        logits_batch = classifier(sents_batch)

        loss = criterion(logits_batch, gt_classes_batch.float())
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pred_classes_batch = (logits_batch > 0).int()

        train_gt_classes_stack += gt_classes_batch.cpu().numpy().tolist()
        train_pred_classes_stack += pred_classes_batch.cpu().numpy().tolist()

        #
        # Log first batch
        #

        if log_first_batch and batch_idx == 0:

            dlb = logits_batch.cpu().detach().numpy()  # logits batch
            dpb = pred_classes_batch.cpu().detach().numpy()  # predicted classes batch
            dgb = gt_classes_batch.cpu().detach().numpy()  # ground truth classes batch
            dsb = sents_batch.cpu().detach().numpy()  # sentences batch

            df_cols = ['entity', 'logits', 'p', 'gt', 'sents']
            df_data = [('foo', logits, pred_classes, classes, [[vocab.itos[tok] for tok in sent] for sent in sents])
                       for logits, pred_classes, classes, sents in zip(dlb, dpb, dgb, dsb)]

            df = pd.DataFrame(df_data[:8], columns=df_cols)
            display(df)

            display_atts = debug['atts_batch'][:8].cpu()
            ent_labels = [f'ent {i}' for i in range(batch_size)]
            class_labels = [f'clss {i}' for i in range(class_count)]
            sent_labels = [f'sent {i}' for i in range(sent_count)]
            plot_tensor(display_atts, 'atts_batch', [ent_labels, class_labels, sent_labels])

    ## Validate

    valid_loss = 0.0

    # Valid gt/pred classes across all batches
    valid_gt_classes_stack: List[List[int]] = []
    valid_pred_classes_stack: List[List[int]] = []

    with torch.no_grad():
        for batch_idx, (ent_batch, sents_batch, gt_classes_batch) in enumerate(tqdm(valid_loader, desc=f'Epoch {epoch}')):
            sents_batch = sents_batch.to(device)
            gt_classes_batch = gt_classes_batch.to(device)

            logits_batch = classifier(sents_batch)

            loss = criterion(logits_batch, gt_classes_batch.float())
            valid_loss += loss.item()

            pred_classes_batch = (logits_batch > 0).int()

            valid_gt_classes_stack += gt_classes_batch.cpu().numpy().tolist()
            valid_pred_classes_stack += pred_classes_batch.cpu().numpy().tolist()

            #
            # Print first batch
            #

            if log_first_batch and batch_idx == 0:

                dlb = logits_batch.cpu().detach().numpy()  # logits batch
                dpb = pred_classes_batch.cpu().detach().numpy()  # predicted classes batch
                dgb = gt_classes_batch.cpu().detach().numpy()  # ground truth classes batch
                dsb = sents_batch.cpu().detach().numpy()  # sentences batch

                df_cols = ['entity', 'logits', 'p', 'gt', 'sents']
                df_data = [('foo', logits, pred_classes, classes, [[vocab.itos[tok] for tok in sent] for sent in sents])
                           for logits, pred_classes, classes, sents in zip(dlb, dpb, dgb, dsb)]

                df = pd.DataFrame(df_data[:8], columns=df_cols)
                display(df)

                display_atts = debug['atts_batch'][:8].cpu()
                ent_labels = [f'ent {i}' for i in range(batch_size)]
                class_labels = [f'clss {i}' for i in range(class_count)]
                sent_labels = [f'sent {i}' for i in range(sent_count)]
                plot_tensor(display_atts, 'atts_batch', [ent_labels, class_labels, sent_labels])


    ## Log loss

    train_loss /= len(train_loader)
    valid_loss /= len(valid_loader)

    writer.add_scalars('loss', {'train': train_loss, 'valid': valid_loss}, epoch)

    ## Log metrics for most/least common classes

    # tps = train precisions, vps = valid precisions, etc.
    tps = precision_score(train_gt_classes_stack, train_pred_classes_stack, average=None)
    vps = precision_score(valid_gt_classes_stack, valid_pred_classes_stack, average=None)
    trs = recall_score(train_gt_classes_stack, train_pred_classes_stack, average=None)
    vrs = recall_score(valid_gt_classes_stack, valid_pred_classes_stack, average=None)
    tfs = f1_score(train_gt_classes_stack, train_pred_classes_stack, average=None)
    vfs = f1_score(valid_gt_classes_stack, valid_pred_classes_stack, average=None)

    # Log metrics for each class c
    for c, (tp, vp, tr, vr, tf, vf), in enumerate(zip(tps, vps, trs, vrs, tfs, vfs)):

        # many classes -> log only first and last ones
        if (class_count > 2 * 3) and (3 <= c <= len(tps) - 3 - 1):
            continue

        writer.add_scalars('precision', {f'train_{c}': tp}, epoch)
        writer.add_scalars('precision', {f'valid_{c}': vp}, epoch)
        writer.add_scalars('recall', {f'train_{c}': tr}, epoch)
        writer.add_scalars('recall', {f'valid_{c}': vr}, epoch)
        writer.add_scalars('f1', {f'train_{c}': tf}, epoch)
        writer.add_scalars('f1', {f'valid_{c}': vf}, epoch)

    ## Log macro metrics over all classes

    # mtp = mean train precision, mvp = mean valid precision, etc.
    mtp = tps.mean()
    mvp = vps.mean()
    mtr = trs.mean()
    mvr = vrs.mean()
    mtf = tfs.mean()
    mvf = vfs.mean()

    writer.add_scalars('precision', {'train': mtp}, epoch)
    writer.add_scalars('precision', {'valid': mvp}, epoch)
    writer.add_scalars('recall', {'train': mtr}, epoch)
    writer.add_scalars('recall', {'valid': mvr}, epoch)
    writer.add_scalars('f1', {'train': mtf}, epoch)
    writer.add_scalars('f1', {'valid': mvf}, epoch)

# Calc top class-word attentions

In [None]:
class_embs = classifier.class_embs
tok_embs = classifier.embedding_bag.weight

tok_atts = torch.einsum('ce, ve -> cv', class_embs, tok_embs)
result = tok_atts.sort(descending=True)
indices = result.indices.cpu().numpy()
values = result.values.cpu().detach().numpy()

rel_tail_freq_lbl_tuples = ower_dir.classes_tsv.load()
_, _, _, class_labels = zip(*rel_tail_freq_lbl_tuples)

for c, c_lbl in zip(range(class_count), class_labels):
    print('\n', c_lbl)
    for tok, val in zip(indices[c][:10], values[c][:10]):
        print('\t{} ({:.2f})'.format(vocab.itos[tok], val))

# Visualize class-word attentions in sentences

In [None]:
tok_atts /= max(-tok_atts.min(), tok_atts.max())
tok_atts *= 512
tok_atts += 128

def get_color(att: float) -> str:
    att = max(min(att, 255), 0)
    r, g, b = pyplot.get_cmap('viridis').colors[int(att)]

    return f'rgba({int(r * 256)}, {int(g * 256)}, {int(b * 256)}, 0.5)'


def render_sent(class_: int, sent: List[int]) -> str:

    words = ['<span style="background-color: {}">{}</span>'.format(
                get_color(tok_atts[class_][tok]),
                vocab.itos[tok] if tok != 0 else '_'
            ) for tok in sent]

    return ' '.join(words)


def render_table(sents: List[List[int]]) -> None:
    short_class_labels = [class_label[-20:] for class_label in class_labels]

    display(HTML(Template('''
        <style>
            table.atts td { text-align: left }
        </style>

        <table class='atts'>
            <tr>
                <th></th>

                {% for i in range(len(sents)) %}
                <th> Sent {{ i }} </th>
                {% endfor %}
            </tr>

            {% for c in range(len(class_labels)) %}
            <tr>
                <th>{{ class_labels[c] }}</th>

                {% for sent in sents %}
                <td>{{ render_sent(c, sent) }}</td>
                {% endfor %}
            </tr>
            {% endfor %}
        </table>
    ''').render(
        sents=sents,
        class_labels=short_class_labels,
        render_sent=render_sent,
        len=len
    )))

In [None]:
ent_to_lbl = ower_dir.ent_labels_txt.load()

logits_batch = logits_batch.cpu()
pred_classes_batch = pred_classes_batch.cpu()
gt_classes_batch = gt_classes_batch.cpu()
atts_batch = debug['atts_batch'].cpu()

for i in range(20):

    ent = ent_batch[i].item()

    display(HTML('<h1>{} ({})</h1>'.format(ent_to_lbl[ent], ent)))

    texts = [' '.join([vocab.itos[tok] if tok != 0 else '_' for tok in tok_list])
             for tok_list in sents_batch[i]]

    display(HTML(Template('''
        <ul>
            {% for text in texts %}
                <li> {{ text }} </li>
            {% endfor %}
        </ul>
    ''').render(texts=texts)))

    print(class_labels)
    print('logits =', logits_batch[i])
    print('pred =', pred_classes_batch[i])
    print('gt =', gt_classes_batch[i])

    sent_labels = [f'sent {i}' for i in range(sent_count)]
    plot_tensor(atts_batch[i], 'atts', [class_labels, sent_labels])

    display(HTML(render_table(sents_batch[i])))