Mar 6

Visualize how well the class embeddings attend on words and
sentences. The expected result would be that the “married”
class embedding, for example, attends heavily on words and
sentences related to marriage like “married”, “husband”, “wife”, etc.

This notebook prints the results on the validation data after training
(not during training).

Also, shuffle each entity's sentences during training and weight classes
according to class frequency.

# Imports

In [None]:
%load_ext autoreload
%autoreload 2

from datetime import datetime
from pathlib import Path
from random import shuffle
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
from IPython.core.display import display, HTML
from jinja2 import Template
from matplotlib import pyplot
from sklearn.metrics import precision_score, recall_score, f1_score
from torch import tensor, Tensor
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from classifier import Classifier, debug
from dao.ower.ower_dir import Sample, OwerDir
from dao.ryn.ryn_dir import RynDir
from util import plot_tensor

pd.set_option('display.max_colwidth', None)
pd.set_option('display.precision', 2)

# Config

In [None]:
ower_dataset = 'ower-v3-fb-irt-3'
ower_dir_path = Path(f'data/ower/{ower_dataset}')
class_count = 4
sent_count = 3

ryn_dir_path = Path(f'data/ryn/irt.fb.irt.5.clean')

# vectors = None
# vectors = 'charngram.100d'
# vectors = 'fasttext.en.300d'
# vectors = 'fasttext.simple.300d'
# vectors = 'glove.42B.300d'
# vectors = 'glove.840B.300d'
# vectors = 'glove.twitter.27B.25d'
# vectors = 'glove.twitter.27B.50d'
# vectors = 'glove.twitter.27B.100d'
vectors = 'glove.twitter.27B.200d'
# vectors = 'glove.6B.50d'
# vectors = 'glove.6B.100d'
# vectors = 'glove.6B.200d'
# vectors = 'glove.6B.300d'

emb_size = None
# emb_size = 200

batch_size = 1024
sent_len = 64

lr = 0.1
epoch_count = 20

log_dir = 'runs/' + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + \
          f'_{ower_dataset}_emb-{emb_size}'

# 1 Build train/valid DataLoaders

In [None]:
ower_dir = OwerDir('OWER Dataset Directory', ower_dir_path, class_count, sent_count)
ower_dir.check()

train_set: List[Sample]
valid_set: List[Sample]

train_set, valid_set, _, vocab = ower_dir.read_datasets(vectors=vectors)

#
# Create dataloaders
#

def generate_batch(batch: List[Sample]) -> Tuple[Tensor, Tensor, Tensor]:

    ent_batch, gt_classes_batch, tok_lists_batch = zip(*batch)

    cropped_tok_lists_batch = [[tok_list[:sent_len]
                                for tok_list in tok_lists] for tok_lists in tok_lists_batch]

    padded_tok_lists_batch = [[tok_list + [0] * (sent_len - len(tok_list))
                               for tok_list in tok_lists] for tok_lists in cropped_tok_lists_batch]

    for padded_tok_lists in padded_tok_lists_batch:
        shuffle(padded_tok_lists)

    return tensor(ent_batch), tensor(padded_tok_lists_batch), tensor(gt_classes_batch)

train_loader = DataLoader(train_set, batch_size=batch_size, collate_fn=generate_batch, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, collate_fn=generate_batch)

#
# Calc class freqs and set class weights
#

_, train_classes_stack, _ = zip(*train_set)
train_classes_stack = np.array(train_classes_stack)
train_freqs = train_classes_stack.mean(axis=0)

class_weights = tensor(1 / train_freqs)
class_weights

# 2 Create classifier

In [None]:
if vocab.vectors is None:
    classifier = Classifier.from_random(len(vocab), emb_size, class_count)
else:
    classifier = Classifier.from_pre_trained(vocab, class_count)

debug['enabled'] = True
print(classifier)

# 3 Training

In [None]:
# criterion = MSELoss()
criterion = BCEWithLogitsLoss(pos_weight=class_weights)

# optimizer = SGD(classifier.parameters(), lr=lr)
optimizer = Adam(classifier.parameters(), lr=lr)

writer = SummaryWriter(log_dir=log_dir)

