In [1]:
from efficient_kan import KAN
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import torch
import mplhep as hep
hep.style.use("CMS")
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
# device = 'cpu'

In [2]:
X_train=torch.from_numpy(np.load('data/X_train_val.npy')).float().to(device)
y_train=torch.from_numpy(np.load('data/y_train_val.npy')).float().to(device).argmax(dim=1)
X_test=torch.from_numpy(np.load('data/X_test.npy')).float().to(device)
y_test=torch.from_numpy(np.load('data/y_test.npy')).float().to(device).argmax(dim=1)

# Create TensorDataset objects
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

# Create DataLoader objects
batch_size = 64  # Adjust this based on your available memory
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [3]:
model = KAN([16,4,5], grid_size=40, spline_order=4, grid_eps=0.03, base_activation=nn.GELU, grid_range=[-8,8])
print(sum(p.numel() for p in model.parameters()))

3864


In [4]:
model.to(device)
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

In [5]:
training_loss = []
testing_loss = []

# Define loss
criterion = nn.CrossEntropyLoss()
for epoch in range(30):
    # Train
    model.train()
    epoch_train_loss = 0  # Initialize loss for the epoch
    total_batches = 0
    with tqdm(trainloader) as pbar:
        for i, (inputs, labels) in enumerate(pbar):
            inputs = inputs.to(device)
            optimizer.zero_grad()
            output = model(inputs)
            loss = criterion(output, labels.to(device))
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item()
            total_batches += 1

            accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])
    
    average_train_loss = epoch_train_loss / total_batches
    training_loss.append(average_train_loss)  # Record the average training loss

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs = inputs.to(device)
            output = model(inputs)
            val_loss += criterion(output, labels.to(device)).item()
            val_accuracy += (
                (output.argmax(dim=1) == labels.to(device)).float().mean().item()
            )
    val_loss /= len(testloader)
    val_accuracy /= len(testloader)
    testing_loss.append(val_loss)

    # Update learning rate
    scheduler.step()

    print(
        f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}"
    )

100%|██████████| 10375/10375 [00:20<00:00, 507.04it/s, accuracy=0.766, loss=0.764, lr=0.0001]


Epoch 1, Val Loss: 0.8091774684425125, Val Accuracy: 0.7233118093188395


100%|██████████| 10375/10375 [00:19<00:00, 520.25it/s, accuracy=0.75, loss=0.79, lr=9e-5]  


Epoch 2, Val Loss: 0.7629604442540553, Val Accuracy: 0.7316282767925983


100%|██████████| 10375/10375 [00:19<00:00, 535.38it/s, accuracy=0.781, loss=0.715, lr=8.1e-5]


Epoch 3, Val Loss: 0.7471557933730902, Val Accuracy: 0.7359310749317926


100%|██████████| 10375/10375 [00:19<00:00, 540.79it/s, accuracy=0.719, loss=0.745, lr=7.29e-5]


Epoch 4, Val Loss: 0.737323245132714, Val Accuracy: 0.7393223143304047


100%|██████████| 10375/10375 [00:19<00:00, 541.14it/s, accuracy=0.734, loss=0.712, lr=6.56e-5]


Epoch 5, Val Loss: 0.7306114924280296, Val Accuracy: 0.7416975873887952


100%|██████████| 10375/10375 [00:19<00:00, 527.69it/s, accuracy=0.656, loss=0.847, lr=5.9e-5]


Epoch 6, Val Loss: 0.7259147253143117, Val Accuracy: 0.7436130654150096


100%|██████████| 10375/10375 [00:19<00:00, 529.33it/s, accuracy=0.797, loss=0.578, lr=5.31e-5]


Epoch 7, Val Loss: 0.7223938440784824, Val Accuracy: 0.7447635569338993


100%|██████████| 10375/10375 [00:19<00:00, 540.35it/s, accuracy=0.75, loss=0.745, lr=4.78e-5] 


Epoch 8, Val Loss: 0.7195391563056704, Val Accuracy: 0.7458357427473149


100%|██████████| 10375/10375 [00:19<00:00, 541.08it/s, accuracy=0.844, loss=0.511, lr=4.3e-5]


Epoch 9, Val Loss: 0.717457350989719, Val Accuracy: 0.7466489173810852


100%|██████████| 10375/10375 [00:19<00:00, 541.96it/s, accuracy=0.672, loss=0.861, lr=3.87e-5]


Epoch 10, Val Loss: 0.7155892815420788, Val Accuracy: 0.7474861860780783


100%|██████████| 10375/10375 [00:19<00:00, 538.72it/s, accuracy=0.766, loss=0.735, lr=3.49e-5]


Epoch 11, Val Loss: 0.7142154429288855, Val Accuracy: 0.7479499967951175


100%|██████████| 10375/10375 [00:19<00:00, 541.82it/s, accuracy=0.75, loss=0.727, lr=3.14e-5] 


