In [1]:
!pip install monai
!pip install einops



In [2]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import Adam
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy

In [3]:
import torch
import torch.nn as nn


# This is inspired by Kolmogorov-Arnold Networks but using Chebyshev polynomials instead of splines coefficients
class ChebyKANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, degree):
        super(ChebyKANLayer, self).__init__()
        self.inputdim = input_dim
        self.outdim = output_dim
        self.degree = degree

        self.cheby_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
        nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1)))
        self.register_buffer("arange", torch.arange(0, degree + 1, 1))

    def forward(self, x):
        # Since Chebyshev polynomial is defined in [-1, 1]
        # We need to normalize x to [-1, 1] using tanh
        x = torch.tanh(x)
        # View and repeat input degree + 1 times
        x = x.view((-1, self.inputdim, 1)).expand(
            -1, -1, self.degree + 1
        )  # shape = (batch_size, inputdim, self.degree + 1)
        # Apply acos
        x = x.acos()
        # Multiply by arange [0 .. degree]
        x *= self.arange
        # Apply cos
        x = x.cos()
        # Compute the Chebyshev interpolation
        y = torch.einsum(
            "bid,iod->bo", x, self.cheby_coeffs
        )  # shape = (batch_size, outdim)
        y = y.view(-1, self.outdim)
        return y

