## 1. Random Seed 설정

In [None]:
import random
import torch
import numpy as np

def setup_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

## 2. Masked Auto Encoder 모델

In [None]:
import torch
import timm
import numpy as np

from einops import repeat, rearrange
from einops.layers.torch import Rearrange

from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block

In [None]:
def random_indexes(size : int):
    forward_indexes = np.arange(size)
    np.random.shuffle(forward_indexes)
    backward_indexes = np.argsort(forward_indexes)

    return forward_indexes, backward_indexes

def take_indexes(sequences, indexes):
    return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))


### 2-1. PatchShuffle

In [None]:
class PatchShuffle(torch.nn.Module):
    def __init__(self, ratio) -> None:
        super().__init__()
        self.ratio = ratio

    def forward(self, patches : torch.Tensor):
        T, B, C = patches.shape
        remain_T = int(T * (1 - self.ratio))

        indexes = [random_indexes(T) for _ in range(B)]
        forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
        backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)

        patches = take_indexes(patches, forward_indexes)
        patches = patches[:remain_T]

        return patches, forward_indexes, backward_indexes

In [None]:
shuffle = PatchShuffle(0.75)
a = torch.rand(16, 2, 10)
b, forward_indexes, backward_indexes = shuffle(a)
print(b.shape)

### 2-2. Encoder & Decoder

In [None]:
class MAE_Encoder(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 num_layer=12,
                 num_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
        self.shuffle = PatchShuffle(mask_ratio)

        self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.layer_norm = torch.nn.LayerNorm(emb_dim)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.cls_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding

        patches, forward_indexes, backward_indexes = self.shuffle(patches)

        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')

        return features, backward_indexes

In [None]:
class MAE_Decoder(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 num_layer=4,
                 num_head=3,
                 ) -> None:
        super().__init__()

        self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
        self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.mask_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, features, backward_indexes):
        T = features.shape[0]
        backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
        features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
        features = take_indexes(features, backward_indexes)
        features = features + self.pos_embedding

        features = rearrange(features, 't b c -> b t c')
        features = self.transformer(features)
        features = rearrange(features, 'b t c -> t b c')
        features = features[1:] # remove global feature

        patches = self.head(features)
        mask = torch.zeros_like(patches)
        mask[T-1:] = 1
        mask = take_indexes(mask, backward_indexes[1:] - 1)
        img = self.patch2img(patches)
        mask = self.patch2img(mask)

        return img, mask

In [None]:
img = torch.rand(2, 3, 32, 32)
encoder, decoder = MAE_Encoder(), MAE_Decoder()

features, backward_indexes = encoder(img)
print(forward_indexes.shape)

predicted_img, mask = decoder(features, backward_indexes)
print(predicted_img.shape)

loss = torch.mean((predicted_img - img) ** 2 * mask / 0.75)
print(loss)

### 2-3. MAE Based Vision Transformer

In [None]:
class MAE_ViT(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 encoder_layer=12,
                 encoder_head=3,
                 decoder_layer=4,
                 decoder_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
        self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)

    def forward(self, img):
        features, backward_indexes = self.encoder(img)
        predicted_img, mask = self.decoder(features,  backward_indexes)
        return predicted_img, mask

### 2-4. Classification을 위한 VisionTransformer

In [None]:
class ViT_Classifier(torch.nn.Module):
    def __init__(self, encoder : MAE_Encoder, num_classes=10) -> None:
        super().__init__()
        self.cls_token = encoder.cls_token
        self.pos_embedding = encoder.pos_embedding
        self.patchify = encoder.patchify
        self.transformer = encoder.transformer
        self.layer_norm = encoder.layer_norm
        self.head = torch.nn.Linear(self.pos_embedding.shape[-1], num_classes)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding
        
        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')
        logits = self.head(features[0])
        
        return logits

## 3. Pre-training MAE with CIFAR10Dataset (Self-Supervised Learning)

### 3-1. Library & HyperParameter

In [None]:
import os
import math
import torch
import torchvision

from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor, Compose, Normalize
from tqdm.notebook import tqdm
from einops import rearrange
from torchvision import datasets, transforms

In [None]:
SEED = 42
BATCH_SIZE = 4096
MAX_DEVICE_BATCH_SIZE = 512
LR = 1.5e-4
WEIGHT_DECAY = 0.05
MASK_RATIO = 0.75
EPOCH = 2000
WARMUP = 200

device = 'cuda' if torch.cuda.is_available() else 'cpu'
setup_seed(SEED)

model_path = './pth/mae_pretrained_cifar.pt'

batch_size = min(MAX_DEVICE_BATCH_SIZE, BATCH_SIZE)
assert BATCH_SIZE % batch_size == 0
steps_per_update = BATCH_SIZE // batch_size

### 3-2. CIFAR10Dataset

In [None]:
# CIFAR-10 Dataset
train_dataset = torchvision.datasets.CIFAR10('../../datasets/CIFAR10', train=True, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
val_dataset = torchvision.datasets.CIFAR10('../../datasets/CIFAR10', train=False, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))

# 데이터 로더 설정
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=4)
writer = SummaryWriter(os.path.join('logs', 'cifar10', 'mae-pretrain'))

### 3-3. Network 설정

In [None]:
model = MAE_ViT(mask_ratio=MASK_RATIO).to(device)

### 3-4. Pre-Training

In [None]:
def pretrain_mae(model, dataloader, val_dataset, writer, model_path):
    optim = torch.optim.AdamW(model.parameters(), lr=LR * BATCH_SIZE / 256, betas=(0.9, 0.95), weight_decay=WEIGHT_DECAY)
    lr_func = lambda epoch: min((epoch + 1) / (WARMUP + 1e-8), 0.5 * (math.cos(epoch / EPOCH * math.pi) + 1))
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_func, verbose=True)

    step_count = 0
    optim.zero_grad()

    for e in range(EPOCH):
        model.train()
        losses = []
        for img, label in tqdm(iter(dataloader)):
            step_count += 1

            img = img.to(device)
            predicted_img, mask = model(img)

            loss = torch.mean((predicted_img - img) ** 2 * mask) / MASK_RATIO
            loss.backward()

            if step_count % steps_per_update == 0:
                optim.step()
                optim.zero_grad()
            
            losses.append(loss.item())
        lr_scheduler.step()
        
        avg_loss = sum(losses) / len(losses)
        writer.add_scalar('mae_loss', avg_loss, global_step=e)
        print(f'In epoch {e}, average traning loss is {avg_loss}.')

        ''' visualize the first 16 predicted images on val dataset'''
        model.eval()
        with torch.no_grad():
            val_img = torch.stack([val_dataset[i][0] for i in range(16)])
            val_img = val_img.to(device)

            predicted_val_img, mask = model(val_img)
            predicted_val_img = predicted_val_img * mask + val_img * (1 - mask)

            img = torch.cat([val_img * (1 - mask), predicted_val_img, val_img], dim=0)
            img = rearrange(img, '(v h1 w1) c h w -> c (h1 h) (w1 v w)', w1=2, v=3)

            writer.add_image('mae_image', (img + 1) / 2, global_step=e)
        
        ''' save model '''
        torch.save(model, model_path)

In [None]:
pretrain_mae(model, dataloader, val_dataset, writer, model_path)

## 4. Transfer Learning MAE with CIFAR10Dataset (Supervised Learning)

### 4-1. Library & HyperParameter

In [None]:
import os
import math
import torch
import torchvision

from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
from torchvision import datasets, transforms

In [None]:
SEED = 42
BATCH_SIZE = 128
MAX_DEVICE_BATCH_SIZE = 256
LR = 1e-3
WEIGHT_DECAY = 0.05
EPOCH = 100
WARMUP = 5

