In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import os
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
#from linformer import Linformer

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%pip install einops
%pip install wandb

In [None]:
%cd drive/MyDrive/Colab Notebooks/ViT

In [None]:
from model.TNT import TNT
from dataloader import set_transforms
from utils.utils import seed_everything

seed_everything(42)

In [None]:
import torch.backends.cudnn as cudnn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
use_cuda = torch.cuda.is_available()
cudnn.benchmark = True
print('Use CUDA:', use_cuda)

In [None]:
config={
    "epochs": 20, 
    "batch_size": 64,
    "lr" : 3e-5,
    "gamma" : 0.7,
    "image_size" : 32,
    "patch_size" : 16,
    "num_classes" : 10,
    "dim" : 128,
    "depth" : 12,
    "heads" : 8,
    "mlp_dim" : 1024,
    "channels" : 3,
    "dropout" : 0.1,
    "emb_dropout": 0.
    }

# Training settings
import wandb
wandb.init(config=config, project="Cifar10_ViT", entity='NoguNogu',name='TNT')

In [None]:
trainloader, testloader = set_transforms(
    config['image_size'],
    config['batch_size'],
    name='Cifar10')._set_transforms()

In [13]:
model = TNT(
    image_size = config['image_size'],
    patch_dim = config['dim'],
    pixel_dim = 24,
    patch_size = config['patch_size'],
    pixel_size = 4,
    depth= config['depth'],
    num_classes= config['num_classes'],
    heads= config['heads'],
    dim_head = 64,
    ff_dropout = config['dropout'],
    attn_dropout = config['dropout'],
    ).to(device)

In [19]:
import torchsummary
torchsummary.summary(model, (3, 32, 32))

  pixels += rearrange(self.pixel_pos_emb, 'n d -> () n d')


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1            [-1, 3, 16, 16]               0
            Unfold-2               [-1, 48, 16]               0
         Rearrange-3               [-1, 16, 48]               0
            Linear-4               [-1, 16, 24]           1,176
         LayerNorm-5               [-1, 16, 24]              48
            Linear-6             [-1, 16, 1536]          36,864
            Linear-7               [-1, 16, 24]          12,312
           Dropout-8               [-1, 16, 24]               0
         Attention-9               [-1, 16, 24]               0
          PreNorm-10               [-1, 16, 24]               0
        LayerNorm-11               [-1, 16, 24]              48
           Linear-12               [-1, 16, 96]           2,400
             GELU-13               [-1, 16, 96]               0
          Dropout-14               [-1,





```
# tnt = TNT(
    image_size = 256,       # size of image
    patch_dim = 512,        # dimension of patch token
    pixel_dim = 24,         # dimension of pixel token
    patch_size = 16,        # patch size
    pixel_size = 4,         # pixel size
    depth = 6,              # depth
    num_classes = 1000,     # output number of classes
    attn_dropout = 0.1,     # attention dropout
    ff_dropout = 0.1        # feedforward dropout
```






In [15]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=config['lr'])
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=config['gamma'])

In [16]:
# 3. Log gradients and model parameters
wandb.watch(model, log_freq=100)
for epoch in range(config['epochs']):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(trainloader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(trainloader)
        epoch_loss += loss / len(trainloader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in testloader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(testloader)
            epoch_val_loss += val_loss / len(testloader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

    # 4. Log metrics to visualize performance
    wandb.log({
        "Epoch": epoch+1,
        "loss": epoch_loss,
        "acc": epoch_accuracy,
        "val_loss" : epoch_val_loss,
        "val_acc": epoch_val_accuracy
        })

HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))

  pixels += rearrange(self.pixel_pos_emb, 'n d -> () n d')



Epoch : 1 - loss : 2.1120 - acc: 0.2064 - val_loss : 2.0243 - val_acc: 0.2517



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 2 - loss : 2.0034 - acc: 0.2552 - val_loss : 1.9661 - val_acc: 0.2743



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 3 - loss : 1.9377 - acc: 0.2856 - val_loss : 1.9042 - val_acc: 0.3069



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 4 - loss : 1.8824 - acc: 0.3097 - val_loss : 1.8421 - val_acc: 0.3258



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 5 - loss : 1.8308 - acc: 0.3303 - val_loss : 1.8050 - val_acc: 0.3424



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 6 - loss : 1.7998 - acc: 0.3415 - val_loss : 1.7690 - val_acc: 0.3504



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 7 - loss : 1.7624 - acc: 0.3582 - val_loss : 1.7350 - val_acc: 0.3708



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 8 - loss : 1.7440 - acc: 0.3638 - val_loss : 1.7239 - val_acc: 0.3742



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 9 - loss : 1.7205 - acc: 0.3748 - val_loss : 1.7014 - val_acc: 0.3785



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 10 - loss : 1.7040 - acc: 0.3833 - val_loss : 1.7251 - val_acc: 0.3754



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 11 - loss : 1.6894 - acc: 0.3878 - val_loss : 1.6731 - val_acc: 0.3946



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 12 - loss : 1.6764 - acc: 0.3890 - val_loss : 1.6636 - val_acc: 0.3963



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 13 - loss : 1.6598 - acc: 0.3983 - val_loss : 1.6411 - val_acc: 0.4114



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 14 - loss : 1.6511 - acc: 0.4036 - val_loss : 1.6283 - val_acc: 0.4122



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 15 - loss : 1.6396 - acc: 0.4065 - val_loss : 1.6238 - val_acc: 0.4081



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 16 - loss : 1.6277 - acc: 0.4126 - val_loss : 1.6211 - val_acc: 0.4177



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 17 - loss : 1.6131 - acc: 0.4161 - val_loss : 1.6222 - val_acc: 0.4218



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 18 - loss : 1.6015 - acc: 0.4220 - val_loss : 1.6056 - val_acc: 0.4198



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 19 - loss : 1.5990 - acc: 0.4237 - val_loss : 1.6053 - val_acc: 0.4240



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 20 - loss : 1.5860 - acc: 0.4276 - val_loss : 1.5773 - val_acc: 0.4356



In [17]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Thu May  6 10:48:11 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P0    54W / 300W |   2041MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces