In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


## 1. Import packages

In [2]:
from torchtext import data
from torchtext import datasets
from torchtext.vocab import GloVe
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch
from tqdm import tqdm
import pandas as pd
import numpy as np

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 32
epochs = 10
embedding_dim = 300
hidden_dim = 200
max_seq_length = 64

## 2. Bulid dataloader and vocab

In [4]:
train_df = pd.read_csv("/content/drive/My Drive/SST-2/data/train.tsv",sep='\t',header=None, names=['similarity','s1'])
dev_df = pd.read_csv("/content/drive/My Drive/SST-2/data/dev.tsv",sep='\t',header=None, names=['similarity','s1'])
test_df = pd.read_csv("/content/drive/My Drive/SST-2/data/test.tsv",sep='\t',header=None, names=['similarity','s1'])

# define Field
tokenize = lambda x: x.split()
TEXT = data.Field(sequential=True, tokenize=tokenize, lower=True, fix_length=max_seq_length)
LABEL = data.Field(sequential=False, use_vocab=False)

# get_dataset constructs and returns the examples and fields required by the Dataset
def get_dataset(csv_data, text_field, label_field, test=False):
    fields = [('id', None), ('s1', text_field), ('similarity', label_field)]
    examples = []  
    for text, label in tqdm(zip(csv_data['s1'], csv_data['similarity'])):
      examples.append(data.Example.fromlist([None, text, label], fields))
    return examples, fields

# Get the examples and fields needed to build the Dataset
train_examples, train_fields = get_dataset(train_df, TEXT, LABEL)
valid_examples, valid_fields = get_dataset(dev_df, TEXT, LABEL)
test_examples, test_fields = get_dataset(test_df, TEXT, LABEL)

# Build Dataset
train = data.Dataset(train_examples, train_fields)
valid = data.Dataset(valid_examples, valid_fields)
test = data.Dataset(test_examples, test_fields)

6920it [00:00, 84243.02it/s]
872it [00:00, 79793.90it/s]
1821it [00:00, 22474.31it/s]


In [5]:
# build the vocabulary
TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=embedding_dim))
LABEL.build_vocab(train)

In [6]:
from torchtext.data import Iterator, BucketIterator

# make splits for data
train_iter, test_iter = BucketIterator.splits(
        (train, test), 
        batch_size=batch_size, 
        device="cuda:0", 
        sort_key=lambda x: len(x.s1),
        sort_within_batch=True,
        repeat=False 
)

valid_iter = data.BucketIterator(dataset=valid, batch_size=batch_size, device="cuda:0", 
        sort_key=lambda x: len(x.s1), shuffle=True, sort_within_batch=True, repeat=False)

## 3. Define the BiLSTM_Attention Model

In [7]:
class SelfAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.projection = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(True),
            nn.Linear(64, 1)
        )

    def forward(self, encoder_outputs):
        batch_size = encoder_outputs.size(0)
        # (B, L, H) -> (B , L, 1)
        energy = self.projection(encoder_outputs)
        weights = F.softmax(energy.squeeze(-1), dim=1)
        # (B, L, H) * (B, L, 1) -> (B, H)
        outputs = (encoder_outputs * weights.unsqueeze(-1)).sum(dim=1)
        return outputs, weights

class AttnClassifier(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim):
        super().__init__()
        self.input_dim = input_dim
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True)
        self.attention = SelfAttention(hidden_dim)
        self.fc = nn.Linear(hidden_dim, 1)
        
    def set_embedding(self, vectors):
        self.embedding.weight.data.copy_(vectors)

    def dropout(self, v):
        return F.dropout(v, p=0.5, training=self.training)

    def forward(self, inputs, lengths):
        batch_size = inputs.size(1)
        # (L, B)
        embedded = self.embedding(inputs)
        embedded = self.dropout(embedded)
        # (L, B, E)
        packed_emb = nn.utils.rnn.pack_padded_sequence(embedded, lengths)
        out, hidden = self.lstm(packed_emb)
        out = nn.utils.rnn.pad_packed_sequence(out)[0]
        out = out[:, :, :self.hidden_dim] + out[:, :, self.hidden_dim:]
        # (L, B, H)
        embedding, attn_weights = self.attention(out.transpose(0, 1))
        # (B, HOP, H)
        outputs = self.fc(embedding.view(batch_size, -1))
        # (B, 1)
        return outputs, attn_weights

