In [1]:
import torch as th
import numpy as np

In [2]:
import math
import numpy
import torch
import torch.nn.functional as F

torch.manual_seed(42)
numpy.random.seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
class NaiveFourierKANLayer(th.nn.Module):
    def __init__( self, inputdim, outdim, gridsize, addbias=True, smooth_initialization=False):
        super(NaiveFourierKANLayer,self).__init__()
        self.gridsize= gridsize
        self.addbias = addbias
        self.inputdim = inputdim
        self.outdim = outdim
        
        # With smooth_initialization, fourier coefficients are attenuated by the square of their frequency.
        # This makes KAN's scalar functions smooth at initialization.
        # Without smooth_initialization, high gridsizes will lead to high-frequency scalar functions,
        # with high derivatives and low correlation between similar inputs.
        grid_norm_factor = (th.arange(gridsize) + 1)**2 if smooth_initialization else np.sqrt(gridsize)
        
        #The normalization has been chosen so that if given inputs where each coordinate is of unit variance,
        #then each coordinates of the output is of unit variance 
        #independently of the various sizes
        self.fouriercoeffs = th.nn.Parameter( th.randn(2,outdim,inputdim,gridsize) / 
                                                (np.sqrt(inputdim) * grid_norm_factor ) )
        if( self.addbias ):
            self.bias  = th.nn.Parameter( th.zeros(1,outdim))

    #x.shape ( ... , indim ) 
    #out.shape ( ..., outdim)
    def forward(self,x):
        xshp = x.shape
        outshape = xshp[0:-1]+(self.outdim,)
        x = th.reshape(x,(-1,self.inputdim))
        #Starting at 1 because constant terms are in the bias
        k = th.reshape( th.arange(1,self.gridsize+1,device=x.device),(1,1,1,self.gridsize))
        xrshp = th.reshape(x,(x.shape[0],1,x.shape[1],1) ) 
        #This should be fused to avoid materializing memory
        c = th.cos( k*xrshp )
        s = th.sin( k*xrshp )
        #We compute the interpolation of the various functions defined by their fourier coefficient for each input coordinates and we sum them 
        y =  th.sum( c*self.fouriercoeffs[0:1],(-2,-1)) 
        y += th.sum( s*self.fouriercoeffs[1:2],(-2,-1))
        if( self.addbias):
            y += self.bias
        #End fuse
        '''
        #You can use einsum instead to reduce memory usage
        #It stills not as good as fully fused but it should help
        #einsum is usually slower though
        c = th.reshape(c,(1,x.shape[0],x.shape[1],self.gridsize))
        s = th.reshape(s,(1,x.shape[0],x.shape[1],self.gridsize))
        y2 = th.einsum( "dbik,djik->bj", th.concat([c,s],axis=0) ,self.fouriercoeffs )
        if( self.addbias):
            y2 += self.bias
        diff = th.sum((y2-y)**2)
        print("diff")
        print(diff) #should be ~0
        '''
        y = th.reshape( y, outshape)
        return y

