In [None]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler, random_split
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import copy
import transformers
import numpy as np
import random
from transformers import AdamW, BertTokenizer, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm import tqdm
# from model import SentimentClassifierWithMultipleHeads

In [None]:
dataset = load_dataset("imdb")

In [None]:
df_train = dataset['train'].to_pandas()
# split train into train and validation
df_train, df_val = train_test_split(df_train, test_size=0.1, random_state=42)
# df_val = dataset['validation'].to_pandas()
# split validation into validation and test
# df_val, df_test = train_test_split(df_val, test_size=0.5, random_state=42)
df_test = dataset['test'].to_pandas()

In [None]:
text_column_name = 'text'

In [None]:
train_messages = df_train[text_column_name].to_list()
train_labels = df_train['label'].to_list()
val_messages = df_val[text_column_name].to_list()
val_labels = df_val['label'].to_list()
test_messages = df_test[text_column_name].to_list()
test_labels = df_test['label'].to_list()

In [None]:
df_train

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [None]:
print(' Original: ', train_messages[0])

# Print the text split into tokens.
print('Tokenized: ', tokenizer.tokenize(train_messages[0]))

# Print the text mapped to token ids.
print('Token IDs: ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(train_messages[0])))

In [None]:
max_len = 512

In [None]:
def tokenize_texts(texts, tokenizer, max_len):
    input_ids = []
    attention_masks = []

    for sent in texts:
        encoded_dict = tokenizer.encode_plus(
                            sent,                      # text to encode.
                            add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                            max_length = max_len,           # Pad & truncate all texts.
                            pad_to_max_length = True,
                            return_attention_mask = True,   # Construct attn. masks.
                            return_tensors = 'pt',     # Return pytorch tensors.
                       )

        input_ids.append(encoded_dict['input_ids'])

        attention_masks.append(encoded_dict['attention_mask'])

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)

    return input_ids, attention_masks

In [None]:
train_inputs, train_masks = tokenize_texts(train_messages, tokenizer, max_len)
val_inputs, val_masks = tokenize_texts(val_messages, tokenizer, max_len)
test_inputs, test_masks = tokenize_texts(test_messages, tokenizer, max_len)

train_labels = torch.tensor(train_labels)
val_labels = torch.tensor(val_labels)
test_labels = torch.tensor(test_labels)

In [None]:
# print all shapes
print("Train Messages: ", len(train_messages))
print("Train Inputs: ", train_inputs.shape)
print("Train Masks: ", train_masks.shape)
print("Train Labels: ", train_labels.shape)
print("Validation Messages: ", len(val_messages))
print("Validation Inputs: ", val_inputs.shape)
print("Validation Masks: ", val_masks.shape)
print("Validation Labels: ", val_labels.shape)
print("Test Messages: ", len(test_messages))
print("Test Inputs: ", test_inputs.shape)
print("Test Masks: ", test_masks.shape)
print("Test Labels: ", test_labels.shape)

In [None]:
train_dataset = TensorDataset(train_inputs, train_masks, train_labels)
val_dataset = TensorDataset(val_inputs, val_masks, val_labels)
test_dataset = TensorDataset(test_inputs, test_masks, test_labels)

In [None]:
batch_size = 32

train_dataloader = DataLoader(
            train_dataset,  # The training samples.
            sampler = RandomSampler(train_dataset), # Select batches randomly
            batch_size = batch_size # Trains with this batch size.
        )

validation_dataloader = DataLoader(
            val_dataset, # The validation samples.
            sampler = SequentialSampler(val_dataset), # Pull out batches sequentially.
            batch_size = batch_size # Evaluate with this batch size.
        )

test_dataloader = DataLoader(
            test_dataset, # The validation samples.
            sampler = SequentialSampler(test_dataset), # Pull out batches sequentially.
            batch_size = batch_size # Evaluate with this batch size.
        )

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
device

In [None]:
from transformers import AutoModel, AutoTokenizer
import torch
from torch import nn

class SentimentClassifierWithMultipleHeads(nn.Module):
    def __init__(self, model_name, num_labels):
        super(SentimentClassifierWithMultipleHeads, self).__init__()
        self.model = AutoModel.from_pretrained(model_name, num_labels=num_labels)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.num_labels = num_labels
        # 12 heads for BERT
        self.classification_heads = [torch.nn.Sequential(
            torch.nn.Linear(768, 128),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(128, self.num_labels)).to(device) for _ in range(12)]
        


    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(input_ids, attention_mask=attention_mask, output_hidden_states=True, output_attentions=True)
        hidden_states = outputs[2]
        hidden_states = [torch.mean(layer, dim=1) for layer in hidden_states]
        logits = [head(hidden_states[i+1]) for i, head in enumerate(self.classification_heads)]
        probs = [torch.nn.functional.softmax(logit, dim=-1) for logit in logits]
        loss = [torch.nn.functional.cross_entropy(logit, labels.float()) for logit in logits]
        
        return loss, logits

    def predict(self, input_ids, attention_mask):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        hidden_states = outputs[2]
        hidden_states = [torch.mean(layer, dim=1) for layer in hidden_states]
        
        # pass each layer to its own classification head
        logits = [head(hidden_states[i+1]) for i, head in enumerate(self.classification_heads)]
        
        # Take softmax of logits to get probabilities
        probs = [torch.nn.functional.softmax(logit, dim=-1) for logit in logits]
        
        # Get predictions for each head
        predictions = [torch.argmax(prob, dim=-1) for prob in probs]
        
        return predictions

In [None]:
model = SentimentClassifierWithMultipleHeads('bert-base-uncased', 2)
model = model.to(device)

optimizer = AdamW(model.parameters())

epochs = 10

total_steps = len(train_dataloader) * epochs

scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)


def compute_aggregate_and_classwise_metrics(predicted_labels, true_labels):
    classification_metrics = classification_report(true_labels, predicted_labels, output_dict=True)
    return classification_metrics

In [None]:
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
all_stats = []

for epoch_i in range(0, epochs):
    total_train_loss = 0
    model.train()
    train_dataloader = tqdm(train_dataloader, desc=f"Epoch {epoch_i + 1}")
    for step, batch in enumerate(train_dataloader):
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        b_labels = torch.nn.functional.one_hot(b_labels, num_classes=2).to(device)
        optimizer.zero_grad()
        outputs = model(b_input_ids, b_input_mask, b_labels)
        loss = outputs[0]
        # sum loss of all heads
        loss = sum(loss)
        total_train_loss += loss.item()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
    avg_train_loss = total_train_loss / len(train_dataloader)
    print("Average train loss: {}".format(avg_train_loss))

    model.eval()
    eval_preds = {i: [] for i in range(12)}
    eval_labels = []
    total_eval_loss = 0
    nb_eval_steps = 0
    validation_dataloader = tqdm(validation_dataloader, desc="Validation")
    for batch in validation_dataloader:
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        b_labels_one_hot = torch.nn.functional.one_hot(b_labels, num_classes=2).to(device)
        with torch.no_grad():
            outputs = model(b_input_ids, b_input_mask, b_labels_one_hot)
        loss = outputs[0]
        logits = [logit.detach().cpu().numpy() for logit in outputs[1]]  # Detach each tensor in the list
        loss = sum(loss)
        total_eval_loss += loss.item()
        label_ids = b_labels.to('cpu').numpy()
        predictions = [np.argmax(logit, axis=1).flatten() for logit in logits]
        for i in range(12):
            eval_preds[i].extend(predictions[i])
        eval_labels.extend(label_ids)
    val_metrics = [compute_aggregate_and_classwise_metrics(eval_preds[i], eval_labels) for i in range(12)]
    avg_val_loss = total_eval_loss / len(validation_dataloader)
    print("Validation Loss: {}".format(avg_val_loss))
    # print validation metrics for each head
    for i in range(12):
        print("Validation Metrics for Head {}: {}".format(i, val_metrics[i]))

    model.eval()
    test_preds = {i: [] for i in range(12)}
    test_labels = []
    total_test_loss = 0
    nb_test_steps = 0
    test_dataloader = tqdm(test_dataloader, desc="Test")
    for batch in test_dataloader:
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        b_labels_one_hot = torch.nn.functional.one_hot(b_labels, num_classes=2).to(device)
        with torch.no_grad():
            outputs = model(b_input_ids, b_input_mask, b_labels_one_hot)
        loss = outputs[0]
        logits = [logit.detach().cpu().numpy() for logit in outputs[1]]  # Detach each tensor in the list
        loss = sum(loss)
        total_eval_loss += loss.item()
        label_ids = b_labels.to('cpu').numpy()
        predictions = [np.argmax(logit, axis=1).flatten() for logit in logits]
        for i in range(12):
            test_preds[i].extend(predictions[i])
        test_labels.extend(label_ids)
    test_metrics = [compute_aggregate_and_classwise_metrics(test_preds[i], test_labels) for i in range(12)]
    avg_test_loss = total_test_loss / len(test_dataloader)
    print("Test Loss: {}".format(avg_test_loss))
    for i in range(12):
        print("Test Metrics for Head {}: {}".format(i, test_metrics[i]))
    
    all_stats.append({
        'epoch': epoch_i + 1,
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss,
        'test_loss': avg_test_loss,
        'val_metrics': val_metrics,
        'test_metrics': test_metrics
    })


print("Training complete!")
    

In [None]:
# plot validation stats
import matplotlib.pyplot as plt
# import seaborn as sns
# sns.set(style='darkgrid')
# sns.set(font_scale=1.5)
# plt.rcParams["figure.figsize"] = (12,6)

# print macrof1
val_macrof1_per_head = {}
for i in range(12):
    val_macrof1_per_head[i] = [stat['val_metrics'][i]['macro avg']['f1-score'] for stat in all_stats]
for i in range(12):
    plt.plot(val_macrof1_per_head[i], label='Head {}'.format(i))
# x axis is 1,2,3,4 epochs
plt.xticks([i for i in range(epochs)])
plt.xlabel('Epoch')
plt.ylabel('Macro F1 Score')
plt.title('Macro F1 Score Per Head')
plt.legend()
# increase the size of the plot to make it more readable
plt.rcParams["figure.figsize"] = (20,10)
plt.show()

In [None]:
test_macrof1_per_head = {}
for i in range(12):
    test_macrof1_per_head[i] = [stat['test_metrics'][i]['macro avg']['f1-score'] for stat in all_stats]
for i in range(12):
    plt.plot(test_macrof1_per_head[i], label='Head {}'.format(i))
# x axis is 1,2,3,4 epochs
plt.xticks([i for i in range(epochs)])
plt.xlabel('Epoch')
plt.ylabel('Macro F1 Score')
plt.title('Macro F1 Score Per Head')
plt.legend()
# increase the size of the plot to make it more readable
plt.rcParams["figure.figsize"] = (20,10)
plt.show()