## ViTの実装

## 1. import the necessary libraries

In [1]:
#### 1. import the necessary libraries
import torch
import torch.nn as nn
from vit import ViT
import matplotlib.pyplot as plt
from train import train
from torchsummary import summary

## 2. Train the model

### 2.1. Check the device

In [2]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

### 2.2. Load the dataset

In [3]:
from dataloader import get_dataloader
train_dataloader, val_dataloader, test_dataloader = get_dataloader()
# Check the image size and batch size from the first batch of each dataloader
for name, dataloader in zip(['Train', 'Eval', 'Test'], [train_dataloader, val_dataloader, test_dataloader]):
    images, labels = next(iter(dataloader))
    ## image.shape = (batch_size, channel, img_size, img_size)
    print(f"{name} Dataloader - Batch size: {images.shape[0]}, Image size: {images.shape[2]}x{images.shape[3]}")
print("Classes:",train_dataloader.dataset.classes)

Files already downloaded and verified
Files already downloaded and verified
Train Dataloader - Batch size: 32, Image size: 32x32
Eval Dataloader - Batch size: 32, Image size: 32x32
Test Dataloader - Batch size: 32, Image size: 32x32
Classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


### 2.3. Define the model

In [4]:
batch_size = images.shape[0]
channel = images.shape[1]
img_size = images.shape[2]
num_classes = len(train_dataloader.dataset.classes) ## Using CIFAR10

model = ViT(in_channels = channel,
            image_size = img_size,
            num_classes = num_classes)
summary(model, input_size=(channel, img_size, img_size))
x = torch.randn(batch_size, channel, img_size, img_size)
print("Input shape is ",x.shape)
pred = model(x)
print("Output shape is ",pred.shape)

ViT is implemented
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 384, 2, 2]         295,296
           Dropout-2               [-1, 5, 384]               0
   ViT_input_Layer-3               [-1, 5, 384]               0
         LayerNorm-4               [-1, 5, 384]             768
            Linear-5               [-1, 5, 384]         147,456
            Linear-6               [-1, 5, 384]         147,456
            Linear-7               [-1, 5, 384]         147,456
            Linear-8               [-1, 5, 384]         147,840
Self_Attention_Layer-9               [-1, 5, 384]               0
        LayerNorm-10               [-1, 5, 384]             768
           Linear-11              [-1, 5, 1536]         591,360
             GELU-12              [-1, 5, 1536]               0
          Dropout-13              [-1, 5, 1536]               0
           Linear-

### 2.4. Define the optimizer and the loss function

In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

### 2.5. Train the model

In [None]:
num_epochs = 200
train(model, train_dataloader, val_dataloader, test_dataloader, num_epochs, optimizer, criterion, device)