In [1]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

# Load the IMDB dataset
imdb = load_dataset("imdb")

In [2]:
# Load the pre-trained transformer model and tokenizer
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
transformer_model = AutoModel.from_pretrained(model_name)

In [3]:
# Freeze the pre-trained model parameters
for param in transformer_model.parameters():
    param.requires_grad = False

In [4]:
# Set up the data collator and dataloaders
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

max_length = 512
batch_size = 32

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding=True, max_length=max_length)


small_train_dataset = (imdb["train"].shuffle(seed=42).select([i for i in list(range(3000))]))
small_eval_dataset = imdb["test"].shuffle(seed=42).select([i for i in list(range(300))])
small_test_dataset = imdb["test"].select([i for i in list(range(300, 600))])

tokenized_train = small_train_dataset.map(preprocess_function, batched=True, remove_columns=["text"])
tokenized_eval = small_eval_dataset.map(preprocess_function, batched=True, remove_columns=["text"])
tokenized_test = small_test_dataset.map(preprocess_function, batched=True, remove_columns=["text"])

train_dataloader = DataLoader(tokenized_train, shuffle=True, batch_size=batch_size, collate_fn=data_collator)
eval_dataloader = DataLoader(tokenized_eval, batch_size=batch_size, collate_fn=data_collator)
test_dataloader = DataLoader(tokenized_test, batch_size=batch_size, collate_fn=data_collator)

Map:   0%|          | 0/300 [00:00<?, ? examples/s]

In [5]:
# Move the model to the GPU (if available)
device = "mps" if torch.backends.mps.is_available() else "cpu"
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [9]:
from kan_fourier import NaiveFourierKANLayer

# Define the custom classification head using KAN
class KANClassificationHead(nn.Module):
    def __init__(self, transformer_model, num_classes, gridsize=5):
        super().__init__()
        self.transformer_model = transformer_model.to(device)
        self.dropout = nn.Dropout(0.1)
        self.classifier = NaiveFourierKANLayer(transformer_model.config.hidden_size, num_classes, gridsize).to(device)

    def forward(self, input_ids, attention_mask):
        output = self.transformer_model(input_ids=input_ids, attention_mask=attention_mask)[0]
        output = self.dropout(output[:, 0])  # Take the CLS token representation
        output = self.classifier(output)
        return output

In [7]:
# Set hyperparameters
num_classes = 2  # Binary classification (positive/negative)
learning_rate = 2e-5
num_epochs = 3

In [10]:
# Create the classification model
model = KANClassificationHead(transformer_model, num_classes)

# Set up the optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [11]:
# Train
for epoch in range(num_epochs):
    model.train()
    for batch in train_dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        optimizer.zero_grad()
        output = model(input_ids, attention_mask)
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()

    model.eval()
    eval_loss = 0
    correct_predictions = 0
    total_samples = 0
    for batch in eval_dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        output = model(input_ids, attention_mask)
        loss = loss_fn(output, labels)
        eval_loss += loss.item()
        correct_predictions += (output.argmax(dim=1) == labels).sum().item()
        total_samples += labels.size(0)

    # Show metrics
    accuracy = correct_predictions / total_samples
    print(
        f"Epoch {epoch+1}/{num_epochs}, Eval Loss: {eval_loss/len(eval_dataloader)}, Accuracy: {accuracy:.4f}"
    )

Epoch 1/3, Eval Loss: 0.6800906598567963, Accuracy: 0.5700
Epoch 2/3, Eval Loss: 0.6051929473876954, Accuracy: 0.6967
Epoch 3/3, Eval Loss: 0.5615160465240479, Accuracy: 0.7300
