In [None]:
# if needed in Colab run the following command:
# !pip install torch==2.2.0 torchtext==0.17.0
# !pip install portalocker>=2.0.0
# !pip install transformers 
# !pip install torchmetrics

import torch
from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
from torchmetrics.classification import BinaryF1Score
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
import time
import hashlib
from transformers import DistilBertTokenizerFast, DistilBertModel

In [None]:
# Takes a long time...
!wget https://nlp.stanford.edu/data/glove.6B.zip
!unzip glove.6B.zip > /dev/null 2>&1

In [None]:
# Global variables
BATCH_SIZE = 8
EMBEDDING_DIM = 100
HIDDEN_DIM = 128
OUTPUT_DIM = 1
DROPOUT = 0.2
NUM_LAYERS = 2
NUM_EPOCHS = 5
GLOVE_PATH = "glove.6B.100d.txt"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

## Building Vocab using tokenizer

In [None]:
# Build vocab
def yield_tokens(data_iter, tokenizer, model_name='basic'):
    if model_name == 'basic':
        for _, text in data_iter:
            yield tokenizer(text)
            
train_iter = IMDB(split='train')

### Basic vocab

In [None]:
# Load tokenizer
basic_tokenizer = get_tokenizer("basic_english")        
basic_vocab = build_vocab_from_iterator(yield_tokens(train_iter, basic_tokenizer, model_name='basic'), specials=['<unk>'])
basic_vocab.set_default_index(basic_vocab['<unk>'])  # Default index for unknown words
BASIC_VOCAB_SIZE = len(basic_vocab)
BASIC_PAD_IDX = basic_vocab['<unk>'] # 0
print(f"Basic vocab length: {BASIC_VOCAB_SIZE}")

## Checking datas and building loaders

In [None]:
# Function for calculating a unique hash for each data item
def hash_data(data):
    _, text = data
    return hashlib.md5(text.encode('utf-8')).hexdigest()

# Deletes duplicates in the dataset
def remove_duplicates(data):
    seen = set() 
    unique_data = [] 
    for item in data:
        data_hash = hash_data(item)  
        if data_hash not in seen:
            seen.add(data_hash)  
            unique_data.append(item)
    return unique_data

def verify_no_overlap(train_datas, val_datas, test_datas, hash_fn):
    # Hashes calculations
    train_hashes = set(hash_fn(data) for data in train_datas)
    val_hashes = set(hash_fn(data) for data in val_datas)
    test_hashes = set(hash_fn(data) for data in test_datas)
    
    assert train_hashes.isdisjoint(val_hashes), "Overlap between train and validation"
    assert train_hashes.isdisjoint(test_hashes), "Overlap between train and test"
    assert val_hashes.isdisjoint(test_hashes), "Overlap between validation and test"
    
    print("The data sets are well disjointed.")

In [None]:
train_iter = IMDB(split='train')
test_iter = IMDB(split='test')
all_data = list(train_iter) + list(test_iter)
all_data = remove_duplicates(all_data)
all_hashes = [hash_data(data) for data in all_data]

# Check for duplicates
if len(all_hashes) != len(set(all_hashes)):
    print("Caution: Duplicate data exists!")
else:
    print("All data is unique.")

In [None]:
TRAIN_SIZE = int(0.8 * len(all_data))
VAL_SIZE = int(0.1 * len(all_data))
TEST_SIZE = len(all_data) - TRAIN_SIZE - VAL_SIZE

train_datas, val_datas, test_datas = random_split(all_data, [TRAIN_SIZE, VAL_SIZE, TEST_SIZE])
verify_no_overlap(train_datas, val_datas, test_datas, hash_fn=hash_data) # Useless if we called 'remove_duplicates' previously

### Basic loaders

In [None]:
def basic_collate_batch(batch):
    text_list, label_list = [], []
    for label, text in batch:
        text_list.append(torch.tensor(basic_vocab(basic_tokenizer(text)), dtype=torch.int32))
        label_list.append(label-1)
    text_tensor = pad_sequence(text_list, batch_first=True, padding_value=BASIC_PAD_IDX)
    label_tensor = torch.tensor(label_list, dtype=torch.float32)
    return text_tensor, label_tensor

