In [None]:
import torch
import torch.nn.functional as F
import math


class Layer(torch.nn.Module):
    def __init__(
        self,
        in_features = 2,
        out_features = 5,
        grid_range = [-1,1],
        grid_size = 5, # Number of knot/control point in the interval -1 , 1
        spline_order = 3, # polynomial of order 2
        sigma = 0.1, # sigma in footnote 2 of page 6
        base_activation = torch.nn.SiLU(),
    ):
        super(Layer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        grid_spacing = (grid_range[1] - grid_range[0]) / grid_size
        grid = ((torch.arange(-spline_order, grid_size + spline_order + 1) * grid_spacing + grid_range[0] ).expand(in_features, -1))

        self.grid = grid

        self.w_b = torch.nn.Parameter(torch.Tensor(out_features, in_features)) #w_b in 2.10 in the paper

        torch.nn.init.kaiming_uniform_(self.w_b, a=math.sqrt(5)) # Xavier
        self.c_i = torch.nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order)) #C_is in 2.12 in the paper


        self.sigma = sigma
        self.base_activation = base_activation()

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.w_b, a=math.sqrt(5))

        noise = ((torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 1 / 2) * self.sigma/ self.grid_size)
        self.c_i.data.copy_(self.curve2coeff(self.grid.T[self.spline_order : -self.spline_order], noise,))


    def b_splines(self, x):


        grid = self.grid
        x = x.unsqueeze(-1)
        # Initialize the bases tensor with zeros
        bases = torch.zeros((x.size(0), grid.size(0), grid.size(1) - 1), dtype=x.dtype, device=x.device)
        #print("grid sizes: ", grid.size(0), ",", grid.size(1))
        # Calculate bases for k=0
        for n in range(x.size(0)):
            for i in range(grid.size(0)):
                for j in range(grid.size(1) - 1):
                    if x[n, i] >= grid[i, j] and x[n, i] < grid[i, j + 1]:
                        bases[n, i, j] = 1.0
        # Iterate over the spline order
        for k in range(1, self.spline_order + 1):
            new_bases = torch.zeros((x.size(0), grid.size(0), grid.size(1) - k - 1), dtype=x.dtype, device=x.device)

            for n in range(x.size(0)):
                for i in range(grid.size(0)):
                    for j in range(grid.size(1) - k - 1):
                        left_term = 0.0
                        right_term = 0.0

                        if grid[i, j] != grid[i, j + k]:
                            left_term = ((x[n, i] - grid[i, j]) / (grid[i, j + k] - grid[i, j])) * bases[n, i, j]

                        if grid[i, j + k + 1] != grid[i, j + 1]:
                            right_term = ((grid[i, j + k + 1] - x[n, i]) / (grid[i, j + k + 1] - grid[i, j + 1])) * bases[n, i, j + 1]

                        new_bases[n, i, j] = left_term + right_term

            bases = new_bases
        # print(bases)


        # bases are now the basis function for the splines
        # bases[0] represent the contribution of each spline basis to the point newx[0]
        # It follows the restriction that the sum of basis function is 1 at each point

        # assert torch.all(torch.isclose(torch.sum(bases, dim = 2), torch.ones_like(x)))
        # Now we have the basis function that makes up the spline

        return bases



    def curve2coeff(self, x, y):
        # Compute the B-spline bases for the input tensor x
        #print("x shape during curve2coeff: ", x.shape)
        spline_bases = self.b_splines(x)  # (batch_size, in_features, grid_size + spline_order)
        #print("spline bases shape during curve2coeff: ", spline_bases.shape)
        spline_bases_t = spline_bases.transpose(0, 1)  # (in_features, batch_size, grid_size + spline_order)

        # Transpose the target tensor y
        y_t = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        #print("y_t shape: ", y_t.shape)
        # Solve the least squares problem to find the spline coefficients
        coefficients = torch.linalg.lstsq(spline_bases_t, y_t).solution  # (in_features, grid_size + spline_order, out_features)

        # Permute the dimensions to get the correct shape
        coeffs_permuted = coefficients.permute(2, 0, 1)  # (out_features, in_features, grid_size + spline_order)

        return coeffs_permuted



    def forward(self, x):
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.w_b)
        spline_output = F.linear(
            self.b_splines(x).reshape(x.size(0), -1),
            self.c_i.reshape(self.out_features, -1),
        )
        #print("x shape: ", x.shape)
        #print("spline shape: ", self.b_splines(x).shape)
        #print("c_is shape: ", self.c_i.shape)
        #print("spline shape after reshape: ", self.b_splines(x).reshape(x.size(0), -1).shape)
        #print("c_is after reshape: ", self.c_i.reshape(self.out_features, -1).shape)
        output = base_output + spline_output
        #print(*original_shape[:-1], self.out_features)
        output = output.reshape(*original_shape[:-1], self.out_features)
        return output


class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        sigma=0.1,
        base_activation=torch.nn.SiLU,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                Layer(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    sigma=sigma,
                    base_activation=base_activation,
                    grid_range=grid_range,))

    def forward(self, x: torch.Tensor):
        for layer in self.layers:
            x = layer(x)
        return x