In [1]:
! git clone https://github.com/MrPio/KAN_tests
! mv KAN_tests/* ./
! rm -r KAN_tests/

Cloning into 'KAN_tests'...
remote: Enumerating objects: 32, done.[K
remote: Counting objects: 100% (32/32), done.[K
remote: Compressing objects: 100% (29/29), done.[K
remote: Total 32 (delta 3), reused 31 (delta 2), pack-reused 0[K
Receiving objects: 100% (32/32), 2.61 MiB | 13.22 MiB/s, done.
Resolving deltas: 100% (3/3), done.


In [2]:
from matplotlib import pyplot as plt
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms
from tqdm import tqdm

from efficient_kan.kan import KAN

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [14]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# Train set. Here we sort the MNIST by digits and disable data shuffling
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
sorted_indices = sorted(range(len(train_dataset)//1), key=lambda idx: train_dataset.targets[idx])
# sorted_indices = range(len(train_dataset)//10)
train_dataset = torch.utils.data.Subset(train_dataset, sorted_indices)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)
# Test set
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [15]:
def train(model, checkpoint, epochs=5):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-6,weight_decay=1e-4)
    # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996)
    for epoch in range(epochs):
        model.train()
        with tqdm(train_loader) as pbar:
            for i, (images, labels) in enumerate(pbar):
                images = images.view(-1, 28 * 28).to(device)
                optimizer.zero_grad()
                output = model(images)
                loss = criterion(output, labels.to(device))
                loss.backward()
                optimizer.step()
                accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
                pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])
                # scheduler.step()
        print(f'Epoch {epoch + 1}, Loss: {loss.item()}')
        torch.save(model.state_dict(), checkpoint)

In [16]:
def validate(model):
    model.eval()
    vals=[0]*10
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.view(-1, 28 * 28).to(device)
            output = model(images)
            for out in output.argmax(dim=1):
              vals[out.item()]+=1
            val_accuracy += (output.argmax(dim=1) == labels.to(device)).float().mean().item()
    val_accuracy /= len(test_loader)
    print(vals)
    print(f"Accuracy: {val_accuracy}")

In [17]:
model = KAN([28 * 28, 256, 64, 10]).to(device)
train(model, 'checkpoint/kan_mnist.pth')

100%|██████████| 938/938 [00:29<00:00, 31.52it/s, accuracy=0, loss=2.32, lr=1e-6]     


Epoch 1, Loss: 2.3166604042053223


100%|██████████| 938/938 [00:29<00:00, 32.01it/s, accuracy=0.0312, loss=2.31, lr=1e-6]


Epoch 2, Loss: 2.310750961303711


100%|██████████| 938/938 [00:27<00:00, 34.63it/s, accuracy=0.0312, loss=2.3, lr=1e-6] 


Epoch 3, Loss: 2.3049917221069336


100%|██████████| 938/938 [00:28<00:00, 33.46it/s, accuracy=0, loss=2.3, lr=1e-6]      


Epoch 4, Loss: 2.299386501312256


100%|██████████| 938/938 [00:30<00:00, 31.10it/s, accuracy=0, loss=2.29, lr=1e-6]     

Epoch 5, Loss: 2.2938811779022217





In [19]:
model = KAN([28 * 28, 256, 64, 10]).to(device)
model.load_state_dict(torch.load('checkpoint/kan_mnist.pth'))
validate(model)

[942, 2937, 578, 361, 1294, 502, 1147, 1600, 537, 102]
Accuracy: 0.5325437898089171