In [8]:
def get_length(x):
    length = []
    for i in x.transpose(0, 1).cpu().tolist():
        length.append(len(i)-i.count(1))
    return length

def train(train_iter, model, optimizer, criterion):
    model.train()
    epoch_loss = 0
    bar = tqdm(total=len(train_iter))
    b_ix = 1
    for batch in train_iter:
        x, y = batch.s1, batch.similarity
        optimizer.zero_grad()
        outputs, _ = model(x, get_length(x))
        loss = criterion(outputs.view(-1), y.float())
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        if b_ix % 10 == 0:
            bar.update(10)
            bar.set_description('current loss:{:.4f}'.format(epoch_loss / b_ix))
        b_ix += 1
    bar.update((b_ix - 1) % 10)
    bar.close()
    return epoch_loss / len(train_iter)

In [9]:
def binary_accuracy(preds, y):
    # round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float()  # convert into float for division
    acc = correct.sum() / len(correct)
    return acc

def accuracy(model, test_iter):
    model.eval()
    total_acc = 0
    for i, batch in enumerate(test_iter):
        x, y = batch.s1, batch.similarity
        outputs, _ = model(x, get_length(x))
        total_acc += binary_accuracy(outputs.view(-1), y.float()).item()
    return total_acc / len(test_iter)

## 4. Model training and testing

In [22]:
model = AttnClassifier(len(TEXT.vocab), embedding_dim, hidden_dim).to(device)
model.set_embedding(TEXT.vocab.vectors)
# optim
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)
criterion = nn.BCEWithLogitsLoss().to(device)

# train the model and stop when the accuracy of dev set is not imporved
patience = 1
patience_counter = 0
best_score = 0.0
for epoch in range(epochs):
    train(train_iter, model, optimizer, criterion)
    dev_accuracy = accuracy(model, valid_iter)
    
    if dev_accuracy < best_score:
      patience_counter += 1
    else:
      best_score = dev_accuracy
      patience_counter = 0

    if patience_counter >= patience:
      print("-> Early stopping: patience limit reached, stopping...")
      break

current loss:0.5113: 100%|██████████| 217/217 [00:01<00:00, 121.66it/s]
current loss:0.3593: 100%|██████████| 217/217 [00:01<00:00, 120.31it/s]
current loss:0.2568: 100%|██████████| 217/217 [00:01<00:00, 120.07it/s]


-> Early stopping: patience limit reached, stopping...


In [23]:
# test the model
print("test accuracy: {}".format(accuracy(model, test_iter)))

test accuracy: 0.8547527225393998


## 5. Attention visualization

In [24]:
def highlight(word, attn):
    html_color = '#%02X%02X%02X' % (255, int(255*(1 - attn)), int(255*(1 - attn)))
    return '<span style="background-color: {}">{}</span>'.format(html_color, word)

def mk_html(seq, attns):
    html = ""
    for ix, attn in zip(seq, attns):
        html += ' ' + highlight(
            TEXT.vocab.itos[ix],
            attn
        )
    return html + "<br><br>\n"

In [25]:
from IPython.display import HTML, display
with torch.no_grad():
    for batch in test_iter:
        x, y = batch.s1, batch.similarity
        y = batch.similarity
        outputs, attn_weights = model(x, get_length(x))
        # show the correctly classified sentences of the first batch
        for i in range(batch_size):
            if torch.round(F.sigmoid(outputs[i])) == y[i].float():
                # print(attn_weights[i].cpu().numpy())
                text = mk_html(x.t()[i].cpu().numpy(), attn_weights[i].cpu().numpy())
                display(HTML(text))
            break



the sentences of each batch are sorted by its length, so the output are also sorted.  
Actually, we can see from the visualized result that the bilstm_attention model capture the key words to identify if the sentiment of the sentence is positive.