In [1]:
from pathlib import Path
import torch
import numpy as np
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
# Uncomment this line for MNIST training.
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
from tqdm import tqdm, trange
from torch.optim import Adam

np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7917b6f8f390>

In [2]:
class MSA(torch.nn.Module):
    """
    This is the template implementation of the "Multi-Scale Attention" Layer.

    The query, key and value mapping are matrix-multipled against each other in order to
    find the attention, or, the relation of a word and its interaction with surrounding words.
    """
    def __init__(self, d, n_heads=4):
        super(MSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0  # Shouldn't divide dimension (d) into n_heads

        d_head = int(d / n_heads)
        self.q_mappings = torch.nn.ModuleList([torch.nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = torch.nn.ModuleList([torch.nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = torch.nn.ModuleList([torch.nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, sequences):
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [3]:
!pip install monai
!pip install einops

Collecting monai
  Downloading monai-1.3.2-py3-none-any.whl.metadata (10 kB)
Downloading monai-1.3.2-py3-none-any.whl (1.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m22.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.3.2
Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


In [4]:
import math
import torch
from functools import partial
from torch import nn, einsum
from torch.autograd.function import Function

from einops import rearrange

from torch.jit import fork, wait

from torch.cuda.amp import autocast, GradScaler
from torch.nn import DataParallel
# constants

EPSILON = 1e-10

# helper functions

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

class FlashAttentionFunction(Function):
    @staticmethod
    @torch.no_grad()
    def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
        """ Algorithm 1 in the v2 paper """

        device = q.device
        max_neg_value = -torch.finfo(q.dtype).max
        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)

        o = torch.zeros_like(q)
        all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device)
        all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device = device)

        scale = (q.shape[-1] ** -0.5)

        num_row_tiles = math.ceil(q.shape[-2] / q_bucket_size)
        num_col_tiles = math.ceil(k.shape[-2] / k_bucket_size)

        if exists(mask) and mask.ndim == 2:
            mask = rearrange(mask, 'b n -> b 1 1 n')

        if not exists(mask):
            col_masks = (None,) * num_col_tiles
            mask = (col_masks,) * num_row_tiles 
        else:
            mask = ((mask,) * num_row_tiles) if mask.shape[-2] == 1 else mask.split(q_bucket_size, dim = -2)
            mask = tuple(((row_mask,) * num_col_tiles) if row_mask.shape[-1] == 1 else row_mask.split(k_bucket_size, dim = -1) for row_mask in mask)

        row_splits = zip(
            q.split(q_bucket_size, dim = -2),
            o.split(q_bucket_size, dim = -2),
            mask,
            all_row_sums.split(q_bucket_size, dim = -2),
            all_row_maxes.split(q_bucket_size, dim = -2),
        )

        for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
            q_start_index = ind * q_bucket_size - qk_len_diff

            col_splits = zip(
                k.split(k_bucket_size, dim = -2),
                v.split(k_bucket_size, dim = -2),
                row_mask
            )

            for k_ind, (kc, vc, col_mask) in enumerate(col_splits):
                k_start_index = k_ind * k_bucket_size

                attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

                if exists(col_mask):
                    attn_weights.masked_fill_(~col_mask, max_neg_value)

                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
                    attn_weights.masked_fill_(causal_mask, max_neg_value)

                block_row_maxes = attn_weights.amax(dim = -1, keepdims = True)
                new_row_maxes = torch.maximum(block_row_maxes, row_maxes)

                exp_weights = torch.exp(attn_weights - new_row_maxes)

                if exists(col_mask):
                    exp_weights.masked_fill_(~col_mask, 0.)

                block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)

                exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)

                exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)

                new_row_sums = exp_row_max_diff * row_sums + block_row_sums

                oc.mul_(exp_row_max_diff).add_(exp_values)

                row_maxes.copy_(new_row_maxes)
                row_sums.copy_(new_row_sums)

            oc.div_(row_sums)

        lse = all_row_sums.log() + all_row_maxes

        ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
        ctx.save_for_backward(q, k, v, o, lse)

        return o

    @staticmethod
    @torch.no_grad()
    def backward(ctx, do):
        """ Algorithm 2 in the v2 paper """

        causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
        q, k, v, o, lse = ctx.saved_tensors

        device = q.device

        max_neg_value = -torch.finfo(q.dtype).max
        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)

        dq = torch.zeros_like(q)
        dk = torch.zeros_like(k)
        dv = torch.zeros_like(v)

        row_splits = zip(
            q.split(q_bucket_size, dim = -2),
            o.split(q_bucket_size, dim = -2),
            do.split(q_bucket_size, dim = -2),
            mask,
            lse.split(q_bucket_size, dim = -2),
            dq.split(q_bucket_size, dim = -2)
        )

        for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
            q_start_index = ind * q_bucket_size - qk_len_diff

            col_splits = zip(
                k.split(k_bucket_size, dim = -2),
                v.split(k_bucket_size, dim = -2),
                dk.split(k_bucket_size, dim = -2),
                dv.split(k_bucket_size, dim = -2),
                row_mask
            )

            for k_ind, (kc, vc, dkc, dvc, col_mask) in enumerate(col_splits):
                k_start_index = k_ind * k_bucket_size

                attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
                    attn_weights.masked_fill_(causal_mask, max_neg_value)

                p = torch.exp(attn_weights - lsec)

                if exists(col_mask):
                    p.masked_fill_(~col_mask, 0.)

                dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
                dp = einsum('... i d, ... j d -> ... i j', doc, vc)

                D = (doc * oc).sum(dim = -1, keepdims = True)
                ds = p * scale * (dp - D)

                dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
                dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)

                dqc.add_(dq_chunk)
                dkc.add_(dk_chunk)
                dvc.add_(dv_chunk)

        return dq, dk, dv, None, None, None, None

In [5]:
class FlashAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        causal = False,
        q_bucket_size = 512,
        k_bucket_size = 1024,
        parallel = False,
        mixed_precision = False
    ):
        super().__init__()
        self.heads = heads
        self.causal = causal
        self.parallel = parallel
        self.mixed_precision = mixed_precision

        inner_dim = heads * dim_head

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

        # memory efficient attention related parameters
        # can be overriden on forward
        self.q_bucket_size = q_bucket_size
        self.k_bucket_size = k_bucket_size

        if self.parallel:
            self.model = DataParallel(self)
        if self.mixed_precision:
            self.scaler = GradScaler()

    def forward(
        self,
        x,
        context = None,
        mask = None,
        q_bucket_size = None,
        k_bucket_size = None,
    ):
        q_bucket_size = default(q_bucket_size, self.q_bucket_size)
        k_bucket_size = default(k_bucket_size, self.k_bucket_size)

        h = self.heads
        context = default(context, x)

        q = self.to_q(x)
        k, v = self.to_kv(context).chunk(2, dim=-1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))

        if self.parallel:
            # Split the input data into chunks and move each chunk to the correct GPU
            num_gpus = torch.cuda.device_count()
            x_chunks = x.split(x.size(0) // num_gpus)
            x_chunks = [chunk.to(f'cuda:{i}') for i, chunk in enumerate(x_chunks)]
            q = x_chunks

        if self.mixed_precision:
            # Use autocast to allow operations to run in lower precision
            with autocast():
                out = FlashAttentionFunction.apply(q, k, v, mask, self.causal, q_bucket_size, k_bucket_size)
        else:
            out = FlashAttentionFunction.apply(q, k, v, mask, self.causal, q_bucket_size, k_bucket_size)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

In [6]:
class FlashViT(torch.nn.Module):
    """
    The workflow will be as follows.
        1. Find the linear mapping of the input
        2. Embed them using the function that we have written
        3. Use 'n' MSA blocks and add a linear and a softmax layer at the end
    """

    def __init__(self, chw, n_patches=16, n_blocks=2, hidden_d=8, n_heads=4, out_d=10):
        super(FlashViT, self).__init__()

        self.chw = chw
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.hidden_d = hidden_d

        # Input and patch sizes
        assert chw[1] % n_patches == 0
        assert chw[2] % n_patches == 0
        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

        # Linear mapping
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = torch.nn.Linear(self.input_d, self.hidden_d)

        # Classification token
        self.v_class = torch.nn.Parameter(torch.rand(1, self.hidden_d))

        # Positional embedding
        self.register_buffer('pos_embeddings', self.positional_embeddings(n_patches ** 2 + 1, hidden_d),
                             persistent=False)

        # Encoder blocks
        self.blocks = torch.nn.ModuleList([FlashAttention(dim = hidden_d, heads = n_heads) for _ in range(n_blocks)])

        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(self.hidden_d, out_d),
            torch.nn.Softmax(dim=-1)
        )
        
    def patchify(self, images, n_patches):
        """
        In order to "sequentially" pass in the images, we can break down the main image into multiple sub-images
        and map them to a vector. This is exactly what this function does.

        Arguments:
        images: The image passed into this function
        n_patches: The number of patches to split the image into.

        Returns our patches aka the sub-images.
        """
        n, c, h, w = images.shape

        assert h == w, "Only for square images"

        patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
        patch_size = h // n_patches

        for idx, image in enumerate(images):
            for i in range(n_patches):
                for j in range(n_patches):
                    patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                    patches[idx, i * n_patches + j] = patch.flatten()
        return patches
    
    def positional_embeddings(self, sequence_length, d):
        """
        In order for the model to know where to place each image, one can use positional embeddings where high freq values
        are classified into the first few dimensions while low frequency values are added on to the latter dimensions. This
        function performs exactly that. It has two parameters.

        Arguments:
        sequence_length: The number of tokens for the dataset.
        d: The dimensionality for each token.

        Returns a matrix where each (i,j) is added as token i in dimension j.
        """
        result = torch.ones(sequence_length, d)
        for i in range(sequence_length):
            for j in range(d):
                result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** (j / d)))
        return result

    def forward(self, images):
        n, c, h, w = images.shape
        patches = self.patchify(images, self.n_patches).to(self.pos_embeddings.device)

        # rutorch.nning tokenization
        tokens = self.linear_mapper(patches)
        tokens = torch.cat((self.v_class.expand(n, 1, -1), tokens), dim=1)
        out = tokens + self.pos_embeddings.repeat(n, 1, 1)

        for block in self.blocks:
            out = block(out)

        out = out[:, 0]
        return self.mlp(out)

