https://www.youtube.com/watch?v=LxPDpAiyqSU&ab_channel=AIHMP

https://arxiv.org/pdf/2103.14030

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

In [None]:
def next_iter(temp):
    return next(iter(temp))

In [None]:
x = torch.randn(size = (1, 56, 56, 96))

In [None]:
x1 = x.reshape(1, -1, 96)
x2 = x.reshape(1, 3, 64, 49, 32)

In [None]:
x2.shape

torch.Size([1, 3, 64, 49, 32])

In [None]:
%%timeit

x2 @ x2.transpose(-1, -2)

1.53 ms ± 117 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [None]:
%%timeit

x1 @ x1.transpose(-1, -2)

57.1 ms ± 1.61 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
class PatchifyMerging(nn.Module):
    def __init__(self, in_ch, out_ch, downscale_factor):
        super().__init__()
        self.model = nn.Conv2d(in_ch, out_ch, kernel_size=downscale_factor, stride=downscale_factor)


    def forward(self, x, is_permute = True):
        if is_permute:
            return self.model(x).permute(0, 2, 3, 1) # B, H, W, C

        else:
            return self.model(x) # B, C, H, W


In [None]:
pat = PatchifyMerging(3, 96, 4)

In [None]:
x = torch.randn(size = (1, 3, 224, 224))

In [None]:
pat(x).shape

torch.Size([1, 56, 56, 96])

In [None]:
class CyclicShift(nn.Module):

    def __init__(self, displacement):

        super().__init__()
        self.displacement = displacement

    def forward(self, x):

        return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1,2))

In [None]:
torch.arange(100).view(10,10)

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
        [60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
        [70, 71, 72, 73, 74, 75, 76, 77, 78, 79],
        [80, 81, 82, 83, 84, 85, 86, 87, 88, 89],
        [90, 91, 92, 93, 94, 95, 96, 97, 98, 99]])

In [None]:
torch.roll(torch.arange(100).view(10,10), shifts=(-1, -1), dims=(0,1))

tensor([[11, 12, 13, 14, 15, 16, 17, 18, 19, 10],
        [21, 22, 23, 24, 25, 26, 27, 28, 29, 20],
        [31, 32, 33, 34, 35, 36, 37, 38, 39, 30],
        [41, 42, 43, 44, 45, 46, 47, 48, 49, 40],
        [51, 52, 53, 54, 55, 56, 57, 58, 59, 50],
        [61, 62, 63, 64, 65, 66, 67, 68, 69, 60],
        [71, 72, 73, 74, 75, 76, 77, 78, 79, 70],
        [81, 82, 83, 84, 85, 86, 87, 88, 89, 80],
        [91, 92, 93, 94, 95, 96, 97, 98, 99, 90],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9,  0]])

In [None]:
def create_mask(window_size, displacement, upper_lower, left_right):
    mask = torch.zeros(window_size ** 2, window_size ** 2) # (49, 49)

    if upper_lower:
        mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
        mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')

    if left_right:
        mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
        mask[:, -displacement:, :, :-displacement] = float('-inf')
        mask[:, :-displacement, :, -displacement:] = float('-inf')
        mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')

    return mask


def get_relative_distances(window_size):
    indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
    distances = indices[None, :, :] - indices[:, None, :]
    return distances

In [None]:
create_mask(3, 3 // 2, True, False)

tensor([[0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0.],
        [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0.],
        [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0.]])

In [None]:
create_mask(3, 3 // 2, False, True)

tensor([[0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, 0.],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, 0.],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, 0.]])

