# Comparing MLP vs. KAN on Classification problem with a bit more input dimensions (MNIST)

Original KAN example: https://github.com/KindXiaoming/pykan/blob/master/tutorials/Example_3_classfication.ipynb \
EfficientKAN by Blealtan: https://github.com/Blealtan/efficient-kan \
MLP in pytorch Referenced from: https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html \
ConvolutionalKAN by AntonioTepsich: https://github.com/AntonioTepsich/Convolutional-KANs

Test with MNIST dataset 

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

import pickle
import struct
from array import array

from tqdm import tqdm
import time

import gc
import os
from os.path import join
import sys
import math

In [2]:
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..', 'efficient-kan-master', 'src')))

from efficient_kan import KAN 

In [3]:
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..', 'Convolutional-KANs-master')))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..', 'Convolutional-KANs-master', 'kan_convolutional')))

from kan_convolutional.KANConv import KAN_Convolutional_Layer
from kan_convolutional.KANLinear import KANLinear

In [4]:
print('Has CUDA:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('device name:', torch.cuda.get_device_name(0))

Has CUDA: True
device name: NVIDIA GeForce RTX 3060


### Prepare the dataset

##### Read from file

In [5]:
# MNIST Data Loader Class
# https://www.kaggle.com/code/hojjatk/read-mnist-dataset
class MnistDataloader(object):
    def __init__(self, training_images_filepath,training_labels_filepath,
                 test_images_filepath, test_labels_filepath):
        self.training_images_filepath = training_images_filepath
        self.training_labels_filepath = training_labels_filepath
        self.test_images_filepath = test_images_filepath
        self.test_labels_filepath = test_labels_filepath
    
    def read_images_labels(self, images_filepath, labels_filepath):        
        labels = []
        with open(labels_filepath, 'rb') as file:
            magic, size = struct.unpack(">II", file.read(8))
            if magic != 2049:
                raise ValueError('Magic number mismatch, expected 2049, got {}'.format(magic))
            labels = array("B", file.read())        
        
        with open(images_filepath, 'rb') as file:
            magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
            if magic != 2051:
                raise ValueError('Magic number mismatch, expected 2051, got {}'.format(magic))
            image_data = array("B", file.read())        
        images = []
        for i in range(size):
            images.append([0] * rows * cols)
        for i in range(size):
            img = np.array(image_data[i * rows * cols:(i + 1) * rows * cols])
            img = img.reshape(28, 28)
            images[i][:] = img            
        
        return images, labels
            
    def load_data(self):
        x_train, y_train = self.read_images_labels(self.training_images_filepath, self.training_labels_filepath)
        x_test, y_test = self.read_images_labels(self.test_images_filepath, self.test_labels_filepath)
        return (x_train, y_train),(x_test, y_test) 

In [6]:
# Set file paths based on added MNIST Datasets
input_path = 'MNIST_dataset'
training_images_filepath = join(input_path, 'train-images-idx3-ubyte/train-images-idx3-ubyte')
training_labels_filepath = join(input_path, 'train-labels-idx1-ubyte/train-labels-idx1-ubyte')
test_images_filepath = join(input_path, 't10k-images-idx3-ubyte/t10k-images-idx3-ubyte')
test_labels_filepath = join(input_path, 't10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte')

In [7]:
# Load MINST dataset
mnist_dataloader = MnistDataloader(training_images_filepath, training_labels_filepath, test_images_filepath, test_labels_filepath)
(x_tr, y_tr), (x_te, y_te) = mnist_dataloader.load_data()

In [8]:
x_tr = np.array(x_tr)
y_tr = np.array(y_tr)
x_te = np.array(x_te)
y_te = np.array(y_te)

##### Data processing

In [9]:
# normalie the dataset values to [0.0, 1.0]
vmax = np.amax(x_tr)
x_tr = x_tr / vmax
x_te = x_te / vmax

In [10]:
# flatten
x_tr_flattened = x_tr.reshape((x_tr.shape[0], x_tr.shape[1] * x_tr.shape[2]))
x_te_flattened = x_te.reshape((x_te.shape[0], x_te.shape[1] * x_te.shape[2]))

# conv color channel
x_tr_channeled = x_tr.reshape((x_tr.shape[0], 1, x_tr.shape[1], x_tr.shape[2]))
x_te_channeled = x_te.reshape((x_te.shape[0], 1, x_te.shape[1], x_te.shape[2]))

In [11]:
x_train_flattened = torch.tensor(np.array(x_tr_flattened))
x_train_channeled = torch.tensor(np.array(x_tr_channeled))
y_train = torch.tensor(np.array(y_tr))

x_test_flattened  = torch.tensor(np.array(x_te_flattened))
x_test_channeled  = torch.tensor(np.array(x_te_channeled))
y_test  = torch.tensor(np.array(y_te))

In [12]:
x_train_flattened.shape

torch.Size([60000, 784])

In [13]:
x_train_channeled.shape

torch.Size([60000, 1, 28, 28])

In [14]:
# chop off a ratio of the train sets
ratio_keep = 1.0

num_train = x_tr.shape[0]
num_test = x_te.shape[0]

num_train_1 = int(num_train * ratio_keep)
num_test_1  = int(num_test * ratio_keep)

x_train_flattened = x_train_flattened[:num_train_1, :]
x_train_channeled = x_train_channeled[:num_train_1, :]
y_train = y_train[:num_train_1]

x_test_flattened = x_test_flattened[:num_test_1, :]
x_test_channeled = x_test_channeled[:num_test_1, :]
y_test = y_test[:num_test_1]

##### Create a troch dataset object

In [15]:
from torch.utils.data import TensorDataset, DataLoader

In [16]:
batch_size = 256

In [17]:
# flattened for fully connected
train_set_flattened = TensorDataset(x_train_flattened.to(dtype=torch.float32), y_train)
train_loader_flattened = DataLoader(train_set_flattened, batch_size=batch_size, shuffle=True)

test_set_flattened = TensorDataset(x_test_flattened.to(dtype=torch.float32), y_test)
test_loader_flattened = DataLoader(test_set_flattened, batch_size=batch_size, shuffle=False)

# channeled for convolution
train_set_channeled = TensorDataset(x_train_channeled.to(dtype=torch.float32), y_train)
train_loader_channeled = DataLoader(train_set_channeled, batch_size=batch_size, shuffle=True)

test_set_channeled = TensorDataset(x_test_channeled.to(dtype=torch.float32), y_test)
test_loader_channeled = DataLoader(test_set_channeled, batch_size=batch_size, shuffle=False)

In [18]:
dataset_flattened = {'train': train_loader_flattened, 'test': test_loader_flattened}
dataset_channeled = {'train': train_loader_channeled, 'test': test_loader_channeled}

### Code for classical MLP

In [19]:
class NeuralNetwork(nn.Module):
    '''
    width is the dimension of each layer, like [784, 20, 20, 10]
    '''
    def __init__(self, width):
        super().__init__()
        self.flatten = nn.Flatten()
        stack = []
        l = len(width)
        for i in range(l-2):
            stack.append(nn.Linear(width[i], width[i+1]))
            stack.append(nn.ReLU())
        stack.append(nn.Linear(width[l-2], width[l-1]))
        
        self.linear_relu_stack = nn.Sequential(*stack)

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

### Code for ConvolutionKAN
copied from: https://github.com/AntonioTepsich/Convolutional-KANs/tree/master/architectures_28x28

In [20]:
class ConvKAN_2conv(nn.Module):
    def __init__(self,device: str = 'cuda'):
        super().__init__()
        self.conv1 = KAN_Convolutional_Layer(
            n_convs = 5,
            kernel_size= (3,3),
            device = device
        )

        self.conv2 = KAN_Convolutional_Layer(
            n_convs = 5,
            kernel_size = (3,3),
            device = device
        )

        self.pool1 = nn.MaxPool2d(
            kernel_size=(2, 2)
        )
        
        self.flat = nn.Flatten() 
        
        self.linear1 = nn.Linear(625, 64)
        self.linear2 = nn.Linear(64, 10)


    def forward(self, x):
        x = self.conv1(x)

        x = self.pool1(x)

        x = self.conv2(x)
        x = self.pool1(x)
        x = self.flat(x)
        x = self.linear1(x)
        x = self.linear2(x)
        x = F.log_softmax(x, dim=1)
        return x

In [21]:
class ConvKAN_1conv(nn.Module):
    def __init__(self,device: str = 'cuda'):
        super().__init__()
        self.conv1 = KAN_Convolutional_Layer(
            n_convs = 5,
            kernel_size= (3,3),
            device = device
        )

        self.pool1 = nn.MaxPool2d(
            kernel_size=(2, 2)
        )
        
        self.flat = nn.Flatten() 
        
        self.linear1 = nn.Linear(845, 64)
        self.linear2 = nn.Linear(64, 10)


    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.flat(x)
        x = self.linear1(x)
        x = self.linear2(x)
        x = F.log_softmax(x, dim=1)
        return x

In [22]:
class KKAN_Convolutional_Network(nn.Module):
    def __init__(self,device: str = 'cuda'):
        super().__init__()
        self.conv1 = KAN_Convolutional_Layer(
            n_convs = 5,
            kernel_size= (3,3),
            device = device
        )

        self.conv2 = KAN_Convolutional_Layer(
            n_convs = 5,
            kernel_size = (3,3),
            device = device
        )

        self.pool1 = nn.MaxPool2d(
            kernel_size=(2, 2)
        )
        
        self.flat = nn.Flatten() 

        self.kan1 = KANLinear(
            625,
            10,
            grid_size=10,
            spline_order=3,
            scale_noise=0.01,
            scale_base=1,
            scale_spline=1,
            base_activation=nn.SiLU,
            grid_eps=0.02,
            grid_range=[0,1],
        )


    def forward(self, x):
        x = self.conv1(x)

        x = self.pool1(x)

        x = self.conv2(x)
        x = self.pool1(x)
        x = self.flat(x)

        x = self.kan1(x) 
        x = F.log_softmax(x, dim=1)

        return x

In [23]:
convkan_types_map = {0: ConvKAN_1conv, 1: ConvKAN_2conv, 2: KKAN_Convolutional_Network}

### Create and train model

##### utility funcs

In [24]:
# see total number of params
def get_num_params(model):
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    pytorch_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

    return pytorch_total_params, pytorch_total_params_trainable

##### New util codes
copied from: https://github.com/AntonioTepsich/Convolutional-KANs/blob/master/experiment_28x28.ipynb

In [25]:
def train_batch_loader(model, device, train_loader, optimizer, epoch, criterion):
    """
    Train the model for one epoch

    Args:
        model: the neural network model
        device: cuda or cpu
        train_loader: DataLoader for training data
        optimizer: the optimizer to use (e.g. SGD)
        epoch: the current epoch
        criterion: the loss function (e.g. CrossEntropy)

    Returns:
        avg_loss: the average loss over the training set
    """

    model.to(device)
    model.train()
    train_loss = 0
    # Process the images in batches
    for batch_idx, (data, target) in enumerate(train_loader):
        # Recall that GPU is optimized for the operations we are dealing with
        data, target = data.to(device), target.to(device)
        # Reset the optimizer
        optimizer.zero_grad()
        # Push the data forward through the model layers
        output = model(data)
        # Get the loss
        loss = criterion(output, target)
        # Keep a running total
        train_loss += loss.item()
        # Backpropagate
        loss.backward()
        optimizer.step()
        
    # return average loss for the epoch
    avg_loss = train_loss / (batch_idx+1)
    # print('Training set: Average loss: {:.6f}'.format(avg_loss))
    return avg_loss

def test_batch_loader(model, device, test_loader, criterion):
    """
    Test the model

    Args:
        model: the neural network model
        device: cuda or cpu
        test_loader: DataLoader for test data
        criterion: the loss function (e.g. CrossEntropy)

    Returns:
        test_loss: the average loss over the test set
        accuracy: the accuracy of the model on the test set
        precision: the precision of the model on the test set
        recall: the recall of the model on the test set
        f1: the f1 score of the model on the test set
    """

    model.eval()
    test_loss = 0
    correct = 0
    all_targets = []
    all_predictions = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            
            # Get the predicted classes for this batch
            output = model(data)
            
            # Calculate the loss for this batch
            test_loss += criterion(output, target).item()
            
            # Calculate the accuracy for this batch
            _, predicted = torch.max(output.data, 1)
            correct += (target == predicted).sum().item()

            # Collect all targets and predictions for metric calculations
            all_targets.extend(target.view_as(predicted).cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    # Normalize test loss
    test_loss /= len(test_loader.dataset)
    accuracy = correct / len(test_loader.dataset)

    # print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%), Precision: {:.2f}, Recall: {:.2f}, F1 Score: {:.2f}\n'.format(
    #     test_loss, correct, len(test_loader.dataset), accuracy, precision, recall, f1))

    return test_loss, accuracy

In [26]:
def train_process(model, num_epochs, loss_fn, train_loader, test_loader, optimizer, scheduler=None):
    res = {'train_loss': [], 'test_loss': [], 'test_acc': []}
    for n in tqdm(range(num_epochs)):
        train_loss = train_batch_loader(model, 'cuda', train_loader, optimizer, n, loss_fn)
        # test_loss, test_acc, precision, recall, f1 = test_batch_loader(model, 'cuda', test_loader, loss_fn)
        test_loss, test_acc = test_batch_loader(model, 'cuda', test_loader, loss_fn)
        if scheduler:
            scheduler.step()
        
        res['train_loss'].append(train_loss)
        res['test_loss'].append(test_loss)
        res['test_acc'].append(test_acc)
    
    return res

##### Create different types of models and train them
We want to use the same optimization method, batch size, etc. to compare the different architectures fairly \
This might cause slower training

In [27]:
'''
create, train, then discard a model
our objective is to figure out the efficacy of different shaped models
'''
def create_and_train(model_type, width, epochs, lr, batch_size=-1):
    dset = dataset_flattened
    
    # create model
    if model_type == 'mlp':
        model = NeuralNetwork(width).to(device='cuda')
    elif model_type == 'kan':
        model = KAN(width, grid_size=3, spline_order=3).to(device='cuda')
    elif model_type == 'convkan':
        # TODO: change to use width
        model = convkan_types_map[width]().to(device='cuda')
        dset = dataset_channeled

    # record num of params
    num_params = get_num_params(model)
    
    # use categorical cross entropy loss
    loss_fn = torch.nn.CrossEntropyLoss()
    # learning rate and adam optimizer
    # optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)
    
    # train as any vanilla pytorch model
    start = time.time()
    # results = train(model, epochs, loss_fn, dset, optimizer, batch_size=batch_size)
    results = train_process(model, epochs, loss_fn, dset['train'], dset['test'], optimizer, scheduler=scheduler)
    end = time.time()

    # free the memory
    model.cpu()
    del model
    gc.collect()
    torch.cuda.empty_cache()

    return {'model_type': model_type, 'width': width, 'num_params': num_params, 'train_result': results, 'train_time': (end - start)}

In [28]:
kan_shapes = [
    [784, 10],
    [784, 16, 10],
    [784, 32, 10],
    [784, 64, 10],
    [784, 128, 10],
    [784, 32, 32, 10],
    [784, 64, 64, 10],
]

mlp_shapes = [
    [784, 10],
    [784, 16, 10],
    [784, 32, 10],
    [784, 64, 10],
    [784, 128, 10],
    [784, 256, 10],
    [784, 32, 32, 10],
    [784, 64, 64, 10],
    [784, 128, 128, 10],
    [784, 256, 256, 10],
    [784, 512, 512, 10],
]

convkan_types = convkan_types_map.keys()

In [29]:
results = []

In [30]:
epoch_ct = 40
# epoch_ct = 1

In [31]:
for shape in kan_shapes:
    print('Training KAN shape:', shape)
    res = create_and_train('kan', shape, epoch_ct, lr=1e-3)
    print('train time:', res['train_time'])
    print('------------------------------------------------------------')
    results.append(res)

for shape in convkan_types:
    print('Training ConvKAN type:', shape)
    res = create_and_train('convkan', shape, epoch_ct, lr=1e-3)
    print('train time:', res['train_time'])
    print('------------------------------------------------------------')
    results.append(res)

for shape in mlp_shapes:
    print('Training MLP shape:', shape)
    res = create_and_train('mlp', shape, epoch_ct, lr=1e-3)
    print('train time:', res['train_time'])
    print('------------------------------------------------------------')
    results.append(res)

Training KAN shape: [784, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [00:57<00:00,  1.43s/it]


train time: 57.024134397506714
------------------------------------------------------------
Training KAN shape: [784, 16, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [01:09<00:00,  1.74s/it]


train time: 69.70530200004578
------------------------------------------------------------
Training KAN shape: [784, 32, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [01:09<00:00,  1.74s/it]


train time: 69.58233690261841
------------------------------------------------------------
Training KAN shape: [784, 64, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [01:05<00:00,  1.65s/it]


train time: 65.91526436805725
------------------------------------------------------------
Training KAN shape: [784, 128, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [01:06<00:00,  1.67s/it]


train time: 66.72944474220276
------------------------------------------------------------
Training KAN shape: [784, 32, 32, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [01:23<00:00,  2.09s/it]


train time: 83.45880961418152
------------------------------------------------------------
Training KAN shape: [784, 64, 64, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [01:20<00:00,  2.01s/it]


train time: 80.53492832183838
------------------------------------------------------------
Training ConvKAN type: 0


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [09:58<00:00, 14.97s/it]


train time: 598.6782262325287
------------------------------------------------------------
Training ConvKAN type: 1


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [33:28<00:00, 50.21s/it]


train time: 2008.5516653060913
------------------------------------------------------------
Training ConvKAN type: 2


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [34:13<00:00, 51.34s/it]


train time: 2053.6402065753937
------------------------------------------------------------
Training MLP shape: [784, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [00:39<00:00,  1.02it/s]


train time: 39.381596326828
------------------------------------------------------------
Training MLP shape: [784, 16, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [00:43<00:00,  1.08s/it]


train time: 43.26330351829529
------------------------------------------------------------
Training MLP shape: [784, 32, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [00:43<00:00,  1.08s/it]


train time: 43.281919717788696
------------------------------------------------------------
Training MLP shape: [784, 64, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [00:44<00:00,  1.11s/it]


train time: 44.23043966293335
------------------------------------------------------------
Training MLP shape: [784, 128, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [00:44<00:00,  1.11s/it]


train time: 44.35863900184631
------------------------------------------------------------
Training MLP shape: [784, 256, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [00:44<00:00,  1.11s/it]


train time: 44.43997931480408
------------------------------------------------------------
Training MLP shape: [784, 32, 32, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [00:44<00:00,  1.11s/it]


train time: 44.26502466201782
------------------------------------------------------------
Training MLP shape: [784, 64, 64, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [00:46<00:00,  1.16s/it]


train time: 46.36068105697632
------------------------------------------------------------
Training MLP shape: [784, 128, 128, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [00:43<00:00,  1.08s/it]


train time: 43.14708209037781
------------------------------------------------------------
Training MLP shape: [784, 256, 256, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [00:46<00:00,  1.17s/it]


train time: 46.928699254989624
------------------------------------------------------------
Training MLP shape: [784, 512, 512, 10]


100%|███████████████████████████████████████████████████████████████████████████████| 40/40 [00:43<00:00,  1.09s/it]

train time: 43.43510174751282
------------------------------------------------------------





In [32]:
results

[{'model_type': 'kan',
  'width': [784, 10],
  'num_params': (62720, 62720),
  'train_result': {'train_loss': [0.8547003550732389,
    0.385613499803746,
    0.33371262036739513,
    0.31314211119996743,
    0.3015814148999275,
    0.2943480814390994,
    0.2889997953430135,
    0.2853761023029368,
    0.28301142365374465,
    0.28106454820074933,
    0.2790402493578322,
    0.2779865837477623,
    0.2771090750364547,
    0.27592858621414673,
    0.27546686308181034,
    0.2746155270870696,
    0.2748618239417989,
    0.27429414829041093,
    0.27417522408860795,
    0.27410230826824267,
    0.27352276691730987,
    0.27331564305944644,
    0.2734539895615679,
    0.2732338231294713,
    0.27364757859960515,
    0.273379614188316,
    0.273683895773076,
    0.2732433831438105,
    0.27292971446159037,
    0.2734361114020043,
    0.272896498061241,
    0.27337897071178924,
    0.2733150319850191,
    0.27300069655509707,
    0.27303145591248856,
    0.2733002274594408,
    0.27326074001

In [33]:
# save results to file
f = open('test_results_MNIST_2.pickle', 'wb')
pickle.dump(results, f)
f.close()

### Extras