In [1]:
from src.efficient_kan.kan import KAN
from src.efficient_kan.functions import evaluate, save_kan_model, load_kan_model, data_subset

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

import sys
import os
os.chdir("/workspaces/efficient-kan")
sys.path.append(os.path.abspath("./src"))

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),  
    transforms.Normalize((0.5,), (0.5,))
])
full_trainset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
full_valset = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)

In [3]:
#Subsets of the data
trainset = data_subset(full_trainset, 1)
valset = data_subset(full_valset, 1)
print(f"Using {len(trainset)} training samples out of {len(full_trainset)} total")
print(f"Using {len(valset)} validation samples out of {len(full_valset)} total")

#Data Loaders
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

Using 60000 training samples out of 60000 total
Using 10000 validation samples out of 10000 total


In [None]:
#Initialize the model
model = KAN([28 * 28, 64, 10])

#Load the model?
#model = load_kan_model("./model/any")

#Other Variables
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)
criterion = nn.CrossEntropyLoss()

RuntimeError: Error(s) in loading state_dict for KAN:
	Missing key(s) in state_dict: "layers.0.spline_scaler", "layers.1.spline_scaler". 

In [5]:
#Initial evaluation
fullloader = DataLoader(full_valset, batch_size=64, shuffle=False)
total_loss, total_accuracy = evaluate(fullloader, model, criterion)
print(f"Total Loss: {total_loss}, Total Accuracy: {total_accuracy}")

Total Loss: 0.13927828586529822, Total Accuracy: 0.9570063694267515


In [6]:
#Train the model
for epoch in range(10):
    model.train()
    with tqdm(trainloader) as pbar:
        for i, (images, labels) in enumerate(pbar):
            images = images.view(-1, 28 * 28).to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels.to(device))
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])

    # Validation
    model.eval()
    val_loss, val_accuracy = evaluate(valloader, model, criterion)
    scheduler.step()
    print(f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████| 938/938 [00:27<00:00, 33.88it/s, accuracy=0.969, loss=0.107, lr=0.001] 


Epoch 1, Val Loss: 0.13094316027868705, Val Accuracy: 0.9596934713375797


100%|██████████| 938/938 [00:26<00:00, 35.95it/s, accuracy=0.969, loss=0.0711, lr=0.0008]


Epoch 2, Val Loss: 0.11488506491859532, Val Accuracy: 0.9647691082802548


100%|██████████| 938/938 [00:26<00:00, 35.13it/s, accuracy=0.969, loss=0.166, lr=0.00064] 


Epoch 3, Val Loss: 0.10752356418402521, Val Accuracy: 0.9675557324840764


100%|██████████| 938/938 [00:25<00:00, 36.13it/s, accuracy=0.969, loss=0.132, lr=0.000512] 


Epoch 4, Val Loss: 0.09236083468007054, Val Accuracy: 0.9720342356687898


100%|██████████| 938/938 [00:25<00:00, 36.71it/s, accuracy=1, loss=0.0159, lr=0.00041]    


Epoch 5, Val Loss: 0.09315305239087932, Val Accuracy: 0.971437101910828


100%|██████████| 938/938 [00:25<00:00, 36.12it/s, accuracy=1, loss=0.00473, lr=0.000328]   


Epoch 6, Val Loss: 0.08997582893812685, Val Accuracy: 0.9739251592356688


100%|██████████| 938/938 [00:27<00:00, 34.03it/s, accuracy=1, loss=0.0321, lr=0.000262]    


Epoch 7, Val Loss: 0.09081262281716201, Val Accuracy: 0.9735270700636943


100%|██████████| 938/938 [00:26<00:00, 36.04it/s, accuracy=1, loss=0.00338, lr=0.00021]   


Epoch 8, Val Loss: 0.0895675291703231, Val Accuracy: 0.9743232484076433


100%|██████████| 938/938 [00:26<00:00, 36.04it/s, accuracy=1, loss=0.0019, lr=0.000168]    


Epoch 9, Val Loss: 0.08938034130987017, Val Accuracy: 0.9740246815286624


100%|██████████| 938/938 [00:27<00:00, 34.00it/s, accuracy=1, loss=0.00176, lr=0.000134]   


Epoch 10, Val Loss: 0.08996701971286683, Val Accuracy: 0.9741242038216561


In [7]:
#Save and more evaluation
save_kan_model(model, path="./model/loadmodel")
fullloader = DataLoader(full_valset, batch_size=64, shuffle=False)
total_loss, total_accuracy = evaluate(fullloader, model, criterion)
print(f"Total Loss: {total_loss}, Total Accuracy: {total_accuracy}")

Total Loss: 0.08996594905060609, Total Accuracy: 0.9741242038216561
