In [1]:
import numpy as np
from tqdm import tqdm, trange
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST
np.random.seed(0)
torch.manual_seed(0)


<torch._C.Generator at 0x22282964c70>

In [2]:
import time 
def patchify_tommy(images, n_patches):
    n_batch, c, h, w = images.shape
    patch_size_y, patch_size_x = h//n_patches, w//n_patches
    patches = torch.zeros((n_batch, n_patches ** 2,  h * w * c // n_patches ** 2))
    for idx, image in enumerate(images):
        c_patch=0
        for i in range(n_patches):
            for j in range(n_patches):
                pos_h_st = int(i * patch_size_y)
                pos_h_end = int((i + 1) * patch_size_y)
                pos_w_st = int(j * patch_size_x)
                pos_w_end = int((j + 1) * patch_size_x)
                patches[idx, c_patch , :] = image[:, pos_h_st: pos_h_end, pos_w_st: pos_w_end] .flatten()  
                c_patch+=1
    return patches

transform = ToTensor()
test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)
test_loader = DataLoader(test_set, shuffle=False, batch_size=256)
st_count=0
for batch in test_loader:
    x, y = batch
    st = time.time()
    patches = patchify_tommy(x, n_patches=7)
    st_count += time.time()-st
    break
print(st_count)


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


0.27900075912475586


In [3]:
class MultiHeadAttentation(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MultiHeadAttentation, self).__init__()
        self.d = d
        self.n_heads = n_heads
        self.q_mappings = nn.ModuleList([nn.Linear(d, d) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d, d) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d, d) for _ in range(self.n_heads)])
        self.o_mappings = nn.Linear(d*n_heads, d)
        self.softmax = nn.Softmax(dim=-1)
    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        result_head = []
        # multi-head
        for head in range(self.n_heads):
            q_mapping = self.q_mappings[head]
            k_mapping = self.k_mappings[head]
            v_mapping = self.v_mappings[head]
    
            q, k, v = q_mapping(sequences), k_mapping(sequences), v_mapping(sequences)
            k = torch.transpose(k, 1, 2)
            a=torch.matmul(q,k)
            attention = self.softmax(a / (self.d ** 0.5))
            o=torch.matmul(attention,v)
            result_head.append(o)
            
        o = torch.cat(result_head,axis=-1)
        o = self.o_mappings(o)
        
        return o
        

class ViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(ViTBlock, self).__init__()
        self.norm1 = nn.LayerNorm(hidden_d)
        self.mha = MultiHeadAttentation(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )

    def forward(self, x):
        out = x + self.mha(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out


class ViT(nn.Module):
    def __init__(self, input_dim, n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10, device='cpu'):
        # Super constructor
        super(ViT, self).__init__()
        
        # 1) Linear mapper
        self.linear_mapper = nn.Linear(input_dim, hidden_d)
        
        # 2) Learnable classification token
        self.class_token = nn.Parameter(torch.rand(1, hidden_d))
        
        # 3) Positional embedding
        self.positional_embeddings= self.get_positional_embeddings(n_patches ** 2 + 1, hidden_d).to(device)
        
        # 4) Transformer encoder blocks
        self.blocks = nn.ModuleList([ViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])
        
        # 5) Classification MLPk
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, out_d),
            nn.Softmax(dim=-1)
        )

    def forward(self, images_patch):
        # patch:torch.Size([n_batch, 49, 16]) MNIST
        n = images_patch.shape[0]

        # Running linear layer tokenization
        # Map the vector corresponding to each patch to the hidden size dimension
        tokens = self.linear_mapper(images_patch)
        # tokens1:torch.Size([128, 49, 8])

        
        # Adding classification token to the tokens
        tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)
        # tokens2:torch.Size([128, 50, 8])
        
        # Adding positional embedding
        out = tokens + self.positional_embeddings.repeat(n, 1, 1)
        # Adding positional embedding:torch.Size([128, 50, 8])

        # Transformer Blocks
        for block in self.blocks:
            out = block(out)
            
        # Getting the classification token only
        out = out[:, 0]
        
        return self.mlp(out) # Map to output dimension, output category distribution

    def get_positional_embeddings(self, sequence_length, d):
        result = torch.ones(sequence_length, d)
        for i in range(sequence_length):
            for j in range(d):
                result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
        return result

    
images = torch.randn(5, 1, 28,28)
n_patches=7
n, c, h, w = images.shape
print('images:{}'.format(images.shape))
patches = patchify_tommy(images, n_patches)
print('patch:{}'.format(patches.shape))
model = ViT(patches.shape[-1], n_patches=n_patches, n_blocks=2, hidden_d=8, n_heads=5, out_d=10)
   
# dummy_input = torch.randn(100, 50, 8)
# print('123')
# out=model(images)
# torch.onnx.export(model, images, "ViT.onnx",opset_version =11) 
    

images:torch.Size([5, 1, 28, 28])
patch:torch.Size([5, 49, 16])


