In [37]:
import re
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchtext.datasets import IMDB
from torchtext.vocab import build_vocab_from_iterator, Vocab
from torch.nn.utils.rnn import pad_sequence

In [32]:
!pip install torchdata

## Text Preprocessing

In [36]:
training_data = IMDB(split='train')
test_data = IMDB(split='test')

In [38]:
training_data, valid_data = random_split(list(training_data),[20000,5000])
test_data, _ = random_split(list(test_data),[25000,0])

In [39]:
print(type(training_data))
print(type(valid_data))
print(type(test_data))

<class 'torch.utils.data.dataset.Subset'>
<class 'torch.utils.data.dataset.Subset'>
<class 'torch.utils.data.dataset.Subset'>


In [40]:
IterDataPipe = iter(training_data)
def get_tokenizer(sentence):
  tokenized = re.sub(r'[^\w\s]+',' ',sentence.lower())
  return tokenized.split()

In [41]:
def yield_tokens(example):
  for label,text in example:
    yield get_tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(IterDataPipe))
vocab.insert_token("<pad>",0)
vocab.insert_token("<unk>",1)
vocab.set_default_index(vocab["<unk>"])

In [42]:
textpipeline = lambda x: vocab(get_tokenizer(x))
labelpipeline = lambda x:1 if x == 'pos' else 0

In [43]:
def collate_batch(data_iter):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  text_processed, label_processed, length = [],[],[]

  for y_batch,x_batch in data_iter:
    text = torch.tensor(textpipeline(x_batch),dtype=torch.float32)
    text_processed.append(text)
    length.append(text.size(0))
    label = torch.tensor(labelpipeline(y_batch),dtype=torch.float32)
    label_processed.append(label)
    
  padded_text = pad_sequence(text_processed,batch_first=True).clone().detach().requires_grad_(True)
  label_processed = torch.tensor(label_processed)
  length = torch.tensor(length,dtype=torch.float32)
  return padded_text.to(device),label_processed.to(device), length

In [44]:
## sanity check
batch = [('neg','This is bad'),('pos','This is good'),('pos','This is good again')]
print(collate_batch(batch))

(tensor([[ 12.,   7.,  77.,   0.],
        [ 12.,   7.,  51.,   0.],
        [ 12.,   7.,  51., 177.]], device='cuda:0', grad_fn=<ToCopyBackward0>), tensor([0., 1., 1.], device='cuda:0'), tensor([3., 3., 4.]))


In [45]:
train_dl = DataLoader(training_data,batch_size=32,collate_fn=collate_batch,shuffle=True,drop_last=True)
valid_dl = DataLoader(valid_data,batch_size=32,collate_fn=collate_batch,drop_last=True)
test_dl = DataLoader(test_data,batch_size=32,collate_fn=collate_batch,drop_last=True)

## Model Definition


In [47]:
vocab_size = len(vocab)
embed_size = 20
rnn_hidden_size = 64
fc_hidden_size = 64

In [48]:
class TextClassification(nn.Module):
  def __init__(self,vocab_size:int, embed_size:int,rnn_hidden_size:int,fc_hidden_size:int):
    super(TextClassification,self).__init__()
    self.embedding = nn.Embedding(vocab_size,embed_size,padding_idx=0)
    self.lstm = nn.LSTM(embed_size,rnn_hidden_size,batch_first=True)
    self.linear_1 = nn.Linear(rnn_hidden_size, fc_hidden_size)
    self.relu = nn.ReLU()
    self.linear_2 = nn.Linear(fc_hidden_size,1)
    self.sigmoid = nn.Sigmoid()

  def forward(self,input,length):
    ## lstm hidden and cell states default to zero
    input = input.to(torch.int64)     ## due to argument type mismatch
    out = self.embedding(input)
    out = nn.utils.rnn.pack_padded_sequence(out,length.cpu().numpy(),batch_first=True, enforce_sorted=False)  ##length should 1D cpu tensor
    out, (hidden,cell) = self.lstm(out)
    out = hidden[-1,:,:]
    out = self.linear_1(out)
    out = self.relu(out)
    out = self.linear_2(out)
    out = self.sigmoid(out)
    return out

In [49]:
model = TextClassification(vocab_size,embed_size,rnn_hidden_size,fc_hidden_size)
model.cuda()

TextClassification(
  (embedding): Embedding(68516, 20, padding_idx=0)
  (lstm): LSTM(20, 64, batch_first=True)
  (linear_1): Linear(in_features=64, out_features=64, bias=True)
  (relu): ReLU()
  (linear_2): Linear(in_features=64, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

## Model Training and Evaluation

In [50]:
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

In [51]:
def train(dataloader):
  model.train()
  total_loss, total_accuracy = 0.0, 0.0
  for x_batch,y_batch, length in dataloader:
    optimizer.zero_grad()
    pred = model(x_batch,length)[:,0]   ## due to target size mismatch 
    loss = loss_fn(pred,y_batch)
    loss.backward()
    optimizer.step()
    total_accuracy += ((pred >= 0.5) == y_batch).float().sum().item()
    total_loss += loss.item()*y_batch.size(0)
  train_loss = total_loss/len(dataloader.dataset)
  train_accuracy = total_accuracy/len(dataloader.dataset)
  return train_loss, train_accuracy
  

def evaluate(dataloader):
  model.eval()
  total_loss, total_accuracy = 0.0, 0.0
  with torch.no_grad():
    for x_batch, y_batch, length in dataloader:
      pred = model(x_batch,length)[:,0]    ##due to target size mismatch
      loss = loss_fn(pred,y_batch)
      total_accuracy += ((pred >= 0.5).float() == y_batch).float().sum().item()
      total_loss += loss.item()*y_batch.size(0)
  valid_loss = total_loss/len(dataloader.dataset)
  valid_accuracy = total_accuracy/len(dataloader.dataset)
  return valid_loss, valid_accuracy

In [52]:
num_epochs = 10
torch.manual_seed(1)
for epoch in range(num_epochs):
  train_loss, train_accuracy = train(train_dl)
  valid_loss, valid_accuracy = evaluate(valid_dl)
  print(f"Epoch: {epoch}, Train accuracy: {train_accuracy:.4f}, Valid accuracy: {valid_accuracy:.4f}")

Epoch: 0, Train accuracy: 0.6170, Valid accuracy: 0.6516
Epoch: 1, Train accuracy: 0.7409, Valid accuracy: 0.7538
Epoch: 2, Train accuracy: 0.7683, Valid accuracy: 0.7280
Epoch: 3, Train accuracy: 0.8097, Valid accuracy: 0.8246
Epoch: 4, Train accuracy: 0.8802, Valid accuracy: 0.8452
Epoch: 5, Train accuracy: 0.9060, Valid accuracy: 0.8560
Epoch: 6, Train accuracy: 0.9307, Valid accuracy: 0.8604
Epoch: 7, Train accuracy: 0.9472, Valid accuracy: 0.8624
Epoch: 8, Train accuracy: 0.9598, Valid accuracy: 0.8698
Epoch: 9, Train accuracy: 0.9721, Valid accuracy: 0.8700


In [53]:
test_loss, test_accuracy = evaluate(test_dl)
print(f"Test Accuracy: {test_accuracy: .4f}")

Test Accuracy:  0.8548
