## <img src="https://img.icons8.com/bubbles/50/000000/checklist.png" style="height:50px;display:inline"> Table of Contents
---

* [Data Preprocessing](#data-preprocessing)
    * [Data Loading](#data-loading)
    * [Data Augmentation](#data-augmentation)
    * [Data Tokenization and Vectorization](#data-tokenization-and-vectorization)
    * [DataLoaders](#dataloaders)
    * [Statistics](#statistics)
* [Training Functions](#training-functions)
* [Hyper Parameters](#hyper-parameters)
* [Models](#models)
    * [LSTM Model](#lstm-model)
    * [xLSTM Model](#xlstm-model)
    * [GRU Model](#gru-model)
    * [RWKV Model](#rwkv-model)
* [Optuna](#optuna)
    * [Framework](#framework)
    * [LSTM Study](#lstm-study)
    * [GRU Study](#gru-study)
    * [RWKV Study](#rwkv-study)
* [Training](#training)
    * [LSTM Train](#lstm-train)
    * [GRU Train](#gru-train)
    * [RWKV Train](#rwkv-train)
* [Comparison](#comparison)

## <img src="https://img.icons8.com/?size=100&id=uwzWDmqxwaFo&format=png&color=000000" style="height:50px;display:inline"> Imports
---

In [None]:
import os
import time

from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay, precision_score, recall_score, f1_score

import joblib

# uncomment if you are using google colab
# !pip install torchviz
# !pip install textattack

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
from torchviz import make_dot
import torch.nn.functional as F

import math, os
import logging

from tqdm import tqdm

from textattack.augmentation import EasyDataAugmenter
from transformers import BertModel, BertTokenizer

import datasets
from datasets import load_dataset, concatenate_datasets
# uncomment if you are using google colab
# !pip install optuna
import optuna

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'{device=}')

## <img src="https://img.icons8.com/?size=100&id=D7OBPFcT8dSK&format=png&color=000000" style="height:50px;display:inline"> Data Preprocessing
---

### Data Loading

In [None]:
# Load GoEmotions dataset
dataset = load_dataset("go_emotions")

In [None]:
# Define the classes
classes = ['admiration', 'amusement', 'anger', 'annoyance', 'approval',
           'caring', 'confusion', 'curiosity', 'desire', 'disappointment',
           'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear',
           'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism',
           'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise',
           'neutral']

# Ekman Mapping
primary_emotion_to_sub_emotions = {
    "anger": ["anger", "annoyance", "disapproval"],
    "disgust": ["disgust"],
    "fear": ["fear", "nervousness"],
    "joy": ["joy", "amusement", "approval", "excitement", "gratitude",  "love", "optimism", "relief", "pride", "admiration", "desire", "caring"],
    "sadness": ["sadness", "disappointment", "embarrassment", "grief",  "remorse"],
    "surprise": ["surprise", "realization", "confusion", "curiosity"],
    "neutral": ["neutral"]
}

sub_emotion_to_primary_emotion = {sub: primary for primary, subs in primary_emotion_to_sub_emotions.items() for sub in subs}

In [None]:
raw_train_labels = [label for entry in dataset['train'] for label in entry['labels']]
raw_validation_labels = [label for entry in dataset['validation'] for label in entry['labels']]
raw_test_labels = [label for entry in dataset['test'] for label in entry['labels']]

In [None]:
def disp_labels_distribution(labels, split, classes=classes):
    label_counts = Counter(labels)
    print(f'{label_counts.most_common()=}')

    # Calculate total labels
    total_labels = sum(label_counts.values())
    print(f'{total_labels=}')

    # Prepare data for plotting
    counts = [label_counts.get(i, 0) for i in range(len(classes))]

    # Plotting the bar graph
    fig, ax = plt.subplots(figsize=(8, 8))
    bars = ax.barh(classes, counts)
    ax.set_xlabel('Count')
    ax.set_title(f'Class Distribution in GoEmotions {split} Data')

    # Adding data labels
    for bar in bars:
        width = bar.get_width()
        label_x_pos = width + total_labels * 0.005
        ax.text(label_x_pos, bar.get_y() + bar.get_height()/2, f'{(width/total_labels)*100:.2f}%', va='center')

    plt.show()
    return label_counts.values()

In [None]:
disp_labels_distribution(raw_train_labels, 'Train')
disp_labels_distribution(raw_validation_labels, 'Validation')
disp_labels_distribution(raw_test_labels, 'Test')

In [None]:
print(f"Train Size: {len(dataset['train'])} | Valid Size: {len(dataset['validation'])} | Test Size: {len(dataset['test'])}")

### Data Augmentation

In [None]:
def primary_label_pipeline(labels):
    primary_labels = [sub_emotion_to_primary_emotion[classes[label]] for label in labels]

    primary_emotions = list(primary_emotion_to_sub_emotions.keys())
    primary_labels_indx = [primary_emotions.index(p) for p in primary_labels]

    repetitions = Counter(primary_labels_indx)

    most_common_label, _ = repetitions.most_common(1)[0]
    return most_common_label

In [None]:
eda_augmenter = EasyDataAugmenter()

def balanced_augment(dataset):
    # Count the occurrence of each label in the dataset
    unbalanced_train_counter = Counter([primary_label_pipeline(label) for label in dataset['labels']])
    print(f'{unbalanced_train_counter=}')
    target_count = max(unbalanced_train_counter.values()) // 4

    augmented_items = []

    for label, count in unbalanced_train_counter.items():
        if count < target_count:
            print(f'Augmenting {label=}: {count}/{target_count}')
            # Get the samples of the weak class
            weak_class_samples = [data['text'] for data in dataset if primary_label_pipeline(data['labels']) == label]

            # Calculate the number of augmentations needed
            num_augmentations = target_count - count

            # Augment the weak class samples
            for _ in tqdm(range(num_augmentations)):
                sample_to_augment = weak_class_samples[_ % len(weak_class_samples)]
                augmented_samples = eda_augmenter.augment(sample_to_augment)
                augmented_items.extend([{'text': aug_text, 'labels': [classes.index(list(primary_emotion_to_sub_emotions.keys())[label])], 'id': None} for aug_text in augmented_samples])

    # Create a new dataset from the augmented items
    augmented_dataset = datasets.Dataset.from_list(augmented_items).cast(dataset.features)

    # Concatenate the original dataset with the augmented dataset
    merged_dataset = concatenate_datasets([dataset, augmented_dataset])

    # Check the new label distribution
    balanced_train_counter = Counter([primary_label_pipeline(label) for label in merged_dataset['labels']])
    print(f'{balanced_train_counter=}')

    return merged_dataset, dataset

In [None]:
dataset = load_dataset("go_emotions")

print("========= Augmenting Train Dataset =========")
aug_train, train = balanced_augment(dataset['train'])
dataset['train'] = aug_train

# print("========= Augmenting Validation Dataset =========")
# aug_validation, validation = balanced_augment(dataset['validation'])
# dataset['validation'] = aug_validation

# print("========= Augmenting Test Dataset =========")
# aug_test, test = balanced_augment(dataset['test'])
# dataset['test'] = aug_test

# if not os.path.isdir('datasets'):
#     os.mkdir('datasets')
# dataset.save_to_disk('./datasets/')

### Data Tokenization and Vectorization

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
def yield_tokens(data_iter):
    for data in data_iter:
        text = data['text']
        yield tokenizer.tokenize(text)

vocab = build_vocab_from_iterator(yield_tokens(iter(dataset['train'])), specials=["<unk>", "<pad>"])
vocab.set_default_index(vocab["<unk>"])

In [None]:
bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)

# Freeze BERT parameters
for param in bert_model.parameters():
    param.requires_grad = False

# Create an embedding matrix for the vocabulary
embedding_dim = bert_model.config.hidden_size
vocab_dim = len(vocab)
embedding_matrix = torch.zeros(vocab_dim, embedding_dim)

# Get the embeddings for each token
for token, idx in tqdm(vocab.get_stoi().items()):
    inputs = tokenizer(token, return_tensors='pt').to(device)
    outputs = bert_model(**inputs)
    embeddings = outputs.last_hidden_state
    # Take the average of the token embeddings
    embedding_matrix[idx] = embeddings.mean(dim=1).squeeze()

In [None]:
np.savez_compressed('embedding_matrix.npz', embeddings=embedding_matrix.cpu().numpy(), vocab_dim=vocab_dim)
torch.save(vocab, 'vocab.pth')

### DataLoaders

In [None]:
text_pipeline = lambda x: vocab(tokenizer.tokenize(x))

label_pipeline = {
    'train': lambda x: primary_label_pipeline(x),
    'validation': lambda x: primary_label_pipeline(x),
    'test': lambda x: primary_label_pipeline(x)
}

In [None]:
max_seq_len = 30

def collate_batch(batch, split):
    label_list, text_tokenized_list = [], []
    for data in batch:
        _text, _labels, _id = data.values()
        label_list.append(label_pipeline[split](_labels))
        processed_text = torch.tensor(text_pipeline(_text)[:max_seq_len], dtype=torch.int64, device=device)
        if processed_text.shape[0] < max_seq_len:
            pad = vocab(['<pad>'])[0] * torch.ones(max_seq_len - len(processed_text), dtype=torch.int64, device=processed_text.device)
            processed_text = torch.cat([processed_text, pad])
        text_tokenized_list.append(processed_text)
    label_list = torch.tensor(label_list, dtype=torch.int64, device=device)
    text_tokenized_list = torch.stack(text_tokenized_list, dim=0)
    return label_list, text_tokenized_list

In [None]:
batch_size = 64

train_dataloader = DataLoader(
    dataset['train'],
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda batch: collate_batch(batch, 'train'))

valid_dataloader = DataLoader(
    dataset['validation'],
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: collate_batch(batch, 'validation'))

test_dataloader = DataLoader(
    dataset['test'],
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: collate_batch(batch, 'test'))

### Statistics

In [None]:
primary_emotions = list(primary_emotion_to_sub_emotions.keys())

In [None]:
def calc_primary_emotions_distribution(dataloader, split):
    total_labels = np.array([])
    for _, (labels, _) in enumerate(dataloader):
        total_labels = np.append(total_labels, labels.cpu().numpy())

    return disp_labels_distribution(total_labels, split, classes=primary_emotions)

total_sum_train = calc_primary_emotions_distribution(train_dataloader, 'train')
total_sum_validation = calc_primary_emotions_distribution(valid_dataloader, 'validation')
total_sum_test = calc_primary_emotions_distribution(test_dataloader, 'test')


## <img src="https://img.icons8.com/?size=100&id=114910&format=png&color=000000" style="height:50px;display:inline"> Training Functions
---

In [None]:
def clip_gradient(model, clip_value):
    params = list(filter(lambda p: p.grad is not None, model.parameters()))
    for p in params:
        p.grad.data.clamp_(-clip_value, clip_value)

In [None]:
def count_layers_and_parameters(dummy_model):
    # Count the number of layers
    def count_layers(module):
        if len(list(module.children())) == 0:
            return 1
        return sum(count_layers(child) for child in module.children())

    num_layers = count_layers(dummy_model)

    # Count the number of parameters
    num_params = sum(p.numel() for p in dummy_model.parameters() if p.requires_grad)
    print(f"{type(dummy_model).__name__}: {num_layers=} {num_params=}")

In [None]:
def calc_model_size(dummy_model):
    param_size = 0
    for param in dummy_model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in dummy_model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    size_all_mb = (param_size + buffer_size) / 1024 ** 2
    print(f"{type(dummy_model).__name__} size: {size_all_mb:.2f} MB")

In [None]:
class EarlyStopping:
    """Early stops the training if validation accuracy doesn't improve after a given patience."""
    def __init__(self, patience=5, verbose=False, delta=0, path='checkpoint.pth', save_th=60, trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation accuracy improved. (Default: 5)
            verbose (bool): If True, prints a message for each validation accuracy improvement. (Default: False)
            delta (float): Minimum change in the monitored quantity to qualify as an improvement. (Default: 0)
            path (str): Path for the checkpoint to be saved to. (Default: 'checkpoint.pth')
            save_th (int): Accuracy threshold for saving the model. (Default: 60)
            trace_func (function): Trace print function. (Default: print)
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_acc_max = -np.Inf
        self.delta = delta
        self.path = path
        self.save_th = save_th
        self.trace_func = trace_func

    def __call__(self, val_acc, model):
        if self.best_score is None:
            self.best_score = val_acc
        elif val_acc < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter}/{self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_acc
            if self.best_score > self.save_th:
                self.save_checkpoint(val_acc, model)
            self.counter = 0

    def save_checkpoint(self, val_acc, model):
        if self.verbose:
            self.trace_func(f'{self.best_score}% >= {self.save_th}% - Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_acc_max = val_acc

In [None]:
def calculate_accuracy(model, dataloader):
    model.eval()
    total_correct = 0
    total_images = 0
    confusion_matrix = np.zeros([len(primary_emotions), len(primary_emotions)], int)
    with torch.no_grad():
        for _, (labels, text) in enumerate(dataloader):
            outputs = model(text)[0] if isinstance(model, GoEmotions_RWKV) else model(text)
            _, predicted = torch.max(outputs.data, 1)
            total_images += labels.size(0)
            total_correct += (predicted == labels).sum().item()
            for i, l in enumerate(labels):
                confusion_matrix[l.item(), predicted[i].item()] += 1

    model_accuracy = total_correct / total_images * 100
    return model_accuracy, confusion_matrix


def train_model(model, train_dataloader, valid_dataloader, learning_rate, step_size, optimizer_name, num_epochs):
    optimizer = getattr(torch.optim, optimizer_name)(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, step_size)
    early_stopping = EarlyStopping(verbose=True, path=f'./checkpoints/{type(model).__name__}_best_ckpt.pth')

    train_losses = []
    train_accuracies = []
    validation_accuracies = []
    for epoch in range(1, num_epochs + 1):
        model.train()
        running_loss = 0.0
        epoch_time = time.time()

        for _, (labels, text) in enumerate(train_dataloader):
            outputs = model(text)[0] if isinstance(model, GoEmotions_RWKV) else model(text)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            clip_gradient(model, 1e-1)
            optimizer.step()

            running_loss += loss.item()

        running_loss /= len(train_dataloader)

        train_accuracy, train_cm = calculate_accuracy(model, train_dataloader)
        validation_accuracy, valid_cm = calculate_accuracy(model, valid_dataloader)
        print(f'Epoch [{epoch:2}/{num_epochs}] | Loss: {running_loss:.6f} | Training Accuracy: {train_accuracy:.4f}% | Validation Accuracy: {validation_accuracy:.4f}% | Time: {time.time() - epoch_time:.2f}s | Learning Rate: {scheduler.get_last_lr()}')

        scheduler.step()

        train_losses.append(running_loss)
        train_accuracies.append(train_accuracy)
        validation_accuracies.append(validation_accuracy)

        if epoch % 10 == 0:
            print('==> Saving model ...')
            state = {
                'net': model.state_dict(),
                'epoch': epoch,
            }
            if not os.path.isdir('checkpoints'):
                os.mkdir('checkpoints')
            torch.save(state, f'./checkpoints/{type(model).__name__}_ckpt.pth')

        early_stopping(validation_accuracy, model)
        if early_stopping.early_stop:
            print("Early Stopping")
            break

    return train_losses, train_accuracies, validation_accuracies, train_cm, valid_cm


In [None]:
def plot_statistics(train_losses, train_accuracies, validation_accuracies, train_cm, valid_cm):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(14, 5))

    # Plot losses
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Training Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss Over Epochs')
    plt.legend()

    # Plot accuracies
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label='Training Accuracy', color='orange')
    plt.plot(epochs, validation_accuracies, label='Validation Accuracy', color='green')
    max_valid_acc_indx = validation_accuracies.index(max(validation_accuracies)) + 1
    plt.axvline(max_valid_acc_indx, linestyle='--', color='r',label='Early Stopping Checkpoint')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Training Accuracy Over Epochs')
    plt.legend()

    plt.tight_layout()
    plt.show()

    # Displaying confusion matrices using ConfusionMatrixDisplay
    def plot_confusion_matrix(cm, title='Confusion Matrix'):
        cmn = np.round(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], 3)
        disp = ConfusionMatrixDisplay(confusion_matrix=cmn, display_labels=primary_emotions)
        disp.plot(cmap=plt.cm.Blues, xticks_rotation='vertical')
        plt.title(title)
        plt.show()

    # Plot training confusion matrix
    plot_confusion_matrix(train_cm, title='Training Confusion Matrix (Normalized)')

    # Plot validation confusion matrix
    plot_confusion_matrix(valid_cm, title='Validation Confusion Matrix (Normalized)')

## <img src="https://img.icons8.com/cute-clipart/64/000000/horizontal-settings-mixer.png" style="height:50px;display:inline"> Hyper Parameters
---

In [None]:
# These are parameters not tuned by Optuna later

# Constants
vocab_dim = len(vocab)
num_classes = len(primary_emotion_to_sub_emotions.keys())
num_epochs = 60

# Scheduler
step_size = [10, 50]

if os.path.exists("./studies/LSTM_study.pkl"):
    lstm_study = joblib.load("./studies/LSTM_study.pkl")
    best_lstm_trial = lstm_study.best_trial

if os.path.exists("./studies/xLSTM_study.pkl"):
    xlstm_study = joblib.load("./studies/xLSTM_study.pkl")
    best_xlstm_trial = xlstm_study.best_trial

if os.path.exists("./studies/RWKV_study.pkl"):
    rwkv_study = joblib.load("./studies/RWKV_study.pkl")
    best_rwkv_trial = rwkv_study.best_trial

if os.path.exists("./studies/GRU_study.pkl"):
    gru_study = joblib.load("./studies/GRU_study.pkl")
    best_gru_trial = gru_study.best_trial

In [None]:
if os.path.exists('embedding_matrix.npz'):
    with np.load('embedding_matrix.npz') as data:
        embedding_matrix = torch.from_numpy(data['embeddings']).to(device)
        vocab_dim = data['vocab_dim']
        vocab = torch.load('vocab.pth')
        print("Loaded embedding_matrix.npz and vocab.pth")

print(embedding_matrix.shape, vocab_dim, len(vocab))

## <img src="https://img.icons8.com/?size=100&id=Y6kSC37ALOtM&format=png&color=000000" style="height:50px;display:inline"> Models
---

### LSTM Model

In [None]:
class GoEmotions_LSTM(nn.Module):
    def __init__(self,
                 # Vocab
                 vocab_dim,
                 # Embedding
                 embedding_dim,
                 embedding_weights,
                 # LSTM
                 lstm_hidden_dim,
                 lstm_num_layers,
                 lstm_dropout,
                 bi_directional,
                 # Dense
                 dense_hidden_dims,
                 dense_dropouts,
                 # Output
                 num_classes):
        super(GoEmotions_LSTM, self).__init__()

        # Vocab
        self.vocab_dim = vocab_dim
        # Embedding
        self.embedding_dim = embedding_dim
        self.embedding_weights = embedding_weights
        # LSTM
        self.lstm_hidden_dim = lstm_hidden_dim
        self.lstm_num_layers = lstm_num_layers
        self.lstm_dropout = lstm_dropout
        self.bi_directional = bi_directional
        # Dense
        self.dense_hidden_dims = dense_hidden_dims
        self.dense_dropouts = dense_dropouts
        self.dense_input_dim = 2 * self.lstm_hidden_dim if bi_directional else self.lstm_hidden_dim
        # Output
        self.num_classes = num_classes

        # Layer definitions
        self.embedding = nn.Embedding(self.vocab_dim, self.embedding_dim)
        self.embedding.weights = nn.Parameter(self.embedding_weights, requires_grad=False)

        self.embedding_norm = nn.LayerNorm(embedding_dim, eps=1e-12, elementwise_affine=True)

        self.lstm = nn.LSTM(
            input_size=self.embedding_dim,
            hidden_size=self.lstm_hidden_dim,
            num_layers=self.lstm_num_layers,
            batch_first=True,
            dropout=self.lstm_dropout,
            bidirectional=self.bi_directional)

        in_features = self.dense_input_dim
        layers = []

        for l in range(len(self.dense_hidden_dims)):
            out_features = self.dense_hidden_dims[l]
            p = self.dense_dropouts[l]

            layers.append(nn.Linear(in_features, out_features))
            layers.append(nn.GELU())
            layers.append(nn.LayerNorm(out_features, eps=1e-12, elementwise_affine=True))
            layers.append(nn.Dropout(p))

            in_features = out_features

        layers.append(nn.Linear(out_features, self.num_classes))
        layers.append(nn.Dropout(self.dense_dropouts[-1]))
        self.dense_layer = nn.Sequential(*layers)

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.init_weights()

    def init_weights(self):
        # pick initialzation: https://pytorch.org/docs/stable/nn.init.html
        # examples
        # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        # nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu', a=math.sqrt(5))
        # nn.init.normal_(m.weight, 0, 0.005)
        # don't forget the bias term (m.bias)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_normal_(m.weight, gain=1.0)
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)


    def forward(self, x):
        h0 = torch.zeros((self.lstm_num_layers * 2, x.size(0), self.lstm_hidden_dim), device=self.device)
        c0 = torch.zeros((self.lstm_num_layers * 2, x.size(0), self.lstm_hidden_dim), device=self.device)

        torch.nn.init.xavier_normal_(h0)
        torch.nn.init.xavier_normal_(c0)

        x = self.embedding(x)
        x_embedding_shape = x.shape
        x = self.embedding_norm(x)

        # From: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
        #
        # For bidirectional LSTMs, h_n is not equivalent to the last element of output; the former
        # contains the final forward and reverse hidden states, while the latter contains the final
        # forward hidden state and the initial reverse hidden state.
        x, (h, c) = self.lstm(x, (h0, c0))
        if self.bi_directional:
            h = h.view(self.lstm_num_layers, 2, x_embedding_shape[0], self.lstm_hidden_dim)
            h = h[-1]
            h = h.transpose(0, 1).reshape(x_embedding_shape[0], 2 * self.lstm_hidden_dim)
        else:
            h = h[-1]

        out = self.dense_layer(h)
        return out

### xLSTM Model

W.I.P - Unable to test due to lack of compute power.
The following blocks are left untested.

In [None]:
class BlockDiagonal(nn.Module):
    def __init__(self, in_features, out_features, num_blocks, bias=True):
        super(BlockDiagonal, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_blocks = num_blocks

        assert out_features % num_blocks == 0

        block_out_features = out_features // num_blocks

        self.blocks = nn.ModuleList([
            nn.Linear(in_features, block_out_features, bias=bias)
            for _ in range(num_blocks)
        ])

    def forward(self, x):
        x = [block(x) for block in self.blocks]
        x = torch.cat(x, dim=-1)
        return x

class CausalConv1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs):
        super(CausalConv1D, self).__init__()
        self.padding = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=self.padding, dilation=dilation, **kwargs)

    def forward(self, x):
        x = self.conv(x)
        return x[:, :, :-self.padding]

In [None]:
class mLSTMblock(nn.Module):
    def __init__(self, x_example, factor, depth, dropout=0.2):
        super().__init__()
        self.input_size = x_example.shape[2]
        self.hidden_size = int(self.input_size*factor)

        self.ln = nn.LayerNorm(self.input_size)

        self.left = nn.Linear(self.input_size, self.hidden_size)
        self.right = nn.Linear(self.input_size, self.hidden_size)

        self.conv = CausalConv1D(self.hidden_size, self.hidden_size, int(self.input_size/10))
        self.drop = nn.Dropout(dropout+0.1)

        self.lskip = nn.Linear(self.hidden_size, self.hidden_size)

        self.wq = BlockDiagonal(self.hidden_size, self.hidden_size, depth)
        self.wk = BlockDiagonal(self.hidden_size, self.hidden_size, depth)
        self.wv = BlockDiagonal(self.hidden_size, self.hidden_size, depth)
        self.dropq = nn.Dropout(dropout/2)
        self.dropk = nn.Dropout(dropout/2)
        self.dropv = nn.Dropout(dropout/2)

        self.i_gate = nn.Linear(self.hidden_size, self.hidden_size)
        self.f_gate = nn.Linear(self.hidden_size, self.hidden_size)
        self.o_gate = nn.Linear(self.hidden_size, self.hidden_size)

        self.ln_c = nn.LayerNorm(self.hidden_size)
        self.ln_n = nn.LayerNorm(self.hidden_size)

        self.lnf = nn.LayerNorm(self.hidden_size)
        self.lno = nn.LayerNorm(self.hidden_size)
        self.lni = nn.LayerNorm(self.hidden_size)

        self.GN = nn.LayerNorm(self.hidden_size)
        self.ln_out = nn.LayerNorm(self.hidden_size)

        self.drop2 = nn.Dropout(dropout)

        self.proj = nn.Linear(self.hidden_size, self.input_size)
        self.ln_proj = nn.LayerNorm(self.input_size)

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.init_weights()

    def init_weights(self):
        self.ct_1 = torch.zeros([1, 1, self.hidden_size], device=self.device)
        self.nt_1 = torch.zeros([1, 1, self.hidden_size], device=self.device)

    def forward(self, x):
        assert x.ndim == 3

        x = self.ln(x) # layer norm on x

        left = self.left(x) # part left
        right = F.silu(self.right(x)) # part right with just swish (silu) function

        left_left = left.transpose(1, 2)
        left_left = F.silu( self.drop( self.conv( left_left ).transpose(1, 2) ) )
        l_skip = self.lskip(left_left)

        # start mLSTM
        q = self.dropq(self.wq(left_left))
        k = self.dropk(self.wk(left_left))
        v = self.dropv(self.wv(left))

        i = torch.exp(self.lni(self.i_gate(left_left)))
        f = torch.exp(self.lnf(self.f_gate(left_left)))
        o = torch.sigmoid(self.lno(self.o_gate(left_left)))

        ct_1 = self.ct_1
        ct = f*ct_1 + i*v*k
        ct = torch.mean(self.ln_c(ct), [0, 1], keepdim=True)
        self.ct_1 = ct.detach()

        nt_1 = self.nt_1
        nt = f*nt_1 + i*k
        nt =torch.mean( self.ln_n(nt), [0, 1], keepdim=True)
        self.nt_1 = nt.detach()

        ht = o * ((ct*q) / torch.max(nt*q))
        # end mLSTM
        ht = ht

        left = self.drop2(self.GN(ht + l_skip))

        out = self.ln_out(left * right)
        out = self.ln_proj(self.proj(out))

        return out

In [None]:
class sLSTMblock(nn.Module):
    def __init__(self, x_example, depth, dropout=0.2):
        super().__init__()
        self.input_size = x_example.shape[2]
        # conv_channels = x_example.shape[1]

        self.ln = nn.LayerNorm(self.input_size)

        self.conv = CausalConv1D(self.input_size, self.input_size, int(self.input_size/8))
        self.drop = nn.Dropout(dropout)

        self.i_gate = BlockDiagonal(self.input_size, self.input_size, depth)
        self.f_gate = BlockDiagonal(self.input_size, self.input_size, depth)
        self.o_gate = BlockDiagonal(self.input_size, self.input_size, depth)
        self.z_gate = BlockDiagonal(self.input_size, self.input_size, depth)

        self.ri_gate = BlockDiagonal(self.input_size, self.input_size, depth, bias=False)
        self.rf_gate = BlockDiagonal(self.input_size, self.input_size, depth, bias=False)
        self.ro_gate = BlockDiagonal(self.input_size, self.input_size, depth, bias=False)
        self.rz_gate = BlockDiagonal(self.input_size, self.input_size, depth, bias=False)

        self.ln_i = nn.LayerNorm(self.input_size)
        self.ln_f = nn.LayerNorm(self.input_size)
        self.ln_o = nn.LayerNorm(self.input_size)
        self.ln_z = nn.LayerNorm(self.input_size)

        self.GN = nn.LayerNorm(self.input_size)
        self.ln_c = nn.LayerNorm(self.input_size)
        self.ln_n = nn.LayerNorm(self.input_size)
        self.ln_h = nn.LayerNorm(self.input_size)

        self.left_linear = nn.Linear(self.input_size, int(self.input_size*(4/3)))
        self.right_linear = nn.Linear(self.input_size, int(self.input_size*(4/3)))

        self.ln_out = nn.LayerNorm(int(self.input_size*(4/3)))

        self.proj = nn.Linear(int(self.input_size*(4/3)), self.input_size)

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.init_weights()

    def init_weights(self):
        self.nt_1 = torch.zeros(1, 1, self.input_size, device=self.device)
        self.ct_1 = torch.zeros(1, 1, self.input_size, device=self.device)
        self.ht_1 = torch.zeros(1, 1, self.input_size, device=self.device)
        self.mt_1 = torch.zeros(1, 1, self.input_size, device=self.device)

    def forward(self, x):
        x = self.ln(x)

        x_conv = F.silu( self.drop(self.conv( x.transpose(1, 2) ).transpose(1, 2) ) )

        # start sLSTM
        ht_1 = self.ht_1

        i = torch.exp(self.ln_i( self.i_gate(x_conv) + self.ri_gate(ht_1) ) )
        f = torch.exp( self.ln_f(self.f_gate(x_conv) + self.rf_gate(ht_1) ) )

        m = torch.max(torch.log(f)+self.mt_1[:, 0, :].unsqueeze(1), torch.log(i))
        i = torch.exp(torch.log(i) - m)
        f = torch.exp(torch.log(f) + self.mt_1[:, 0, :].unsqueeze(1)-m)
        self.mt_1 = m.detach()

        o = torch.sigmoid( self.ln_o(self.o_gate(x) + self.ro_gate(ht_1) ) )
        z = torch.tanh( self.ln_z(self.z_gate(x) + self.rz_gate(ht_1) ) )

        ct_1 = self.ct_1
        ct = f*ct_1 + i*z
        ct = torch.mean(self.ln_c(ct), [0, 1], keepdim=True)
        self.ct_1 = ct.detach()

        nt_1 = self.nt_1
        nt = f*nt_1 + i
        nt = torch.mean(self.ln_n(nt), [0, 1], keepdim=True)
        self.nt_1 = nt.detach()

        ht = o*(ct/nt) # torch.Size([4, 8, 16])
        ht = torch.mean(self.ln_h(ht), [0, 1], keepdim=True)
        self.ht_1 = ht.detach()
        # end sLSTM

        slstm_out = self.GN(ht)

        left = self.left_linear(slstm_out)
        right = F.gelu(self.right_linear(slstm_out))

        out = self.ln_out(left*right)
        out = self.proj(out)
        return out


In [None]:
class xLSTM(nn.Module):
    def __init__(self, layers, x_example, depth=4, factor=2):
        super(xLSTM, self).__init__()

        self.layers = nn.ModuleList()
        for layer_type in layers:
            if layer_type == 's':
                layer = sLSTMblock(x_example, depth)
            elif layer_type == 'm':
                layer = mLSTMblock(x_example, factor, depth)
            else:
                raise ValueError(f"Invalid layer type: {layer_type}. Choose 's' for sLSTM or 'm' for mLSTM.")
            self.layers.append(layer)

    def init_weights(self, x):
        [l.init_weights(x) for l in self.layers]

    def forward(self, x):
        x_original = x.clone()
        for l in self.layers:
             x = l(x) + x_original

        return x

In [None]:
class GoEmotions_xLSTM(nn.Module):
    def __init__(self,
                 # Vocab
                 vocab_dim,
                 # Embedding
                 embedding_dim,
                 embedding_weights,
                 # xLSTM
                 x_example,
                 config_block, # "msmmmmmm" - [7:1] 7 mlstm, 1 slstm.
                 hidden_size,
                 num_heads, # define number of block diagonal
                 dropout,
                #  # Dense
                #  dense_hidden_dims,
                #  dense_dropouts,
                 # Output
                 num_classes):
        super(GoEmotions_xLSTM, self).__init__()
        self.vocab_dim = vocab_dim
        # self.embedding_dim = embedding_dim
        self.embedding_dim = x_example.shape[2]
        self.embedding_weights = embedding_weights

        # self.block_size = block_size # input_size
        self.block_size = x_example.shape[1] # input_size
        self.config_block = config_block
        self.num_heads = num_heads   # define number of block diagonal
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.ln = nn.LayerNorm(self.embedding_dim)
        self.head = nn.Linear(self.embedding_dim, self.vocab_dim)

        self.embedding = nn.Embedding(self.vocab_dim, self.embedding_dim)
        self.embedding.weights = nn.Parameter(self.embedding_weights, requires_grad=False)
        self.position_embedding_table = nn.Embedding(self.block_size, self.embedding_dim)
        self.embedding_norm = nn.LayerNorm(embedding_dim, eps=1e-12, elementwise_affine=True)

        # self.xlstm = xLSTM(layers=self.config_block, input_size=block_size, depth=4)
        self.xlstm = xLSTM(layers=self.config_block, x_example=x_example, depth=4)
        self.num_classes = num_classes

        self.proj = nn.Linear(self.vocab_dim, num_classes)

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    def init_weights(self, x):
        self.xlstm.init_states(x)

    def forward(self, x):
        B, T = x.shape

        # Ensure T does not exceed block_size
        if T > self.block_size:
            x = x[:, :self.block_size]  # Truncate input to block_size
            T = self.block_size
        elif T < self.block_size:
            padding = torch.zeros(B, self.block_size - T, dtype=x.dtype, device=x.device)
            x = torch.cat((x, padding), dim=1)
            T = self.block_size

        # x and targets are both (B,T) tensor of integers
        tok_emb = self.embedding(x) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=self.device)) # T, C
        x = tok_emb + pos_emb # (B, T, C)
        x = self.embedding_norm(x)
        x = self.xlstm(x)

        x = self.ln(x)

        out = self.head(x)
        out = self.proj(out[:, -1, :])
        # out = self.proj(x[-1])
        # print(f"out shape - {out.shape}")
        # B, T, C = out.shape
        # out = out.view(B, C)

        return out

### GRU Model

In [None]:
class GoEmotions_GRU(nn.Module):
    def __init__(self,
                 # Vocab
                 vocab_dim,
                 # Embedding
                 embedding_dim,
                 embedding_weights,
                 # GRU
                 gru_hidden_dim,
                 gru_num_layers,
                 gru_dropout,
                 bi_directional,
                 # Dense
                 dense_hidden_dims,
                 dense_dropouts,
                 # Output
                 num_classes):
        super(GoEmotions_GRU, self).__init__()

        # Vocab
        self.vocab_dim = vocab_dim
        # Embedding
        self.embedding_dim = embedding_dim
        self.embedding_weights = embedding_weights
        # GRU
        self.gru_hidden_dim = gru_hidden_dim
        self.gru_num_layers = gru_num_layers
        self.gru_dropout = gru_dropout
        self.bi_directional = bi_directional
        # Dense
        self.dense_hidden_dims = dense_hidden_dims
        self.dense_dropouts = dense_dropouts
        self.dense_input_dim = 2 * self.gru_hidden_dim if bi_directional else self.gru_hidden_dim
        # Output
        self.num_classes = num_classes

        # Layer definitions
        self.embedding = nn.Embedding(self.vocab_dim, self.embedding_dim)
        self.embedding.weights = nn.Parameter(self.embedding_weights, requires_grad=False)

        self.embedding_norm = nn.LayerNorm(embedding_dim, eps=1e-12, elementwise_affine=True)

        self.gru = nn.GRU(
            input_size=self.embedding_dim,
            hidden_size=self.gru_hidden_dim,
            num_layers=self.gru_num_layers,
            batch_first=True,
            dropout=self.gru_dropout,
            bidirectional=self.bi_directional)

        in_features = self.dense_input_dim
        layers = []

        for l in range(len(self.dense_hidden_dims)):
            out_features = self.dense_hidden_dims[l]
            p = self.dense_dropouts[l]

            layers.append(nn.Linear(in_features, out_features))
            layers.append(nn.GELU())
            layers.append(nn.LayerNorm(out_features, eps=1e-12, elementwise_affine=True))
            layers.append(nn.Dropout(p))

            in_features = out_features

        layers.append(nn.Linear(out_features, self.num_classes))
        layers.append(nn.Dropout(self.dense_dropouts[-1]))
        self.dense_layer = nn.Sequential(*layers)

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.init_weights()

    def init_weights(self):
        # pick initialzation: https://pytorch.org/docs/stable/nn.init.html
        # examples
        # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        # nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu', a=math.sqrt(5))
        # nn.init.normal_(m.weight, 0, 0.005)
        # don't forget the bias term (m.bias)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_normal_(m.weight, gain=1.0)
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)


    def forward(self, x):
        h0 = torch.zeros((self.gru_num_layers * 2, x.size(0), self.gru_hidden_dim), device=self.device)

        torch.nn.init.xavier_normal_(h0)

        x = self.embedding(x)
        x_embedding_shape = x.shape
        x = self.embedding_norm(x)

        # While it is not documented in Pytorch, it is fundumentally different taking the output (x) vs
        # the final hidden state (h) when using Bidirectional GRU.
        x, h = self.gru(x, h0)
        if self.bi_directional:
            h = h.view(self.gru_num_layers, 2, x_embedding_shape[0], self.gru_hidden_dim)
            h = h[-1]
            h = h.transpose(0, 1).reshape(x_embedding_shape[0], 2 * self.gru_hidden_dim)
        else:
            h = h[-1]

        out = self.dense_layer(h)
        return out

### RWKV Model

In [None]:
class ChannelMix(nn.Module):
    def __init__(self, layer_id, n_layer, n_embed):
        super().__init__()
        self.layer_id = layer_id

        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

        with torch.no_grad():
            ratio_1_to_almost0 = 1.0 - layer_id/n_layer
            x = torch.ones(1,1, n_embed)
            for i in range(n_embed):
                x[0, 0, i] = i/n_embed

            self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
            self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))

        hidden_size = 4*n_embed
        self.key = nn.Linear(n_embed, hidden_size, bias=False)
        self.receptance = nn.Linear(n_embed, n_embed, bias=False)

        self.value = nn.Linear(hidden_size, n_embed, bias=False)

    def forward(self, x):
        xx = self.time_shift(x)
        xk = x * self.time_mix_k + (1-self.time_mix_k) * xx
        xr = x * self.time_mix_r + (1-self.time_mix_r) * xx

        k = self.key(xk)
        k = torch.square(torch.relu(k))

        kv = self.value(k)
        rkv = torch.sigmoid(self.receptance(xr)) * kv
        return rkv

In [None]:
class TimeMix(nn.Module):
    def __init__(self, layer_id, n_layer, n_embed, aa_bb_pp_shape_1=1):
        super().__init__()
        self.layer_id = layer_id

        attn_sz = n_embed
        with torch.no_grad():
            ratio_1_to_almost0 = 1.0 - layer_id/n_layer
            ratio_0_to_1 = layer_id / (n_layer - 1)

            decay_speed = torch.ones(attn_sz)
            for h in range(attn_sz):
                decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1)

            self.time_decay = nn.Parameter(decay_speed)

            zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5)
            self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)

            x = torch.ones(1,1, n_embed)
            for i in range(n_embed):
                x[0, 0, i] = i/n_embed

            self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
            self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) +0.3 * ratio_0_to_1)
            self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))

            self.aa = nn.Parameter(torch.ones(1,aa_bb_pp_shape_1,attn_sz))
            self.bb = nn.Parameter(torch.ones(1,aa_bb_pp_shape_1,attn_sz))
            pp = torch.ones(1,aa_bb_pp_shape_1,attn_sz)
            pp = pp * -1e30
            self.pp = nn.Parameter(pp)
            self.xx = nn.Parameter(torch.ones(1,1,attn_sz))

        hidden_size = attn_sz
        self.key = nn.Linear(n_embed, attn_sz, bias=False)
        self.receptance = nn.Linear(n_embed, attn_sz, bias=False)

        self.value = nn.Linear(hidden_size, attn_sz, bias=False)
        self.output = nn.Linear(attn_sz, n_embed, bias=False)

    def forward(self, x):
        b, t, c = x.shape
        xx = self.xx

        xk = x * self.time_mix_k + (1-self.time_mix_k) * xx
        xv = x * self.time_mix_v + (1-self.time_mix_v) * xx
        xr = x * self.time_mix_r + (1-self.time_mix_r) * xx

        min_size = min(b, self.aa.shape[0])

        k = self.key(xk)[:min_size, :, :]

        v = self.value(xv)[:min_size, :, :]
        r = self.receptance(xr)[:min_size, :, :]

        r = torch.sigmoid(r)

        aa = self.aa[:min_size, :, :]
        bb = self.bb[:min_size, :, :]
        pp = self.pp[:min_size, :, :]

        ww = self.time_first + k

        qq = torch.maximum(pp, ww )
        e1 = torch.exp(pp - qq)
        e2 = torch.exp(ww - qq)

        a = e1 * aa + e2 * v
        b = e1 * bb + e2
        wkv = a / b
        ww = pp + self.time_decay

        qq = torch.maximum(ww, k)
        e1 = torch.exp(ww - qq)
        e2 = torch.exp(k - qq)

        with torch.no_grad():
            xx = nn.Parameter(x)
            self.aa = nn.Parameter(e1 * aa + e2 * v)
            self.bb = nn.Parameter(e1 * bb + e2)
            self.pp = nn.Parameter(qq)

        return self.output(r * wkv)

In [None]:
class Block(nn.Module):
    def __init__(self, layer_id, rwkv_num_layers, embedding_dim, aa_bb_pp_shape_1=1):
        super().__init__()
        self.layer_id = layer_id

        self.ln1 = nn.LayerNorm(embedding_dim)
        self.ln2 = nn.LayerNorm(embedding_dim)

        if self.layer_id == 0:
            self.ln0 = nn.LayerNorm(embedding_dim)

        if self.layer_id == 0 :
            self.ffnPre = ChannelMix(0, rwkv_num_layers, embedding_dim)
        else:
            self.att = TimeMix(layer_id, rwkv_num_layers, embedding_dim, aa_bb_pp_shape_1)

        self.ffn = ChannelMix(layer_id, rwkv_num_layers, embedding_dim)

    def forward(self, x):
        if self.layer_id == 0:
            x = self.ln0(x)
        if self.layer_id == 0 :
            x = x + self.ffnPre(self.ln1(x))  # better in some cases
        else:
            x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

In [None]:
class GoEmotions_RWKV(nn.Module):
    def __init__(self,
                 # Vocab
                 vocab_dim,
                 # Embedding
                 embedding_dim,
                 embedding_weights,
                 # RWKV
                 rwkv_num_layers,
                 ctx_len,
                 # Output
                 num_classes,
                 # Default Values
                 aa_bb_pp_shape_1=1
                 ):
        super().__init__()

        # Vocab
        self.vocab_dim = vocab_dim
        # Embedding
        self.embedding_dim = embedding_dim
        self.embedding_weights = embedding_weights
        # RWKV
        self.step = 0
        self.rwkv_num_layers = rwkv_num_layers
        self.ctx_len = ctx_len # block size - what is the maximum context length for predictions?
        # Output
        self.num_classes = num_classes

        self.embedding = nn.Embedding(self.vocab_dim, self.embedding_dim)
        self.embedding.weights = nn.Parameter(self.embedding_weights, requires_grad=False)

        self.blocks = nn.Sequential(*[Block(i, self.rwkv_num_layers, self.embedding_dim, aa_bb_pp_shape_1)
                                    for i in range(self.rwkv_num_layers)])

        self.ln_out = nn.LayerNorm(self.embedding_dim)
        self.head = nn.Linear(self.embedding_dim, self.vocab_dim, bias=False)
        self.proj = nn.Linear(self.vocab_dim, self.num_classes, bias=False)

    def forward(self, idx, targets=None):
            idx = idx.to(self.embedding.weight.device)

            self.step += 1

            B, T = idx.size()
            assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."

            x = self.embedding(idx)
            x = self.blocks(x)
            x = self.ln_out(x)

            x = self.head(x)
            x = self.proj(x)

            loss = None
            if targets is not None:
                loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1))
            x = torch.mean(x, dim=1, keepdim=True).squeeze()
            return x, loss

    # def generate(self, idx, max_new_tokes):
    #     for _ in range(max_new_tokes):
    #         idx_cond = idx[:, -block_size:]
    #         logits, loss = self(idx_cond)
    #         logits = logits[:, -1, :]
    #         probs = F.softmax(logits, dim = -1)
    #         idx_next = torch.multinomial(probs, num_samples = 1)
    #         idx = torch.cat((idx, idx_next), dim = 1)
    #     return idx

## <img src="https://img.icons8.com/color/96/000000/pie-chart--v1.png" style="height:50px;display:inline"> Optuna
---

### Framework

In [None]:
log_interval = 10
n_train_examples = batch_size * 100
n_valid_examples = batch_size * 300


def define_model(trial, model_type):
    embedding_dim = trial.suggest_int("embedding_dim", 64, 256)
    rnn_num_layers = trial.suggest_int("rnn_num_layers", 2, 5)

    if model_type == 'LSTM':
        rnn_hidden_dim = trial.suggest_int("rnn_hidden_dim", 64, 256)
        rnn_dropout = trial.suggest_float("rnn_dropout", 0.1, 0.5)
        dense_num_layers = trial.suggest_int("n_layers", 1, 4)
        dense_hidden_dims = [trial.suggest_int(f"n_units_l{l}", 64, 256) for l in range(dense_num_layers)]
        dense_dropouts = [trial.suggest_float(f"dropout_l{l}", 0.1, 0.5) for l in range(dense_num_layers + 1)]

        model = GoEmotions_LSTM(
            vocab_dim=vocab_dim,
            embedding_dim=embedding_dim,
            embedding_weights=embedding_matrix,
            lstm_hidden_dim=rnn_hidden_dim,
            lstm_num_layers=rnn_num_layers,
            lstm_dropout=rnn_dropout,
            bi_directional=True,
            dense_hidden_dims=dense_hidden_dims,
            dense_dropouts=dense_dropouts,
            num_classes=num_classes
        ).to(device)
    elif model_type == 'RWKV':
        ctx_len = trial.suggest_int("ctx_len", max_seq_len, 2 * max_seq_len)

        model = GoEmotions_RWKV(
            vocab_dim=vocab_dim,
            embedding_dim=embedding_dim,
            embedding_weights=embedding_matrix,
            rwkv_num_layers=rnn_num_layers,
            ctx_len=ctx_len,
            num_classes=num_classes
        ).to(device)
    elif model_type == 'GRU':
        rnn_hidden_dim = trial.suggest_int("rnn_hidden_dim", 64, 256)
        rnn_dropout = trial.suggest_float("rnn_dropout", 0.1, 0.5)
        dense_num_layers = trial.suggest_int("n_layers", 1, 4)
        dense_hidden_dims = [trial.suggest_int(f"n_units_l{l}", 64, 256) for l in range(dense_num_layers)]
        dense_dropouts = [trial.suggest_float(f"dropout_l{l}", 0.1, 0.5) for l in range(dense_num_layers + 1)]

        model = GoEmotions_GRU(
            vocab_dim=vocab_dim,
            embedding_dim=embedding_dim,
            embedding_weights=embedding_matrix,
            gru_hidden_dim=rnn_hidden_dim,
            gru_num_layers=rnn_num_layers,
            gru_dropout=rnn_dropout,
            bi_directional=True,
            dense_hidden_dims=dense_hidden_dims,
            dense_dropouts=dense_dropouts,
            num_classes=num_classes
        ).to(device)

    return model


def objective(trial, model_type):
    model = define_model(trial, model_type)

    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop"])
    optimizer = getattr(torch.optim, optimizer_name)(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1, num_epochs + 1):
        model.train()

        for batch_idx, (labels, text) in enumerate(train_dataloader):
            if batch_idx * batch_size >= n_train_examples:
                break

            outputs = model(text)[0] if isinstance(model, GoEmotions_RWKV) else model(text)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            clip_gradient(model, 1e-1)
            optimizer.step()

        model.eval()
        total_correct = 0
        total_labels = 0
        with torch.no_grad():
            for batch_idx, (labels, text) in enumerate(valid_dataloader):
                if batch_idx * batch_size >= n_valid_examples:
                    break

                outputs = model(text)[0] if isinstance(model, GoEmotions_RWKV) else model(text)
                _, predicted = torch.max(outputs.data, 1)
                total_labels += labels.size(0)
                total_correct += (predicted == labels).sum().item()

        model_accuracy = total_correct / total_labels * 100
        trial.report(model_accuracy, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return model_accuracy

In [None]:
def plot_trials(study, study_name, acc_th):
    trials = study.trials
    fig, ax = plt.subplots()

    for trial in trials:
        if trial.state == optuna.trial.TrialState.COMPLETE:
            epochs = list(trial.intermediate_values.keys())
            accuracies = list(trial.intermediate_values.values())

            if accuracies[-1] >= acc_th:
                ax.plot(epochs, accuracies, label=f'Trial {trial.number}')

    ax.axhline(y=acc_th, linestyle='--', label=f'{acc_th}% Accuracy')
    ax.set_xlabel('Epochs')
    ax.set_ylabel('Accuracy')
    ax.set_title(f'Accuracy vs Epochs for all Trials - Study {study_name} (Accuracy Threshold: {acc_th}%)')
    ax.legend()
    plt.show()

### LSTM Study

In [None]:
sampler = optuna.samplers.TPESampler()
study = optuna.create_study(study_name="goemotions-lstm", direction="maximize", sampler=sampler)
study.optimize(lambda study: objective(study, 'LSTM'), n_trials=100, timeout=60 * 90)

pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

print("Study statistics: ")
print("  Number of finished trials: ", len(study.trials))
print("  Number of pruned trials: ", len(pruned_trials))
print("  Number of complete trials: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

lstm_study = study

In [None]:
plot_trials(lstm_study, "LSTM", 50)

In [None]:
optuna.visualization.plot_param_importances(lstm_study)

In [None]:
param_importances = optuna.importance.get_param_importances(lstm_study)
top_params = [param for param, _ in sorted(param_importances.items(), key=lambda item: item[1], reverse=True)[:3]]
fig = optuna.visualization.plot_contour(lstm_study, params=top_params[:3])
fig.show()

In [None]:
best_lstm_trial = trial
if not os.path.isdir("studies"):
    os.mkdir("studies")
joblib.dump(study, "./studies/LSTM_study.pkl")

### GRU Study

In [None]:
sampler = optuna.samplers.TPESampler()
study = optuna.create_study(study_name="goemotions-gru", direction="maximize", sampler=sampler)
study.optimize(lambda study: objective(study, 'GRU'), n_trials=100, timeout=60 * 90)

pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

print("Study statistics: ")
print("  Number of finished trials: ", len(study.trials))
print("  Number of pruned trials: ", len(pruned_trials))
print("  Number of complete trials: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

gru_study = study

In [None]:
plot_trials(gru_study, "GRU", 50)

In [None]:
optuna.visualization.plot_param_importances(gru_study)

In [None]:
param_importances = optuna.importance.get_param_importances(gru_study)
top_params = [param for param, _ in sorted(param_importances.items(), key=lambda item: item[1], reverse=True)[:3]]
fig = optuna.visualization.plot_contour(gru_study, params=top_params[:3])
fig.show()

In [None]:
best_gru_trial = trial
if not os.path.isdir("studies"):
    os.mkdir("studies")
joblib.dump(study, "./studies/GRU_study.pkl")

### RWKV Study

In [None]:
sampler = optuna.samplers.TPESampler()
study = optuna.create_study(study_name="goemotions-rwkv", direction="maximize", sampler=sampler)
study.optimize(lambda study: objective(study, 'RWKV'), n_trials=100, timeout=60 * 90)

pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

print("Study statistics: ")
print("  Number of finished trials: ", len(study.trials))
print("  Number of pruned trials: ", len(pruned_trials))
print("  Number of complete trials: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

rwkv_study = study

In [None]:
plot_trials(rwkv_study, "RWKV", 50)

In [None]:
optuna.visualization.plot_param_importances(rwkv_study)

In [None]:
param_importances = optuna.importance.get_param_importances(rwkv_study)
top_params = [param for param, _ in sorted(param_importances.items(), key=lambda item: item[1], reverse=True)[:3]]
fig = optuna.visualization.plot_contour(rwkv_study, params=top_params[:3])
fig.show()

In [None]:
best_rwkv_trial = trial
if not os.path.isdir("studies"):
    os.mkdir("studies")
joblib.dump(study, "./studies/RWKV_study.pkl")

## <img src="https://img.icons8.com/?size=100&id=s8cTlBs8lfX0&format=png&color=000000" style="height:50px;display:inline"> Training
---

### LSTM Train

In [None]:
seed = 211
np.random.seed(seed)
torch.manual_seed(seed)

model_blstm = GoEmotions_LSTM(
    vocab_dim=vocab_dim,
    embedding_dim=best_lstm_trial.params['embedding_dim'],
    embedding_weights=embedding_matrix,
    lstm_hidden_dim=best_lstm_trial.params['rnn_hidden_dim'],
    lstm_num_layers=best_lstm_trial.params['rnn_num_layers'],
    lstm_dropout=best_lstm_trial.params['rnn_dropout'],
    bi_directional=True,
    dense_hidden_dims=[dim for key, dim in best_lstm_trial.params.items() if key.startswith('n_units_l')],
    dense_dropouts=[p for key, p in best_lstm_trial.params.items() if key.startswith('dropout_l')],
    num_classes=num_classes
).to(device)
print(model_blstm)

count_layers_and_parameters(model_blstm)
calc_model_size(model_blstm)

# y = model_blstm(next(iter(train_dataloader))[1])
# graph = make_dot(y.mean(), params=dict(model_blstm.named_parameters()))
# graph.render('./images/lstm.svg', format='svg')

state = {
    'net': model_blstm.state_dict(),
    'epoch': 0,
}
if not os.path.isdir('checkpoints'):
    os.mkdir('checkpoints')
torch.save(state, f'./checkpoints/{type(model_blstm).__name__}_initial_ckpt.pth')

In [None]:
train_losses_lstm, train_accuracies_lstm, validation_accuracies_lstm, train_cm_lstm, valid_cm_lstm = train_model(model_blstm, train_dataloader, valid_dataloader, best_lstm_trial.params['lr'], step_size, best_lstm_trial.params['optimizer'], num_epochs)

In [None]:
plot_statistics(train_losses_lstm, train_accuracies_lstm, validation_accuracies_lstm, train_cm_lstm, valid_cm_lstm)

In [None]:
np.savez('lstm_training_stats.npz',
         train_losses=train_losses_lstm,
         train_accuracies=train_accuracies_lstm,
         validation_accuracies=validation_accuracies_lstm,
         train_cm=train_cm_lstm,
         valid_cm=valid_cm_lstm)

### GRU Train

In [None]:
seed = 211
np.random.seed(seed)
torch.manual_seed(seed)

model_gru = GoEmotions_GRU(
    vocab_dim=vocab_dim,
    embedding_dim=best_gru_trial.params['embedding_dim'],
    embedding_weights=embedding_matrix,
    gru_hidden_dim=best_gru_trial.params['rnn_hidden_dim'],
    gru_num_layers=best_gru_trial.params['rnn_num_layers'],
    gru_dropout=best_gru_trial.params['rnn_dropout'],
    bi_directional=True,
    dense_hidden_dims=[dim for key, dim in best_gru_trial.params.items() if key.startswith('n_units_l')],
    dense_dropouts=[p for key, p in best_gru_trial.params.items() if key.startswith('dropout_l')],
    num_classes=num_classes
).to(device)
print(model_gru)

count_layers_and_parameters(model_gru)
calc_model_size(model_gru)

# y = model_blstm(next(iter(train_dataloader))[1])
# graph = make_dot(y.mean(), params=dict(model_blstm.named_parameters()))
# graph.render('./images/gru.svg', format='svg')

state = {
    'net': model_gru.state_dict(),
    'epoch': 0,
}
if not os.path.isdir('checkpoints'):
    os.mkdir('checkpoints')
torch.save(state, f'./checkpoints/{type(model_gru).__name__}_initial_ckpt.pth')

In [None]:
train_losses_gru, train_accuracies_gru, validation_accuracies_gru, train_cm_gru, valid_cm_gru = train_model(model_gru, train_dataloader, valid_dataloader, best_gru_trial.params['lr'], step_size, best_gru_trial.params['optimizer'], num_epochs)

In [None]:
plot_statistics(train_losses_gru, train_accuracies_gru, validation_accuracies_gru, train_cm_gru, valid_cm_gru)

In [None]:
np.savez('gru_training_stats.npz',
         train_losses=train_losses_gru,
         train_accuracies=train_accuracies_gru,
         validation_accuracies=validation_accuracies_gru,
         train_cm=train_cm_gru,
         valid_cm=valid_cm_gru)

### RWkV Train

In [None]:
seed = 211
np.random.seed(seed)
torch.manual_seed(seed)

model_rwkv = GoEmotions_RWKV(
    vocab_dim=vocab_dim,
    embedding_dim=best_rwkv_trial.params['embedding_dim'],
    embedding_weights=embedding_matrix,
    rwkv_num_layers=best_rwkv_trial.params['rnn_num_layers'],
    ctx_len=best_rwkv_trial.params['ctx_len'],
    num_classes=num_classes
).to(device)
print(model_rwkv)

count_layers_and_parameters(model_rwkv)
calc_model_size(model_rwkv)

state = {
    'net': model_rwkv.state_dict(),
    'epoch': 0,
}
if not os.path.isdir('checkpoints'):
    os.mkdir('checkpoints')
torch.save(state, f'./checkpoints/{type(model_rwkv).__name__}_initial_ckpt.pth')

In [None]:
train_losses_rwkv, train_accuracies_rwkv, validation_accuracies_rwkv, train_cm_rwkv, valid_cm_rwkv = train_model(model_rwkv, train_dataloader, valid_dataloader, best_rwkv_trial.params['lr'], step_size, best_rwkv_trial.params['optimizer'], num_epochs)

In [None]:
np.savez('rwkv_training_stats.npz',
         train_losses=train_losses_rwkv,
         train_accuracies=train_accuracies_rwkv,
         validation_accuracies=validation_accuracies_rwkv,
         train_cm=train_cm_rwkv,
         valid_cm=valid_cm_rwkv)

## <img src="https://img.icons8.com/?size=100&id=HkG28tSEJLgP&format=png&color=000000" style="height:50px;display:inline"> Comparison
---

In [None]:
if os.path.exists('lstm_training_stats.npz'):
    lstm_stats = np.load('lstm_training_stats.npz')
    train_losses_lstm = lstm_stats['train_losses']
    train_accuracies_lstm = lstm_stats['train_accuracies']
    validation_accuracies_lstm = lstm_stats['validation_accuracies']
    train_cm_lstm = lstm_stats['train_cm']
    valid_cm_lstm = lstm_stats['valid_cm']
    print('Loaded LSTM training stats')

if os.path.exists('gru_training_stats.npz'):
    gru_stats = np.load('gru_training_stats.npz')
    train_losses_gru = gru_stats['train_losses']
    train_accuracies_gru = gru_stats['train_accuracies']
    validation_accuracies_gru = gru_stats['validation_accuracies']
    train_cm_gru = gru_stats['train_cm']
    valid_cm_gru = gru_stats['valid_cm']
    print('Loaded GRU training stats')

if os.path.exists('rwkv_training_stats.npz'):
    rwkv_stats = np.load('rwkv_training_stats.npz')
    train_losses_rwkv = rwkv_stats['train_losses']
    train_accuracies_rwkv = rwkv_stats['train_accuracies']
    validation_accuracies_rwkv = rwkv_stats['validation_accuracies']
    train_cm_rwkv = rwkv_stats['train_cm']
    valid_cm_rwkv = rwkv_stats['valid_cm']
    print('Loaded RWKV training stats')

In [None]:
epochs_lstm = range(1, len(train_losses_lstm) + 1)
epochs_gru = range(1, len(train_losses_gru) + 1)
epochs_rwkv = range(1, len(train_losses_rwkv) + 1)

plt.figure(figsize=(24, 5))

# Plot losses
plt.subplot(1, 3, 1)
plt.plot(epochs_lstm[1:], train_losses_lstm[1:], label='LSTM')
plt.plot(epochs_gru[1:], train_losses_gru[1:], label='GRU')
plt.plot(epochs_rwkv[1:], train_losses_rwkv[1:], label='RWKV')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.legend()

# Plot accuracies
plt.subplot(1, 3, 2)
plt.plot(epochs_lstm, train_accuracies_lstm, label='LSTM Train')
plt.plot(epochs_gru, train_accuracies_gru, label='GRU Train')
plt.plot(epochs_rwkv, train_accuracies_rwkv, label='RWKV Train')
plt.plot(epochs_lstm, validation_accuracies_lstm, label='LSTM Validation')
plt.plot(epochs_gru, validation_accuracies_gru, label='GRU Validation')
plt.plot(epochs_rwkv, validation_accuracies_rwkv, label='RWKV Validation')
max_valid_acc_indx = list(validation_accuracies_lstm).index(max(validation_accuracies_lstm)) + 1
plt.axvline(max_valid_acc_indx, linestyle='--', color='g',label='LSTM Early Stopping Checkpoint')
max_valid_acc_indx = list(validation_accuracies_gru).index(max(validation_accuracies_gru)) + 1
plt.axvline(max_valid_acc_indx, linestyle='--', color='r',label='GRU Early Stopping Checkpoint')
max_valid_acc_indx = list(validation_accuracies_rwkv).index(max(validation_accuracies_rwkv)) + 1
plt.axvline(max_valid_acc_indx, linestyle='--', color='b',label='RWKV Early Stopping Checkpoint')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Accuracies Over Epochs')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
def calculate_metrics(model, dataloader, primary_emotions):
    model.eval()
    total_correct = 0
    total_images = 0
    confusion_matrix = np.zeros([len(primary_emotions), len(primary_emotions)], int)
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for _, (labels, text) in enumerate(dataloader):
            outputs = model(text)[0] if isinstance(model, GoEmotions_RWKV) else model(text)
            _, predicted = torch.max(outputs.data, 1)
            total_images += labels.size(0)
            total_correct += (predicted == labels).sum().item()
            for i, l in enumerate(labels):
                confusion_matrix[l.item(), predicted[i].item()] += 1
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    model_accuracy = total_correct / total_images * 100

    # Calculate precision, recall, and f1 scores for each class
    precision = precision_score(all_labels, all_predictions, average=None, labels=list(range(len(primary_emotions))))
    recall = recall_score(all_labels, all_predictions, average=None, labels=list(range(len(primary_emotions))))
    f1 = f1_score(all_labels, all_predictions, average=None, labels=list(range(len(primary_emotions))))

    # Calculate macro-average and standard deviation
    macro_precision = np.mean(precision)
    macro_recall = np.mean(recall)
    macro_f1 = np.mean(f1)

    std_precision = np.std(precision)
    std_recall = np.std(recall)
    std_f1 = np.std(f1)

    # Create a pandas DataFrame
    data = {
        'Ekman Emotion': primary_emotions + ['macro-average', 'std'],
        'Precision': np.append(precision, [macro_precision, std_precision]),
        'Recall': np.append(recall, [macro_recall, std_recall]),
        'F1': np.append(f1, [macro_f1, std_f1])
    }

    df = pd.DataFrame(data)

    return model_accuracy, confusion_matrix, df

In [None]:
model_blstm_best = GoEmotions_LSTM(
    vocab_dim=vocab_dim,
    embedding_dim=best_lstm_trial.params['embedding_dim'],
    embedding_weights=embedding_matrix,
    lstm_hidden_dim=best_lstm_trial.params['rnn_hidden_dim'],
    lstm_num_layers=best_lstm_trial.params['rnn_num_layers'],
    lstm_dropout=best_lstm_trial.params['rnn_dropout'],
    bi_directional=True,
    dense_hidden_dims=[dim for key, dim in best_lstm_trial.params.items() if key.startswith('n_units_l')],
    dense_dropouts=[p for key, p in best_lstm_trial.params.items() if key.startswith('dropout_l')],
    num_classes=num_classes
).to(device)

state = torch.load(f'./checkpoints/{type(model_blstm_best).__name__}_best_ckpt.pth', map_location=device)
model_blstm_best.load_state_dict(state)
model_accuracy_blstm, confusion_matrix_blstm, metrics_df_blstm = calculate_metrics(model_blstm_best, test_dataloader, primary_emotions)
print(f'{model_accuracy_blstm=}')
print(metrics_df_blstm)

In [None]:
model_gru_best = GoEmotions_GRU(
    vocab_dim=vocab_dim,
    embedding_dim=best_gru_trial.params['embedding_dim'],
    embedding_weights=embedding_matrix,
    gru_hidden_dim=best_gru_trial.params['rnn_hidden_dim'],
    gru_num_layers=best_gru_trial.params['rnn_num_layers'],
    gru_dropout=best_gru_trial.params['rnn_dropout'],
    bi_directional=True,
    dense_hidden_dims=[dim for key, dim in best_gru_trial.params.items() if key.startswith('n_units_l')],
    dense_dropouts=[p for key, p in best_gru_trial.params.items() if key.startswith('dropout_l')],
    num_classes=num_classes
).to(device)

state = torch.load(f'./checkpoints/{type(model_gru_best).__name__}_best_ckpt.pth', map_location=device)
model_gru_best.load_state_dict(state)
model_accuracy_gru, confusion_matrix_gru, metrics_df_gru = calculate_metrics(model_gru_best, test_dataloader, primary_emotions)
print(f'{model_accuracy_gru=}')
print(metrics_df_gru)

In [None]:
state = torch.load(f'./checkpoints/{GoEmotions_RWKV.__name__}_best_ckpt.pth', map_location=device)

model_rwkv_best = GoEmotions_RWKV(
    vocab_dim=vocab_dim,
    embedding_dim=best_rwkv_trial.params['embedding_dim'],
    embedding_weights=embedding_matrix,
    rwkv_num_layers=best_rwkv_trial.params['rnn_num_layers'],
    ctx_len=best_rwkv_trial.params['ctx_len'],
    num_classes=num_classes,
    aa_bb_pp_shape_1=state['blocks.1.att.aa'].shape[1]
).to(device)

model_rwkv_best.load_state_dict(state)
model_accuracy_rwkv, confusion_matrix_rwkv, metrics_df_rwkv = calculate_metrics(model_rwkv_best, test_dataloader, primary_emotions)
print(f'{model_accuracy_rwkv=}')
print(metrics_df_rwkv)

In [None]:
# Create subplots
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))

# Plot confusion matrix for BLSTM
cmn_lstm = np.round(confusion_matrix_blstm.astype('float') / confusion_matrix_blstm.sum(axis=1)[:, np.newaxis], 3)
disp_blstm = ConfusionMatrixDisplay(confusion_matrix=cmn_lstm, display_labels=primary_emotions)
disp_blstm.plot(ax=ax1, cmap=plt.cm.Blues, xticks_rotation='vertical')
ax1.set_title(f'Confusion Matrix - BLSTM - Model Accuracy: {model_accuracy_blstm:.2f}%')

# Plot confusion matrix for GRU
cmn_gru = np.round(confusion_matrix_gru.astype('float') / confusion_matrix_gru.sum(axis=1)[:, np.newaxis], 3)
disp_gru = ConfusionMatrixDisplay(confusion_matrix=cmn_gru, display_labels=primary_emotions)
disp_gru.plot(ax=ax2, cmap=plt.cm.Blues, xticks_rotation='vertical')
ax2.set_title(f'Confusion Matrix - GRU - Model Accuracy: {model_accuracy_gru:.2f}%')

# Plot confusion matrix for RWKV
cmn_rwkv = np.round(confusion_matrix_rwkv.astype('float') / confusion_matrix_rwkv.sum(axis=1)[:, np.newaxis], 3)
disp_rwkv = ConfusionMatrixDisplay(confusion_matrix=cmn_rwkv, display_labels=primary_emotions)
disp_rwkv.plot(ax=ax3, cmap=plt.cm.Blues, xticks_rotation='vertical')
ax3.set_title(f'Confusion Matrix - RWKV - Model Accuracy: {model_accuracy_rwkv:.2f}%')

fig.suptitle('Comparison of Confusion Matrices on Test Set', fontsize=16)

# Adjust layout
plt.tight_layout()
plt.show()