In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import os
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
#from linformer import Linformer

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
%pip install einops
%pip install wandb

Collecting einops
  Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl
Installing collected packages: einops
Successfully installed einops-0.3.0
Collecting wandb
[?25l  Downloading https://files.pythonhosted.org/packages/67/5a/b037b50f9849212863a2fed313624d8f6f33ffa4ce89dc706e2a0e98c780/wandb-0.10.29-py2.py3-none-any.whl (2.1MB)
[K     |████████████████████████████████| 2.1MB 5.7MB/s 
Collecting shortuuid>=0.5.0
  Downloading https://files.pythonhosted.org/packages/25/a6/2ecc1daa6a304e7f1b216f0896b26156b78e7c38e1211e9b798b4716c53d/shortuuid-1.0.1-py3-none-any.whl
Collecting pathtools
  Downloading https://files.pythonhosted.org/packages/e7/7f/470d6fcdf23f9f3518f6b0b76be9df16dcc8630ad409947f8be2eb0ed13a/pathtools-0.1.2.tar.gz
Collecting GitPython>=1.0.0
[?25l  Downloading https://files.pythonhosted.org/packages/a6/99/98019716955ba243657daedd1de8f3a88ca1f5b75057c38e959db22fb87b/GitPyt

In [4]:
%cd drive/MyDrive/Colab Notebooks/ViT

/content/drive/MyDrive/Colab Notebooks/ViT


In [5]:
from model.vit import ViT
from model.t2t import T2TViT, RearrangeImage
from dataloader import set_transforms
from utils.utils import seed_everything

seed_everything(42)

In [6]:
import torch.backends.cudnn as cudnn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
use_cuda = torch.cuda.is_available()
cudnn.benchmark = True
print('Use CUDA:', use_cuda)

Use CUDA: True


In [7]:
config={
    "epochs": 20, 
    "batch_size": 64,
    "lr" : 3e-5,
    "gamma" : 0.7,
    "image_size" : 32,
    "patch_size" : 16,
    "num_classes" : 10,
    "dim" : 128,
    "depth" : 12,
    "heads" : 8,
    "mlp_dim" : 1024,
    "channels" : 3,
    "dropout" : 0.,
    "emb_dropout": 0.
    }

# Training settings
import wandb
wandb.init(config=config, project="Cifar10_ViT", entity='NoguNogu',name='T2T')

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [8]:
trainloader, testloader = set_transforms(
    config['image_size'],
    config['batch_size'],
    name='Cifar10')._set_transforms()

In [14]:
#T2TはPatch_sizeの指定がないことに注意
#nn.Unfoldの計算：https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html
model = T2TViT(
    dim = config['dim'],
    image_size = config['image_size'],
    num_classes= config['num_classes'],
    channels = config['channels'],
    depth = config['depth'],
    heads = config['heads'],
    mlp_dim = config['mlp_dim'],
    t2t_layers = ((3, 2), (3, 2), (3, 2))
    ).to(device)

In [15]:
from torchsummary import summary
summary(model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
          Identity-1            [-1, 3, 32, 32]               0
            Unfold-2              [-1, 27, 256]               0
         Rearrange-3              [-1, 256, 27]               0
         LayerNorm-4              [-1, 256, 27]              54
            Linear-5              [-1, 256, 81]           2,187
           Softmax-6          [-1, 1, 256, 256]               0
          Identity-7              [-1, 256, 27]               0
         Attention-8              [-1, 256, 27]               0
           PreNorm-9              [-1, 256, 27]               0
        LayerNorm-10              [-1, 256, 27]              54
           Linear-11              [-1, 256, 27]             756
             GELU-12              [-1, 256, 27]               0
          Dropout-13              [-1, 256, 27]               0
           Linear-14              [-1, 

In [16]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=config['lr'])
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=config['gamma'])

In [17]:
# 3. Log gradients and model parameters
wandb.watch(model, log_freq=100)
for epoch in range(config['epochs']):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(trainloader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(trainloader)
        epoch_loss += loss / len(trainloader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in testloader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(testloader)
            epoch_val_loss += val_loss / len(testloader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

    # 4. Log metrics to visualize performance
    wandb.log({
        "Epoch": epoch+1,
        "loss": epoch_loss,
        "acc": epoch_accuracy,
        "val_loss" : epoch_val_loss,
        "val_acc": epoch_val_accuracy
        })

HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 1 - loss : 1.9409 - acc: 0.2676 - val_loss : 1.8588 - val_acc: 0.3051



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 2 - loss : 1.7975 - acc: 0.3348 - val_loss : 1.7729 - val_acc: 0.3464



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 3 - loss : 1.7222 - acc: 0.3725 - val_loss : 1.6691 - val_acc: 0.3939



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 4 - loss : 1.6534 - acc: 0.3984 - val_loss : 1.6258 - val_acc: 0.4099



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 5 - loss : 1.6047 - acc: 0.4175 - val_loss : 1.5709 - val_acc: 0.4303



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 6 - loss : 1.5620 - acc: 0.4346 - val_loss : 1.5488 - val_acc: 0.4448



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 7 - loss : 1.5281 - acc: 0.4488 - val_loss : 1.5089 - val_acc: 0.4514



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 8 - loss : 1.4930 - acc: 0.4629 - val_loss : 1.4808 - val_acc: 0.4707



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 9 - loss : 1.4575 - acc: 0.4757 - val_loss : 1.4621 - val_acc: 0.4785



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 10 - loss : 1.4355 - acc: 0.4839 - val_loss : 1.4461 - val_acc: 0.4770



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 11 - loss : 1.4027 - acc: 0.4980 - val_loss : 1.4199 - val_acc: 0.4892



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 12 - loss : 1.3749 - acc: 0.5054 - val_loss : 1.3774 - val_acc: 0.5140



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 13 - loss : 1.3449 - acc: 0.5204 - val_loss : 1.3774 - val_acc: 0.5093



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 14 - loss : 1.3186 - acc: 0.5306 - val_loss : 1.3590 - val_acc: 0.5158



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 15 - loss : 1.2942 - acc: 0.5384 - val_loss : 1.3182 - val_acc: 0.5336



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 16 - loss : 1.2668 - acc: 0.5491 - val_loss : 1.3034 - val_acc: 0.5329



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 17 - loss : 1.2461 - acc: 0.5573 - val_loss : 1.2721 - val_acc: 0.5456



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 18 - loss : 1.2273 - acc: 0.5652 - val_loss : 1.2511 - val_acc: 0.5574



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 19 - loss : 1.2026 - acc: 0.5731 - val_loss : 1.2347 - val_acc: 0.5590



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 20 - loss : 1.1737 - acc: 0.5839 - val_loss : 1.2040 - val_acc: 0.5711



In [18]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Thu May  6 13:56:48 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   72C    P0    55W / 250W |  15537MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces