# ViT assignment

colab의 경우, 런타임 유형을 GPU로 바꿔주세요.

# 0. Setting

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from einops import repeat
from einops.layers.torch import Rearrange
from torch import Tensor
import math
import time

# 1. Project input to patches

In [2]:
import torch
from torch import nn, Tensor
from einops.layers.torch import Rearrange  # 패치 차원 재배치를 위해 필요

class PatchProjection(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=224):
        super().__init__()
        self.patch_size = patch_size

        # (이미지 한 변 // 패치 크기)^2 = 총 패치 수
        self.num_patches = (img_size // patch_size) ** 2

        self.projection = nn.Sequential(
            # Conv2d: 입력 이미지를 patch 단위로 자르고, 각 patch를 emb_size 차원으로 매핑
            nn.Conv2d(
                in_channels,              # 입력 채널 수 (RGB면 3)
                emb_size,                 # 출력 채널 수 = 임베딩 차원
                kernel_size=patch_size,   # 패치 크기만큼 자름
                stride=patch_size         # 겹치지 않게 stride도 동일하게
            ),
            # Rearrange: (B, emb_size, H', W') → (B, num_patches, emb_size)
            Rearrange('b e h w -> b (h w) e')  # 패치 차원(h*w)을 flatten
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # Xavier 초기화는 일반적으로 Linear/Conv에 사용됨
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x: Tensor) -> Tensor:
        # 입력 x: (B, C, H, W) → projection → (B, num_patches, emb_size)
        return self.projection(x)

# Test
if __name__ == "__main__":
    x = torch.randn(8, 3, 224, 224)  # 배치 크기 8, 3채널, 224x224 이미지
    patch_proj = PatchProjection()
    out = patch_proj(x)

    print(f'Input shape: {x.shape}')  # (8, 3, 224, 224)
    print(f'Patch embeddings shape: {out.shape}')  # (8, 196, 768) expected
    print(f'Number of patches: {patch_proj.num_patches}')  # 14 * 14 = 196

Input shape: torch.Size([8, 3, 224, 224])
Patch embeddings shape: torch.Size([8, 196, 768])
Number of patches: 196


# 2. Patches embedding

In [3]:
import torch
from torch import nn, Tensor
from einops.layers.torch import Rearrange

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=224):
        super().__init__()

        # 이미지 크기를 patch_size로 나눈 후 전체 패치 수 계산
        self.num_patches = (img_size // patch_size) ** 2  # 14x14 = 196 patches for 224x224

        # Patch projection: Conv2d로 패치 추출 후 Rearrange로 flatten
        self.projection = nn.Sequential(
            nn.Conv2d(
                in_channels, emb_size,
                kernel_size=patch_size, stride=patch_size  # 패치 단위로 분할
            ),
            Rearrange('b e h w -> b (h w) e')  # (B, emb_size, H, W) → (B, num_patches, emb_size)
        )

        # [CLS] 토큰 (learnable)과 positional encoding (num_patches + 1 for CLS)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))  # (1, 1, D)
        self.positions = nn.Parameter(torch.randn(self.num_patches + 1, emb_size))  # (197, D)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        # [CLS] 토큰과 포지셔널 임베딩 초기화 (truncated normal)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.positions, std=0.02)

    def forward(self, x: Tensor) -> Tensor:
        B = x.shape[0]  # 배치 크기

        # 1. 이미지 → 패치 임베딩 (B, 196, 768)
        x = self.projection(x)

        # 2. CLS 토큰을 배치 크기에 맞게 복제 (B, 1, 768)
        cls_tokens = self.cls_token.expand(B, -1, -1)

        # 3. CLS 토큰을 앞에 붙여서 (B, 197, 768)
        x = torch.cat((cls_tokens, x), dim=1)

        # 4. 포지셔널 임베딩 더하기 (broadcast됨)
        x = x + self.positions  # (B, 197, 768)

        return x  # shape: (B, 197, emb_size)

# Test
if __name__ == "__main__":
    x = torch.randn(8, 3, 224, 224)
    patch_emb = PatchEmbedding()
    out = patch_emb(x)
    print(f'Input shape: {x.shape}')    # (8, 3, 224, 224)
    print(f'Output shape: {out.shape}')  # (8, 197, 768)
    print(f'Expected: (8, 197, 768)')    # 196 patches + 1 CLS token


Input shape: torch.Size([8, 3, 224, 224])
Output shape: torch.Size([8, 197, 768])
Expected: (8, 197, 768)


# 3. Multi Head Attention (MHA)

In [4]:
import torch
from torch import nn, Tensor
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size=768, num_heads=12, dropout=0.1):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.head_dim = emb_size // num_heads  # 각 head당 차원 수
        self.scale = self.head_dim ** -0.5     # Scaled dot-product 시 나눌 값

        assert emb_size % num_heads == 0, "Embedding size must be divisible by number of heads"

        # Q, K, V를 한 번에 생성하는 Linear 레이어 (bias=False)
        self.qkv = nn.Linear(emb_size, emb_size * 3, bias=False)

        # Output projection (attention 결과를 통합 후 projection)
        self.proj = nn.Linear(emb_size, emb_size)

        # Dropout 레이어
        self.dropout = nn.Dropout(dropout)

        self._init_weights()

    def _init_weights(self):
        # QKV의 가중치 초기화 (xavier)
        nn.init.xavier_uniform_(self.qkv.weight)
        # 출력 projection 가중치 초기화
        nn.init.xavier_uniform_(self.proj.weight)
        nn.init.constant_(self.proj.bias, 0)

    def forward(self, x: Tensor) -> Tensor:
        B, N, C = x.shape  # B: batch, N: tokens, C: embedding dim

        # 1. Q, K, V 생성: (B, N, 3 * C)
        qkv = self.qkv(x)

        # 2. Reshape: (B, N, 3, num_heads, head_dim)
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)

        # 3. Permute: (3, B, num_heads, N, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)

        # 4. q, k, v 각각 분리 (각 shape: B, num_heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # 5. Attention score 계산: (B, num_heads, N, N)
        attn_scores = (q @ k.transpose(-2, -1)) * self.scale  # QK^T / sqrt(d)
        attn_probs = F.softmax(attn_scores, dim=-1)           # Softmax over last dim
        attn_probs = self.dropout(attn_probs)                 # Dropout 적용

        # 6. Attention 결과 계산: (B, num_heads, N, head_dim)
        attn_out = attn_probs @ v

        # 7. head 합치기: (B, N, num_heads * head_dim = C)
        attn_out = attn_out.transpose(1, 2).reshape(B, N, C)

        # 8. 최종 projection 및 dropout
        out = self.proj(attn_out)
        out = self.dropout(out)

        return out

