In [1]:
!pip install torch torchvision tqdm pdm-backend



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import math

In [3]:


# Define the KANLinear class
class KANLinear(torch.nn.Module):
    """
    A custom linear layer that combines a base linear transformation with spline interpolation.

    Args:
        in_features (int): Size of each input sample.
        out_features (int): Size of each output sample.
        grid_size (int, optional): Number of grid points for the spline interpolation. Default is 5.
        spline_order (int, optional): Order of the spline used in the interpolation. Default is 3.
        scale_noise (float, optional): Scaling factor for the noise added during initialization. Default is 0.1.
        scale_base (float, optional): Scaling factor for the base weights initialization. Default is 1.0.
        scale_spline (float, optional): Scaling factor for the spline weights initialization. Default is 1.0.
        enable_standalone_scale_spline (bool, optional): If True, enables independent scaling for spline weights. Default is True.
        base_activation (callable, optional): Activation function applied after the base linear transformation. Default is SiLU.
        grid_eps (float, optional): Smoothing parameter for the grid adaptation. Default is 0.02.
        grid_range (list, optional): Range of the grid values. Default is [-1, 1].

    Attributes:
        grid (torch.Tensor): Tensor representing the grid points for spline interpolation.
        base_weight (torch.nn.Parameter): Parameter for the base linear transformation.
        spline_weight (torch.nn.Parameter): Parameter for the spline interpolation weights.
        spline_scaler (torch.nn.Parameter): Parameter for scaling the spline weights, if enabled.
    """

    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        # Calculate grid spacing
        h = (grid_range[1] - grid_range[0]) / grid_size

        # Create a grid with the specified range and spacing
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        # Initialize weights for the base linear transformation and the spline interpolation
        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )

        # Initialize optional scaling parameter for the spline weights
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        """Resets the parameters of the layer."""
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = self.grid
        x = x.unsqueeze(-1)

        # Initialize the bases with the first-order condition
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)

        # Iteratively calculate higher-order B-spline bases
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        # Solve the linear system to find the spline coefficients
        A = self.b_splines(x).transpose(0, 1)
        B = y.transpose(0, 1)
        solution = torch.linalg.lstsq(A, B).solution
        result = solution.permute(2, 0, 1)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        """Return the scaled spline weights, optionally applying a scaling factor."""
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        """
        Forward pass through the KANLinear layer.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, out_features).
        """
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        # Apply the base linear transformation with activation
        base_output = F.linear(self.base_activation(x), self.base_weight)

        # Apply the spline interpolation
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )

        # Combine the base and spline outputs
        output = base_output + spline_output
        output = output.reshape(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        """
        Update the grid based on input data, allowing for adaptive grid placement.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            margin (float, optional): Margin added to the grid boundaries. Default is 0.01.
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        # Calculate spline bases and interpolate coefficients
        splines = self.b_splines(x).permute(1, 0, 2)
        orig_coeff = self.scaled_spline_weight.permute(1, 2, 0)
        unreduced_spline_output = torch.bmm(splines, orig_coeff).permute(1, 0, 2)

        # Sort input data and adaptively update grid
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device)
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(self.grid_size + 1, dtype=torch.float32, device=x.device).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.cat(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        # Update grid and spline weights with new grid values
        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute regularization loss based on activation and entropy.

        Args:
            regularize_activation (float, optional): Weight for activation regularization loss. Default is 1.0.
            regularize_entropy (float, optional): Weight for entropy regularization loss. Default is 1.0.

        Returns:
            torch.Tensor: Regularization loss value.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


In [4]:
# Define KAN Class
class KAN(torch.nn.Module):
    """
    A neural network model implementing the Kolmogorov-Arnold Neural Network (KAN) architecture.
    This model uses a series of KANLinear layers to build a deep network with B-spline basis functions.

    Parameters:
    - layers_hidden (list of int): List defining the number of units in each hidden layer.
      The input layer size is the first element, and the output layer size is the last element.
    - grid_size (int, optional): Number of grid points for the B-spline basis functions. Default is 5.
    - spline_order (int, optional): Order of the B-spline basis functions. Default is 3.
    - scale_noise (float, optional): Scale of the noise added to the spline weights. Default is 0.1.
    - scale_base (float, optional): Scaling factor for the base weights. Default is 1.0.
    - scale_spline (float, optional): Scaling factor for the spline weights. Default is 1.0.
    - base_activation (torch.nn.Module, optional): Activation function applied to the base weights. Default is SiLU.
    - grid_eps (float, optional): Mixing factor for adaptive and uniform grid adjustment. Default is 0.02.
    - grid_range (list of float, optional): Range of the grid for the B-spline basis functions. Default is [-1, 1].
    """

    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        # Create a list of KANLinear layers based on the specified architecture
        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        """
        Forward pass through the network.

        Parameters:
        - x (torch.Tensor): Input tensor.
        - update_grid (bool, optional): Whether to update the grid based on the current input.
          Default is False.

        Returns:
        - torch.Tensor: Output tensor after passing through all layers.
        """
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss for the network.

        Parameters:
        - regularize_activation (float, optional): Weight for the activation regularization term. Default is 1.0.
        - regularize_entropy (float, optional): Weight for the entropy regularization term. Default is 1.0.

        Returns:
        - torch.Tensor: Regularization loss.
        """
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )


In [5]:

# Define transformations for the training and validation datasets
transform = transforms.Compose([
    transforms.ToTensor(),                     # Convert images to PyTorch tensors
    transforms.Normalize((0.5,), (0.5,))       # Normalize the images with mean=0.5 and std=0.5
])

# Load the MNIST training dataset with the specified transformations
trainset = torchvision.datasets.MNIST(
    root="./data",                              # Directory where the data will be stored
    train=True,                                 # Download the training set
    download=True,                              # Download the data if not already present
    transform=transform                         # Apply the defined transformations
)

# Load the MNIST validation dataset with the specified transformations
valset = torchvision.datasets.MNIST(
    root="./data",                              # Directory where the data will be stored
    train=False,                                # Download the test set (validation)
    download=True,                              # Download the data if not already present
    transform=transform                         # Apply the defined transformations
)

# Create data loaders for training and validation datasets
trainloader = DataLoader(
    trainset,                                   # Dataset to load
    batch_size=64,                              # Number of samples per batch
    shuffle=True                                # Shuffle the data at every epoch
)

valloader = DataLoader(
    valset,                                     # Dataset to load
    batch_size=64,                              # Number of samples per batch
    shuffle=False                               # Do not shuffle the data (useful for validation)
)


In [6]:


# Define the KAN model
# The model expects an input size of 28x28 (flattened), a hidden layer of 64 neurons, and an output layer of 10 classes (for MNIST digits 0-9)
model = KAN([28 * 28, 64, 10])

# Check if a GPU is available and move the model to the appropriate device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define the optimizer
# We use the AdamW optimizer, which is a variant of the Adam optimizer with weight decay (L2 regularization) included
# - `model.parameters()` provides the model's parameters to the optimizer
# - `lr=1e-3` sets the initial learning rate
# - `weight_decay=1e-4` adds regularization to prevent overfitting
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

# Define the learning rate scheduler
# We use an ExponentialLR scheduler, which reduces the learning rate by a factor of `gamma` after each epoch
# - `optimizer` is the optimizer whose learning rate is adjusted
# - `gamma=0.8` reduces the learning rate by 20% each epoch
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

# Define the loss function
# CrossEntropyLoss is suitable for multi-class classification problems like MNIST
# It combines `nn.LogSoftmax()` and `nn.NLLLoss()` in a single class
criterion = nn.CrossEntropyLoss()


In [7]:
from tqdm import tqdm

# Training loop for 5 epochs
for epoch in range(5):
    # Set the model to training mode
    model.train()

    # Initialize a progress bar for the training loop using tqdm
    with tqdm(trainloader) as pbar:
        for i, (images, labels) in enumerate(pbar):
            # Flatten the images and move them to the appropriate device
            images = images.view(-1, 28 * 28).to(device)

            # Zero the gradients before the backward pass
            optimizer.zero_grad()

            # Forward pass: compute the model output
            output = model(images)

            # Compute the loss between the output and the true labels
            loss = criterion(output, labels.to(device))

            # Backward pass: compute the gradients
            loss.backward()

            # Update the model parameters based on the gradients
            optimizer.step()

            # Calculate accuracy for the current batch
            accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()

            # Update the progress bar with the current loss, accuracy, and learning rate
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])

    # Validation loop
    model.eval()  # Set the model to evaluation mode (disables dropout, batchnorm, etc.)
    val_loss = 0
    val_accuracy = 0

    # Disable gradient calculation for validation to save memory and computation
    with torch.no_grad():
        for images, labels in valloader:
            # Flatten the images and move them to the appropriate device
            images = images.view(-1, 28 * 28).to(device)

            # Forward pass: compute the model output
            output = model(images)

            # Accumulate the validation loss
            val_loss += criterion(output, labels.to(device)).item()

            # Accumulate the validation accuracy
            val_accuracy += (output.argmax(dim=1) == labels.to(device)).float().mean().item()

    # Calculate the average validation loss and accuracy
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Update the learning rate according to the scheduler
    scheduler.step()

    # Print the results for the current epoch
    print(f"Epoch {epoch + 1}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")


100%|██████████| 938/938 [00:53<00:00, 17.63it/s, accuracy=0.938, loss=0.21, lr=0.001]


Epoch 1, Val Loss: 0.2147, Val Accuracy: 0.9358


100%|██████████| 938/938 [00:53<00:00, 17.44it/s, accuracy=0.969, loss=0.174, lr=0.0008]


Epoch 2, Val Loss: 0.1580, Val Accuracy: 0.9538


100%|██████████| 938/938 [00:57<00:00, 16.22it/s, accuracy=0.969, loss=0.194, lr=0.00064]


Epoch 3, Val Loss: 0.1239, Val Accuracy: 0.9625


100%|██████████| 938/938 [00:55<00:00, 16.97it/s, accuracy=1, loss=0.0188, lr=0.000512]


Epoch 4, Val Loss: 0.1121, Val Accuracy: 0.9679


100%|██████████| 938/938 [00:58<00:00, 16.09it/s, accuracy=1, loss=0.0133, lr=0.00041]


Epoch 5, Val Loss: 0.1052, Val Accuracy: 0.9690
