### **Imports and Config**

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import time
from torchvision import datasets
from torchvision.transforms import v2
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, cohen_kappa_score, accuracy_score, classification_report, confusion_matrix
from dataclasses import dataclass
from fvcore.nn import FlopCountAnalysis
import wandb
from soap import SOAP

%matplotlib inline

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@dataclass
class Config:
    batch_size: int = 128
    learning_rate: float = 1.73e-3
    epochs: int = 50
    patience: int = 10
    num_workers: int = 10
    rotation: int = 5
    translation: float = 0.1
    shear_angle: int = 1

# Initialize wandb
run = wandb.init(
    # Set the project where this run will be logged
    project="mnist",
    # Track hyperparameters and run metadata
    config={
        "batch_size": Config.batch_size,
        "learning_rate": Config.learning_rate,
        "epochs": Config.epochs,
        "patience": Config.patience,
        "num_workers": Config.num_workers,
        "rotation": Config.rotation,
        "translation": Config.translation,
        "shear_angle": Config.shear_angle
    },
)

0,1
epoch,▁▃▅▆█
test_acc,▁▃▅▅█
test_loss,█▆▄▄▁
train_acc,▁▆▇██
train_loss,█▃▂▁▁

0,1
epoch,4.0
test_acc,0.9705
test_loss,0.0876
train_acc,0.95097
train_loss,0.15743


cat: /sys/module/amdgpu/initstate: No such file or directory
ERROR:root:Driver not initialized (amdgpu not found in modules)


### **Network Definition**

In [7]:
def vieta_pell(n, x):
    if n == 0:
        return 2 * torch.ones_like(x)
    elif n == 1:
        return x
    else:
        return x * vieta_pell(n - 1, x) + vieta_pell(n - 2, x)

class VietaPellKANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, degree):
        super(VietaPellKANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.degree = degree
        self.vp_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
        nn.init.normal_(self.vp_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1)))

    def forward(self, x):
        # Normalize x to [-1, 1] using tanh
        x = torch.tanh(x)

        # Compute the Vieta-Pell basis functions
        vp_basis = []
        for n in range(self.degree + 1):
            vp_basis.append(vieta_pell(n, x))
        vp_basis = torch.stack(vp_basis, dim=-1)  # shape = (batch_size, input_dim, degree + 1)

        # Compute the Vieta-Pell interpolation
        y = torch.einsum("bid,iod->bo", vp_basis, self.vp_coeffs)  # shape = (batch_size, output_dim)
        y = y.view(-1, self.output_dim)

        return y

class MNISTVietaPellKAN(nn.Module):
    def __init__(self):
        super(MNISTVietaPellKAN, self).__init__()
        self.trigkan1 = VietaPellKANLayer(784, 32, 3)
        self.bn1 = nn.LayerNorm(32)
        self.trigkan2 = VietaPellKANLayer(32, 32, 3)
        self.bn2 = nn.LayerNorm(32)
        self.trigkan3 = VietaPellKANLayer(32, 10, 3)

    def forward(self, x):
        x = x.view(-1, 28*28)
        #x=x.tanh()
        x = self.trigkan1(x)
        x = self.bn1(x)
        x = self.trigkan2(x)
        x = self.bn2(x)
        x = self.trigkan3(x)
        return x

### **Training and Testing Defintion**

In [8]:
transform_train = v2.Compose([
    v2.ToImage(),
    v2.RandomAffine(degrees=Config.rotation, translate=(Config.translation, Config.translation), shear=Config.shear_angle),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.1307,), (0.3081,))
])

transform_test = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=Config.num_workers)
test_loader = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=Config.num_workers)

num_classes = 10


criterion = nn.CrossEntropyLoss()



def train(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    progress_bar = tqdm(train_loader, desc="Training")
    for data, target in progress_bar:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
        progress_bar.set_postfix({'Loss': total_loss / (progress_bar.n + 1), 'Acc': 100. * correct / total})
    
    return total_loss / len(train_loader), correct / total

def validate(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

    return total_loss / len(test_loader), correct / total

### **Model Set-up**

In [9]:
Model_Name='VietaPell'  #Add names of other models
model0 = MNISTVietaPellKAN().to(device)
model=model0
total_params = sum(p.numel() for p in model0.parameters() if p.requires_grad)
flops = FlopCountAnalysis(model, inputs=(torch.randn(1, 28 * 28).to(device),)).total()
print(f"Total trainable parameters of {Model_Name}: {total_params}")
print(f"FLOPs of {Model_Name}: {flops}")

Unsupported operator aten::tanh encountered 3 time(s)
Unsupported operator aten::ones_like encountered 9 time(s)
Unsupported operator aten::mul encountered 18 time(s)
Unsupported operator aten::add encountered 9 time(s)


Total trainable parameters of VietaPell: 105856
FLOPs of VietaPell: 106046.0


### **Training of Model**

In [10]:
def train_and_validate(model, train_loader, test_loader, criterion, optimizer, device, epochs, patience):
    best_test_loss = float('inf')
    best_weights = None
    no_improve = 0
    
    for epoch in range(epochs):
        train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
        test_loss, test_acc = validate(model, test_loader, criterion, device)

        # Log metrics to wandb
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss,
            "train_acc": train_acc,
            "test_loss": test_loss,
            "test_acc": test_acc
        }, step=epoch+1)
        
        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
        print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
        
        if test_loss < best_test_loss:
            best_test_loss = test_loss
            best_weights = model.state_dict()
            no_improve = 0
        else:
            no_improve += 1
            if no_improve == patience:
                print(f'Early stopping after {epoch+1} epochs')
                break
    
    return best_weights, best_test_loss


optimizers = SOAP(model.parameters(), lr=Config.learning_rate)

best_weights, model_times = train_and_validate(model, train_loader, test_loader, criterion, optimizers, device, Config.epochs, Config.patience)

wandb.finish()

# Save the best weights for model
model.load_state_dict(best_weights)
torch.save(model.state_dict(), f'{Model_Name}_best_weights.pth')

# Print the processing time for model
print(f"{Model_Name} processing time: {model_times:.2f} seconds")

Training: 100%|██████████| 469/469 [00:02<00:00, 169.69it/s, Loss=0.576, Acc=83.3]


Epoch 1/50:
Train Loss: 0.5567, Train Acc: 0.8335
Test Loss: 0.1795, Test Acc: 0.9441


Training: 100%|██████████| 469/469 [00:02<00:00, 176.43it/s, Loss=0.26, Acc=92.2] 


Epoch 2/50:
Train Loss: 0.2477, Train Acc: 0.9223
Test Loss: 0.1419, Test Acc: 0.9544


Training: 100%|██████████| 469/469 [00:02<00:00, 169.87it/s, Loss=0.206, Acc=93.8]


Epoch 3/50:
Train Loss: 0.2012, Train Acc: 0.9377
Test Loss: 0.1300, Test Acc: 0.9580


Training: 100%|██████████| 469/469 [00:02<00:00, 169.68it/s, Loss=0.194, Acc=94.2]


Epoch 4/50:
Train Loss: 0.1862, Train Acc: 0.9419
Test Loss: 0.1059, Test Acc: 0.9675


Training: 100%|██████████| 469/469 [00:02<00:00, 172.37it/s, Loss=0.171, Acc=94.6]


Epoch 5/50:
Train Loss: 0.1703, Train Acc: 0.9457
Test Loss: 0.1035, Test Acc: 0.9668


Training: 100%|██████████| 469/469 [00:02<00:00, 177.16it/s, Loss=0.166, Acc=95]  


Epoch 6/50:
Train Loss: 0.1593, Train Acc: 0.9502
Test Loss: 0.0975, Test Acc: 0.9691


Training: 100%|██████████| 469/469 [00:02<00:00, 175.66it/s, Loss=0.149, Acc=95.3]


Epoch 7/50:
Train Loss: 0.1491, Train Acc: 0.9534
Test Loss: 0.0963, Test Acc: 0.9685


Training: 100%|██████████| 469/469 [00:02<00:00, 176.84it/s, Loss=0.147, Acc=95.7]


Epoch 8/50:
Train Loss: 0.1412, Train Acc: 0.9567
Test Loss: 0.0876, Test Acc: 0.9720


Training: 100%|██████████| 469/469 [00:02<00:00, 175.84it/s, Loss=0.139, Acc=95.8]


Epoch 9/50:
Train Loss: 0.1341, Train Acc: 0.9576
Test Loss: 0.0971, Test Acc: 0.9676


Training: 100%|██████████| 469/469 [00:02<00:00, 179.92it/s, Loss=0.136, Acc=95.8]


Epoch 10/50:
Train Loss: 0.1337, Train Acc: 0.9583
Test Loss: 0.0842, Test Acc: 0.9743


Training: 100%|██████████| 469/469 [00:02<00:00, 177.19it/s, Loss=0.134, Acc=96]  


Epoch 11/50:
Train Loss: 0.1290, Train Acc: 0.9599
Test Loss: 0.0793, Test Acc: 0.9743


Training: 100%|██████████| 469/469 [00:02<00:00, 177.62it/s, Loss=0.128, Acc=96.1]


Epoch 12/50:
Train Loss: 0.1232, Train Acc: 0.9609
Test Loss: 0.0781, Test Acc: 0.9746


Training: 100%|██████████| 469/469 [00:02<00:00, 173.28it/s, Loss=0.125, Acc=96.2]


Epoch 13/50:
Train Loss: 0.1227, Train Acc: 0.9617
Test Loss: 0.0790, Test Acc: 0.9733


Training: 100%|██████████| 469/469 [00:02<00:00, 175.34it/s, Loss=0.119, Acc=96.3]


Epoch 14/50:
Train Loss: 0.1179, Train Acc: 0.9628
Test Loss: 0.0722, Test Acc: 0.9762


Training: 100%|██████████| 469/469 [00:02<00:00, 176.55it/s, Loss=0.123, Acc=96.3]


Epoch 15/50:
Train Loss: 0.1182, Train Acc: 0.9635
Test Loss: 0.0741, Test Acc: 0.9749


Training: 100%|██████████| 469/469 [00:02<00:00, 176.21it/s, Loss=0.119, Acc=96.4]


Epoch 16/50:
Train Loss: 0.1140, Train Acc: 0.9639
Test Loss: 0.0713, Test Acc: 0.9771


Training: 100%|██████████| 469/469 [00:02<00:00, 174.72it/s, Loss=0.114, Acc=96.6]


Epoch 17/50:
Train Loss: 0.1107, Train Acc: 0.9657
Test Loss: 0.0720, Test Acc: 0.9760


Training: 100%|██████████| 469/469 [00:02<00:00, 175.56it/s, Loss=0.117, Acc=96.5]


Epoch 18/50:
Train Loss: 0.1119, Train Acc: 0.9650
Test Loss: 0.0776, Test Acc: 0.9746


Training: 100%|██████████| 469/469 [00:02<00:00, 176.13it/s, Loss=0.108, Acc=96.5]


Epoch 19/50:
Train Loss: 0.1076, Train Acc: 0.9652
Test Loss: 0.0716, Test Acc: 0.9763


Training: 100%|██████████| 469/469 [00:02<00:00, 175.68it/s, Loss=0.106, Acc=96.7]


Epoch 20/50:
Train Loss: 0.1056, Train Acc: 0.9673
Test Loss: 0.0734, Test Acc: 0.9769


Training: 100%|██████████| 469/469 [00:02<00:00, 173.44it/s, Loss=0.108, Acc=96.7]


Epoch 21/50:
Train Loss: 0.1056, Train Acc: 0.9672
Test Loss: 0.0665, Test Acc: 0.9789


Training: 100%|██████████| 469/469 [00:02<00:00, 174.78it/s, Loss=0.106, Acc=96.8]


Epoch 22/50:
Train Loss: 0.1049, Train Acc: 0.9678
Test Loss: 0.0653, Test Acc: 0.9787


Training: 100%|██████████| 469/469 [00:02<00:00, 173.97it/s, Loss=0.103, Acc=96.8]


Epoch 23/50:
Train Loss: 0.1005, Train Acc: 0.9680
Test Loss: 0.0703, Test Acc: 0.9771


Training: 100%|██████████| 469/469 [00:02<00:00, 174.79it/s, Loss=0.103, Acc=96.8] 


Epoch 24/50:
Train Loss: 0.1018, Train Acc: 0.9679
Test Loss: 0.0637, Test Acc: 0.9799


Training: 100%|██████████| 469/469 [00:02<00:00, 176.44it/s, Loss=0.101, Acc=96.8] 


Epoch 25/50:
Train Loss: 0.1008, Train Acc: 0.9682
Test Loss: 0.0693, Test Acc: 0.9786


Training: 100%|██████████| 469/469 [00:02<00:00, 172.70it/s, Loss=0.1, Acc=96.8]   


Epoch 26/50:
Train Loss: 0.0994, Train Acc: 0.9684
Test Loss: 0.0678, Test Acc: 0.9791


Training: 100%|██████████| 469/469 [00:02<00:00, 174.21it/s, Loss=0.0984, Acc=97] 


Epoch 27/50:
Train Loss: 0.0980, Train Acc: 0.9700
Test Loss: 0.0671, Test Acc: 0.9782


Training: 100%|██████████| 469/469 [00:02<00:00, 175.16it/s, Loss=0.0977, Acc=96.9]


Epoch 28/50:
Train Loss: 0.0977, Train Acc: 0.9692
Test Loss: 0.0674, Test Acc: 0.9775


Training: 100%|██████████| 469/469 [00:02<00:00, 172.72it/s, Loss=0.099, Acc=97]   


Epoch 29/50:
Train Loss: 0.0975, Train Acc: 0.9699
Test Loss: 0.0670, Test Acc: 0.9783


Training: 100%|██████████| 469/469 [00:02<00:00, 176.73it/s, Loss=0.101, Acc=97]   


