In [1]:
import os
import sys
import torch

PROJ_DIR = os.path.join(os.environ['WORKSPACE'], 'tutorial/')

if PROJ_DIR not in sys.path:
    sys.path.append(PROJ_DIR)

# Prepare the data

In [2]:
import pickle
from src.dataset import IMDBDatset
from src.utilities import flatten, get_dataloader

with open('data.pickle', 'rb') as fp:
    corpus = pickle.load(fp)
 
dataloaders = {
    'train': get_dataloader(corpus['train'], batch_size=32, shuffle=True),
    'dev':   get_dataloader(corpus['dev'],   batch_size=128, shuffle=False),
    'test':  get_dataloader(corpus['test'],  batch_size=128, shuffle=False)
}

# Attention layer

In [3]:
import torch.nn as nn

class AttentionLayer(nn.Module):
    """Attention mechanism: a = softmax(v' · tanh(W_h h + b_h))"""
    def __init__(self,  hidden_size, attn_size):
        super(AttentionLayer, self).__init__()
        self.dh = hidden_size
        self.da = attn_size

        self.W = nn.Linear(self.dh, self.da)        # (feat_dim, attn_dim)
        self.v = nn.Linear(self.da, 1)              # (attn_dim, 1)

    def forward(self, inputs, mask):
        # Raw scores
        u = self.v(torch.tanh(self.W(inputs)))      # (batch, seq, hidden) -> (batch, seq, attn) -> (batch, seq, 1)

        # Masked softmax
        u = u.exp()                                 # exp to calculate softmax
        u = mask.unsqueeze(2).float() * u           # (batch, seq, 1) * (batch, seq, 1) to zerout out-of-mask numbers
        sums = torch.sum(u, dim=1, keepdim=True)    # now we are sure only in-mask values are in sum
        a = u / sums                                # the probability distribution only goes to in-mask values now

        # Weighted sum of the input vectors
        z = torch.sum(inputs * a, dim=1)
        
        return  {'output':z, 'attention': a.view(inputs.size(0), inputs.size(1))}

# LSTM + Attention  Classifier

In [4]:
class LSTMAttentionClassifier(nn.Module):
    def __init__(self, embedder, extractor, attention):
        super(LSTMAttentionClassifier, self).__init__()
        self.embedder = embedder
        self.extractor = extractor
        self.attention = attention
        self.classifier = nn.Linear(extractor.hidden_dim, 1)
        self.xentropy = nn.BCEWithLogitsLoss()

    def forward(self, tokens, targets=None):
        embedded = self.embedder(tokens)
        extracted = self.extractor(embedded['output'], embedded['mask'])
        attended = self.attention(extracted['outputs'], embedded['mask'])
        
        logits = self.classifier(attended['output'])
        loss = None

        if targets is not None:
            logits = logits.view(-1)
            targets = targets.float()
            loss = self.xentropy(logits, targets)

        return {'output': logits, 'loss': loss, 'attention': attended['attention'].data}

In [8]:
from src.nets.embedder import WordEmbedder
from src.nets.lstm import LSTMLayer
from src.nets.classifier import LSTMClassifier

vocab = set(flatten(corpus['train'].tokens + corpus['dev'].tokens))

def create_lstm_attention_classifier():
    embedder = WordEmbedder(vocab, os.path.join(PROJ_DIR, 'glove.6B/glove.6B.100d.txt'))
    lstm_layer = LSTMLayer(embedder.emb_dim, hidden_dim=64, bidirectional=False, num_layers=1)
    attn_layer = AttentionLayer(hidden_size=64, attn_size=100)
    lstm_attn_model = LSTMAttentionClassifier(embedder, lstm_layer, attn_layer)
    return lstm_attn_model

model = create_lstm_attention_classifier()
model

LSTMAttentionClassifier(
  (embedder): WordEmbedder(
    (embeddings): Embedding(21695, 100)
  )
  (extractor): LSTMLayer(
    (lstm): LSTM(100, 64, batch_first=True)
  )
  (attention): AttentionLayer(
    (W): Linear(in_features=64, out_features=100, bias=True)
    (v): Linear(in_features=100, out_features=1, bias=True)
  )
  (classifier): Linear(in_features=64, out_features=1, bias=True)
  (xentropy): BCEWithLogitsLoss()
)

