Feb 23

Notebook 1, "Simple Positive Case" does not deliver very helpful results.
It is not clear whether the prepared classifier weights represent the
classifier state after being learned.

A better test might be starting out with a classifier with random weights
and trying to overfit on a single sample. The expected result would be
learned weights that are similar to the preparation in notebook 1.

# Set up helpers

In [None]:
from IPython.lib.pretty import pretty

def log_tensor(tensor, title, labels, vmin=None, vmax=None):
    pass

%run util.ipynb

# 1 Input data

In [None]:
data = [
    { 'ent': 123, 'classes': [1, 1, 1], 'sents': ['married married married', 'male male male', 'American American American'] },
    { 'ent': 123, 'classes': [1, 1, 1], 'sents': ['married married married', 'male male male', 'American American American'] },
]

# 2 Pre-processing

## 2.1 Build vocabulary

In [None]:
from collections import Counter
from torchtext.vocab import Vocab

def tokenize(text):
    return text.split()

words = [word for ent in data for sent in ent['sents'] for word in tokenize(sent)]
vocab = Vocab(Counter(words))

print(pretty(vocab.stoi))

## 2.2 Transform data

Map words to tokens and create tensors.

In [None]:
import torch
from torch import tensor

sents_batch = tensor([[[vocab[word] for word in tokenize(sent)] for sent in ent['sents']] for ent in data])
classes_batch = torch.tensor([ent['classes'] for ent in data])

assert len(sents_batch) == len(classes_batch)

batch_size = len(sents_batch)
sent_count = 3
sent_len = 3

ent_labels = [f'ent {i}' for i in range(batch_size)]
sent_labels = [f'sent {i}' for i in range(sent_count)]
tok_labels = [f'tok {i}' for i in range(sent_len)]

log_tensor(sents_batch, 'sents_batch', [ent_labels, sent_labels, tok_labels])

# 3 Prepare classifier

## 3.1 Prepare EmbeddingBag

Create and prepare an `EmbeddingBag` with randomly distributed token embeddings.

In [None]:
from torch.nn import EmbeddingBag
from torch import tensor

vocab_size = len(vocab)
assert vocab_size == 5

emb_size = 4

embedding_bag = EmbeddingBag(num_embeddings=vocab_size, embedding_dim=emb_size)

log_output = 1

word_labels = ['<unk>', '<pad>', 'married', 'male', 'American']
emb_labels = [f'emb {i}' for i in range(emb_size)]

if log_output:
    log_tensor(embedding_bag.weight, 'embedding_bag.weight', [word_labels, emb_labels])

## 3.2 Prepare class embeddings

Create randomly initialized class embeddings.

In [None]:
class_count = 3

class_embs = torch.rand((class_count, emb_size), requires_grad=True)

log_output = 1

class_labels = ['married', 'male', 'American']

if log_output:
    log_tensor(class_embs, 'class_embs', [class_labels, emb_labels])

## 3.3 Prepare linear layer

Create a randomly initialized linear layer.

In [None]:
from torch.nn import Linear

linear = Linear(class_count * emb_size, class_count)

log_output = 1

mix_emb_labels = [f'mix {i} / emb {j}' for i in range(class_count) for j in range(emb_size)]

if log_output:
    log_tensor(linear.weight.data, 'linear.weight.data', [class_labels, mix_emb_labels])
    log_tensor(linear.bias.data, 'linear.bias.data', [class_labels])

# 4 Forward & Backward

In [None]:
#
# 4.1 Embed sentences
#

#
# Flatten batch
#
# < sents_batch  (batch_size, sent_count, sent_len)
# > flat_sents   (batch_size * sent_count, sent_len)
#

flat_sents = sents_batch.reshape(batch_size * sent_count, sent_len)

log_input = 0
log_output = 0

ent_sent_labels = [f'ent {i} / sent {j}' for i in range(batch_size) for j in range(sent_count)]

if log_input:
    log_tensor(sents_batch, 'sents_batch', [ent_labels, sent_labels, tok_labels])

if log_output:
    log_tensor(flat_sents, 'flat_sents', [ent_sent_labels, tok_labels])