In [None]:
class WindowAttention(nn.Module):
    def __init__(self, dim, head_dim, shifted, window_size):
        super().__init__()

        self.n_heads = dim // head_dim
        self.shifted = shifted
        self.window_size = window_size
        self.head_dim = head_dim

        if shifted:
            self.cyclic_shift = CyclicShift(-(window_size // 2))
            self.back_cyclic_shift = CyclicShift(window_size // 2)

            self.upper_lower_mask = nn.Parameter(create_mask(window_size = window_size, displacement =window_size // 2,
                                                             upper_lower = True, left_right = False), requires_grad=False)

            self.left_right_mask = nn.Parameter(create_mask(window_size = window_size, displacement =window_size // 2,
                                                             upper_lower = False, left_right = True), requires_grad=False)

        self.Q = nn.Linear(dim, dim)
        self.K = nn.Linear(dim, dim)
        self.V = nn.Linear(dim, dim)
        self.Out = nn.Linear(dim, dim)
        self.scale = head_dim ** -0.5

        self.pos_embedding = nn.Parameter(torch.randn(size = (window_size ** 2, window_size ** 2)))


    def forward(self, x):
        B, H, W, C = x.shape
        n_windows_H = H // self.window_size
        n_windows_W = W // self.window_size

        if self.shifted:
            x = self.cyclic_shift(x)

        Q = self.Q(x)
        K = self.K(x)
        V = self.V(x)
        Q1 = Q.clone().view(B, n_windows_H, self.window_size, n_windows_W, self.window_size, self.n_heads, self.head_dim)
        K1 = K.clone().view(B, n_windows_H, self.window_size, n_windows_W, self.window_size, self.n_heads, self.head_dim)
        V1 = V.clone().view(B, n_windows_H, self.window_size, n_windows_W, self.window_size, self.n_heads, self.head_dim)
        Q = Q1.permute(0, 5, 1, 3, 2, 4, 6).reshape(B, self.n_heads, n_windows_H * n_windows_W, self.window_size ** 2, self.head_dim)
        K = K1.permute(0, 5, 1, 3, 2, 4, 6).reshape(B, self.n_heads, n_windows_H * n_windows_W, self.window_size ** 2, self.head_dim)
        V = V1.permute(0, 5, 1, 3, 2, 4, 6).reshape(B, self.n_heads, n_windows_H * n_windows_W, self.window_size ** 2, self.head_dim)

        # Q2, K2, V2 = map(
        #             lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
        #                 h=self.n_heads , w_h=self.window_size, w_w=self.window_size), [Q.clone(), K.clone(), V.clone()])

        # assert torch.allclose(Q, Q2) and torch.allclose(K, K2) and torch.allclose(V, V2)

        att = (Q @ K.transpose(-2, -1)) * self.scale

        att += self.pos_embedding

        if self.shifted:
            att[:, :, -n_windows_W:] += self.upper_lower_mask
            att[:, :, n_windows_W - 1::n_windows_W] += self.left_right_mask


        att = att.softmax(dim = -1)

        att_v = att @ V

        att_v1 = rearrange(att_v.clone(), 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
                        h=self.n_heads, w_h=self.window_size, w_w=self.window_size, nw_h=n_windows_H, nw_w=n_windows_W)

        att_v = att_v.clone().view(B, self.n_heads, n_windows_H, n_windows_W, self.window_size, self.window_size, self.head_dim)
        att_v = att_v.permute(0, 2, 4, 3, 5, 1, 6).reshape(B, H, W, C)

        # att_v2 = rearrange(att_v.clone(), 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
        #                 h=self.n_heads, w_h=self.window_size, w_w=self.window_size, nw_h=n_windows_H, nw_w=n_windows_W)

        assert torch.allclose(att_v, att_v1)

        out = self.Out(att_v)

        if self.shifted:
            out = self.back_cyclic_shift(out)

        return out

In [None]:
wind_att = WindowAttention(96, 32, True, 7)

In [None]:
out = wind_att(pat(x))

In [None]:
out.shape

torch.Size([1, 56, 56, 96])

In [None]:
class SwinBlock(nn.Module):
    def __init__(self, dim, head_dim, window_size):
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(dim)
        self.layer_norm_2 = nn.LayerNorm(dim)

        self.layer_norm_1_shift = nn.LayerNorm(dim)
        self.layer_norm_2_shift = nn.LayerNorm(dim)

        self.WMSA = WindowAttention(dim, head_dim, False, window_size)
        self.SWMSA = WindowAttention(dim, head_dim, True, window_size)


        self.MLP = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
        self.MLP_shift = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))

    def forward(self, x):

        x = self.WMSA(self.layer_norm_1(x)) + x
        x = self.MLP(self.layer_norm_2(x)) + x

        x = self.SWMSA(self.layer_norm_1_shift(x)) + x
        x = self.MLP_shift(self.layer_norm_2_shift(x)) + x

        return x

In [None]:
swin_block = SwinBlock(96, 32, 7)

In [None]:
out_swin_block = swin_block(pat(x))


In [None]:
out_swin_block.shape

torch.Size([1, 56, 56, 96])

In [None]:
class SwinStage(nn.Module):
    def __init__(self, in_dim, dim, head_dim, downscale_factor, window_size, n_blocks):
        super().__init__()

        self.patchify_merging = PatchifyMerging(in_dim, dim, downscale_factor)

        self.blocks = nn.Sequential(*nn.ModuleList([SwinBlock(dim, head_dim, window_size) for _ in range(n_blocks)]))

    def forward(self, x):

        x = self.patchify_merging(x)
        x = self.blocks(x)

        return x.permute(0, 3, 2, 1)

In [None]:
swin_stage = SwinStage(3, 96, 32, 4, 7, 1)

In [None]:
out_stage = swin_stage(x)

In [None]:
out_stage.shape

torch.Size([1, 96, 56, 56])

In [None]:
class SwinTransformer(nn.Module):
    def __init__(self, in_dims = [3, 96, 192, 384], dims = [96, 192, 384, 768], n_blocks = [1, 1, 3, 1], downscale_factors = [4, 2, 2, 2], window_size = 7, head_dim = 32, n_classes = 10):
        super().__init__()

        # self.swin_stage_1 = SwinStage(in_dims[0], dims[0], head_dim, downscale_factors[0], window_size, n_blocks[0])
        # self.swin_stage_2 = SwinStage(in_dims[1], dims[1], head_dim, downscale_factors[1], window_size, n_blocks[1])
        # self.swin_stage_3 = SwinStage(in_dims[2], dims[2], head_dim, downscale_factors[2], window_size, n_blocks[2])
        # self.swin_stage_4 = SwinStage(in_dims[3], dims[3], head_dim, downscale_factors[3], window_size, n_blocks[3])

        self.stages = nn.Sequential(*nn.ModuleList(
            SwinStage(in_dim, dim, head_dim, downscale_factor, window_size, n_block) for in_dim, dim, downscale_factor, n_block in zip(in_dims, dims, downscale_factors, n_blocks)
        ))

        self.final = nn.Linear(dims[-1], n_classes)

    def forward(self, x):
        x = x.to(
            next_iter(self.parameters()).device
        )
        x = self.stages(x)
        out = self.final(x.mean(dim = [2,3]))
        return out

In [None]:
swin_transformer = SwinTransformer()

In [None]:
out = swin_transformer(x)

In [None]:
out.shape

torch.Size([1, 10])

In [None]:
batch_size = 64
learning_rate = 0.0001
num_epochs = 5

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Lambda(lambda x: torch.cat([x, x, x], 0)),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=False)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:02<00:00, 3803112.66it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 491711.94it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4489144.94it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 13266384.94it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [None]:
DEVICE = torch.device('cuda:0')