# Test
if __name__ == "__main__":
    x = torch.randn(8, 197, 768)  # (batch, tokens, emb_size)
    mha = MultiHeadAttention()
    out = mha(x)
    print(f'Input shape: {x.shape}')      # (8, 197, 768)
    print(f'Output shape: {out.shape}')    # (8, 197, 768)
    print(f'Parameters: {sum(p.numel() for p in mha.parameters()):,}')  # 파라미터 수


Input shape: torch.Size([8, 197, 768])
Output shape: torch.Size([8, 197, 768])
Parameters: 2,360,064


# 4. Transformer Encoder Block

In [5]:
import torch
from torch import nn, Tensor
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(self, emb_size=768, mlp_ratio=4, dropout=0.1):
        super().__init__()
        hidden_size = int(emb_size * mlp_ratio)

        self.net = nn.Sequential(
            nn.Linear(emb_size, hidden_size),  # 확장 (768 → 3072)
            nn.GELU(),                         # 비선형 활성화
            nn.Dropout(dropout),              # regularization
            nn.Linear(hidden_size, emb_size),  # 축소 (3072 → 768)
            nn.Dropout(dropout)               # 다시 dropout
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)


class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size=768, num_heads=12, dropout=0.1):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.head_dim = emb_size // num_heads
        self.scale = self.head_dim ** -0.5  # Scaled Dot-Product에서 나눠줄 값

        # Q, K, V를 한 번에 생성 (B, N, 3 * C)
        self.qkv = nn.Linear(emb_size, emb_size * 3, bias=False)

        # Attention 결과 projection
        self.proj = nn.Linear(emb_size, emb_size)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.qkv.weight)
        nn.init.xavier_uniform_(self.proj.weight)
        nn.init.constant_(self.proj.bias, 0)

    def forward(self, x: Tensor) -> Tensor:
        B, N, C = x.shape  # Batch, Token 수, Embedding dim

        # qkv: (B, N, 3 * C) → reshape → (B, N, 3, num_heads, head_dim)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)

        # permute: (3, B, num_heads, N, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)

        q, k, v = qkv[0], qkv[1], qkv[2]  # 각각 shape: (B, num_heads, N, head_dim)

        # Attention score 계산: (B, num_heads, N, N)
        attn_score = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn_score, dim=-1)
        attn = self.dropout(attn)

        # attention-weighted value 합산: (B, num_heads, N, head_dim)
        out = attn @ v

        # head 병합: (B, N, C)
        out = out.transpose(1, 2).reshape(B, N, C)

        # 최종 projection + dropout
        out = self.proj(out)
        out = self.dropout(out)

        return out


