# The first attempt to recreate a KAN

In [118]:
import torch.nn as nn
import torch
import numpy as np
from spline import *
from utils import *
from scipy.interpolate import BSpline
import matplotlib.pyplot as plt

In [119]:
# KAN Layer implementation based off of PyKAN
class KANLayer(nn.Module):

    def __init__(self, in_dim, out_dim, num, k=3, noise_scale=0.1, scale_base=1.0, grid_range=[-1, 1], device='cpu'):
        super(KANLayer, self).__init__()

        # self.func_matrix = nn.Parameter(torch.zeros(in_dim, out_dim))

        # square = lambda x: x**2

        # self.func_matrix[0,0] = square

        self.size = size = out_dim * in_dim
        self.out_dim = out_dim
        self.in_dim = in_dim
        self.num = num
        self.k = k

        self.grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1, device=device).repeat(size, 1)
        self.grid = torch.nn.Parameter(self.grid).requires_grad_(False)

        # NO IDEA WHAT THIS DOES
        noises = (torch.rand(size, self.grid.shape[1]) - 1 / 2) * noise_scale / num
        noises = noises.to(device)

        self.coef = torch.nn.Parameter(curve2coef(self.grid, noises, self.grid, k, device))

        # NOR THIS
        # if isinstance(scale_base, float):
        #     self.scale_base = torch.nn.Parameter(torch.ones(size, device=device) * scale_base).requires_grad_(sb_trainable) 


        self.mask = torch.nn.Parameter(torch.ones(size, device=device)).requires_grad_(False)
        # self.grid_eps = grid_eps
        self.weight_sharing = torch.arange(size)
        self.lock_counter = 0
        self.lock_id = torch.zeros(size)
        self.device = device


    def forward(self, x):
        batch = x.shape[0]
        # x: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
        # x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, device=self.device)).reshape(batch, self.size).permute(1, 0)

        return x


# Only work from here

In [140]:
# create dataset f(x,y) = exp(sin(pi*x)+y^2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
dataset['train_input'].shape, dataset['train_label'].shape, dataset['test_input'].shape, dataset['test_label'].shape

X_train, y_train, X_test, y_test = dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label']

In [203]:
# creating a KAN layer from scratch
class MyKANLayer(nn.Module):

    def __init__(self, in_dim, out_dim, grid, degree=3, noise_scale=0.1, scale_base=1.0, grid_range=[-1, 1], device='cpu'):
        super(MyKANLayer, self).__init__()

        # initiliaze variables about the layer
        self.size = size = out_dim * in_dim
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.degree = degree
        self.grid_range = grid_range
        self.cache = None

        # The spline function requires three parameters: knots, coeff, and degree
        # knots: the grid points for the spline
        # self.knots = nn.Parameter(torch.linspace(grid_range[0], grid_range[1], steps=grid + 1, device=device).repeat(size, 1)).requires_grad_(False)
        self.knots = torch.linspace(grid_range[0], grid_range[1], steps=grid + 1, device=device)

        # coeff: the coefficients for the spline - THESE ARE LEARNABLE
        # I am wrapping them in a parameter since that is what they are
        self.coeff = nn.Parameter(torch.rand(size, grid + 1 + degree, device=device)).requires_grad_(True)


    def forward(self, x):

        # we process data in batches!
        print('batch size: ', x.shape[0])
        batch_size = x.shape[0]

        # we need to repeat the input for each spline function
        # x.shape = (size, batch_size)
        x = x.transpose(0, 1).repeat(self.out_dim, 1)

        # store the input for later
        self.cache = x

        # store the output of the spline functions
        out = torch.zeros(self.size, batch_size)


        # OLD WAY: WITH BUILT-IN B-SPLINE FUNCTION, WHICH DOES NOT PROVIDE US WITH GRADIENTS
        # for i in range(self.size):
        #     # x[i].shape = batch_size
        #     spline = BSpline(self.knots[i], self.coeff[i].detach().numpy(), self.degree)
        #     out[i] = torch.tensor(spline(x[i]))


        # loop through all the spline functions, and apply them to a single element for the whole batch
        # TODO: see if we can vectorize this
        for i in range(self.size):
            # Use torch operations to evaluate the B-spline
            knots = self.knots
            print('knots shape: ', knots.shape)
            coeff = self.coeff[i]
            print('coeff shape: ', coeff.shape)
            spline_values = self.evaluate_spline(x[i], knots, coeff, self.degree)
            out[i] = spline_values

        
        # plot the spline functions (optional)
        # self.plot_splines()


        # reshape the output to be of shape (out_dim, in_dim, batch_size)
        # then we sum it as part of the algorithm
        # then we transpose it so subsequent layers can use it
        y = out.reshape(self.out_dim, self.in_dim, batch_size).sum(dim=1).transpose(0, 1)

        return y
    






    def evaluate_spline(self, x, knots, coeff, degree):
        # Implement the B-spline evaluation directly in PyTorch
        # This is a simplified version and assumes a cubic B-spline (degree=3)
        assert degree == 3, "This implementation only supports cubic B-splines (degree=3)"

        # Initialize the B-spline basis functions
        n_knots = len(knots)
        n_coeffs = len(coeff)
        # assert n_knots == n_coeffs + degree + 1, "Mismatch between number of knots and coefficients for cubic B-splines"

        # Implement basis function evaluation (recursively or using a loop)
        # This example uses a loop for simplicity
        B = torch.zeros(x.shape[0], n_coeffs)

        # Basis function calculation
        for i in range(n_coeffs):
            B[:, i] = self.basis_function(x, knots, i, degree)

        # Evaluate the spline
        spline_values = B.matmul(coeff)

        return spline_values
    




    def basis_function(self, x, knots, degree):
        # Compute the B-spline basis function value
        # This is a placeholder for the basis function calculation
        pass

        # You need to implement the Cox-de Boor recursion formula here
        


        
        



    






    # pytorch may automatically handle this for us!!!
    # def backward(self, dupstream):
    #     pass


    # If we want to plot the spline curves of a layer
    def plot_splines(self):
        # Plot the spline functions (optional)
        points = np.linspace(self.grid_range[0], self.grid_range[1], 100)
        for i in range(self.size):
            spline = BSpline(self.knots[i].cpu().numpy(), self.coeff[i].detach().cpu().numpy(), self.degree)
            y = spline(points)
            plt.plot(points, y, label=f'B_{i,3}(points)')
            
        plt.title('Cubic B-spline Basis Functions')
        plt.xlabel('x')
        plt.ylabel('B_{i,3}(points)')
        plt.legend()
        plt.show()


