In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
import numpy as np

# Define KAN Layer
class KANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, grid_size=10):
        super(KANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.grid_size = grid_size

        # Learnable spline weights
        self.spline_weights = nn.Parameter(
            torch.randn(output_dim, input_dim, grid_size)
        )
        self.grid_points = nn.Parameter(
            torch.linspace(0, 1, grid_size).unsqueeze(0).expand(output_dim, input_dim, -1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        outputs = []
        for i in range(self.output_dim):
            spline_output = torch.zeros(batch_size).to(x.device)
            for j in range(self.input_dim):
                # B-spline basis interpolation
                b_spline = F.relu(self.grid_points[i, j, :] - x[:, j].unsqueeze(1))
                spline_output += torch.sum(self.spline_weights[i, j, :] * b_spline, dim=1)
            outputs.append(spline_output)
        return torch.stack(outputs, dim=1)


# Define SentimentKANModel
class SentimentKANModel(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, output_dim, num_layers, grid_size):
        super(SentimentKANModel, self).__init__()
        self.kan_layers = nn.ModuleList([
            KANLayer(embedding_dim if i == 0 else hidden_dim, hidden_dim, grid_size)
            for i in range(num_layers)
        ])
        self.classifier = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        for kan_layer in self.kan_layers:
            x = kan_layer(x)
        x = self.classifier(x)
        return F.softmax(x, dim=1)


# Dataset Class
class SentimentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoded = self.tokenizer(
            self.texts[idx], 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors='pt'
        )
        input_ids = encoded['input_ids'].squeeze(0)
        return input_ids, self.labels[idx]


# Prepare Data
def prepare_data():
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    texts = ["I love this movie!", "This was the worst film ever.", "It was okay."]
    labels = [2, 0, 1]  # 2: Positive, 1: Neutral, 0: Negative

    dataset = SentimentDataset(texts, labels, tokenizer, max_length=128)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    return dataloader, tokenizer


# Training Function
def train_model(model, dataloader, criterion, optimizer, device, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in dataloader:
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs.float())
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss:.4f}")


# Main Script
def main():
    # Device Configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Data Preparation
    dataloader, tokenizer = prepare_data()

    # Model, Loss, Optimizer
    embedding_dim = 128  # Dimension of input embeddings
    hidden_dim = 256     # Hidden dimension
    output_dim = 3       # Positive, Neutral, Negative
    num_layers = 3
    grid_size = 10

    model = SentimentKANModel(embedding_dim, hidden_dim, output_dim, num_layers, grid_size).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Train the Model
    train_model(model, dataloader, criterion, optimizer, device, epochs=5)


if __name__ == "__main__":
    main()

Epoch 1/5, Loss: 2.6029
