In [2]:
import torch
import torchtext
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import time

In [3]:
start = time.time()
TEXT = torchtext.legacy.data.Field(lower=True, fix_length=200, batch_first=False)
LABEL = torchtext.legacy.data.Field(sequential=False)

In [4]:
from torchtext.legacy import datasets
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

downloading aclImdb_v1.tar.gz


100%|██████████| 84.1M/84.1M [00:02<00:00, 30.1MB/s]


In [5]:
print(vars(train_data.examples[0]))

{'text': ['spoilers.', 'like', 'other', 'posters,', 'i', 'felt', 'that', 'the', 'ending', 'was', 'a', 'bit', 'abrupt.', 'i', 'would', 'have', 'liked', 'to', 'have', 'seen', 'the', 'crew', 'adjusting', 'to', 'life', 'back', 'on', 'earth', 'after', 'their', 'return.', 'i', 'suppose', 'the', 'writers', 'anticipated', 'this', 'problem', 'by', '"front', 'loading"', 'some', 'voyager', 'on', 'earth', 'sequences', 'at', 'the', 'beginning', 'of', 'the', 'episode.', '(of', 'course,', 'that', 'time', 'line', 'has', 'been', 'eradicated,', 'so', "it's", 'all', 'moot.)', 'i', 'did', 'like', 'how', 'admiral', 'janeway', 'died', 'for', 'the', 'voyager', 'crew.', 'as', 'fans,', 'we', 'get', 'to', 'have', 'our', 'cake', 'and', 'eat', 'it', 'to,', 'by', 'having', 'janeway', 'both', 'make', 'the', 'ultimate', 'sacrifice', 'and', 'live', 'on.', 'i', 'admit', 'that', 'the', 'scenes', 'of', 'janeway', 'and', 'her', 'older', 'self', 'having', 'conversations', 'was', 'bizarre', 'and', 'so', 'easily', 'could', 

In [6]:
import string
for example in train_data.examples:
  text = [x.lower() for x in vars(example)['text']]
  text = [x.replace('<br', '') for x in text]
  text = [''.join(c for c in s if c not in string.punctuation) for s in text]
  text = [s for s in text if s]
  vars(example)['text'] = text

In [7]:
import random
train_data, valid_data = train_data.split(random_state=random.seed(0), split_ratio=0.8)

In [8]:
print(f'Number of training examples: {len(train_data)}')
print(f'Number of validation examples: {len(valid_data)}')
print(f'Number of testing examples: {len(test_data)}')

Number of training examples: 20000
Number of validation examples: 5000
Number of testing examples: 25000


In [9]:
TEXT.build_vocab(train_data, max_size=10000, min_freq=10, vectors=None)
LABEL.build_vocab(train_data)

print(f'Unique tokens in TEXT vocabulary: {len(TEXT.vocab)}')
print(f'Unique tokens in LABEL vocabulary: {len(LABEL.vocab)}')

Unique tokens in TEXT vocabulary: 10002
Unique tokens in LABEL vocabulary: 3


In [10]:
BATCH_SIZE = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

embedding_dim = 100
hidden_size = 300

train_iterator, valid_iterator, test_iterator = torchtext.legacy.data.BucketIterator.splits((train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)

In [11]:
class RNNCell_Encoder(nn.Module):
  def __init__(self, input_dim, hidden_size):
    super().__init__()
    self.rnn = nn.RNNCell(input_dim, hidden_size)

  def forward(self, inputs):
    bz = inputs.shape[1]
    ht = torch.zeros((bz, hidden_size)).to(device)
    for word in inputs:
      ht = self.rnn(word, ht)
    return ht

class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.em = nn.Embedding(len(TEXT.vocab.stoi), embedding_dim)
    self.rnn = RNNCell_Encoder(embedding_dim, hidden_size)
    self.fc1 = nn.Linear(hidden_size, 256)
    self.fc2 = nn.Linear(256, 3)

  def forward(self, x):
    x = self.em(x)
    x = self.rnn(x)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return x

In [13]:
model = Net()
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [14]:
def training(epoch, model, trainloader, validloader):
  correct = 0
  total = 0
  running_loss = 0

  model.train()

  for b in trainloader:
    x, y = b.text, b.label
    x, y = x.to(device), y.to(device)
    y_pred = model(x)
    loss = criterion(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    with torch.no_grad():
      y_pred = torch.argmax(y_pred, dim=1)
      correct += (y_pred == y).sum().item()
      total += y.size(0)
      running_loss += loss.item()

  epoch_loss = running_loss / len(trainloader.dataset)
  epoch_acc = correct / total

  valid_correct = 0
  valid_total = 0
  valid_running_loss = 0

  model.eval()
  with torch.no_grad():
    for b in validloader:
      x, y = b.text, b.label
      x, y = x.to(device), y.to(device)
      y_pred = model(x)
      loss = criterion(y_pred, y)
      y_pred = torch.argmax(y_pred, dim=1)
      valid_correct += (y_pred == y).sum().item()
      valid_total += y.size(0)
      valid_running_loss += loss.item()

  epoch_valid_loss = valid_running_loss / len(validloader.dataset)
  epoch_valid_acc = valid_correct / valid_total

  print('epoch: ', epoch)
  print('loss: ', round(epoch_loss, 3))
  print('accuracy: ', round(epoch_acc, 3))
  print('valid_loss: ', round(epoch_valid_loss, 3))
  print('valid_accuracy: ', round(epoch_valid_acc, 3))

  return epoch_loss, epoch_acc, epoch_valid_loss, epoch_valid_acc

In [15]:
epochs = 15
train_loss = []
train_acc = []
valid_loss = []
valid_acc = []

for epoch in range(epochs):
  epoch_loss, epoch_acc, epoch_valid_loss, epoch_valid_acc = training(epoch, model, train_iterator, valid_iterator)
  train_loss.append(epoch_loss)
  train_acc.append(epoch_acc)
  valid_loss.append(epoch_valid_loss)
  valid_acc.append(epoch_valid_acc)

end = time.time()
print(f'Training done in {(end-start / 60)}m {(end-start) % 60}s')

epoch:  0
loss:  0.011
accuracy:  0.495
valid_loss:  0.011
valid_accuracy:  0.506
epoch:  1
loss:  0.011
accuracy:  0.506
valid_loss:  0.011
valid_accuracy:  0.491
epoch:  2
loss:  0.011
accuracy:  0.514
valid_loss:  0.011
valid_accuracy:  0.493
epoch:  3
loss:  0.011
accuracy:  0.517
valid_loss:  0.011
valid_accuracy:  0.49
epoch:  4
loss:  0.011
accuracy:  0.527
valid_loss:  0.011
valid_accuracy:  0.504
epoch:  5
loss:  0.011
accuracy:  0.531
valid_loss:  0.011
valid_accuracy:  0.491
epoch:  6
loss:  0.011
accuracy:  0.543
valid_loss:  0.011
valid_accuracy:  0.51
epoch:  7
loss:  0.011
accuracy:  0.548
valid_loss:  0.011
valid_accuracy:  0.498
epoch:  8
loss:  0.01
accuracy:  0.558
valid_loss:  0.011
valid_accuracy:  0.513
epoch:  9
loss:  0.01
accuracy:  0.563
valid_loss:  0.011
valid_accuracy:  0.523
epoch:  10
loss:  0.01
accuracy:  0.575
valid_loss:  0.011
valid_accuracy:  0.504
epoch:  11
loss:  0.01
accuracy:  0.584
valid_loss:  0.011
valid_accuracy:  0.51
epoch:  12
loss:  0.0