class TransformerEncoderBlock(nn.Module):
    def __init__(self, emb_size=768, num_heads=12, mlp_ratio=4, dropout=0.1):
        super().__init__()

        # Pre-LayerNorm → Attention → Residual
        self.norm1 = nn.LayerNorm(emb_size)
        self.attention = MultiHeadAttention(emb_size, num_heads, dropout)

        # Pre-LayerNorm → MLP → Residual
        self.norm2 = nn.LayerNorm(emb_size)
        self.mlp = MLP(emb_size, mlp_ratio, dropout)

    def forward(self, x: Tensor) -> Tensor:
        # Residual connection with attention block
        x = x + self.attention(self.norm1(x))  # LayerNorm → Attention → Add

        # Residual connection with MLP block
        x = x + self.mlp(self.norm2(x))        # LayerNorm → MLP → Add

        return x


# Test
if __name__ == "__main__":
    x = torch.randn(8, 197, 768)  # 8개의 이미지, 197개 토큰 (196 patch + 1 CLS)
    block = TransformerEncoderBlock()
    out = block(x)
    print(f'Input shape: {x.shape}')         # (8, 197, 768)
    print(f'Output shape: {out.shape}')      # (8, 197, 768)
    print(f'Parameters: {sum(p.numel() for p in block.parameters()):,}')  # 총 파라미터 수


Input shape: torch.Size([8, 197, 768])
Output shape: torch.Size([8, 197, 768])
Parameters: 7,085,568


# 5. Complete ViT

In [6]:
import torch
from torch import nn, Tensor

# (기존 PatchEmbedding, TransformerEncoderBlock이 이미 정의돼 있다고 가정)

class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        num_classes=1000,
        emb_size=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        dropout=0.1,
        drop_path=0.0  # 사용 안함 (placeholder)
    ):
        super().__init__()

        # Stage 1-2: Patch embedding + CLS token + positional encoding 포함
        self.patch_embed = PatchEmbedding(in_channels, patch_size, emb_size, img_size)

        # Stage 3-4: Transformer encoder blocks (N개 반복)
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(emb_size, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        # Stage 5: Classification head (LayerNorm + Linear)
        self.norm = nn.LayerNorm(emb_size)
        self.head = nn.Linear(emb_size, num_classes)

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.head.weight, std=0.02)
        nn.init.constant_(self.head.bias, 0)

    def forward(self, x: Tensor) -> Tensor:
        # 1. 패치 임베딩 + CLS 토큰 + 포지셔널 인코딩 (B, 197, D)
        x = self.patch_embed(x)

        # 2. Transformer 블록 N개 순차 적용
        for blk in self.blocks:
            x = blk(x)

        # 3. LayerNorm 적용 후, CLS 토큰 (첫 번째 토큰) 추출
        x = self.norm(x)
        cls_token = x[:, 0]  # (B, D)

        # 4. Linear head로 분류 (B, num_classes)
        x = self.head(cls_token)

        return x


# 6. ViT for CIFAR-10

위의 코드를 완성했다면, 아래 코드를 실행하여 전체 모델을 테스트할 수 있습니다.

In [7]:
class ViTCIFAR10(nn.Module):
    def __init__(
        self,
        img_size=32,
        patch_size=4,
        in_channels=3,
        num_classes=10,
        emb_size=256,
        depth=6,
        num_heads=8,
        mlp_ratio=4,
        dropout=0.1
    ):
        super().__init__()

        self.patch_embed = PatchEmbedding(in_channels, patch_size, emb_size, img_size)

        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(emb_size, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(emb_size)
        self.head = nn.Linear(emb_size, num_classes)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
            elif isinstance(m, nn.Conv2d):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        if hasattr(self.patch_embed, 'cls_token'):
            nn.init.trunc_normal_(self.patch_embed.cls_token, std=0.02)
        if hasattr(self.patch_embed, 'positions'):
            nn.init.trunc_normal_(self.patch_embed.positions, std=0.02)

    def forward(self, x: Tensor) -> Tensor:
        x = self.patch_embed(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)
        cls_token = x[:, 0]
        x = self.head(cls_token)

        return x


def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}: Loss {loss.item():.4f}, Acc {100.*correct/total:.2f}%')

    return running_loss / len(dataloader), 100. * correct / total