In [208]:
# Testing the layer
layer = MyKANLayer(in_dim=2, out_dim=5, grid=8, degree=3)

out = layer(dataset['train_input'][:2])
out
layer.coeff.shape

batch size:  2
knots shape:  torch.Size([9])
coeff shape:  torch.Size([9])
knots shape:  torch.Size([9])
coeff shape:  torch.Size([9])
knots shape:  torch.Size([9])
coeff shape:  torch.Size([9])
knots shape:  torch.Size([9])
coeff shape:  torch.Size([9])
knots shape:  torch.Size([9])
coeff shape:  torch.Size([9])
knots shape:  torch.Size([9])
coeff shape:  torch.Size([9])
knots shape:  torch.Size([9])
coeff shape:  torch.Size([9])
knots shape:  torch.Size([9])
coeff shape:  torch.Size([9])
knots shape:  torch.Size([9])
coeff shape:  torch.Size([9])
knots shape:  torch.Size([9])
coeff shape:  torch.Size([9])


torch.Size([10, 9])

In [188]:
class MyKAN(nn.Module):
    
    def __init__(self, width=None, grid=3, degree=3, seed=69, device='cpu'):
        super(MyKAN, self).__init__()

        # intialize variables for the KAN
        self.biases = []
        self.act_fun = nn.ModuleList()
        self.depth = len(width) - 1
        self.width = width

        # create the layers here
        for l in range(self.depth):
            kan_layer = MyKANLayer(width[l], width[l+1], grid, degree, device=device)
            self.act_fun.append(kan_layer)

    # x should only be passed in batches
    def forward(self, x):
        for l in range(self.depth):
            x = self.act_fun[l](x)
            
        return x
    

    # # The backward and optimizer_step functions may not be necessary
    # def backward(self, dupstream):
    #     """
    #     Performs backward pass through all layers of the network.

    #     Args:
    #         dupstream: Gradient of loss with respect to output.
    #     """
    #     dx = dupstream

    #     for layer in reversed(self.layers):
    #         dx = layer.backward(dx)

    #     return dx
    

    # def optimizer_step(self, lr):
    #     """
    #     Updates network weights by performing a step in the negative gradient
    #     direction in each layer. The step size is determined by the learning
    #     rate.

    #     Args:
    #         lr: Learning rate to use for update step.
    #     """
    #     for layer in self.layers:
    #         if hasattr(layer, 'weight'):
    #             layer.weight -= layer.weight_grad * lr
    #         if hasattr(layer, 'bias'):
    #             layer.bias -= layer.bias_grad * lr

In [189]:
# test the KAN here
model = MyKAN(width=[2, 5, 1], grid=10, degree=3)

out = model(dataset['train_input'][:10])
out

batch size:  10


AssertionError: Mismatch between number of knots and coefficients for cubic B-splines

In [190]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# stick to MSE Loss for now, later we can do cross entropy
# criterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()



num_epochs = 10
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode

    optimizer.zero_grad()  # Zero the gradients

    outputs = model(X_train)  # Forward pass
    loss = criterion(outputs, y_train)  # Compute the loss

    loss.backward()  # Backward pass (compute gradients)
    optimizer.step()  # Update the parameters

    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# After training, you can evaluate the model
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    test_outputs = model(X_train)  # Forward pass on the training data
    test_loss = criterion(test_outputs, y_train)  # Compute the loss
    print(f'Test Loss: {test_loss.item():.4f}')

batch size:  1000


AssertionError: Mismatch between number of knots and coefficients for cubic B-splines