In [1]:
pip install pykan

Collecting pykan
  Downloading pykan-0.0.5-py3-none-any.whl (33 kB)
Installing collected packages: pykan
Successfully installed pykan-0.0.5


In [2]:
import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from kan import *
from google.colab import files

from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive', force_remount=True)




print(f'Using Pytorch version:{torch.__version__}')

if torch.cuda.is_available():
  print(f'Using GPU device: {torch.cuda.get_device_name(0)}')
  device = torch.device('cuda')
else:
  print('No GPU Found, using CPU instead')
  device = torch.device('cpu')

print('Using device:', device)


#### Training Dataset #####

train_dataset = datasets.MNIST(root='.', train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_dataset, batch_size=100, shuffle=True, pin_memory=True)

#### Test dataset ####

test_dataset = datasets.MNIST(root='.', train=False, transform=transforms.ToTensor())
test_loader = DataLoader(dataset=test_dataset, batch_size=100, shuffle=False, pin_memory=True)


Mounted at /content/drive
Using Pytorch version:2.3.0+cu121
No GPU Found, using CPU instead
Using device: cpu
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 16056533.19it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 539392.52it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4444280.40it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1609677.12it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



In [3]:
##### Implement KAN ######

model = KAN(width=[28*28,10,10], device = device)
model.to(device)


KAN(
  (biases): ModuleList(
    (0-1): 2 x Linear(in_features=10, out_features=1, bias=False)
  )
  (act_fun): ModuleList(
    (0-1): 2 x KANLayer(
      (base_fun): SiLU()
    )
  )
  (base_fun): SiLU()
  (symbolic_fun): ModuleList(
    (0-1): 2 x Symbolic_KANLayer()
  )
)

In [4]:
######## Implement loss function (crossentropyloss) and optimizer (adam)

#### Setting the loss and the optimizer

Loss = nn.CrossEntropyLoss()

GD = torch.optim.Adam(model.parameters(), lr=1e-3)


#### Calculating correctly predicted labels ####

def correctly_predicted(output,target):
  predicted = torch.argmax(output, dim=1)
  correct_ones = (predicted==target).type(torch.float)
  return correct_ones.sum().item()


In [5]:

def train(data_loader,model,Loss,GD, start_epoch):



  num_batches = len(data_loader)
  num_samples = len(data_loader.dataset)

  total_loss = 0
  total_accurate = 0

  checkpoint_path = '/content/drive/MyDrive/mnist_model_checkpoint-10hiddennodes.pth'

  #model.train()

  for batch_idx, (data, target) in enumerate(data_loader):

    if batch_idx % 100 == 0:
            print(f"\n >>> Epoch {start_epoch}, Batch {batch_idx}")

    # Save checkpoint after every 100 batches
    if batch_idx > 0 and batch_idx % 100 == 0:
            print(f"\n Saving checkpoint for epoch {start_epoch}, batch {batch_idx}")
            checkpoint = {
                'epoch': start_epoch,
                'batch_idx': batch_idx,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': GD.state_dict(),
                'train_loss': total_loss / (batch_idx + 1),  # Average loss up to current batch
                'accuracy': total_accurate / ((batch_idx + 1) * data.size(0))  # Average accuracy up to current batch
            }
            torch.save(checkpoint, checkpoint_path)
            #files.download(checkpoint_path)
            print(f"Checkpoint saved and downloaded on local for epoch {start_epoch}, batch {batch_idx}")

    ## copy to device
    data = data.view(-1, 28 * 28).to(device)
    target = target.to(device)

    #print(f"Length of data = {len(data)} and that of target is {len(target)}")

    ### forward pass
    output = model(data)
    #print(f"Length of output = {len(output)}")

    ### Calculate loss ###
    #output_singular = torch.argmax(output, dim=1)
    #target_float = target.float()
    batch_loss = Loss(output,target)
    total_loss += batch_loss.item()

    ### Count correctly predicted labels###
    correctly_predicted_count = correctly_predicted(output,target)
    total_accurate += correctly_predicted_count

    ### backward propagation
    GD.zero_grad()
    batch_loss.backward()
    GD.step()




  train_loss = total_loss/num_batches
  accuracy = total_accurate/num_samples
  print(f"Average loss: {train_loss:4f}, accuracy: {accuracy:.2%}")




In [6]:

epochs = 10
for i in range(epochs):
  print(f"Traning Start: EPOCH number: {i+1}")
  train(train_loader, model, Loss, GD, i+1)


Traning Start: EPOCH number: 1

 >>> Epoch 1, Batch 0

 >>> Epoch 1, Batch 100

 Saving checkpoint for epoch 1, batch 100
Checkpoint saved and downloaded on local for epoch 1, batch 100

 >>> Epoch 1, Batch 200

 Saving checkpoint for epoch 1, batch 200
Checkpoint saved and downloaded on local for epoch 1, batch 200

 >>> Epoch 1, Batch 300

 Saving checkpoint for epoch 1, batch 300
Checkpoint saved and downloaded on local for epoch 1, batch 300

 >>> Epoch 1, Batch 400

 Saving checkpoint for epoch 1, batch 400
Checkpoint saved and downloaded on local for epoch 1, batch 400

 >>> Epoch 1, Batch 500

 Saving checkpoint for epoch 1, batch 500
Checkpoint saved and downloaded on local for epoch 1, batch 500
Average loss: 1.318138, accuracy: 64.73%
Traning Start: EPOCH number: 2

 >>> Epoch 2, Batch 0

 >>> Epoch 2, Batch 100

 Saving checkpoint for epoch 2, batch 100
Checkpoint saved and downloaded on local for epoch 2, batch 100

 >>> Epoch 2, Batch 200

 Saving checkpoint for epoch 2, b

In [2]:
### Testing #####

def test(data_loader,model,Loss,GD):

  #model.eval()

  num_batches_test = len(data_loader)
  num_samples_test = len(data_loader.dataset)

  total_loss_test = 0
  total_accurate_test = 0



  for batch_idx, (data, target) in enumerate(data_loader):

    print(f"Batch ID = {batch_idx}")
    ## copy to device
    data = data.view(-1, 28 * 28).to(device)
    target = target.to(device)



    ### forward pass
    output = model(data)
    #print(f"Length of output = {len(output)}")

    ### Calculate loss ###
    batch_loss = Loss(output,target)
    total_loss_test += batch_loss

    ### Count correctly predicted labels###
    correctly_predicted_count = correctly_predicted(output,target)
    total_accurate_test += correctly_predicted_count


  test_loss = total_loss_test/num_batches_test
  accuracy_test = total_accurate_test/num_samples_test
  print(f"Average loss: {test_loss:4f}, accuracy: {accuracy_test:.2%}")


test(test_loader,model,Loss,GD)

NameError: name 'test_loader' is not defined