In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.autograd import grad
from torch import nn, optim

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch.optim.lr_scheduler import MultiStepLR
from torchvision.utils import make_grid
from torchvision import transforms as torch_transforms

import sys
sys.path.append("./../../")

from modules.dvae.model import DVAE
from modules.common_utils import latent_to_img
from datasets.mnist_loader import MNISTData
from notebooks.utils import show

In [2]:
# This function computes the accuracy on the test dataset
def compute_accuracy(clf, testloader, device):
    clf.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for img, label in testloader:
            img, label = img.to(device), label.to(device)
            outputs = clf(img)
            _, predicted = torch.max(outputs.data, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()
    return correct / total

In [3]:
from einops.layers.torch import Rearrange
from modules.common_blocks import TrEncoderBlock


class ViT(nn.Module):
    def __init__(self,
                 img_height,
                 img_width,
                 img_channels,
                 patch_height,
                 patch_width,
                 embed_dim,
                 num_blocks,
                 hidden_dim,
                 n_attn_heads,
                 dropout_prob,
                 out_dim,
                 sigmoid_output=False,
                 device=torch.device('cpu')):
        super(ViT, self).__init__()
        self.device = device
        self.sigmoid_output = sigmoid_output

        self.n_h_patch = img_height // patch_height
        self.n_w_patch = img_width // patch_width
        patch_dim = img_channels * patch_height * patch_width

        self.img_pe_col = nn.Parameter(torch.randn(self.n_h_patch, 1, embed_dim))
        self.img_pe_row = nn.Parameter(torch.randn(self.n_w_patch, 1, embed_dim))

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.Linear(patch_dim, embed_dim),
        )

        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        self.tr_encoder_blocks = nn.ModuleList([
            TrEncoderBlock(n_features=embed_dim,
                           n_attn_heads=n_attn_heads,
                           n_hidden=hidden_dim,
                           dropout_prob=dropout_prob,
                           norm_first=True)
            for _ in range(num_blocks)
        ])

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, out_dim),
        )

        self.to(self.device)

    def forward(self, x, average_cls_token=False):
        batch, ch, h, w = x.size()

        x = self.to_patch_embedding(x)
        x = x.permute(1, 0, 2)

        pe_column = self.img_pe_col.repeat(self.n_w_patch, batch, 1)
        pe_row = self.img_pe_row.repeat_interleave(self.n_h_patch, dim=0).repeat(1, batch, 1)
        x = x + pe_column + pe_row
        
        cls_tokens = self.cls_token.expand(-1, batch, -1)
        
        full_x = torch.cat([cls_tokens, x], dim=0)
        for i, block in enumerate(self.tr_encoder_blocks):
            full_x = block(full_x)
        
        if average_cls_token:
            cls_input = full_x.mean(dim=0)
        else:
            cls_input = full_x[0, :, :]
        
        cls = self.mlp_head(cls_input).squeeze()

        if self.sigmoid_output:
            return torch.sigmoid(cls)
        return cls

In [4]:
class Config:
    DEVICE                      = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    img_height                  = 28
    img_width                   = 28
    img_channels                = 1
    patch_height                = 4
    patch_width                 = 4
    embed_dim                   = 128
    num_blocks                  = 8
    hidden_dim                  = 256
    n_attn_heads                = 8
    dropout_prob                = 0.1
    out_dim                     = 10
    sigmoid_output              = False

    dataset_type                = "classic"
    root_img_path               = "/m/home/home8/82/sukhoba1/data/Desktop/TA-VQVAE/data/MNIST/"
    
    NUM_EPOCHS                  = 20
    BATCH_SIZE                  = 512
    LR                          = 0.001


CONFIG = Config()

In [5]:
data_source = MNISTData(
    img_type=CONFIG.dataset_type,
    root_path=CONFIG.root_img_path,
    batch_size=CONFIG.BATCH_SIZE)

train_loader = data_source.get_train_loader()
test_loader = data_source.get_test_loader()

In [6]:
clf = ViT(
    img_height=CONFIG.img_height,
    img_width=CONFIG.img_width,
    img_channels=CONFIG.img_channels,
    patch_height=CONFIG.patch_height,
    patch_width=CONFIG.patch_width,
    embed_dim=CONFIG.embed_dim,
    num_blocks=CONFIG.num_blocks,
    hidden_dim=CONFIG.hidden_dim,
    n_attn_heads=CONFIG.n_attn_heads,
    dropout_prob=CONFIG.dropout_prob,
    out_dim=CONFIG.out_dim,
    sigmoid_output=CONFIG.sigmoid_output,
    device=CONFIG.DEVICE)

clf.train()

pass

In [7]:
print("Device in use: {}".format(CONFIG.DEVICE))

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(clf.parameters(), lr=CONFIG.LR)

iteration = 0
for epoch in range(CONFIG.NUM_EPOCHS):
    for img, label in train_loader:
        
        label = label.to(CONFIG.DEVICE)
        img = img.to(CONFIG.DEVICE)
        
        pred = clf(img)
        loss = criterion(pred, label)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if iteration % 50 == 0:
            print("Epoch: {} Iteration: {} Loss: {}".format(epoch, iteration, loss.item()))
        
        iteration += 1

Device in use: cuda
Epoch: 0 Iteration: 0 Loss: 2.488192319869995
Epoch: 0 Iteration: 50 Loss: 1.9112589359283447
Epoch: 0 Iteration: 100 Loss: 0.4603889584541321
Epoch: 1 Iteration: 150 Loss: 0.3399171233177185
Epoch: 1 Iteration: 200 Loss: 0.2801797688007355
Epoch: 2 Iteration: 250 Loss: 0.13375626504421234
Epoch: 2 Iteration: 300 Loss: 0.21883732080459595
Epoch: 2 Iteration: 350 Loss: 0.1882646232843399
Epoch: 3 Iteration: 400 Loss: 0.099082350730896
Epoch: 3 Iteration: 450 Loss: 0.1347513645887375
Epoch: 4 Iteration: 500 Loss: 0.12107304483652115
Epoch: 4 Iteration: 550 Loss: 0.07986947894096375
Epoch: 5 Iteration: 600 Loss: 0.10484745353460312
Epoch: 5 Iteration: 650 Loss: 0.07460377365350723
Epoch: 5 Iteration: 700 Loss: 0.0954517349600792
Epoch: 6 Iteration: 750 Loss: 0.10170673578977585
Epoch: 6 Iteration: 800 Loss: 0.09940138459205627
Epoch: 7 Iteration: 850 Loss: 0.08316060155630112
Epoch: 7 Iteration: 900 Loss: 0.06859330832958221
Epoch: 8 Iteration: 950 Loss: 0.113676026463

In [8]:
compute_accuracy(clf, test_loader, CONFIG.DEVICE)

0.9839