<a href="https://colab.research.google.com/github/Anirudh-R-1201/CodePapers/blob/main/KAN_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [19]:
class KolmogorovArnoldNetwork(nn.Module):
    def __init__(self, input_dim=3072, output_dim=10, hidden_dim=128, num_inner_functions=16, num_outer_functions=10):
        super(KolmogorovArnoldNetwork, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.num_inner_functions = num_inner_functions
        self.num_outer_functions = num_outer_functions

        self.inner_functions = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim + 1, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, 1)
            ) for _ in range(num_inner_functions)
        ])

        self.outer_functions = nn.Sequential(
            nn.Linear(num_inner_functions, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)

        inner_results = []
        for j in range(self.num_inner_functions):
            j_tensor = torch.full((batch_size, 1), j, dtype=torch.float, device=x.device)
            input_j = torch.cat([x, j_tensor], dim=1)
            inner_results.append(self.inner_functions[j](input_j))

        inner_output = torch.cat(inner_results, dim=1)

        output = self.outer_functions(inner_output)

        return output


In [3]:
def load_cifar10(batch_size=8):#64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    return trainloader, testloader

In [30]:
def train(model, trainloader, epochs=25, lr=0.0001):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 1000 == 199:
              print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 200:.3f}')
              running_loss = 0.0

    #print('Finished Training')

In [16]:
def evaluate(model, testloader):
    correct = 0
    total = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    model.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy on test images: {100 * correct / total:.2f}%')


In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [23]:
trainloader, testloader = load_cifar10()

Files already downloaded and verified
Files already downloaded and verified


In [32]:
model = KolmogorovArnoldNetwork().to(device)

In [None]:
train(model, trainloader)

[1,   200] loss: 2.088
[1,  1200] loss: 9.038
[1,  2200] loss: 8.553
[1,  3200] loss: 8.134
[1,  4200] loss: 8.076
[1,  5200] loss: 7.811
[1,  6200] loss: 7.736
[2,   200] loss: 1.494
[2,  1200] loss: 7.379
[2,  2200] loss: 7.362
[2,  3200] loss: 7.269
[2,  4200] loss: 7.167
[2,  5200] loss: 7.017
[2,  6200] loss: 7.144
[3,   200] loss: 1.349
[3,  1200] loss: 6.711
[3,  2200] loss: 6.775
[3,  3200] loss: 6.638
[3,  4200] loss: 6.732
[3,  5200] loss: 6.672
[3,  6200] loss: 6.527
[4,   200] loss: 1.239
[4,  1200] loss: 6.178
[4,  2200] loss: 6.268
[4,  3200] loss: 6.329
[4,  4200] loss: 6.311
[4,  5200] loss: 6.232
[4,  6200] loss: 6.255
[5,   200] loss: 1.156
[5,  1200] loss: 5.783
[5,  2200] loss: 5.936
[5,  3200] loss: 5.952
[5,  4200] loss: 6.052
[5,  5200] loss: 5.848
[5,  6200] loss: 5.911
[6,   200] loss: 1.102
[6,  1200] loss: 5.502
[6,  2200] loss: 5.565
[6,  3200] loss: 5.572
[6,  4200] loss: 5.580
[6,  5200] loss: 5.640
[6,  6200] loss: 5.652
[7,   200] loss: 1.003
[7,  1200] 

In [35]:
evaluate(model, testloader)

Accuracy on test images: 53.28%