In [None]:
model = SwinTransformer().to(DEVICE)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
def train(model, train_loader, criterion, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
            # data = data.to(DEVICE)
            target = target.to(DEVICE)
            output = model(data)
            loss = criterion(output, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % 5 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}')


def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader):
            data = data.to(DEVICE)
            target = target.to(DEVICE)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    print(f'Test Accuracy of the model on the test images: {100 * correct / total:.2f}%')
    return 100 * correct / total

In [None]:
train(model, train_loader, criterion, optimizer, num_epochs)

  0%|          | 0/938 [00:00<?, ?it/s]

Epoch [1/5], Step [0/938], Loss: 2.3717
Epoch [1/5], Step [5/938], Loss: 3.8486
Epoch [1/5], Step [10/938], Loss: 2.9476
Epoch [1/5], Step [15/938], Loss: 2.6948
Epoch [1/5], Step [20/938], Loss: 2.3393
Epoch [1/5], Step [25/938], Loss: 2.0098
Epoch [1/5], Step [30/938], Loss: 1.8856
Epoch [1/5], Step [35/938], Loss: 1.7361
Epoch [1/5], Step [40/938], Loss: 1.8501
Epoch [1/5], Step [45/938], Loss: 1.7887
Epoch [1/5], Step [50/938], Loss: 1.5656
Epoch [1/5], Step [55/938], Loss: 1.6079
Epoch [1/5], Step [60/938], Loss: 1.3303
Epoch [1/5], Step [65/938], Loss: 1.2169
Epoch [1/5], Step [70/938], Loss: 1.3359
Epoch [1/5], Step [75/938], Loss: 1.2404
Epoch [1/5], Step [80/938], Loss: 1.1320
Epoch [1/5], Step [85/938], Loss: 1.0766
Epoch [1/5], Step [90/938], Loss: 0.9578
Epoch [1/5], Step [95/938], Loss: 0.8727
Epoch [1/5], Step [100/938], Loss: 0.8238
Epoch [1/5], Step [105/938], Loss: 0.8548
Epoch [1/5], Step [110/938], Loss: 0.5943
Epoch [1/5], Step [115/938], Loss: 0.8396
Epoch [1/5], S

