In [45]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
from kan_convolutional.KANConv import KAN_Convolutional_Layer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [52]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# Train set. Here we sort the MNIST by digits and disable data shuffling
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
sorted_indices = sorted(range(len(train_dataset)//10), key=lambda idx: train_dataset.targets[idx])
# sorted_indices = range(len(train_dataset)//10)
train_dataset = torch.utils.data.Subset(train_dataset, sorted_indices)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)
# Test set
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [51]:
def train(model, checkpoint, epochs=5):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-6)
    # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996)
    for epoch in range(epochs):
        model.train()
        with tqdm(train_loader) as pbar:
            for i, (images, labels) in enumerate(pbar):
                optimizer.zero_grad()
                output = model(images.to(device))
                loss = criterion(output, labels.to(device))
                loss.backward()
                optimizer.step()
                accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
                pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])
                # scheduler.step()
        print(f'Epoch {epoch + 1}, Loss: {loss.item()}')
        torch.save(model.state_dict(), checkpoint)

In [49]:
def validate(model):
    model.eval()
    vals=[0]*10
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in test_loader:
            output = model(images.to(device))
            for out in output.argmax(dim=1):
              vals[out.item()]+=1
            val_accuracy += (output.argmax(dim=1) == labels.to(device)).float().mean().item()
    val_accuracy /= len(test_loader)
    print(vals)
    print(f"Accuracy: {val_accuracy}")

In [35]:
class CKAN_BN(nn.Module):
    def __init__(self,device: str = 'cpu'):
        super().__init__()
        self.conv1 = KAN_Convolutional_Layer(
            n_convs = 5,
            kernel_size= (3,3),
            device = device
        )
        self.bn1 = nn.BatchNorm2d(5)

        self.conv2 = KAN_Convolutional_Layer(
            n_convs = 5,
            kernel_size = (3,3),
            device = device
        )
        self.bn2 = nn.BatchNorm2d(25)

        self.pool1 = nn.MaxPool2d(
            kernel_size=(2, 2)
        )
        
        self.flat = nn.Flatten() 
        
        self.linear1 = nn.Linear(625, 256)
        self.linear2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = self.bn2(x)
       
        x = self.pool1(x)
        x = self.flat(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x

In [53]:
ckan_model = CKAN_BN(device=device).to(device)
train(ckan_model,'checkpoint/ckan_mnist_no_shuffle2.pth')

100%|██████████| 94/94 [00:16<00:00,  5.78it/s, accuracy=0.375, loss=1.94, lr=1e-6] 


Epoch 1, Loss: 1.9429854154586792


100%|██████████| 94/94 [00:15<00:00,  5.91it/s, accuracy=0.417, loss=1.94, lr=1e-6] 


Epoch 2, Loss: 1.9387669563293457


100%|██████████| 94/94 [00:16<00:00,  5.78it/s, accuracy=0.438, loss=1.93, lr=1e-6] 


Epoch 3, Loss: 1.934876561164856


 23%|██▎       | 22/94 [00:03<00:12,  5.65it/s, accuracy=0, loss=2.8, lr=1e-6] 


KeyboardInterrupt: 

In [54]:
model = CKAN_BN(device=device).to(device)
model.load_state_dict(torch.load('checkpoint/ckan_mnist_no_shuffle2.pth'))
validate(model)

[20, 0, 2, 2724, 1357, 900, 37, 100, 2833, 2027]
Accuracy: 0.15764331210191082
