In [None]:
pip install convkan

In [4]:
import torch
from torch import nn
from torchvision import datasets, transforms
from tqdm import tqdm
from convkan import ConvKAN, LayerNorm2D

# Define the model
model = nn.Sequential(
    ConvKAN(1, 32, padding=1, kernel_size=3, stride=1),
    LayerNorm2D(32),
    ConvKAN(32, 32, padding=1, kernel_size=3, stride=2),
    LayerNorm2D(32),
    ConvKAN(32, 10, padding=1, kernel_size=3, stride=2),
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten(),
).cuda()

# Define transformations and download the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
model.train()
for epoch in range(1):
    for data, target in tqdm(train_loader):
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()


 19%|█▊        | 174/938 [00:11<00:48, 15.64it/s]


KeyboardInterrupt: 

In [3]:
# Load the test dataset
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Switch to evaluation mode
model.eval()

# Initialize variables to track correct predictions
correct = 0
total = 0

# Disable gradient computation for testing
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()
        output = model(data)
        _, predicted = torch.max(output, 1)  # Get the index of the max log-probability
        correct += (predicted == target).sum().item()
        total += target.size(0)

# Calculate accuracy
accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")


Test Accuracy: 20.66%