def test(model, dataloader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()

            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

    test_loss /= len(dataloader)
    accuracy = 100. * correct / total

    return test_loss, accuracy


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)

    model = ViTCIFAR10(
        img_size=32,
        patch_size=4,
        num_classes=10,
        emb_size=256,
        depth=6,
        num_heads=4,
        dropout=0.1
    ).to(device)

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.AdamW(
        model.parameters(),
        lr=5e-4,
        weight_decay=0.03,
        betas=(0.9, 0.999)
    )

    num_epochs = 30
    warmup_epochs = 1
    total_steps = len(train_loader) * num_epochs
    warmup_steps = len(train_loader) * warmup_epochs
    scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)

    best_acc = 0
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')

        start_time = time.time()
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        test_loss, test_acc = test(model, test_loader, criterion, device)

        for _ in range(len(train_loader)):
            scheduler.step()

        epoch_time = time.time() - start_time
        current_lr = optimizer.param_groups[0]['lr']

        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
        print(f'LR: {current_lr:.6f}, Epoch time: {epoch_time:.2f}s')

        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'vit_cifar10_best.pth')

        if test_acc > 90.0:
            print(f"Reached target accuracy!")
            break

    print(f'\nBest Test Accuracy: {best_acc:.2f}%')


if __name__ == "__main__":
    model = ViTCIFAR10(emb_size=256, depth=6, num_heads=4)
    x = torch.randn(4, 3, 32, 32)
    out = model(x)

    print(f'Input shape: {x.shape}')
    print(f'Output shape: {out.shape}')
    print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')

    main()

Input shape: torch.Size([4, 3, 32, 32])
Output shape: torch.Size([4, 10])
Parameters: 4,766,474


100%|██████████| 170M/170M [00:13<00:00, 13.1MB/s]



Epoch 1/30
Batch 0: Loss 2.3584, Acc 7.81%
Batch 100: Loss 2.3397, Acc 9.58%
Train Loss: 2.3465, Train Acc: 9.56%
Test Loss: 2.3452, Test Acc: 10.14%
LR: 0.000500, Epoch time: 40.17s

Epoch 2/30
Batch 0: Loss 2.3524, Acc 7.81%
Batch 100: Loss 1.8780, Acc 27.10%
Train Loss: 1.9420, Train Acc: 31.24%
Test Loss: 1.7955, Test Acc: 39.13%
LR: 0.000499, Epoch time: 40.79s

Epoch 3/30
Batch 0: Loss 1.7942, Acc 37.89%
Batch 100: Loss 1.6713, Acc 41.00%
Train Loss: 1.7243, Train Acc: 42.64%
Test Loss: 1.6558, Test Acc: 45.61%
LR: 0.000494, Epoch time: 41.27s

Epoch 4/30
Batch 0: Loss 1.7299, Acc 40.62%
Batch 100: Loss 1.5954, Acc 47.18%
Train Loss: 1.6087, Train Acc: 48.18%
Test Loss: 1.5648, Test Acc: 49.45%
LR: 0.000487, Epoch time: 43.83s

Epoch 5/30
Batch 0: Loss 1.5420, Acc 51.56%
Batch 100: Loss 1.5252, Acc 49.74%
Train Loss: 1.5568, Train Acc: 50.68%
Test Loss: 1.5206, Test Acc: 52.54%
LR: 0.000477, Epoch time: 42.73s

Epoch 6/30
Batch 0: Loss 1.5522, Acc 49.61%
Batch 100: Loss 1.5252, 

ViT는 일반적으로 대규모 데이터셋에서 사전 학습된(pretrained) 모델을 활용하는 경우가 많기 때문에, 하이퍼파라미터를 조정하거나 학습 epoch을 늘리면 성능이 개선될 수는 있지만, 소규모 데이터셋에서 처음부터 학습한 ViT의 성능이 낮은 것은 구조적 한계에 가깝습니다.