KeyboardInterrupt: 

In [None]:
test(model, test_loader)

  0%|          | 0/157 [00:00<?, ?it/s]

Test Accuracy of the model on the test images: 92.03%


92.03

In [None]:
imgs , tragets = next_iter(test_loader)

In [None]:
preds = model(imgs).argmax(1).cpu().detach()

In [None]:
(preds == tragets).float().mean()

tensor(0.9219)

In [None]:
num = 0

for p in model.parameters():
    num += p.numel()

In [None]:
num

27555586

In [None]:
!pip install datasets -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/474.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m471.0/474.3 kB[0m [31m20.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.3/474.3 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.9/39.9 MB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the fo

In [None]:
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import CLIPVisionModel, AutoImageProcessor

In [None]:
class CLIPVitClassifier(nn.Module):
    def __init__(self, pth = 'openai/clip-vit-large-patch14-336', num_classes = 10):
        super().__init__()
        self.model = CLIPVisionModel.from_pretrained(pth)
        self.image_proc = AutoImageProcessor.from_pretrained(pth)
        self.model.requires_grad_(False)
        self.final_layer = nn.Linear(self.model.config.hidden_size, num_classes)

    def forward(self, x):
        x = x.to(
            next_iter(self.parameters()).device
        )
        x = self.model(x).pooler_output

        return self.final_layer(x)

In [None]:
class Mnist(Dataset):
    def __init__(self, image_proc, split = 'train'):
        self.data = load_dataset("mnist", split = split)
        self.image_proc = image_proc

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        img, label = self.data[idx]['image'], self.data[idx]['label']

        img = self.image_proc(img, return_tensors = 'pt')['pixel_values'][0]
        label = torch.tensor(label).long()

        return img, label

In [None]:
model = CLIPVitClassifier(pth = 'openai/clip-vit-base-patch32').to(DEVICE)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
mnist_train = Mnist(model.image_proc)
mnist_test = Mnist(model.image_proc, split='test')
train_loader = DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=False)

README.md:   0%|          | 0.00/6.97k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/15.6M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/2.60M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [None]:
train(model, train_loader, criterion, optimizer, num_epochs)

  0%|          | 0/938 [00:00<?, ?it/s]

Epoch [1/5], Step [0/938], Loss: 2.4545
Epoch [1/5], Step [5/938], Loss: 2.4268
Epoch [1/5], Step [10/938], Loss: 2.3622
Epoch [1/5], Step [15/938], Loss: 2.2909
Epoch [1/5], Step [20/938], Loss: 2.2282
Epoch [1/5], Step [25/938], Loss: 2.2043
Epoch [1/5], Step [30/938], Loss: 2.1940
Epoch [1/5], Step [35/938], Loss: 2.1635
Epoch [1/5], Step [40/938], Loss: 2.1018
Epoch [1/5], Step [45/938], Loss: 2.1207
Epoch [1/5], Step [50/938], Loss: 2.0739
Epoch [1/5], Step [55/938], Loss: 2.1175
Epoch [1/5], Step [60/938], Loss: 2.1484
Epoch [1/5], Step [65/938], Loss: 2.0346
Epoch [1/5], Step [70/938], Loss: 2.0596
Epoch [1/5], Step [75/938], Loss: 2.0458
Epoch [1/5], Step [80/938], Loss: 1.9615
Epoch [1/5], Step [85/938], Loss: 1.9104
Epoch [1/5], Step [90/938], Loss: 1.8799
Epoch [1/5], Step [95/938], Loss: 1.9032
Epoch [1/5], Step [100/938], Loss: 1.9588
Epoch [1/5], Step [105/938], Loss: 1.9078
Epoch [1/5], Step [110/938], Loss: 1.8595
Epoch [1/5], Step [115/938], Loss: 1.8135
Epoch [1/5], S

KeyboardInterrupt: 

In [None]:
test(model, test_loader)

  0%|          | 0/157 [00:00<?, ?it/s]

Test Accuracy of the model on the test images: 88.86%


88.86

https://huggingface.co/microsoft/swin-base-patch4-window7-224-in22k