In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import *
from torch.autograd import Function

In [None]:
from pathlib import Path
import numpy as np
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
# Uncomment this line for MNIST training.
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
from tqdm import tqdm, trange
from torch.optim import Adam

np.random.seed(42)
torch.manual_seed(42)

In [2]:

class RSWAFFunction(Function):
    @staticmethod
    def forward(ctx, input, grid, inv_denominator, train_grid, train_inv_denominator):
        diff = (input[..., None] - grid)
        diff_mul = diff.mul(inv_denominator)
        tanh_diff = torch.tanh(diff)
        tanh_diff_deriviative = -tanh_diff.mul(tanh_diff) + 1  # sech^2(x) = 1 - tanh^2(x)
        
        # Save tensors for backward pass
        ctx.save_for_backward(input, tanh_diff, tanh_diff_deriviative, diff, inv_denominator)
        ctx.train_grid = train_grid
        ctx.train_inv_denominator = train_inv_denominator
        
        return tanh_diff_deriviative

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve saved tensors
        input, tanh_diff, tanh_diff_deriviative, diff, inv_denominator = ctx.saved_tensors
        grad_grid = None
        grad_inv_denominator = None
        grad_input = -2 * tanh_diff * tanh_diff_deriviative * grad_output
        grad_input = grad_input.sum(dim=-1).mul(inv_denominator)
        if ctx.train_grid:
            #print('\n')
            #print(f"grad_grid shape: {grad_grid.shape }")
            grad_grid = -inv_denominator * grad_output.sum(dim=0).sum(dim=0)    
        if ctx.train_inv_denominator:
            grad_inv_denominator = (grad_output* diff).sum()
        return grad_input, grad_grid, grad_inv_denominator, None, None # same number as tensors or parameters

In [3]:
class ReflectionalSwitchFunction(nn.Module):
    def __init__(
        self,
        grid_min: float = -1.2,
        grid_max: float = 0.2,
        num_grids: int = 8,
        exponent: int = 2,
        inv_denominator: float = 0.5,
        train_grid: bool = False,        
        train_inv_denominator: bool = False,
    ):
        super().__init__()
        grid = torch.linspace(grid_min, grid_max, num_grids)
        self.train_grid = torch.tensor(train_grid, dtype=torch.bool)
        self.train_inv_denominator = torch.tensor(train_inv_denominator, dtype=torch.bool) 
        self.grid = torch.nn.Parameter(grid, requires_grad=train_grid)
        #print(f"grid initial shape: {self.grid.shape }")
        self.inv_denominator = torch.nn.Parameter(torch.tensor(inv_denominator, dtype=torch.float32), requires_grad=train_inv_denominator)  # Cache the inverse of the denominator

    def forward(self, x):
        return RSWAFFunction.apply(x, self.grid, self.inv_denominator, self.train_grid, self.train_inv_denominator)

In [4]:
class SplineLinear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
        self.init_scale = init_scale
        super().__init__(in_features, out_features, bias=False, **kw)

    def reset_parameters(self) -> None:
        nn.init.xavier_uniform_(self.weight)  # Using Xavier Uniform initialization

In [5]:
class FasterKANLayer(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        grid_min: float = -1.2,
        grid_max: float = 0.2,
        num_grids: int = 8,
        exponent: int = 2,
        inv_denominator: float = 0.5,
        train_grid: bool = False,        
        train_inv_denominator: bool = False,
        #use_base_update: bool = True,
        base_activation = F.silu,
        spline_weight_init_scale: float = 0.667,
    ) -> None:
        super().__init__()
        self.layernorm = nn.LayerNorm(input_dim)
        self.rbf = ReflectionalSwitchFunction(grid_min, grid_max, num_grids, exponent, inv_denominator, train_grid, train_inv_denominator)
        self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)

    def forward(self, x):
        #print("Shape before LayerNorm:", x.shape)  # Debugging line to check the input shape
        x = self.layernorm(x)
        #print("Shape After LayerNorm:", x.shape)
        spline_basis = self.rbf(x).view(x.shape[0], -1)
        ret = self.spline_linear(spline_basis)
        return ret

In [7]:
import numpy

torch.manual_seed(42)
numpy.random.seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [8]:
def positional_embeddings(seq_length, d):
    """
    the purpose of this function is to find high and low interaction of a word with surrounding words.
    We can do so by the following equation below:

    Args:
        seq_length (int): The length of the sequence/sentence
        d (int): The dimension of the embedding
    """

    result = torch.ones(seq_length, d)
    for i in range(seq_length):
        for j in range(d):
            result[i][j] = numpy.sin(i / 10000 ** (j / d)) if j % 2 == 0 else numpy.cos(i / 10000 ** (j/ d))
    return result

