Code for Cvt-13 largely based off of: https://github.com/microsoft/CvT/blob/main/lib/dataset/transformas/build.py
and https://arxiv.org/pdf/2103.15808
with some modifications

In [None]:
!pip install torchsummary torchvision tqdm wandb typing_extensions pytorch_metric_learning timm==0.9.16 einops


In [None]:
!pip freeze > requirements.txt

In [None]:
import torch
from torchsummary import summary
import torchvision
from torchvision.utils import make_grid
from torchvision import datasets, transforms as T
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
import gc
# from tqdm import tqdm
from tqdm.auto import tqdm
from PIL import Image
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn import metrics as mt
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import glob
import wandb
import matplotlib.pyplot as plt
from pytorch_metric_learning import samplers
import csv
import logging
from timm.data import create_loader, create_transform
import torch.nn as nn
from functools import partial

from einops import rearrange
from einops.layers.torch import Rearrange

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", DEVICE)

In [None]:
# params largely copied from https://github.com/leoxiaobin/CvT/blob/main/experiments/imagenet/cvt/cvt-13-224x224.yaml
config = {
  "OUTPUT_DIR": "OUTPUT/imagenet-blur-1",
  "WORKERS": 8,
  "PRINT_FREQ": 500,
  "AMP": {
    "ENABLED": True
  },
  "MODEL": {
    "NAME": "cls_cvt",
    "SPEC": {
      "INIT": "trunc_norm",
      "NUM_STAGES": 3,
      "PATCH_SIZE": [7, 3, 3],
      "PATCH_STRIDE": [4, 2, 2],
      "PATCH_PADDING": [2, 1, 1],
      "DIM_EMBED": [64, 192, 384],
      "NUM_HEADS": [1, 3, 6],
      "DEPTH": [1, 2, 10],
      "MLP_RATIO": [4.0, 4.0, 4.0],
      "ATTN_DROP_RATE": [0.0, 0.0, 0.0],
      "DROP_RATE": [0.0, 0.0, 0.0],
      "DROP_PATH_RATE": [0.0, 0.0, 0.1],
      "QKV_BIAS": [True, True, True],
      "CLS_TOKEN": [False, False, True],
      "POS_EMBED": [False, False, False],
      "QKV_PROJ_METHOD": ["dw_bn", "dw_bn", "dw_bn"],
      "KERNEL_QKV": [3, 3, 3],
      "PADDING_KV": [1, 1, 1],
      "STRIDE_KV": [2, 2, 2],
      "PADDING_Q": [1, 1, 1],
      "STRIDE_Q": [1, 1, 1]
    }
  },
  "AUG": {
    "MIXUP_PROB": 1.0,
    "MIXUP": 0.8,
    "MIXCUT": 1.0,
    "TIMM_AUG": {
      "USE_LOADER": True,
      "RE_COUNT": 1,
      "RE_MODE": "pixel",
      "RE_SPLIT": False,
      "RE_PROB": 0.25,
      "AUTO_AUGMENT": "rand-m9-mstd0.5-inc1",
      "HFLIP": 0.5,
      "VFLIP": 0.0,
      "COLOR_JITTER": 0.4,
      "INTERPOLATION": "bicubic"
    }
  },
  "LOSS": {
    "LABEL_SMOOTHING": 0.1
  },
  "CUDNN": {
    "BENCHMARK": True,
    "DETERMINISTIC": False,
    "ENABLED": True
  },
  "DATASET": {
    "DATASET": "imagenet",
    "DATA_FORMAT": "jpg",
    "ROOT": "./ImageNet100_224",
    "TEST_SET": "val",
    "TRAIN_SET": "train"
  },
  "TEST": {
    "BATCH_SIZE_PER_GPU": 32,
    "IMAGE_SIZE": [224, 224],
    "MODEL_FILE": "",
    "INTERPOLATION": "bicubic"
  },
  "TRAIN": {
    "BATCH_SIZE_PER_GPU": 256,
    "GRADIENT_ACCUMULATION_STEPS": 8,       
    "LR": .00125,
    "IMAGE_SIZE": [224, 224],
    "BEGIN_EPOCH": 0,
    "END_EPOCH": 100,
    "BLUR": {
        "KERNEL_SIZE": 7,
        "SIGMA": 1,
        "EPOCHS": 20
    },      
    "LR_SCHEDULER": {
      "METHOD": "timm",
      "ARGS": {
        "sched": "cosine",
        "warmup_epochs": 5,
        "warmup_lr": 0.000001,
        "min_lr": 0.00001,
        "cooldown_epochs": 10,
        "decay_rate": 0.1
      }
    },
    "OPTIMIZER": "adamW",
    "WD": 0.05,
    "WITHOUT_WD_LIST": ["bn", "bias", "ln"],
    "SHUFFLE": True
  },
  "DEBUG": {
    "DEBUG": False
  }
}

### Defining transforms