for epoch in range(epoch_count):

    #
    # Train
    #

    # Train loss across all batches
    train_loss = 0.0

    # Valid gt/pred classes across all batches
    train_gt_classes_stack: List[List[int]] = []
    train_pred_classes_stack: List[List[int]] = []

    for _, sents_batch, gt_classes_batch in tqdm(train_loader):
        logits_batch = classifier(sents_batch)

        loss = criterion(logits_batch, gt_classes_batch.float())
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pred_classes_batch = (logits_batch > 0).int()

        train_gt_classes_stack += gt_classes_batch.numpy().tolist()
        train_pred_classes_stack += pred_classes_batch.numpy().tolist()

    #
    # Validate
    #

    # Valid loss across all batches
    valid_loss = 0.0

    # Valid gt/pred classes across all batches
    valid_gt_classes_stack: List[List[int]] = []
    valid_pred_classes_stack: List[List[int]] = []

    with torch.no_grad():
        for batch_idx, (ent_batch, sents_batch, gt_classes_batch) in enumerate(tqdm(valid_loader)):
            logits_batch = classifier(sents_batch)

            loss = criterion(logits_batch, gt_classes_batch.float())
            valid_loss += loss.item()

            pred_classes_batch = (logits_batch > 0).int()

            valid_gt_classes_stack += gt_classes_batch.numpy().tolist()
            valid_pred_classes_stack += pred_classes_batch.numpy().tolist()

    #
    # Log
    #

    print(f'Epoch {epoch}')

    train_loss /= len(train_loader)
    valid_loss /= len(valid_loader)

    writer.add_scalars('loss', {'train': train_loss, 'valid': valid_loss}, epoch)

    # tps = train precisions, vps = valid precisions, etc.
    tp = precision_score(train_gt_classes_stack, train_pred_classes_stack, average=None).mean()
    vp = precision_score(valid_gt_classes_stack, valid_pred_classes_stack, average=None).mean()
    tr = recall_score(train_gt_classes_stack, train_pred_classes_stack, average=None).mean()
    vr = recall_score(valid_gt_classes_stack, valid_pred_classes_stack, average=None).mean()
    tf = f1_score(train_gt_classes_stack, train_pred_classes_stack, average=None).mean()
    vf = f1_score(valid_gt_classes_stack, valid_pred_classes_stack, average=None).mean()

    writer.add_scalars('precision', {f'train': tp, f'valid': vp}, epoch)
    writer.add_scalars('recall', {f'train': tr, f'valid': vr}, epoch)
    writer.add_scalars('f1', {f'train': tf, f'valid': vf}, epoch)

# 4 Calc top class-word attentions

In [None]:
class_embs = classifier.class_embs
tok_embs = classifier.embedding_bag.weight

tok_atts = torch.einsum('ce, ve -> cv', class_embs, tok_embs)
result = tok_atts.sort(descending=True)
indices = result.indices.numpy()
values = result.values.detach().numpy()

class_labels = ['male', 'married', 'American', 'actor']

for c, c_lbl in zip(range(class_count), class_labels):
    print(c_lbl)
    for tok, val in zip(indices[c][:10], values[c][:10]):
        print('\t{} ({:.2f})'.format(vocab.itos[tok], val))

# 5 Visualize class-word attentions in sentences

In [None]:
tok_atts /= max(-tok_atts.min(), tok_atts.max())
tok_atts *= 512
tok_atts += 128

In [None]:
def get_color(att: float) -> str:
    att = max(min(att, 255), 0)
    r, g, b = pyplot.get_cmap('viridis').colors[int(att)]

    return f'rgba({int(r * 256)}, {int(g * 256)}, {int(b * 256)}, 0.5)'


def render_sent(class_: int, sent: List[int]) -> str:

    words = ['<span style="background-color: {}">{}</span>'.format(
                get_color(tok_atts[class_][tok]),
                vocab.itos[tok] if tok != 0 else '_'
            ) for tok in sent]

    return ' '.join(words)


def render_table(sents: List[List[int]]) -> None:
    display(HTML(Template('''
        <style>
            table.atts td { text-align: left }
        </style>

        <table class='atts'>
            <tr>
                <th></th>

                {% for i in range(len(sents)) %}
                <th> Sent {{ i }} </th>
                {% endfor %}
            </tr>

            {% for c in range(len(class_labels)) %}
            <tr>
                <th>{{ class_labels[c] }}</th>

                {% for sent in sents %}
                <td>{{ render_sent(c, sent) }}</td>
                {% endfor %}
            </tr>
            {% endfor %}
        </table>
    ''').render(
        sents=sents,
        class_labels=class_labels,
        render_sent=render_sent,
        len=len
    )))

In [None]:
ryn_dir = RynDir('Ryn Directory', ryn_dir_path)
ryn_dir.check()
rid_to_label = ryn_dir.split_dir.entity_labels_txt.load_rid_to_label()

for i in range(20):

    ent = ent_batch[i].item()

    display(HTML('<h1>{} ({})</h1>'.format(rid_to_label[ent], ent)))

    texts = [' '.join([vocab.itos[tok] if tok != 0 else '_' for tok in tok_list])
             for tok_list in sents_batch[i]]

    display(HTML(Template('''
        <ul>
            {% for text in texts %}
                <li> {{ text }} </li>
            {% endfor %}
        </ul>
    ''').render(texts=texts)))

    print('male', 'married', 'American', 'actor')
    print('logits =', logits_batch[i].detach().numpy())
    print('pred =', pred_classes_batch[i].detach().numpy())
    print('gt =', gt_classes_batch[i].detach().numpy())

    class_labels = ['male', 'married', 'American', 'actor']
    sent_labels = [f'sent {i}' for i in range(sent_count)]
    plot_tensor(debug['atts_batch'][i], 'atts', [class_labels, sent_labels])

    display(HTML(render_table(sents_batch[i])))