In [1]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import Adam
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy

In [2]:
class SplineLinear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
        self.init_scale = init_scale
        super().__init__(in_features, out_features, bias=False, **kw)

    def reset_parameters(self) -> None:
        nn.init.trunc_normal_(self.weight, mean=0, std=self.init_scale)

class RadialBasisFunction(nn.Module):
    def __init__(
        self,
        grid_min: float = -2.,
        grid_max: float = 2.,
        num_grids: int = 8,
        denominator: float = None,  # larger denominators lead to smoother basis
    ):
        super().__init__()
        grid = torch.linspace(grid_min, grid_max, num_grids)
        self.grid = torch.nn.Parameter(grid, requires_grad=False)
        self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)

    def forward(self, x):
        return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)

class FastKANLayer(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        grid_min: float = -2.,
        grid_max: float = 2.,
        num_grids: int = 8,
        use_base_update: bool = True,
        base_activation = F.silu,
        spline_weight_init_scale: float = 0.1,
    ) -> None:
        super().__init__()
        self.layernorm = nn.LayerNorm(input_dim)
        self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)
        self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)
        self.use_base_update = use_base_update
        if use_base_update:
            self.base_activation = base_activation
            self.base_linear = nn.Linear(input_dim, output_dim)

    def forward(self, x, time_benchmark=False):
        if not time_benchmark:
            spline_basis = self.rbf(self.layernorm(x))
        else:
            spline_basis = self.rbf(x)
        ret = self.spline_linear(spline_basis.view(*spline_basis.shape[:-2], -1))
        if self.use_base_update:
            base = self.base_linear(self.base_activation(x))
            ret = ret + base
        return ret

In [3]:
torch.manual_seed(42)
numpy.random.seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
def positional_embeddings(seq_length, d):
    """
    the purpose of this function is to find high and low interaction of a word with surrounding words.
    We can do so by the following equation below:

    Args:
        seq_length (int): The length of the sequence/sentence
        d (int): The dimension of the embedding
    """

    result = torch.ones(seq_length, d)
    for i in range(seq_length):
        for j in range(d):
            result[i][j] = numpy.sin(i / 10000 ** (j / d)) if j % 2 == 0 else numpy.cos(i / 10000 ** (j/ d))
    return result

