<a href="https://colab.research.google.com/github/Bismuthe-32/Adaptive-Basis-Functions-for-Enhanced-Kolmogorov/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Adaptive Basis Functions in Kolmogorov Arnold Networks

This paper proposes adaptive basis functions (ABFs) for Kolmogorovâ€“Arnold Networks (KANs), enabling dynamic adjustment of univariate basis functions during training. ABF-KANs achieve smoother convergence, improved generalization, and reduced mean-squared error on noisy oscillatory and polynomial functions. Results indicate a 7.6\% reduction in test error relative to fixed-basis KANs. The method addresses limitations of static basis functions, improving scalability, flexibility, and applicability of KANs in complex function approximation tasks.

In [61]:
%pip install ucimlrepo

Collecting ucimlrepo
  Downloading ucimlrepo-0.0.7-py3-none-any.whl.metadata (5.5 kB)
Downloading ucimlrepo-0.0.7-py3-none-any.whl (8.0 kB)
Installing collected packages: ucimlrepo
Successfully installed ucimlrepo-0.0.7


In [69]:
import torch
from torch import nn
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from ucimlrepo import fetch_ucirepo
import torch.nn.functional as F


We use the following datasets:

- **Synthetic Polynomial Functions:** Functions of the form  

  $$
  f(x) = \sum_{i=0}^{d} p_i x^i + \epsilon
  $$

  where each polynomial coefficient $p_i$ is sampled from a standard Gaussian distribution, and $\epsilon$ represents additive noise. The domain is $x \in [0,10]$, and the output is normalized to the range $[0,1]$.  

- **Boston Housing:** Predicts house prices using 13 input features.  

- **Airfoil Self-Noise:** Predicts sound pressure levels from 5 aerodynamic input variables.  

- **Energy Efficiency:** Predicts heating and cooling loads from 8 building characteristics.  

All datasets are standardized to zero mean and unit variance.


In [51]:
def generate_synthetic_poly_function(degrees=3, num_points=100, noise_std=0.0):
  X = torch.linspace(0,11,num_points)
  pis = torch.randn(degrees+1,1)

  y = sum(pis[i] * X**i for i in range(degrees + 1))

  if noise_std > 0:
    y += torch.randn_like(y) * noise_std

  return X,y


In [63]:
def boston():
      data_url = "http://lib.stat.cmu.edu/datasets/boston"
      raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
      return raw_df

  raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)


In [64]:
def airfoil_selfnoise():
  URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00291/airfoil_self_noise.dat'
  columns = ['frequency',
            'angle_of_attack',
            'chord_length',
            'free_stream_velocity',
            'suction_side_displacement_thickness',
            'scaled_sound_pressure_level']

  features = ['frequency','angle_of_attack',
              'chord_length',
              'free_stream_velocity',
              'suction_side_displacement_thickness']

  airfoil_dataset = pd.read_csv(url_file, sep='\t', header=None, names=columns)

In [65]:
def energy_efficiency():
  return fetch_ucirepo(id=242)

# Model Architectures

We compare Adaptive KANs to three other models:

**Fourier Feature Networks** : Fourier Feature Networks use a Fourier Feature Mapping to transform low-dimensional inputs (like pixel coordinates) into a higher-dimensional space using sines and cosines, allowing standard Multi-Layer Perceptrons (MLPs) to efficiently learn complex, high-frequency functions

**Multilayer Percpetron (MLPs)**: Multi-Layer Perceptrons (MLPs) are fundamental, fully connected feedforward neural networks with input, hidden, and output layers, using non-linear activation functions (like ReLU or sigmoid) to learn complex, non-linear patterns in structured data

**KANs with fixed basis functions** : Kolmogorov-Arnold Networks (KANs) are a novel neural network architecture that replaces fixed node activations in traditional Multi-Layer Perceptrons (MLPs) with learnable, univariate functions (splines) on the network's edges, offering enhanced accuracy, interpretability, and efficiency for complex function approximation

