In [5]:
from efficient_kan import KAN
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import torch
import mplhep as hep
hep.style.use("CMS")
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
# device = 'cpu'

In [6]:
X_train=torch.from_numpy(np.load('data/X_train_val.npy')).float().to(device)
y_train=torch.from_numpy(np.load('data/y_train_val.npy')).float().to(device).argmax(dim=1)
X_test=torch.from_numpy(np.load('data/X_test.npy')).float().to(device)
y_test=torch.from_numpy(np.load('data/y_test.npy')).float().to(device).argmax(dim=1)

# Create TensorDataset objects
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

# Create DataLoader objects
batch_size = 64  # Adjust this based on your available memory
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [7]:
model = KAN([16,4,5], grid_size=30, spline_order=3, grid_eps=0.05, base_activation=nn.GELU, grid_range=[-5,5])
print(sum(p.numel() for p in model.parameters()))

2940


In [8]:
# for i in range(15):
#     plt.hist(X_train.cpu().numpy()[:,i], bins=100, range=(-5,5))
#     plt.show()

In [9]:
model.to(device)
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

In [10]:
training_loss = []
testing_loss = []

# Define loss
criterion = nn.CrossEntropyLoss()
for epoch in range(15):
    # Train
    model.train()
    epoch_train_loss = 0  # Initialize loss for the epoch
    total_batches = 0
    with tqdm(trainloader) as pbar:
        for i, (inputs, labels) in enumerate(pbar):
            inputs = inputs.to(device)
            optimizer.zero_grad()
            output = model(inputs)
            loss = criterion(output, labels.to(device))
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item()
            total_batches += 1

            accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])
    
    average_train_loss = epoch_train_loss / total_batches
    training_loss.append(average_train_loss)  # Record the average training loss

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs = inputs.to(device)
            output = model(inputs)
            val_loss += criterion(output, labels.to(device)).item()
            val_accuracy += (
                (output.argmax(dim=1) == labels.to(device)).float().mean().item()
            )
    val_loss /= len(testloader)
    val_accuracy /= len(testloader)
    testing_loss.append(val_loss)

    # Update learning rate
    scheduler.step()

    print(
        f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}"
    )

100%|██████████| 10375/10375 [00:17<00:00, 604.33it/s, accuracy=0.688, loss=0.925, lr=0.0001]


Epoch 1, Val Loss: 0.8035608896771669, Val Accuracy: 0.7214746369980993


100%|██████████| 10375/10375 [00:16<00:00, 616.12it/s, accuracy=0.672, loss=0.994, lr=9e-5]


Epoch 2, Val Loss: 0.7525244999961295, Val Accuracy: 0.7349131007606283


100%|██████████| 10375/10375 [00:16<00:00, 617.41it/s, accuracy=0.797, loss=0.626, lr=8.1e-5]


Epoch 3, Val Loss: 0.7373613281693933, Val Accuracy: 0.7402921003751233


100%|██████████| 10375/10375 [00:16<00:00, 617.38it/s, accuracy=0.672, loss=0.809, lr=7.29e-5]


Epoch 4, Val Loss: 0.728964772468911, Val Accuracy: 0.7423280487174517


100%|██████████| 10375/10375 [00:16<00:00, 617.13it/s, accuracy=0.688, loss=0.838, lr=6.56e-5]


Epoch 5, Val Loss: 0.7230823457057235, Val Accuracy: 0.744285691354306


100%|██████████| 10375/10375 [00:17<00:00, 605.09it/s, accuracy=0.719, loss=0.673, lr=5.9e-5]


Epoch 6, Val Loss: 0.7189688165094086, Val Accuracy: 0.7455024415470585


100%|██████████| 10375/10375 [00:16<00:00, 622.92it/s, accuracy=0.828, loss=0.619, lr=5.31e-5]


Epoch 7, Val Loss: 0.7156047168472867, Val Accuracy: 0.7463035691492174


100%|██████████| 10375/10375 [00:16<00:00, 631.60it/s, accuracy=0.781, loss=0.705, lr=4.78e-5]


Epoch 8, Val Loss: 0.7130680248175756, Val Accuracy: 0.7476046485632497


100%|██████████| 10375/10375 [00:16<00:00, 632.10it/s, accuracy=0.766, loss=0.56, lr=4.3e-5] 


Epoch 9, Val Loss: 0.7110616498114792, Val Accuracy: 0.7479238949009521


100%|██████████| 10375/10375 [00:16<00:00, 643.48it/s, accuracy=0.75, loss=0.695, lr=3.87e-5] 


Epoch 10, Val Loss: 0.70937959049055, Val Accuracy: 0.7485824659547166


100%|██████████| 10375/10375 [00:15<00:00, 655.04it/s, accuracy=0.828, loss=0.591, lr=3.49e-5]


Epoch 11, Val Loss: 0.7079626138199268, Val Accuracy: 0.7491647391569274


100%|██████████| 10375/10375 [00:15<00:00, 654.41it/s, accuracy=0.719, loss=0.735, lr=3.14e-5]


Epoch 12, Val Loss: 0.7068833232316588, Val Accuracy: 0.7493775700410698


100%|██████████| 10375/10375 [00:15<00:00, 650.92it/s, accuracy=0.766, loss=0.646, lr=2.82e-5]


Epoch 13, Val Loss: 0.7057936670877608, Val Accuracy: 0.7499618510742232


100%|██████████| 10375/10375 [00:15<00:00, 653.80it/s, accuracy=0.781, loss=0.643, lr=2.54e-5]


Epoch 14, Val Loss: 0.7049752966936681, Val Accuracy: 0.750558179138988


100%|██████████| 10375/10375 [00:15<00:00, 657.14it/s, accuracy=0.672, loss=0.763, lr=2.29e-5]


Epoch 15, Val Loss: 0.7042573128834815, Val Accuracy: 0.7503714501490112


In [None]:
plt.plot(training_loss, label='KAN Training Loss',linewidth=5)
plt.plot(testing_loss, label='KAN Testing Loss',linewidth=5)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()