Epoch 12, Val Loss: 0.712878265158893, Val Accuracy: 0.7484619956386025


100%|██████████| 10375/10375 [00:19<00:00, 542.96it/s, accuracy=0.781, loss=0.675, lr=2.82e-5]


Epoch 13, Val Loss: 0.7117521203106517, Val Accuracy: 0.7486065600179395


100%|██████████| 10375/10375 [00:19<00:00, 544.34it/s, accuracy=0.812, loss=0.576, lr=2.54e-5]


Epoch 14, Val Loss: 0.7108268658404361, Val Accuracy: 0.749064347219173


100%|██████████| 10375/10375 [00:19<00:00, 538.61it/s, accuracy=0.719, loss=0.744, lr=2.29e-5]


Epoch 15, Val Loss: 0.7100387523862703, Val Accuracy: 0.7492149351143157


100%|██████████| 10375/10375 [00:19<00:00, 543.76it/s, accuracy=0.688, loss=0.873, lr=2.06e-5]


Epoch 16, Val Loss: 0.7093372924817739, Val Accuracy: 0.749371546525264


100%|██████████| 10375/10375 [00:19<00:00, 543.68it/s, accuracy=0.734, loss=0.774, lr=1.85e-5]


Epoch 17, Val Loss: 0.7086498234808767, Val Accuracy: 0.749576346062658


100%|██████████| 10375/10375 [00:19<00:00, 544.02it/s, accuracy=0.781, loss=0.679, lr=1.67e-5]


Epoch 18, Val Loss: 0.7081289151791159, Val Accuracy: 0.7496365812207151


100%|██████████| 10375/10375 [00:19<00:00, 541.44it/s, accuracy=0.891, loss=0.45, lr=1.5e-5] 


Epoch 19, Val Loss: 0.7076647112989021, Val Accuracy: 0.7500040156848631


100%|██████████| 10375/10375 [00:19<00:00, 542.10it/s, accuracy=0.766, loss=0.651, lr=1.35e-5]


Epoch 20, Val Loss: 0.7071907868638991, Val Accuracy: 0.7499919686532517


100%|██████████| 10375/10375 [00:19<00:00, 543.08it/s, accuracy=0.719, loss=0.775, lr=1.22e-5]


Epoch 21, Val Loss: 0.7068320451305753, Val Accuracy: 0.7501485800642


100%|██████████| 10375/10375 [00:19<00:00, 543.34it/s, accuracy=0.781, loss=0.679, lr=1.09e-5]


Epoch 22, Val Loss: 0.7065209641696493, Val Accuracy: 0.7502208622538685


100%|██████████| 10375/10375 [00:19<00:00, 528.39it/s, accuracy=0.797, loss=0.666, lr=9.85e-6]


Epoch 23, Val Loss: 0.7062149628502824, Val Accuracy: 0.75038148934968


100%|██████████| 10375/10375 [00:19<00:00, 527.17it/s, accuracy=0.781, loss=0.713, lr=8.86e-6]


Epoch 24, Val Loss: 0.7059569927413554, Val Accuracy: 0.7503232620225655


100%|██████████| 10375/10375 [00:19<00:00, 525.16it/s, accuracy=0.703, loss=0.887, lr=7.98e-6]


Epoch 25, Val Loss: 0.7057261849217169, Val Accuracy: 0.7504798734335139


100%|██████████| 10375/10375 [00:19<00:00, 539.14it/s, accuracy=0.781, loss=0.683, lr=7.18e-6]


Epoch 26, Val Loss: 0.7054998913152272, Val Accuracy: 0.7504959361499883


100%|██████████| 10375/10375 [00:19<00:00, 541.33it/s, accuracy=0.766, loss=0.72, lr=6.46e-6] 


Epoch 27, Val Loss: 0.7053302176750524, Val Accuracy: 0.7505200302132111


100%|██████████| 10375/10375 [00:19<00:00, 538.69it/s, accuracy=0.766, loss=0.641, lr=5.81e-6]


Epoch 28, Val Loss: 0.7051776356323003, Val Accuracy: 0.7505501477922397


100%|██████████| 10375/10375 [00:19<00:00, 541.29it/s, accuracy=0.75, loss=0.572, lr=5.23e-6] 


Epoch 29, Val Loss: 0.7050161322746814, Val Accuracy: 0.7506585710767424


100%|██████████| 10375/10375 [00:19<00:00, 539.50it/s, accuracy=0.781, loss=0.696, lr=4.71e-6]


Epoch 30, Val Loss: 0.7048946131098335, Val Accuracy: 0.7506645945925481


In [None]:
plt.plot(training_loss, label='KAN Training Loss',linewidth=5)
plt.plot(testing_loss, label='KAN Testing Loss',linewidth=5)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()