#
# Embed sentences
#
# < embedding_bag.weight  (vocab_size, emb_size)
# < flat_sents            (batch_size * sent_count, sent_len)
# > flat_sent_embs        (batch_size * sent_count, emb_size)
#

flat_sent_embs = embedding_bag(flat_sents)

log_input = 0
log_output = 0

if log_input:
    log_tensor(embedding_bag.weight, 'embedding_bag.weight', [word_labels, emb_labels])
    log_tensor(flat_sents, 'flat_sents', [ent_sent_labels, tok_labels])

if log_output:
    log_tensor(flat_sent_embs, 'flat_sent_embs', [ent_sent_labels, emb_labels])

#
# Restore batch
#
# < flat_sent_embs   (batch_size * sent_count, emb_size)
# > sent_embs_batch  (batch_size, sent_count, emb_size)
#

sent_embs_batch = flat_sent_embs.reshape(batch_size, sent_count, emb_size)

log_input = 0
log_output = 1

if log_input:
    log_tensor(flat_sent_embs, 'flat_sent_embs', [ent_sent_labels, emb_labels])

if log_output:
    log_tensor(sent_embs_batch, 'sent_embs_batch', [ent_labels, sent_labels, emb_labels])

#
# 4.2 Calc attentions
#

#
# Expand class embeddings for bmm()
#
# < class_embs        (class_count, emb_size)
# > class_embs_batch  (batch_size, class_count, emb_size)
#

class_embs_batch = class_embs.expand(batch_size, class_count, emb_size)

log_input = 0
log_output = 0

if log_input:
    log_tensor(class_embs, 'class_embs', [class_labels, emb_labels])

if log_output:
    log_tensor(class_embs_batch, 'class_embs_batch', [ent_labels, class_labels, emb_labels])

#
# Multiply each class with each sentence
#
# < class_embs_batch    (batch_size, class_count, emb_size)
# < sent_embs_batch     (batch_size, sent_count, emb_size)
# > atts_batch          (batch_size, class_count, sent_count)
#

atts_batch = torch.bmm(class_embs_batch, sent_embs_batch.transpose(1, 2))

log_input = 0
log_output = 0

if log_input:
    log_tensor(class_embs_batch, 'class_embs_batch', [ent_labels, class_labels, emb_labels])
    log_tensor(sent_embs_batch, 'sent_embs_batch', [ent_labels, sent_labels, emb_labels])

if log_output:
    log_tensor(atts_batch, 'atts_batch', [ent_labels, class_labels, sent_labels])

#
# Apply softmax over sentences
#
# < atts_batch      (batch_size, class_count, sent_count)
# > softs_batch     (batch_size, class_count, sent_count)
#

from torch.nn import Softmax

softs_batch = Softmax(dim=-1)(atts_batch)

log_input = 0
log_output = 1

if log_input:
    log_tensor(atts_batch, 'atts_batch', [ent_labels, class_labels, sent_labels])

if log_output:
    log_tensor(softs_batch, 'softs_batch', [ent_labels, class_labels, sent_labels])

#
# 4.3 Weight and mix sentences
#

#
# Repeat each batch slice class_count times
#
# < sent_embs_batch     (batch_size, sent_count, emb_size)
# > expaned_batch       (batch_size, class_count, sent_count, emb_size)
#

expaned_batch = sent_embs_batch.unsqueeze(1).expand(-1, class_count, -1, -1)

log_input = 0
log_output = 0

if log_input:
    log_tensor(sent_embs_batch, 'sent_embs_batch', [ent_labels, sent_labels, emb_labels])

if log_output:
    log_tensor(expaned_batch, 'expaned_batch', [ent_labels, class_labels, sent_labels, emb_labels])

#
# Flatten sentences for bmm()
#
# < expaned_batch   (batch_size, class_count, sent_count, emb_size)
# > flat_expanded   (batch_size * class_count, sent_count, emb_size)
#

flat_expanded = expaned_batch.reshape(-1, sent_count, emb_size)

log_input = 0
log_output = 0

ent_class_labels = [f'ent {i} / class {j}' for i in range(batch_size) for j in range(class_count)]

if log_input:
    log_tensor(expaned_batch, 'expaned_batch', [ent_labels, class_labels, sent_labels, emb_labels])

