In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from efficientkan import KAN

from time import time
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
# Load MNIST dataset

transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors and scale to [0,1]
    transforms.Normalize((0.5,), (0.5,))  # Normalize to mean=0.5, std=0.5
])

trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

In [3]:
class MyMLP(nn.Module):
    
    def __init__(self, *args: int):
        super(MyMLP, self).__init__()

        self.layers = list(args)
        self.mlp = nn.Sequential(
            nn.Linear(self.layers[0], self.layers[1]),
            nn.ReLU(),
            nn.Linear(self.layers[1], self.layers[2])
        )

    def forward(self, x):
        return self.mlp(x)
    


class MyKAN(nn.Module):

    def __init__(self, *args: int):
        super(MyKAN, self).__init__()

        self.layers = list(args)
        self.kan = KAN(self.layers)

    def forward(self, x):
        return self.kan(x)

In [4]:
# Define models and parameters

model_mlp = MyMLP(28*28, 64, 10)
model_kan = MyKAN(28*28, 64, 10)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = 'cpu'
model_mlp.to(device)
model_kan.to(device)

num_epochs = 10
learning_rate = 0.001
criterion = nn.CrossEntropyLoss()
optimizer_mlp = torch.optim.Adam(model_mlp.parameters(), lr=learning_rate)
optimizer_kan = torch.optim.Adam(model_kan.parameters(), lr=learning_rate)

[784, 64, 10]


In [5]:
def train_and_evaluate(model, trainloader, valloader, device, num_epochs, optimizer, criterion):

    lr = optimizer.param_groups[0]['lr']
    print(f'Training model "{model.__class__.__name__}" with device: {device} and parameters:' if device=='cpu' else f'Training model "{model.__class__.__name__}" with device: {device} ({torch.cuda.get_device_name()}) and parameters:')
    print(f'\tLayers: {model.layers}\n\tLearning rate: {lr}\n\tLoss function: {criterion}\n\tOptimizer: {optimizer.__class__.__name__}\n')

    losses = []
    accs = []

    for epoch in range(num_epochs):

# Train
        model.train()
        train_loss = 0
        train_acc = 0
        
        start = time()

        with tqdm(trainloader) as pbar:
            for i, (images, labels) in enumerate(pbar):
                images = images.view(-1, 28 * 28).to(device)
                optimizer.zero_grad()
                output = model(images)

                loss = criterion(output, labels.to(device))
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

                accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
                train_acc += accuracy.item()
        
        losses.append(train_loss / i)
        accs.append(train_acc / i)

        print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss / i:.4f}, Train Accuracy: {train_acc / i:.4f}, Time: {time()-start:.2f}s")


# Validation
        model.eval()
        val_loss = 0
        val_acc = 0

        with torch.no_grad():
            for i, (images, labels) in enumerate(valloader):
                images = images.view(-1, 28 * 28).to(device)
                output = model(images)

                val_loss += criterion(output, labels.to(device)).item()
                val_acc += ((output.argmax(dim=1) == labels.to(device)).float().mean().item())

        print(f"\t      Valid Loss: {val_loss / i:.4f}, Valid Accuracy: {val_acc / i:.4f}\n")

    
    return losses, accs

In [None]:
losses_mlp, accs_mlp = train_and_evaluate(model_mlp, trainloader, valloader, device, num_epochs, optimizer_mlp, criterion)

In [None]:
losses_kan, accs_kan = train_and_evaluate(model_kan, trainloader, valloader, device, num_epochs, optimizer_kan, criterion)

#### Training cycle

*Plotting testing results*