In [5]:
N_EPOCHS = 50
LR = 0.005

# ViT parameter
n_patches=7
n_blocks=5
hidden_d=8
n_heads=2
out_class=10
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")


# data loader 
transform = ToTensor()
train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
test_loader = DataLoader(test_set, shuffle=False, batch_size=128)


# Defining model and training options
images = torch.randn(1, 1, 28,28)
patches = patchify_tommy(images, n_patches)
model = ViT(patches.shape[-1], n_patches=n_patches, n_blocks=n_blocks, hidden_d=hidden_d, n_heads=n_heads, out_d=out_class, device=device).to(device)


# Training loop
optimizer = Adam(model.parameters(), lr=LR)
criterion = CrossEntropyLoss()
for epoch in trange(N_EPOCHS, desc="Training"):
    train_loss = 0.0
    for batch in train_loader:
        x, y = batch
        x = patchify_tommy(x, n_patches)

        x, y = x.to(device), y.to(device)

        y_hat = model(x)
        loss = criterion(y_hat, y)

        train_loss += loss.detach().cpu().item() / len(train_loader)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

# Test loop
with torch.no_grad():
    correct, total = 0, 0
    test_loss = 0.0
    for batch in tqdm(test_loader, desc="Testing"):
        x, y = batch
        x = patchify_tommy(x, n_patches)

        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)
        test_loss += loss.detach().cpu().item() / len(test_loader)

        correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
        total += len(x)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")


checkpoint={}
checkpoint['model_dict'] = model.state_dict()
torch.save(checkpoint, 'Model_ViT_MNIST.pth')
print('Model saved.')

Using device:  cuda (NVIDIA GeForce RTX 2060)


Training:   2%|█▍                                                                    | 1/50 [01:41<1:22:42, 101.28s/it]

Epoch 1/50 loss: 2.10


Training:   4%|██▊                                                                   | 2/50 [03:22<1:21:04, 101.35s/it]

Epoch 2/50 loss: 1.81


Training:   6%|████▏                                                                 | 3/50 [05:04<1:19:25, 101.39s/it]

Epoch 3/50 loss: 1.74


Training:   8%|█████▌                                                                | 4/50 [06:45<1:17:35, 101.21s/it]

Epoch 4/50 loss: 1.70


Training:  10%|███████                                                               | 5/50 [08:27<1:16:10, 101.56s/it]

Epoch 5/50 loss: 1.69


Training:  12%|████████▍                                                             | 6/50 [10:09<1:14:32, 101.66s/it]

Epoch 6/50 loss: 1.68


Training:  14%|█████████▊                                                            | 7/50 [11:50<1:12:51, 101.65s/it]

Epoch 7/50 loss: 1.67


Training:  16%|███████████▏                                                          | 8/50 [13:32<1:11:15, 101.80s/it]

Epoch 8/50 loss: 1.66


Training:  18%|████████████▌                                                         | 9/50 [15:15<1:09:44, 102.07s/it]

Epoch 9/50 loss: 1.63


Training:  20%|█████████████▊                                                       | 10/50 [16:57<1:08:07, 102.18s/it]

Epoch 10/50 loss: 1.59


Training:  22%|███████████████▏                                                     | 11/50 [18:41<1:06:36, 102.48s/it]

Epoch 11/50 loss: 1.58


Training:  24%|████████████████▌                                                    | 12/50 [20:24<1:05:00, 102.66s/it]

Epoch 12/50 loss: 1.58


Training:  26%|█████████████████▉                                                   | 13/50 [22:06<1:03:16, 102.60s/it]

Epoch 13/50 loss: 1.58


Training:  28%|███████████████████▎                                                 | 14/50 [23:47<1:01:16, 102.13s/it]

Epoch 14/50 loss: 1.57


Training:  30%|█████████████████████▎                                                 | 15/50 [25:29<59:30, 102.03s/it]

Epoch 15/50 loss: 1.56


Training:  32%|██████████████████████▋                                                | 16/50 [27:10<57:39, 101.74s/it]

Epoch 16/50 loss: 1.56


Training:  34%|████████████████████████▏                                              | 17/50 [28:52<55:57, 101.74s/it]

Epoch 17/50 loss: 1.56


Training:  36%|█████████████████████████▌                                             | 18/50 [30:34<54:16, 101.75s/it]

Epoch 18/50 loss: 1.55


Training:  38%|██████████████████████████▉                                            | 19/50 [32:14<52:24, 101.45s/it]

Epoch 19/50 loss: 1.56


Training:  40%|████████████████████████████▍                                          | 20/50 [33:55<50:35, 101.18s/it]

Epoch 20/50 loss: 1.55


Training:  42%|█████████████████████████████▊                                         | 21/50 [35:35<48:41, 100.74s/it]

Epoch 21/50 loss: 1.55


Training:  44%|███████████████████████████████▏                                       | 22/50 [37:14<46:52, 100.44s/it]

Epoch 22/50 loss: 1.55


