In [5]:
import sys
sys.path.append('/home/aarushg/KAN-FPGA/KAN_Impl')
from KANLinear import KAN, Quantizer
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
device = 'cuda'

In [8]:
TP = 7
FP = 3


grid_range=[-2**(TP - FP - 1), 2**(TP - FP - 1)]
resolution = int(2 ** TP)

REGULARIZE_ACTIVATION = 0.001

print(TP, FP)


7 4


In [9]:

X_train=torch.from_numpy(np.load('data/X_train_val.npy')).float().to(device)
y_train=torch.from_numpy(np.load('data/y_train_val.npy')).float().to(device).argmax(dim=1)
X_test=torch.from_numpy(np.load('data/X_test.npy')).float().to(device)
y_test=torch.from_numpy(np.load('data/y_test.npy')).float().to(device).argmax(dim=1)

# Create TensorDataset objects
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

# Create DataLoader objects
batch_size = 64  # Adjust this based on your available memory
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

model = KAN([16,5,5], grid_size=30, spline_order=10, grid_eps=0.05, base_activation=nn.GELU, grid_range=grid_range, quantize=True, tp=TP, fp=FP, lut_res=resolution, quantize_clip=True).to(device)
# model = KAN([16,4,5], grid_size=30, spline_order=3, grid_eps=0.05, base_activation=nn.GELU, grid_range=grid_range).to(device)

print(sum(p.numel() for p in model.parameters()))

model.to(device)
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

training_loss = []
testing_loss = []

# Define loss
criterion = nn.CrossEntropyLoss()
for epoch in range(30):
    # Train
    model.train()
    epoch_train_loss = 0  # Initialize loss for the epoch
    total_batches = 0
    with tqdm(trainloader) as pbar:
        for i, (inputs, labels) in enumerate(pbar):
            inputs = inputs.to(device)
            optimizer.zero_grad()
            output = model(inputs)
            loss = criterion(output, labels.to(device)) + model.regularization_loss(regularize_activation=REGULARIZE_ACTIVATION, regularize_entropy=min(0.005 * epoch, 0.05), regularize_clipping=min(0.05 * epoch, 0.2))
            loss.backward()

            optimizer.step()

            epoch_train_loss += loss.item()
            total_batches += 1

            accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()

            fracs_clipped = []
            for layer in model.layers:
                fracs_clipped.extend(round(x.item(), 4) for x in layer.get_frac_clipped())
            
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'], frac_clipped=fracs_clipped)
    
    average_train_loss = epoch_train_loss / total_batches
    training_loss.append(average_train_loss)  # Record the average training loss

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs = inputs.to(device)
            output = model(inputs)
            val_loss += criterion(output, labels.to(device)).item() + model.regularization_loss(regularize_activation=0, regularize_entropy=0, regularize_clipping=0.2)
            val_accuracy += (
                (output.argmax(dim=1) == labels.to(device)).float().mean().item()
            )
    val_loss /= len(testloader)
    val_accuracy /= len(testloader)
    testing_loss.append(val_loss)

    # Update learning rate
    scheduler.step()

    REMAINING_FRACTION = model.prune_below_threshold(threshold=0.1)
    print(f"Overall remaining fraction (Epoch {epoch + 1}): ", REMAINING_FRACTION)

    fracs_clipped = []
    for layer in model.layers:
        fracs_clipped.extend(x.item() for x in layer.get_frac_clipped())

    print(
        f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}, Val Frac Clipped: {fracs_clipped}"
    )

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [6]:
# Temporarily store quantizers
stored_quantizers = []
for layer in model.layers:
    layer_quantizers = {}
    if hasattr(layer, 'lut_inp_quantizer'):
        layer_quantizers['inp'] = layer.lut_inp_quantizer
        delattr(layer, 'lut_inp_quantizer')
    if hasattr(layer, 'lut_out_quantizer'):
        layer_quantizers['out'] = layer.lut_out_quantizer
        delattr(layer, 'lut_out_quantizer')
    stored_quantizers.append(layer_quantizers)

# Save model without quantizers
torch.save(model, f'models/model_{TP}t{FP}f_pr{1 - REMAINING_FRACTION}.pth')

# Restore quantizers
for layer, quantizers in zip(model.layers, stored_quantizers):
    if 'inp' in quantizers:
        layer.lut_inp_quantizer = quantizers['inp']
    if 'out' in quantizers:
        layer.lut_out_quantizer = quantizers['out']