Epoch 30/50:
Train Loss: 0.0963, Train Acc: 0.9705
Test Loss: 0.0591, Test Acc: 0.9800


Training: 100%|██████████| 469/469 [00:02<00:00, 175.58it/s, Loss=0.0983, Acc=97]  


Epoch 31/50:
Train Loss: 0.0939, Train Acc: 0.9705
Test Loss: 0.0655, Test Acc: 0.9782


Training: 100%|██████████| 469/469 [00:02<00:00, 174.96it/s, Loss=0.0946, Acc=97]  


Epoch 32/50:
Train Loss: 0.0944, Train Acc: 0.9697
Test Loss: 0.0658, Test Acc: 0.9789


Training: 100%|██████████| 469/469 [00:03<00:00, 156.27it/s, Loss=0.0988, Acc=97]  


Epoch 33/50:
Train Loss: 0.0939, Train Acc: 0.9699
Test Loss: 0.0647, Test Acc: 0.9791


Training: 100%|██████████| 469/469 [00:02<00:00, 173.86it/s, Loss=0.0962, Acc=97]  


Epoch 34/50:
Train Loss: 0.0941, Train Acc: 0.9702
Test Loss: 0.0607, Test Acc: 0.9809


Training: 100%|██████████| 469/469 [00:02<00:00, 175.01it/s, Loss=0.0926, Acc=97.1]


Epoch 35/50:
Train Loss: 0.0920, Train Acc: 0.9713
Test Loss: 0.0600, Test Acc: 0.9801


Training: 100%|██████████| 469/469 [00:02<00:00, 174.53it/s, Loss=0.0929, Acc=97.1]


Epoch 36/50:
Train Loss: 0.0923, Train Acc: 0.9707
Test Loss: 0.0606, Test Acc: 0.9792


Training: 100%|██████████| 469/469 [00:02<00:00, 175.79it/s, Loss=0.0905, Acc=97.2]


Epoch 37/50:
Train Loss: 0.0903, Train Acc: 0.9718
Test Loss: 0.0640, Test Acc: 0.9790


Training: 100%|██████████| 469/469 [00:02<00:00, 174.64it/s, Loss=0.092, Acc=97.2] 


Epoch 38/50:
Train Loss: 0.0908, Train Acc: 0.9719
Test Loss: 0.0661, Test Acc: 0.9791


Training: 100%|██████████| 469/469 [00:02<00:00, 175.15it/s, Loss=0.0884, Acc=97.2]


Epoch 39/50:
Train Loss: 0.0880, Train Acc: 0.9720
Test Loss: 0.0617, Test Acc: 0.9803


Training: 100%|██████████| 469/469 [00:02<00:00, 176.53it/s, Loss=0.0906, Acc=97.2]


Epoch 40/50:
Train Loss: 0.0906, Train Acc: 0.9718
Test Loss: 0.0612, Test Acc: 0.9792
Early stopping after 40 epochs


0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test_acc,▁▃▄▅▅▆▆▆▅▇▇▇▇▇▇▇▇▇▇▇██▇███▇▇██▇█████████
test_loss,█▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_acc,▁▅▆▆▇▇▇▇▇▇▇▇▇███████████████████████████
train_loss,█▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,39.0
test_acc,0.9792
test_loss,0.06119
train_acc,0.97185
train_loss,0.09059


VietaPell processing time: 0.06 seconds


### **Model Evaluation**

In [11]:
def evaluate_model(model, test_loader, device):
    model.eval()
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds, average='weighted')
    kappa = cohen_kappa_score(all_targets, all_preds)
    
    # Print classification report
    print(classification_report(all_targets, all_preds))
    
    # Plot confusion matrix
    cm = confusion_matrix(all_targets, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
    return accuracy, f1, kappa

# Usage
model.load_state_dict(best_weights)
accuracy, f1, kappa = evaluate_model(model, test_loader, device)
print(f'Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, Cohen\'s Kappa: {kappa:.4f}')


              precision    recall  f1-score   support

           0       0.98      0.99      0.99       980
           1       0.99      0.99      0.99      1135
           2       0.98      0.98      0.98      1032
           3       0.97      0.97      0.97      1010
           4       0.98      0.97      0.98       982
           5       0.98      0.97      0.98       892
           6       0.98      0.98      0.98       958
           7       0.98      0.98      0.98      1028
           8       0.98      0.97      0.98       974
           9       0.96      0.97      0.97      1009

    accuracy                           0.98     10000
   macro avg       0.98      0.98      0.98     10000
weighted avg       0.98      0.98      0.98     10000



NameError: name 'sns' is not defined

<Figure size 1000x800 with 0 Axes>