# KAN For Multivariate Classification in Predicting Visual Stimuli Given Firing Rates

This notebook applies the Kolmogorov-Arnold Networks (KAN) architecture to visual stimulus prediction in mice, given the firing rates of single neurons.
The model predicts what image the mouse is being shown out of 118 images. The data consists of spike-trains containing the total spikes each neuron fired when each image was shown, with around 2000 neurons each. 

This model's highest accuracy was ~75%, where an accuracy of ~0.85% would be a total guess. The model was found to work best with only one KAN layer connecting the input and output dimensions. This means only one activation function was found to transform the dataset. Although running PCA before this layer, with principal components roughly the same number of components as the number of classes, allowed the model to perform faster and with a diminished accuracy. This performed worse than models, like LSTMs or ST-GATs, that can account for temporal and spatial dependencies. However, the ability to show inference between the effect of individual neurons and the possibility to network effects might create a niche for Kolmogorov-Arnold Networks to be used in future studies.

### Best Model - KAN with one layer

There is only an input and output layer with no hidden dimension here. This means only one singular activation function was found to predict the output. 

In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from kan import KANLayer  # Assuming kan is the module containing KANLayer
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from data_processors.pull_and_process_data import master_function
from tqdm import tqdm

# Hyper-parameters
batch_size = 128
num_epochs = 20
num_intervals = 8
k = 6
noise_scale = 0.1
scale_base = 1.0
scale_sp = 1.0
base_fun = nn.SiLU()
grid_eps = 0.0
grid_range = [0, 1]
sp_trainable = True
sb_trainable = True
# Sparsification parameters
lambda_l1 = 1e-4
lambda_entropy = 1e-4

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Pull and process data
mouse_number = 715093703
spike_df = master_function(session_number=mouse_number, output_dir="output", timesteps_per_frame=1)
spike_df = spike_df[spike_df['frame'] >= 0]

# Prepare data
X = spike_df.drop('frame', axis=1).values
y = spike_df['frame'].values.astype(int)

# Normalize the data
scaler = MinMaxScaler()
X_normalized = scaler.fit_transform(X)
X = X_normalized
num_classes = len(np.unique(y))

# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Convert to PyTorch tensors
X_train = torch.from_numpy(X_train).float().to(device)
y_train = torch.from_numpy(y_train).long().to(device)
X_test = torch.from_numpy(X_test).float().to(device)
y_test = torch.from_numpy(y_test).long().to(device)

# Create DataLoader
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize the KANLayer
in_dim = X_train.shape[1]  # 2073 in this case
out_dim = num_classes

kan_layer = KANLayer(in_dim=in_dim, out_dim=out_dim, num=num_intervals, k=k, noise_scale=noise_scale, 
                     scale_base=scale_base, scale_sp=scale_sp, base_fun=base_fun, grid_eps=grid_eps,
                     grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, device=device).to(device)

# Example forward pass
def forward_pass(kan_layer, X):
    y_pred, preacts, postacts, postspline = kan_layer(X)
    return y_pred

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(kan_layer.parameters(), lr=0.001)

# Helper function to compute sparsification penalties
def compute_sparsification_penalty(kan_layer):
    l1_penalty = 0.0
    entropy_penalty = 0.0
    for param in kan_layer.parameters():
        l1_penalty += torch.sum(torch.abs(param))
        param_normalized = torch.abs(param) / torch.sum(torch.abs(param))
        entropy_penalty -= torch.sum(param_normalized * torch.log(param_normalized + 1e-10))
    return lambda_l1 * l1_penalty + lambda_entropy * entropy_penalty

