# Task IX: Kolmogorov-Arnold Network 
Implement a classical Kolmogorov-Arnold Network using basis-splines or some other 
KAN architecture and apply it to MNIST. Show its performance on the test data. 
Comment on potential ideas to extend this classical KAN architecture to a quantum 
KAN and sketch out the architecture in detail.

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

# Step 1: Load MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Step 2: Define Basis Splines (B-splines)
class BSplineLayer(nn.Module):
    def __init__(self, input_dim, num_basis):
        super(BSplineLayer, self).__init__()
        self.input_dim = input_dim
        self.num_basis = num_basis
        self.knots = nn.Parameter(torch.linspace(-1, 1, num_basis))  # Uniformly spaced knots
        self.weights = nn.Parameter(torch.randn(input_dim, num_basis))  # Learnable weights

    def forward(self, x):
        # Compute B-spline basis functions
        x = x.unsqueeze(-1)  # Add a dimension for basis functions
        basis = torch.relu(1 - torch.abs(x - self.knots))  # Linear B-spline basis
        weighted_basis = torch.matmul(basis, self.weights.T)  # Weighted sum of basis functions
        return weighted_basis.sum(dim=1)  # Aggregate across the input dimension
    
    
# Step 3: Define Kolmogorov-Arnold Network (KAN)
class KolmogorovArnoldNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_basis):
        super(KolmogorovArnoldNetwork, self).__init__()
        self.basis_layer = BSplineLayer(input_dim, num_basis)
        self.hidden_layer = nn.Linear(num_basis, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input (28x28 -> 784)
        x = self.basis_layer(x)
        x = torch.relu(self.hidden_layer(x))
        x = self.output_layer(x)
        return x

# Step 4: Initialize Model, Loss, and Optimizer
input_dim = 784  # MNIST images are 28x28
hidden_dim = 128
output_dim = 10  # 10 classes for digits 0-9
num_basis = 32

model = KolmogorovArnoldNetwork(input_dim, hidden_dim, output_dim, num_basis)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Step 5: Train the Model
def train(model, train_loader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

# Step 6: Evaluate the Model
def evaluate(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")
    return accuracy

# Train and Evaluate
train(model, train_loader, criterion, optimizer, epochs=5)
evaluate(model, test_loader)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x784 and 32x128)

In [8]:
from kan import *
torch.set_default_dtype(torch.float64)


# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).
model = KAN(width=[2,5,1], grid=3, k=3, seed=42)

ModuleNotFoundError: No module named 'sklearn'