In [None]:
def build_transforms(config, is_train):
    if is_train:
        img_size = config['TRAIN']['IMAGE_SIZE'][0]
        timm_cfg = config['AUG']['TIMM_AUG']
        # hardcoded values are from defaults e.g., https://github.com/microsoft/CvT/blob/main/lib/config/default.py#L68
        transforms = create_transform(
            input_size = img_size,
            is_training = True,
            use_prefetcher=False,
            no_aug=False,
            re_prob=timm_cfg['RE_PROB'],
            re_mode=timm_cfg['RE_MODE'],
            re_count=timm_cfg['RE_COUNT'],
            scale=(0.08, 1.0),
            ratio=(3.0/4.0, 4.0/3.0),
            hflip=timm_cfg['HFLIP'],
            vflip=timm_cfg['VFLIP'],
            color_jitter=timm_cfg['COLOR_JITTER'],
            auto_augment=timm_cfg['AUTO_AUGMENT'],
            interpolation=timm_cfg['INTERPOLATION'],
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225),            
            
        )
    else:
        normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        img_size = config['TEST']['IMAGE_SIZE'][0]
        transforms = T.Compose([
            T.Resize(int(img_size/ 0.875), interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
            T.CenterCrop(img_size),
            T.ToTensor(),
            normalize
        ])
    return transforms

### Building Datasets

In [None]:
def _build_imagenet_dataset(config, is_train):
    transforms = build_transforms(config, is_train)
    dataset_name = config['DATASET']['TRAIN_SET'] if is_train else config['DATASET']['TEST_SET']
    dataset = datasets.ImageFolder(os.path.join(config['DATASET']['ROOT'], dataset_name), transforms)
    logging.info(
        'load samples: {}, is_train: {}'
        .format(len(dataset), is_train)
    )
    
    return dataset
    
def build_dataset(config, is_train):
    '''
    In this file this will call the build_imagenet_dataset method.
    In the CIFAR file it will call the appropriate method
    '''
    dataset = None
    dataset = _build_imagenet_dataset(config, is_train)
    return dataset

### Building Dataloader

In [None]:
def build_dataloader(config, is_train):
    if is_train:
        batch_size_per_gpu = config['TRAIN']['BATCH_SIZE_PER_GPU']
        shuffle = True
    else:
        batch_size_per_gpu = config['TEST']['BATCH_SIZE_PER_GPU']
        shuffle = False
    dataset = build_dataset(config, is_train)
    sampler = None # the cvt code has code to use multi-gpu. here we assume single A100 gpu
    # set to true in the config so we are going to use TIMM loader for training
    if is_train:
        logging.info('use timm loader for training')
        timm_cfg = config['AUG']['TIMM_AUG']
        data_loader = create_loader(
            dataset,
            input_size=tuple(config['TRAIN']['IMAGE_SIZE']),
            batch_size=config['TRAIN']['BATCH_SIZE_PER_GPU'],
            is_training=True,
            use_prefetcher=False,
            no_aug=False,
            re_prob=timm_cfg['RE_PROB'],
            re_mode=timm_cfg['RE_MODE'],
            re_count=timm_cfg['RE_COUNT'],
            re_split=timm_cfg['RE_SPLIT'],
            scale=(0.08, 1.0),
            ratio=(3.0/4.0, 4.0/3.0),
            hflip=timm_cfg['HFLIP'],
            vflip=timm_cfg['VFLIP'],
            color_jitter=timm_cfg['COLOR_JITTER'],
            auto_augment=timm_cfg['AUTO_AUGMENT'],
            interpolation=timm_cfg['INTERPOLATION'],
            num_aug_splits=0,
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225),
            num_workers=config['WORKERS'],
            distributed=False,
            collate_fn=None,
            pin_memory=True,
            use_multi_epochs_loader=True
        )        
    else:
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size_per_gpu,
            shuffle=shuffle,
            num_workers=config['WORKERS'],
            pin_memory=True,
            sampler=sampler,
            drop_last=True if is_train else False,
        )     
    return data_loader
train_loader = build_dataloader(config, is_train=True)
val_loader = build_dataloader(config, is_train=False)

print(f"\nDataLoader Info:")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Train dataset size: {len(train_loader.dataset)}")
print(f"Val dataset size: {len(val_loader.dataset)}")
print(f"Number of classes: {len(train_loader.dataset.classes)}")

# Test loading a batch
images, labels = next(iter(train_loader))
print(f"\nBatch shapes:")
print(f"Images: {images.shape}")
print(f"Labels: {labels.shape}")

## Setup Gaussian DataLoader
The following section sets up a Gaussian DataLoader that will be used in early epochs of training. Approach closely follows the method described here, except scaled down to a limited number of epochs: https://dl.acm.org/doi/pdf/10.1145/3606043.3606086

In [None]:
def build_gaussian_transforms(config):
    normalize = T.Normalize(mean=[0.485,  0.456, 0.406], std=[0.229, 0.224, 0.225])
    transforms = T.Compose([
        T.ToTensor(),
        T.GaussianBlur(kernel_size=config['TRAIN']['BLUR']['KERNEL_SIZE'], sigma=config['TRAIN']['BLUR']['SIGMA']),
        normalize
    ])
    return transforms

