<a href="https://colab.research.google.com/github/ShokiSuzuki/MPRGDeepLearningLectureNotebook/blob/dev/16_vit/01_vit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Vision Transformer

---
Vision Transformer (ViT) [1] はTransformerをコンピュータビジョンに応用した画像分類手法です．ViTは入力画像を固定領域のパッチに分割して埋め込み，Transformer Encoderに入力します．Transformer Encoder内のSelf Attentionでパッチの関係を学習することで，畳み込みニューラルネットワーク (CNN: Convolutional Neural Network) とは異なり，浅い層から画像全体の特徴を捉えられます．これにより，ImageNetなどのクラス分類タスクでCNNの性能を上回りました．また，ViTはセマンティックセグメンテーションや動画像認識などのタスクに応用され，CNNベースの性能を上回りました．

<img src="https://github.com/ShokiSuzuki/MPRGDeepLearningLectureNotebook/blob/dev/16_vit/model_scheme.png?raw=true" width=60%>


## Patch Embedding

Patch Embeddingは，入力画像を固定領域のパッチに分割して埋め込む処理を行います．例えば，$224 \times 224$ピクセルの画像を入力として各パッチのサイズを$16 \times 16$ピクセルとした場合，重なり合わないように$14 \times 14$の領域に分割します．分割されたパッチは，それぞれflatにして全結合に入力することで埋め込みます．このとき，学習可能なパラメータであるクラストークンを結合し，Transformer Encoderを通した後にクラス分類に使用します．

## Position Embedding

Position Embeddingは，パッチの位置情報を学習するパラメータです．このパラメータは，Patch Embeddingのあとにそれぞれのパッチに足されます．ネットワークが学習する過程で位置情報を獲得するため，学習条件で値が変化します．

Patch EmbeddingとPosition Embeddingを定式化すると以下のようになります．

\begin{aligned}
\mathbf{z}_0 &= [ \mathbf{x}_\text{class}; \, \mathbf{x}^1_p \mathbf{E}; \, \mathbf{x}^2_p \mathbf{E}; \cdots; \, \mathbf{x}^{N}_p \mathbf{E} ] + \mathbf{E}_{pos},
&& \mathbf{E} \in \mathbb{R}^{(P^2 \cdot C) \times D},\, \mathbf{E}_{pos}  \in \mathbb{R}^{(N + 1) \times D}
\end{aligned}

ここで，$\mathbf{x}_\text{class}$はクラストークン，$\mathbf{x}_p$はパッチ，$N$はパッチ数，$P$はパッチサイズ，$C$はチャンネル数，$D$は埋め込み次元数，$\mathbf{E}$は全結合，$\mathbf{E}_{pos}$はPosition Embeddingです．

## ファインチューニング

ViTは，大規模データセットで事前学習して小規模データセットでファインチューニングすることが効果的です．事前学習するときに画像枚数を変更すると，CNNは枚数を多くしても精度に限界がありますが，ViTは枚数が多いほど精度向上が見込めます．ViTは，JFT-300Mという3億枚の画像が含まれているデータセットで事前学習し，様々なデータセットでファインチューニングをすることでSoTAを達成していますが，非公開のデータセットのため再現不可能です．

# Vision Transformerの学習

### モジュールの読み込み

In [None]:
!pip install timm==0.5.4

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm==0.5.4
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
[K     |████████████████████████████████| 431 kB 23.8 MB/s 
Installing collected packages: timm
Successfully installed timm-0.5.4


In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import timm
from timm.models.layers import trunc_normal_, DropPath
from timm.models import create_model
from timm.scheduler.cosine_lr import CosineLRScheduler
from functools import partial
from time import time

### ネットワークの定義

#### Patch Embedding

Patch Embeddingでは，画像をパッチに分割して埋め込みます．埋め込まれたパッチをパッチトークンと呼びます．ViTはパッチをflatにして全結合に入力しますが，実装上は，カーネルサイズ（パッチサイズ）= ストライドとした2次元畳み込みでも可能です．

In [None]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size    = img_size
        self.patch_size  = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # 埋め込み処理のための重み
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        return self.proj(x).flatten(2).transpose(1, 2)

#### Multi-Head Attention

Self-Attentionはパッチトークンを空間方向に混ぜるような変換を行います．Multi-Head Attentionはパッチトークンをベクトルのdepth方向に$h$個に分割し，それぞれでSelf-Attentionを求めます．これにより，Head毎に注目したパッチが異なる特徴が得られるため，アンサンブル効果による精度向上が見込めます．

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, "dim should be divisible by num_heads"
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

#### Multi-Layer Perceptron

Multi-Head Attentionでは空間方向に混ぜるような変換を行うのに対し，Multi-Layer Perceptronではベクトルのdepth方向に混ぜるような変換を行います．活性化関数にはGELUを使用します．

In [None]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

#### Transformer Encoder

Transformer Encoderは，Multi-Head AttentionとMulti-Layer Perceptronを交互に使用します．また，それぞれResidual Connectionを用います．

In [None]:
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., 
                 attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        x = x + self.drop_path1(self.attn(self.norm1(x)))
        x = x + self.drop_path2(self.mlp(self.norm2(x)))
        return x

#### ネットワーク全体の構築

これまで定義したクラスをもとにViTを構築します．

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, 
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, 
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, 
                 norm_layer=None, act_layer=None, block_fn=Block):
        super().__init__()
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim

        self.patch_embed = embed_layer(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )
        num_patches = self.patch_embed.num_patches

        # クラストークン
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Position Embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop  = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth
        
        # Transformer Encoder
        self.blocks = nn.Sequential(*[
            block_fn(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i],
                norm_layer=norm_layer,
                act_layer=act_layer
            )
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

        # Classifier Head
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        
        # クラストークンの結合
        x = torch.cat((cls_tokens, x), dim=1)
        
        x = self.pos_drop(x + self.pos_embed)
        
        # Transformer Encoderへの入力
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x[:, 0]) # 0番目にあるクラストークンを取り出して全結合へ入力
        return x

In [None]:
def vit_tiny_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

def vit_small_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

### データの準備
今回は，CIFAR-10を用いてフルスクラッチで学習します．

In [None]:
img_size = 32

train_transform = transforms.Compose([transforms.RandomCrop(img_size, padding=4),
                                      transforms.Resize(img_size),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                     ])

test_transform  = transforms.Compose([transforms.Resize(img_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                     ])

dataset_train = torchvision.datasets.CIFAR10("./", train=True, transform=train_transform, download=True)
dataset_test  = torchvision.datasets.CIFAR10("./", train=False, transform=test_transform, download=False)

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=128, num_workers=2, pin_memory=True, drop_last=True)
dataloader_test  = torch.utils.data.DataLoader(dataset_test, batch_size=64, num_workers=2, pin_memory=True, drop_last=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to ./


  cpuset_checked))


### 学習条件の設定

ViTの性能をCNNと比較するために，ViTのSmallモデルとパラメータ数が同等のResNet-50を用います．

In [None]:
# クラス数の設定
num_classes = 10

# ViTの定義
vit = vit_small_patch16_224(pretrained=False, num_classes=num_classes, img_size=img_size, patch_size=4)
# CNNの定義
cnn = create_model("resnet50", pretrained=False, num_classes=num_classes)


lr  = 0.0005
weight_decay = 0.05
epochs = 10   # エポック数の設定
warmup_t = 3

optimizer_vit     = torch.optim.AdamW(vit.parameters(), lr=lr, weight_decay=weight_decay)
lr_scheduler_vit  = CosineLRScheduler(optimizer=optimizer_vit, t_initial=epochs, warmup_t=warmup_t)
optimizer_cnn     = torch.optim.AdamW(cnn.parameters(), lr=lr, weight_decay=weight_decay)
lr_scheduler_cnn  = CosineLRScheduler(optimizer=optimizer_cnn, t_initial=epochs, warmup_t=warmup_t)

パラメータ数の確認

In [None]:
def num_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("ViT parameters: ", num_parameters(vit))
print("CNN parameters: ", num_parameters(cnn))

ViT parameters:  21342346
CNN parameters:  23528522


### CNNの学習

In [None]:
criterion = torch.nn.CrossEntropyLoss()

device = torch.device("cuda")
cnn.to(device)
use_amp = True
scaler_cnn = torch.cuda.amp.GradScaler(enabled=use_amp)

start = time()
for epoch in range(epochs):
    cnn.train()
    
    sum_loss = 0.0
    count    = 0
    for img, cls in dataloader_train:
        
        img = img.to(device, non_blocking=True)
        cls = cls.to(device, non_blocking=True)
        
        # CNNに画像を入力 & 損失を計算
        with torch.cuda.amp.autocast(enabled=use_amp):
            logit_cnn = cnn(img)
            loss_cnn  = criterion(logit_cnn, cls)
            
        # CNNの更新
        optimizer_cnn.zero_grad()
        scaler_cnn.scale(loss_cnn).backward()
        scaler_cnn.step(optimizer_cnn)
        scaler_cnn.update()
        
        # ログ用に損失値と正解したデータ数を取得
        sum_loss += loss_cnn.item()
        count    += torch.sum(logit_cnn.argmax(dim=1) == cls).item()
        
    lr_scheduler_cnn.step(epoch)
    
    # ログの表示
    print(f"epoch: {epoch+1},\
            mean loss: {round(sum_loss/len(dataloader_train), 3)},\
            mean accuracy: {round(count/len(dataloader_train.dataset), 2)},\
            elapsed_time : {round(time()-start, 2)}")
    
    # 評価
    cnn.eval()
    count = 0
    with torch.no_grad():
        for img, cls in dataloader_test:
            img = img.to(device, non_blocking=True)
            cls = cls.to(device, non_blocking=True)
        
            logit_cnn = cnn(img)
            count += torch.sum(logit_cnn.argmax(dim=1) == cls).item()
            
        print(f"test accuracy: {count/len(dataloader_test.dataset)}")

epoch: 1,            mean loss: 2.35,            mean accuracy: 0.11,            elapsed_time : 44.7
test accuracy: 0.1253
epoch: 2,            mean loss: 2.352,            mean accuracy: 0.11,            elapsed_time : 85.9
test accuracy: 0.1271
epoch: 3,            mean loss: 1.713,            mean accuracy: 0.36,            elapsed_time : 128.19
test accuracy: 0.4976
epoch: 4,            mean loss: 1.371,            mean accuracy: 0.5,            elapsed_time : 169.86
test accuracy: 0.5753
epoch: 5,            mean loss: 1.174,            mean accuracy: 0.58,            elapsed_time : 211.54
test accuracy: 0.6257
epoch: 6,            mean loss: 1.017,            mean accuracy: 0.63,            elapsed_time : 253.89
test accuracy: 0.6742
epoch: 7,            mean loss: 0.906,            mean accuracy: 0.68,            elapsed_time : 299.72
test accuracy: 0.7056
epoch: 8,            mean loss: 0.824,            mean accuracy: 0.7,            elapsed_time : 344.07
test accuracy: 0.7296

### ViTの学習

In [None]:
criterion = torch.nn.CrossEntropyLoss()

device = torch.device("cuda")
vit.to(device)
use_amp = True
scaler_vit = torch.cuda.amp.GradScaler(enabled=use_amp)

start = time()
for epoch in range(epochs):
    vit.train()
    
    sum_loss = 0.0
    count    = 0
    for img, cls in dataloader_train:
        
        img = img.to(device, non_blocking=True)
        cls = cls.to(device, non_blocking=True)
        
        # ViTに画像を入力 & 損失を計算
        with torch.cuda.amp.autocast(enabled=use_amp):
            logit_vit = vit(img)
            loss_vit  = criterion(logit_vit, cls)
            
        # ViTの更新
        optimizer_vit.zero_grad()
        scaler_vit.scale(loss_vit).backward()
        scaler_vit.step(optimizer_vit)
        scaler_vit.update()
        
        # ログ用に損失値と正解したデータ数を取得
        sum_loss += loss_vit.item()
        count    += torch.sum(logit_vit.argmax(dim=1) == cls).item()
        
    lr_scheduler_vit.step(epoch)
    
    # ログの表示
    print(f"epoch: {epoch+1},\
            mean loss: {round(sum_loss/len(dataloader_train), 3)},\
            mean accuracy: {round(count/len(dataloader_train.dataset), 2)},\
            elapsed_time : {round(time()-start, 2)}")
    
    # 評価
    vit.eval()
    count = 0
    with torch.no_grad():
        for img, cls in dataloader_test:
            img = img.to(device, non_blocking=True)
            cls = cls.to(device, non_blocking=True)
        
            logit_vit = vit(img)
            count += torch.sum(logit_vit.argmax(dim=1) == cls).item()
            
        print(f"test accuracy: {count/len(dataloader_test.dataset)}")

epoch: 1,            mean loss: 2.388,            mean accuracy: 0.1,            elapsed_time : 68.22
test accuracy: 0.0993
epoch: 2,            mean loss: 2.389,            mean accuracy: 0.1,            elapsed_time : 147.16
test accuracy: 0.0993
epoch: 3,            mean loss: 1.826,            mean accuracy: 0.32,            elapsed_time : 225.19
test accuracy: 0.4146
epoch: 4,            mean loss: 1.531,            mean accuracy: 0.43,            elapsed_time : 303.73
test accuracy: 0.4897
epoch: 5,            mean loss: 1.361,            mean accuracy: 0.5,            elapsed_time : 382.24
test accuracy: 0.5143
epoch: 6,            mean loss: 1.233,            mean accuracy: 0.56,            elapsed_time : 460.59
test accuracy: 0.5802
epoch: 7,            mean loss: 1.136,            mean accuracy: 0.59,            elapsed_time : 539.01
test accuracy: 0.6044
epoch: 8,            mean loss: 1.046,            mean accuracy: 0.62,            elapsed_time : 617.51
test accuracy: 0.6

# ImageNetで事前学習したモデルを用いたファインチューニング

次に，ImageNetで事前学習したモデルを用いてCIFAR-10でファインチューニングします．

### データの準備

In [None]:
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                      transforms.Resize(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                     ])

test_transform  = transforms.Compose([transforms.Resize(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                     ])

dataset_train = torchvision.datasets.CIFAR10("./", train=True, transform=train_transform, download=True)
dataset_test  = torchvision.datasets.CIFAR10("./", train=False, transform=test_transform, download=False)

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=64, num_workers=2, pin_memory=True, drop_last=True)
dataloader_test  = torch.utils.data.DataLoader(dataset_test, batch_size=64, num_workers=2, pin_memory=True, drop_last=False)

Files already downloaded and verified


### 学習条件の設定

ファインチューニングでも，ViTのSmallモデルとパラメータ数が同等のResNet-50を用います．

In [None]:
# クラス数の設定
num_classes = 10

# ViTの定義 (timmのcreate_modelを使用)
vit_finetune = create_model("deit_small_patch16_224", pretrained=True, num_classes=num_classes)
# CNNの定義
cnn_finetune = create_model("resnet50", pretrained=True, num_classes=num_classes)

lr  = 0.0001
weight_decay = 0.05
epochs = 5
warmup_t = 0

optimizer_vit     = torch.optim.AdamW(vit_finetune.parameters(), lr=lr, weight_decay=weight_decay)
lr_scheduler_vit  = CosineLRScheduler(optimizer=optimizer_vit, t_initial=epochs, warmup_t=warmup_t)
optimizer_cnn     = torch.optim.AdamW(cnn_finetune.parameters(), lr=lr, weight_decay=weight_decay)
lr_scheduler_cnn  = CosineLRScheduler(optimizer=optimizer_cnn, t_initial=epochs, warmup_t=warmup_t)

Downloading: "https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth" to /root/.cache/torch/hub/checkpoints/deit_small_patch16_224-cd65a155.pth
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth" to /root/.cache/torch/hub/checkpoints/resnet50_a1_0-14fe96d1.pth


### CNNの学習

In [None]:
criterion = torch.nn.CrossEntropyLoss()

device = torch.device("cuda")
cnn_finetune.to(device)
use_amp = True
scaler_cnn = torch.cuda.amp.GradScaler(enabled=use_amp)

start = time()
for epoch in range(epochs):
    cnn_finetune.train()
    
    sum_loss = 0.0
    count    = 0
    for img, cls in dataloader_train:
        img = img.to(device, non_blocking=True)
        cls = cls.to(device, non_blocking=True)
        
        # CNNに画像を入力 & 損失を計算
        with torch.cuda.amp.autocast(enabled=use_amp):
            logit_cnn = cnn_finetune(img)
            loss_cnn  = criterion(logit_cnn, cls)
            
        # CNNの更新
        optimizer_cnn.zero_grad()
        scaler_cnn.scale(loss_cnn).backward()
        scaler_cnn.step(optimizer_cnn)
        scaler_cnn.update()
        
        # ログ用に損失値と正解したデータ数を取得
        sum_loss += loss_cnn.item()
        count    += torch.sum(logit_cnn.argmax(dim=1) == cls).item()
        
    lr_scheduler_cnn.step(epoch)
    
    # ログの表示
    print(f"epoch: {epoch+1},\
            mean loss: {round(sum_loss/len(dataloader_train), 3)},\
            mean accuracy: {round(count/len(dataloader_train.dataset), 2)},\
            elapsed_time : {round(time()-start, 2)}")
    
    # 評価
    cnn_finetune.eval()
    count = 0
    with torch.no_grad():
        for img, cls in dataloader_test:
            img = img.to(device, non_blocking=True)
            cls = cls.to(device, non_blocking=True)
        
            logit_cnn = cnn_finetune(img)
            count += torch.sum(logit_cnn.argmax(dim=1) == cls).item()
            
        print(f"test accuracy: {count/len(dataloader_test.dataset)}")

epoch: 1,            mean loss: 0.672,            mean accuracy: 0.81,            elapsed_time : 217.19
test accuracy: 0.9459
epoch: 2,            mean loss: 0.165,            mean accuracy: 0.95,            elapsed_time : 463.87
test accuracy: 0.9578
epoch: 3,            mean loss: 0.108,            mean accuracy: 0.96,            elapsed_time : 711.87
test accuracy: 0.9628
epoch: 4,            mean loss: 0.07,            mean accuracy: 0.98,            elapsed_time : 958.31
test accuracy: 0.9657
epoch: 5,            mean loss: 0.048,            mean accuracy: 0.99,            elapsed_time : 1206.16
test accuracy: 0.9672


### ViTの学習

In [None]:
criterion = torch.nn.CrossEntropyLoss()

device = torch.device("cuda")
vit_finetune.to(device)
use_amp = True
scaler_vit = torch.cuda.amp.GradScaler(enabled=use_amp)

start = time()
for epoch in range(epochs):
    vit_finetune.train()
    
    sum_loss = 0.0
    count    = 0
    for img, cls in dataloader_train:
        img = img.to(device, non_blocking=True)
        cls = cls.to(device, non_blocking=True)
        
        # ViTに画像を入力 & 損失を計算
        with torch.cuda.amp.autocast(enabled=use_amp):
            logit_vit = vit_finetune(img)
            loss_vit  = criterion(logit_vit, cls)
            
        # ViTの更新
        optimizer_vit.zero_grad()
        scaler_vit.scale(loss_vit).backward()
        scaler_vit.step(optimizer_vit)
        scaler_vit.update()
        
        # ログ用に損失値と正解したデータ数を取得
        sum_loss += loss_vit.item()
        count    += torch.sum(logit_vit.argmax(dim=1) == cls).item()
        
    lr_scheduler_vit.step(epoch)
    
    # ログの表示
    print(f"epoch: {epoch+1},\
            mean loss: {round(sum_loss/len(dataloader_train), 3)},\
            mean accuracy: {round(count/len(dataloader_train.dataset), 2)},\
            elapsed_time : {round(time()-start, 2)}")
    
    # 評価
    vit_finetune.eval()
    count = 0
    with torch.no_grad():
        for img, cls in dataloader_test:
            img = img.to(device, non_blocking=True)
            cls = cls.to(device, non_blocking=True)
        
            logit_vit = vit_finetune(img)
            count += torch.sum(logit_vit.argmax(dim=1) == cls).item()
            
        print(f"test accuracy: {count/len(dataloader_test.dataset)}")

  cpuset_checked))


epoch: 1,            mean loss: 0.219,            mean accuracy: 0.93,            elapsed_time : 273.08
test accuracy: 0.9547
epoch: 2,            mean loss: 0.101,            mean accuracy: 0.97,            elapsed_time : 586.94
test accuracy: 0.9631
epoch: 3,            mean loss: 0.064,            mean accuracy: 0.98,            elapsed_time : 900.3
test accuracy: 0.9594
epoch: 4,            mean loss: 0.036,            mean accuracy: 0.99,            elapsed_time : 1213.06
test accuracy: 0.9683
epoch: 5,            mean loss: 0.015,            mean accuracy: 0.99,            elapsed_time : 1524.83
test accuracy: 0.9749


# 課題
1. Multi-Head Attentionのhead数を変えてみましょう

# 参考文献

[1]  Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. In *International Conference on Learning Representations*, 2021.