device = 'cuda' if torch.cuda.is_available() else 'cpu'
setup_seed(SEED)

model_path = './pth/mae_classifer_cifar.pt'
pretrain_path = './pth/mae_pretrained_cifar.pt'

batch_size = min(MAX_DEVICE_BATCH_SIZE, BATCH_SIZE)
assert BATCH_SIZE % batch_size == 0
steps_per_update = BATCH_SIZE // batch_size

### 4-2. Dataset

In [None]:
# transformation function 정의
transform_train = transforms.Compose([
        transforms.Resize((32, 32)),  # 이미지 크기를 32x32로 조정
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 표준 정규화
    ])
    
transform_test = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 표준 정규화
])

# 데이터셋 설정
train_dataset = torchvision.datasets.CIFAR10('../../datasets/CIFAR10', train=True, download=True, transform=transform_train)
val_dataset = torchvision.datasets.CIFAR10('../../datasets/CIFAR10', train=False, download=True, transform=transform_test)

# 데이터 로더 설정
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
writer = SummaryWriter(os.path.join('logs', 'cifar10', 'mae-pretrain'))

### 4-3. Network 설정

In [None]:
model = torch.load(pretrain_path, map_location='cpu')
model = model.to(device)
writer = SummaryWriter(os.path.join('logs', 'cifar10', 'pretrain-cls'))

model = ViT_Classifier(model.encoder, num_classes=10).to(device)

### 4-4. Train

In [None]:
def train_mae(model, train_dataloader, val_dataloader, writer, model_path):
    loss_fn = torch.nn.CrossEntropyLoss()
    acc_fn = lambda logit, label: torch.mean((logit.argmax(dim=-1) == label).float())

    optim = torch.optim.AdamW(model.parameters(), lr=LR * BATCH_SIZE / 256, betas=(0.9, 0.999), weight_decay=WEIGHT_DECAY)
    lr_func = lambda epoch: min((epoch + 1) / (WARMUP + 1e-8), 0.5 * (math.cos(epoch / EPOCH * math.pi) + 1))
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_func, verbose=True)
    
    best_val_acc = 0
    step_count = 0
    optim.zero_grad()

    for e in range(EPOCH):
        model.train()
        losses = []
        acces = []
    
        for img, label in tqdm(iter(train_dataloader)):
            step_count += 1
    
            img = img.to(device)
            label = label.to(device)
            logits = model(img)

            loss = loss_fn(logits, label)
            acc = acc_fn(logits, label)
            loss.backward()
            
            if step_count % steps_per_update == 0:
                optim.step()
                optim.zero_grad()
            
            losses.append(loss.item())
            acces.append(acc.item())
        
        lr_scheduler.step()
        
        avg_train_loss = sum(losses) / len(losses)
        avg_train_acc = sum(acces) / len(acces)
        print(f'In epoch {e}, average training loss is {avg_train_loss}, average training acc is {avg_train_acc}.')

        model.eval()
        
        with torch.no_grad():
            losses = []
            acces = []
        
            for img, label in tqdm(iter(val_dataloader)):
                img = img.to(device)
                label = label.to(device)
                logits = model(img)
        
                loss = loss_fn(logits, label)
                acc = acc_fn(logits, label)
        
                losses.append(loss.item())
                acces.append(acc.item())
        
            avg_val_loss = sum(losses) / len(losses)
            avg_val_acc = sum(acces) / len(acces)
            print(f'In epoch {e}, average validation loss is {avg_val_loss}, average validation acc is {avg_val_acc}.')  

        if avg_val_acc > best_val_acc:
            best_val_acc = avg_val_acc
            print(f'saving best model with acc {best_val_acc} at {e} epoch!')       
            torch.save(model, model_path)

        writer.add_scalars('cls/loss', {'train': avg_train_loss, 'val': avg_val_loss}, global_step=e)
        writer.add_scalars('cls/acc', {'train': avg_train_acc, 'val': avg_val_acc}, global_step=e)