def build_gaussian_dataset(config, is_train=True):
    transforms = build_gaussian_transforms(config)
    dataset_name = config['DATASET']['TRAIN_SET'] # will only use blur dataset for training
    dataset = datasets.ImageFolder(os.path.join(config['DATASET']['ROOT'], dataset_name), transforms)
    logging.info(f'load samples: {len(dataset)}, is_train: {is_train}')
    return dataset    

def build_gaussian_dataloader(config):
    dataset = build_gaussian_dataset(config)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config['TRAIN']['BATCH_SIZE_PER_GPU'],
        shuffle=True,
        num_workers=config['WORKERS'],
        pin_memory=True,
        sampler=None,
        drop_last=True,
    )
    return data_loader
    
gaussian_loader = build_gaussian_dataloader(config) # will only be used for train_data 
print(f"\nDataLoader Info:")
print(f"Gaussian Loader batches: {len(gaussian_loader)}")
print(f"Gaussian dataset size: {len(gaussian_loader.dataset)}")
print(f"Gaussian Dataset of classes: {len(gaussian_loader.dataset.classes)}")

# Test loading a batch
images, labels = next(iter(gaussian_loader))
print(f"\nBatch shapes:")
print(f"Images: {images.shape}")
print(f"Labels: {labels.shape}")        

## Defining Cvt-13 model:

In [None]:
# CVT code from: https://github.com/microsoft/CvT/blob/main/lib/models/cls_cvt.py
from functools import partial
from itertools import repeat
import collections.abc as container_abcs
from collections import OrderedDict
from einops.layers.torch import Rearrange
from timm.models.layers import DropPath, trunc_normal_

# From PyTorch internals
def _ntuple(n):
    def parse(x):
        if isinstance(x, container_abcs.Iterable):
            return x
        return tuple(repeat(x, n))

    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple

class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)


class Mlp(nn.Module):
    def __init__(self,
                 in_features,
                 hidden_features=None,
                 out_features=None,
                 act_layer=QuickGELU,
                 drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self,
                 dim_in,
                 dim_out,
                 num_heads,
                 qkv_bias=False,
                 attn_drop=0.,
                 proj_drop=0.,
                 method='dw_bn',
                 kernel_size=3,
                 stride_kv=1,
                 stride_q=1,
                 padding_kv=1,
                 padding_q=1,
                 with_cls_token=True,
                 **kwargs
                 ):
        super().__init__()
        self.stride_kv = stride_kv
        self.stride_q = stride_q
        self.dim = dim_out
        self.num_heads = num_heads
        # head_dim = self.qkv_dim // num_heads
        self.scale = dim_out ** -0.5
        self.with_cls_token = with_cls_token

        self.conv_proj_q = self._build_projection(
            dim_in, dim_out, kernel_size, padding_q,
            stride_q, 'linear' if method == 'avg' else method
        )
        self.conv_proj_k = self._build_projection(
            dim_in, dim_out, kernel_size, padding_kv,
            stride_kv, method
        )
        self.conv_proj_v = self._build_projection(
            dim_in, dim_out, kernel_size, padding_kv,
            stride_kv, method
        )

        self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim_out, dim_out)
        self.proj_drop = nn.Dropout(proj_drop)

    def _build_projection(self,
                          dim_in,
                          dim_out,
                          kernel_size,
                          padding,
                          stride,
                          method):
        if method == 'dw_bn':
            proj = nn.Sequential(OrderedDict([
                ('conv', nn.Conv2d(
                    dim_in,
                    dim_in,
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                    bias=False,
                    groups=dim_in
                )),
                ('bn', nn.BatchNorm2d(dim_in)),
                ('rearrage', Rearrange('b c h w -> b (h w) c')),
            ]))
        elif method == 'avg':
            proj = nn.Sequential(OrderedDict([
                ('avg', nn.AvgPool2d(
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                    ceil_mode=True
                )),
                ('rearrage', Rearrange('b c h w -> b (h w) c')),
            ]))
        elif method == 'linear':
            proj = None
        else:
            raise ValueError('Unknown method ({})'.format(method))

        return proj

    def forward_conv(self, x, h, w):
        if self.with_cls_token:
            cls_token, x = torch.split(x, [1, h*w], 1)

        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

        if self.conv_proj_q is not None:
            q = self.conv_proj_q(x)
        else:
            q = rearrange(x, 'b c h w -> b (h w) c')

        if self.conv_proj_k is not None:
            k = self.conv_proj_k(x)
        else:
            k = rearrange(x, 'b c h w -> b (h w) c')

        if self.conv_proj_v is not None:
            v = self.conv_proj_v(x)
        else:
            v = rearrange(x, 'b c h w -> b (h w) c')

        if self.with_cls_token:
            q = torch.cat((cls_token, q), dim=1)
            k = torch.cat((cls_token, k), dim=1)
            v = torch.cat((cls_token, v), dim=1)

        return q, k, v

    def forward(self, x, h, w):
        if (
            self.conv_proj_q is not None
            or self.conv_proj_k is not None
            or self.conv_proj_v is not None
        ):
            q, k, v = self.forward_conv(x, h, w)

        q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads)
        k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads)
        v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads)

        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
        attn = F.softmax(attn_score, dim=-1)
        attn = self.attn_drop(attn)

        x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
        x = rearrange(x, 'b h t d -> b t (h d)')

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

    @staticmethod
    def compute_macs(module, input, output):
        # T: num_token
        # S: num_token
        input = input[0]
        flops = 0

        _, T, C = input.shape
        H = W = int(np.sqrt(T-1)) if module.with_cls_token else int(np.sqrt(T))

        H_Q = H / module.stride_q
        W_Q = H / module.stride_q
        T_Q = H_Q * W_Q + 1 if module.with_cls_token else H_Q * W_Q

        H_KV = H / module.stride_kv
        W_KV = W / module.stride_kv
        T_KV = H_KV * W_KV + 1 if module.with_cls_token else H_KV * W_KV

        # C = module.dim
        # S = T
        # Scaled-dot-product macs
        # [B x T x C] x [B x C x T] --> [B x T x S]
        # multiplication-addition is counted as 1 because operations can be fused
        flops += T_Q * T_KV * module.dim
        # [B x T x S] x [B x S x C] --> [B x T x C]
        flops += T_Q * module.dim * T_KV

        if (
            hasattr(module, 'conv_proj_q')
            and hasattr(module.conv_proj_q, 'conv')
        ):
            params = sum(
                [
                    p.numel()
                    for p in module.conv_proj_q.conv.parameters()
                ]
            )
            flops += params * H_Q * W_Q

        if (
            hasattr(module, 'conv_proj_k')
            and hasattr(module.conv_proj_k, 'conv')
        ):
            params = sum(
                [
                    p.numel()
                    for p in module.conv_proj_k.conv.parameters()
                ]
            )
            flops += params * H_KV * W_KV

        if (
            hasattr(module, 'conv_proj_v')
            and hasattr(module.conv_proj_v, 'conv')
        ):
            params = sum(
                [
                    p.numel()
                    for p in module.conv_proj_v.conv.parameters()
                ]
            )
            flops += params * H_KV * W_KV

        params = sum([p.numel() for p in module.proj_q.parameters()])
        flops += params * T_Q
        params = sum([p.numel() for p in module.proj_k.parameters()])
        flops += params * T_KV
        params = sum([p.numel() for p in module.proj_v.parameters()])
        flops += params * T_KV
        params = sum([p.numel() for p in module.proj.parameters()])
        flops += params * T

        module.__flops__ += flops