In [7]:
lr = 2e-3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mnist_model = FlashViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)
optimizer = Adam(mnist_model.parameters(), lr=lr)

In [8]:
def main(train_loader, test_loader):
    """
    This code contains the training and testing loop for training the vision transformers model. It requires two
    parameters

    :param train_loader: The dataloader for the training set for training the model.
    :param test_loader: The dataloader for the testing set during evaluation phase.
    """
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

    epochs = 5
    criterion = CrossEntropyLoss()
    for epoch in trange(epochs, desc="train"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch + 1}/{epochs} loss: {train_loss:.2f}")

    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)

        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

In [9]:
# For MNIST: comment out the lines above and uncomment the lines below.

transform = transforms.ToTensor()
train_mnist = MNIST(root='./mnist', train=True, download=True, transform=transform)
test_mnist = MNIST(root='./mnist', train=False, download=True, transform=transform)
train_loader = DataLoader(train_mnist, shuffle=True, batch_size=128)
test_loader = DataLoader(test_mnist, shuffle=False, batch_size=128)
main(train_loader=train_loader, test_loader=test_loader)

Using device:  cuda (Tesla P100-PCIE-16GB)


train:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 1/469 [00:00<06:40,  1.17it/s][A
Epoch 1 in training:   0%|          | 2/469 [00:01<04:09,  1.87it/s][A
Epoch 1 in training:   1%|          | 3/469 [00:01<03:18,  2.34it/s][A
Epoch 1 in training:   1%|          | 4/469 [00:01<02:55,  2.66it/s][A
Epoch 1 in training:   1%|          | 5/469 [00:02<02:42,  2.86it/s][A
Epoch 1 in training:   1%|▏         | 6/469 [00:02<02:33,  3.01it/s][A
Epoch 1 in training:   1%|▏         | 7/469 [00:02<02:28,  3.11it/s][A
Epoch 1 in training:   2%|▏         | 8/469 [00:02<02:25,  3.18it/s][A
Epoch 1 in training:   2%|▏         | 9/469 [00:03<02:22,  3.23it/s][A
Epoch 1 in training:   2%|▏         | 10/469 [00:03<02:20,  3.26it/s][A
Epoch 1 in training:   2%|▏         | 11/469 [00:03<02:19,  3.27it/s][A
Epoch 1 in training:   3%|▎         | 12/469 [00:04<02:23,  3.19it/s][A
Epoch 1 in training:   3

Epoch 1/5 loss: 2.17



Epoch 2 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 2 in training:   0%|          | 1/469 [00:00<02:22,  3.29it/s][A
Epoch 2 in training:   0%|          | 2/469 [00:00<02:21,  3.30it/s][A
Epoch 2 in training:   1%|          | 3/469 [00:00<02:20,  3.32it/s][A
Epoch 2 in training:   1%|          | 4/469 [00:01<02:19,  3.33it/s][A
Epoch 2 in training:   1%|          | 5/469 [00:01<02:18,  3.34it/s][A
Epoch 2 in training:   1%|▏         | 6/469 [00:01<02:18,  3.34it/s][A
Epoch 2 in training:   1%|▏         | 7/469 [00:02<02:18,  3.34it/s][A
Epoch 2 in training:   2%|▏         | 8/469 [00:02<02:17,  3.34it/s][A
Epoch 2 in training:   2%|▏         | 9/469 [00:02<02:17,  3.34it/s][A
Epoch 2 in training:   2%|▏         | 10/469 [00:03<02:20,  3.27it/s][A
Epoch 2 in training:   2%|▏         | 11/469 [00:03<02:20,  3.27it/s][A
Epoch 2 in training:   3%|▎         | 12/469 [00:03<02:19,  3.29it/s][A
Epoch 2 in training:   3%|▎         | 13/469 [00:03<02:17,  3.31it/s

Epoch 2/5 loss: 2.11



Epoch 3 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 3 in training:   0%|          | 1/469 [00:00<02:21,  3.32it/s][A
Epoch 3 in training:   0%|          | 2/469 [00:00<02:20,  3.32it/s][A
Epoch 3 in training:   1%|          | 3/469 [00:00<02:19,  3.35it/s][A
Epoch 3 in training:   1%|          | 4/469 [00:01<02:18,  3.35it/s][A
Epoch 3 in training:   1%|          | 5/469 [00:01<02:18,  3.34it/s][A
Epoch 3 in training:   1%|▏         | 6/469 [00:01<02:18,  3.34it/s][A
Epoch 3 in training:   1%|▏         | 7/469 [00:02<02:22,  3.23it/s][A
Epoch 3 in training:   2%|▏         | 8/469 [00:02<02:29,  3.08it/s][A
Epoch 3 in training:   2%|▏         | 9/469 [00:02<02:28,  3.11it/s][A
Epoch 3 in training:   2%|▏         | 10/469 [00:03<02:27,  3.11it/s][A
Epoch 3 in training:   2%|▏         | 11/469 [00:03<02:24,  3.17it/s][A
Epoch 3 in training:   3%|▎         | 12/469 [00:03<02:22,  3.21it/s][A
Epoch 3 in training:   3%|▎         | 13/469 [00:04<02:20,  3.25it/s

Epoch 3/5 loss: 2.03



Epoch 4 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 4 in training:   0%|          | 1/469 [00:00<02:20,  3.33it/s][A
Epoch 4 in training:   0%|          | 2/469 [00:00<02:19,  3.35it/s][A
Epoch 4 in training:   1%|          | 3/469 [00:00<02:19,  3.35it/s][A
Epoch 4 in training:   1%|          | 4/469 [00:01<02:23,  3.24it/s][A
Epoch 4 in training:   1%|          | 5/469 [00:01<02:20,  3.29it/s][A
Epoch 4 in training:   1%|▏         | 6/469 [00:01<02:19,  3.32it/s][A
Epoch 4 in training:   1%|▏         | 7/469 [00:02<02:18,  3.34it/s][A
Epoch 4 in training:   2%|▏         | 8/469 [00:02<02:18,  3.34it/s][A
Epoch 4 in training:   2%|▏         | 9/469 [00:02<02:17,  3.34it/s][A
Epoch 4 in training:   2%|▏         | 10/469 [00:03<02:17,  3.34it/s][A
Epoch 4 in training:   2%|▏         | 11/469 [00:03<02:16,  3.35it/s][A
Epoch 4 in training:   3%|▎         | 12/469 [00:03<02:16,  3.35it/s][A
Epoch 4 in training:   3%|▎         | 13/469 [00:03<02:16,  3.33it/s

Epoch 4/5 loss: 1.93



Epoch 5 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 5 in training:   0%|          | 1/469 [00:00<02:33,  3.05it/s][A
Epoch 5 in training:   0%|          | 2/469 [00:00<02:35,  2.99it/s][A
Epoch 5 in training:   1%|          | 3/469 [00:00<02:32,  3.05it/s][A
Epoch 5 in training:   1%|          | 4/469 [00:01<02:28,  3.12it/s][A
Epoch 5 in training:   1%|          | 5/469 [00:01<02:25,  3.18it/s][A
Epoch 5 in training:   1%|▏         | 6/469 [00:01<02:22,  3.24it/s][A
Epoch 5 in training:   1%|▏         | 7/469 [00:02<02:22,  3.25it/s][A
Epoch 5 in training:   2%|▏         | 8/469 [00:02<02:20,  3.28it/s][A
Epoch 5 in training:   2%|▏         | 9/469 [00:02<02:18,  3.31it/s][A
Epoch 5 in training:   2%|▏         | 10/469 [00:03<02:17,  3.33it/s][A
Epoch 5 in training:   2%|▏         | 11/469 [00:03<02:17,  3.34it/s][A
Epoch 5 in training:   3%|▎         | 12/469 [00:03<02:16,  3.34it/s][A
Epoch 5 in training:   3%|▎         | 13/469 [00:03<02:16,  3.35it/s

Epoch 5/5 loss: 1.84


Testing: 100%|██████████| 79/79 [00:23<00:00,  3.41it/s]

Test loss: 1.82
Test accuracy: 64.35%





In [10]:
path: str = "flash_vit_5epochs.pth"
torch.save(mnist_model.state_dict(), path)