In [1]:
import torch
from torchtext.legacy import data, datasets
import torch.nn as nn
import torch.optim as optim
from torchtext.data.utils import get_tokenizer
import time
import random

# Set random seed for reproducibility
SEED = 42

torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Define the Fields for processing the dataset
tokenizer = get_tokenizer('basic_english')

TEXT = data.Field(tokenize=tokenizer, include_lengths=True)
LABEL = data.LabelField(dtype=torch.float)

# Load the IMDb dataset
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

train_data, _ = train_data.split(split_ratio=0.01, random_state=random.seed(SEED))


# Split the training data to create a validation set
# train_data, valid_data = train_data.split(random_state = torch.manual_seed(SEED))
train_data, valid_data = train_data.split(split_ratio=0.7)

# Build the vocabulary and load pre-trained word embeddings (GloVe)
MAX_VOCAB_SIZE = 25_000

TEXT.build_vocab(train_data, 
                 max_size = MAX_VOCAB_SIZE, 
                 vectors = "glove.6B.100d", 
                 unk_init = torch.Tensor.normal_)

LABEL.build_vocab(train_data)

# Create iterators for the data
BATCH_SIZE = 64

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE,
    sort_within_batch = True,
    device = device)

In [2]:
class RNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, 
                 bidirectional, dropout, pad_idx):
        
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)
        
        self.rnn = nn.LSTM(embedding_dim, 
                           hidden_dim, 
                           num_layers=n_layers, 
                           bidirectional=bidirectional, 
                           dropout=dropout)
        
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, text, text_lengths):
        
        #text = [sent len, batch size]
        
        embedded = self.dropout(self.embedding(text))
        
        #embedded = [sent len, batch size, emb dim]
        
        # Pack sequence
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu'))
        packed_output, (hidden, cell) = self.rnn(packed_embedded)
        
        # Unpack sequence
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)
        
        #hidden = [num layers * num directions, batch size, hid dim]
        #cell = [num layers * num directions, batch size, hid dim]
        
        # Concat the final forward (hidden[-2,:,:]) and backward (hidden[-1,:,:]) hidden layers
        # and apply dropout
        
        hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))
                
        return self.fc(hidden)

In [3]:
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 256
OUTPUT_DIM = 1
N_LAYERS = 5
BIDIRECTIONAL = True
DROPOUT = 0.5
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]

model = RNN(INPUT_DIM, 
            EMBEDDING_DIM, 
            HIDDEN_DIM, 
            OUTPUT_DIM, 
            N_LAYERS, 
            BIDIRECTIONAL, 
            DROPOUT, 
            PAD_IDX)

# Load the pre-trained embeddings
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)

# Zero the initial weights of the unknown and padding tokens
UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]

model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)
model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)

optimizer = optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss()

model = model.to(device)
criterion = criterion.to(device)

In [4]:
def binary_accuracy(preds, y):
    """Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8"""
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float()  # convert into float for division 
    acc = correct.sum() / len(correct)
    return acc

def train(model, iterator, optimizer, criterion):
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        
        optimizer.zero_grad()
        
        text, text_lengths = batch.text
        
        predictions = model(text, text_lengths).squeeze(1)
        
        loss = criterion(predictions, batch.label)
        acc = binary_accuracy(predictions, batch.label)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate(model, iterator, criterion):
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            text, text_lengths = batch.text
            
            predictions = model(text, text_lengths).squeeze(1)
            
            loss = criterion(predictions, batch.label)
            acc = binary_accuracy(predictions, batch.label)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [5]:
N_EPOCHS = 100

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = divmod(end_time - start_time, 60)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

Epoch: 01 | Epoch Time: 0.0m 1.3118689060211182s
	Train Loss: 0.696 | Train Acc: 47.04%
	 Val. Loss: 0.698 |  Val. Acc: 42.40%
Epoch: 02 | Epoch Time: 0.0m 1.2832648754119873s
	Train Loss: 0.700 | Train Acc: 54.00%
	 Val. Loss: 0.719 |  Val. Acc: 42.40%
Epoch: 03 | Epoch Time: 0.0m 1.293835163116455s
	Train Loss: 0.689 | Train Acc: 54.00%
	 Val. Loss: 0.708 |  Val. Acc: 42.40%
Epoch: 04 | Epoch Time: 0.0m 1.2983603477478027s
	Train Loss: 0.694 | Train Acc: 54.00%
	 Val. Loss: 0.698 |  Val. Acc: 42.40%
Epoch: 05 | Epoch Time: 0.0m 1.3104240894317627s
	Train Loss: 0.692 | Train Acc: 54.00%
	 Val. Loss: 0.701 |  Val. Acc: 42.40%
Epoch: 06 | Epoch Time: 0.0m 1.3100824356079102s
	Train Loss: 0.690 | Train Acc: 54.00%
	 Val. Loss: 0.710 |  Val. Acc: 42.40%
Epoch: 07 | Epoch Time: 0.0m 1.2928311824798584s
	Train Loss: 0.690 | Train Acc: 54.00%
	 Val. Loss: 0.719 |  Val. Acc: 42.40%
Epoch: 08 | Epoch Time: 0.0m 1.2923288345336914s
	Train Loss: 0.686 | Train Acc: 53.48%
	 Val. Loss: 0.714 |  Va

Epoch: 66 | Epoch Time: 0.0m 1.2571191787719727s
	Train Loss: 0.007 | Train Acc: 99.48%
	 Val. Loss: 2.510 |  Val. Acc: 50.99%
Epoch: 67 | Epoch Time: 0.0m 1.262887954711914s
	Train Loss: 0.036 | Train Acc: 98.96%
	 Val. Loss: 2.960 |  Val. Acc: 51.63%
Epoch: 68 | Epoch Time: 0.0m 1.2530179023742676s
	Train Loss: 0.028 | Train Acc: 98.96%
	 Val. Loss: 2.587 |  Val. Acc: 45.67%
Epoch: 69 | Epoch Time: 0.0m 1.2600862979888916s
	Train Loss: 0.012 | Train Acc: 99.48%
	 Val. Loss: 2.466 |  Val. Acc: 48.65%
Epoch: 70 | Epoch Time: 0.0m 1.2603058815002441s
	Train Loss: 0.004 | Train Acc: 100.00%
	 Val. Loss: 2.419 |  Val. Acc: 49.43%
Epoch: 71 | Epoch Time: 0.0m 1.2560935020446777s
	Train Loss: 0.015 | Train Acc: 99.48%
	 Val. Loss: 2.337 |  Val. Acc: 47.23%
Epoch: 72 | Epoch Time: 0.0m 1.2604327201843262s
	Train Loss: 0.024 | Train Acc: 98.44%
	 Val. Loss: 2.395 |  Val. Acc: 44.11%
Epoch: 73 | Epoch Time: 0.0m 1.2577569484710693s
	Train Loss: 0.024 | Train Acc: 98.44%
	 Val. Loss: 2.600 |  V