In [1]:
# Weird deprecation issues in torchtext means this cell needs to be run twice
import torch
from torch import nn
from torchtext.datasets import IMDB
from torchtext.vocab import vocab
from torch.utils.data import DataLoader, random_split
import re
from collections import Counter, OrderedDict

torch.manual_seed(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda




In [2]:
# Download data
train_dataset = IMDB(split='train')
test_dataset = IMDB(split='test')
train_dataset, val_dataset = random_split(list(train_dataset),[20000,5000])

In [3]:
# Implement tokenizer to remove unwanted tokens
def tokenizer(text):
    text = re.sub('<[^>]*>','',text)
    emoticons = re.findall('(?::|;|=)(?:-)?(?:\)|\(|D|P)',text.lower())
    text = re.sub('[\W]+',' ', text.lower() + ' '.join(emoticons).replace('-',''))
    tokenized = text.split()
    return tokenized

In [4]:
token_counts = Counter()
for label, line in train_dataset:
    tokens = tokenizer(line)
    token_counts.update(tokens)
print(f'Vocab size: {len(token_counts)}')

Vocab size: 69000


In [5]:
# Construct vocabulary from Counter in above cell and add tokens: <pad>,<unk>
sorted_tuples = sorted(token_counts.items(),key=lambda x: x[1],reverse=True)
ordered_dict = OrderedDict(sorted_tuples)
vocab = vocab(ordered_dict)
vocab.insert_token('<pad>',0)
vocab.insert_token('<unk>',1)
vocab.set_default_index(1)

In [6]:
# Collate function to load batches
#   - Tokenize each sentence
#   - Extract target label
#   - Pad with zeroes
text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]
label_pipeline = lambda x: 1.0 if x == 2 else 0.0

def collate_batch(batch):
    label_list,text_list,lengths = [],[],[]
    for _label,_text in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text),dtype=torch.int64)
        text_list.append(processed_text)
        lengths.append(processed_text.size(0))
    label_list = torch.tensor(label_list).to(device)
    lengths = torch.tensor(lengths).to(device)
    padded_text_list = nn.utils.rnn.pad_sequence(text_list,batch_first=True).to(device)
    return padded_text_list,label_list,lengths

In [7]:
batch_size = 32
train_dl = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,collate_fn=collate_batch)
val_dl = DataLoader(val_dataset,batch_size=batch_size,shuffle=False,collate_fn=collate_batch)
test_dl = DataLoader(test_dataset,batch_size=batch_size,shuffle=False,collate_fn=collate_batch)

In [8]:
class RNN(nn.Module):
    def __init__(self,vocab_size,embed_dim,rnn_hidden_size,fc_hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size,embed_dim,padding_idx=0)
        self.rnn = nn.LSTM(embed_dim,rnn_hidden_size,batch_first=True,bidirectional=True)
        # self.rnn = nn.LSTM(embed_dim,rnn_hidden_size,batch_first=True)
        # self.rnn = nn.RNN(embed_dim,rnn_hidden_size,batch_first=True)
        self.fc1 = nn.Linear(rnn_hidden_size*2,fc_hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(fc_hidden_size,1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self,text,lengths):
        out = self.embedding(text)
        out = nn.utils.rnn.pack_padded_sequence(out,lengths.cpu().numpy(),enforce_sorted=False,batch_first=True)
        out, (hidden,cell) = self.rnn(out)
        # _,hidden = self.rnn(out)
        # out = hidden[-1,:,:] # Extract last state
        out = torch.concat((hidden[-2,:,:],hidden[-1,:,:]),dim=1) # Extract last state
        out = self.relu(self.fc1(out))
        out = self.sigmoid(self.fc2(out))
        return out

vocab_size = len(vocab)
embed_dim = 20
rnn_hidden_size = 64
fc_hidden_size = 64
model = RNN(vocab_size,embed_dim,rnn_hidden_size,fc_hidden_size).to(device)

In [9]:
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

def train(dataloader):
    model.train()
    total_acc, total_loss = 0,0
    for text_batch,label_batch,lengths in dataloader:
        optimizer.zero_grad()
        preds = model(text_batch,lengths)[:,0]
        loss = loss_fn(preds,label_batch)
        loss.backward()
        optimizer.step()
        total_acc += ((preds >= 0.5).float() == label_batch).float().sum().item()
        total_loss += loss.item()*label_batch.size(0)
    return total_acc/len(dataloader.dataset), total_loss/len(dataloader.dataset)

def evaluate(dataloader):
    model.eval()
    total_acc, total_loss = 0,0
    with torch.no_grad():
        for text_batch,label_batch,lengths in dataloader:
            preds = model(text_batch,lengths)[:,0]
            loss = loss_fn(preds,label_batch)
            total_acc += ((preds >= 0.5).float() == label_batch).float().sum().item()
            total_loss += loss.item()*label_batch.size(0)
    return total_acc/len(dataloader.dataset), total_loss/len(dataloader.dataset)

In [10]:
num_epochs = 10

for epoch in range(num_epochs):
    acc_train, loss_train = train(train_dl)
    acc_val, loss_val = evaluate(val_dl)
    print(f'Epoch {epoch}: Accuracy {acc_train:.4f} Val acc {acc_val:.4f}')
    print(f'Epoch {epoch}: Loss {loss_train:.4f} Val loss {loss_val:.4f}')

Epoch 0: Accuracy 0.6252 Val acc 0.7086
Epoch 0: Loss 0.6340 Val loss 0.5607
Epoch 1: Accuracy 0.7771 Val acc 0.7982
Epoch 1: Loss 0.4668 Val loss 0.4442
Epoch 2: Accuracy 0.8396 Val acc 0.7758
Epoch 2: Loss 0.3673 Val loss 0.4582
Epoch 3: Accuracy 0.8821 Val acc 0.8436
Epoch 3: Loss 0.2864 Val loss 0.3769
Epoch 4: Accuracy 0.8807 Val acc 0.8390
Epoch 4: Loss 0.2811 Val loss 0.4120
Epoch 5: Accuracy 0.9208 Val acc 0.8452
Epoch 5: Loss 0.1999 Val loss 0.4333
Epoch 6: Accuracy 0.9403 Val acc 0.8542
Epoch 6: Loss 0.1569 Val loss 0.3916
Epoch 7: Accuracy 0.9570 Val acc 0.8588
Epoch 7: Loss 0.1183 Val loss 0.4216
Epoch 8: Accuracy 0.9689 Val acc 0.8338
Epoch 8: Loss 0.0883 Val loss 0.5147
Epoch 9: Accuracy 0.9718 Val acc 0.8582
Epoch 9: Loss 0.0808 Val loss 0.5700