if log_output:
    log_tensor(flat_expanded, 'flat_expanded', [ent_class_labels, sent_labels, emb_labels])

#
# Flatten attentions for bmm()
#
# < softs_batch     (batch_size, class_count, sent_count)
# > flat_softs      (batch_size * class_count, sent_count, 1)
#

flat_softs = softs_batch.reshape(batch_size * class_count, sent_count).unsqueeze(-1)

log_input = 0
log_output = 0

if log_input:
    log_tensor(softs_batch, 'softs_batch', [ent_labels, class_labels, sent_labels])

if log_output:
    log_tensor(flat_softs, 'flat_softs', [ent_class_labels, sent_labels, ['']])

#
# Multiply each sentence with each attention
#
# < flat_expanded   (batch_size * class_count, sent_count, emb_size)
# < flat_softs      (batch_size * class_count, sent_count, 1)
# > flat_weighted   (batch_size * class_count, emb_size)
#

flat_weighted = torch.bmm(flat_expanded.transpose(1, 2), flat_softs).squeeze(-1)

log_input = 0
log_output = 0

if log_input:
    log_tensor(flat_expanded, 'flat_expanded', [ent_class_labels, sent_labels, emb_labels])
    log_tensor(flat_softs, 'flat_softs', [ent_class_labels, sent_labels, ['']])

if log_output:
    log_tensor(flat_weighted, 'flat_weighted', [ent_class_labels, emb_labels])

#
# Restore batch
#
# < flat_weighted   (batch_size * class_count, emb_size)
# > weighted_batch  (batch_size, class_count, emb_size)
#

weighted_batch = flat_weighted.reshape(batch_size, class_count, emb_size)

log_input = 0
log_output = 1

if log_input:
    log_tensor(flat_weighted, 'flat_weighted', [ent_class_labels, emb_labels])

if log_output:
    log_tensor(weighted_batch, 'weighted_batch', [ent_labels, class_labels, emb_labels])

#
# 4.4 Linear layer
#

#
# Concatenate mixes
#
# < weighted_batch  (batch_size, class_count, emb_size)
# > concat_mixes_batch  (batch_size, class_count * emb_size)
#

concat_mixes_batch = weighted_batch.reshape(batch_size, class_count * emb_size)

log_input = 0
log_output = 1

mix_emb_labels = [f'mix {i} / emb {j}' for i in range(class_count) for j in range(emb_size)]

if log_input:
    log_tensor(weighted_batch, 'weighted_batch', [ent_labels, class_labels, emb_labels])

if log_output:
    log_tensor(concat_mixes_batch, 'concat_mixes_batch', [ent_labels, mix_emb_labels])

#
# Push concatenated mixes through linear layer
#
# < concat_mixes_batch  (batch_size, class_count * emb_size)
# > logits_batch        (batch_size, class_count)
#

logits_batch = linear(concat_mixes_batch)

log_input = 0
log_output = 1

if log_input:
    log_tensor(concat_mixes_batch, 'concat_mixes_batch', [ent_labels, mix_emb_labels])

if log_output:
    log_tensor(logits_batch, 'logits_batch', [ent_labels, class_labels], vmin=-1, vmax=1)

#
# 4.5 Loss
#

#
# Push concatenated mixes through linear layer
#
# < concat_mixes_batch  (batch_size, class_count * emb_size)
# > logits_batch        (batch_size, class_count)
#

from torch.nn import BCEWithLogitsLoss

criterion = BCEWithLogitsLoss(pos_weight=torch.tensor([80] * class_count))
# criterion = torch.nn.MSELoss()

loss = criterion(logits_batch, classes_batch.float())

log_input = 0
log_output = 1

if log_input:
    log_tensor(logits_batch, 'logits_batch', [ent_labels, class_labels], vmin=-1, vmax=1)
    log_tensor(classes_batch, 'classes_batch', [ent_labels, class_labels], vmin=-1, vmax=1)

log_tensor(loss, 'loss', [])

#
# 4.6 Backward
#

from torch.optim import Adam

optimizer = Adam([embedding_bag.weight, class_embs, linear.weight, linear.bias], lr=0.1)

optimizer.zero_grad()
loss.backward()

optimizer.step()