Feb 24

The training in notebook _03_overfit_and_test seems to work. The classifier
overfits on the training data and correctly predicts when given the training
data for testing.

Now, the classifier code should be enhanced to a PyTorch module that is trained
in a training + validation loop as it is the case in the code base. The loss
curve should be printed as well.

The expected result is a loss curve similar to the one resulting from the code
base, i.e. validation loss should drop until around the 5th epoch after which
it should rise again

Hopefully, this notebook provides the means to debug and understand why the
validation loss curve does not drop further and thus, why the classifier does
not yield better test results.

# Imports

In [None]:
%load_ext autoreload
%autoreload 2

from collections import Counter

import torch
from IPython.lib.pretty import pretty
from torch import tensor
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torchtext.vocab import Vocab

from notebooks._04_class_word_attentions import util
from notebooks._04_class_word_attentions.classifier import Classifier
from notebooks._04_class_word_attentions.util import log_tensor, get_ent_lbls, get_sent_lbls, get_tok_lbls, \
    get_emb_lbls, get_class_lbls, get_mix_emb_lbls, get_word_lbls

# 1 Define train/valid data

In [None]:
util.batch_size = batch_size = 2
util.class_count = class_count = 3
util.emb_size = emb_size = 4
util.sent_count = sent_count = 3
util.sent_len = sent_len = 3

train_data = [
    {
        'classes': [1, 1, 1],
        'sents': [
            'married married married',
            'male male male',
            'American American American'
        ]
    },
    {
        'classes': [0, 0, 0],
        'sents': [
            'single single single',
            'female female female',
            'German German German'
        ]
    },
]

valid_data = [
    {
        'classes': [1, 1, 1],  # married, male, American
        'sents': [
            'Barack is married',
            'Barack is male',
            'Barack is American'
        ]
    },
    {
        'classes': [1, 0, 0],  # married, male, American
        'sents': [
            'Angela is married',
            'Angela is female',
            'Angela is German'
        ]
    }
]

# 2 Pre-processing

## 2.1 Build vocabulary from train data

In [None]:
def tokenize(text):
    return text.split()

train_words = [word for ent in train_data for sent in ent['sents'] for word in tokenize(sent)]
vocab = Vocab(Counter(train_words))

print(pretty(vocab.stoi))

util.vocab = vocab
util.vocab_size = vocab_size = len(vocab)

## 2.2 Transform train/valid data

Map words to tokens and create tensors.

In [None]:
train_sents_batch = tensor([[[vocab[word] for word in tokenize(sent)] for sent in ent['sents']] for ent in train_data])
train_classes_batch = torch.tensor([ent['classes'] for ent in train_data])

valid_sents_batch = tensor([[[vocab[word] for word in tokenize(sent)] for sent in ent['sents']] for ent in valid_data])
valid_classes_batch = torch.tensor([ent['classes'] for ent in valid_data])

log_tensor(train_sents_batch, 'train_sents_batch', [get_ent_lbls(), get_sent_lbls(), get_tok_lbls()])
log_tensor(valid_sents_batch, 'valid_sents_batch', [get_ent_lbls(), get_sent_lbls(), get_tok_lbls()])

# 3 Create classifier

In [None]:
classifier = Classifier(vocab_size, emb_size, class_count)

log_tensor(classifier.embedding_bag.weight.detach(), 'classifier.embedding_bag.weight', [get_word_lbls(), get_emb_lbls()])
log_tensor(classifier.class_embs.detach(), 'classifier.class_embs', [get_class_lbls(), get_emb_lbls()])
log_tensor(classifier.linear.weight.data.detach(), 'classifier.linear.weight.data', [get_class_lbls(), get_mix_emb_lbls()])
log_tensor(classifier.linear.bias.data.detach(), 'classifier.linear.bias.data', [get_class_lbls()])

# 4 Training

In [None]:
%load_ext tensorboard
%tensorboard --logdir runs

In [None]:
# criterion = MSELoss()
criterion = BCEWithLogitsLoss()
# criterion = BCEWithLogitsLoss(pos_weight=torch.tensor([80] * class_count))

# optimizer = SGD(classifier.parameters(), lr=0.1)
optimizer = Adam(classifier.parameters(), lr=0.1)

writer = SummaryWriter()


for epoch in range(1000):

    #
    # Train
    #

    train_logits_batch = classifier(train_sents_batch)
    train_loss = criterion(train_logits_batch, train_classes_batch.float())

    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()

    #
    # Validate
    #

    with torch.no_grad():
        valid_logits_batch = classifier(valid_sents_batch)
        valid_loss = criterion(valid_logits_batch, valid_classes_batch.float())

    #
    # Log
    #

    writer.add_scalars('loss', {'train': train_loss, 'valid': valid_loss}, epoch)

    if epoch in [0, 9, 99, 999]:
        print(f'Epoch {epoch}: Train loss = {train_loss.item()}, valid loss = {valid_loss.item()}')

        # log_tensor(classifier.embedding_bag.weight.detach(), 'classifier.embedding_bag.weight', [get_word_lbls(), get_emb_lbls()])
        # log_tensor(classifier.class_embs.detach(), 'classifier.class_embs', [get_class_lbls(), get_emb_lbls()])
        # log_tensor(classifier.linear.weight.data.detach(), 'classifier.linear.weight.data', [get_class_lbls(), get_mix_emb_lbls()])
        # log_tensor(classifier.linear.bias.data.detach(), 'classifier.linear.bias.data', [get_class_lbls()])
        #
        # log_tensor(train_logits_batch.detach(), 'train_logits_batch', [get_ent_lbls(), get_class_lbls()])
        # log_tensor(train_classes_batch.detach(), 'train_classes_batch', [get_ent_lbls(), get_class_lbls()])
        #
        # log_tensor(valid_logits_batch.detach(), 'valid_logits_batch', [get_ent_lbls(), get_class_lbls()])
        # log_tensor(valid_classes_batch.detach(), 'valid_classes_batch', [get_ent_lbls(), get_class_lbls()])

        #
        # How well do class embeddings match words?
        #

        class_embs = classifier.class_embs.data.detach()
        tok_embs = classifier.embedding_bag.weight.detach()

        atts_batch = torch.mm(class_embs, tok_embs.T)
        log_tensor(atts_batch, 'atts_batch', [get_class_lbls(), get_word_lbls()])

# 5 Test

## 5.1 Define test data

In [None]:
test_data = [
    {
        'classes': [1, 0, 1],  # married, male, American
        'sents': [
            'Michelle is married',
            'Michelle is female',
            'Michelle is American'
        ]
    },
    {
        'classes': [1, 0, 0],  # married, male, American
        'sents': [
            'Angela is married',
            'Angela is female',
            'Angela is German'
        ]
    }
]

## 5.2 Pre-process test data

In [None]:
test_sents_batch = tensor([[[vocab[word] for word in tokenize(sent)] for sent in ent['sents']] for ent in test_data])
test_classes_batch = torch.tensor([ent['classes'] for ent in test_data])

log_tensor(test_sents_batch, 'test_sents_batch', [get_ent_lbls(), get_sent_lbls(), get_tok_lbls()])
log_tensor(test_classes_batch, 'test_classes_batch', [get_ent_lbls(), get_class_lbls()])

## 5.3 Forward test batch

In [None]:
test_logits_batch = classifier(test_sents_batch)
test_loss = criterion(test_logits_batch, test_classes_batch.float())

log_tensor(test_logits_batch, 'test_logits_batch', [get_ent_lbls(), get_class_lbls()])
log_tensor(test_classes_batch, 'test_classes_batch', [get_ent_lbls(), get_class_lbls()])