class Block(nn.Module):

    def __init__(self,
                 dim_in,
                 dim_out,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 **kwargs):
        super().__init__()

        self.with_cls_token = kwargs['with_cls_token']

        self.norm1 = norm_layer(dim_in)
        self.attn = Attention(
            dim_in, dim_out, num_heads, qkv_bias, attn_drop, drop,
            **kwargs
        )

        self.drop_path = DropPath(drop_path) \
            if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim_out)

        dim_mlp_hidden = int(dim_out * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim_out,
            hidden_features=dim_mlp_hidden,
            act_layer=act_layer,
            drop=drop
        )

    def forward(self, x, h, w):
        res = x

        x = self.norm1(x)
        attn = self.attn(x, h, w)
        x = res + self.drop_path(attn)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class ConvEmbed(nn.Module):
    """ Image to Conv Embedding

    """

    def __init__(self,
                 patch_size=7,
                 in_chans=3,
                 embed_dim=64,
                 stride=4,
                 padding=2,
                 norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.proj = nn.Conv2d(
            in_chans, embed_dim,
            kernel_size=patch_size,
            stride=stride,
            padding=padding
        )
        self.norm = norm_layer(embed_dim) if norm_layer else None

    def forward(self, x):
        x = self.proj(x)

        B, C, H, W = x.shape
        x = rearrange(x, 'b c h w -> b (h w) c')
        if self.norm:
            x = self.norm(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)

        return x


class VisionTransformer(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """
    def __init__(self,
                 patch_size=16,
                 patch_stride=16,
                 patch_padding=0,
                 in_chans=3,
                 embed_dim=768,
                 depth=12,
                 num_heads=12,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 init='trunc_norm',
                 **kwargs):
        super().__init__()
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models

        self.rearrage = None

        self.patch_embed = ConvEmbed(
            # img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            stride=patch_stride,
            padding=patch_padding,
            embed_dim=embed_dim,
            norm_layer=norm_layer
        )

        with_cls_token = kwargs['with_cls_token']
        if with_cls_token:
            self.cls_token = nn.Parameter(
                torch.zeros(1, 1, embed_dim)
            )
        else:
            self.cls_token = None

        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule

        blocks = []
        for j in range(depth):
            blocks.append(
                Block(
                    dim_in=embed_dim,
                    dim_out=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[j],
                    act_layer=act_layer,
                    norm_layer=norm_layer,
                    **kwargs
                )
            )
        self.blocks = nn.ModuleList(blocks)

        if self.cls_token is not None:
            trunc_normal_(self.cls_token, std=.02)

        if init == 'xavier':
            self.apply(self._init_weights_xavier)
        else:
            self.apply(self._init_weights_trunc_normal)

    def _init_weights_trunc_normal(self, m):
        if isinstance(m, nn.Linear):
            logging.info('=> init weight of Linear from trunc norm')
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                logging.info('=> init bias of Linear to zeros')
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def _init_weights_xavier(self, m):
        if isinstance(m, nn.Linear):
            logging.info('=> init weight of Linear from xavier uniform')
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                logging.info('=> init bias of Linear to zeros')
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        x = self.patch_embed(x)
        B, C, H, W = x.size()

        x = rearrange(x, 'b c h w -> b (h w) c')

        cls_tokens = None
        if self.cls_token is not None:
            # stole cls_tokens impl from Phil Wang, thanks
            cls_tokens = self.cls_token.expand(B, -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)

        x = self.pos_drop(x)

        for i, blk in enumerate(self.blocks):
            x = blk(x, H, W)

        if self.cls_token is not None:
            cls_tokens, x = torch.split(x, [1, H*W], 1)
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)

        return x, cls_tokens


class ConvolutionalVisionTransformer(nn.Module):
    def __init__(self,
                 in_chans=3,
                 num_classes=1000,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 init='trunc_norm',
                 spec=None):
        super().__init__()
        self.num_classes = num_classes

        self.num_stages = spec['NUM_STAGES']
        for i in range(self.num_stages):
            kwargs = {
                'patch_size': spec['PATCH_SIZE'][i],
                'patch_stride': spec['PATCH_STRIDE'][i],
                'patch_padding': spec['PATCH_PADDING'][i],
                'embed_dim': spec['DIM_EMBED'][i],
                'depth': spec['DEPTH'][i],
                'num_heads': spec['NUM_HEADS'][i],
                'mlp_ratio': spec['MLP_RATIO'][i],
                'qkv_bias': spec['QKV_BIAS'][i],
                'drop_rate': spec['DROP_RATE'][i],
                'attn_drop_rate': spec['ATTN_DROP_RATE'][i],
                'drop_path_rate': spec['DROP_PATH_RATE'][i],
                'with_cls_token': spec['CLS_TOKEN'][i],
                'method': spec['QKV_PROJ_METHOD'][i],
                'kernel_size': spec['KERNEL_QKV'][i],
                'padding_q': spec['PADDING_Q'][i],
                'padding_kv': spec['PADDING_KV'][i],
                'stride_kv': spec['STRIDE_KV'][i],
                'stride_q': spec['STRIDE_Q'][i],
            }

            stage = VisionTransformer(
                in_chans=in_chans,
                init=init,
                act_layer=act_layer,
                norm_layer=norm_layer,
                **kwargs
            )
            setattr(self, f'stage{i}', stage)

            in_chans = spec['DIM_EMBED'][i]

        dim_embed = spec['DIM_EMBED'][-1]
        self.norm = norm_layer(dim_embed)
        self.cls_token = spec['CLS_TOKEN'][-1]

        # Classifier head
        self.head = nn.Linear(dim_embed, num_classes) if num_classes > 0 else nn.Identity()
        trunc_normal_(self.head.weight, std=0.02)

    def init_weights(self, pretrained='', pretrained_layers=[], verbose=True):
        if os.path.isfile(pretrained):
            pretrained_dict = torch.load(pretrained, map_location='cpu')
            logging.info(f'=> loading pretrained model {pretrained}')
            model_dict = self.state_dict()
            pretrained_dict = {
                k: v for k, v in pretrained_dict.items()
                if k in model_dict.keys()
            }
            need_init_state_dict = {}
            for k, v in pretrained_dict.items():
                need_init = (
                        k.split('.')[0] in pretrained_layers
                        or pretrained_layers[0] is '*'
                )
                if need_init:
                    if verbose:
                        logging.info(f'=> init {k} from {pretrained}')
                    if 'pos_embed' in k and v.size() != model_dict[k].size():
                        size_pretrained = v.size()
                        size_new = model_dict[k].size()
                        logging.info(
                            '=> load_pretrained: resized variant: {} to {}'
                            .format(size_pretrained, size_new)
                        )

                        ntok_new = size_new[1]
                        ntok_new -= 1

                        posemb_tok, posemb_grid = v[:, :1], v[0, 1:]

                        gs_old = int(np.sqrt(len(posemb_grid)))
                        gs_new = int(np.sqrt(ntok_new))

                        logging.info(
                            '=> load_pretrained: grid-size from {} to {}'
                            .format(gs_old, gs_new)
                        )

                        posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
                        zoom = (gs_new / gs_old, gs_new / gs_old, 1)
                        posemb_grid = scipy.ndimage.zoom(
                            posemb_grid, zoom, order=1
                        )
                        posemb_grid = posemb_grid.reshape(1, gs_new ** 2, -1)
                        v = torch.tensor(
                            np.concatenate([posemb_tok, posemb_grid], axis=1)
                        )

                    need_init_state_dict[k] = v
            self.load_state_dict(need_init_state_dict, strict=False)

    @torch.jit.ignore
    def no_weight_decay(self):
        layers = set()
        for i in range(self.num_stages):
            layers.add(f'stage{i}.pos_embed')
            layers.add(f'stage{i}.cls_token')

        return layers

    def forward_features(self, x):
        for i in range(self.num_stages):
            x, cls_tokens = getattr(self, f'stage{i}')(x)

        if self.cls_token:
            x = self.norm(cls_tokens)
            x = torch.squeeze(x)
        else:
            x = rearrange(x, 'b c h w -> b (h w) c')
            x = self.norm(x)
            x = torch.mean(x, dim=1)

        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)

        return x

In [None]:
def create_cvt_13(config, num_classes=100):
    spec = config['MODEL']['SPEC']
    model = ConvolutionalVisionTransformer(
        in_chans=3,
        num_classes=num_classes,
        act_layer=nn.GELU,
        norm_layer=partial(LayerNorm, eps=1e-5),
        init=config['MODEL']['SPEC']['INIT'],
        spec=spec
    )
    return model
model = create_cvt_13(config, num_classes=100)

In [None]:
def count_parameters(model):
    """Count model parameters"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print("="*60)
    print("Model Parameter Count")
    print("="*60)
    print(f"Total parameters: {total_params:,} ({total_params/1e6:.2f}M)")
    print(f"Trainable parameters: {trainable_params:,} ({trainable_params/1e6:.2f}M)")
    print("="*60)
    
    return total_params, trainable_params

# Count parameters
total_params, trainable_params = count_parameters(model)

## Setup Optimizer 

In [None]:
def _is_depthwise(m):
    return (
        isinstance(m, nn.Conv2d)
        and m.groups == m.in_channels
        and m.groups == m.out_channels
    )

def set_wd(cfg, model):
    without_decay_list = cfg["TRAIN"]["WITHOUT_WD_LIST"]
    without_decay_depthwise = []
    without_decay_norm = []
    for m in model.modules():
        if _is_depthwise(m) and 'dw' in without_decay_list:
            without_decay_depthwise.append(m.weight)
        elif isinstance(m, nn.BatchNorm2d) and 'bn' in without_decay_list:
            without_decay_norm.append(m.weight)
            without_decay_norm.append(m.bias)
        elif isinstance(m, nn.GroupNorm) and 'gn' in without_decay_list:
            without_decay_norm.append(m.weight)
            without_decay_norm.append(m.bias)
        elif isinstance(m, nn.LayerNorm) and 'ln' in without_decay_list:
            without_decay_norm.append(m.weight)
            without_decay_norm.append(m.bias)

    with_decay = []
    without_decay = []

    skip = {}
    if hasattr(model, 'no_weight_decay'):
        skip = model.no_weight_decay()

    skip_keys = {}
    if hasattr(model, 'no_weight_decay_keywords'):
        skip_keys = model.no_weight_decay_keywords()

    for n, p in model.named_parameters():
        ever_set = False

        if p.requires_grad is False:
            continue

        skip_flag = False
        if n in skip:
            print('=> set {} wd to 0'.format(n))
            without_decay.append(p)
            skip_flag = True
        else:
            for i in skip:
                if i in n:
                    print('=> set {} wd to 0'.format(n))
                    without_decay.append(p)
                    skip_flag = True

        if skip_flag:
            continue

        for i in skip_keys:
            if i in n:
                print('=> set {} wd to 0'.format(n))

        if skip_flag:
            continue

        for pp in without_decay_depthwise:
            if p is pp:
                if cfg['DEBUG']['DEBUG']:
                    print('=> set depthwise({}) wd to 0'.format(n))
                without_decay.append(p)
                ever_set = True
                break

        for pp in without_decay_norm:
            if p is pp:
                if cfg['DEBUG']['DEBUG']:
                    print('=> set norm({}) wd to 0'.format(n))
                without_decay.append(p)
                ever_set = True
                break

        if (
            (not ever_set)
            and 'bias' in without_decay_list
            and n.endswith('.bias')
        ):
            if cfg['DEBUG']['DEBUG']:
                print('=> set bias({}) wd to 0'.format(n))
            without_decay.append(p)
        elif not ever_set:
            with_decay.append(p)

    # assert (len(with_decay) + len(without_decay) == len(list(model.parameters())))
    params = [
        {'params': with_decay},
        {'params': without_decay, 'weight_decay': 0.}
    ]
    return params
def build_optimizer(config, model):
    optimizer = None
    params = set_wd(config, model)
    optimizer = torch.optim.AdamW(params, lr=config['TRAIN']['LR'], weight_decay=config['TRAIN']['WD'])
    print(f"Optimizer: AdamW")
    print(f"  LR: {config['TRAIN']['LR']}")
    print(f"  Weight Decay: {config['TRAIN']['WD']}")
    return optimizer
optimizer = build_optimizer(config, model)


## Setup Scheduler

In [None]:
from timm.scheduler.cosine_lr import CosineLRScheduler

def get_scheduler(config, optimizer, n_iter_per_epoch):

    lr_scheduler_args = config['TRAIN']['LR_SCHEDULER']['ARGS']
    num_epochs = config['TRAIN']['END_EPOCH']
    
    warmup_epochs = lr_scheduler_args.get('warmup_epochs', 5)
    warmup_lr = lr_scheduler_args.get('warmup_lr', 1e-6)
    min_lr = lr_scheduler_args.get('min_lr', 1e-5)
    
    scheduler = CosineLRScheduler(
        optimizer,
        t_initial=num_epochs,
        lr_min=min_lr,
        warmup_t=warmup_epochs,
        warmup_lr_init=warmup_lr,
        warmup_prefix=True,
        cycle_limit=1,
        t_in_epochs=True,
    )
    
    print(f"Scheduler: Cosine with warmup")
    print(f"  Total epochs: {num_epochs}")
    print(f"  Warmup epochs: {warmup_epochs}")
    print(f"  Warmup LR: {warmup_lr}")
    print(f"  Min LR: {min_lr}")
    
    return scheduler

scheduler = get_scheduler(config, optimizer, len(train_loader))

## Set Criterion

In [None]:
from timm.data.mixup import Mixup
class SoftTargetCrossEntropy(nn.Module):
    def __init__(self):
        super(SoftTargetCrossEntropy, self).__init__()

    def forward(self, x, target):
        loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
        return loss.mean()

def build_criterion(is_train=True):
    if is_train:
        return SoftTargetCrossEntropy()
    else:
        return nn.CrossEntropyLoss()
aug = config['AUG']
mixup_fn = Mixup(
        mixup_alpha=aug['MIXUP'], cutmix_alpha=aug['MIXCUT'],
        cutmix_minmax=None,
        prob=aug['MIXUP_PROB'],
        label_smoothing=0.0,
        num_classes=100
    )
criterion = build_criterion()
criterion.cuda()
criterion_eval = build_criterion(is_train=False)
criterion_eval.cuda()

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    maxk = min(max(topk), output.size()[1])
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))
    return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]