In [4]:
torch.manual_seed(42)
numpy.random.seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [5]:
class ChebyMSA(torch.nn.Module):
    """
    This is the template implementation of the "Multi-Scale Attention" Layer.

    The query, key and value mapping are matrix-multipled against each other in order to
    find the attention, or, the relation of a word and its interaction with surrounding words.
    """
    def __init__(self, d, n_heads=4):
        super(ChebyMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0  # Shouldn't divide dimension (d) into n_heads

        d_head = int(d / n_heads)
        self.q_mappings = torch.nn.ModuleList([ChebyKANLayer(d_head, d_head, 4) for _ in range(self.n_heads)])
        self.k_mappings = torch.nn.ModuleList([ChebyKANLayer(d_head, d_head, 4) for _ in range(self.n_heads)])
        self.v_mappings = torch.nn.ModuleList([ChebyKANLayer(d_head, d_head, 4) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, sequences):
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [6]:
class ChebyViT(torch.nn.Module): 
    """
        Initializes a Vision Transformer (ViT) module.

        Args:
            chw (list/tuple of 3 ints): The input image shape.
            n_patches (int, optional): The number of patches to split the image into. Defaults to 10.
            n_blocks (int, optional): The number of blocks in the transformer encoder. Defaults to 2.
            d_hidden (int, optional): The number of hidden dimensions in the transformer encoder. Defaults to 8.
            n_heads (int, optional): The number of attention heads in each block. Defaults to 2.
            out_d (int, optional): The number of output dimensions. Defaults to 10.

        Returns:
            None
    """    
    def __init__(self, chw, n_patches=10, n_blocks=2, d_hidden=8, n_heads=2, out_d=10): 
        super(ChebyViT, self).__init__()
        
        self.chw = chw
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.d_hidden = d_hidden
        
        assert chw[1] % n_patches == 0 
        assert chw[2] % n_patches == 0
        
        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

        # Linear mapping
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = ChebyKANLayer(self.input_d, self.d_hidden, 4)

        # Classification token
        self.v_class = torch.nn.Parameter(torch.rand(1, self.d_hidden))

        # Positional embedding
        self.register_buffer('pos_embeddings', self.positional_embeddings(n_patches ** 2 + 1, d_hidden),
                             persistent=False)

        # Encoder blocks
        self.blocks = torch.nn.ModuleList([ChebyMSA(d_hidden, n_heads) for _ in range(n_blocks)])

        self.mlp = torch.nn.Sequential(
            ChebyKANLayer(self.d_hidden, out_d, 4),
            torch.nn.Softmax(dim=-1)
        )
        
    def patchify(self, images, n_patches):
        """
        The purpose of this function is to break down the main image into multiple sub-images and map them.

        Args:
            images (_type_): The image passeed into this function.
            n_patches (_type_): The number of sub-images that will be created.
        """

        n, c, h, w = images.shape
        assert h == w, "Only for square images"

        patches = torch.zeros(n, n_patches**2, h * w * c // n_patches ** 2) # The equation to calculate the patches
        patch_size = h // n_patches

        for idx, image in enumerate(images):
            for i in range(n_patches):
                for j in range(n_patches):
                    patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                    patches[idx, i * n_patches + j] = patch.flatten()
        return patches

    def positional_embeddings(self, seq_length, d):
        """
        the purpose of this function is to find high and low interaction of a word with surrounding words.
        We can do so by the following equation below:

        Args:
            seq_length (int): The length of the sequence/sentence
            d (int): The dimension of the embedding
        """

        result = torch.ones(seq_length, d)
        for i in range(seq_length):
            for j in range(d):
                result[i][j] = numpy.sin(i / 10000 ** (j / d)) if j % 2 == 0 else numpy.cos(i / 10000 ** (j/ d))
        return result

    def forward(self, images):
        n, c, h, w = images.shape
        patches = self.patchify(images, self.n_patches).to(self.pos_embeddings.device)
        # running tokenization
        tokens = self.linear_mapper(patches)
        # Reshape tokens to maintain batch dimension
        tokens = tokens.reshape(n, -1, tokens.shape[-1])
        # Expand v_class
        v_class_expanded = self.v_class.expand(n, 1, -1)
        # Concatenate
        tokens = torch.cat((v_class_expanded, tokens), dim=1)
        out = tokens + self.pos_embeddings.repeat(n, 1, 1)
        for block in self.blocks:
            out = block(out)

        out = out[:, 0]
        return self.mlp(out)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mnist_model = ChebyViT((1, 28, 28), n_patches=7, n_blocks=2, d_hidden=8, n_heads=2, out_d=10).to(device)
optimizer = Adam(mnist_model.parameters(), lr=0.005)

In [8]:
from tqdm import tqdm, trange
def main(train_loader, test_loader):
    """
    This code contains the training and testing loop for training the vision transformers model. It requires two
    parameters

    :param train_loader: The dataloader for the training set for training the model.
    :param test_loader: The dataloader for the testing set during evaluation phase.
    """
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

    epochs = 10
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in trange(epochs, desc="train"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch + 1}/{epochs} loss: {train_loss:.2f}")

    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)

        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

In [9]:
transform = transforms.ToTensor()
train_mnist = MNIST(root='./mnist', train=True, download=True, transform=transform)
test_mnist = MNIST(root='./mnist', train=False, download=True, transform=transform)
train_loader = DataLoader(train_mnist, shuffle=True, batch_size=128)
test_loader = DataLoader(test_mnist, shuffle=False, batch_size=128)
main(train_loader=train_loader, test_loader=test_loader)

Using device:  cuda (Tesla P100-PCIE-16GB)


train:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 1/469 [00:01<12:05,  1.55s/it][A
Epoch 1 in training:   0%|          | 2/469 [00:02<11:13,  1.44s/it][A
Epoch 1 in training:   1%|          | 3/469 [00:04<10:59,  1.42s/it][A
Epoch 1 in training:   1%|          | 4/469 [00:05<10:52,  1.40s/it][A
Epoch 1 in training:   1%|          | 5/469 [00:07<10:54,  1.41s/it][A
Epoch 1 in training:   1%|▏         | 6/469 [00:08<10:49,  1.40s/it][A
Epoch 1 in training:   1%|▏         | 7/469 [00:10<11:07,  1.44s/it][A
Epoch 1 in training:   2%|▏         | 8/469 [00:11<10:59,  1.43s/it][A
Epoch 1 in training:   2%|▏         | 9/469 [00:12<10:55,  1.43s/it][A
Epoch 1 in training:   2%|▏         | 10/469 [00:14<10:55,  1.43s/it][A
Epoch 1 in training:   2%|▏         | 11/469 [00:15<10:49,  1.42s/it][A
Epoch 1 in training:   3%|▎         | 12/469 [00:17<10:57,  1.44s/it][A
Epoch 1 in training:   

Epoch 1/10 loss: 2.06



Epoch 2 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 2 in training:   0%|          | 1/469 [00:01<10:36,  1.36s/it][A
Epoch 2 in training:   0%|          | 2/469 [00:02<10:40,  1.37s/it][A
Epoch 2 in training:   1%|          | 3/469 [00:04<10:45,  1.39s/it][A
Epoch 2 in training:   1%|          | 4/469 [00:05<10:44,  1.39s/it][A
Epoch 2 in training:   1%|          | 5/469 [00:06<10:42,  1.38s/it][A
Epoch 2 in training:   1%|▏         | 6/469 [00:08<10:46,  1.40s/it][A
Epoch 2 in training:   1%|▏         | 7/469 [00:09<10:43,  1.39s/it][A
Epoch 2 in training:   2%|▏         | 8/469 [00:11<10:43,  1.40s/it][A
Epoch 2 in training:   2%|▏         | 9/469 [00:12<10:49,  1.41s/it][A
Epoch 2 in training:   2%|▏         | 10/469 [00:14<10:52,  1.42s/it][A
Epoch 2 in training:   2%|▏         | 11/469 [00:15<10:46,  1.41s/it][A
Epoch 2 in training:   3%|▎         | 12/469 [00:16<10:39,  1.40s/it][A
Epoch 2 in training:   3%|▎         | 13/469 [00:18<10:53,  1.43s/it

Epoch 2/10 loss: 1.91



Epoch 3 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 3 in training:   0%|          | 1/469 [00:01<10:35,  1.36s/it][A
Epoch 3 in training:   0%|          | 2/469 [00:02<10:37,  1.36s/it][A
Epoch 3 in training:   1%|          | 3/469 [00:04<10:49,  1.39s/it][A
Epoch 3 in training:   1%|          | 4/469 [00:05<10:44,  1.39s/it][A
Epoch 3 in training:   1%|          | 5/469 [00:06<10:45,  1.39s/it][A
Epoch 3 in training:   1%|▏         | 6/469 [00:08<10:44,  1.39s/it][A
Epoch 3 in training:   1%|▏         | 7/469 [00:09<10:41,  1.39s/it][A
Epoch 3 in training:   2%|▏         | 8/469 [00:11<10:37,  1.38s/it][A
Epoch 3 in training:   2%|▏         | 9/469 [00:12<10:35,  1.38s/it][A
Epoch 3 in training:   2%|▏         | 10/469 [00:13<10:34,  1.38s/it][A
Epoch 3 in training:   2%|▏         | 11/469 [00:15<10:32,  1.38s/it][A
Epoch 3 in training:   3%|▎         | 12/469 [00:16<10:36,  1.39s/it][A
Epoch 3 in training:   3%|▎         | 13/469 [00:18<10:35,  1.39s/it

Epoch 3/10 loss: 1.83



Epoch 4 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 4 in training:   0%|          | 1/469 [00:01<10:42,  1.37s/it][A
Epoch 4 in training:   0%|          | 2/469 [00:02<10:47,  1.39s/it][A
Epoch 4 in training:   1%|          | 3/469 [00:04<11:14,  1.45s/it][A
Epoch 4 in training:   1%|          | 4/469 [00:05<10:57,  1.41s/it][A
Epoch 4 in training:   1%|          | 5/469 [00:07<10:50,  1.40s/it][A
Epoch 4 in training:   1%|▏         | 6/469 [00:08<10:43,  1.39s/it][A
Epoch 4 in training:   1%|▏         | 7/469 [00:09<10:41,  1.39s/it][A
Epoch 4 in training:   2%|▏         | 8/469 [00:11<10:40,  1.39s/it][A
Epoch 4 in training:   2%|▏         | 9/469 [00:12<10:42,  1.40s/it][A
Epoch 4 in training:   2%|▏         | 10/469 [00:13<10:38,  1.39s/it][A
Epoch 4 in training:   2%|▏         | 11/469 [00:15<10:34,  1.39s/it][A
Epoch 4 in training:   3%|▎         | 12/469 [00:16<10:31,  1.38s/it][A
Epoch 4 in training:   3%|▎         | 13/469 [00:18<10:32,  1.39s/it

Epoch 4/10 loss: 1.80



Epoch 5 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 5 in training:   0%|          | 1/469 [00:01<10:49,  1.39s/it][A
Epoch 5 in training:   0%|          | 2/469 [00:02<10:48,  1.39s/it][A
Epoch 5 in training:   1%|          | 3/469 [00:04<10:45,  1.39s/it][A
Epoch 5 in training:   1%|          | 4/469 [00:05<10:49,  1.40s/it][A
Epoch 5 in training:   1%|          | 5/469 [00:06<10:46,  1.39s/it][A
Epoch 5 in training:   1%|▏         | 6/469 [00:08<10:46,  1.40s/it][A
Epoch 5 in training:   1%|▏         | 7/469 [00:09<10:43,  1.39s/it][A
Epoch 5 in training:   2%|▏         | 8/469 [00:11<11:03,  1.44s/it][A
Epoch 5 in training:   2%|▏         | 9/469 [00:12<10:54,  1.42s/it][A
Epoch 5 in training:   2%|▏         | 10/469 [00:14<10:46,  1.41s/it][A
Epoch 5 in training:   2%|▏         | 11/469 [00:15<10:43,  1.41s/it][A
Epoch 5 in training:   3%|▎         | 12/469 [00:16<10:38,  1.40s/it][A
Epoch 5 in training:   3%|▎         | 13/469 [00:18<10:33,  1.39s/it

Epoch 5/10 loss: 1.78



Epoch 6 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 6 in training:   0%|          | 1/469 [00:01<10:40,  1.37s/it][A
Epoch 6 in training:   0%|          | 2/469 [00:02<10:43,  1.38s/it][A
Epoch 6 in training:   1%|          | 3/469 [00:04<10:55,  1.41s/it][A
Epoch 6 in training:   1%|          | 4/469 [00:05<10:54,  1.41s/it][A
Epoch 6 in training:   1%|          | 5/469 [00:06<10:48,  1.40s/it][A
Epoch 6 in training:   1%|▏         | 6/469 [00:08<10:43,  1.39s/it][A
Epoch 6 in training:   1%|▏         | 7/469 [00:09<10:42,  1.39s/it][A
Epoch 6 in training:   2%|▏         | 8/469 [00:11<10:41,  1.39s/it][A
Epoch 6 in training:   2%|▏         | 9/469 [00:12<10:42,  1.40s/it][A
Epoch 6 in training:   2%|▏         | 10/469 [00:13<10:39,  1.39s/it][A
Epoch 6 in training:   2%|▏         | 11/469 [00:15<10:35,  1.39s/it][A
Epoch 6 in training:   3%|▎         | 12/469 [00:16<10:33,  1.39s/it][A
Epoch 6 in training:   3%|▎         | 13/469 [00:18<10:30,  1.38s/it

Epoch 6/10 loss: 1.76



Epoch 7 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 7 in training:   0%|          | 1/469 [00:01<10:36,  1.36s/it][A
Epoch 7 in training:   0%|          | 2/469 [00:02<10:37,  1.36s/it][A
Epoch 7 in training:   1%|          | 3/469 [00:04<10:40,  1.37s/it][A
Epoch 7 in training:   1%|          | 4/469 [00:05<10:54,  1.41s/it][A
Epoch 7 in training:   1%|          | 5/469 [00:07<10:58,  1.42s/it][A
Epoch 7 in training:   1%|▏         | 6/469 [00:08<10:51,  1.41s/it][A
Epoch 7 in training:   1%|▏         | 7/469 [00:09<10:48,  1.40s/it][A
Epoch 7 in training:   2%|▏         | 8/469 [00:11<10:46,  1.40s/it][A
Epoch 7 in training:   2%|▏         | 9/469 [00:12<10:42,  1.40s/it][A
Epoch 7 in training:   2%|▏         | 10/469 [00:13<10:38,  1.39s/it][A
Epoch 7 in training:   2%|▏         | 11/469 [00:15<10:42,  1.40s/it][A
Epoch 7 in training:   3%|▎         | 12/469 [00:16<10:37,  1.39s/it][A
Epoch 7 in training:   3%|▎         | 13/469 [00:18<10:32,  1.39s/it

Epoch 7/10 loss: 1.74



Epoch 8 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 8 in training:   0%|          | 1/469 [00:01<11:50,  1.52s/it][A
Epoch 8 in training:   0%|          | 2/469 [00:02<11:14,  1.45s/it][A
Epoch 8 in training:   1%|          | 3/469 [00:04<11:11,  1.44s/it][A
Epoch 8 in training:   1%|          | 4/469 [00:05<11:23,  1.47s/it][A
Epoch 8 in training:   1%|          | 5/469 [00:07<11:09,  1.44s/it][A
Epoch 8 in training:   1%|▏         | 6/469 [00:08<11:04,  1.44s/it][A
Epoch 8 in training:   1%|▏         | 7/469 [00:10<11:00,  1.43s/it][A
Epoch 8 in training:   2%|▏         | 8/469 [00:11<10:53,  1.42s/it][A
Epoch 8 in training:   2%|▏         | 9/469 [00:12<10:53,  1.42s/it][A
Epoch 8 in training:   2%|▏         | 10/469 [00:14<10:50,  1.42s/it][A
Epoch 8 in training:   2%|▏         | 11/469 [00:15<10:55,  1.43s/it][A
Epoch 8 in training:   3%|▎         | 12/469 [00:17<10:50,  1.42s/it][A
Epoch 8 in training:   3%|▎         | 13/469 [00:18<10:44,  1.41s/it

Epoch 8/10 loss: 1.72



Epoch 9 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 9 in training:   0%|          | 1/469 [00:01<10:44,  1.38s/it][A
Epoch 9 in training:   0%|          | 2/469 [00:02<10:41,  1.37s/it][A
Epoch 9 in training:   1%|          | 3/469 [00:04<10:41,  1.38s/it][A
Epoch 9 in training:   1%|          | 4/469 [00:05<10:44,  1.39s/it][A
Epoch 9 in training:   1%|          | 5/469 [00:06<10:40,  1.38s/it][A
Epoch 9 in training:   1%|▏         | 6/469 [00:08<10:47,  1.40s/it][A
Epoch 9 in training:   1%|▏         | 7/469 [00:09<10:42,  1.39s/it][A
Epoch 9 in training:   2%|▏         | 8/469 [00:11<10:43,  1.40s/it][A
Epoch 9 in training:   2%|▏         | 9/469 [00:12<10:40,  1.39s/it][A
Epoch 9 in training:   2%|▏         | 10/469 [00:14<10:55,  1.43s/it][A
Epoch 9 in training:   2%|▏         | 11/469 [00:15<10:50,  1.42s/it][A
Epoch 9 in training:   3%|▎         | 12/469 [00:16<10:43,  1.41s/it][A
Epoch 9 in training:   3%|▎         | 13/469 [00:18<10:47,  1.42s/it

Epoch 9/10 loss: 1.69



Epoch 10 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 10 in training:   0%|          | 1/469 [00:01<10:59,  1.41s/it][A
Epoch 10 in training:   0%|          | 2/469 [00:02<11:01,  1.42s/it][A
Epoch 10 in training:   1%|          | 3/469 [00:04<11:00,  1.42s/it][A
Epoch 10 in training:   1%|          | 4/469 [00:05<11:03,  1.43s/it][A
Epoch 10 in training:   1%|          | 5/469 [00:07<10:58,  1.42s/it][A
Epoch 10 in training:   1%|▏         | 6/469 [00:08<10:52,  1.41s/it][A
Epoch 10 in training:   1%|▏         | 7/469 [00:09<10:49,  1.41s/it][A
Epoch 10 in training:   2%|▏         | 8/469 [00:11<10:47,  1.40s/it][A
Epoch 10 in training:   2%|▏         | 9/469 [00:12<11:07,  1.45s/it][A
Epoch 10 in training:   2%|▏         | 10/469 [00:14<10:59,  1.44s/it][A
Epoch 10 in training:   2%|▏         | 11/469 [00:15<10:55,  1.43s/it][A
Epoch 10 in training:   3%|▎         | 12/469 [00:17<11:07,  1.46s/it][A
Epoch 10 in training:   3%|▎         | 13/469 [00:18<11

Epoch 10/10 loss: 1.68


Testing: 100%|██████████| 79/79 [00:49<00:00,  1.60it/s]

Test loss: 1.70
Test accuracy: 76.14%





In [10]:
path: str = "chebykan_vit_10epochs.pth"

torch.save(mnist_model.state_dict(), path)