Many thanks to Professor Matthew Johnson at University of Cambridge Department of Engineering for providing the PyTorch Network of Fourier Feature Networks. Full repository on FFNs can be found [here](https://github.com/matajoh/fourier_feature_nets)

In [68]:
from typing import Optional, List

class FourierFeatureMLP(nn.Module):
    """MLP using Fourier features as a preprocessing step."""

    def __init__(self,
                 num_inputs: int,
                 num_outputs: int,
                 a_values: Optional[torch.Tensor],
                 b_values: Optional[torch.Tensor],
                 layer_channels: List[int]):
        super().__init__()
        self.num_inputs = num_inputs

        # Fourier feature encoding
        if b_values is None:
            self.a_values = None
            self.b_values = None
            encoded_inputs = num_inputs
        else:
            assert b_values.shape[0] == num_inputs, "b_values first dim must match num_inputs"
            assert a_values.shape[0] == b_values.shape[1], "a_values shape mismatch"
            self.a_values = nn.Parameter(a_values, requires_grad=False)
            self.b_values = nn.Parameter(b_values, requires_grad=False)
            encoded_inputs = b_values.shape[1] * 2

        # Store params for reproducibility
        self.params = {
            "num_inputs": num_inputs,
            "num_outputs": num_outputs,
            "a_values": None if a_values is None else a_values.tolist(),
            "b_values": None if b_values is None else b_values.tolist(),
            "layer_channels": layer_channels
        }

        # Build MLP layers
        self.layers = nn.ModuleList()
        for num_channels in layer_channels:
            self.layers.append(nn.Linear(encoded_inputs, num_channels))
            encoded_inputs = num_channels
        self.layers.append(nn.Linear(encoded_inputs, num_outputs))

        # Activation tracking
        self.keep_activations = False
        self.activations: List = []

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.b_values is not None:
            encoded = math.pi * x @ self.b_values
            x = torch.cat([self.a_values * encoded.cos(),
                           self.a_values * encoded.sin()], dim=-1)

        self.activations.clear()
        for layer in self.layers[:-1]:
            x = torch.relu(layer(x))
        if self.keep_activations:
            self.activations.append(x.detach().cpu().numpy())
        x = self.layers[-1](x)
        return x

# MLP

We use an MLP with only one hidden layer, and a ReLU activation

In [72]:
class SimpleMLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleMLP, self).__init__()
        # Define the layers
        self.fc1 = nn.Linear(input_size, hidden_size)  # Fully connected layer 1
        self.fc2 = nn.Linear(hidden_size, output_size) # Fully connected layer 2

    def forward(self, x):
        # Define the forward pass
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

# KAN with fixed basis functions

The network below comprises of a Kolmogorov Arnold network with fixed basis functions. Many thanks to Blealtan for creating the following code snippet. Full repository can be found [here](https://github.com/Blealtan/efficient-kan/tree/master)

In [73]:
import torch
import torch.nn.functional as F
import math


class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output

        output = output.reshape(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )

In [74]:
class AdaptiveKAN(KAN):
    """
    KAN with adaptive basis functions.
    Automatically updates the spline grid based on input distribution
    at each forward pass if `update_grid=True`.
    """

    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
        update_grid_every_forward=True,  # controls auto adaptation
    ):
        super().__init__(
            layers_hidden,
            grid_size=grid_size,
            spline_order=spline_order,
            scale_noise=scale_noise,
            scale_base=scale_base,
            scale_spline=scale_spline,
            base_activation=base_activation,
            grid_eps=grid_eps,
            grid_range=grid_range,
        )
        self.update_grid_every_forward = update_grid_every_forward

    def forward(self, x: torch.Tensor, update_grid=False):
        # If adaptive updating is enabled globally
        adaptive_update = self.update_grid_every_forward or update_grid
        for layer in self.layers:
            if adaptive_update:
                layer.update_grid(x)
            x = layer(x)
        return x