In [None]:
def train_one_epoch(epoch, model, train_loader, gaussian_loader, criterion, optimizer, scheduler, 
                    config, scaler, mixup_fn):
    """
    Train for one epoch
    Following: https://github.com/microsoft/CvT/blob/main/lib/core/function.py
    """

    losses = AverageMeter()
    acc_m = AverageMeter()

    model.train()
    accumulation_steps = config['TRAIN'].get('GRADIENT_ACCUMULATION_STEPS', 1)
    batch_bar = tqdm(total=len(train_loader), dynamic_ncols=True, leave=False, position=0, desc='Train', ncols=5)
    loader = gaussian_loader if epoch < config['TRAIN']['BLUR']['EPOCHS'] else train_loader
    for idx, (images, targets) in enumerate(loader):
        images = images.to(DEVICE, non_blocking=True)
        targets = targets.to(DEVICE, non_blocking=True)

        images, targets = mixup_fn(images, targets)

        
        with torch.cuda.amp.autocast(enabled=True):
            outputs = model(images)
   
            loss = criterion(outputs, targets)
       
            loss = loss / accumulation_steps
            
        scaler.scale(loss).backward()
        if (idx + 1) % accumulation_steps == 0:
            scaler.unscale_(optimizer)
            if idx % 100 == 0 and epoch >= 6:
                total_norm = 0
                for p in model.parameters():
                    if p.grad is not None:
                        total_norm += p.grad.norm().item() ** 2
                total_norm = total_norm ** 0.5
                print(f"Epoch {epoch}, Batch {idx}, Grad norm: {total_norm:.4f}")            
            for param in model.parameters():
                if param.grad is not None:
                    param.grad.data.clamp_(-0.5, 0.5)                    
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)            
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()        

        targets_for_acc = torch.argmax(targets, dim=1)

        acc = accuracy(outputs, targets_for_acc)

        losses.update(loss.item()*accumulation_steps, images.size(0))
        acc_m.update(acc[0].item(), images.size(0))
        batch_bar.set_postfix(
            acc="{:.02f}% ({:.02f}%)".format(acc[0].item(), acc_m.avg),
            loss="{:.04f} ({:.04f})".format(loss.item()*accumulation_steps, losses.avg),
            lr="{:.06f}".format(float(optimizer.param_groups[0]['lr'])))

        batch_bar.update() # Update tqdm bar        
    
    if (idx + 1) % accumulation_steps != 0:
        scaler.unscale_(optimizer)
        for param in model.parameters():
            if param.grad is not None:
                param.grad.data.clamp_(-0.5, 0.5)                    
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)                    
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

    batch_bar.close()

    scheduler.step(epoch + 1)
    torch.cuda.empty_cache()

    return losses.avg, acc_m.avg

In [None]:
@torch.no_grad()
def validate(model, val_loader, criterion, config):
    losses = AverageMeter()
    acc_m = AverageMeter()
    
    model.eval()
    batch_bar = tqdm(total=len(val_loader), dynamic_ncols=True, position=0, leave=False, desc='Val.', ncols=5)

    for idx, (images, targets) in enumerate(val_loader):
        images = images.to(DEVICE, non_blocking=True)
        targets = targets.to(DEVICE, non_blocking=True)
        
        with torch.cuda.amp.autocast(enabled=False):
            outputs = model(images)
            loss = criterion(outputs, targets)
        
        acc = accuracy(outputs, targets)
        losses.update(loss.item(), images.size(0))
        acc_m.update(acc[0].item(), images.size(0))
        batch_bar.set_postfix(
            acc="{:.02f}% ({:.02f}%)".format(acc[0].item(), acc_m.avg),
            loss="{:.04f} ({:.04f})".format(loss.item(), losses.avg))

        batch_bar.update()

    batch_bar.close()
    print(f' * Acc {acc_m.avg:.3f}')
    
    return losses.avg, acc_m.avg

In [None]:
def save_model(model, optimizer, scheduler, epoch, path):
    torch.save(
        {'model_state_dict'         : model.state_dict(),
         'optimizer_state_dict'     : optimizer.state_dict(),
         'scheduler_state_dict'     : scheduler.state_dict(),
         'epoch'                    : epoch},
         path)


