# The first attempt to recreate a KAN

In [49]:
import torch.nn as nn
import torch
import numpy as np
from spline import *
from utils import *
from scipy.interpolate import BSpline

In [50]:
# KAN Layer implementation based off of PyKAN
class KANLayer(nn.Module):

    def __init__(self, in_dim, out_dim, num, k=3, noise_scale=0.1, scale_base=1.0, grid_range=[-1, 1], device='cpu'):
        super(KANLayer, self).__init__()

        # self.func_matrix = nn.Parameter(torch.zeros(in_dim, out_dim))

        # square = lambda x: x**2

        # self.func_matrix[0,0] = square

        self.size = size = out_dim * in_dim
        self.out_dim = out_dim
        self.in_dim = in_dim
        self.num = num
        self.k = k

        self.grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1, device=device).repeat(size, 1)
        self.grid = torch.nn.Parameter(self.grid).requires_grad_(False)

        # NO IDEA WHAT THIS DOES
        noises = (torch.rand(size, self.grid.shape[1]) - 1 / 2) * noise_scale / num
        noises = noises.to(device)

        self.coef = torch.nn.Parameter(curve2coef(self.grid, noises, self.grid, k, device))

        # NOR THIS
        # if isinstance(scale_base, float):
        #     self.scale_base = torch.nn.Parameter(torch.ones(size, device=device) * scale_base).requires_grad_(sb_trainable) 


        self.mask = torch.nn.Parameter(torch.ones(size, device=device)).requires_grad_(False)
        # self.grid_eps = grid_eps
        self.weight_sharing = torch.arange(size)
        self.lock_counter = 0
        self.lock_id = torch.zeros(size)
        self.device = device


    def forward(self, x):
        batch = x.shape[0]
        # x: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
        # x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, device=self.device)).reshape(batch, self.size).permute(1, 0)

        return x


# Only work from here

In [51]:
# create dataset f(x,y) = exp(sin(pi*x)+y^2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
dataset['train_input'].shape, dataset['train_label'].shape

(torch.Size([1000, 2]), torch.Size([1000, 1]))

In [109]:
# creating a KAN layer from scratch
class MyKANLayer(nn.Module):

    def __init__(self, in_dim, out_dim, grid, degree=3, noise_scale=0.1, scale_base=1.0, grid_range=[-1, 1], device='cpu'):
        super(MyKANLayer, self).__init__()

        # initiliaze variables about the layer
        self.size = size = out_dim * in_dim
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.degree = degree

        # The spline function requires three parameters: knots, coeff, and degree
        # knots: the grid points for the spline
        self.knots = nn.Parameter(torch.linspace(grid_range[0], grid_range[1], steps=grid + 1, device=device).repeat(size, 1)).requires_grad_(False)

        # coeff: the coefficients for the spline - THESE ARE LEARNABLE
        # I am wrapping them in a parameter since that is what they are
        self.coeff = nn.Parameter(torch.rand(size, grid + 1, device=device)).requires_grad_(True)


    def forward(self, x):

        # we process data in batches!
        print('batch size: ', x.shape[0])
        batch_size = x.shape[0]

        # we need to repeat the input for each spline function
        # x.shape = (size, batch_size)
        x = x.transpose(0, 1).repeat(self.out_dim, 1)

        # store the input for later
        self.cache = x

        # store the output of the spline functions
        out = torch.zeros(self.size, batch_size)

        # loop through all the spline functions, and apply them to a single element for the whole batch
        # TODO: see if we can vectorize this
        for i in range(self.size):
            # x[i].shape = batch_size
            # TODO: see if we can get rid of this detach numpy thingy
            out[i] = torch.tensor(BSpline(self.knots[i], self.coeff[i].detach().numpy(), self.degree)(x[i]))

        # reshape the output to be of shape (out_dim, in_dim, batch_size)
        # then we sum it as part of the algorithm
        # then we transpose it so subsequent layers can use it
        y = out.reshape(self.out_dim, self.in_dim, batch_size).sum(dim=1).transpose(0, 1)

        return y


In [110]:
# Testing the layer
layer = MyKANLayer(in_dim=2, out_dim=5, grid=9, degree=3)

out = layer(dataset['train_input'][:2])
out

batch size:  2


tensor([[0.8266, 0.7423, 0.5339, 1.1364, 0.2176],
        [1.0453, 1.4096, 0.4858, 0.6775, 0.7896]])

In [111]:
class MyKAN(nn.Module):
    
    def __init__(self, width=None, grid=3, degree=3, seed=69, device='cpu'):
        super(MyKAN, self).__init__()

        # intialize variables for the KAN
        self.biases = []
        self.act_fun = []
        self.depth = len(width) - 1
        self.width = width

        # create the layers here
        for l in range(self.depth):
            kan_layer = MyKANLayer(width[l], width[l+1], grid, degree, device=device)
            self.act_fun.append(kan_layer)

    # x should only be passed in batches
    def forward(self, x):
        for l in range(self.depth):
            x = self.act_fun[l](x)
            
        return x

In [115]:
# test the KAN here
kan = MyKAN(width=[2, 5, 1], grid=10, degree=3)

out = kan(dataset['train_input'][:10])
out

batch size:  10
batch size:  10


tensor([[   33.2443],
        [   17.4406],
        [ 7232.3174],
        [ 1155.7405],
        [   33.1464],
        [   37.5522],
        [  -46.5262],
        [-1184.3843],
        [   20.1112],
        [16459.5215]])