Training:  46%|████████████████████████████████▋                                      | 23/50 [38:54<45:04, 100.18s/it]

Epoch 23/50 loss: 1.55


Training:  48%|██████████████████████████████████                                     | 24/50 [40:35<43:34, 100.55s/it]

Epoch 24/50 loss: 1.55


Training:  50%|███████████████████████████████████▌                                   | 25/50 [42:17<42:01, 100.84s/it]

Epoch 25/50 loss: 1.54


Training:  52%|████████████████████████████████████▉                                  | 26/50 [43:57<40:18, 100.75s/it]

Epoch 26/50 loss: 1.54


Training:  54%|██████████████████████████████████████▎                                | 27/50 [45:41<39:00, 101.75s/it]

Epoch 27/50 loss: 1.55


Training:  56%|███████████████████████████████████████▊                               | 28/50 [47:27<37:46, 103.03s/it]

Epoch 28/50 loss: 1.54


Training:  58%|█████████████████████████████████████████▏                             | 29/50 [49:10<35:59, 102.86s/it]

Epoch 29/50 loss: 1.54


Training:  60%|██████████████████████████████████████████▌                            | 30/50 [50:48<33:46, 101.32s/it]

Epoch 30/50 loss: 1.55


Training:  62%|████████████████████████████████████████████                           | 31/50 [52:25<31:41, 100.10s/it]

Epoch 31/50 loss: 1.54


Training:  64%|██████████████████████████████████████████████                          | 32/50 [54:03<29:50, 99.50s/it]

Epoch 32/50 loss: 1.55


Training:  66%|███████████████████████████████████████████████▌                        | 33/50 [55:41<28:03, 99.01s/it]

Epoch 33/50 loss: 1.54


Training:  68%|████████████████████████████████████████████████▉                       | 34/50 [57:18<26:16, 98.56s/it]

Epoch 34/50 loss: 1.54


Training:  70%|██████████████████████████████████████████████████▍                     | 35/50 [58:56<24:34, 98.32s/it]

Epoch 35/50 loss: 1.54


Training:  72%|██████████████████████████████████████████████████▍                   | 36/50 [1:00:34<22:53, 98.09s/it]

Epoch 36/50 loss: 1.54


Training:  74%|███████████████████████████████████████████████████▊                  | 37/50 [1:02:10<21:10, 97.70s/it]

Epoch 37/50 loss: 1.54


Training:  76%|█████████████████████████████████████████████████████▏                | 38/50 [1:03:48<19:31, 97.62s/it]

Epoch 38/50 loss: 1.54


Training:  78%|██████████████████████████████████████████████████████▌               | 39/50 [1:05:26<17:54, 97.66s/it]

Epoch 39/50 loss: 1.54


Training:  80%|████████████████████████████████████████████████████████              | 40/50 [1:07:03<16:17, 97.73s/it]

Epoch 40/50 loss: 1.54


Training:  82%|█████████████████████████████████████████████████████████▍            | 41/50 [1:08:40<14:37, 97.49s/it]

Epoch 41/50 loss: 1.54


Training:  84%|██████████████████████████████████████████████████████████▊           | 42/50 [1:10:18<12:59, 97.45s/it]

Epoch 42/50 loss: 1.55


Training:  86%|████████████████████████████████████████████████████████████▏         | 43/50 [1:11:56<11:22, 97.53s/it]

Epoch 43/50 loss: 1.55


Training:  88%|█████████████████████████████████████████████████████████████▌        | 44/50 [1:13:33<09:44, 97.40s/it]

Epoch 44/50 loss: 1.55


Training:  90%|███████████████████████████████████████████████████████████████       | 45/50 [1:15:10<08:06, 97.31s/it]

Epoch 45/50 loss: 1.54


Training:  92%|████████████████████████████████████████████████████████████████▍     | 46/50 [1:16:48<06:29, 97.47s/it]

Epoch 46/50 loss: 1.54


Training:  94%|█████████████████████████████████████████████████████████████████▊    | 47/50 [1:18:25<04:52, 97.39s/it]

Epoch 47/50 loss: 1.54


Training:  96%|███████████████████████████████████████████████████████████████████▏  | 48/50 [1:20:02<03:14, 97.27s/it]

Epoch 48/50 loss: 1.54


Training:  98%|████████████████████████████████████████████████████████████████████▌ | 49/50 [1:21:40<01:37, 97.53s/it]

Epoch 49/50 loss: 1.54


Training: 100%|██████████████████████████████████████████████████████████████████████| 50/50 [1:23:18<00:00, 99.96s/it]


Epoch 50/50 loss: 1.54


Testing: 100%|█████████████████████████████████████████████████████████████████████████| 79/79 [00:13<00:00,  6.06it/s]

Test loss: 1.52
Test accuracy: 93.63%
Model saved.





In [15]:
dummy_input = torch.randn(1, 49,16).to(device)
model = model.to(device)
torch.onnx.export(model, dummy_input, "ViT_MNIST.onnx")  