# Training the model

In [9]:
from src.utilities import train
import torch.optim as optim

config = {
    'lr': 1e-2,
    'momentum': 0.99,
    'epochs': 10,
    'checkpoint': 'lstm_attn_model.pt'
}

params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'])
model = train(model, dataloaders, optimizer, config)

E001 [TRAIN] Loss: 0.6854, Acc: 0.5496 [DEV] Loss: 0.6502, Acc: 0.6330 [TEST] Loss: 0.6579, Acc: 0.6076 * 
E002 [TRAIN] Loss: 0.6326, Acc: 0.6387 [DEV] Loss: 0.6107, Acc: 0.6586 [TEST] Loss: 0.6001, Acc: 0.6711 * 
E003 [TRAIN] Loss: 0.5470, Acc: 0.7193 [DEV] Loss: 0.5093, Acc: 0.7510 [TEST] Loss: 0.5114, Acc: 0.7449 * 
E004 [TRAIN] Loss: 0.4996, Acc: 0.7516 [DEV] Loss: 0.4836, Acc: 0.7676 [TEST] Loss: 0.4851, Acc: 0.7572 * 
E005 [TRAIN] Loss: 0.4706, Acc: 0.7705 [DEV] Loss: 0.4716, Acc: 0.7634 [TEST] Loss: 0.4759, Acc: 0.7646
E006 [TRAIN] Loss: 0.4510, Acc: 0.7841 [DEV] Loss: 0.4915, Acc: 0.7516 [TEST] Loss: 0.4943, Acc: 0.7602
E007 [TRAIN] Loss: 0.4455, Acc: 0.7877 [DEV] Loss: 0.4902, Acc: 0.7528 [TEST] Loss: 0.4954, Acc: 0.7626
E008 [TRAIN] Loss: 0.4179, Acc: 0.8097 [DEV] Loss: 0.4738, Acc: 0.7720 [TEST] Loss: 0.4704, Acc: 0.7716 * 
E009 [TRAIN] Loss: 0.3847, Acc: 0.8232 [DEV] Loss: 0.4749, Acc: 0.7706 [TEST] Loss: 0.4720, Acc: 0.7733
E010 [TRAIN] Loss: 0.3558, Acc: 0.8433 [DEV] Loss

# Visualizing attention

In [16]:
from sklearn.metrics import accuracy_score
from src.utilities import process_logits

def attn_model_predict(model, dataset):
    loss = 0
    probs, preds, truth = [], [], []
    attns = []

    model.eval()
    for tokens, targets in dataset:
        result = model(tokens, targets)

        batch_preds, batch_probs = process_logits(result['output'])
        loss += result['loss'].item() * len(batch_preds)

        preds += batch_preds
        probs += batch_probs
        truth += targets.data.cpu().tolist()
        attns += result['attention'].cpu().tolist()

    loss /= len(truth)
    acc = accuracy_score(truth, preds)
    print("Loss: {:.5f}, Acc: {:.5f}".format(loss, acc))

    return probs, preds, truth, attns

In [18]:
probs, preds, truth, attns = attn_model_predict(model, dataloaders['dev'])

Loss: 0.47378, Acc: 0.77200


In [41]:
for i, tokens in enumerate(dataloaders['dev'].dataset.tokens):
    if len(tokens) < 25:
        attn_total = 0
        for token, attn in zip(tokens, attns[i]):
            attn_total += attn
            print('{:.5f} {}'.format(attn, token))
        print(f'\nSample {i}')
        print(f'Pred: {preds[i]} (Prob: {probs[i]:.5f})')
        print(f'Truth: {truth[i]}')
        print(f'Attn: {attn_total:.5f}')
        print(f'Size: {len(tokens)}')
        break

0.00749 The
0.01684 characters
0.00437 are
0.02744 unlikeable
0.02073 and
0.00791 the
0.08742 script
0.03646 is
0.29595 awful
0.10089 .
0.04331 It
0.01325 's
0.00459 a
0.11016 waste
0.06131 of
0.02121 the
0.04267 talents
0.03129 of
0.01442 Deneuve
0.02232 and
0.01114 Auteuil
0.01881 .

Sample 1767
Pred: 0.0 (Prob: 0.00752)
Truth: 0
Attn: 1.00000
Size: 22