In [None]:
train_mae(model, train_dataloader, val_dataloader, writer, model_path)

## 5. Contrastive Learning을 휘한 SimCLR 모델

### 5-1. Contrastive Loss 계산

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        batch_size = z_i.size(0)
        z = torch.cat([z_i, z_j], dim=0)
        similarity_matrix = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2) / self.temperature

        labels = torch.cat([torch.arange(batch_size) for _ in range(2)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()

        mask = torch.eye(labels.shape[0], dtype=torch.bool)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
        negatives = similarity_matrix[~labels.bool()].view(labels.shape[0], -1)

        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(logits.device)

        return F.cross_entropy(logits, labels)

### 5-2. Model Architecture

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

class ProjectionHead(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=2048):
        super(ProjectionHead, self).__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

class SimCLR(nn.Module):
    def __init__(self, base_model, out_dim=128):
        super(SimCLR, self).__init__()
        self.backbone = base_model
        self.feature_dim = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.projection_head = ProjectionHead(self.feature_dim, out_dim)

    def forward(self, x):
        features = self.backbone(x)
        return self.projection_head(features)

class SimCLRClassifier(nn.Module):
    def __init__(self, encoder, num_classes):
        super(SimCLRClassifier, self).__init__()
        self.encoder = encoder
        self.fc = nn.Linear(encoder.feature_dim, num_classes)

    def forward(self, x):
        features = self.encoder.backbone(x)
        return self.fc(features)

## 6. Pre-training SimCLR with CIFAR10Dataset (Self-Supervised Learning)

### 6-1. Library & HyperParameter

In [None]:
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm.notebook import tqdm

In [None]:
BATCH_SIZE = 512
EPOCH = 100
TEMPERATURE = 0.5
OUT_DIM = 128
LR = 3e-4

model_path = './pth/simclr_pretrained_cifar.pth'

### 6-2. DataLoader

In [None]:
def get_data_loaders(batch_size):
    transform = transforms.Compose([
        transforms.RandomResizedCrop(size=32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])
    
    train_dataset = torchvision.datasets.CIFAR10(root='../../datasets/CIFAR10', train=True, transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    return train_loader

### 6-3. Train

In [None]:
def pretrain_simclr(get_data_loaders, model_path):
    train_loader = get_data_loaders(BATCH_SIZE)

    model = SimCLR(torchvision.models.resnet18(pretrained=False), OUT_DIM).to(device)
    criterion = NTXentLoss(temperature=TEMPERATURE)
    optimizer = optim.Adam(model.parameters(), lr=LR)

    for epoch in range(EPOCH):
        model.train()
        total_loss = 0
        for images, _ in tqdm(train_loader):
            images_i, images_j = images.to(device), images.to(device)

            z_i = model(images_i)
            z_j = model(images_j)

            loss = criterion(z_i, z_j)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch + 1}/{EPOCH}, Loss: {avg_loss:.4f}')

    torch.save(model.state_dict(), model_path)

In [None]:
pretrain_simclr(get_data_loaders, model_path)

## 7. Transfer Learning SimCLR with CIFAR10Dataset (Supervised Learning)

### 7-1. Library & HyperParameter

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from torchvision import datasets
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

In [None]:
OUT_DIM = 128
BATCH_SIZE = 128
LR = 1e-3
EPOCH = 100

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_path = './pth/simclr_classifer_cifar.pt'
pretrain_path = './pth/simclr_pretrained_cifar.pth'

### 7-2. Dataset

In [None]:
# transformation function 정의
transform_train = transforms.Compose([
        transforms.Resize((32, 32)),  # 이미지 크기를 32x32로 조정
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 표준 정규화
    ])
    
transform_test = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 표준 정규화
])

# 데이터셋 설정
train_dataset = torchvision.datasets.CIFAR10('../../datasets/CIFAR10', train=True, download=True, transform=transform_train)
val_dataset = torchvision.datasets.CIFAR10('../../datasets/CIFAR10', train=False, download=True, transform=transform_test)

# 데이터 로더 설정
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
writer = SummaryWriter(os.path.join('logs', 'cifar10', 'mae-pretrain'))

### 7-3. Network 설정

In [None]:
simclr = SimCLR(torchvision.models.resnet18(pretrained=False), OUT_DIM)
simclr.load_state_dict(torch.load(pretrain_path))
classifier = SimCLRClassifier(simclr, num_classes=10).to(device)

### 7-4. Train

In [None]:
def train_simclr(classifier, train_loader, val_loader, model_path):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(classifier.parameters(), lr=LR)

    for epoch in range(EPOCH):
        classifier.train()
        total_loss = 0
        correct = 0
        total = 0
        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)

            outputs = classifier(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

        train_acc = 100.0 * correct / total
        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch + 1}/{EPOCH}, Loss: {avg_loss:.4f}, Accuracy: {train_acc:.2f}%')

    # 모델 평가
    classifier.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(val_loader):
            images, labels = images.to(device), labels.to(device)

            outputs = classifier(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

    test_acc = 100.0 * correct / total
    avg_loss = total_loss / len(val_loader)
    print(f'Test Loss: {avg_loss:.4f}, Test Accuracy: {test_acc:.2f}%')

    torch.save(classifier, model_path)

In [None]:
train_simclr(classifier, train_loader, val_loader, model_path)

## 8. Compare Model

### 8-1. Library & HyperParameter

In [None]:
import torch
import torchvision
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

In [None]:
BATCH_SIZE = 128
OUT_DIM = 128
IMAGE_SIZE = 32
NUM_CLASS = 10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mae_path = './pth/mae_classifer_cifar.pt'
simclr_path = './pth/simclr_classifer_cifar.pt'

### 8-2. Test DataLoader

In [None]:
def get_test_loader(batch_size, img_size):
    transform_test = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 표준 정규화
    ])
    
    test_dataset = torchvision.datasets.CIFAR10('../../datasets/CIFAR10', train=False, download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    return test_loader

### 8-3. Compare Model

In [None]:
# MAE 및 SimCLR 모델 성능 비교
def compare_models(mae_path, simclr_path, get_test_loader):
    # MAE 모델 로드
    mae_classifier = torch.load(mae_path).to(device)

    # SimCLR 모델 로드
    simclr_classifier = torch.load(simclr_path).to(device)

    dataloader = get_test_loader(BATCH_SIZE, IMAGE_SIZE)

    results = {
        "MAE": {"top1": 0, "top5": 0, "class_acc": torch.zeros(NUM_CLASS)},
        "SimCLR": {"top1": 0, "top5": 0, "class_acc": torch.zeros(NUM_CLASS)}
    }

    mae_classifier.eval()
    simclr_classifier.eval()

    with torch.no_grad():
        # MAE 평가
        mae_correct = torch.zeros(NUM_CLASS)
        mae_total = torch.zeros(NUM_CLASS)
        mae_top1 = 0
        mae_top5 = 0

        for images, labels in tqdm(dataloader, desc="Evaluating MAE"):
            images, labels = images.to(device), labels.to(device)
            outputs = mae_classifier(images)

            _, predicted = outputs.topk(5, dim=1, largest=True, sorted=True)

            mae_top1 += (predicted[:, 0] == labels).sum().item()
            mae_top5 += (predicted == labels.unsqueeze(1)).sum().item()

            c = (predicted[:, 0] == labels).squeeze()
            for i in range(len(labels)):
                label = labels[i]
                mae_correct[label] += c[i].item()
                mae_total[label] += 1

        results["MAE"]["top1"] = 100.0 * mae_top1 / len(dataloader.dataset)
        results["MAE"]["top5"] = 100.0 * mae_top5 / len(dataloader.dataset)
        results["MAE"]["class_acc"] = mae_correct / mae_total

        # SimCLR 평가
        simclr_correct = torch.zeros(NUM_CLASS)
        simclr_total = torch.zeros(NUM_CLASS)
        simclr_top1 = 0
        simclr_top5 = 0
        for images, labels in tqdm(dataloader, desc="Evaluating SimCLR"):
            images, labels = images.to(device), labels.to(device)
            outputs = simclr_classifier(images)

            _, predicted = outputs.topk(5, dim=1, largest=True, sorted=True)

            simclr_top1 += (predicted[:, 0] == labels).sum().item()
            simclr_top5 += (predicted == labels.unsqueeze(1)).sum().item()

            c = (predicted[:, 0] == labels).squeeze()
            for i in range(len(labels)):
                label = labels[i]
                simclr_correct[label] += c[i].item()
                simclr_total[label] += 1

        results["SimCLR"]["top1"] = 100.0 * simclr_top1 / len(dataloader.dataset)
        results["SimCLR"]["top5"] = 100.0 * simclr_top5 / len(dataloader.dataset)
        results["SimCLR"]["class_acc"] = simclr_correct / simclr_total

    print("\nResults Summary:")
    print("MAE - Top-1 Accuracy: {:.2f}%, Top-5 Accuracy: {:.2f}%".format(results["MAE"]["top1"], results["MAE"]["top5"]))
    print("SimCLR - Top-1 Accuracy: {:.2f}%, Top-5 Accuracy: {:.2f}%".format(results["SimCLR"]["top1"], results["SimCLR"]["top5"]))

    print("\nClass-wise Accuracy:")
    for label in range(NUM_CLASS):
        print(f"Class {label} - MAE: {results['MAE']['class_acc'][label]:.2f}, SimCLR: {results['SimCLR']['class_acc'][label]:.2f}")


In [None]:
compare_models(mae_path, simclr_path, get_test_loader)

## 9. Oxford Dataset에 대해 수행

### 9-1. Pretrain MAE with OxfordIIIPetDataset (Self-Supervised Learning)

In [None]:
import os
import math
import torch
import torchvision

from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
from torchvision import datasets, transforms

In [None]:
SEED = 42
BATCH_SIZE = 4096
MAX_DEVICE_BATCH_SIZE = 512
LR = 1.5e-4
WEIGHT_DECAY = 0.05
MASK_RATIO = 0.75
EPOCH = 2000
WARMUP = 200

device = 'cuda' if torch.cuda.is_available() else 'cpu'
setup_seed(SEED)

model_path = './pth/mae_pretrained_oxford.pt'

batch_size = min(MAX_DEVICE_BATCH_SIZE, BATCH_SIZE)
assert BATCH_SIZE % batch_size == 0
steps_per_update = BATCH_SIZE // batch_size

In [None]:
# 데이터셋 전처리 설정
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 이미지 크기를 32x32로 조정
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 표준 정규화
])

# Oxford-IIIT Pet 데이터셋 로딩
train_dataset = datasets.OxfordIIITPet(root='../../datasets/OxfordPet', split='trainval', transform=transform, download=True)
val_dataset = datasets.OxfordIIITPet(root='../../datasets/OxfordPet', split='test', transform=transform, download=True)

# 데이터 로더 설정
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
writer = SummaryWriter(os.path.join('logs', 'oxford-iiit-pet', 'mae-pretrain'))

In [None]:
model = MAE_ViT(image_size=224, patch_size=16, mask_ratio=MASK_RATIO).to(device)

In [None]:
pretrain_mae(model, dataloader, val_dataset, writer, model_path)

### 9-2.Transfer Learning MAE with OxfordIIIPetDataset (Supervised Learning)

In [None]:
import os
import math
import torch
import torchvision

from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
from torchvision import datasets, transforms

In [None]:
SEED = 42
BATCH_SIZE = 128
MAX_DEVICE_BATCH_SIZE = 256
LR = 1e-3
WEIGHT_DECAY = 0.05
EPOCH = 100
WARMUP = 5

device = 'cuda' if torch.cuda.is_available() else 'cpu'
setup_seed(SEED)

model_path = './pth/mae_classifer_oxford.pt'
pretrain_path = './pth/mae_pretrained_oxford.pt'

batch_size = min(MAX_DEVICE_BATCH_SIZE, BATCH_SIZE)
assert BATCH_SIZE % batch_size == 0
steps_per_update = BATCH_SIZE // batch_size

In [None]:
train_dataset = datasets.OxfordIIITPet(
    root='../../datasets/OxfordPet', split='trainval', target_types='category', transform=transform_train, download=True)
val_dataset = datasets.OxfordIIITPet(
    root='../../datasets/OxfordPet', split='test', target_types='category', transform=transform_test, download=True)

# 데이터 로더 설정
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

writer = SummaryWriter(os.path.join('logs', 'oxford-iiit-pet', 'mae-pretrain'))

In [None]:
model = torch.load(pretrain_path, map_location='cpu')
model = model.to(device)
writer = SummaryWriter(os.path.join('logs', 'oxford-iiit-pet', 'pretrain-cls'))

model = ViT_Classifier(model.encoder, num_classes=37).to(device)

In [None]:
train_mae(model, train_dataloader, val_dataloader, writer, model_path)

### 9-3. Pretrain SimCLR with OxfordIIIPetDataset (Self-Supervised Learning)

In [None]:
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm.notebook import tqdm

In [None]:
BATCH_SIZE = 512
EPOCH = 100
TEMPERATURE = 0.5
OUT_DIM = 128
LR = 3e-4

model_path = './pth/simclr_pretrained_oxford.pth'

In [None]:
def get_data_loaders(batch_size):
    # 데이터셋 전처리 설정
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # 이미지 크기를 32x32로 조정
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 표준 정규화
    ])

    # Oxford-IIIT Pet 데이터셋 로딩
    train_dataset = torchvision.datasets.OxfordIIITPet(
        root='../../datasets/OxfordPet', split='trainval', transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    return train_loader


In [None]:
pretrain_simclr(get_data_loaders, model_path)

### 9-4.Transfer Learning SimCLR with OxfordIIIPetDataset (Supervised Learning)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from torchvision import datasets
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

In [None]:
OUT_DIM = 128
BATCH_SIZE = 128
LR = 1e-3
EPOCH = 100

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_path = './pth/simclr_classifer_oxford.pt'
pretrain_path = './pth/simclr_pretrained_oxford.pth'

In [None]:
transform_train = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Oxford-IIIT Pet 데이터셋 로딩
train_dataset = datasets.OxfordIIITPet(
    root='../../datasets/OxfordPet', split='trainval', target_types='category', transform=transform_train, download=True)
val_dataset = datasets.OxfordIIITPet(
    root='../../datasets/OxfordPet', split='test', target_types='category', transform=transform_test, download=True)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [None]:
simclr = SimCLR(torchvision.models.resnet18(pretrained=False), OUT_DIM)
simclr.load_state_dict(torch.load(pretrain_path))
classifier = SimCLRClassifier(simclr, num_classes=37).to(device)

In [None]:
train_simclr(classifier, train_loader, val_loader, model_path)

### 9-5. Compare Model

In [None]:
import torch
import torchvision
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

In [None]:
BATCH_SIZE = 128
OUT_DIM = 128
IMAGE_SIZE = 32
NUM_CLASS = 37

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mae_path = './pth/mae_classifer_oxford.pt'
simclr_path = './pth/simclr_classifer_oxford.pt'

In [None]:
def get_test_loader(batch_size, img_size):
    transform_test = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 표준 정규화
    ])
    
    test_dataset = datasets.OxfordIIITPet(root='../../datasets/OxfordPet', split='test', transform=transform_test, download=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    return test_loader

In [None]:
compare_models(mae_path, simclr_path, get_test_loader)