In [1]:
import os
import sys
import math
from typing import *
from tqdm.auto import tqdm
import torch
from torch import nn, Tensor, optim
from torch.utils.data import DataLoader
import datasets
from datasets import *
from transformers import AutoTokenizer, DataCollatorWithPadding
import project_paths as pp

In [2]:
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(pp.current_file_path), '..', '..')))
from models import RNN

In [3]:
dataset_folder_path = os.path.join(pp.aclImdb_dataset_folder_path, 'train')
dataset = datasets.load_from_disk(dataset_folder_path)

train_and_val_datasets = dataset.train_test_split(test_size=0.33)
train_dataset = train_and_val_datasets['train']
val_dataset = train_and_val_datasets['test']

len_train_dataset = len(train_dataset)
num_pos_instances_in_train_dataset = len(train_dataset.filter(lambda item: item['label'] == 'pos'))
num_neg_instances_in_train_dataset = len_train_dataset - num_pos_instances_in_train_dataset

len_val_dataset = len(val_dataset)
num_pos_instances_in_val_dataset = len(val_dataset.filter(lambda item: item['label'] == 'pos'))
num_neg_instances_in_val_dataset = len_val_dataset - num_pos_instances_in_val_dataset

Filter:   0%|          | 0/16750 [00:00<?, ? examples/s]

Filter:   0%|          | 0/8250 [00:00<?, ? examples/s]

In [4]:
tokenizer_folder_path = os.path.join(pp.word_piece_tokenizer_folder_path, 'aclImdb_4096')
tokenizer = AutoTokenizer.from_pretrained(tokenizer_folder_path)
tokenizer.padding_side = 'left'

In [5]:
vocab_size = tokenizer.vocab_size
embedding_dim = 256
hidden_size = 256
output_size = 1
num_rnn_cells = 3

model = RNN(vocab_size, embedding_dim, hidden_size, output_size, num_rnn_cells)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

  self.register_buffer('input_dim', torch.tensor(input_dim))
  self.register_buffer('hidden_dim', torch.tensor(hidden_dim))


RNN(
  (embedding): Embedding(4096, 256)
  (rnn_cells): ModuleList(
    (0-2): 3 x RNNCell(
      (W_h): Linear(in_features=512, out_features=256, bias=True)
      (f_h): Tanh()
    )
  )
  (output_layer): Linear(in_features=256, out_features=1, bias=False)
)

In [6]:
learning_rate = 1e-3
train_batch_size = 256
val_batch_size = 256
num_epochs = 5
num_train_batches = math.ceil(len(train_dataset) / train_batch_size)
num_val_batches = math.ceil(len(val_dataset) / val_batch_size)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

In [7]:
label_map = {'neg': 0, 'pos': 1}
def collate_fn(batch):
    texts = [item['text'] for item in batch]
    labels = [[label_map[item['label']]] for item in batch]    
    encodings = tokenizer(texts, padding=True, truncation=True, max_length=1024, return_tensors='pt')
    input_ids = encodings['input_ids']
    # attention_mask = encodings['attention_mask']
    labels = torch.tensor(labels, dtype=torch.float32)

    return {
        'texts': texts,
        'input_ids': input_ids,
        'labels': labels
    }

train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=True, collate_fn=collate_fn)

In [None]:
torch.autograd.set_detect_anomaly(True)
progress_bar = tqdm(total=num_epochs * num_train_batches, dynamic_ncols=True)
for epoch_idx in range(num_epochs):
    model.train()
    epoch_loss = 0

    for batch_idx, batch in enumerate(train_dataloader):
        texts = batch['texts']
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        output_logits = model(input_ids)[:, -1, :]

        loss = criterion(output_logits, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        progress_bar.set_description(f'Batch loss: {round(loss.item(), 3)}')
        progress_bar.update(1)

    avg_loss = epoch_loss / len_train_dataset
    print(f'Epoch {epoch_idx + 1} average loss: {round(avg_loss, 3)}')
progress_bar.close()

  0%|                                                                                                         …

In [None]:
model = RNN(vocab_size, embedding_dim, hidden_size, output_size, num_cells)
model_save_file_path = os.path.join(pp.rnn_models_folder_path, f'{str(num_rnn_cells).zfill(2)}.pth')
model.load_state_dict(torch.load(model_save_file_path))

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

In [None]:
tp = fp = tn = fn = 0
progress_bar = tqdm(total=num_val_batches, dynamic_ncols=True)
for batch_idx, batch in enumerate(val_dataloader):
    input_ids, labels = [item.to(device) for item in batch]

    output_logits = model(input_ids)
    probs = nn.functional.sigmoid(output_logits)
    probs[probs >= 0.5] = 1
    probs[probs < 0.5] = 0

    tp += ((probs == 1.0) & (labels == 1.0)).sum()
    fp += ((probs == 1.0) & (labels == 0.0)).sum()
    tn += ((probs == 0.0) & (labels == 0.0)).sum()
    fn += ((probs == 0.0) & (labels == 1.0)).sum()
    progress_bar.update(1)

progress_bar.close()

accuracy = (tp + tn) / (tp + fp + tn + fn)
precision = (tp) / (tp + fp)
recall = (tp) / (tp + fn)

print(f'Accuracy: {accuracy}')
print(f'Precision: {precision}')
print(f'Recall: {recall}')