In [4]:
class NaiveFourierMSA(torch.nn.Module):
    """
        Initializes the Multi-Head Self-Attention (MSA) module with the given dimensions.

        Args:
            d (int): The total dimension of the input.
            n_heads (int): The number of attention heads.

        Returns:
            None
    """
    def __init__(self, d, n_heads):
        super(NaiveFourierMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0
        d_head = int(d / n_heads)

        self.q_mappings = torch.nn.ModuleList([NaiveFourierKANLayer(d_head, d_head, gridsize=4) for _ in range(self.n_heads)])
        self.k_mappings = torch.nn.ModuleList([NaiveFourierKANLayer(d_head, d_head, gridsize=4) for _ in range(self.n_heads)])
        self.v_mappings = torch.nn.ModuleList([NaiveFourierKANLayer(d_head, d_head, gridsize=4) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, sequence):
        result = []
        for sequence in sequence:
            seq_res = []
            for head in range(self.n_heads):
                q_map = self.q_mappings[head]
                k_map = self.k_mappings[head]
                v_map = self.v_mappings[head]

                seq = sequence[:, head*self.d_head: (head+1)*self.d_head]
                q, k, v = q_map(seq), k_map(seq), v_map(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_res.append(attention @ v)
            result.append(torch.hstack(seq_res))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [5]:
class FourierKAN_ViT(torch.nn.Module): 
    """
        Initializes a Vision Transformer (ViT) module.

        Args:
            chw (list/tuple of 3 ints): The input image shape.
            n_patches (int, optional): The number of patches to split the image into. Defaults to 10.
            n_blocks (int, optional): The number of blocks in the transformer encoder. Defaults to 2.
            d_hidden (int, optional): The number of hidden dimensions in the transformer encoder. Defaults to 8.
            n_heads (int, optional): The number of attention heads in each block. Defaults to 2.
            out_d (int, optional): The number of output dimensions. Defaults to 10.

        Returns:
            None
    """    
    def __init__(self, chw, n_patches=10, n_blocks=2, d_hidden=8, n_heads=2, out_d=10): 
        super(FourierKAN_ViT, self).__init__()
        
        self.chw = chw
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.d_hidden = d_hidden
        
        assert chw[1] % n_patches == 0 
        assert chw[2] % n_patches == 0
        
        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

        # Linear mapping
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = NaiveFourierKANLayer(self.input_d, self.d_hidden, gridsize=28)

        # Classification token
        self.v_class = torch.nn.Parameter(torch.rand(1, self.d_hidden))

        # Positional embedding
        self.register_buffer('pos_embeddings', self.positional_embeddings(n_patches ** 2 + 1, d_hidden),
                             persistent=False)

        # Encoder blocks
        self.blocks = torch.nn.ModuleList([NaiveFourierMSA(d_hidden, n_heads) for _ in range(n_blocks)])

        self.mlp = torch.nn.Sequential(
            NaiveFourierKANLayer(self.d_hidden, out_d, gridsize=4),
            torch.nn.Softmax(dim=-1)
        )
        
    def patchify(self, images, n_patches):
        """
        The purpose of this function is to break down the main image into multiple sub-images and map them.

        Args:
            images (_type_): The image passeed into this function.
            n_patches (_type_): The number of sub-images that will be created.
        """

        n, c, h, w = images.shape
        assert h == w, "Only for square images"

        patches = torch.zeros(n, n_patches**2, h * w * c // n_patches ** 2) # The equation to calculate the patches
        patch_size = h // n_patches

        for idx, image in enumerate(images):
            for i in range(n_patches):
                for j in range(n_patches):
                    patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                    patches[idx, i * n_patches + j] = patch.flatten()
        return patches

    def positional_embeddings(self, seq_length, d):
        """
        the purpose of this function is to find high and low interaction of a word with surrounding words.
        We can do so by the following equation below:

        Args:
            seq_length (int): The length of the sequence/sentence
            d (int): The dimension of the embedding
        """

        result = torch.ones(seq_length, d)
        for i in range(seq_length):
            for j in range(d):
                result[i][j] = numpy.sin(i / 10000 ** (j / d)) if j % 2 == 0 else numpy.cos(i / 10000 ** (j/ d))
        return result

    def forward(self, images):
        n, c, h, w = images.shape
        patches = self.patchify(images, self.n_patches).to(self.pos_embeddings.device)

        # running tokenization
        tokens = self.linear_mapper(patches)
        tokens = torch.cat((self.v_class.expand(n, 1, -1), tokens), dim=1)
        out = tokens + self.pos_embeddings.repeat(n, 1, 1)

        for block in self.blocks:
            out = block(out)

        out = out[:, 0]
        return self.mlp(out)

In [6]:
from torch.optim import Adam
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mnist_model = FourierKAN_ViT((1, 28, 28), n_patches=7, n_blocks=2, d_hidden=8, n_heads=2, out_d=10).to(device)
optimizer = Adam(mnist_model.parameters(), lr=0.005)

In [7]:
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from tqdm import tqdm, trange

def main(train_loader, test_loader):
    """
    This code contains the training and testing loop for training the vision transformers model. It requires two
    parameters

    :param train_loader: The dataloader for the training set for training the model.
    :param test_loader: The dataloader for the testing set during evaluation phase.
    """
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

    epochs = 10
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in trange(epochs, desc="train"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch + 1}/{epochs} loss: {train_loss:.2f}")

    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)

        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

In [8]:
transform = transforms.ToTensor()
train_mnist = MNIST(root='./mnist', train=True, download=True, transform=transform)
test_mnist = MNIST(root='./mnist', train=False, download=True, transform=transform)
train_loader = DataLoader(train_mnist, shuffle=True, batch_size=128)
test_loader = DataLoader(test_mnist, shuffle=False, batch_size=128)
main(train_loader=train_loader, test_loader=test_loader)

Using device:  cuda (Tesla P100-PCIE-16GB)


train:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 1/469 [00:02<15:44,  2.02s/it][A
Epoch 1 in training:   0%|          | 2/469 [00:03<13:41,  1.76s/it][A
Epoch 1 in training:   1%|          | 3/469 [00:05<12:59,  1.67s/it][A
Epoch 1 in training:   1%|          | 4/469 [00:06<12:13,  1.58s/it][A
Epoch 1 in training:   1%|          | 5/469 [00:08<12:07,  1.57s/it][A
Epoch 1 in training:   1%|▏         | 6/469 [00:09<11:44,  1.52s/it][A
Epoch 1 in training:   1%|▏         | 7/469 [00:11<11:31,  1.50s/it][A
Epoch 1 in training:   2%|▏         | 8/469 [00:12<11:38,  1.51s/it][A
Epoch 1 in training:   2%|▏         | 9/469 [00:14<11:25,  1.49s/it][A
Epoch 1 in training:   2%|▏         | 10/469 [00:15<11:18,  1.48s/it][A
Epoch 1 in training:   2%|▏         | 11/469 [00:16<11:13,  1.47s/it][A
Epoch 1 in training:   3%|▎         | 12/469 [00:18<11:13,  1.47s/it][A
Epoch 1 in training:   

Epoch 1/10 loss: 2.28



Epoch 2 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 2 in training:   0%|          | 1/469 [00:01<11:15,  1.44s/it][A
Epoch 2 in training:   0%|          | 2/469 [00:02<11:13,  1.44s/it][A
Epoch 2 in training:   1%|          | 3/469 [00:04<11:11,  1.44s/it][A
Epoch 2 in training:   1%|          | 4/469 [00:05<11:10,  1.44s/it][A
Epoch 2 in training:   1%|          | 5/469 [00:07<11:07,  1.44s/it][A
Epoch 2 in training:   1%|▏         | 6/469 [00:08<11:05,  1.44s/it][A
Epoch 2 in training:   1%|▏         | 7/469 [00:10<11:27,  1.49s/it][A
Epoch 2 in training:   2%|▏         | 8/469 [00:11<11:17,  1.47s/it][A
Epoch 2 in training:   2%|▏         | 9/469 [00:13<11:11,  1.46s/it][A
Epoch 2 in training:   2%|▏         | 10/469 [00:14<11:07,  1.45s/it][A
Epoch 2 in training:   2%|▏         | 11/469 [00:15<11:03,  1.45s/it][A
Epoch 2 in training:   3%|▎         | 12/469 [00:17<10:57,  1.44s/it][A
Epoch 2 in training:   3%|▎         | 13/469 [00:18<10:56,  1.44s/it

Epoch 2/10 loss: 2.27



Epoch 3 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 3 in training:   0%|          | 1/469 [00:01<11:06,  1.42s/it][A
Epoch 3 in training:   0%|          | 2/469 [00:02<11:12,  1.44s/it][A
Epoch 3 in training:   1%|          | 3/469 [00:04<11:07,  1.43s/it][A
Epoch 3 in training:   1%|          | 4/469 [00:05<11:18,  1.46s/it][A
Epoch 3 in training:   1%|          | 5/469 [00:07<11:31,  1.49s/it][A
Epoch 3 in training:   1%|▏         | 6/469 [00:08<11:18,  1.47s/it][A
Epoch 3 in training:   1%|▏         | 7/469 [00:10<11:10,  1.45s/it][A
Epoch 3 in training:   2%|▏         | 8/469 [00:11<11:16,  1.47s/it][A
Epoch 3 in training:   2%|▏         | 9/469 [00:13<11:40,  1.52s/it][A
Epoch 3 in training:   2%|▏         | 10/469 [00:14<11:38,  1.52s/it][A
Epoch 3 in training:   2%|▏         | 11/469 [00:16<11:41,  1.53s/it][A
Epoch 3 in training:   3%|▎         | 12/469 [00:17<11:39,  1.53s/it][A
Epoch 3 in training:   3%|▎         | 13/469 [00:19<11:23,  1.50s/it

Epoch 3/10 loss: 2.26



Epoch 4 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 4 in training:   0%|          | 1/469 [00:01<11:08,  1.43s/it][A
Epoch 4 in training:   0%|          | 2/469 [00:02<11:34,  1.49s/it][A
Epoch 4 in training:   1%|          | 3/469 [00:04<11:21,  1.46s/it][A
Epoch 4 in training:   1%|          | 4/469 [00:05<11:12,  1.45s/it][A
Epoch 4 in training:   1%|          | 5/469 [00:07<11:07,  1.44s/it][A
Epoch 4 in training:   1%|▏         | 6/469 [00:08<11:05,  1.44s/it][A
Epoch 4 in training:   1%|▏         | 7/469 [00:10<11:01,  1.43s/it][A
Epoch 4 in training:   2%|▏         | 8/469 [00:11<11:13,  1.46s/it][A
Epoch 4 in training:   2%|▏         | 9/469 [00:13<11:11,  1.46s/it][A
Epoch 4 in training:   2%|▏         | 10/469 [00:14<11:04,  1.45s/it][A
Epoch 4 in training:   2%|▏         | 11/469 [00:15<10:58,  1.44s/it][A
Epoch 4 in training:   3%|▎         | 12/469 [00:17<10:56,  1.44s/it][A
Epoch 4 in training:   3%|▎         | 13/469 [00:18<10:52,  1.43s/it

Epoch 4/10 loss: 2.25



Epoch 5 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 5 in training:   0%|          | 1/469 [00:01<11:16,  1.45s/it][A
Epoch 5 in training:   0%|          | 2/469 [00:02<11:16,  1.45s/it][A
Epoch 5 in training:   1%|          | 3/469 [00:04<11:25,  1.47s/it][A
Epoch 5 in training:   1%|          | 4/469 [00:05<11:21,  1.47s/it][A
Epoch 5 in training:   1%|          | 5/469 [00:07<11:17,  1.46s/it][A
Epoch 5 in training:   1%|▏         | 6/469 [00:08<11:33,  1.50s/it][A
Epoch 5 in training:   1%|▏         | 7/469 [00:10<11:26,  1.49s/it][A
Epoch 5 in training:   2%|▏         | 8/469 [00:11<11:21,  1.48s/it][A
Epoch 5 in training:   2%|▏         | 9/469 [00:13<11:16,  1.47s/it][A
Epoch 5 in training:   2%|▏         | 10/469 [00:14<11:18,  1.48s/it][A
Epoch 5 in training:   2%|▏         | 11/469 [00:16<11:12,  1.47s/it][A
Epoch 5 in training:   3%|▎         | 12/469 [00:17<11:30,  1.51s/it][A
Epoch 5 in training:   3%|▎         | 13/469 [00:19<11:21,  1.49s/it

Epoch 5/10 loss: 2.24



Epoch 6 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 6 in training:   0%|          | 1/469 [00:01<11:22,  1.46s/it][A
Epoch 6 in training:   0%|          | 2/469 [00:02<11:22,  1.46s/it][A
Epoch 6 in training:   1%|          | 3/469 [00:04<11:17,  1.45s/it][A
Epoch 6 in training:   1%|          | 4/469 [00:05<11:40,  1.51s/it][A
Epoch 6 in training:   1%|          | 5/469 [00:07<11:31,  1.49s/it][A
Epoch 6 in training:   1%|▏         | 6/469 [00:08<11:29,  1.49s/it][A
Epoch 6 in training:   1%|▏         | 7/469 [00:10<11:24,  1.48s/it][A
Epoch 6 in training:   2%|▏         | 8/469 [00:11<11:24,  1.49s/it][A
Epoch 6 in training:   2%|▏         | 9/469 [00:13<11:27,  1.49s/it][A
Epoch 6 in training:   2%|▏         | 10/469 [00:14<11:19,  1.48s/it][A
Epoch 6 in training:   2%|▏         | 11/469 [00:16<11:33,  1.51s/it][A
Epoch 6 in training:   3%|▎         | 12/469 [00:17<11:24,  1.50s/it][A
Epoch 6 in training:   3%|▎         | 13/469 [00:19<11:23,  1.50s/it

Epoch 6/10 loss: 2.23



Epoch 7 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 7 in training:   0%|          | 1/469 [00:01<11:15,  1.44s/it][A
Epoch 7 in training:   0%|          | 2/469 [00:02<11:22,  1.46s/it][A
Epoch 7 in training:   1%|          | 3/469 [00:04<11:54,  1.53s/it][A
Epoch 7 in training:   1%|          | 4/469 [00:05<11:38,  1.50s/it][A
Epoch 7 in training:   1%|          | 5/469 [00:07<11:27,  1.48s/it][A
Epoch 7 in training:   1%|▏         | 6/469 [00:08<11:22,  1.47s/it][A
Epoch 7 in training:   1%|▏         | 7/469 [00:10<11:18,  1.47s/it][A
Epoch 7 in training:   2%|▏         | 8/469 [00:11<11:14,  1.46s/it][A
Epoch 7 in training:   2%|▏         | 9/469 [00:13<11:10,  1.46s/it][A
Epoch 7 in training:   2%|▏         | 10/469 [00:14<11:29,  1.50s/it][A
Epoch 7 in training:   2%|▏         | 11/469 [00:16<11:20,  1.49s/it][A
Epoch 7 in training:   3%|▎         | 12/469 [00:17<11:13,  1.47s/it][A
Epoch 7 in training:   3%|▎         | 13/469 [00:19<11:07,  1.46s/it

Epoch 7/10 loss: 2.23



Epoch 8 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 8 in training:   0%|          | 1/469 [00:01<11:08,  1.43s/it][A
Epoch 8 in training:   0%|          | 2/469 [00:02<11:45,  1.51s/it][A
Epoch 8 in training:   1%|          | 3/469 [00:04<11:37,  1.50s/it][A
Epoch 8 in training:   1%|          | 4/469 [00:05<11:25,  1.47s/it][A
Epoch 8 in training:   1%|          | 5/469 [00:07<11:19,  1.46s/it][A
Epoch 8 in training:   1%|▏         | 6/469 [00:08<11:13,  1.45s/it][A
Epoch 8 in training:   1%|▏         | 7/469 [00:10<11:13,  1.46s/it][A
Epoch 8 in training:   2%|▏         | 8/469 [00:11<11:27,  1.49s/it][A
Epoch 8 in training:   2%|▏         | 9/469 [00:13<11:23,  1.49s/it][A
Epoch 8 in training:   2%|▏         | 10/469 [00:14<11:15,  1.47s/it][A
Epoch 8 in training:   2%|▏         | 11/469 [00:16<11:09,  1.46s/it][A
Epoch 8 in training:   3%|▎         | 12/469 [00:17<11:06,  1.46s/it][A
Epoch 8 in training:   3%|▎         | 13/469 [00:19<11:03,  1.45s/it

Epoch 8/10 loss: 2.23



Epoch 9 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 9 in training:   0%|          | 1/469 [00:01<11:11,  1.43s/it][A
Epoch 9 in training:   0%|          | 2/469 [00:02<11:23,  1.46s/it][A
Epoch 9 in training:   1%|          | 3/469 [00:04<11:16,  1.45s/it][A
Epoch 9 in training:   1%|          | 4/469 [00:05<11:13,  1.45s/it][A
Epoch 9 in training:   1%|          | 5/469 [00:07<11:12,  1.45s/it][A
Epoch 9 in training:   1%|▏         | 6/469 [00:08<11:30,  1.49s/it][A
Epoch 9 in training:   1%|▏         | 7/469 [00:10<11:22,  1.48s/it][A
Epoch 9 in training:   2%|▏         | 8/469 [00:11<11:16,  1.47s/it][A
Epoch 9 in training:   2%|▏         | 9/469 [00:13<11:14,  1.47s/it][A
Epoch 9 in training:   2%|▏         | 10/469 [00:14<11:09,  1.46s/it][A
Epoch 9 in training:   2%|▏         | 11/469 [00:16<11:05,  1.45s/it][A
Epoch 9 in training:   3%|▎         | 12/469 [00:17<11:04,  1.45s/it][A
Epoch 9 in training:   3%|▎         | 13/469 [00:18<11:01,  1.45s/it

Epoch 9/10 loss: 2.23



Epoch 10 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 10 in training:   0%|          | 1/469 [00:01<11:21,  1.46s/it][A
Epoch 10 in training:   0%|          | 2/469 [00:02<11:12,  1.44s/it][A
Epoch 10 in training:   1%|          | 3/469 [00:04<11:10,  1.44s/it][A
Epoch 10 in training:   1%|          | 4/469 [00:05<11:30,  1.49s/it][A
Epoch 10 in training:   1%|          | 5/469 [00:07<11:21,  1.47s/it][A
Epoch 10 in training:   1%|▏         | 6/469 [00:08<11:13,  1.45s/it][A
Epoch 10 in training:   1%|▏         | 7/469 [00:10<11:08,  1.45s/it][A
Epoch 10 in training:   2%|▏         | 8/469 [00:11<11:09,  1.45s/it][A
Epoch 10 in training:   2%|▏         | 9/469 [00:13<11:07,  1.45s/it][A
Epoch 10 in training:   2%|▏         | 10/469 [00:14<11:02,  1.44s/it][A
Epoch 10 in training:   2%|▏         | 11/469 [00:16<11:23,  1.49s/it][A
Epoch 10 in training:   3%|▎         | 12/469 [00:17<11:14,  1.48s/it][A
Epoch 10 in training:   3%|▎         | 13/469 [00:19<11

Epoch 10/10 loss: 2.23


Testing: 100%|██████████| 79/79 [00:49<00:00,  1.59it/s]

Test loss: 2.23
Test accuracy: 20.33%





In [9]:
path: str = "fourierkan_vit_10epochs.pth"

torch.save(mnist_model.state_dict(), path)