# Training loop with sparsification
for epoch in range(num_epochs):
    kan_layer.train()
    train_losses = []
    train_correct = 0
    total_train = 0
    
    # Training phase
    for X_batch, y_batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
        # Forward pass
        y_pred = forward_pass(kan_layer, X_batch)
        # Compute loss
        loss = criterion(y_pred, y_batch)
        
        # Compute sparsification penalty
        sparsification_penalty = compute_sparsification_penalty(kan_layer)
        
        # Total loss
        total_loss = loss + sparsification_penalty
        
        # Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        train_losses.append(total_loss.item())
        _, predicted = torch.max(y_pred.data, 1)
        total_train += y_batch.size(0)
        train_correct += (predicted == y_batch).sum().item()
    
    train_accuracy = 100 * train_correct / total_train
    
    # Testing phase
    kan_layer.eval()
    test_losses = []
    test_correct = 0
    total_test = 0
    
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            y_pred = forward_pass(kan_layer, X_batch)
            loss = criterion(y_pred, y_batch)
            test_losses.append(loss.item())
            _, predicted = torch.max(y_pred.data, 1)
            total_test += y_batch.size(0)
            test_correct += (predicted == y_batch).sum().item()
    
    test_accuracy = 100 * test_correct / total_test
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {np.mean(train_losses):.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {np.mean(test_losses):.4f}, Test Accuracy: {test_accuracy:.2f}%')

Updated version 3!
Initializing workflow...
Loading existing datasets...
Loaded spike trains dataset: <class 'pandas.core.frame.DataFrame'>
Total time elapsed: 0.01 seconds


                                                           

Epoch [1/20], Train Loss: 188.4661, Train Accuracy: 4.79%, Test Loss: 4.3529, Test Accuracy: 12.80%


                                                           

Epoch [2/20], Train Loss: 185.9678, Train Accuracy: 46.80%, Test Loss: 3.6774, Test Accuracy: 39.15%


                                                           

Epoch [3/20], Train Loss: 184.1534, Train Accuracy: 70.57%, Test Loss: 3.2550, Test Accuracy: 48.14%


                                                           

Epoch [4/20], Train Loss: 182.6214, Train Accuracy: 81.06%, Test Loss: 2.9202, Test Accuracy: 57.71%


                                                           

Epoch [5/20], Train Loss: 181.2276, Train Accuracy: 85.68%, Test Loss: 2.6852, Test Accuracy: 63.31%


                                                           

Epoch [6/20], Train Loss: 179.9529, Train Accuracy: 86.69%, Test Loss: 2.5426, Test Accuracy: 66.61%


                                                           

Epoch [7/20], Train Loss: 178.7396, Train Accuracy: 87.01%, Test Loss: 2.3908, Test Accuracy: 68.14%


                                                           

Epoch [8/20], Train Loss: 177.5277, Train Accuracy: 87.56%, Test Loss: 2.3430, Test Accuracy: 70.85%


                                                           

Epoch [9/20], Train Loss: 176.3341, Train Accuracy: 87.84%, Test Loss: 2.2758, Test Accuracy: 69.58%


                                                            

Epoch [10/20], Train Loss: 175.1333, Train Accuracy: 87.92%, Test Loss: 2.1722, Test Accuracy: 74.07%


                                                            

Epoch [11/20], Train Loss: 173.9204, Train Accuracy: 88.14%, Test Loss: 2.1115, Test Accuracy: 74.15%


                                                            

Epoch [12/20], Train Loss: 172.7098, Train Accuracy: 88.22%, Test Loss: 2.1030, Test Accuracy: 72.63%


                                                            

Epoch [13/20], Train Loss: 171.5150, Train Accuracy: 88.24%, Test Loss: 2.0394, Test Accuracy: 73.98%


                                                            

Epoch [14/20], Train Loss: 170.3055, Train Accuracy: 88.22%, Test Loss: 2.0104, Test Accuracy: 73.81%


                                                            

Epoch [15/20], Train Loss: 169.0863, Train Accuracy: 88.20%, Test Loss: 1.9663, Test Accuracy: 74.07%


                                                            

Epoch [16/20], Train Loss: 167.8666, Train Accuracy: 88.28%, Test Loss: 1.9235, Test Accuracy: 75.51%


                                                            

Epoch [17/20], Train Loss: 166.6372, Train Accuracy: 88.31%, Test Loss: 1.8950, Test Accuracy: 76.27%


                                                            

Epoch [18/20], Train Loss: 165.4019, Train Accuracy: 88.26%, Test Loss: 1.8645, Test Accuracy: 76.10%


                                                            

Epoch [19/20], Train Loss: 164.1638, Train Accuracy: 88.26%, Test Loss: 1.8368, Test Accuracy: 76.10%


                                                            

Epoch [20/20], Train Loss: 162.9174, Train Accuracy: 88.31%, Test Loss: 1.8233, Test Accuracy: 76.02%


In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from kan import KANLayer  # Assuming kan is the module containing KANLayer
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from data_processors.pull_and_process_data import master_function
from tqdm import tqdm

# Hyper-parameters
batch_size = 64
num_epochs = 16
num_intervals = 8
k = 6
noise_scale = 0.1
scale_base = 1.0
scale_sp = 1.0
base_fun = nn.SiLU()
grid_eps = 0.0
grid_range = [0, 1]
sp_trainable = True
sb_trainable = True
# Sparsification parameters
lambda_l1 = 1e-4
lambda_entropy = 1e-4

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Pull and process data
mouse_number = 715093703
spike_df = master_function(session_number=mouse_number, output_dir="output", timesteps_per_frame=1)
spike_df = spike_df[spike_df['frame'] >= 0]

# Prepare data
X = spike_df.drop('frame', axis=1).values
y = spike_df['frame'].values.astype(int)

# Normalize the data
scaler = MinMaxScaler()
X_normalized = scaler.fit_transform(X)
X = X_normalized
num_classes = len(np.unique(y))

# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Convert to PyTorch tensors
X_train = torch.from_numpy(X_train).float().to(device)
y_train = torch.from_numpy(y_train).long().to(device)
X_test = torch.from_numpy(X_test).float().to(device)
y_test = torch.from_numpy(y_test).long().to(device)

# Create DataLoader
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize the KANLayer
in_dim = X_train.shape[1]  # 2073 in this case
out_dim = num_classes

kan_layer = KANLayer(in_dim=in_dim, out_dim=out_dim, num=num_intervals, k=k, noise_scale=noise_scale, 
                     scale_base=scale_base, scale_sp=scale_sp, base_fun=base_fun, grid_eps=grid_eps,
                     grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, device=device).to(device)

# Example forward pass
def forward_pass(kan_layer, X):
    y_pred, preacts, postacts, postspline = kan_layer(X)
    return y_pred

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(kan_layer.parameters(), lr=0.001)

# Helper function to compute sparsification penalties
def compute_sparsification_penalty(kan_layer):
    l1_penalty = 0.0
    entropy_penalty = 0.0
    for param in kan_layer.parameters():
        l1_penalty += torch.sum(torch.abs(param))
        param_normalized = torch.abs(param) / torch.sum(torch.abs(param))
        entropy_penalty -= torch.sum(param_normalized * torch.log(param_normalized + 1e-10))
    return lambda_l1 * l1_penalty + lambda_entropy * entropy_penalty

# Training loop with sparsification
for epoch in range(num_epochs):
    kan_layer.train()
    train_losses = []
    train_correct = 0
    total_train = 0
    
    # Training phase
    for X_batch, y_batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
        # Forward pass
        y_pred = forward_pass(kan_layer, X_batch)
        # Compute loss
        loss = criterion(y_pred, y_batch)
        
        # Compute sparsification penalty
        sparsification_penalty = compute_sparsification_penalty(kan_layer)
        
        # Total loss
        total_loss = loss + sparsification_penalty
        
        # Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        train_losses.append(total_loss.item())
        _, predicted = torch.max(y_pred.data, 1)
        total_train += y_batch.size(0)
        train_correct += (predicted == y_batch).sum().item()
    
    train_accuracy = 100 * train_correct / total_train
    
    # Testing phase
    kan_layer.eval()
    test_losses = []
    test_correct = 0
    total_test = 0
    
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            y_pred = forward_pass(kan_layer, X_batch)
            loss = criterion(y_pred, y_batch)
            test_losses.append(loss.item())
            _, predicted = torch.max(y_pred.data, 1)
            total_test += y_batch.size(0)
            test_correct += (predicted == y_batch).sum().item()
    
    test_accuracy = 100 * test_correct / total_test
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {np.mean(train_losses):.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {np.mean(test_losses):.4f}, Test Accuracy: {test_accuracy:.2f}%')

Updated version 3!
Initializing workflow...
Loading existing datasets...
Loaded spike trains dataset: <class 'pandas.core.frame.DataFrame'>
Total time elapsed: 0.01 seconds


                                                           

Epoch [1/16], Train Loss: 188.0290, Train Accuracy: 6.93%, Test Loss: 4.2406, Test Accuracy: 19.15%


                                                           

Epoch [2/16], Train Loss: 184.2855, Train Accuracy: 54.47%, Test Loss: 3.2804, Test Accuracy: 43.81%


                                                           

Epoch [3/16], Train Loss: 181.4581, Train Accuracy: 77.78%, Test Loss: 2.8695, Test Accuracy: 57.37%


                                                           

Epoch [4/16], Train Loss: 179.0327, Train Accuracy: 84.72%, Test Loss: 2.6087, Test Accuracy: 60.34%


                                                           

Epoch [5/16], Train Loss: 176.7684, Train Accuracy: 86.69%, Test Loss: 2.3947, Test Accuracy: 66.95%


                                                           

Epoch [6/16], Train Loss: 174.5394, Train Accuracy: 87.54%, Test Loss: 2.2611, Test Accuracy: 67.97%


                                                           

Epoch [7/16], Train Loss: 172.3232, Train Accuracy: 88.05%, Test Loss: 2.1331, Test Accuracy: 71.78%


                                                           

Epoch [8/16], Train Loss: 170.1048, Train Accuracy: 88.03%, Test Loss: 2.0538, Test Accuracy: 73.73%


                                                           

Epoch [9/16], Train Loss: 167.8531, Train Accuracy: 88.22%, Test Loss: 2.0073, Test Accuracy: 73.56%


                                                            

Epoch [10/16], Train Loss: 165.5976, Train Accuracy: 88.24%, Test Loss: 1.9333, Test Accuracy: 74.49%


                                                            

Epoch [11/16], Train Loss: 163.3189, Train Accuracy: 88.20%, Test Loss: 1.8913, Test Accuracy: 74.32%


                                                            

Epoch [12/16], Train Loss: 161.0274, Train Accuracy: 88.31%, Test Loss: 1.8483, Test Accuracy: 75.25%


                                                            

Epoch [13/16], Train Loss: 158.7204, Train Accuracy: 88.24%, Test Loss: 1.8169, Test Accuracy: 76.02%


                                                            

Epoch [14/16], Train Loss: 156.4054, Train Accuracy: 88.28%, Test Loss: 1.7938, Test Accuracy: 75.25%


                                                            

Epoch [15/16], Train Loss: 154.6162, Train Accuracy: 88.31%, Test Loss: 1.7689, Test Accuracy: 75.51%


                                                            

Epoch [16/16], Train Loss: 153.6548, Train Accuracy: 88.39%, Test Loss: 1.7530, Test Accuracy: 75.08%


### KAN with PCA Layer as Input

This works much quicker than running each neuron as an input, but at the cost of reducing accuracy.

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from kan import KANLayer  
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score
from data_processors.pull_and_process_data import master_function
from tqdm import tqdm

# Hyper-parameters
batch_size = 16
num_epochs = 5
num_intervals = 3
k = 3
noise_scale = 0.1
scale_base = 1.0
scale_sp = 1.0
base_fun = nn.SiLU()
grid_eps = 0.0
grid_range = [0, 1]
sp_trainable = True
sb_trainable = True
n_components = 120
# Sparsification parameters
lambda_l1 = 1e-4
lambda_entropy = 1e-4

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Pull and process data
mouse_number = 715093703
spike_df = master_function(session_number=mouse_number, output_dir="output", timesteps_per_frame=1)
spike_df = spike_df[spike_df['frame'] >= 0]

# Prepare data
X = spike_df.drop('frame', axis=1).values
y = spike_df['frame'].values.astype(int)

# Normalize the data
scaler = MinMaxScaler()
X_normalized = scaler.fit_transform(X)
X = X_normalized
num_classes = len(np.unique(y))

# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Perform PCA to reduce dimensions
pca = PCA(n_components=n_components)
X_train_reduced = pca.fit_transform(X_train)
X_test_reduced = pca.transform(X_test)

# Convert reduced data back to PyTorch tensors
X_train_reduced = torch.from_numpy(X_train_reduced).float().to(device)
X_test_reduced = torch.from_numpy(X_test_reduced).float().to(device)
y_train = torch.from_numpy(y_train).long().to(device)
y_test = torch.from_numpy(y_test).long().to(device)

# Create DataLoader with reduced data
train_dataset_reduced = TensorDataset(X_train_reduced, y_train)
test_dataset_reduced = TensorDataset(X_test_reduced, y_test)
train_loader_reduced = DataLoader(train_dataset_reduced, batch_size=batch_size, shuffle=True)
test_loader_reduced = DataLoader(test_dataset_reduced, batch_size=batch_size, shuffle=False)

# Initialize the KANLayer with reduced input dimension
in_dim_reduced = X_train_reduced.shape[1]  # This should be equal to n_components
out_dim = num_classes

kan_layer_reduced = KANLayer(in_dim=in_dim_reduced, out_dim=out_dim, num=num_intervals, k=k, noise_scale=noise_scale, 
                             scale_base=scale_base, scale_sp=scale_sp, base_fun=base_fun, grid_eps=grid_eps,
                             grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, device=device).to(device)

# Example forward pass
def forward_pass(kan_layer, X):
    y_pred, preacts, postacts, postspline = kan_layer(X)
    return y_pred

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(kan_layer_reduced.parameters(), lr=0.01)

# Helper function to compute sparsification penalties
def compute_sparsification_penalty(kan_layer):
    l1_penalty = 0.0
    entropy_penalty = 0.0
    for param in kan_layer.parameters():
        l1_penalty += torch.sum(torch.abs(param))
        param_normalized = torch.abs(param) / torch.sum(torch.abs(param))
        entropy_penalty -= torch.sum(param_normalized * torch.log(param_normalized + 1e-10))
    return lambda_l1 * l1_penalty + lambda_entropy * entropy_penalty

# Training loop with reduced data
for epoch in range(num_epochs):
    kan_layer_reduced.train()
    train_losses = []
    train_correct = 0
    total_train = 0
    
    # Training phase
    for X_batch, y_batch in tqdm(train_loader_reduced, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
        # Forward pass
        y_pred = forward_pass(kan_layer_reduced, X_batch)
        # Compute loss
        loss = criterion(y_pred, y_batch)
        
        # Compute sparsification penalty
        sparsification_penalty = compute_sparsification_penalty(kan_layer_reduced)
        
        # Total loss
        total_loss = loss + sparsification_penalty
        
        # Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        train_losses.append(total_loss.item())
        _, predicted = torch.max(y_pred.data, 1)
        total_train += y_batch.size(0)
        train_correct += (predicted == y_batch).sum().item()
    
    train_accuracy = 100 * train_correct / total_train
    
    # Testing phase
    kan_layer_reduced.eval()
    test_losses = []
    test_correct = 0
    total_test = 0
    
    with torch.no_grad():
        for X_batch, y_batch in test_loader_reduced:
            y_pred = forward_pass(kan_layer_reduced, X_batch)
            loss = criterion(y_pred, y_batch)
            test_losses.append(loss.item())
            _, predicted = torch.max(y_pred.data, 1)
            total_test += y_batch.size(0)
            test_correct += (predicted == y_batch).sum().item()
    
    test_accuracy = 100 * test_correct / total_test
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {np.mean(train_losses):.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {np.mean(test_losses):.4f}, Test Accuracy: {test_accuracy:.2f}%')

Updated version 3!
Initializing workflow...
Loading existing datasets...
Loaded spike trains dataset: <class 'pandas.core.frame.DataFrame'>
Total time elapsed: 0.01 seconds


                                                             

Epoch [1/5], Train Loss: 10.2513, Train Accuracy: 37.75%, Test Loss: 2.1236, Test Accuracy: 59.66%


                                                             

Epoch [2/5], Train Loss: 7.8138, Train Accuracy: 74.89%, Test Loss: 1.8550, Test Accuracy: 62.63%


                                                             

Epoch [3/5], Train Loss: 7.0677, Train Accuracy: 79.75%, Test Loss: 1.8371, Test Accuracy: 61.69%


                                                             

Epoch [4/5], Train Loss: 6.7355, Train Accuracy: 81.10%, Test Loss: 1.8331, Test Accuracy: 60.85%


                                                             

Epoch [5/5], Train Loss: 6.5631, Train Accuracy: 82.06%, Test Loss: 1.8799, Test Accuracy: 59.24%


### KAN Symbolic layers

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from kan import Symbolic_KANLayer  
import numpy as np
from data_processors.pull_and_process_data import master_function
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score
from data_processors.pull_and_process_data import master_function
from tqdm import tqdm

# Hyper-parameters
batch_size = 128
num_epochs = 5
num_intervals = 4
k = 3
noise_scale = 0.1
scale_base = 1.0
scale_sp = 1.0
base_fun = nn.SiLU()
grid_eps = 0.01
grid_range = [0, 1]
sp_trainable = True
sb_trainable = True
n_components = 12
# Sparsification parameters
lambda_l1 = 1e-4
lambda_entropy = 1e-4

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Pull and process data
mouse_number = 715093703
spike_df = master_function(session_number=mouse_number, output_dir="output", timesteps_per_frame=1)
spike_df = spike_df[spike_df['frame'] >= 0]

# Prepare data
X = spike_df.drop('frame', axis=1).values
y = spike_df['frame'].values.astype(int)

# Normalize the data
scaler = MinMaxScaler()
X_normalized = scaler.fit_transform(X)
X = X_normalized
num_classes = len(np.unique(y))

# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Perform PCA to reduce dimensions
pca = PCA(n_components=n_components)
X_train = pca.fit_transform(X_train)
X_test = pca.transform(X_test)

# Convert to PyTorch tensors
X_train = torch.from_numpy(X_train).float().to(device)
X_test = torch.from_numpy(X_test).float().to(device)
y_train = torch.from_numpy(y_train).long().to(device)
y_test = torch.from_numpy(y_test).long().to(device)

# Create DataLoader
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize the Symbolic_KANLayer
in_dim = X_train.shape[1]  # 2073 in this case
out_dim = num_classes

symbolic_kan_layer = Symbolic_KANLayer(in_dim=n_components, out_dim=out_dim, device=device).to(device)

# Define symbolic functions for the layer (example, you can modify as needed)
for i in range(in_dim):
    for j in range(out_dim):
        symbolic_kan_layer.fix_symbolic(i, j, 'sin')  # Replace 'sin' with the desired symbolic function

# Example forward pass
def forward_pass(symbolic_kan_layer, X):
    y_pred, postacts = symbolic_kan_layer(X)
    return y_pred

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(symbolic_kan_layer.parameters(), lr=0.01)

# Helper function to compute sparsification penalties
def compute_sparsification_penalty(symbolic_kan_layer):
    l1_penalty = 0.0
    entropy_penalty = 0.0
    for param in symbolic_kan_layer.parameters():
        l1_penalty += torch.sum(torch.abs(param))
        param_normalized = torch.abs(param) / torch.sum(torch.abs(param))
        entropy_penalty -= torch.sum(param_normalized * torch.log(param_normalized + 1e-10))
    return lambda_l1 * l1_penalty + lambda_entropy * entropy_penalty

# Training loop with sparsification
for epoch in range(num_epochs):
    symbolic_kan_layer.train()
    train_losses = []
    train_correct = 0
    total_train = 0
    
    # Training phase
    for X_batch, y_batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
        # Forward pass
        y_pred = forward_pass(symbolic_kan_layer, X_batch)
        # Compute loss
        loss = criterion(y_pred, y_batch)
        
        # Compute sparsification penalty
        sparsification_penalty = compute_sparsification_penalty(symbolic_kan_layer)
        
        # Total loss
        total_loss = loss + sparsification_penalty
        
        # Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        train_losses.append(total_loss.item())
        _, predicted = torch.max(y_pred.data, 1)
        total_train += y_batch.size(0)
        train_correct += (predicted == y_batch).sum().item()
    
    train_accuracy = 100 * train_correct / total_train
    
    # Testing phase
    symbolic_kan_layer.eval()
    test_losses = []
    test_correct = 0
    total_test = 0
    
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            y_pred = forward_pass(symbolic_kan_layer, X_batch)
            loss = criterion(y_pred, y_batch)
            test_losses.append(loss.item())
            _, predicted = torch.max(y_pred.data, 1)
            total_test += y_batch.size(0)
            test_correct += (predicted == y_batch).sum().item()
    
    test_accuracy = 100 * test_correct / total_test
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {np.mean(train_losses):.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {np.mean(test_losses):.4f}, Test Accuracy: {test_accuracy:.2f}%')

Updated version 3!
Initializing workflow...
Loading existing datasets...
Loaded spike trains dataset: <class 'pandas.core.frame.DataFrame'>
Total time elapsed: 0.01 seconds


                                                          

Epoch [1/5], Train Loss: nan, Train Accuracy: 0.85%, Test Loss: 4.7707, Test Accuracy: 0.85%


                                                          

Epoch [2/5], Train Loss: nan, Train Accuracy: 0.85%, Test Loss: 4.7707, Test Accuracy: 0.85%


                                                          

Epoch [3/5], Train Loss: nan, Train Accuracy: 0.85%, Test Loss: 4.7707, Test Accuracy: 0.85%


                                                          

Epoch [4/5], Train Loss: nan, Train Accuracy: 0.85%, Test Loss: 4.7707, Test Accuracy: 0.85%


                                                          

Epoch [5/5], Train Loss: nan, Train Accuracy: 0.85%, Test Loss: 4.7707, Test Accuracy: 0.85%


### Multi-Layered KAN

I may come back and redo this section in the future as it is very computationally expensive in it's current iteration.

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from kan import KANLayer  # Assuming kan is the module containing KANLayer
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from data_processors.pull_and_process_data import master_function
from tqdm import tqdm

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Pull and process data
mouse_number = 715093703
spike_df = master_function(session_number=mouse_number, output_dir="output", timesteps_per_frame=1)
spike_df = spike_df[spike_df['frame'] >= 0]

# Prepare data
X = spike_df.drop('frame', axis=1).values
y = spike_df['frame'].values.astype(int)

# Normalize the data
scaler = MinMaxScaler()
X_normalized = scaler.fit_transform(X)
X = X_normalized
num_classes = len(np.unique(y))

# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Convert to PyTorch tensors
X_train = torch.from_numpy(X_train).float().to(device)
y_train = torch.from_numpy(y_train).long().to(device)
X_test = torch.from_numpy(X_test).float().to(device)
y_test = torch.from_numpy(y_test).long().to(device)

# Create DataLoader
batch_size = 16
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define a custom KAN model with an intermediate layer
class CustomKANModel(nn.Module):
    def __init__(self, in_dim, inter_dim, out_dim, num_intervals, k, noise_scale, scale_base, scale_sp, base_fun, grid_eps, grid_range, sp_trainable, sb_trainable, device):
        super(CustomKANModel, self).__init__()
        self.kan_layer1 = KANLayer(in_dim=in_dim, out_dim=inter_dim, num=num_intervals, k=k, noise_scale=noise_scale, 
                                   scale_base=scale_base, scale_sp=scale_sp, base_fun=base_fun, grid_eps=grid_eps,
                                   grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, device=device).to(device)
        self.kan_layer2 = KANLayer(in_dim=inter_dim, out_dim=out_dim, num=num_intervals, k=k, noise_scale=noise_scale, 
                                   scale_base=scale_base, scale_sp=scale_sp, base_fun=base_fun, grid_eps=grid_eps,
                                   grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, device=device).to(device)

    def forward(self, X):
        y_pred1, preacts1, postacts1, postspline1 = self.kan_layer1(X)
        y_pred2, preacts2, postacts2, postspline2 = self.kan_layer2(y_pred1)
        return y_pred2

# Initialize the CustomKANModel
in_dim = X_train.shape[1]  # 2073 in this case
inter_dim = 1000  # Example intermediate dimension
out_dim = num_classes
num_intervals = 8
k = 6
noise_scale = 0.1
scale_base = 1.0
scale_sp = 1.0
base_fun = nn.SiLU()
grid_eps = 1
grid_range = [0, 1]
sp_trainable = True
sb_trainable = True

kan_model = CustomKANModel(in_dim=in_dim, inter_dim=inter_dim, out_dim=out_dim, num_intervals=num_intervals, k=k, noise_scale=noise_scale, 
                            scale_base=scale_base, scale_sp=scale_sp, base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, 
                            sp_trainable=sp_trainable, sb_trainable=sb_trainable, device=device).to(device)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(kan_model.parameters(), lr=0.01)

# Sparsification parameters
lambda_l1 = 1e-4
lambda_entropy = 1e-4

# Helper function to compute sparsification penalties
def compute_sparsification_penalty(model):
    l1_penalty = 0.0
    entropy_penalty = 0.0
    for param in model.parameters():
        l1_penalty += torch.sum(torch.abs(param))
        param_normalized = torch.abs(param) / torch.sum(torch.abs(param))
        entropy_penalty -= torch.sum(param_normalized * torch.log(param_normalized + 1e-10))
    return lambda_l1 * l1_penalty + lambda_entropy * entropy_penalty

# Training loop with sparsification
num_epochs = 100
for epoch in range(num_epochs):
    kan_model.train()
    train_losses = []
    train_correct = 0
    total_train = 0
    
    # Training phase
    for X_batch, y_batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
        # Forward pass
        y_pred = kan_model(X_batch)
        
        # Compute loss
        loss = criterion(y_pred, y_batch)
        
        # Compute sparsification penalty
        sparsification_penalty = compute_sparsification_penalty(kan_model)
        
        # Total loss
        total_loss = loss + sparsification_penalty
        
        # Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        train_losses.append(total_loss.item())
        _, predicted = torch.max(y_pred.data, 1)
        total_train += y_batch.size(0)
        train_correct += (predicted == y_batch).sum().item()
    
    train_accuracy = 100 * train_correct / total_train
    
    # Testing phase
    kan_model.eval()
    test_losses = []
    test_correct = 0
    total_test = 0
    
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            y_pred = kan_model(X_batch)
            loss = criterion(y_pred, y_batch)
            test_losses.append(loss.item())
            _, predicted = torch.max(y_pred.data, 1)
            total_test += y_batch.size(0)
            test_correct += (predicted == y_batch).sum().item()
    
    test_accuracy = 100 * test_correct / total_test
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {np.mean(train_losses):.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {np.mean(test_losses):.4f}, Test Accuracy: {test_accuracy:.2f}%')

# Pruning function
def prune_kan_layer(kan_layer, threshold=1e-2):
    with torch.no_grad():
        for name, param in kan_layer.named_parameters():
            if 'weight' in name:
                param *= (torch.abs(param) > threshold).float()

# Prune the KAN layers after training
prune_kan_layer(kan_model.kan_layer1)
prune_kan_layer(kan_model.kan_layer2)

# Retrain the pruned KAN model
for epoch in range(num_epochs):
    kan_model.train()
    train_losses = []
    train_correct = 0
    total_train = 0
    
    # Training phase
    for X_batch, y_batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
        # Forward pass
        y_pred = kan_model(X_batch)
        
        # Compute loss
        loss = criterion(y_pred, y_batch)
        
        # Compute sparsification penalty
        sparsification_penalty = compute_sparsification_penalty(kan_model)
        
        # Total loss
        total_loss = loss + sparsification_penalty
        
        # Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        train_losses.append(total_loss.item())
        _, predicted = torch.max(y_pred.data, 1)
        total_train += y_batch.size(0)
        train_correct += (predicted == y_batch).sum().item()
    
    train_accuracy = 100 * train_correct / total_train
    
    # Testing phase
    kan_model.eval()
    test_losses = []
    test_correct = 0
    total_test = 0
    
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            y_pred = kan_model(X_batch)
            loss = criterion(y_pred, y_batch)
            test_losses.append(loss.item())
            _, predicted = torch.max(y_pred.data, 1)
            total_test += y_batch.size(0)
            test_correct += (predicted == y_batch).sum().item()
    
    test_accuracy = 100 * test_correct / total_test
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {np.mean(train_losses):.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {np.mean(test_losses):.4f}, Test Accuracy: {test_accuracy:.2f}%')

Updated version 3!
Initializing workflow...
Loading existing datasets...
Loaded spike trains dataset: <class 'pandas.core.frame.DataFrame'>
Total time elapsed: 0.01 seconds


                                                    

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.35 GiB. GPU 0 has a total capacity of 9.62 GiB of which 2.10 GiB is free. Including non-PyTorch memory, this process has 7.49 GiB memory in use. Of the allocated memory 6.49 GiB is allocated by PyTorch, and 866.04 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from kan import KANLayer  # Assuming kan is the module containing KANLayer
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from data_processors.pull_and_process_data import master_function
from tqdm import tqdm

class KANModelTrainer:
    def __init__(self, mouse_number, output_dir, timesteps_per_frame, batch_size=16, num_epochs=100, device=None):
        self.mouse_number = mouse_number
        self.output_dir = output_dir
        self.timesteps_per_frame = timesteps_per_frame
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.criterion = nn.CrossEntropyLoss()
        self.lambda_l1 = 1e-4
        self.lambda_entropy = 1e-4

        # Set default hyperparameters
        self.in_dim = None
        self.inter_dim = 250
        self.out_dim = None
        self.num_intervals = 7
        self.k = 6
        self.noise_scale = 0.1
        self.scale_base = 1.0
        self.scale_sp = 1.0
        self.base_fun = nn.SiLU()
        self.grid_eps = 0.02
        self.grid_range = [0, 1]
        self.sp_trainable = True
        self.sb_trainable = True
        self.num_layers = 3
        self.lr = 0.01

    def set_hyperparameters(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

    def prepare_data(self):
        spike_df = master_function(session_number=self.mouse_number, output_dir=self.output_dir, timesteps_per_frame=self.timesteps_per_frame)
        spike_df = spike_df[spike_df['frame'] >= 0]

        X = spike_df.drop('frame', axis=1).values
        y = spike_df['frame'].values.astype(int)

        scaler = MinMaxScaler()
        X_normalized = scaler.fit_transform(X)
        X = X_normalized
        self.out_dim = len(np.unique(y))

        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
        
        self.in_dim = X_train.shape[1]

        X_train = torch.from_numpy(X_train).float().to(self.device)
        y_train = torch.from_numpy(y_train).long().to(self.device)
        X_test = torch.from_numpy(X_test).float().to(self.device)
        y_test = torch.from_numpy(y_test).long().to(self.device)

        train_dataset = TensorDataset(X_train, y_train)
        test_dataset = TensorDataset(X_test, y_test)
        self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)

    def create_kan_model(self):
        class CustomKANModel(nn.Module):
            def __init__(self, in_dim, inter_dim, out_dim, num_intervals, k, noise_scale, scale_base, scale_sp, base_fun, grid_eps, grid_range, sp_trainable, sb_trainable, num_layers, device):
                super(CustomKANModel, self).__init__()
                self.layers = nn.ModuleList()
                self.layers.append(KANLayer(in_dim=in_dim, out_dim=inter_dim, num=num_intervals, k=k, noise_scale=noise_scale, 
                                            scale_base=scale_base, scale_sp=scale_sp, base_fun=base_fun, grid_eps=grid_eps,
                                            grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, device=device).to(device))
                for _ in range(num_layers - 2):
                    self.layers.append(KANLayer(in_dim=inter_dim, out_dim=inter_dim, num=num_intervals, k=k, noise_scale=noise_scale, 
                                                scale_base=scale_base, scale_sp=scale_sp, base_fun=base_fun, grid_eps=grid_eps,
                                                grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, device=device).to(device))
                self.layers.append(KANLayer(in_dim=inter_dim, out_dim=out_dim, num=num_intervals, k=k, noise_scale=noise_scale, 
                                            scale_base=scale_base, scale_sp=scale_sp, base_fun=base_fun, grid_eps=grid_eps,
                                            grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, device=device).to(device))
            
            def forward(self, X):
                for layer in self.layers:
                    y_pred, _, _, _ = layer(X)
                    X = y_pred
                return y_pred

        self.kan_model = CustomKANModel(self.in_dim, self.inter_dim, self.out_dim, self.num_intervals, self.k, self.noise_scale, 
                                        self.scale_base, self.scale_sp, self.base_fun, self.grid_eps, self.grid_range, 
                                        self.sp_trainable, self.sb_trainable, self.num_layers, self.device).to(self.device)
        self.optimizer = optim.Adam(self.kan_model.parameters(), lr=self.lr)

    def compute_sparsification_penalty(self, model):
        l1_penalty = 0.0
        entropy_penalty = 0.0
        for param in model.parameters():
            l1_penalty += torch.sum(torch.abs(param))
            param_normalized = torch.abs(param) / torch.sum(torch.abs(param))
            entropy_penalty -= torch.sum(param_normalized * torch.log(param_normalized + 1e-10))
        return self.lambda_l1 * l1_penalty + self.lambda_entropy * entropy_penalty

    def train(self):
        for epoch in range(self.num_epochs):
            self.kan_model.train()
            train_losses = []
            train_correct = 0
            total_train = 0
            
            for X_batch, y_batch in tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{self.num_epochs}', leave=False):
                y_pred = self.kan_model(X_batch)
                loss = self.criterion(y_pred, y_batch)
                sparsification_penalty = self.compute_sparsification_penalty(self.kan_model)
                total_loss = loss + sparsification_penalty
                self.optimizer.zero_grad()
                total_loss.backward()
                self.optimizer.step()
                
                train_losses.append(total_loss.item())
                _, predicted = torch.max(y_pred.data, 1)
                total_train += y_batch.size(0)
                train_correct += (predicted == y_batch).sum().item()
            
            train_accuracy = 100 * train_correct / total_train
            
            self.kan_model.eval()
            test_losses = []
            test_correct = 0
            total_test = 0
            
            with torch.no_grad():
                for X_batch, y_batch in self.test_loader:
                    y_pred = self.kan_model(X_batch)
                    loss = self.criterion(y_pred, y_batch)
                    test_losses.append(loss.item())
                    _, predicted = torch.max(y_pred.data, 1)
                    total_test += y_batch.size(0)
                    test_correct += (predicted == y_batch).sum().item()
            
            test_accuracy = 100 * test_correct / total_test
            
            print(f'Epoch [{epoch+1}/{self.num_epochs}], Train Loss: {np.mean(train_losses):.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {np.mean(test_losses):.4f}, Test Accuracy: {test_accuracy:.2f}%')

    def prune_kan_layer(self, kan_layer, threshold=1e-2):
        with torch.no_grad():
            for name, param in kan_layer.named_parameters():
                if 'weight' in name:
                    param *= (torch.abs(param) > threshold).float()

    def prune_and_retrain(self):
        for layer in self.kan_model.layers:
            self.prune_kan_layer(layer)
        
        for epoch in range(self.num_epochs):
            self.kan_model.train()
            train_losses = []
            train_correct = 0
            total_train = 0
            
            for X_batch, y_batch in tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{self.num_epochs}', leave=False):
                y_pred = self.kan_model(X_batch)
                loss = self.criterion(y_pred, y_batch)
                sparsification_penalty = self.compute_sparsification_penalty(self.kan_model)
                total_loss = loss + sparsification_penalty
                self.optimizer.zero_grad()
                total_loss.backward()
                self.optimizer.step()
                
                train_losses.append(total_loss.item())
                _, predicted = torch.max(y_pred.data, 1)
                total_train += y_batch.size(0)
                train_correct += (predicted == y_batch).sum().item()
            
            train_accuracy = 100 * train_correct / total_train
            
            self.kan_model.eval()
            test_losses = []
            test_correct = 0
            total_test = 0
            
            with torch.no_grad():
                for X_batch, y_batch in self.test_loader:
                    y_pred = self.kan_model(X_batch)
                    loss = self.criterion(y_pred, y_batch)
                    test_losses.append(loss.item())
                    _, predicted = torch.max(y_pred.data, 1)
                    total_test += y_batch.size(0)
                    test_correct += (predicted == y_batch).sum().item()
            
            test_accuracy = 100 * test_correct / total_test
            
            print(f'Epoch [{epoch+1}/{self.num_epochs}], Train Loss: {np.mean(train_losses):.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {np.mean(test_losses):.4f}, Test Accuracy: {test_accuracy:.2f}%')

# Example usage
trainer = KANModelTrainer(mouse_number=715093703, output_dir="output", timesteps_per_frame=1, batch_size=16, num_epochs=6)
trainer.set_hyperparameters(inter_dim=500, num_layers=1, lr=0.01)
trainer.prepare_data()
trainer.create_kan_model()
trainer.train()
trainer.prune_and_retrain()


Updated version 3!
Initializing workflow...
Loading existing datasets...
Loaded spike trains dataset: <class 'pandas.core.frame.DataFrame'>
Total time elapsed: 0.01 seconds


                                                            

Epoch [1/6], Train Loss: 739.1743, Train Accuracy: 0.68%, Test Loss: 4.8930, Test Accuracy: 1.02%


                                                            

Epoch [2/6], Train Loss: 573.5616, Train Accuracy: 1.50%, Test Loss: 4.7925, Test Accuracy: 3.05%


                                                            

Epoch [3/6], Train Loss: 560.4291, Train Accuracy: 11.29%, Test Loss: 3.5963, Test Accuracy: 19.75%


                                                            

Epoch [4/6], Train Loss: 558.4958, Train Accuracy: 38.33%, Test Loss: 2.2982, Test Accuracy: 47.12%


                                                            

Epoch [5/6], Train Loss: 557.6904, Train Accuracy: 59.11%, Test Loss: 1.8608, Test Accuracy: 57.37%


                                                            

Epoch [6/6], Train Loss: 557.2176, Train Accuracy: 68.94%, Test Loss: 1.8196, Test Accuracy: 56.02%


                                                            

Epoch [1/6], Train Loss: 557.0008, Train Accuracy: 74.64%, Test Loss: 1.6900, Test Accuracy: 57.54%


                                                            

Epoch [2/6], Train Loss: 556.9613, Train Accuracy: 76.97%, Test Loss: 1.6114, Test Accuracy: 60.25%


                                                            

Epoch [3/6], Train Loss: 557.0994, Train Accuracy: 75.47%, Test Loss: 2.3321, Test Accuracy: 49.83%


                                                            

KeyboardInterrupt: 

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from kan import KANLayer  # Assuming kan is the module containing KANLayer
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from data_processors.pull_and_process_data import master_function
from tqdm import tqdm

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Pull and process data
mouse_number = 715093703
spike_df = master_function(session_number=mouse_number, output_dir="output", timesteps_per_frame=1)
spike_df = spike_df[spike_df['frame'] >= 0]

# Prepare data
X = spike_df.drop('frame', axis=1).values
y = spike_df['frame'].values.astype(int)

# Normalize the data
scaler = MinMaxScaler()
X_normalized = scaler.fit_transform(X)
X = X_normalized
num_classes = len(np.unique(y))

# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Convert to PyTorch tensors
X_train = torch.from_numpy(X_train).float().to(device)
y_train = torch.from_numpy(y_train).long().to(device)
X_test = torch.from_numpy(X_test).float().to(device)
y_test = torch.from_numpy(y_test).long().to(device)

# Create DataLoader
batch_size = 16
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define a custom KAN model with an intermediate layer
class CustomKANModel(nn.Module):
    def __init__(self, in_dim, inter_dim, out_dim, num_intervals, k, noise_scale, scale_base, scale_sp, base_fun, grid_eps, grid_range, sp_trainable, sb_trainable, device):
        super(CustomKANModel, self).__init__()
        self.kan_layer1 = KANLayer(in_dim=in_dim, out_dim=inter_dim, num=num_intervals, k=k, noise_scale=noise_scale, 
                                   scale_base=scale_base, scale_sp=scale_sp, base_fun=base_fun, grid_eps=grid_eps,
                                   grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, device=device).to(device)
        self.kan_layer2 = KANLayer(in_dim=inter_dim, out_dim=out_dim, num=num_intervals, k=k, noise_scale=noise_scale, 
                                   scale_base=scale_base, scale_sp=scale_sp, base_fun=base_fun, grid_eps=grid_eps,
                                   grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, device=device).to(device)

    def forward(self, X):
        y_pred1, preacts1, postacts1, postspline1 = self.kan_layer1(X)
        y_pred2, preacts2, postacts2, postspline2 = self.kan_layer2(y_pred1)
        return y_pred2

# Initialize the CustomKANModel
in_dim = X_train.shape[1]  # 2073 in this case
inter_dim = 118  # Example intermediate dimension
out_dim = num_classes
num_intervals = 7
k = 3
noise_scale = 0.1
scale_base = 1.0
scale_sp = 1.0
base_fun = nn.SiLU()
grid_eps = 0.02
grid_range = [-1, 1]
sp_trainable = True
sb_trainable = True

kan_model = CustomKANModel(in_dim=in_dim, inter_dim=inter_dim, out_dim=out_dim, num_intervals=num_intervals, k=k, noise_scale=noise_scale, 
                            scale_base=scale_base, scale_sp=scale_sp, base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, 
                            sp_trainable=sp_trainable, sb_trainable=sb_trainable, device=device).to(device)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(kan_model.parameters(), lr=0.01)

# Sparsification parameters
lambda_l1 = 1e-4
lambda_entropy = 1e-4

# Helper function to compute sparsification penalties
def compute_sparsification_penalty(model):
    l1_penalty = 0.0
    entropy_penalty = 0.0
    for param in model.parameters():
        l1_penalty += torch.sum(torch.abs(param))
        param_normalized = torch.abs(param) / torch.sum(torch.abs(param))
        entropy_penalty -= torch.sum(param_normalized * torch.log(param_normalized + 1e-10))
    return lambda_l1 * l1_penalty + lambda_entropy * entropy_penalty

# Training loop with sparsification
num_epochs = 100
for epoch in range(num_epochs):
    kan_model.train()
    train_losses = []
    train_correct = 0
    total_train = 0
    
    # Training phase
    for X_batch, y_batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
        # Forward pass
        y_pred = kan_model(X_batch)
        
        # Compute loss
        loss = criterion(y_pred, y_batch)
        
        # Compute sparsification penalty
        sparsification_penalty = compute_sparsification_penalty(kan_model)
        
        # Total loss
        total_loss = loss + sparsification_penalty
        
        # Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        train_losses.append(total_loss.item())
        _, predicted = torch.max(y_pred.data, 1)
        total_train += y_batch.size(0)
        train_correct += (predicted == y_batch).sum().item()
    
    train_accuracy = 100 * train_correct / total_train
    
    # Testing phase
    kan_model.eval()
    test_losses = []
    test_correct = 0
    total_test = 0
    
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            y_pred = kan_model(X_batch)
            loss = criterion(y_pred, y_batch)
            test_losses.append(loss.item())
            _, predicted = torch.max(y_pred.data, 1)
            total_test += y_batch.size(0)
            test_correct += (predicted == y_batch).sum().item()
    
    test_accuracy = 100 * test_correct / total_test
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {np.mean(train_losses):.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {np.mean(test_losses):.4f}, Test Accuracy: {test_accuracy:.2f}%')

# Pruning function
def prune_kan_layer(kan_layer, threshold=1e-2):
    with torch.no_grad():
        for name, param in kan_layer.named_parameters():
            if 'weight' in name:
                param *= (torch.abs(param) > threshold).float()

# Prune the KAN layers after training
prune_kan_layer(kan_model.kan_layer1)
prune_kan_layer(kan_model.kan_layer2)

# Retrain the pruned KAN model
for epoch in range(num_epochs):
    kan_model.train()
    train_losses = []
    train_correct = 0
    total_train = 0
    
    # Training phase
    for X_batch, y_batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
        # Forward pass
        y_pred = kan_model(X_batch)
        
        # Compute loss
        loss = criterion(y_pred, y_batch)
        
        # Compute sparsification penalty
        sparsification_penalty = compute_sparsification_penalty(kan_model)
        
        # Total loss
        total_loss = loss + sparsification_penalty
        
        # Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        train_losses.append(total_loss.item())
        _, predicted = torch.max(y_pred.data, 1)
        total_train += y_batch.size(0)
        train_correct += (predicted == y_batch).sum().item()
    
    train_accuracy = 100 * train_correct / total_train
    
    # Testing phase
    kan_model.eval()
    test_losses = []
    test_correct = 0
    total_test = 0
    
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            y_pred = kan_model(X_batch)
            loss = criterion(y_pred, y_batch)
            test_losses.append(loss.item())
            _, predicted = torch.max(y_pred.data, 1)
            total_test += y_batch.size(0)
            test_correct += (predicted == y_batch).sum().item()
    
    test_accuracy = 100 * test_correct / total_test
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {np.mean(train_losses):.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {np.mean(test_losses):.4f}, Test Accuracy: {test_accuracy:.2f}%')

In [None]:
# Example usage
trainer = KANModelTrainer(mouse_number=715093703, output_dir="output", timesteps_per_frame=1, batch_size=16, num_epochs=100)
trainer.set_hyperparameters(inter_dim=118, num_layers=3, lr=0.01)
trainer.prepare_data()
trainer.create_kan_model()

# Number of repetitions
num_repetitions = 3

# Repeat training and pruning
for i in range(num_repetitions):
    print(f"Repetition {i+1}/{num_repetitions}")
    trainer.train()
    trainer.prune_and_retrain()
