## Ensuring Reproducibility

In [1]:
import torch
import random
import numpy as np

# Set a fixed seed value for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Loading Model

In [2]:
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

In [3]:
# Load the pre-trained transformer model and tokenizer
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
transformer_model = AutoModel.from_pretrained(model_name)

In [4]:
# Freeze the pre-trained model parameters
for param in transformer_model.parameters():
    param.requires_grad = False

In [5]:
# Move the model to the GPU (if available)
device = "mps" if torch.backends.mps.is_available() else "cpu"
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


## Designing classification head

In [61]:
class ClassificationHead(nn.Module):
    def __init__(self, transformer_model, num_classes, hidden_size=40):
        super().__init__()
        self.transformer_model = transformer_model.to(device)
        self.dropout = nn.Dropout(0.1)
        self.fc1 = nn.Linear(transformer_model.config.hidden_size, hidden_size).to(device)
        self.classifier = nn.Linear(hidden_size, num_classes).to(device)

    def forward(self, input_ids, attention_mask):
        output = self.transformer_model(
            input_ids=input_ids, attention_mask=attention_mask
        )[0]
        output = self.dropout(output[:, 0])
        output = self.fc1(output)
        output = self.classifier(output)
        return output

In [62]:
from kan import FourierKANLayer


# Define the custom classification head using KAN
class KANClassificationHead(nn.Module):
    def __init__(self, transformer_model, num_classes, gridsize=5):
        super().__init__()
        self.transformer_model = transformer_model.to(device)
        self.dropout = nn.Dropout(0.1)
        self.classifier = FourierKANLayer(
            transformer_model.config.hidden_size, num_classes, gridsize
        ).to(device)

    def forward(self, input_ids, attention_mask):
        output = self.transformer_model(
            input_ids=input_ids, attention_mask=attention_mask
        )[0]
        output = self.dropout(output[:, 0])  # Take the CLS token representation
        output = self.classifier(output)
        return output

In [63]:
# Set hyperparameters
num_classes = 4

In [64]:
# Create the classification model
linear_model = ClassificationHead(transformer_model, num_classes)
kan_model = KANClassificationHead(transformer_model, num_classes)

## Model Parameters

In [65]:
from torch.nn.utils import parameters_to_vector
from prettytable import PrettyTable


def count_parameters_per_layer(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Number of Parameters: {parameters_to_vector(model.parameters()).numel()}")
    print(f"Total Trainable Params: {total_params}")

In [66]:
count_parameters_per_layer(linear_model)

+-------------------+------------+
|      Modules      | Parameters |
+-------------------+------------+
|     fc1.weight    |   30720    |
|      fc1.bias     |     40     |
| classifier.weight |    160     |
|  classifier.bias  |     4      |
+-------------------+------------+
Total Number of Parameters: 109513164
Total Trainable Params: 30924


In [67]:
count_parameters_per_layer(kan_model)

+--------------------------+------------+
|         Modules          | Parameters |
+--------------------------+------------+
| classifier.fouriercoeffs |   30720    |
|     classifier.bias      |     4      |
+--------------------------+------------+
Total Number of Parameters: 109512964
Total Trainable Params: 30724
