In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [6]:
# Define the B-Spline basis functions
class BSpline(nn.Module):
    def __init__(self, order=3, knots=None):
        super(BSpline, self).__init__()
        self.order = order
        if knots is None:
            self.knots = torch.linspace(0, 1, order + 1)
        else:
            self.knots = knots

    def forward(self, x):
        return self.basis(x, self.knots, self.order)

    def basis(self, x, knots, order):
        if order == 0:
            return ((knots[:-1] <= x) & (x < knots[1:])).float()
        else:
            denom1 = knots[order:] - knots[:-order]
            denom1[denom1 == 0] = 1
            term1 = (x - knots[:-order].unsqueeze(1)) / denom1.unsqueeze(1)

            denom2 = knots[order + 1:] - knots[1:-order]
            denom2[denom2 == 0] = 1
            term2 = (knots[order + 1:].unsqueeze(1) - x) / denom2.unsqueeze(1)

            basis_left = self.basis(x, knots, order - 1)[:-1]
            basis_right = self.basis(x, knots, order - 1)[1:]

            return term1 * basis_left + term2 * basis_right

# Define a single KAN layer
class KANLayer(nn.Module):
    def __init__(self, in_features, out_features, spline_order=3):
        super(KANLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.basis_functions = nn.ModuleList([BSpline(order=spline_order) for _ in range(in_features * out_features)])

    def forward(self, x):
        batch_size = x.size(0)
        output = torch.zeros(batch_size, self.out_features).to(x.device)
        for i in range(self.out_features):
            for j in range(self.in_features):
                output[:, i] += self.basis_functions[i * self.in_features + j](x[:, j])
        return output

# Define the KAN network
class KAN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=128, depth=3, spline_order=3):
        super(KAN, self).__init__()
        layers = []
        for _ in range(depth):
            layers.append(KANLayer(input_dim, hidden_dim, spline_order))
            input_dim = hidden_dim
        layers.append(nn.Linear(hidden_dim, output_dim))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [3]:


# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=1000, shuffle=False)



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:11<00:00, 875391.99it/s] 


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 117800.53it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 915396.43it/s] 


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4866035.45it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [4]:

# Initialize the KAN model, loss function, and optimizer
input_dim = 28 * 28  # MNIST images are 28x28
output_dim = 10  # MNIST has 10 classes
model = KAN(input_dim=input_dim, output_dim=output_dim)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [None]:


# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    for images, labels in train_loader:
        images = images.view(-1, 28 * 28)
        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Evaluation loop
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images = images.view(-1, 28 * 28)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Test Accuracy: {100 * correct / total:.2f}%')


In [8]:
import torch
import torch.nn as nn
from tqdm import tqdm
from kan import KAN



def test_mul():
    kan = KAN([2, 2, 1], base_activation=nn.Identity)
    optimizer = torch.optim.LBFGS(kan.parameters(), lr=1)
    with tqdm(range(100)) as pbar:
        for i in pbar:
            loss, reg_loss = None, None

            def closure():
                optimizer.zero_grad()
                x = torch.rand(1024, 2)
                y = kan(x, update_grid=(i % 20 == 0))

                assert y.shape == (1024, 1)
                nonlocal loss, reg_loss
                u = x[:, 0]
                v = x[:, 1]
                loss = nn.functional.mse_loss(y.squeeze(-1), (u + v) / (1 + u * v))
                reg_loss = kan.regularization_loss(1, 0)
                (loss + 1e-5 * reg_loss).backward()
                return loss + reg_loss

            optimizer.step(closure)
            pbar.set_postfix(mse_loss=loss.item(), reg_loss=reg_loss.item())
    for layer in kan.layers:
        print(layer.spline_weight)

In [9]:
# Train on MNIST
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# Load MNIST
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
trainset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
valset = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

In [10]:
# Define model
model = KAN([28 * 28, 64, 10])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

In [11]:




# Define loss
criterion = nn.CrossEntropyLoss()
for epoch in range(10):
    # Train
    model.train()
    with tqdm(trainloader) as pbar:
        for i, (images, labels) in enumerate(pbar):
            images = images.view(-1, 28 * 28).to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels.to(device))
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            output = model(images)
            val_loss += criterion(output, labels.to(device)).item()
            val_accuracy += (
                (output.argmax(dim=1) == labels.to(device)).float().mean().item()
            )
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Update learning rate
    scheduler.step()

    print(
        f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}"
    )

100%|██████████| 938/938 [00:17<00:00, 53.73it/s, accuracy=0.906, loss=0.447, lr=0.001] 


Epoch 1, Val Loss: 0.23048456607588158, Val Accuracy: 0.9325238853503185


100%|██████████| 938/938 [00:16<00:00, 55.22it/s, accuracy=0.938, loss=0.165, lr=0.0008] 


Epoch 2, Val Loss: 0.16485021670888753, Val Accuracy: 0.9535230891719745


100%|██████████| 938/938 [00:17<00:00, 54.18it/s, accuracy=1, loss=0.0532, lr=0.00064]    


Epoch 3, Val Loss: 0.12734886335214943, Val Accuracy: 0.9625796178343949


100%|██████████| 938/938 [00:18<00:00, 50.08it/s, accuracy=0.969, loss=0.0837, lr=0.000512]


Epoch 4, Val Loss: 0.11534768961826755, Val Accuracy: 0.9658638535031847


100%|██████████| 938/938 [00:29<00:00, 32.28it/s, accuracy=0.969, loss=0.0481, lr=0.00041]


Epoch 5, Val Loss: 0.1104087866694447, Val Accuracy: 0.9678542993630573


100%|██████████| 938/938 [00:29<00:00, 31.54it/s, accuracy=0.938, loss=0.153, lr=0.000328] 


Epoch 6, Val Loss: 0.10152506728798957, Val Accuracy: 0.9694466560509554


100%|██████████| 938/938 [00:17<00:00, 54.30it/s, accuracy=1, loss=0.015, lr=0.000262]     


Epoch 7, Val Loss: 0.09775269842501967, Val Accuracy: 0.9705414012738853


100%|██████████| 938/938 [00:16<00:00, 55.36it/s, accuracy=1, loss=0.0102, lr=0.00021]    


Epoch 8, Val Loss: 0.09474892176800426, Val Accuracy: 0.971437101910828


100%|██████████| 938/938 [00:16<00:00, 56.07it/s, accuracy=1, loss=0.034, lr=0.000168]     


Epoch 9, Val Loss: 0.09255986028220387, Val Accuracy: 0.972531847133758


100%|██████████| 938/938 [00:16<00:00, 56.17it/s, accuracy=1, loss=0.0177, lr=0.000134]    


Epoch 10, Val Loss: 0.09221856407732199, Val Accuracy: 0.9726313694267515