In [9]:
def patchify(images, n_patches):
    """
    The purpose of this function is to break down the main image into multiple sub-images and map them.

    Args:
        images (_type_): The image passeed into this function.
        n_patches (_type_): The number of sub-images that will be created.
    """

    n, c, h, w = images.shape
    assert h == w, "Only for square images"

    patches = torch.zeros(n, n_patches**2, h * w * c // n_patches ** 2) # The equation to calculate the patches
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

In [None]:
class MSA(torch.nn.Module):
    """
        Initializes the Multi-Head Self-Attention (MSA) module with the given dimensions.

        Args:
            d (int): The total dimension of the input.
            n_heads (int): The number of attention heads.

        Returns:
            None
    """
    def __init__(self, d, n_heads):
        super(MSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0
        d_head = int(d / n_heads)

        self.q_mappings = torch.nn.ModuleList([FasterKANLayer(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = torch.nn.ModuleList([FasterKANLayer(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = torch.nn.ModuleList([FasterKANLayer(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, sequence):
        result = []
        for sequence in sequence:
            seq_res = []
            for head in range(self.n_heads):
                q_map = self.q_mappings[head]
                k_map = self.k_mappings[head]
                v_map = self.v_mappings[head]

                seq = sequence[:, head*self.d_head: (head+1)*self.d_head]
                q, k, v = q_map(seq), k_map(seq), v_map(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_res.append(attention @ v)
            result.append(torch.hstack(seq_res))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [None]:
class FasterKAN_ViT(torch.nn.Module): 
    """
        Initializes a Vision Transformer (ViT) module.

        Args:
            chw (list/tuple of 3 ints): The input image shape.
            n_patches (int, optional): The number of patches to split the image into. Defaults to 10.
            n_blocks (int, optional): The number of blocks in the transformer encoder. Defaults to 2.
            d_hidden (int, optional): The number of hidden dimensions in the transformer encoder. Defaults to 8.
            n_heads (int, optional): The number of attention heads in each block. Defaults to 2.
            out_d (int, optional): The number of output dimensions. Defaults to 10.

        Returns:
            None
    """    
    def __init__(self, chw, n_patches=10, n_blocks=2, d_hidden=8, n_heads=2, out_d=10): 
        super(FasterKAN_ViT, self).__init__()
        
        self.chw = chw
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.d_hidden = d_hidden
        
        assert chw[1] % n_patches == 0 
        assert chw[2] % n_patches == 0
        
        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

        # Linear mapping
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = FasterKANLayer(self.input_d, self.d_hidden)

        # Classification token
        self.v_class = torch.nn.Parameter(torch.rand(1, self.d_hidden))

        # Positional embedding
        self.register_buffer('positional_embeddings', positional_embeddings(n_patches ** 2 + 1, d_hidden),
                             persistent=False)

        # Encoder blocks
        self.blocks = torch.nn.ModuleList([MSA(d_hidden, n_heads) for _ in range(n_blocks)])

        self.mlp = torch.nn.Sequential(
            FasterKANLayer(self.d_hidden, out_d),
            torch.nn.Softmax(dim=-1)
        )

    def forward(self, images):
        n, c, h, w = images.shape
        patches = patchify(images, self.n_patches).to(self.positional_embeddings.device)

        # running tokenization
        tokens = self.linear_mapper(patches)
        tokens = torch.cat((self.v_class.expand(n, 1, -1), tokens), dim=1)
        out = tokens + self.positional_embeddings.repeat(n, 1, 1)

        for block in self.blocks:
            out = block(out)

        out = out[:, 0]
        return self.mlp(out)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mnist_model = FasterKAN_ViT((1, 28, 28), n_patches=7, n_blocks=2, d_hidden=8, n_heads=2, out_d=10).to(device)
optimizer = Adam(mnist_model.parameters(), lr=0.005)

In [None]:
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from tqdm import tqdm, trange

def main(train_loader, test_loader):
    """
    This code contains the training and testing loop for training the vision transformers model. It requires two
    parameters

    :param train_loader: The dataloader for the training set for training the model.
    :param test_loader: The dataloader for the testing set during evaluation phase.
    """
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

    epochs = 10
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in trange(epochs, desc="train"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch + 1}/{epochs} loss: {train_loss:.2f}")

    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)

        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

In [None]:
transform = transforms.ToTensor()
train_mnist = MNIST(root='./mnist', train=True, download=True, transform=transform)
test_mnist = MNIST(root='./mnist', train=False, download=True, transform=transform)
train_loader = DataLoader(train_mnist, shuffle=True, batch_size=128)
test_loader = DataLoader(test_mnist, shuffle=False, batch_size=128)
main(train_loader=train_loader, test_loader=test_loader)

In [None]:
path: str = "fasterkan_vit_10epochs.pth"

torch.save(mnist_model.state_dict(), path)