In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class KANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, grid_size=10):
        super(KANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.grid_size = grid_size

        # Learnable spline parameters for each input-output pair
        self.spline_weights = nn.Parameter(
            torch.randn(output_dim, input_dim, grid_size)
        )
        self.grid_points = nn.Parameter(
            torch.linspace(0, 1, grid_size).unsqueeze(0).expand(output_dim, input_dim, -1)
        )

    def forward(self, x):
        # Apply spline activation on each input-output connection
        batch_size = x.size(0)
        outputs = []
        for i in range(self.output_dim):
            spline_output = torch.zeros(batch_size).to(x.device)
            for j in range(self.input_dim):
                # B-spline basis interpolation
                b_spline = F.relu(self.grid_points[i, j, :] - x[:, j].unsqueeze(1))
                spline_output += torch.sum(self.spline_weights[i, j, :] * b_spline, dim=1)
            outputs.append(spline_output)
        return torch.stack(outputs, dim=1)


In [2]:
class SentimentKANModel(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, output_dim, num_layers, grid_size):
        super(SentimentKANModel, self).__init__()
        self.kan_layers = nn.ModuleList([
            KANLayer(embedding_dim if i == 0 else hidden_dim, hidden_dim, grid_size)
            for i in range(num_layers)
        ])
        self.classifier = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        for kan_layer in self.kan_layers:
            x = kan_layer(x)
        x = self.classifier(x)
        return F.softmax(x, dim=1)


In [None]:
model = SentimentKANModel(embedding_dim=768, hidden_dim=256, output_dim=3, num_layers=3, grid_size=10)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(epochs):
    for batch in dataloader:
        inputs, labels = batch
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item()}")
