In [83]:
# !pip install -q torchinfo
# !pip install -q pytorch_lightning
# !pip install -q timm
# !pip install -q einops

In [84]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from pytorch_lightning.loggers import TensorBoardLogger
import pickle
from pathlib import Path
import torchvision.transforms as T
import cv2
from PIL import Image
from torchvision.utils import make_grid
import torchvision.transforms.functional as TF
import numpy as np
from time import time
from datetime import timedelta
from sklearn.model_selection import train_test_split
from torch.optim import Adam
from torch.cuda.amp import GradScaler
from timm.scheduler import CosineLRScheduler
from einops import rearrange

In [49]:
### Data
VAL_RATIO = 0.1
### CIFAR-10
DATA_DIR = "/content/data/cifar-10-batches-py"
with open(Path(DATA_DIR)/"batches.meta", mode="rb") as f:
    meta = pickle.load(f, encoding="bytes")
label_names = meta[b"label_names"]
CIFAR10_CLASSES = [i.decode("ascii") for i in label_names]
N_CLASSES = len(CIFAR10_CLASSES)
IMG_SIZE = 32

### Architecture
DROP_PROB = 0.1
N_LAYERS = 6
HIDDEN_SIZE = 384
MLP_SIZE = 384
N_HEADS = 12
PATCH_SIZE = 4

### Optimizer
# "Adam with $beta_{1} = 0.9$, $beta_{2}= 0.999$, a batch size of 4096 and apply a high weight decay
# of 0.1, which we found to be useful for transfer of all models."
BASE_LR = 1e-3
BETA1 = 0.9
BETA2 = 0.999
WEIGHT_DECAY = 5e-5
WARMUP_EPOCHS = 5

### Regularization
SMOOTHING = 0.1 # If `0`, do not employ label smoothing
CUTMIX = False
CUTOUT = False
HIDE_AND_SEEK = False

### Training
SEED = 17
# BATCH_SIZE = 4096 # "All models are trained with a batch size of 4096."
BATCH_SIZE = 2048
N_EPOCHS = 300
N_WORKERS = 6
N_GPUS = torch.cuda.device_count()
if N_GPUS > 0:
    DEVICE = torch.device("cuda")
    print(f"""Using {N_GPUS} GPU(s).""")
else:
    DEVICE = torch.device("cpu")
    print(f"""Using CPU(s).""")
MULTI_GPU = True
AUTOCAST = True
N_PRINT_EPOCHS = 4
N_VAL_EPOCHS = 4
CKPT_DIR = Path("/content/").parent/"checkpoints"

### Resume
CKPT_PATH = None

Using 1 GPU(s).


In [55]:
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size, hidden_size, drop_prob=DROP_PROB):
        super().__init__()

        self.patch_size = patch_size
        dim = (patch_size ** 2) * 3

        self.norm1 = nn.LayerNorm(dim)
        self.proj = nn.Linear(dim, hidden_size)
        self.drop = nn.Dropout(drop_prob)
        self.norm2 = nn.LayerNorm(hidden_size)

    def forward(self, x):
        x = rearrange(
            x,
            pattern="b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
            p1=self.patch_size,
            p2=self.patch_size,
        )
        x = self.norm1(x) # Not in the paper
        x = self.proj(x)
        # "Dropout is applied after every dense layer except for the the qkv-projections
        # and directly after adding positional- to patch embeddings."
        x = self.drop(x)
        x = self.norm2(x) # Not in the paper
        return x


class MSA(nn.Module):
    def __init__(self, hidden_size, n_heads, drop_prob=DROP_PROB):
        super().__init__()

        self.head_size = hidden_size // n_heads
        self.n_heads = n_heads

        # "U_{qkv} \in \mathbb{R}^{D \times 3D_{h}}"
        self.qkv_proj = nn.Linear(hidden_size, 3 * n_heads * self.head_size, bias=False)
        self.drop = nn.Dropout(drop_prob)
        self.out_proj = nn.Linear(hidden_size, hidden_size, bias=False)

    def _get_attention_score(self, q, k):
        # "$qk^{T}$"
        attn_score = torch.einsum("bhnd,bhmd->bhnm", q, k)
        return attn_score

    def forward(self, x):
        # "$[q, k, v] = zU_{qkv}$"
        q, k, v = torch.split(
            self.qkv_proj(x), split_size_or_sections=self.n_heads * self.head_size, dim=2,
        )
        q = rearrange(q, pattern="b n (h d) -> b h n d", h=self.n_heads, d=self.head_size)
        k = rearrange(k, pattern="b n (h d) -> b h n d", h=self.n_heads, d=self.head_size)
        v = rearrange(v, pattern="b n (h d) -> b h n d", h=self.n_heads, d=self.head_size)
        attn_score = self._get_attention_score(q=q, k=k)
        # "$A = softmax(qk^{T}/\sqrt{D_{h}}), A \in \mathbb{R}^{N \times N}$"
        attn_weight = F.softmax(attn_score / (self.head_size ** 0.5), dim=3)
        # attn_weight = self.drop(attn_weight)
        x = torch.einsum("bhnm,bhmd->bhnd", attn_weight, v)
        # "$U_{msa} \in \mathbb{R}^{k \cdot D_{h} \times D}$"
        x = rearrange(x, pattern="b h n d -> b n (h d)")
        x = self.out_proj(x)
        # "Dropout is applied after every dense layer except for the the qkv-projections
        # and directly after adding positional- to patch embeddings."
        x = self.drop(x)
        return x


class SkipConnection(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()

        self.norm = nn.LayerNorm(hidden_size) # "$LN$"

    def forward(self, x, sublayer):
        # "Layernorm (LN) is applied before every block, and residual connections after every block."
        # "$z'_{l} = MSA(LN(z_{l - 1})) + z_{l - 1}$", "$z_{l} = MLP(LN(z'_{l})) + z'_{l}$"
        skip = x.clone()
        x = self.norm(x)
        x = sublayer(x)
        x += skip
        return x


class MLP(nn.Module):
    def __init__(self, hidden_size, mlp_size):
        super().__init__()

        self.proj1 = nn.Linear(hidden_size, mlp_size)
        self.drop1 = nn.Dropout(0.1)
        self.proj2 = nn.Linear(mlp_size, hidden_size)
        self.drop2 = nn.Dropout(0.1)

    def forward(self, x):
        x = self.proj1(x)
        x = F.gelu(x) # "The MLP contains two layers with a GELU non-linearity."
        # "Dropout is applied after every dense layer except for the the qkv-projections
        # and directly after adding positional- to patch embeddings."
        # Activation function 다음에 Dropout이 오도록!
        x = self.drop1(x)
        x = self.proj2(x)
        x = F.gelu(x)
        x = self.drop2(x)
        return x


class TransformerEncoderLayer(nn.Module):
    def __init__(self, hidden_size, mlp_size, n_heads):
        super().__init__()

        self.self_attn = MSA(hidden_size=hidden_size, n_heads=n_heads)
        self.self_attn_resid = SkipConnection(hidden_size=hidden_size)
        self.mlp = MLP(hidden_size=hidden_size, mlp_size=mlp_size)
        self.mlp_resid = SkipConnection(hidden_size=hidden_size)

    def forward(self, x):
        x = self.self_attn_resid(x=x, sublayer=self.self_attn)
        x = self.mlp_resid(x=x, sublayer=self.mlp)
        return x


class TransformerEncoder(nn.Module):
    def __init__(self, n_layers, hidden_size, mlp_size, n_heads):
        super().__init__()

        self.enc_stack = nn.ModuleList(
            [TransformerEncoderLayer(hidden_size=hidden_size, mlp_size=mlp_size, n_heads=n_heads)
                for _ in range(n_layers)]
        )

    def forward(self, x):
        for enc_layer in self.enc_stack:
            x = enc_layer(x)
        return x


class ViT(nn.Module):
    """
    ViT-Base: `n_layers=12, hidden_size=768, mlp_size=3072, n_heads=12`
    ViT-Large: `n_layers=24, hidden_size=1024, mlp_size=4096, n_heads=16`
    ViT-Huge: `n_layers=32, hidden_size=1280, mlp_size=5120, n_heads=16`
    """
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        n_layers=12,
        hidden_size=768,
        mlp_size=3072,
        n_heads=12,
        drop_prob=DROP_PROB,
        n_classes=0,
    ):
        super().__init__()

        self.n_classes = n_classes

        assert img_size % patch_size == 0, "`img_size` must be divisible by `patch_size`!"

        cell_size = img_size // patch_size
        n_patches = cell_size ** 2

        # $\textbf{E}$ of the equation 1 in the paper
        self.patch_embed = PatchEmbedding(patch_size=patch_size, hidden_size=hidden_size)
        self.cls_token = nn.Parameter(torch.randn((1, 1, hidden_size))) # $x_{\text{class}}$
        # $\textbf{E}_\text{pos}$
        self.pos_embed = nn.Parameter(torch.randn((1, n_patches + 1, hidden_size)))
        self.drop1 = nn.Dropout(drop_prob)
        self.tf_enc = TransformerEncoder(
            n_layers=n_layers, hidden_size=hidden_size, mlp_size=mlp_size, n_heads=n_heads,
        )

        self.norm = nn.LayerNorm(hidden_size) # "$LN$"
        self.proj = nn.Linear(hidden_size, n_classes)
        self.drop2 = nn.Dropout(drop_prob)

    def forward(self, x):
        b, _, _, _ = x.shape

        x = self.patch_embed(x)
        x = torch.cat((self.cls_token.repeat(b, 1, 1), x), dim=1)
        x += self.pos_embed
        # "Dropout is applied after every dense layer except for the the qkv-projections
        # and directly after adding positional- to patch embeddings."
        x = self.drop1(x)
        x = self.tf_enc(x)

        if self.n_classes == 0:
            x = x.mean(dim=1)
        else:
            x = x[:, 0, :] # $z^{0}_{L}$ of the equation 4 in the paper
            # "Layernorm (LN) is applied before every block."
            x = self.norm(x) # $y$
            x = self.proj(x)
            # "Dropout is applied after every dense layer except for the the qkv-projections
            # and directly after adding positional- to patch embeddings."
            x = self.drop2(x)
        return x

In [56]:
def print_number_of_parameters(model):
    print(f"""{sum([p.numel() for p in model.parameters()]):,}""")


def get_elapsed_time(start_time):
    return timedelta(seconds=round(time() - start_time))


def load_image(img_path):
    img_path = str(img_path)
    img = cv2.imread(img_path, flags=cv2.IMREAD_COLOR)
    img = cv2.cvtColor(src=img, code=cv2.COLOR_BGR2RGB)
    return img


def _to_pil(img):
    if not isinstance(img, Image.Image):
        img = Image.fromarray(img)
    return img


def show_image(img):
    copied = img.copy()
    copied = _to_pil(copied)
    copied.show()


def _apply_jet_colormap(img):
    img_jet = cv2.applyColorMap(src=(255 - img), colormap=cv2.COLORMAP_JET)
    return img_jet


def _to_array(img):
    img = np.array(img)
    return img


def _blend_two_images(img1, img2, alpha=0.5):
    img1 = _to_pil(img1)
    img2 = _to_pil(img2)
    img_blended = Image.blend(im1=img1, im2=img2, alpha=alpha)
    return _to_array(img_blended)


def _to_3d(img):
    if img.ndim == 2:
        return np.dstack([img, img, img])
    else:
        return img


def _rgba_to_rgb(img):
    copied = img.copy().astype("float")
    copied[..., 0] *= copied[..., 3] / 255
    copied[..., 1] *= copied[..., 3] / 255
    copied[..., 2] *= copied[..., 3] / 255
    copied = copied.astype("uint8")
    copied = copied[..., : 3]
    return copied


def _preprocess_image(img):
    if img.dtype == "bool":
        img = img.astype("uint8") * 255

    if img.ndim == 2:
        if (
            np.array_equal(np.unique(img), np.array([0, 255])) or
            np.array_equal(np.unique(img), np.array([0])) or
            np.array_equal(np.unique(img), np.array([255]))
        ):
            img = _to_3d(img)
        else:
            img = _apply_jet_colormap(img)
    return img


def _blend_two_images(img1, img2, alpha=0.5):
    img1 = _to_pil(img1)
    img2 = _to_pil(img2)
    img_blended = Image.blend(im1=img1, im2=img2, alpha=alpha)
    return _to_array(img_blended)


def save_image(img1, img2=None, alpha=0.5, path="") -> None:
    copied1 = _preprocess_image(
        _to_array(img1.copy())
    )
    if img2 is None:
        img_arr = copied1
    else:
        copied2 = _to_array(
            _preprocess_image(
                _to_array(img2.copy())
            )
        )
        img_arr = _to_array(
            _blend_two_images(img1=copied1, img2=copied2, alpha=alpha)
        )

    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    if img_arr.ndim == 3:
        cv2.imwrite(
            filename=str(path), img=img_arr[:, :, :: -1], params=[cv2.IMWRITE_JPEG_QUALITY, 100]
        )
    elif img_arr.ndim == 2:
        cv2.imwrite(
            filename=str(path), img=img_arr, params=[cv2.IMWRITE_JPEG_QUALITY, 100]
        )


def denorm(tensor, mean, std):
    return TF.normalize(
        tensor, mean=- np.array(mean) / np.array(std), std=1 / np.array(std),
    )


def image_to_grid(image, mean, std, n_cols, padding=1):
    tensor = image.clone().detach().cpu()
    tensor = denorm(tensor, mean=mean, std=std)
    grid = make_grid(tensor, nrow=n_cols, padding=1, pad_value=padding)
    grid.clamp_(0, 1)
    grid = TF.to_pil_image(grid)
    return grid

In [41]:
# Data loaders
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [64]:
class CELossWithLabelSmoothing(nn.Module):
    def __init__(self, n_classes, smoothing=0):
        super().__init__()

        assert 0 <= smoothing <= 1, "The argument `smoothing` must be between 0 and 1!"

        self.n_classes = n_classes
        self.smoothing = smoothing

    def forward(self, pred, gt):
        if gt.ndim == 1:
            gt = torch.eye(self.n_classes, device=gt.device)[gt]
            return self(pred, gt)
        elif gt.ndim == 2:
            log_prob = F.log_softmax(pred, dim=1)
            ce_loss = -torch.sum(gt * log_prob, dim=1)
            loss = (1 - self.smoothing) * ce_loss
            loss += self.smoothing * -torch.sum(log_prob, dim=1)
            return torch.mean(loss)


class ClassificationLoss(nn.Module):
    def __init__(self, n_classes, smoothing=0):
        super().__init__()

        assert 0 <= smoothing <= 1, "The argument `smoothing` must be between 0 and 1!"

        self.n_classes = n_classes
        self.smoothing = smoothing

        self.ce = nn.CrossEntropyLoss(reduction="sum")

    def forward(self, pred, gt):
        if gt.ndim == 1:
            new_gt = torch.full_like(pred, fill_value=self.smoothing / (self.n_classes - 1))
            new_gt.scatter_(1, gt.unsqueeze(1), 1 - self.smoothing)
        elif gt.ndim == 2:
            new_gt = gt.clone()
            new_gt.sum(dim=1)
            new_gt *= (1 - self.smoothing)
            is_zero = (gt == 0)
            likelihood = self.smoothing / (gt.shape[1] - (~is_zero).sum(dim=1))
            new_gt += is_zero * likelihood.unsqueeze(1).repeat(1, self.n_classes)
        loss = self.ce(pred, new_gt)
        return loss

In [65]:
class TopKAccuracy(nn.Module):
    def __init__(self, k):
        super().__init__()

        self.k = k

    def forward(self, pred, gt):
        _, topk = torch.topk(pred, k=self.k, dim=1)
        corr = torch.eq(topk, gt.unsqueeze(1).repeat(1, self.k))
        acc = corr.sum(dim=1).float().mean().item()
        return acc

In [70]:
def get_cifar10_imgs_and_gts(data_path):
    with open(data_path, mode="rb") as f:
        data_dic = pickle.load(f, encoding="bytes")

    imgs = data_dic[b"data"]
    imgs = imgs.reshape(-1, 3, IMG_SIZE, IMG_SIZE)
    imgs = imgs.transpose(0, 2, 3, 1)

    gts = data_dic[b"labels"]
    gts = np.array(gts)
    return imgs, gts


def get_cifar10_train_val_set(data_dir):
    imgs_ls = list()
    gts_ls = list()
    for idx in range(1, 6):
        imgs, gts = get_cifar10_imgs_and_gts(Path(data_dir)/f"data_batch_{idx}")
        imgs_ls.append(imgs)
        gts_ls.append(gts)
    imgs = np.concatenate(imgs_ls, axis=0)
    gts = np.concatenate(gts_ls, axis=0)
    return imgs, gts


def get_all_cifar10_imgs_and_gts(data_dir, val_ratio):
    train_val_imgs, train_val_gts = get_cifar10_train_val_set(data_dir)
    train_imgs, val_imgs, train_gts, val_gts = train_test_split(
        train_val_imgs, train_val_gts, test_size=val_ratio,
    )
    test_imgs, test_gts = get_cifar10_imgs_and_gts(Path(data_dir)/"test_batch")
    return train_imgs, train_gts, val_imgs, val_gts, test_imgs, test_gts


def get_cifar_mean_and_std(imgs):
    imgs = imgs.astype("float") / 255
    n_pixels = imgs.size // 3
    sum_ = imgs.reshape(-1, 3).sum(axis=0)
    sum_square = (imgs ** 2).reshape(-1, 3).sum(axis=0)
    mean = (sum_ / n_pixels).round(3)
    std = (((sum_square / n_pixels) - mean ** 2) ** 0.5).round(3)
    return mean, std


class CIFARDataset(Dataset):
    def __init__(self, imgs, gts, mean, std):
        super().__init__()

        self.imgs = imgs
        self.gts = gts

        self.transform = T.Compose([
            T.RandomHorizontalFlip(p=0.5),
            T.RandomCrop(size=IMG_SIZE, padding=4, pad_if_needed=True),
            T.RandomApply(
                [T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)],
                p=0.4,
            ),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])

    def __len__(self):
        return len(self.gts)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        image = Image.fromarray(img, mode="RGB")
        image = self.transform(image)

        gt = self.gts[idx]
        gt = torch.tensor(gt).long()
        return image, gt


def get_cifar10_dses(data_dir, val_ratio=0.1):
    train_imgs, train_gts, val_imgs, val_gts, test_imgs, test_gts = get_all_cifar10_imgs_and_gts(
            data_dir=data_dir, val_ratio=val_ratio,
    )
    mean, std = get_cifar_mean_and_std(train_imgs)
    train_ds = CIFARDataset(imgs=train_imgs, gts=train_gts, mean=mean, std=std)
    val_ds = CIFARDataset(imgs=val_imgs, gts=val_gts, mean=mean, std=std)
    test_ds = CIFARDataset(imgs=test_imgs, gts=test_gts, mean=mean, std=std)
    return train_ds, val_ds, test_ds

In [67]:
import random

def apply_cutmix(image, gt, n_classes):
    if gt.ndim == 1:
        gt = F.one_hot(gt, num_classes=n_classes)

    b, _, h, w = image.shape

    lamb = random.random()
    region_x = random.randint(0, w)
    region_y = random.randint(0, h)
    region_w = region_h = (1 - lamb) ** 0.5

    xmin = max(0, int(region_x - region_w / 2))
    ymin = max(0, int(region_y - region_h / 2))
    xmax = max(w, int(region_x + region_w / 2))
    ymax = max(h, int(region_y + region_h / 2))

    indices = torch.randperm(b)
    image[:, :, ymin: ymax, xmin: xmax] = image[indices][:, :, ymin: ymax, xmin: xmax]
    lamb = 1 - (xmax - xmin) * (ymax - ymin) / (w * h)
    gt = lamb * gt + (1 - lamb) * gt[indices]
    return image, gt



def apply_hide_and_seek(image, patch_size, hide_prob=0.5, mean=(0.5, 0.5, 0.5)):
    b, _, h, w = image.shape
    assert h % patch_size == 0 and w % patch_size == 0,\
        "`patch_size` argument should be a multiple of both the width and height of the input image"

    mean_tensor = torch.Tensor(mean)[None, :, None, None].repeat(b, 1, patch_size, patch_size)

    copied = image.clone()
    for i in range(h // patch_size):
        for j in range(w // patch_size):
            if random.random() < hide_prob:
                    continue

            copied[
                ..., i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size
            ] = mean_tensor
    return copied

def apply_cutout(image, cutout_size=16, mean=(0.485, 0.456, 0.406)):
  _, _, h, w = image.shape

  x = random.randint(0, w)
  y = random.randint(0, h)
  xmin = max(0, x - cutout_size // 2)
  ymin = max(0, y - cutout_size // 2)
  xmax = max(0, x + cutout_size // 2)
  ymax = max(0, y + cutout_size // 2)

  image[:, 0, ymin: ymax, xmin: xmax] = mean[0]
  image[:, 1, ymin: ymax, xmin: xmax] = mean[1]
  image[:, 2, ymin: ymax, xmin: xmax] = mean[2]
  return image

In [42]:
# TensorBoard Logger
logger = TensorBoardLogger('tb_logs', name='vit_experiment')

In [85]:
torch.set_printoptions(linewidth=200, sci_mode=False)
torch.manual_seed(SEED)


def save_checkpoint(epoch, model, optim, scaler, avg_acc, ckpt_path):
    Path(ckpt_path).parent.mkdir(parents=True, exist_ok=True)

    ckpt = {
        "epoch": epoch,
        "optimizer": optim.state_dict(),
        "scaler": scaler.state_dict(),
        "average_accuracy": avg_acc,
    }
    if N_GPUS > 0 and MULTI_GPU:
        ckpt["model"] = model.module.state_dict()
    else:
        ckpt["model"] = model.state_dict()

    torch.save(ckpt, str(ckpt_path))


@torch.no_grad()
def validate(dl, model, metric):
    print(f"""Validating...""")
    model.eval()
    sum_acc = 0
    for image, gt in dl:
        image = image.to(DEVICE)
        gt = gt.to(DEVICE)

        pred = model(image)
        acc = metric(pred=pred, gt=gt)
        sum_acc += acc
    avg_acc = sum_acc / len(dl)
    print(f"""Average accuracy: {avg_acc:.3f}""")

    model.train()
    return avg_acc


if __name__ == "__main__":
    print(f"""N_WORKERS = {N_WORKERS}""")
    print(f"""DEVICE = {DEVICE}""")
    print(f"""AUTOCAST = {AUTOCAST}""")
    print(f"""BATCH_SIZE = {BATCH_SIZE}""")

    train_ds, val_ds, test_ds = get_cifar10_dses(data_dir=DATA_DIR, val_ratio=VAL_RATIO)
    train_dl = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, drop_last=True,
    )
    val_dl = DataLoader(
        val_ds, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, drop_last=True,
    )

    model = ViT(
        img_size=IMG_SIZE,
        patch_size=PATCH_SIZE,
        n_layers=N_LAYERS,
        hidden_size=HIDDEN_SIZE,
        mlp_size=MLP_SIZE,
        n_heads=N_HEADS,
        n_classes=N_CLASSES,
    )
    if N_GPUS > 0:
        model = model.to(DEVICE)
        if MULTI_GPU:
            model = nn.DataParallel(model)

    crit = CELossWithLabelSmoothing(n_classes=N_CLASSES, smoothing=SMOOTHING)
    metric = TopKAccuracy(k=1)

    optim = Adam(
        model.parameters(),
        lr=BASE_LR,
        betas=(BETA1, BETA2),
        weight_decay=WEIGHT_DECAY,
    )
    scheduler = CosineLRScheduler(
        optimizer=optim,
        t_initial=N_EPOCHS,
        warmup_t=WARMUP_EPOCHS,
        warmup_lr_init=BASE_LR / 10,
        t_in_epochs=True,
    )

    scaler = GradScaler(enabled=True if AUTOCAST else False)

    ### Resume
    if CKPT_PATH is not None:
        ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
        if N_GPUS > 1 and MULTI_GPU:
            model.module.load_state_dict(ckpt["model"])
        else:
            model.load_state_dict(ckpt["model"])
        optim.load_state_dict(ckpt["optimizer"])
        scaler.load_state_dict(ckpt["scaler"])

        init_epoch = ckpt["epoch"]
        best_avg_acc = ckpt["average_accuracy"]
        print(f"""Resuming from checkpoint '{CKPT_PATH}'...""")

        prev_ckpt_path = CKPT_PATH
    else:
        init_epoch = 0
        prev_ckpt_path = ".pth"
        best_avg_acc = 0

    start_time = time()
    running_loss = 0
    step_cnt = 0
    for epoch in range(init_epoch + 1, N_EPOCHS + 1):
        for step, (image, gt) in enumerate(train_dl, start=1):
            image = image.to(DEVICE)
            gt = gt.to(DEVICE)

            if HIDE_AND_SEEK:
                image = apply_hide_and_seek(
                    image, patch_size=IMG_SIZE // 4, mean=MEAN,
                )
            if CUTMIX:
                image, gt = apply_cutmix(image=image, gt=gt, n_classes=N_CLASSES)
            if CUTOUT:
                image = apply_cutout(image)

            with torch.autocast(
                device_type=DEVICE.type,
                dtype=torch.float16,
                enabled=True if AUTOCAST else False,
            ):
                pred = model(image)
                loss = crit(pred, gt)
            optim.zero_grad()
            if AUTOCAST:
                scaler.scale(loss).backward()
                scaler.step(optim)
                scaler.update()
            else:
                loss.backward()
                optim.step()
            scheduler.step_update(num_updates=epoch * len(train_dl))

            running_loss += loss.item()
            step_cnt += 1

        if (epoch % N_PRINT_EPOCHS == 0) or (epoch == N_EPOCHS):
            loss = running_loss / step_cnt
            lr = optim.param_groups[0]['lr']
            print(f"""[ {epoch:,}/{N_EPOCHS} ][ {step:,}/{len(train_dl):,} ]""", end="")
            print(f"""[ {lr:.5f} ][ {get_elapsed_time(start_time)} ][ {loss:.2f} ]""")

            running_loss = 0
            step_cnt = 0
            start_time = time()

        if (epoch % N_VAL_EPOCHS == 0) or (epoch == N_EPOCHS):
            avg_acc = validate(dl=val_dl, model=model, metric=metric)
            if avg_acc > best_avg_acc:
                cur_ckpt_path = CKPT_DIR/f"""epoch_{epoch}_avg_acc_{round(avg_acc, 3)}.pth"""
                save_checkpoint(
                    epoch=epoch,
                    model=model,
                    optim=optim,
                    scaler=scaler,
                    avg_acc=avg_acc,
                    ckpt_path=cur_ckpt_path,
                )
                print(f"""Saved checkpoint.""")
                prev_ckpt_path = Path(prev_ckpt_path)
                if prev_ckpt_path.exists():
                    prev_ckpt_path.unlink()

                best_avg_acc = avg_acc
                prev_ckpt_path = cur_ckpt_path

        scheduler.step(epoch + 1)


N_WORKERS = 6
DEVICE = cuda
AUTOCAST = True
BATCH_SIZE = 2048
[ 4/300 ][ 21/21 ][ 0.00082 ][ 0:03:35 ][ 4.23 ]
Validating...
Average accuracy: 0.456
Saved checkpoint.
[ 8/300 ][ 21/21 ][ 0.00100 ][ 0:03:36 ][ 4.05 ]
Validating...
Average accuracy: 0.551
Saved checkpoint.
[ 12/300 ][ 21/21 ][ 0.00100 ][ 0:03:37 ][ 3.97 ]
Validating...
Average accuracy: 0.588
Saved checkpoint.
[ 16/300 ][ 21/21 ][ 0.00099 ][ 0:03:35 ][ 3.92 ]
Validating...
Average accuracy: 0.611
Saved checkpoint.
[ 20/300 ][ 21/21 ][ 0.00099 ][ 0:03:35 ][ 3.88 ]
Validating...
Average accuracy: 0.627
Saved checkpoint.
[ 24/300 ][ 21/21 ][ 0.00098 ][ 0:03:40 ][ 3.85 ]
Validating...
Average accuracy: 0.645
Saved checkpoint.
[ 28/300 ][ 21/21 ][ 0.00098 ][ 0:03:34 ][ 3.83 ]
Validating...
Average accuracy: 0.653
Saved checkpoint.
[ 32/300 ][ 21/21 ][ 0.00097 ][ 0:03:35 ][ 3.81 ]
Validating...
Average accuracy: 0.693
Saved checkpoint.
[ 36/300 ][ 21/21 ][ 0.00096 ][ 0:03:35 ][ 3.79 ]
Validating...
Average accuracy: 0.684
[ 40

KeyboardInterrupt: 

In [86]:
from torchinfo import summary
summary(model)

Layer (type:depth-idx)                             Param #
DataParallel                                       --
├─ViT: 1-1                                         25,344
│    └─PatchEmbedding: 2-1                         --
│    │    └─LayerNorm: 3-1                         96
│    │    └─Linear: 3-2                            18,816
│    │    └─Dropout: 3-3                           --
│    │    └─LayerNorm: 3-4                         768
│    └─Dropout: 2-2                                --
│    └─TransformerEncoder: 2-3                     --
│    │    └─ModuleList: 3-5                        5,322,240
│    └─LayerNorm: 2-4                              768
│    └─Linear: 2-5                                 3,850
│    └─Dropout: 2-6                                --
Total params: 5,371,882
Trainable params: 5,371,882
Non-trainable params: 0