In [None]:
basic_train_loader = DataLoader(train_datas, batch_size=BATCH_SIZE, shuffle=True, collate_fn=basic_collate_batch, drop_last=True)
basic_val_loader = DataLoader(val_datas, batch_size=BATCH_SIZE, shuffle=True, collate_fn=basic_collate_batch, drop_last=False)
basic_test_loader = DataLoader(test_datas, batch_size=1, shuffle=False, collate_fn=basic_collate_batch, drop_last=False)

# Print loaders' sizes
print(f"Size train loader : {len(basic_train_loader.dataset)}")
print(f"Size validation loader : {len(basic_val_loader.dataset)}")
print(f"Size test loader : {len(basic_test_loader.dataset)}")

### Bert loaders

In [None]:
bert_tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
BERT_PAD_IDX = bert_tokenizer.pad_token_id # 0

def bert_collate_batch(batch):
    text_list, label_list = [], []
    for label, text in batch:
        item = bert_tokenizer(text, truncation=True, return_tensors="pt")  # could add: padding=True
        text_list.append(item["input_ids"].squeeze(0))
        label_list.append(label-1)  
    padded_text_ids = pad_sequence(text_list, batch_first=True, padding_value=BERT_PAD_IDX)
    label_list = torch.tensor(label_list, dtype=torch.float32)
    return padded_text_ids, label_list

# To get the attention from item use:
# item["attention_mask"]        

In [None]:
bert_train_loader = DataLoader(train_datas, batch_size=BATCH_SIZE, shuffle=True, collate_fn=bert_collate_batch, drop_last=True)
bert_val_loader = DataLoader(val_datas, batch_size=BATCH_SIZE, shuffle=True, collate_fn=bert_collate_batch, drop_last=False)
bert_test_loader = DataLoader(test_datas, batch_size=1, shuffle=False, collate_fn=bert_collate_batch, drop_last=False)

# Print loaders' sizes
print(f"Size train loader : {len(bert_train_loader.dataset)}")
print(f"Size validation loader : {len(bert_val_loader.dataset)}")
print(f"Size test loader : {len(bert_test_loader.dataset)}")

## Loading pre-trained embedings

### Glove (to be used with the basic vocab)

In [None]:
def load_glove_embeddings(vocab, path=GLOVE_PATH, embedding_dim=EMBEDDING_DIM):
    # Load GloVe embeddings into a dictionary
    glove_embeddings = {}
    with open(path, 'r', encoding='utf8') as f:
        for line in f:
            values = line.split()
            word = values[0]
            vector = torch.tensor([float(val) for val in values[1:embedding_dim+1]], dtype=torch.float32) # works if we choose an emb dim < 100
            glove_embeddings[word] = vector

    # Create a weights matrix for words in vocab
    weights_matrix = torch.zeros((len(vocab), embedding_dim))
    for word, idx in vocab.get_stoi().items():
        if word in glove_embeddings:
            weights_matrix[idx] = glove_embeddings[word]
        else:
            weights_matrix[idx] = torch.zeros(embedding_dim)
    return weights_matrix

GLOVE_EMBS = load_glove_embeddings(vocab)
print(GLOVE_EMBS.shape)

### Bert

In [None]:
bert_model = DistilBertModel.from_pretrained("distilbert-base-uncased")
BERT_EMBS = bert_model.embeddings.word_embeddings.weight
BERT_EMB_SIZE = 768
print(BERT_EMBS.shape) 

## Define models

In [None]:
class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size=BASIC_VOCAB_SIZE, embed_size=EMBEDDING_DIM, hidden_size=HIDDEN_DIM,
                 output_size=OUTPUT_DIM, dropout=DROPOUT, num_layers=NUM_LAYERS, pretrained_embs=GLOVE_EMBS, pretrained=False):
        super(LSTMClassifier, self).__init__()
        if pretrained:
            self.embedding = nn.Embedding.from_pretrained(pretrained_embs, freeze=True) # To ensure weights are not trained
        else:
            self.embedding = nn.Embedding(vocab_size, embed_size)

        self.add_fc = False
        if pretrained and pretrained_embs.size(1) != embed_size:
            self.fc1 = nn.Linear(pretrained_embs.size(1), embed_size)
            self.add_fc = True
            
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        embedded = self.embedding(x)
        if self.add_fc:
            embedded = self.fc1(embedded)
        lstm_out, (hidden, _) = self.lstm(embedded)
        output = self.fc2(hidden[-1])
        return self.sigmoid(output)

