In [1]:
from efficient_kan import KAN

# Train on MNIST
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
# Load MNIST
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
trainset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
valset = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 10391854.64it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 358342.98it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1298461.95it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [3]:
# Define model
model = KAN([28 * 28, 64, 10])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

# Define loss
criterion = nn.CrossEntropyLoss()

In [4]:
for epoch in range(10):
    # Train
    model.train()
    with tqdm(trainloader) as pbar:
        for i, (images, labels) in enumerate(pbar):
            images = images.view(-1, 28 * 28).to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels.to(device))
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
            pbar.set_postfix(
                loss=loss.item(),
                accuracy=accuracy.item(),
                lr=optimizer.param_groups[0]["lr"],
            )

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            output = model(images)
            val_loss += criterion(output, labels.to(device)).item()
            val_accuracy += (
                (output.argmax(dim=1) == labels.to(device)).float().mean().item()
            )
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Update learning rate
    scheduler.step()

    print(f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████| 938/938 [00:16<00:00, 56.43it/s, accuracy=0.969, loss=0.143, lr=0.001] 


Epoch 1, Val Loss: 0.22028351798179044, Val Accuracy: 0.9375995222929936


100%|██████████| 938/938 [00:16<00:00, 56.71it/s, accuracy=0.875, loss=0.341, lr=0.0008] 


Epoch 2, Val Loss: 0.1603689688957848, Val Accuracy: 0.9530254777070064


100%|██████████| 938/938 [00:16<00:00, 56.20it/s, accuracy=0.938, loss=0.119, lr=0.00064] 


Epoch 3, Val Loss: 0.13523798967581124, Val Accuracy: 0.960390127388535


100%|██████████| 938/938 [00:16<00:00, 55.53it/s, accuracy=0.969, loss=0.0589, lr=0.000512]


Epoch 4, Val Loss: 0.12260323289965701, Val Accuracy: 0.9632762738853503


100%|██████████| 938/938 [00:17<00:00, 55.00it/s, accuracy=0.938, loss=0.116, lr=0.00041] 


Epoch 5, Val Loss: 0.10406205562910267, Val Accuracy: 0.9692476114649682


100%|██████████| 938/938 [00:17<00:00, 54.99it/s, accuracy=0.938, loss=0.147, lr=0.000328] 


Epoch 6, Val Loss: 0.10186860299147191, Val Accuracy: 0.9705414012738853


100%|██████████| 938/938 [00:17<00:00, 53.91it/s, accuracy=0.969, loss=0.0553, lr=0.000262]


Epoch 7, Val Loss: 0.09674896704104201, Val Accuracy: 0.9717356687898089


100%|██████████| 938/938 [00:17<00:00, 55.09it/s, accuracy=1, loss=0.0235, lr=0.00021]    


Epoch 8, Val Loss: 0.0943883761606375, Val Accuracy: 0.9717356687898089


100%|██████████| 938/938 [00:16<00:00, 55.45it/s, accuracy=0.969, loss=0.151, lr=0.000168] 


Epoch 9, Val Loss: 0.09376816293350747, Val Accuracy: 0.9728304140127388


100%|██████████| 938/938 [00:17<00:00, 54.56it/s, accuracy=1, loss=0.0453, lr=0.000134]    


Epoch 10, Val Loss: 0.09265094004633724, Val Accuracy: 0.9727308917197452
