In [14]:
# Original code has been adapted from the KAN Paper.
import math
import numpy
import torch
import torch.nn.functional as F

In [15]:
torch.manual_seed(42)
numpy.random.seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [16]:
class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output
        
        output = output.reshape(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


In [17]:
def patchify(images, n_patches): 
    """
    The purpose of this function is to break down the main image into multiple sub-images and map them. 

    Args:
        images (_type_): The image passeed into this function. 
        n_patches (_type_): The number of sub-images that will be created.
    """
    
    n, c, h, w = images.shape
    assert h == w, "Only for square images"
    
    patches = torch.zeros(n, n_patches**2, h * w * c // n_patches ** 2) # The equation to calculate the patches
    patch_size = h // n_patches
    
    for idx, image in enumerate(images):
        for i in range(n_patches): 
            for j in range(n_patches): 
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

In [18]:
def positional_embeddings(seq_length, d): 
    """
    the purpose of this function is to find high and low interaction of a word with surrounding words. 
    We can do so by the following equation below: 
    
    Args: 
        seq_length (int): The length of the sequence/sentence
        d (int): The dimension of the embedding
    """
    
    result = torch.ones(seq_length, d)
    for i in range(seq_length): 
        for j in range(d): 
            result[i][j] = numpy.sin(i / 10000 ** (j / d)) if j % 2 == 0 else numpy.cos(i / 10000 ** (j/ d))
    return result

In [24]:
class MSA(torch.nn.Module): 
    """
        Initializes the Multi-Head Self-Attention (MSA) module with the given dimensions.

        Args:
            d (int): The total dimension of the input.
            n_heads (int): The number of attention heads.

        Returns:
            None
    """
    def __init__(self, d, n_heads): 
        super(MSA, self).__init__()
        self.d = d 
        self.n_heads = n_heads
        
        assert d % n_heads == 0 
        d_head = int(d / n_heads)
        
        self.q_mappings = torch.nn.ModuleList([KANLinear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = torch.nn.ModuleList([KANLinear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = torch.nn.ModuleList([KANLinear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = torch.nn.Softmax(dim=-1)
        
    def forward(self, sequence): 
        result = [] 
        for sequence in sequence: 
            seq_res = [] 
            for head in range(self.n_heads): 
                q_map = self.q_mappings[head]
                k_map = self.k_mappings[head]
                v_map = self.v_mappings[head]
                
                seq = sequence[:, head*self.d_head: (head+1)*self.d_head]
                q, k, v = q_map(seq), k_map(seq), v_map(seq)
                
                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_res.append(attention @ v)
            result.append(torch.hstack(seq_res))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [25]:
class Residual(torch.nn.Module): 
    """
        Initializes a Residual module.

        Args:
            d_hidden (int): The number of hidden dimensions.
            n_heads (int): The number of attention heads.
            mlp_ratio (int, optional): The ratio of the number of hidden dimensions in the MLP layer. Defaults to 4.

        Returns:
            None
    """
    def __init___(self, d_hidden, n_heads, mlp_ratio=4): 
        super(Residual, self).__init__()
        self.d_hidden = d_hidden
        self.n_heads = n_heads
        self.norm1 = torch.nn.LayerNorm(d_hidden)
        self.mhsa = MSA(d_hidden, n_heads)
        self.ml = torch.nn.Sequential(
            KANLinear(d_hidden, mlp_ratio * d_hidden), 
            torch.nn.GELU(), 
            KANLinear(mlp_ratio * d_hidden, d_hidden)
        )
        
    def forward(self, x): 
        out = x = self.mhsa(self.norm1(x))
        out = out + self.mlp(self.norm2(x))
        return out

In [29]:
class KAN_ViT(torch.nn.Module): 
    """
        Initializes a Vision Transformer (ViT) module.

        Args:
            chw (list/tuple of 3 ints): The input image shape.
            n_patches (int, optional): The number of patches to split the image into. Defaults to 10.
            n_blocks (int, optional): The number of blocks in the transformer encoder. Defaults to 2.
            d_hidden (int, optional): The number of hidden dimensions in the transformer encoder. Defaults to 8.
            n_heads (int, optional): The number of attention heads in each block. Defaults to 2.
            out_d (int, optional): The number of output dimensions. Defaults to 10.

        Returns:
            None
    """    
    def __init__(self, chw, n_patches=10, n_blocks=2, d_hidden=8, n_heads=2, out_d=10): 
        super(KAN_ViT, self).__init__()
        
        self.chw = chw
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.d_hidden = d_hidden
        
        assert chw[1] % n_patches == 0 
        assert chw[2] % n_patches == 0
        
        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

        # Linear mapping
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = KANLinear(self.input_d, self.d_hidden)

        # Classification token
        self.v_class = torch.nn.Parameter(torch.rand(1, self.d_hidden))

        # Positional embedding
        self.register_buffer('positional_embeddings', positional_embeddings(n_patches ** 2 + 1, d_hidden),
                             persistent=False)

        # Encoder blocks
        self.blocks = torch.nn.ModuleList([MSA(d_hidden, n_heads) for _ in range(n_blocks)])

        self.mlp = torch.nn.Sequential(
            KANLinear(self.d_hidden, out_d),
            torch.nn.Softmax(dim=-1)
        )

    def forward(self, images):
        n, c, h, w = images.shape
        patches = patchify(images, self.n_patches).to(self.positional_embeddings.device)

        # running tokenization
        tokens = self.linear_mapper(patches)
        tokens = torch.cat((self.v_class.expand(n, 1, -1), tokens), dim=1)
        out = tokens + self.positional_embeddings.repeat(n, 1, 1)

        for block in self.blocks:
            out = block(out)

        out = out[:, 0]
        return self.mlp(out)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mnist_model = KAN_ViT((1, 28, 28), n_patches=7, n_blocks=2, d_hidden=8, n_heads=2, out_d=10).to(device)
optimizer = Adam(mnist_model.parameters(), lr=0.005)

In [34]:
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from tqdm import tqdm, trange

def main(train_loader, test_loader):
    """
    This code contains the training and testing loop for training the vision transformers model. It requires two
    parameters

    :param train_loader: The dataloader for the training set for training the model.
    :param test_loader: The dataloader for the testing set during evaluation phase.
    """
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

    epochs = 5
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in trange(epochs, desc="train"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch + 1}/{epochs} loss: {train_loss:.2f}")

    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)

        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

In [35]:
transform = transforms.ToTensor()
train_mnist = MNIST(root='./mnist', train=True, download=True, transform=transform)
test_mnist = MNIST(root='./mnist', train=False, download=True, transform=transform)
train_loader = DataLoader(train_mnist, shuffle=True, batch_size=128)
test_loader = DataLoader(test_mnist, shuffle=False, batch_size=128)
main(train_loader=train_loader, test_loader=test_loader)

Using device:  cuda (Tesla P100-PCIE-16GB)


train:   0%|          | 0/8 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 1/469 [00:03<26:48,  3.44s/it][A
Epoch 1 in training:   0%|          | 2/469 [00:07<28:18,  3.64s/it][A
Epoch 1 in training:   1%|          | 3/469 [00:10<28:31,  3.67s/it][A
Epoch 1 in training:   1%|          | 4/469 [00:14<28:10,  3.63s/it][A
Epoch 1 in training:   1%|          | 5/469 [00:18<28:28,  3.68s/it][A
Epoch 1 in training:   1%|▏         | 6/469 [00:21<28:05,  3.64s/it][A
Epoch 1 in training:   1%|▏         | 7/469 [00:25<28:00,  3.64s/it][A
Epoch 1 in training:   2%|▏         | 8/469 [00:29<27:46,  3.61s/it][A
Epoch 1 in training:   2%|▏         | 9/469 [00:32<27:56,  3.64s/it][A
Epoch 1 in training:   2%|▏         | 10/469 [00:36<27:44,  3.63s/it][A
Epoch 1 in training:   2%|▏         | 11/469 [00:39<27:39,  3.62s/it][A
Epoch 1 in training:   3%|▎         | 12/469 [00:43<28:08,  3.69s/it][A
Epoch 1 in training:   3

Epoch 1/8 loss: 2.08



Epoch 2 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 2 in training:   0%|          | 1/469 [00:03<27:28,  3.52s/it][A
Epoch 2 in training:   0%|          | 2/469 [00:07<27:29,  3.53s/it][A
Epoch 2 in training:   1%|          | 3/469 [00:10<27:53,  3.59s/it][A
Epoch 2 in training:   1%|          | 4/469 [00:14<27:46,  3.58s/it][A
Epoch 2 in training:   1%|          | 5/469 [00:17<27:46,  3.59s/it][A
Epoch 2 in training:   1%|▏         | 6/469 [00:21<27:55,  3.62s/it][A
Epoch 2 in training:   1%|▏         | 7/469 [00:25<27:50,  3.62s/it][A
Epoch 2 in training:   2%|▏         | 8/469 [00:28<27:39,  3.60s/it][A
Epoch 2 in training:   2%|▏         | 9/469 [00:32<27:49,  3.63s/it][A
Epoch 2 in training:   2%|▏         | 10/469 [00:36<27:35,  3.61s/it][A
Epoch 2 in training:   2%|▏         | 11/469 [00:39<27:25,  3.59s/it][A
Epoch 2 in training:   3%|▎         | 12/469 [00:43<27:37,  3.63s/it][A
Epoch 2 in training:   3%|▎         | 13/469 [00:46<27:31,  3.62s/it

Epoch 2/8 loss: 1.87



Epoch 3 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 3 in training:   0%|          | 1/469 [00:03<28:22,  3.64s/it][A
Epoch 3 in training:   0%|          | 2/469 [00:07<28:14,  3.63s/it][A
Epoch 3 in training:   1%|          | 3/469 [00:11<28:44,  3.70s/it][A
Epoch 3 in training:   1%|          | 4/469 [00:14<28:20,  3.66s/it][A
Epoch 3 in training:   1%|          | 5/469 [00:18<28:08,  3.64s/it][A
Epoch 3 in training:   1%|▏         | 6/469 [00:21<28:16,  3.66s/it][A
Epoch 3 in training:   1%|▏         | 7/469 [00:25<28:06,  3.65s/it][A
Epoch 3 in training:   2%|▏         | 8/469 [00:29<27:47,  3.62s/it][A
Epoch 3 in training:   2%|▏         | 9/469 [00:32<28:00,  3.65s/it][A
Epoch 3 in training:   2%|▏         | 10/469 [00:36<27:51,  3.64s/it][A
Epoch 3 in training:   2%|▏         | 11/469 [00:40<27:38,  3.62s/it][A
Epoch 3 in training:   3%|▎         | 12/469 [00:43<27:47,  3.65s/it][A
Epoch 3 in training:   3%|▎         | 13/469 [00:47<27:37,  3.64s/it

Epoch 3/8 loss: 1.81



Epoch 4 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 4 in training:   0%|          | 1/469 [00:03<28:33,  3.66s/it][A
Epoch 4 in training:   0%|          | 2/469 [00:07<27:55,  3.59s/it][A
Epoch 4 in training:   1%|          | 3/469 [00:10<27:45,  3.57s/it][A
Epoch 4 in training:   1%|          | 4/469 [00:14<28:03,  3.62s/it][A
Epoch 4 in training:   1%|          | 5/469 [00:17<27:44,  3.59s/it][A
Epoch 4 in training:   1%|▏         | 6/469 [00:21<27:35,  3.58s/it][A
Epoch 4 in training:   1%|▏         | 7/469 [00:25<28:01,  3.64s/it][A
Epoch 4 in training:   2%|▏         | 8/469 [00:28<27:45,  3.61s/it][A
Epoch 4 in training:   2%|▏         | 9/469 [00:32<27:29,  3.59s/it][A
Epoch 4 in training:   2%|▏         | 10/469 [00:36<27:41,  3.62s/it][A
Epoch 4 in training:   2%|▏         | 11/469 [00:39<27:31,  3.61s/it][A
Epoch 4 in training:   3%|▎         | 12/469 [00:43<27:40,  3.63s/it][A
Epoch 4 in training:   3%|▎         | 13/469 [00:46<27:28,  3.62s/it

Epoch 4/8 loss: 1.76



Epoch 5 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 5 in training:   0%|          | 1/469 [00:03<30:42,  3.94s/it][A
Epoch 5 in training:   0%|          | 2/469 [00:07<31:10,  4.01s/it][A
Epoch 5 in training:   1%|          | 3/469 [00:12<31:31,  4.06s/it][A
Epoch 5 in training:   1%|          | 4/469 [00:16<31:29,  4.06s/it][A
Epoch 5 in training:   1%|          | 5/469 [00:20<31:19,  4.05s/it][A
Epoch 5 in training:   1%|▏         | 6/469 [00:24<31:03,  4.02s/it][A
Epoch 5 in training:   1%|▏         | 7/469 [00:28<31:06,  4.04s/it][A
Epoch 5 in training:   2%|▏         | 8/469 [00:32<30:48,  4.01s/it][A
Epoch 5 in training:   2%|▏         | 9/469 [00:36<31:06,  4.06s/it][A
Epoch 5 in training:   2%|▏         | 10/469 [00:40<30:59,  4.05s/it][A
Epoch 5 in training:   2%|▏         | 11/469 [00:44<30:50,  4.04s/it][A
Epoch 5 in training:   3%|▎         | 12/469 [00:48<31:14,  4.10s/it][A
Epoch 5 in training:   3%|▎         | 13/469 [00:52<30:52,  4.06s/it

Epoch 5/8 loss: 1.72



Epoch 6 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 6 in training:   0%|          | 1/469 [00:03<27:45,  3.56s/it][A
Epoch 6 in training:   0%|          | 2/469 [00:07<27:34,  3.54s/it][A
Epoch 6 in training:   1%|          | 3/469 [00:10<28:29,  3.67s/it][A
Epoch 6 in training:   1%|          | 4/469 [00:14<28:09,  3.63s/it][A
Epoch 6 in training:   1%|          | 5/469 [00:18<27:56,  3.61s/it][A
Epoch 6 in training:   1%|▏         | 6/469 [00:21<28:10,  3.65s/it][A
Epoch 6 in training:   1%|▏         | 7/469 [00:25<27:53,  3.62s/it][A
Epoch 6 in training:   2%|▏         | 8/469 [00:28<27:39,  3.60s/it][A
Epoch 6 in training:   2%|▏         | 9/469 [00:32<27:35,  3.60s/it][A
Epoch 6 in training:   2%|▏         | 10/469 [00:36<27:51,  3.64s/it][A
Epoch 6 in training:   2%|▏         | 11/469 [00:39<27:53,  3.65s/it][A
Epoch 6 in training:   3%|▎         | 12/469 [00:43<28:04,  3.69s/it][A
Epoch 6 in training:   3%|▎         | 13/469 [00:47<27:52,  3.67s/it

Epoch 6/8 loss: 1.71



Epoch 7 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 7 in training:   0%|          | 1/469 [00:03<27:52,  3.57s/it][A
Epoch 7 in training:   0%|          | 2/469 [00:07<28:23,  3.65s/it][A
Epoch 7 in training:   1%|          | 3/469 [00:10<28:35,  3.68s/it][A
Epoch 7 in training:   1%|          | 4/469 [00:14<28:12,  3.64s/it][A
Epoch 7 in training:   1%|          | 5/469 [00:18<28:11,  3.64s/it][A
Epoch 7 in training:   1%|▏         | 6/469 [00:22<28:37,  3.71s/it][A
Epoch 7 in training:   1%|▏         | 7/469 [00:25<28:18,  3.68s/it][A
Epoch 7 in training:   2%|▏         | 8/469 [00:29<28:11,  3.67s/it][A
Epoch 7 in training:   2%|▏         | 9/469 [00:33<28:34,  3.73s/it][A
Epoch 7 in training:   2%|▏         | 10/469 [00:36<28:22,  3.71s/it][A
Epoch 7 in training:   2%|▏         | 11/469 [00:40<28:01,  3.67s/it][A
Epoch 7 in training:   3%|▎         | 12/469 [00:44<27:49,  3.65s/it][A
Epoch 7 in training:   3%|▎         | 13/469 [00:47<27:59,  3.68s/it

Epoch 7/8 loss: 1.69



Epoch 8 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 8 in training:   0%|          | 1/469 [00:03<27:44,  3.56s/it][A
Epoch 8 in training:   0%|          | 2/469 [00:07<27:29,  3.53s/it][A
Epoch 8 in training:   1%|          | 3/469 [00:10<28:21,  3.65s/it][A
Epoch 8 in training:   1%|          | 4/469 [00:14<27:53,  3.60s/it][A
Epoch 8 in training:   1%|          | 5/469 [00:17<27:35,  3.57s/it][A
Epoch 8 in training:   1%|▏         | 6/469 [00:21<27:54,  3.62s/it][A
Epoch 8 in training:   1%|▏         | 7/469 [00:25<27:36,  3.59s/it][A
Epoch 8 in training:   2%|▏         | 8/469 [00:28<27:25,  3.57s/it][A
Epoch 8 in training:   2%|▏         | 9/469 [00:32<27:42,  3.61s/it][A
Epoch 8 in training:   2%|▏         | 10/469 [00:35<27:29,  3.59s/it][A
Epoch 8 in training:   2%|▏         | 11/469 [00:39<27:21,  3.58s/it][A
Epoch 8 in training:   3%|▎         | 12/469 [00:43<27:54,  3.66s/it][A
Epoch 8 in training:   3%|▎         | 13/469 [00:46<27:36,  3.63s/it

Epoch 8/8 loss: 1.68


Testing: 100%|██████████| 79/79 [01:54<00:00,  1.45s/it]

Test loss: 1.67
Test accuracy: 79.44%





In [None]:
print("Model's state_dict")
for param_tensor in mnist_model.state_dict(): 
  print(param_tensor, "\t", mnist_model.state_dict()[param_tensor].size())

print("Optim's state_dict")
for var in optimizer.state_dict(): 
  print(var, "\t", optimizer.state_dict()[var])

In [None]:
path: str = "./content/kan_vit_5epochs.pth"

torch.save(mnist_model.state_dict(), path)