In [None]:
class GRUClassifier(nn.Module):
    def __init__(self, vocab_size=BASIC_VOCAB_SIZE, embed_size=EMBEDDING_DIM, hidden_size=HIDDEN_DIM,
                 output_size=OUTPUT_DIM, dropout=DROPOUT, num_layers=NUM_LAYERS, pretrained_embs=GLOVE_EMBS, pretrained=False):
        super(GRUClassifier, self).__init__()
        if pretrained:
            self.embedding = nn.Embedding.from_pretrained(pretrained_embs, freeze=True) # To ensure weights are not trained
        else:
            self.embedding = nn.Embedding(vocab_size, embed_size)

        self.add_fc = False
        if pretrained and pretrained_embs.size(1) != embed_size:
            self.fc1 = nn.Linear(pretrained_embs.size(1), embed_size)
            self.add_fc = True
            
        self.gru = nn.GRU(embed_size, hidden_size, num_layers=num_layers,
                          dropout=dropout, batch_first=True)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        embedded = self.embedding(x)
        if self.add_fc:
            embedded = self.fc1(embedded)
        gru_out, _ = self.gru(embedded)
        last_hidden_state = gru_out[:, -1, :]
        output = self.fc2(last_hidden_state)
        return self.sigmoid(output)

## Train and test functions

In [None]:
# Train function
def train_epoch(model, data_loader, criterion, optimizer):
    model.train()
    total_loss, total_acc = 0, 0
    size_loader = 0
    f1_metric = BinaryF1Score().to(DEVICE)
    for text, labels in tqdm(data_loader):
        text, labels = text.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        output = model(text).squeeze(dim=1) # could use .squeeze() if drop_last=True
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        # Metrics part
        total_loss += loss.item()
        total_acc += ((output >= 0.5) == labels).sum().item()
        size_loader += labels.size(0)
        preds = (output >= 0.5)
        f1_metric.update(preds, labels)
    
    f1 = f1_metric.compute()
    f1_metric.reset()
    return total_loss / size_loader, total_acc / size_loader, f1

# Test function
def evaluate(model, data_loader, criterion):
    model.eval()
    total_loss, total_acc = 0, 0
    size_loader = 0
    f1_metric = BinaryF1Score().to(DEVICE)
    with torch.no_grad():
        for text, labels in tqdm(data_loader):
            text, labels = text.to(DEVICE), labels.to(DEVICE)
            output = model(text).squeeze(dim=1)
            loss = criterion(output, labels)

            # Metrics part
            total_loss += loss.item()
            total_acc += ((output >= 0.5) == labels).sum().item()
            size_loader += labels.size(0)
            preds = (output >= 0.5)
            f1_metric.update(preds, labels)
    
    f1 = f1_metric.compute()
    f1_metric.reset()
    return total_loss / size_loader, total_acc / size_loader, f1
    
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

### Basic training

In [None]:
model = GRUClassifier(pretrained=False)
model.to(DEVICE)
criterion = nn.BCELoss(reduction='sum')
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(f"Number of trainable params: {count_trainable_parameters(model)}")

for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    train_loss, train_acc, train_f1 = train_epoch(model, basic_train_loader, criterion, optimizer)
    val_loss, val_acc, val_f1 = evaluate(model, basic_val_loader, criterion)
    end_time = time.time()

    print(f"Epoch: {epoch+1}/{NUM_EPOCHS} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
          f"Train F1: {train_f1:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f} | Time: {end_time-start_time:.2f}s")

### Bert training

In [None]:
model = GRUClassifier(pretrained=True, pretrained_embs=BERT_EMBS, embed_size=BERT_EMB_SIZE)
model.to(DEVICE)
criterion = nn.BCELoss(reduction='sum')
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
print(f"Number of trainable params: {count_trainable_parameters(model)}")

for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    train_loss, train_acc, train_f1 = train_epoch(model, bert_train_loader, criterion, optimizer)
    val_loss, val_acc, val_f1 = evaluate(model, bert_val_loader, criterion)
    end_time = time.time()

    print(f"Epoch: {epoch+1}/{NUM_EPOCHS} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
          f"Train F1: {train_f1:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f} | Time: {end_time-start_time:.2f}s")