In [1]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import math
import matplotlib.pyplot as plt
from einops import rearrange
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(device)


mps


# Hyperparameter

In [2]:
BATCH_SIZE = 32
LAMBDA = 3e-1 
EPOCH = 50

scheduler_name = 'Cos'
warmup_steps = 1000
LR_scale = 0.5
warmup_steps = 12000
LR_peak = 1e-3
LR_init = 1e-4
LR = 1e-4
criterion = nn.CrossEntropyLoss()

save_model_path = '/results/ViT_CIFAR10_2.pt'
save_history_path = '/results/ViT_CIFAR10_2_history.pt'


# Agumentation

In [3]:
transform_train = transforms.Compose([
    # transforms.RandomResizedCrop(scale=(0.9,1), ratio=(0.3,1.7), size=(32,32))

    transforms.RandomGrayscale(p=0.1),
    # transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05)

    transforms.RandomHorizontalFlip(p=0.4), 
    # transforms.RandomVerticalFlip(p=0.5)
    # transforms.RandomAffine(degrees=(0,90),translate=(0.1,0.2),scale=(0.9,1.1))
    # transforms.RandomPerspective(distortion_scale=0.1, p=0.1)

    transforms.ToTensor(), 

    transforms.RandomErasing(p=0.1, scale=(0.03,0.08), ratio=(0.3,3.3)), 
    # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

transform_test = transforms.ToTensor()

train_DS = datasets.CIFAR10(root = './data', train=True, download=True, transform=transform_train)
train_DS, val_DS = torch.utils.data.random_split(train_DS, [45000, 5000])
test_DS = datasets.CIFAR10(root = './data', train=False, download=True, transform=transform_test)

train_DL = torch.utils.data.DataLoader(train_DS, batch_size = BATCH_SIZE, shuffle = True)
val_DL = torch.utils.data.DataLoader(val_DS, batch_size = BATCH_SIZE, shuffle = True)
test_DL = torch.utils.data.DataLoader(test_DS, batch_size = BATCH_SIZE, shuffle = True)

Files already downloaded and verified
Files already downloaded and verified


# Vision Transformer


In [4]:
class MHA(nn.Module):
    def __init__(self, hidden_dim, n_heads):
        super().__init__()

        self.n_heads = n_heads

        self.fc_q = nn.Linear(hidden_dim, hidden_dim)
        self.fc_k = nn.Linear(hidden_dim, hidden_dim)
        self.fc_v = nn.Linear(hidden_dim, hidden_dim)
        self.fc_o = nn.Linear(hidden_dim, hidden_dim)

        self.scale = torch.sqrt(torch.tensor(hidden_dim / n_heads))

        nn.init.xavier_uniform_(self.fc_q.weight)
        nn.init.xavier_uniform_(self.fc_k.weight)
        nn.init.xavier_uniform_(self.fc_v.weight)
        nn.init.xavier_uniform_(self.fc_o.weight)

        if self.fc_q.bias is not None:
            nn.init.constant_(self.fc_q.bias, 0)
        if self.fc_k.bias is not None:
            nn.init.constant_(self.fc_k.bias, 0)
        if self.fc_v.bias is not None:
            nn.init.constant_(self.fc_v.bias, 0)
        if self.fc_o.bias is not None:
            nn.init.constant_(self.fc_o.bias, 0)

    def forward(self, x):

        Q = self.fc_q(x)
        K = self.fc_k(x)
        V = self.fc_v(x)

        Q = rearrange(Q, '개 단 (헤 차) -> 개 헤 단 차', 헤 = self.n_heads)
        K = rearrange(K, '개 단 (헤 차) -> 개 헤 단 차', 헤 = self.n_heads)
        V = rearrange(V, '개 단 (헤 차) -> 개 헤 단 차', 헤 = self.n_heads)

        attention_score = Q @ K.transpose(-2,-1)/self.scale

        attention_weights = torch.softmax(attention_score, dim=-1)

        attention = attention_weights @ V

        x = rearrange(attention, '개 헤 단 차 -> 개 단 (헤 차)')
        x = self.fc_o(x)

        return x, attention_weights

class FeedForward(nn.Module):
    def __init__(self, hidden_dim, d_ff, drop_p):
        super().__init__()

        self.linear = nn.Sequential(nn.Linear(hidden_dim, d_ff),
                                    nn.GELU(),
                                    nn.Dropout(drop_p),
                                    nn.Linear(d_ff, hidden_dim))

    def forward(self, x):
        x = self.linear(x)
        return x

class EncoderLayer(nn.Module):
    def __init__(self, hidden_dim, d_ff, n_heads, drop_p):
        super().__init__()

        self.self_atten_LN = nn.LayerNorm(hidden_dim, eps=1e-6)
        self.self_atten = MHA(hidden_dim, n_heads)

        self.FF_LN = nn.LayerNorm(hidden_dim, eps=1e-6)
        self.FF = FeedForward(hidden_dim, d_ff, drop_p)

        self.dropout = nn.Dropout(drop_p)

    def forward(self, x):

        residual = self.self_atten_LN(x)
        residual, atten_enc = self.self_atten(residual)
        residual = self.dropout(residual)
        x = x + residual

        residual = self.FF_LN(x)
        residual = self.FF(residual)
        residual = self.dropout(residual)
        x = x + residual

        return x, atten_enc

class Encoder(nn.Module):
    def __init__(self, seq_length, n_layers, hidden_dim, d_ff, n_heads, drop_p):
        super().__init__()

        self.pos_embedding = nn.Parameter(0.02*torch.randn(seq_length, hidden_dim))
        self.dropout = nn.Dropout(drop_p)
        self.layers = nn.ModuleList([EncoderLayer(hidden_dim, d_ff, n_heads, drop_p) for _ in range(n_layers)])
        self.ln = nn.LayerNorm(hidden_dim, eps=1e-6)

    def forward(self, src, atten_map_save = False):

        x = src + self.pos_embedding.expand_as(src)
        x = self.dropout(x)

        atten_encs = torch.tensor([]).to(device)
        for layer in self.layers:
            x, atten_enc = layer(x)
            if atten_map_save is True:
                atten_encs = torch.cat([atten_encs , atten_enc[0].unsqueeze(0)], dim=0)

        x = x[:,0,:]
        x = self.ln(x)

        return x, atten_encs

class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, n_layers, hidden_dim, d_ff, n_heads, representation_size = None, drop_p = 0., num_classes = 1000):
        super().__init__()

        self.hidden_dim = hidden_dim

        seq_length = (image_size // patch_size) ** 2 + 1

        self.class_token = nn.Parameter(torch.zeros(hidden_dim))
        self.input_embedding = nn.Conv2d(3, hidden_dim, patch_size, stride=patch_size)
        self.encoder = Encoder(seq_length, n_layers, hidden_dim, d_ff, n_heads, drop_p)

        heads_layers = []
        if representation_size is None:
            self.head = nn.Linear(hidden_dim, num_classes)
        else:
            self.head = nn.Sequential(nn.Linear(hidden_dim, representation_size),
                                      nn.Tanh(),
                                      nn.Linear(representation_size, num_classes))

        fan_in = self.input_embedding.in_channels * self.input_embedding.kernel_size[0] * self.input_embedding.kernel_size[1]
        nn.init.trunc_normal_(self.input_embedding.weight, std=math.sqrt(1 / fan_in))
        if self.input_embedding.bias is not None:
            nn.init.zeros_(self.input_embedding.bias)

        if representation_size is None:
            nn.init.zeros_(self.head.weight)
            nn.init.zeros_(self.head.bias)
        else:
            fan_in = self.head[0].in_features
            nn.init.trunc_normal_(self.head[0].weight, std=math.sqrt(1 / fan_in))
            nn.init.zeros_(self.head[0].bias)

    def forward(self, x):

        x = self.input_embedding(x)
        x = rearrange(x, '개 차 단h 단w -> 개 (단h 단w) 차')

        batch_class_token = self.class_token.expand(x.shape[0], 1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        enc_out, atten_encs = self.encoder(x)

        x = self.head(enc_out)

        return x, atten_encs

# def vit_b_16(**kwargs):
#     return VisionTransformer(image_size = 224, patch_size = 16, n_layers = 12, hidden_dim = 768, d_ff = 3072, n_heads = 12, representation_size = 768, **kwargs)

# def vit_b_32(**kwargs):
#     return VisionTransformer(image_size = 224, patch_size = 32, n_layers = 12, hidden_dim = 768, d_ff = 3072, n_heads = 12, representation_size = 768, **kwargs)

# def vit_l_16(**kwargs):
#     return VisionTransformer(image_size = 224, patch_size = 16, n_layers = 24, hidden_dim = 1024, d_ff = 4096, n_heads = 16, representation_size = 1024, **kwargs)

# def vit_l_32(**kwargs):
#     return VisionTransformer(image_size = 224, patch_size = 32, n_layers = 24, hidden_dim = 1024, d_ff = 4096, n_heads = 16, representation_size = 1024, **kwargs)

# def vit_h_14(**kwargs):
#     return VisionTransformer(image_size = 224, patch_size = 14, n_layers = 32, hidden_dim = 1280, d_ff = 5120, n_heads = 16, representation_size = 1280, **kwargs)

def vit_cifar10(**kwargs):
    return VisionTransformer(image_size = 32, patch_size = 2, n_layers = 3, hidden_dim = 256, d_ff = 1024, n_heads = 8, representation_size = 256, **kwargs)

model = vit_cifar10().to(device)


In [5]:

from torchinfo import summary
summary(model, input_size=(2,3,32,32), device=device)

  from .autonotebook import tqdm as notebook_tqdm


Layer (type:depth-idx)                        Output Shape              Param #
VisionTransformer                             [2, 1000]                 256
├─Conv2d: 1-1                                 [2, 256, 16, 16]          3,328
├─Encoder: 1-2                                [2, 256]                  65,792
│    └─Dropout: 2-1                           [2, 257, 256]             --
│    └─ModuleList: 2-2                        --                        --
│    │    └─EncoderLayer: 3-1                 [2, 257, 256]             789,760
│    │    └─EncoderLayer: 3-2                 [2, 257, 256]             789,760
│    │    └─EncoderLayer: 3-3                 [2, 257, 256]             789,760
│    └─LayerNorm: 2-3                         [2, 256]                  512
├─Sequential: 1-3                             [2, 1000]                 --
│    └─Linear: 2-4                            [2, 256]                  65,792
│    └─Tanh: 2-5                              [2, 256]             

# Train, Test, loss_epoch

In [5]:
def Train(model, train_DL, val_DL, criterion, optimizer, scheduler = None):
    loss_history = {"train": [], "val": []}
    acc_history = {"train": [], "val": []}
    best_loss = 9999
    for ep in range(EPOCH):
        model.train() 
        train_loss, train_acc, _ = loss_epoch(model, train_DL, criterion, optimizer = optimizer, scheduler = scheduler)
        loss_history["train"] += [train_loss]
        acc_history["train"] += [train_acc]

        model.eval() 
        with torch.no_grad():
            val_loss, val_acc, _ = loss_epoch(model, val_DL, criterion)
            loss_history["val"] += [val_loss]
            acc_history["val"] += [val_acc]
            if val_loss < best_loss:
                best_loss = val_loss
                torch.save({"model": model,
                            "ep": ep,
                            "optimizer": optimizer,
                            "scheduler": scheduler}, save_model_path)
        print(f"Epoch: {ep+1}, current_LR = {optimizer.param_groups[0]['lr']:.8f}")
        print(f"train loss: {train_loss:.5f}, "
              f"val loss: {val_loss:.5f} \n"
              f"train acc: {train_acc:.1f} %, "
              f"val acc: {val_acc:.1f} %")
        print("-" * 20)

    torch.save({"loss_history": loss_history,
                "acc_history": acc_history,
                "EPOCH": EPOCH,
                "BATCH_SIZE": BATCH_SIZE}, save_history_path)

def Test(model, test_DL, criterion):
    model.eval() 
    with torch.no_grad():
        test_loss, test_acc, rcorrect = loss_epoch(model, test_DL, criterion)
    print(f"Test loss: {test_loss:.3f}")
    print(f"Test accuracy: {rcorrect}/{len(test_DL.dataset)} ({round(test_acc,1)} %)")
    return round(test_acc,1)

def loss_epoch(model, DL, criterion, optimizer = None, scheduler = None):
    N = len(DL.dataset) 
    rloss=0; rcorrect = 0
    for x_batch, y_batch in tqdm(DL, leave=False):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        y_hat = model(x_batch)[0]
        loss = criterion(y_hat, y_batch)
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if scheduler is not None:
            scheduler.step()
        loss_b = loss.item() * x_batch.shape[0]
        rloss += loss_b
        pred = y_hat.argmax(dim=1)
        corrects_b = torch.sum(pred == y_batch).item()
        rcorrect += corrects_b
    loss_e = rloss/N
    accuracy_e = rcorrect/N * 100

    return loss_e, accuracy_e, rcorrect

class NoamScheduler:
    def __init__(self, optimizer, hidden_dim, warmup_steps, LR_scale = 1):
        self.optimizer = optimizer
        self.current_step = 0
        self.hidden_dim = hidden_dim
        self.warmup_steps = warmup_steps
        self.LR_scale = LR_scale

    def step(self):
        self.current_step += 1
        lrate = self.LR_scale * (self.hidden_dim ** -0.5) * min(self.current_step ** -0.5, self.current_step * self.warmup_steps ** -1.5)
        self.optimizer.param_groups[0]['lr'] = lrate

class LinearWarmupLinearDecayScheduler:
    def __init__(self, optimizer, warmup_steps, total_steps, max_lr):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.max_lr = max_lr
        self.current_step = 0

    def step(self):
        self.current_step += 1
        if self.current_step < self.warmup_steps:
            lrate = self.max_lr * (self.current_step / self.warmup_steps)
        else:
            decay_steps = self.total_steps - self.warmup_steps
            lrate = self.max_lr * max(0, float(decay_steps - (self.current_step - self.warmup_steps)) / decay_steps)
        self.optimizer.param_groups[0]['lr'] = lrate


In [None]:
optimizer = optim.Adam(nn.Linear(1, 1).parameters(), lr=LR_init)
scheduler = CosineAnnealingLR(optimizer, T_max = int(len(train_DS)*EPOCH/BATCH_SIZE))


# Train

In [None]:

params = [p for p in model.parameters() if p.requires_grad] 
if scheduler_name == 'Noam':
    optimizer = optim.AdamW(params, lr=0, weight_decay=LAMBDA)
    scheduler = NoamScheduler(optimizer, hidden_dim=model.hidden_dim, warmup_steps=warmup_steps, LR_scale=LR_scale)
elif scheduler_name == 'Linear':
    optimizer = optim.AdamW(params, lr=0, weight_decay=LAMBDA)
    scheduler = LinearWarmupLinearDecayScheduler(optimizer, warmup_steps=warmup_steps, total_steps = int(len(train_DS)*EPOCH/BATCH_SIZE), max_lr=LR_peak)
elif scheduler_name == 'Cos':
    optimizer = optim.AdamW(params, lr=LR_init, weight_decay=LAMBDA)
    scheduler = CosineAnnealingLR(optimizer, T_max = int(len(train_DS)*EPOCH/BATCH_SIZE))
elif scheduler_name == 'Constant':
    optimizer = optim.AdamW(params, lr=LR, weight_decay=LAMBDA)
    scheduler = None
Train(model, train_DL, val_DL, criterion, optimizer, scheduler)