## Vision Transformer (ViT) Pytorch

Patch Embedding

In [1]:
!pip install einops

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
import os

In [3]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import random
from torch.nn import Module
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from functools import partial
from typing import Any, Callable, List, Optional, Type, Union
from torch import Tensor

In [4]:
class CustomTrainDataset1(Dataset,Module):
    def __init__(self, folder_path, index_array, transform=None):
        self.folder_path = folder_path
        self.image_files=[os.path.join(folder_path,file) for file in os.listdir(folder_path)]
        self.image_paths=random.sample(self.image_files,7000)
        #self.image_paths = [os.path.join(folder_path, img) for img in os.listdir(folder_path)[:7000] if img.endswith('.jpg')]
        self.label = index_array
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image_name = os.path.basename(image_path)  # 이미지 파일의 이름을 가져오기
        image = Image.open(image_path).convert('RGB')  # 이미지를 열고 RGB 형식으로 변환
        if self.transform:
            image = self.transform(image)
        return image,self.label

# 이미지 전처리를 위한 변환기 설정
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 이미지 크기 조절
    transforms.ToTensor(),           # 이미지를 텐서로 변환
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 정규화
])



In [5]:
class CustomTrainDataset2(Dataset,Module):
    def __init__(self, folder_path, index_array, transform=None):
        self.folder_path = folder_path
        image_files=[os.path.join(folder_path,file) for file in os.listdir(folder_path)]
        self.image_paths=random.sample(image_files,3000)
        self.label = index_array
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image_name = os.path.basename(image_path)  # 이미지 파일의 이름을 가져오기
        image = Image.open(image_path).convert('RGB')  # 이미지를 열고 RGB 형식으로 변환
        if self.transform:
            image = self.transform(image)
        return image,self.label

# 이미지 전처리를 위한 변환기 설정
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 이미지 크기 조절
    transforms.ToTensor(),           # 이미지를 텐서로 변환
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 정규화
])


In [6]:
class CustomValidDataset(Dataset,Module):
    def __init__(self, folder_path, index_array, transform=None):
        self.folder_path = folder_path
        image_files=[os.path.join(folder_path,file) for file in os.listdir(folder_path)]
        self.image_paths=random.sample(image_files,1000)
        #self.image_paths = [os.path.join(folder_path, img) for img in os.listdir(folder_path)[:1000] if img.endswith('.jpg')]
        self.label = index_array
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image_name = os.path.basename(image_path)  # 이미지 파일의 이름을 가져오기
        image = Image.open(image_path).convert('RGB')  # 이미지를 열고 RGB 형식으로 변환
        if self.transform:
            image = self.transform(image)
        return image,self.label

# 이미지 전처리를 위한 변환기 설정
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 이미지 크기 조절
    transforms.ToTensor(),           # 이미지를 텐서로 변환
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 정규화
])

In [7]:
import pickle

def save_data_loader(dataloader, file_path):
    with open(file_path, 'wb') as file:
        pickle.dump(dataloader, file)

def load_data_loader(file_path):
    with open(file_path, 'rb') as file:
        dataloader = pickle.load(file)
    return dataloader

In [8]:
# DataLoader 객체를 파일에 저장
#save_dataset(datasets, 'dataset.pkl')

# 파일에서 DataLoader 객체 불러오기
train_loader = load_data_loader('/content/drive/MyDrive/종설/train_Dataloader.pkl')
valid_loader = load_data_loader('/content/drive/MyDrive/종설/valid_Dataloader.pkl')

In [9]:
print(len(train_loader.dataset))
print(len(valid_loader.dataset))

60000
6000