In [5]:
def patchify(images, n_patches):
    """
    The purpose of this function is to break down the main image into multiple sub-images and map them.

    Args:
        images (_type_): The image passeed into this function.
        n_patches (_type_): The number of sub-images that will be created.
    """

    n, c, h, w = images.shape
    assert h == w, "Only for square images"

    patches = torch.zeros(n, n_patches**2, h * w * c // n_patches ** 2) # The equation to calculate the patches
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

In [6]:
class MSA(torch.nn.Module):
    """
        Initializes the Multi-Head Self-Attention (MSA) module with the given dimensions.

        Args:
            d (int): The total dimension of the input.
            n_heads (int): The number of attention heads.

        Returns:
            None
    """
    def __init__(self, d, n_heads):
        super(MSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0
        d_head = int(d / n_heads)

        self.q_mappings = torch.nn.ModuleList([FastKANLayer(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = torch.nn.ModuleList([FastKANLayer(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = torch.nn.ModuleList([FastKANLayer(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, sequence):
        result = []
        for sequence in sequence:
            seq_res = []
            for head in range(self.n_heads):
                q_map = self.q_mappings[head]
                k_map = self.k_mappings[head]
                v_map = self.v_mappings[head]

                seq = sequence[:, head*self.d_head: (head+1)*self.d_head]
                q, k, v = q_map(seq), k_map(seq), v_map(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_res.append(attention @ v)
            result.append(torch.hstack(seq_res))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [7]:
class FastKAN_ViT(torch.nn.Module):
    """
        Initializes a Vision Transformer (ViT) module.

        Args:
            chw (list/tuple of 3 ints): The input image shape.
            n_patches (int, optional): The number of patches to split the image into. Defaults to 10.
            n_blocks (int, optional): The number of blocks in the transformer encoder. Defaults to 2.
            d_hidden (int, optional): The number of hidden dimensions in the transformer encoder. Defaults to 8.
            n_heads (int, optional): The number of attention heads in each block. Defaults to 2.
            out_d (int, optional): The number of output dimensions. Defaults to 10.

        Returns:
            None
    """
    def __init__(self, chw, n_patches=10, n_blocks=2, d_hidden=8, n_heads=2, out_d=10):
        super(FastKAN_ViT, self).__init__()

        self.chw = chw
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.d_hidden = d_hidden

        assert chw[1] % n_patches == 0
        assert chw[2] % n_patches == 0

        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

        # Linear mapping
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = FastKANLayer(self.input_d, self.d_hidden)

        # Classification token
        self.v_class = torch.nn.Parameter(torch.rand(1, self.d_hidden))

        # Positional embedding
        self.register_buffer('positional_embeddings', positional_embeddings(n_patches ** 2 + 1, d_hidden),
                             persistent=False)

        # Encoder blocks
        self.blocks = torch.nn.ModuleList([MSA(d_hidden, n_heads) for _ in range(n_blocks)])

        self.mlp = torch.nn.Sequential(
            FastKANLayer(self.d_hidden, out_d),
            torch.nn.Softmax(dim=-1)
        )

    def forward(self, images):
        n, c, h, w = images.shape
        patches = patchify(images, self.n_patches).to(self.positional_embeddings.device)

        # running tokenization
        tokens = self.linear_mapper(patches)
        tokens = torch.cat((self.v_class.expand(n, 1, -1), tokens), dim=1)
        out = tokens + self.positional_embeddings.repeat(n, 1, 1)

        for block in self.blocks:
            out = block(out)

        out = out[:, 0]
        return self.mlp(out)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mnist_model = FastKAN_ViT((1, 28, 28), n_patches=7, n_blocks=2, d_hidden=8, n_heads=2, out_d=10).to(device)
optimizer = Adam(mnist_model.parameters(), lr=0.005)

In [11]:
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from tqdm import tqdm, trange

def main(train_loader, test_loader):
    """
    This code contains the training and testing loop for training the vision transformers model. It requires two
    parameters

    :param train_loader: The dataloader for the training set for training the model.
    :param test_loader: The dataloader for the testing set during evaluation phase.
    """
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

    epochs = 10
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in trange(epochs, desc="train"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch + 1}/{epochs} loss: {train_loss:.2f}")

    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)

        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

In [12]:
transform = transforms.ToTensor()
train_mnist = MNIST(root='./mnist', train=True, download=True, transform=transform)
test_mnist = MNIST(root='./mnist', train=False, download=True, transform=transform)
train_loader = DataLoader(train_mnist, shuffle=True, batch_size=128)
test_loader = DataLoader(test_mnist, shuffle=False, batch_size=128)
main(train_loader=train_loader, test_loader=test_loader)

Using device:  cuda (Tesla P100-PCIE-16GB)


train:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 1/469 [00:01<12:17,  1.58s/it][A
Epoch 1 in training:   0%|          | 2/469 [00:03<12:01,  1.54s/it][A
Epoch 1 in training:   1%|          | 3/469 [00:04<11:53,  1.53s/it][A
Epoch 1 in training:   1%|          | 4/469 [00:06<11:48,  1.52s/it][A
Epoch 1 in training:   1%|          | 5/469 [00:07<11:47,  1.53s/it][A
Epoch 1 in training:   1%|▏         | 6/469 [00:09<11:56,  1.55s/it][A
Epoch 1 in training:   1%|▏         | 7/469 [00:10<11:55,  1.55s/it][A
Epoch 1 in training:   2%|▏         | 8/469 [00:12<11:50,  1.54s/it][A
Epoch 1 in training:   2%|▏         | 9/469 [00:13<11:45,  1.53s/it][A
Epoch 1 in training:   2%|▏         | 10/469 [00:15<11:44,  1.54s/it][A
Epoch 1 in training:   2%|▏         | 11/469 [00:16<11:43,  1.54s/it][A
Epoch 1 in training:   3%|▎         | 12/469 [00:18<11:41,  1.54s/it][A
Epoch 1 in training:   

Epoch 1/10 loss: 1.90



Epoch 2 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 2 in training:   0%|          | 1/469 [00:01<12:05,  1.55s/it][A
Epoch 2 in training:   0%|          | 2/469 [00:03<11:55,  1.53s/it][A
Epoch 2 in training:   1%|          | 3/469 [00:04<11:59,  1.54s/it][A
Epoch 2 in training:   1%|          | 4/469 [00:06<11:53,  1.53s/it][A
Epoch 2 in training:   1%|          | 5/469 [00:07<12:09,  1.57s/it][A
Epoch 2 in training:   1%|▏         | 6/469 [00:09<11:58,  1.55s/it][A
Epoch 2 in training:   1%|▏         | 7/469 [00:10<11:52,  1.54s/it][A
Epoch 2 in training:   2%|▏         | 8/469 [00:12<11:47,  1.53s/it][A
Epoch 2 in training:   2%|▏         | 9/469 [00:13<11:43,  1.53s/it][A
Epoch 2 in training:   2%|▏         | 10/469 [00:15<11:44,  1.53s/it][A
Epoch 2 in training:   2%|▏         | 11/469 [00:16<11:42,  1.53s/it][A
Epoch 2 in training:   3%|▎         | 12/469 [00:18<11:38,  1.53s/it][A
Epoch 2 in training:   3%|▎         | 13/469 [00:19<11:36,  1.53s/it

Epoch 2/10 loss: 1.85



Epoch 3 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 3 in training:   0%|          | 1/469 [00:01<11:40,  1.50s/it][A
Epoch 3 in training:   0%|          | 2/469 [00:03<11:53,  1.53s/it][A
Epoch 3 in training:   1%|          | 3/469 [00:04<11:59,  1.54s/it][A
Epoch 3 in training:   1%|          | 4/469 [00:06<11:57,  1.54s/it][A
Epoch 3 in training:   1%|          | 5/469 [00:07<11:50,  1.53s/it][A
Epoch 3 in training:   1%|▏         | 6/469 [00:09<11:50,  1.53s/it][A
Epoch 3 in training:   1%|▏         | 7/469 [00:10<11:43,  1.52s/it][A
Epoch 3 in training:   2%|▏         | 8/469 [00:12<11:37,  1.51s/it][A
Epoch 3 in training:   2%|▏         | 9/469 [00:13<11:34,  1.51s/it][A
Epoch 3 in training:   2%|▏         | 10/469 [00:15<11:50,  1.55s/it][A
Epoch 3 in training:   2%|▏         | 11/469 [00:16<11:45,  1.54s/it][A
Epoch 3 in training:   3%|▎         | 12/469 [00:18<11:40,  1.53s/it][A
Epoch 3 in training:   3%|▎         | 13/469 [00:19<11:39,  1.53s/it

Epoch 3/10 loss: 1.81



Epoch 4 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 4 in training:   0%|          | 1/469 [00:01<11:37,  1.49s/it][A
Epoch 4 in training:   0%|          | 2/469 [00:02<11:34,  1.49s/it][A
Epoch 4 in training:   1%|          | 3/469 [00:04<12:06,  1.56s/it][A
Epoch 4 in training:   1%|          | 4/469 [00:06<11:53,  1.54s/it][A
Epoch 4 in training:   1%|          | 5/469 [00:07<11:45,  1.52s/it][A
Epoch 4 in training:   1%|▏         | 6/469 [00:09<11:50,  1.53s/it][A
Epoch 4 in training:   1%|▏         | 7/469 [00:10<11:43,  1.52s/it][A
Epoch 4 in training:   2%|▏         | 8/469 [00:12<11:38,  1.52s/it][A
Epoch 4 in training:   2%|▏         | 9/469 [00:13<11:41,  1.53s/it][A
Epoch 4 in training:   2%|▏         | 10/469 [00:15<11:37,  1.52s/it][A
Epoch 4 in training:   2%|▏         | 11/469 [00:16<11:33,  1.51s/it][A
Epoch 4 in training:   3%|▎         | 12/469 [00:18<11:31,  1.51s/it][A
Epoch 4 in training:   3%|▎         | 13/469 [00:19<11:29,  1.51s/it

Epoch 4/10 loss: 1.76



Epoch 5 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 5 in training:   0%|          | 1/469 [00:01<11:39,  1.49s/it][A
Epoch 5 in training:   0%|          | 2/469 [00:03<11:43,  1.51s/it][A
Epoch 5 in training:   1%|          | 3/469 [00:04<11:46,  1.52s/it][A
Epoch 5 in training:   1%|          | 4/469 [00:06<11:46,  1.52s/it][A
Epoch 5 in training:   1%|          | 5/469 [00:07<11:46,  1.52s/it][A
Epoch 5 in training:   1%|▏         | 6/469 [00:09<11:51,  1.54s/it][A
Epoch 5 in training:   1%|▏         | 7/469 [00:10<12:07,  1.58s/it][A
Epoch 5 in training:   2%|▏         | 8/469 [00:12<12:09,  1.58s/it][A
Epoch 5 in training:   2%|▏         | 9/469 [00:13<12:02,  1.57s/it][A
Epoch 5 in training:   2%|▏         | 10/469 [00:15<11:56,  1.56s/it][A
Epoch 5 in training:   2%|▏         | 11/469 [00:17<11:50,  1.55s/it][A
Epoch 5 in training:   3%|▎         | 12/469 [00:18<11:46,  1.55s/it][A
Epoch 5 in training:   3%|▎         | 13/469 [00:20<11:48,  1.55s/it

Epoch 5/10 loss: 1.81



Epoch 6 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 6 in training:   0%|          | 1/469 [00:01<11:42,  1.50s/it][A
Epoch 6 in training:   0%|          | 2/469 [00:03<11:40,  1.50s/it][A
Epoch 6 in training:   1%|          | 3/469 [00:04<11:42,  1.51s/it][A
Epoch 6 in training:   1%|          | 4/469 [00:06<11:41,  1.51s/it][A
Epoch 6 in training:   1%|          | 5/469 [00:07<11:38,  1.51s/it][A
Epoch 6 in training:   1%|▏         | 6/469 [00:09<11:36,  1.50s/it][A
Epoch 6 in training:   1%|▏         | 7/469 [00:10<11:36,  1.51s/it][A
Epoch 6 in training:   2%|▏         | 8/469 [00:12<11:34,  1.51s/it][A
Epoch 6 in training:   2%|▏         | 9/469 [00:13<11:32,  1.51s/it][A
Epoch 6 in training:   2%|▏         | 10/469 [00:15<11:34,  1.51s/it][A
Epoch 6 in training:   2%|▏         | 11/469 [00:16<11:41,  1.53s/it][A
Epoch 6 in training:   3%|▎         | 12/469 [00:18<11:35,  1.52s/it][A
Epoch 6 in training:   3%|▎         | 13/469 [00:19<11:31,  1.52s/it

Epoch 6/10 loss: 1.77



Epoch 7 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 7 in training:   0%|          | 1/469 [00:01<11:36,  1.49s/it][A
Epoch 7 in training:   0%|          | 2/469 [00:03<12:18,  1.58s/it][A
Epoch 7 in training:   1%|          | 3/469 [00:04<11:58,  1.54s/it][A
Epoch 7 in training:   1%|          | 4/469 [00:06<11:48,  1.52s/it][A
Epoch 7 in training:   1%|          | 5/469 [00:07<11:43,  1.52s/it][A
Epoch 7 in training:   1%|▏         | 6/469 [00:09<11:40,  1.51s/it][A
Epoch 7 in training:   1%|▏         | 7/469 [00:10<11:36,  1.51s/it][A
Epoch 7 in training:   2%|▏         | 8/469 [00:12<11:38,  1.52s/it][A
Epoch 7 in training:   2%|▏         | 9/469 [00:13<11:33,  1.51s/it][A
Epoch 7 in training:   2%|▏         | 10/469 [00:15<11:31,  1.51s/it][A
Epoch 7 in training:   2%|▏         | 11/469 [00:16<11:29,  1.51s/it][A
Epoch 7 in training:   3%|▎         | 12/469 [00:18<11:29,  1.51s/it][A
Epoch 7 in training:   3%|▎         | 13/469 [00:19<11:27,  1.51s/it

Epoch 7/10 loss: 1.72



Epoch 8 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 8 in training:   0%|          | 1/469 [00:01<11:44,  1.51s/it][A
Epoch 8 in training:   0%|          | 2/469 [00:03<11:41,  1.50s/it][A
Epoch 8 in training:   1%|          | 3/469 [00:04<11:41,  1.50s/it][A
Epoch 8 in training:   1%|          | 4/469 [00:06<11:40,  1.51s/it][A
Epoch 8 in training:   1%|          | 5/469 [00:07<11:37,  1.50s/it][A
Epoch 8 in training:   1%|▏         | 6/469 [00:09<12:00,  1.56s/it][A
Epoch 8 in training:   1%|▏         | 7/469 [00:10<11:52,  1.54s/it][A
Epoch 8 in training:   2%|▏         | 8/469 [00:12<11:49,  1.54s/it][A
Epoch 8 in training:   2%|▏         | 9/469 [00:13<11:44,  1.53s/it][A
Epoch 8 in training:   2%|▏         | 10/469 [00:15<11:39,  1.52s/it][A
Epoch 8 in training:   2%|▏         | 11/469 [00:16<11:35,  1.52s/it][A
Epoch 8 in training:   3%|▎         | 12/469 [00:18<11:35,  1.52s/it][A
Epoch 8 in training:   3%|▎         | 13/469 [00:19<11:31,  1.52s/it

Epoch 8/10 loss: 1.70



Epoch 9 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 9 in training:   0%|          | 1/469 [00:01<12:10,  1.56s/it][A
Epoch 9 in training:   0%|          | 2/469 [00:03<11:50,  1.52s/it][A
Epoch 9 in training:   1%|          | 3/469 [00:04<11:42,  1.51s/it][A
Epoch 9 in training:   1%|          | 4/469 [00:06<11:42,  1.51s/it][A
Epoch 9 in training:   1%|          | 5/469 [00:07<11:39,  1.51s/it][A
Epoch 9 in training:   1%|▏         | 6/469 [00:09<11:36,  1.50s/it][A
Epoch 9 in training:   1%|▏         | 7/469 [00:10<11:34,  1.50s/it][A
Epoch 9 in training:   2%|▏         | 8/469 [00:12<11:29,  1.50s/it][A
Epoch 9 in training:   2%|▏         | 9/469 [00:13<11:26,  1.49s/it][A
Epoch 9 in training:   2%|▏         | 10/469 [00:15<11:24,  1.49s/it][A
Epoch 9 in training:   2%|▏         | 11/469 [00:16<11:27,  1.50s/it][A
Epoch 9 in training:   3%|▎         | 12/469 [00:18<11:23,  1.50s/it][A
Epoch 9 in training:   3%|▎         | 13/469 [00:19<11:39,  1.53s/it

Epoch 9/10 loss: 1.69



Epoch 10 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 10 in training:   0%|          | 1/469 [00:01<12:42,  1.63s/it][A
Epoch 10 in training:   0%|          | 2/469 [00:03<12:03,  1.55s/it][A
Epoch 10 in training:   1%|          | 3/469 [00:04<11:51,  1.53s/it][A
Epoch 10 in training:   1%|          | 4/469 [00:06<11:48,  1.52s/it][A
Epoch 10 in training:   1%|          | 5/469 [00:07<11:41,  1.51s/it][A
Epoch 10 in training:   1%|▏         | 6/469 [00:09<11:37,  1.51s/it][A
Epoch 10 in training:   1%|▏         | 7/469 [00:10<11:45,  1.53s/it][A
Epoch 10 in training:   2%|▏         | 8/469 [00:12<11:40,  1.52s/it][A
Epoch 10 in training:   2%|▏         | 9/469 [00:13<11:33,  1.51s/it][A
Epoch 10 in training:   2%|▏         | 10/469 [00:15<11:33,  1.51s/it][A
Epoch 10 in training:   2%|▏         | 11/469 [00:16<11:34,  1.52s/it][A
Epoch 10 in training:   3%|▎         | 12/469 [00:18<11:31,  1.51s/it][A
Epoch 10 in training:   3%|▎         | 13/469 [00:19<11

Epoch 10/10 loss: 1.68


Testing: 100%|██████████| 79/79 [00:56<00:00,  1.41it/s]

Test loss: 1.66
Test accuracy: 80.14%





In [15]:
path: str = "fastkan_vit_10epochs.pth"

torch.save(mnist_model.state_dict(), path)