<a href="https://colab.research.google.com/github/ShinAsakawa/ShinAsakawa.github.io/blob/master/2023notebooks/2023_0804vit_example_Vision_Transformers_versus_Convolutional_Neural_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/RustamyF/vision-transformer/blob/master/vit_example.ipynb)


- source: `https://github.com/RustamyF/vision-transformer/blob/master/vit_example.ipynb`
- colab_source:  `https://colab.research.google.com/github/RustamyF/vision-transformer/blob/master/vit_example.ipynb`
- blog: `https://medium.com/@faheemrustamy/vision-transformers-vs-convolutional-neural-networks-5fe8f9e18efc`

# Vision Transformers vs. Convolutional Neural Networks

Fahim Rustamy, PhD

このブログ記事は google の研究チームによる [AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE](https://arxiv.org/pdf/2010.11929.pdf) 論文に触発された。
この論文では，画像分類課題に，画像パッチに直接適用される純粋な Transformer の使用を提案している。
Vision Transformer (ViT) は，大量のデータで事前に訓練された後，訓練に必要な計算リソースが少ない一方で，複数のベンチマークにおいて最先端の畳み込みネットワークを凌駕する。
<!-- This blog post is inspired by the paper titled [AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE](https://arxiv.org/pdf/2010.11929.pdf) from google’s research team.
The paper proposes using a pure Transformer applied directly to image patches for image classification tasks.
The Vision Transformer (ViT) outperforms state-of-the-art convolutional networks in multiple benchmarks while requiring fewer computational resources to train, after being pre-trained on large amounts of data.-->

Transformer は，その計算効率とスケーラビリティにより，自然言語処理において選択されるモデルとなっている。
コンピュータビジョンでは，畳み込みニューラルネットワーク (CNN) アーキテクチャが依然として主流であるが，CNN と自己注意を組み合わせることを試みた研究者もいる。
著者らは，標準的な Transformerを画像に直接適用する実験を行い，中規模のデータセットで訓練した場合，ResNetのようなアーキテクチャと比較して，モデルの精度が控えめであることを発見した。
しかし，より大きなデータセットで訓練した場合，Vision Transformer (ViT) は優れた結果を達成し，複数の画像認識ベンチマークにおいて，最先端の技術に近づいたり，凌駕したりした。
<!--Transformers have become the model of choice in NLP due to their computational efficiency and scalability.
In computer vision, convolutional neural network (CNN) architectures remain dominant, but some researchers have tried combining CNNs with self-attention.
The authors experimented with applying a standard Transformer directly to images and found that when trained on mid-sized datasets, the models had modest accuracy compared to ResNet-like architectures.
However, when trained on larger datasets, the Vision Transformer (ViT) achieved excellent results and approached or surpassed the state of the art on multiple image recognition benchmarks. -->

<img src="https://miro.medium.com/v2/resize:fit:1400/0*brmcPLvJpiQWjZpY">

図 1 (原著論文より引用) は，2D 画像を平坦化された 2D パッチの系列に変換することで処理するモデルを記述している。
パッチは次に，学習可能な線形射影を用いて一定の潜在ベクトルサイズに写像される。
学習可能な埋め込みがパッチの系列に付加され，Transformer 符号化器の出力におけるその状態が画像表現となる。
この画像表現は，事前学習または微調整のために分類ヘッドに渡される。
位置情報を保持するために位置埋め込みが追加され，埋め込みベクトルの系列が Transformer 符号化器の入力となる。
<!-- Figure 1 (taken from the original paper) describes a model that processes 2D images by transforming them into sequences of flattened 2D patches.
The patches are then mapped to a constant latent vector size with a trainable linear projection.
A learnable embedding is prepended to the sequence of patches and its state at the output of the Transformer encoder serves as the image representation.
The image representation is then passed through a classification head for either pre-training or fine-tuning.
Position embeddings are added to retain positional information and the sequence of embedding vectors serves as input to the Transformer encoder, which consists of alternating layers of multiheaded self-attention and MLP blocks. -->

過去，CNN は長い間，画像処理課題の有力な選択肢であった。
CNN は畳み込み層を通して局所的な空間パターンを捉えることに優れており，階層的な特徴抽出を可能にする。
CNN は大量の画像データから学習することに長けており，画像分類，物体検出，切り分けなどの課題で目覚ましい成功を収めている。
<!-- In the past, CNNs have been the go-to choice for image processing tasks for a long time.
They excel at capturing local spatial patterns through convolutional layers, enabling hierarchical feature extraction.
CNNs are adept at learning from large amounts of image data and have achieved remarkable success in tasks like image classification, object detection, and segmentation.-->

CNN は様々なコンピュータビジョン課題で実績があり，大規模なデータセットを効率的に扱うが，Vision Transformer は大域的な依存関係や文脈の理解が重要なシナリオで優位性を発揮する。
しかし，Vision Transformer は通常，CNN と同等の性能を達成するために，より大量の学習データを必要とする。
また，CNN は並列処理が可能なため計算効率が高く，実時間でリソースに制約のある応用ではより実用的である。
<!-- While CNNs have a proven track record in various computer vision tasks and handle large-scale datasets efficiently, Vision Transformers offer advantages in scenarios where global dependencies and contextual understanding are crucial.
However, Vision Transformers typically require larger amounts of training data to achieve comparable performance to CNNs.
Also, CNNs are computationally efficient due to their parallelizable nature, making them more practical for real-time and resource-constrained applications. -->

### 例 CNN と視覚 Transformer の比較 <!-- ### Example: CNN vs. Vision Transformer-->

本節では，Kaggle で公開されている cats and dogs データセットに対して，CNN と vision Transformer の両方のアプローチを使って視覚分類器を学習する。
まず，Kaggle から 25000 枚の RGB 画像を含む cats and dogs データセットをダウンロードする。
まだの方は，こちらの説明を読んで，Kaggle API クレデンシャルの設定方法を学んで欲しい。
以下の Python コードはデータセットを現在の作業ディレクトリにダウンロードする。
<!-- In this section, we will train a vision classifier on the cats and dogs dataset available in Kaggle, using both CNN and vision transformer approaches.
First, we will download the cats and dogs dataset from Kaggle with 25000 RGB images.
If you haven’t already, you can read the instructions here to learn how to get your Kaggle API credential set up.
The following Python code will download the dataset into your current working directory. -->

In [7]:
#!mkdir ~/.kaggle
!touch ~/.kaggle/kaggle.json

#api_token = {"username":"username","key":"api-key"}
api_token = {"username":"turingcomplete","key":"a49cdd9a6452346d9fdacca035bde21a"}

import json

with open('/root/.kaggle/kaggle.json', 'w') as file:
    json.dump(api_token, file)

!chmod 600 ~/.kaggle/kaggle.json

In [None]:
# !mkdir ~/.kaggle
# !cp kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json

In [None]:
!pip install kaggle einops

In [8]:
from kaggle.api.kaggle_api_extended import KaggleApi

api = KaggleApi()
api.authenticate()

# we write to the current directory with './'
api.dataset_download_files('karakaggle/kaggle-cat-vs-dog-dataset', path='./')

In [9]:
!unzip -qq kaggle-cat-vs-dog-dataset.zip
!rm -r kaggle-cat-vs-dog-dataset.zip

In [None]:
# !wget --no-check-certificate \
#     https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip \
#     -O pets


In [None]:
!git clone https://github.com/RustamyF/vision-transformer.git
!mv vision-transformer/vision_tr .

In [None]:
# !unzip pets
# !rm pets
# !mv -v cats_and_dogs_filtered/validation/dogs/* cats_and_dogs_filtered/train/dogs/
# !mv -v cats_and_dogs_filtered/validation/cats/* cats_and_dogs_filtered/train/cats/

In [11]:
from vision_tr.simple_vit import Transformer

### CNN アプローチ <!-- ### CNN Approach-->


この画像分類器の CNN モデルは、カーネル・サイズ 3，ストライド 2，最大プーリング層  2 の 3 層の 2 次元畳み込みで構成される。
畳み込み層に続いて、2 つの完全連結層があり，それぞれ 10 ノードで構成される。
以下はこの構造を示すコード・スニペットである：
<!-- The CNN model for this image classifier consists of three layers of 2D convolutions, with a kernel size of 3, stride of 2, and a max pooling layer of 2.
Following the convolution layers, there are two fully connected layers, each composed of 10 nodes.
Here is a code snippet that illustrates this structure: -->

訓練は Tesla T4  (g4dn-xlarge) GPU マシンで 10 訓練エポック行った。
Jupyter notebook はプロジェクトの GitHub リポジトリで公開されており，訓練ループのコードが含まれている。
以下は，各エポックのトレーニングループの結果である。
<!-- The training was performed with a Tesla T4 (g4dn-xlarge) GPU machine for 10 training epochs.
The Jupyter Notebook is available in the project’s GitHub repository and contains the code for the training loop.
The following are the results of training loops for each epoch. -->


In [13]:
import torch.nn as nn
import torch
import torch.optim as optim

from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from sklearn.model_selection import train_test_split

import os


class LoadData:
    def __init__(self):
        self.cat_path = 'kagglecatsanddogs_3367a/PetImages/Cat'
        self.dog_path = 'kagglecatsanddogs_3367a/PetImages/Dog'

    def delete_non_jpeg_files(self, directory):
        for filename in os.listdir(directory):
            if not filename.endswith('.jpg') and not filename.endswith('.jpeg'):
                file_path = os.path.join(directory, filename)
                try:
                    if os.path.isfile(file_path) or os.path.islink(file_path):
                        os.unlink(file_path)
                    elif os.path.isdir(file_path):
                        shutil.rmtree(file_path)
                    print('deleted', file_path)
                except Exception as e:
                    print('Failed to delete %s. Reason: %s' % (file_path, e))

    def data(self):
        self.delete_non_jpeg_files(self.dog_path)
        self.delete_non_jpeg_files(self.cat_path)

        dog_list = os.listdir(self.dog_path)
        dog_list = [(os.path.join(self.dog_path, i), 1) for i in dog_list]

        cat_list = os.listdir(self.cat_path)
        cat_list = [(os.path.join(self.cat_path, i), 0) for i in cat_list]

        total_list = cat_list + dog_list

        train_list, test_list = train_test_split(total_list, test_size=0.2)
        train_list, val_list = train_test_split(train_list, test_size=0.2)
        print('train list', len(train_list))
        print('test list', len(test_list))
        print('val list', len(val_list))
        return train_list, test_list, val_list


# data Augumentation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])


class dataset(torch.utils.data.Dataset):

    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    # dataset length
    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    # load an one of images
    def __getitem__(self, idx):
        img_path, label = self.file_list[idx]
        img = Image.open(img_path).convert('RGB')
        img_transformed = self.transform(img)
        return img_transformed, label


class Cnn(nn.Module):
    def __init__(self):
        super(Cnn, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=0, stride=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=0, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=0, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.fc1 = nn.Linear(3 * 3 * 64, 10)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(10, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(out.size(0), -1)
        out = self.relu(self.fc1(out))
        out = self.fc2(out)
        return out


if __name__ == "__main__":
    lr = 0.001  # learning_rate
    batch_size = 800  # we will use mini-batch method
    epochs = 10  # How much to train a model

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    torch.manual_seed(1234)
    if device == 'cuda':
        torch.cuda.manual_seed_all(1234)

    print(device)

    load_data = LoadData()

    train_list, test_list, val_list = load_data.data()

    train_data = dataset(train_list, transform=transform)
    test_data = dataset(test_list, transform=transform)
    val_data = dataset(val_list, transform=transform)

    train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=batch_size, shuffle=True)

    model = Cnn().to(device)
    model.train()

    optimizer = optim.Adam(params=model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        epoch_loss = 0
        epoch_accuracy = 0

        for data, label in train_loader:
            data = data.to(device)
            label = label.to(device)

            output = model(data)
            loss = criterion(output, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc = ((output.argmax(dim=1) == label).float().mean())
            epoch_accuracy += acc / len(train_loader)
            epoch_loss += loss / len(train_loader)

        print(f'Epoch : {epoch + 1:2d},',
              f'train accuracy : {epoch_accuracy:.3f},',
              f'train loss : {epoch_loss:.3f}')
        #print('Epoch : {}, train accuracy : {}, train loss : {}'.format(epoch + 1, epoch_accuracy, epoch_loss))

        with torch.no_grad():
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            for data, label in val_loader:
                data = data.to(device)
                label = label.to(device)

                val_output = model(data)
                val_loss = criterion(val_output, label)

                acc = ((val_output.argmax(dim=1) == label).float().mean())
                epoch_val_accuracy += acc / len(val_loader)
                epoch_val_loss += val_loss / len(val_loader)

            print(f'Epoch : {epoch + 1:2d},',
                  f'val_accuracy : {epoch_val_accuracy:.3f},',
                  f'val_loss : {epoch_val_loss:.3f}')
            #print('Epoch : {}, val_accuracy : {}, val_loss : {}'.format(epoch + 1, epoch_val_accuracy, epoch_val_loss))


cuda
train list 15973
test list 4992
val list 3994




Epoch :  1, train accuracy : 0.586, train loss : 0.676
Epoch :  1, val_accuracy : 0.625, val_loss : 0.654
Epoch :  2, train accuracy : 0.633, train loss : 0.647
Epoch :  2, val_accuracy : 0.637, val_loss : 0.640
Epoch :  3, train accuracy : 0.666, train loss : 0.624
Epoch :  3, val_accuracy : 0.672, val_loss : 0.608
Epoch :  4, train accuracy : 0.688, train loss : 0.587
Epoch :  4, val_accuracy : 0.701, val_loss : 0.573
Epoch :  5, train accuracy : 0.713, train loss : 0.560
Epoch :  5, val_accuracy : 0.724, val_loss : 0.549
Epoch :  6, train accuracy : 0.725, train loss : 0.540
Epoch :  6, val_accuracy : 0.720, val_loss : 0.537
Epoch :  7, train accuracy : 0.729, train loss : 0.534
Epoch :  7, val_accuracy : 0.732, val_loss : 0.534
Epoch :  8, train accuracy : 0.742, train loss : 0.517
Epoch :  8, val_accuracy : 0.740, val_loss : 0.528
Epoch :  9, train accuracy : 0.748, train loss : 0.507
Epoch :  9, val_accuracy : 0.752, val_loss : 0.519
Epoch : 10, train accuracy : 0.750, train loss

### Vision Transformer Approach

The Vision Transformer architecture is designed with customizable dimensions that can be adjusted according to specific requirements.
For this size of image dataset, this architecture is still big.

```python
from vision_tr.simple_vit import ViT
model = ViT(
    image_size=224,
    patch_size=32,
    num_classes=2,
    dim=128,
    depth=12,
    heads=8,
    mlp_dim=1024,
    dropout=0.1,
    emb_dropout=0.1,
).to(device)
```

Each parameter in the vision transformer plays a key role and is described here:

* `image_size=224`: This parameter specifies the desired size (width and height) of the input images to the model. In this case, the images are expected to be of size 224x224 pixels.
* `patch_size=32`: The images are divided into smaller patches, and this parameter defines the size (width and height) of each patch. In this case, each patch is 32x32 pixels.
* `num_classes=2`: This parameter indicates the number of classes in the classification task. In this example, the model is designed to classify inputs into two classes (cats and dogs).
* `dim=128`: It specifies the dimensionality of the embedding vectors in the model. The embeddings capture the representation of each image patch.
* `depth=12`: This parameter defines the depth or number of layers in the Vision Transformer model (encoder model). A higher depth allows for more complex feature extraction.
* `heads=8`: This parameter represents the number of attention heads in the self-attention mechanism of the model.
* `mlp_dim=1024`: It specifies the dimensionality of the Multi-Layer Perceptron (MLP) hidden layers in the model. The MLP is responsible for transforming the token representations after self-attention.
* `dropout=0.1`: This parameter controls the dropout rate, which is a regularization technique used to prevent overfitting. It randomly sets a fraction of input units to 0 during training.
* `emb_dropout=0.1`: It defines the dropout rate specifically applied to the token embeddings. This dropout helps prevent over-reliance on specific tokens during training.

The training of the vision transformer for the classification task was performed with the Tesla T4 (g4dn-xlarge) GPU machine for 20 training epochs.
The training was conducted for 20 epochs (instead of 10 epochs used for CNN) because the training loss’s convergence was slow. The following are the results of training loops for each epoch.

The CNN approach reached 75% accuracy in 10 epochs, while the vision transformer model reached 69% accuracy and took significantly longer to train.

### 結論 <!-- ### Conclusion-->

結論として，CNN モデルと Vision Transformer モデルを比較すると，モデルサイズ，メモリ要件，精度，性能の点で顕著な違いがある。
CNN モデルは伝統的に，そのコンパクトなサイズと効率的なメモリ利用で知られており，リソースに制約のある環境に適している。
画像処理課題において非常に効果的であることが証明されており，様々なコンピュータビジョン応用において優れた精度を示している。
一方，Vision Transformers は，画像の大域的な依存関係や文脈的な理解を捉えるための強力なアプローチを提供し，特定の課題における性能向上をもたらす。
しかし，Vision Transformer は CNN に比べてモデルサイズが大きく，メモリ要件が高い傾向がある。
特に大規模なデータセットを扱う場合，素晴らしい精度を達成できるかもしれないが，計算上の要求が，リソースが限られたシナリオでの実用性を制限する可能性がある。
最終的に，CNN モデルと Vision Transformer モデルのどちらを選択するかは，利用可能なリソース，データセットサイズ，モデルの複雑さ，精度，性能のトレードオフなどの要因を考慮し，目の前の課題の特定の要件に依存する。
コンピュータビジョンの分野が進化し続けるにつれて，両アーキテクチャのさらなる進歩が期待され，研究者や実務家が特定のニーズや制約に基づいて，より多くの情報に基づいた選択を行うことができるようになる。
<!--In conclusion, when comparing CNN and Vision Transformer models, there are notable differences in terms of model size, memory requirements, accuracy, and performance.
CNN models are traditionally known for their compact size and efficient memory utilization, making them suitable for resource-constrained environments.
They have proven to be highly effective in image processing tasks and exhibit excellent accuracy in various computer vision applications.
On the other hand, Vision Transformers offer a powerful approach to capture global dependencies and contextual understanding in images, resulting in improved performance in certain tasks. However, Vision Transformers tend to have larger model sizes and higher memory requirements compared to CNNs. While they may achieve impressive accuracy, especially when dealing with larger datasets, the computational demands can limit their practicality in scenarios with limited resources. Ultimately, the choice between CNN and Vision Transformer models depends on the specific requirements of the task at hand, considering factors such as available resources, dataset size, and the trade-off between model complexity, accuracy, and performance.
As the field of computer vision continues to evolve, further advancements in both architectures are expected, enabling researchers and practitioners to make more informed choices based on their specific needs and constraints. -->



In [None]:
import torch.nn as nn
import torch
import torch.optim as optim
from vision_tr.simple_vit import ViT
# from vit_pytorch.efficient import ViT
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from sklearn.model_selection import train_test_split

import os

# from linformer import Linformer
from torch.optim.lr_scheduler import StepLR
# from vit_pytorch.efficient import ViT


class LoadData:

    def __init__(self):
        self.cat_path = 'kagglecatsanddogs_3367a/PetImages/Cat'
        self.dog_path = 'kagglecatsanddogs_3367a/PetImages/Dog'

    def delete_non_jpeg_files(self, directory):
        for filename in os.listdir(directory):
            if not filename.endswith('.jpg') and not filename.endswith('.jpeg'):
                file_path = os.path.join(directory, filename)
                try:
                    if os.path.isfile(file_path) or os.path.islink(file_path):
                        os.unlink(file_path)
                    elif os.path.isdir(file_path):
                        shutil.rmtree(file_path)
                    print('deleted', file_path)
                except Exception as e:
                    print('Failed to delete %s. Reason: %s' % (file_path, e))

    def data(self):
        self.delete_non_jpeg_files(self.dog_path)
        self.delete_non_jpeg_files(self.cat_path)

        dog_list = os.listdir(self.dog_path)
        dog_list = [(os.path.join(self.dog_path, i), 1) for i in dog_list]

        cat_list = os.listdir(self.cat_path)
        cat_list = [(os.path.join(self.cat_path, i), 0) for i in cat_list]

        total_list = cat_list + dog_list

        train_list, test_list = train_test_split(total_list, test_size=0.2)
        train_list, val_list = train_test_split(train_list, test_size=0.2)
        print('train list', len(train_list))
        print('test list', len(test_list))
        print('val list', len(val_list))
        return train_list, test_list, val_list


# data Augumentation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])


class dataset(torch.utils.data.Dataset):

    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    # dataset length
    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    # load an one of images
    def __getitem__(self, idx):
        img_path, label = self.file_list[idx]
        img = Image.open(img_path).convert('RGB')
        img_transformed = self.transform(img)
        return img_transformed, label



if __name__ == "__main__":
    # Training settings
    batch_size = 64
    epochs = 20
    lr = 3e-5
    gamma = 0.7
    seed = 42

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    torch.manual_seed(1234)
    if device == 'cuda':
        torch.cuda.manual_seed_all(1234)

    print(device)

    load_data = LoadData()

    train_list, test_list, val_list = load_data.data()

    train_data = dataset(train_list, transform=transform)
    test_data = dataset(test_list, transform=transform)
    val_data = dataset(val_list, transform=transform)

    train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=batch_size, shuffle=True)
    model = ViT(
        image_size=224,
        patch_size=32,
        num_classes=2,
        dim=128,
        depth=12,
        heads=8,
        mlp_dim=1024,
        dropout=0.1,
        emb_dropout=0.1,
    ).to(device)

    # loss function
    criterion = nn.CrossEntropyLoss()
    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)
    # scheduler
    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

    epochs = 20

    for epoch in range(epochs):
        epoch_loss = 0
        epoch_accuracy = 0

        for data, label in train_loader:
            data = data.to(device)
            label = label.to(device)

            output = model(data)
            loss = criterion(output, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc = ((output.argmax(dim=1) == label).float().mean())
            epoch_accuracy += acc / len(train_loader)
            epoch_loss += loss / len(train_loader)

        print('Epoch : {}, train accuracy : {}, train loss : {}'.format(epoch + 1, epoch_accuracy, epoch_loss))

        with torch.no_grad():
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            for data, label in val_loader:
                data = data.to(device)
                label = label.to(device)

                val_output = model(data)
                val_loss = criterion(val_output, label)

                acc = ((val_output.argmax(dim=1) == label).float().mean())
                epoch_val_accuracy += acc / len(val_loader)
                epoch_val_loss += val_loss / len(val_loader)

            print('Epoch : {}, val_accuracy : {}, val_loss : {}'.format(epoch + 1, epoch_val_accuracy, epoch_val_loss))


cuda
train list 15973
test list 4992
val list 3994
Epoch : 1, train accuracy : 0.5329695343971252, train loss : 0.6937241554260254
Epoch : 1, val_accuracy : 0.5617369413375854, val_loss : 0.6789079904556274