In [10]:
class PatchEmbedding(nn.Module):
  def __init__(self, in_channels:int=3, patch_size:int=16, emb_size:int=768, img_size:int=224):
    self.P = patch_size
    super().__init__()
    self.projection = nn.Sequential(
        nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
        Rearrange('b e h w -> b (h w) e')
    )
    self.cls_token = nn.Parameter(torch.randn(1,1,emb_size))
    self.positions = nn.Parameter(torch.randn((img_size//patch_size)**2+1,emb_size))
  def forward(self, x:Tensor)->Tensor:
    b,_,_,_ = x.shape
    x = self.projection(x)
    cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
    x = torch.cat([cls_tokens, x], dim=1)
    x += self.positions
    return x

Multi head attention

In [11]:
class MultiHeadAttention(nn.Module):
  def __init__(self, emb_size:int=7689, num_heads:int=8, dropout:float=0):
    super().__init__()
    self.emb_size = emb_size
    self.num_heads = num_heads
    #Fuse the queries, keys, values in one matrix
    self.qkv = nn.Linear(emb_size, emb_size*3)
    self.att_drop = nn.Dropout(dropout)
    self.projection = nn.Linear(emb_size, emb_size)

  def forward(self, x:Tensor, mask:Tensor = None) -> Tensor:
    #split keys, queries, and values in num_heads
    qkv = rearrange(self.qkv(x), 'b n (h d qkv) -> (qkv) b h n d',
                    h=self.num_heads, qkv=3)
    queries, keys, values = qkv[0],qkv[1],qkv[2]

    #sum up over the last axis
    energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
    if mask is not None:
      fill_value = torch.finfo(torch.float32).min
      energy.mask_fill(~mask, fill_value)

    scaling = self.emb_size**(1/2)
    att = F.softmax(energy, dim=-1)/scaling
    att = self.att_drop(att)

    #sum up over the third axis
    out = torch.einsum('bhal, bhlv -> bhav',att, values)
    out = rearrange(out, 'b h n d -> b n (h d)')
    out = self.projection(out)
    return out

Residual Block

In [12]:
class ResidualAdd(nn.Module):
  def __init__(self,fn):
    super().__init__()
    self.fn = fn
  def forward(self, x, **kwargs):
    res = x
    x = self.fn(x, **kwargs)
    x += res
    return x

Feed Forward MLP

In [13]:
class FeedForwardBlock(nn.Sequential):
  def __init__(self, emb_size:int, expansion:int=4, drop_p:float=0.):
    super().__init__(
        nn.Linear(emb_size, expansion*emb_size),
        nn.GELU(),
        nn.Dropout(drop_p),
        nn.Linear(expansion*emb_size, emb_size)
    )

Transformer Encoder Block

In [14]:
class TransformerEncoderBlock(nn.Sequential):
  def __init__(self,
               emb_size:int=768,
               drop_p:float=0,
               forward_expansion:int=4,
               forward_drop_p:float=0.,
               **kwargs):
    super().__init__(
        ResidualAdd(nn.Sequential(
            nn.LayerNorm(emb_size),
            MultiHeadAttention(emb_size, **kwargs),
            nn.Dropout(drop_p)
        )),
        ResidualAdd(nn.Sequential(
            nn.LayerNorm(emb_size),
            FeedForwardBlock(
                emb_size, expansion=forward_expansion, drop_p=forward_drop_p
            )
        ))
    )

Building Block

In [15]:
class TransformerEncoder(nn.Sequential):
  def __init__(self, depth:int=12, **kwargs):
    super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

Head

In [16]:
class ClassificationHead(nn.Sequential):
  def __init__(self, emb_size:int=768, n_classes: int=10):
    super().__init__(
        Reduce('b n e -> b e', reduction ='mean'),
        nn.LayerNorm(emb_size),
        nn.Linear(emb_size, n_classes)
    )

Summary

In [17]:
class ViT(nn.Sequential):
  def __init__(self,
               in_channels: int=3,
               patch_size: int=16,
               emb_size: int=768,
               img_size:int=12,
               depth: int=12,
               n_classes: int=10,
               **kwargs):
    super().__init__(
        PatchEmbedding(in_channels,patch_size, emb_size, img_size),
        TransformerEncoder(depth, emb_size=emb_size, **kwargs),
        ClassificationHead(emb_size, n_classes)
    )

summary(ViT(), (3,224,224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
    PatchEmbedding-3             [-1, 197, 768]               0
         LayerNorm-4             [-1, 197, 768]           1,536
            Linear-5            [-1, 197, 2304]       1,771,776
           Dropout-6          [-1, 8, 197, 197]               0
            Linear-7             [-1, 197, 768]         590,592
MultiHeadAttention-8             [-1, 197, 768]               0
           Dropout-9             [-1, 197, 768]               0
      ResidualAdd-10             [-1, 197, 768]               0
        LayerNorm-11             [-1, 197, 768]           1,536
           Linear-12            [-1, 197, 3072]       2,362,368
             GELU-13            [-1, 197, 3072]               0
          Dropout-14            [-1, 19

In [18]:
#ViT 모델 정의
model = ViT(n_classes=6)

# 손실 함수 및 옵티마이저 설정
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)

In [19]:
# 모델 학습
num_epochs = 10
# 에폭을 반복하면서 학습과 검증을 수행
for epoch in range(num_epochs):
    i=0
    # 학습 단계
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs,targets)
        loss.backward()
        optimizer.step()
        print(f'epoch: {epoch}, batch: {i}')
        i+=1
    # 각 에폭 종료 후 검증 단계
    model.eval()  # 모델을 평가 모드로 전환
    val_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():
        for val_inputs, val_targets in valid_loader:
            val_outputs = model(val_inputs)
            val_loss += criterion(val_outputs, val_targets).item()

            # 정확도 계산
            _, predicted = torch.max(val_outputs, 1)
            correct_predictions += (predicted == val_targets).sum().item()
            total_samples += val_targets.size(0)

    average_val_loss = val_loss / len(valid_loader)
    accuracy = (correct_predictions / total_samples) * 100

    # 에폭 종료 후 출력
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Validation Loss: {average_val_loss:.4f}, Accuracy: {accuracy:.2f}%')

    checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'accuracy':accuracy,
    }

    # 저장 경로 설정
    save_path = '/content/drive/MyDrive/ViT_checkpoint_epoch{}.pt'.format(epoch)

    # 딕셔너리를 파일로 저장
    torch.save(checkpoint, save_path)


    model.train()  # 모델을 학습 모드로 전환

epoch: 0, batch: 0
epoch: 0, batch: 1
epoch: 0, batch: 2
epoch: 0, batch: 3
epoch: 0, batch: 4
epoch: 0, batch: 5
epoch: 0, batch: 6
epoch: 0, batch: 7
epoch: 0, batch: 8
epoch: 0, batch: 9
epoch: 0, batch: 10
epoch: 0, batch: 11
epoch: 0, batch: 12
epoch: 0, batch: 13
epoch: 0, batch: 14
epoch: 0, batch: 15
epoch: 0, batch: 16
epoch: 0, batch: 17
epoch: 0, batch: 18
epoch: 0, batch: 19
epoch: 0, batch: 20
epoch: 0, batch: 21
epoch: 0, batch: 22
epoch: 0, batch: 23
epoch: 0, batch: 24
epoch: 0, batch: 25
epoch: 0, batch: 26
epoch: 0, batch: 27
epoch: 0, batch: 28
epoch: 0, batch: 29
epoch: 0, batch: 30
epoch: 0, batch: 31
epoch: 0, batch: 32
epoch: 0, batch: 33
epoch: 0, batch: 34
epoch: 0, batch: 35
epoch: 0, batch: 36
epoch: 0, batch: 37
epoch: 0, batch: 38
epoch: 0, batch: 39
epoch: 0, batch: 40
epoch: 0, batch: 41
epoch: 0, batch: 42


KeyboardInterrupt: ignored

In [None]:
saved_epoch=0
checkpoint = torch.load('/content/drive/MyDrive/ViT_checkpoint_epoch{}.pt'.format(saved_epoch))
# 모델과 옵티마이저의 상태 복원
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#epoch = checkpoint['epoch']
accuracy = checkpoint['accuracy']
print(accuracy)