def load_model(model, optimizer=None, scheduler=None, path='./checkpoint.pth'):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        optimizer = None
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    else:
        scheduler = None
    epoch = checkpoint['epoch']
    return model, optimizer, scheduler, epoch

In [None]:
os.makedirs(config['OUTPUT_DIR'], exist_ok=True)
model = model.to(DEVICE)
scaler = scaler = torch.cuda.amp.GradScaler()

In [None]:
wandb.login(key=os.environ.get('WANDB_API_KEY')) # API Key is in your wandb account, under settings (wandb.ai/settings)


In [None]:
run = wandb.init(
    name = "idl-project-cvt-13-imagenet-blur-1", ## Wandb creates random run names if you skip this field
    reinit = True, ### Allows reinitalizing runs when you re-run this cell
#     id = "t0xoatlu", #Insert specific run id here if you want to resume a previous run
#     resume = "must", ### You need this to resume previous runs, but comment out reinit = True when using this
    project = "idl-project", ### Project should be created in your wandb account
    config = config ### Wandb Config for your run
)

In [None]:
gc.collect() # These commands help you when you face CUDA OOM error
torch.cuda.empty_cache()

In [None]:
train_losses = []
train_accs = []
val_losses = []
val_accs = []
best_loss = -1
start_epoch = config['TRAIN']['BEGIN_EPOCH']
end_epoch = config['TRAIN']['END_EPOCH']

print("Starting Training")
print(f"Epochs: {start_epoch} -> {end_epoch}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Device: {DEVICE}")

for epoch in range(start_epoch, end_epoch):
    # epoch
    print("\nEpoch {}/{}".format(epoch+1, end_epoch))
    train_loss, train_acc = train_one_epoch(epoch, model, train_loader, gaussian_loader, criterion, optimizer, scheduler, 
                    config, scaler, mixup_fn)
    
    val_loss, val_acc = validate(model, val_loader, criterion_eval, config)
    is_best = (best_loss) == -1 or val_loss < best_loss
    best_loss = min(val_loss, best_loss)
    if epoch % 10 == 0:
        save_model(model, optimizer, scheduler, epoch,os.path.join(config['OUTPUT_DIR'], f'{epoch}.pth'))
    if is_best:
        save_model(model, optimizer, scheduler, epoch, os.path.join(config['OUTPUT_DIR'], 'best.pth') )

    print(f"Epoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    print(f"  Train Acc: {train_acc:.3f}, Val Acc: {val_acc:.3f}")
    metrics = {'train_loss': train_loss, 'val_loss': val_loss, 'train_acc': train_acc, 'val_acc': val_acc, 'epoch': epoch}
    if run is not None:
        run.log(metrics)

In [None]:
model, optimizer, scheduler, epoch = load_model(model, optimizer, scheduler, path="./OUTPUT/imagenet-blur-1/best.pth")

In [None]:
@torch.no_grad()
def test(model, test_loader, config):

    model.eval()
    
    all_predictions = []
    all_targets = []
    all_probabilities = []
    
    losses = AverageMeter()
    top1 = AverageMeter()
    
    criterion = torch.nn.CrossEntropyLoss()
    
    for images, targets in tqdm(test_loader, desc='Testing'):
        images = images.to(DEVICE, non_blocking=True)
        targets = targets.to(DEVICE, non_blocking=True)
        
        with torch.cuda.amp.autocast(enabled=True):
            outputs = model(images)
            loss = criterion(outputs, targets)
        
        probs = torch.softmax(outputs, dim=1)
        
        _, preds = outputs.topk(1, 1, True, True)
        
        all_predictions.extend(preds.cpu().numpy().flatten())
        all_targets.extend(targets.cpu().numpy())
        all_probabilities.extend(probs.cpu().numpy())
        
        acc1 = accuracy(outputs, targets)
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0].item(), images.size(0))
    
    return {
        'predictions': np.array(all_predictions),
        'targets': np.array(all_targets),
        'probabilities': np.array(all_probabilities),
        'loss': losses.avg,
        'top1_acc': top1.avg
    }

test_results = test(model, val_loader, config)

print("Final Test Results")
print(f"Test Loss: {test_results['loss']:.4f}")
print(f"Test Acc@1: {test_results['top1_acc']:.3f}%")
