EECS 542 Final Project - Swin Transformer

Jiangyan Feng, Pengfei Gao, Yilin Li, Zekun Li

In [None]:
!pip install timm
import os
import sys
import numpy as np
import math
import random
import cv2
import copy
from PIL import Image
# import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler, AdamW
from torch.utils.data import Dataset
import torchvision
from torchvision import datasets, transforms
from torchvision.transforms import InterpolationMode
from timm.models.layers import to_2tuple
from timm.scheduler.cosine_lr import CosineLRScheduler
from timm.scheduler.step_lr import StepLRScheduler
from timm.loss import SoftTargetCrossEntropy
from timm.data import Mixup
from timm.data import create_transform

print("PyTorch Version: ", torch.__version__)
print("Torchvision Version: ", torchvision.__version__)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


Collecting timm
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
[?25l[K     |▊                               | 10 kB 27.1 MB/s eta 0:00:01[K     |█▌                              | 20 kB 34.1 MB/s eta 0:00:01[K     |██▎                             | 30 kB 17.8 MB/s eta 0:00:01[K     |███                             | 40 kB 13.3 MB/s eta 0:00:01[K     |███▉                            | 51 kB 11.7 MB/s eta 0:00:01[K     |████▋                           | 61 kB 13.5 MB/s eta 0:00:01[K     |█████▎                          | 71 kB 12.6 MB/s eta 0:00:01[K     |██████                          | 81 kB 12.8 MB/s eta 0:00:01[K     |██████▉                         | 92 kB 13.9 MB/s eta 0:00:01[K     |███████▋                        | 102 kB 15.0 MB/s eta 0:00:01[K     |████████▍                       | 112 kB 15.0 MB/s eta 0:00:01[K     |█████████▏                      | 122 kB 15.0 MB/s eta 0:00:01[K     |█████████▉                      | 133 kB 15.0 MB/s eta 0:00:01

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


1. Patch partition, Linear embedding, Patch merging (Zekun Li)

In [None]:
class PatchEmbed(nn.Module):
  def __init__(self, patch_size=4, in_c=3, output_dim=96, norm=None):
    super(PatchEmbed, self).__init__()
    self.patch_size = patch_size
    self.linear = nn.Conv2d(in_c, output_dim, kernel_size=self.patch_size, stride=self.patch_size, padding=0) # C -> output_dim, or use in_chans
    self.norm = nn.LayerNorm(output_dim)

  def forward(self, x): ## x: B * C * H * W
    B, C, H, W = x.shape
    if (H % self.patch_size) or (W % self.patch_size):
      x = F.pad(x, (0, self.patch_size - W % self.patch_size, 0, self.patch_size - H % self.patch_size, 0, 0))
    x = self.linear(x)
    x = x.flatten(2)
    x = x.transpose(1, 2)
    if self.norm: x = self.norm(x)
    return x 

class PatchMerging(nn.Module):
  def __init__(self, in_c, norm=nn.LayerNorm):
    super(PatchMerging, self).__init__()
    # self.norm = norm
    self.norm = nn.LayerNorm(in_c * 4)
    self.linear = nn.Linear(in_c * 4, in_c * 2, bias=False)

  def forward(self, x, H, W): ## x: B * L * C
    B, _, C = x.shape
    x = x.view(B, H, W, C)
    if H % 2 or W % 2:
      x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
    x = torch.cat([x[:, 0::2, 0::2, :], x[:, 1::2, 0::2, :],
                   x[:, 0::2, 1::2, :], x[:, 1::2, 1::2, :]], -1).view(B, -1, C * 4)
    # norm = nn.LayerNorm(C * 4)
    # linear = nn.Linear(C * 4, C * 2, bias=False)
    return self.linear(self.norm(x))  ## x: B * HW/4 * 2C

2. Swin transformer block basic function (Yilin Li)

In [None]:
class Mlp(nn.Module):
    """
    a 2-layer MLP with GELU nonlinearity activation function

    Args:
        in_features: input dimension
        hidden_features: hidden layer dimension
        out_features: output dimension
        act_layer: type of activation function(here we use GELU nonlinearity)
        drop: (dropout layer parameter) probability of an element to be zeroed
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        """
        Args:
            x: input features
        """
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): shape of window is (window_size * window_size)

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    patch_num_y = H // window_size
    patch_num_x = W // window_size
    x = torch.reshape(x, (-1, patch_num_y, window_size, patch_num_x, window_size, C))
    x = x.permute(0, 1, 3, 2, 4, 5)
    windows = torch.reshape(x, (-1, window_size, window_size, C))
    return windows

def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): shape of window is (window_size * window_size)
        H (int): height of image
        W (int): width of image

    Returns:
        x: (B, H, W, C)
    """
    patch_num_y = H // window_size
    patch_num_x = W // window_size
    _, _, _, C = windows.shape
    x = torch.reshape(windows, (-1, patch_num_y, patch_num_x, window_size, window_size, C))
    x = x.permute(0, 1, 3, 2, 4, 5)
    x = torch.reshape(x, (-1, H, W, C))
    return x

class DropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).

    Args:
        drop_prob (float): probability of drop 
        scale_by_keep (bool): 
    """
    def __init__(self, drop_prob=None, scale_by_keep=True):
        super().__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        """
        Args:
            x: input
        """
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
        if keep_prob > 0.0 and self.scale_by_keep:
            random_tensor.div_(keep_prob)
        return x * random_tensor

def relative_position(window_size):
    """
    get pair-wise relative position index for each token inside the window

    Args:
    window_size (tuple[int]): The height and width of the window.
    
    Returns:
    relative_position_index (Wh*Ww, Wh*Ww)
    """
    coords_h = torch.arange(window_size[0])
    coords_w = torch.arange(window_size[1])
    coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  
    coords_flatten = torch.flatten(coords, 1)  
    relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  
    relative_coords = relative_coords.permute(1, 2, 0).contiguous()  
    relative_coords[:, :, 0] += window_size[0] - 1  
    relative_coords[:, :, 1] += window_size[1] - 1
    relative_coords[:, :, 0] *= 2 * window_size[1] - 1
    relative_position_index = relative_coords.sum(-1)  
    return relative_position_index

def trunc_normal(tensor, mean=0., std=1., a=-2., b=2.):
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)
        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()
        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)
        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor

class WindowAttention(nn.Module):
    """ W-MSA / SW-MSA

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  
        self.num_heads = num_heads
        self.scale = qk_scale or (dim // num_heads) ** (-0.5)

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) 

        self.register_buffer("relative_position_index", relative_position(self.window_size))

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal(self.relative_position_bias_table, std=.02)
        
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

3. Model construction (Pengfei Gao)

In [None]:
def get_mask(input_resolution, window_size, shift_size):
    img_mask = torch.zeros((1, input_resolution[0], input_resolution[1], 1))
    h_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None))
    w_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None))
    count = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[:, h, w, :] = count
            count += 1
    windows = window_partition(img_mask, window_size).view(-1, window_size ** 2)
    mask = windows.unsqueeze(1) - windows.unsqueeze(2)
    mask = mask.masked_fill(mask != 0, float(-100.00)).masked_fill(mask == 0, float(0.00))
    return mask

def normal_init(m):
    if isinstance(m, nn.Linear):
        trunc_normal(m.weight, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
       nn.init.constant_(m.bias, 0)
       nn.init.constant_(m.weight, 1.0)
    

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                  mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.):
        super(SwinTransformerBlock, self).__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        print(input_resolution)
        self.window_size = window_size
        # self.shift_size = shift_size
        self.shift_size = window_size // 2
        self.attention_mask = get_mask(self.input_resolution, self.window_size, self.shift_size).to(device)
        self.num_heads = num_heads
        #LayerNorm
        self.LN1 = nn.LayerNorm(dim)
        self.LN2 = nn.LayerNorm(dim)
        self.LN3 = nn.LayerNorm(dim)
        self.LN4 = nn.LayerNorm(dim)
        #MLP
        self.MLP1 = Mlp(in_features=self.dim, hidden_features=int(mlp_ratio*self.dim), drop=drop)
        self.MLP2 = Mlp(in_features=self.dim, hidden_features=int(mlp_ratio*self.dim), drop=drop)
        #WindowAttention
        self.WindowAttention1 = WindowAttention(dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
                qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.WindowAttention2 = WindowAttention(dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
                qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        #Drop Path
        self.drop_path1 = DropPath(drop_path[0] if drop_path[0] > 0. else 0.)
        self.drop_path2 = DropPath(drop_path[1] if drop_path[1] > 0. else 0.)

    def forward(self,x):
        H, W = self.input_resolution
        #first block
        B, L, C = x.shape

        initial_x = x
        x = self.LN1(x).reshape(B,H,W,C)
        # partition windows
        # print("x",x.shape)
        x_windows = window_partition(x, self.window_size)
        x_windows = x_windows.reshape(-1, self.window_size ** 2, C)
        # W-MSA
        WindowAttention1 = self.WindowAttention1(x_windows, mask=None)
        # merge windows
        WindowAttention1 = WindowAttention1.reshape(-1, self.window_size, self.window_size, C)
        x = window_reverse(WindowAttention1, self.window_size, H, W)
        x = x.reshape(B, H * W, C)
        #FFN
        x = initial_x + self.drop_path1(x)
        initial_x = x
        x = self.LN2(x)
        x = self.MLP1(x)
        x = self.drop_path1(x)
        x += initial_x

        #second block
        B, L, C = x.shape

        # x_initial = x
        initial_x = x
        x = self.LN3(x).reshape(B,H,W,C)
        #shift
        x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        # partition windows
        x_windows = window_partition(x, self.window_size)
        x_windows = x_windows.reshape(-1, self.window_size ** 2, C)
        # SW-MSA
        WindowAttention2 = self.WindowAttention2(x_windows, mask=self.attention_mask)
        # merge windows
        WindowAttention2 = WindowAttention2.reshape(-1, self.window_size, self.window_size, C)
        x = window_reverse(WindowAttention2, self.window_size, H, W)
        x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        x = x.reshape(B, H * W, C)
        #FFN
        x = initial_x + self.drop_path2(x)
        initial_x = x
        x = self.LN4(x)
        x = self.MLP2(x)
        x = self.drop_path2(x)
        x += initial_x
        
        return x


class SwinTransformerModel(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                  embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                  window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                  drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                  norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                  use_checkpoint=False, **kwargs):
        super(SwinTransformerModel, self).__init__()
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = [self.img_size // self.patch_size, self.img_size // self.patch_size]
        self.embed_dim = embed_dim
        self.patch_embed = PatchEmbed(patch_size=self.patch_size, in_c=in_chans, output_dim=self.embed_dim)
        self.num_features = int(self.embed_dim * 2 ** (self.patch_size - 1))
        self.input_resolution = (self.patches_resolution[0], self.patches_resolution[1])
        self.block1 = SwinTransformerBlock(dim = int(self.embed_dim), 
                                           input_resolution=(self.patches_resolution[0], self.patches_resolution[1]), 
                                           num_heads=num_heads[0],
                                           drop_path = dpr[sum(depths[:0]):sum(depths[:0 + 1])])

        self.block2 = SwinTransformerBlock(dim = int(self.embed_dim * 2), input_resolution=(self.patches_resolution[0] // 2, self.patches_resolution[1] // 2), num_heads=num_heads[1],drop_path = dpr[sum(depths[:1]):sum(depths[:1 + 1])])
        # self.block3 = SwinTransformerBlock(dim = int(self.embed_dim * 4), input_resolution=(self.patches_resolution[0] // 4, self.patches_resolution[1] // 4), num_heads=num_heads[2])
        self.block3 = nn.ModuleList([
                                     SwinTransformerBlock(dim = int(self.embed_dim * 4), 
                                                          input_resolution=(self.patches_resolution[0] // 4, self.patches_resolution[1] // 4), 
                                                          num_heads=num_heads[2],
                                                          drop_path = dpr[(sum(depths[:2])+i*2):(sum(depths[:2])+(i+1)*2)])
                                     for i in range(3)])
        self.block4 = SwinTransformerBlock(dim = int(self.embed_dim * 8), 
                                           input_resolution=(self.patches_resolution[0] // 8, self.patches_resolution[1] // 8), 
                                           num_heads=num_heads[3],
                                           drop_path = dpr[sum(depths[:3]):sum(depths[:3 + 1])])
        self.patch_merging1 = PatchMerging(int(self.embed_dim))
        self.patch_merging2 = PatchMerging(int(self.embed_dim * 2))
        self.patch_merging3 = PatchMerging(int(self.embed_dim * 4))
        self.final_layer = nn.Linear(int(self.embed_dim * 8), num_classes)
        self.norm = nn.LayerNorm(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        
        self.apply(self.weight_init)
        
    def weight_init(self, m):
        normal_init(m)

    def forward(self, image):
        patch_embed = self.patch_embed(image)
        #print("embed done.")
        x = self.block1(patch_embed)
        #print("block1 done.")
        x = self.patch_merging1(x,self.input_resolution[0],self.input_resolution[1])
        #print("patch merging1 done.")
        x = self.block2(x)
        #print("block2 done.")
        x = self.patch_merging2(x,self.input_resolution[0]//2,self.input_resolution[1]//2)
        #print("merging2 done.")
        for blk in self.block3:
            x = blk(x)

        x = self.patch_merging3(x,self.input_resolution[0]//4,self.input_resolution[1]//4)
        x = self.block4(x)

        x = self.norm(x)
        x = self.avgpool(x.transpose(1, 2))
        x = torch.flatten(x, 1)
        x = self.head(x)
        return x


4. Data processing and training (Jiangyan Feng)

In [None]:
# logger
import logging
import time

def log_creater(output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    logger_name = '{}.log'.format(time.strftime('%Y-%m-%d-%H-%M'))
    final_log_file = os.path.join(output_dir,logger_name)
    # creat a log
    logger = logging.getLogger('train_log')
    logger.setLevel(logging.DEBUG)
    if not logger.handlers:
        # FileHandler
        file = logging.FileHandler(final_log_file)
        file.setLevel(logging.DEBUG)
        # StreamHandler
        stream = logging.StreamHandler()
        stream.setLevel(logging.DEBUG)
        # Formatter
        formatter = logging.Formatter('[%(asctime)s][line: %(lineno)d] ==> %(message)s')
        # setFormatter
        file.setFormatter(formatter)
        stream.setFormatter(formatter)
        # addHandler
        logger.addHandler(file)
        logger.addHandler(stream)
      
        logger.info('creating {}'.format(final_log_file))
    return logger


# define transforms
def get_transforms():
    # normlaize params for 3 channels
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]

    # transform by timm
    train_transform = create_transform(
            input_size=img_size,
            is_training=True,
            color_jitter=0.4,
            auto_augment='rand-m9-mstd0.5-inc1',
            re_prob=0.25,
            re_mode='pixel',
            re_count=1,
            interpolation='bicubic')
    
    val_transform = transforms.Compose([
        transforms.Resize((img_size, img_size), 
                          interpolation=InterpolationMode.BICUBIC), # interpolation 'bicubic'
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)])
    return train_transform, val_transform


# load train, val images
def get_dataloaders(path, train_transform, val_transform):
    # imagenet dataset path
    train_path = os.path.join(path, 'train')
    val_path = os.path.join(path, 'val')

    # import train and val dataset
    train_data = datasets.ImageFolder(root=train_path, transform=train_transform)
    val_data = datasets.ImageFolder(root=val_path, transform=val_transform)
    # train_data = torch.utils.data.Subset(train_data, np.random.choice(len(train_data), 10000, replace=False))
    # val_data = torch.utils.data.Subset(val_data, np.random.choice(len(train_data), 1000, replace=False))
    dataset_sizes = {'train': len(train_data), 'val': len(val_data)}

    #load dataset into Dataloader
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=batchsize, shuffle=True, 
        pin_memory=True, num_workers=num_workers, drop_last=True)

    val_loader = torch.utils.data.DataLoader(
        val_data, batch_size=200, shuffle=False, 
        pin_memory=True, num_workers=num_workers, drop_last=False)

    dataloaders = {'train': train_loader, 'val': val_loader}
    return dataloaders, dataset_sizes

def get_grad_norm(parameters, norm_type=2):
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    norm_type = float(norm_type)
    total_norm = 0
    for p in parameters:
        param_norm = p.grad.data.norm(norm_type)
        total_norm += param_norm.item() ** norm_type
    total_norm = total_norm ** (1. / norm_type)
    return total_norm

def train_one_epoch(model, epoch, dataloaders, optimizer, lr_scheduler, criterion, mixup_fn):
    model.train()
    logger.info('Train: ')
    
    running_loss = 0.0
    running_acc = 0.0

    niter_per_epoch = len(dataloaders['train'])
    # iterate over dataset
    for idx, (images, labels) in enumerate(dataloaders['train']):
        images = images.to(device)
        labels = labels.to(device)
        
        # mixup labels during training
        images, labels = mixup_fn(images, labels)

        optimizer.zero_grad()

        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        loss.backward()
        # Clip gradient norm
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
        with torch.no_grad():
            grad_norm = get_grad_norm(model.parameters())
        optimizer.step()
        lr_scheduler.step_update(epoch*niter_per_epoch+idx) # timm scheduler
        
        running_loss += loss.item() * images.size(0)
        _, indices = torch.max(labels, 1)
        running_acc += torch.sum(preds == indices)
        
        if idx % print_freq == 0:
            lr = optimizer.param_groups[0]['lr']
            logger.info(
                f'Epoch [{epoch}/{num_epoches - 1}][{idx}/{niter_per_epoch}]\t'
                f'lr: {lr:.7f}\t'
                f'loss: {loss.item():.4f}\t'
                f'grad norm: {grad_norm:.4f}'
                )
    print(dataset_sizes['train'])
    epoch_loss = running_loss / dataset_sizes['train']
    epoch_acc = running_acc.double() / dataset_sizes['train']
    logger.info(
        f'Train: epoch loss: {epoch_loss:.6f}\t'
        f'epoch acc: {epoch_acc:.4f}')
    return epoch_loss, epoch_acc


@torch.no_grad()
def validate(model, epoch, dataloaders):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    logger.info('Validate: ')         
    
    running_loss = 0.0
    running_acc = 0.0
    niter_per_epoch = len(dataloaders['val'])
    # iterate over dataset
    for idx, (images, labels) in enumerate(dataloaders['val']):
        images = images.to(device)
        labels = labels.to(device)
      
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        running_loss += loss.item() * images.size(0)
        running_acc += torch.sum(preds == labels)
        
        if idx % print_freq == 0:
            logger.info(
                f'Epoch [{epoch}/{num_epoches - 1}][{idx}/{niter_per_epoch}]\t'
                f'loss: {loss.item():.4f}')
    epoch_loss = running_loss / dataset_sizes['val']
    epoch_acc = running_acc.double() / dataset_sizes['val']
    logger.info(
        f'Val: epoch loss: {epoch_loss:.6f}\t'
        f'epoch acc: {epoch_acc:.4f}')
    print(' ')
    print('=='*15)
    return epoch_loss, epoch_acc


def decay_filter(model):
    # set weight decay of normalization to 0.
    # len(param.shape) = 1 for all normalization layers' weights and bias
    # and all bias has length 1
    no_weight_decay = []
    has_weight_decay = []
    for name, param in model.named_parameters():
        # print(name, len(param.shape))
        if len(param.shape)== 1 or name.endswith('.relative_position_bias_table'):
            no_weight_decay.append(param)
        else:
            has_weight_decay.append(param)
    paramters = [{'params': no_weight_decay, "weight_decay": 0.0}, 
                 {'params': has_weight_decay}]
    return paramters

# define training process
def train(dataloaders, model, num_epoches=30, save_dir='./models/', model_name='swinT_'):
    # parameters = model.parameters()
    parameters = decay_filter(model)
    # lr=0.001, weight_decay=0.05
    optimizer = AdamW(parameters, eps=1e-8, betas=(0.9, 0.999), lr=0.0001, weight_decay=0.005)
    
    # scheduler from timm - cosine decay
    niter_per_epoch = len(dataloaders['train'])
    lr_scheduler = CosineLRScheduler(optimizer,
                                     t_initial=int(num_epoches*niter_per_epoch),
                                     lr_min=5e-6,
                                     warmup_lr_init=1e-7, # 5e-7 on github
                                     warmup_t=int(warmup_epoches*niter_per_epoch),
                                     t_in_epochs=False)
    '''
    lr_scheduler = StepLRScheduler(
            optimizer,
            decay_t=int(30*niter_per_epoch),
            decay_rate=0.1,
            warmup_lr_init=5e-7,
            warmup_t=int(warmup_epoches*niter_per_epoch),
            t_in_epochs=False)
    '''
    # loss function
    mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
                     prob=1.0, switch_prob=0.5, mode='batch',  # prob was 1.0 in github
                     label_smoothing=0.1, num_classes=num_classes) 
    # criterion = nn.CrossEntropyLoss() # normal way
    criterion = SoftTargetCrossEntropy() # for mixup label transform
    
    # train and validate
    train_loss_history = []
    val_loss_history = []
    train_acc_history = []
    val_acc_history = []

    best_tr_acc = 0.0
    best_val_acc = 0.0

    logger.info("Start training-----------")
    for epoch in range(num_epoches):
        # train and validate for one epoch
        epoch_loss, epoch_acc = train_one_epoch(model, epoch, dataloaders, 
                                                optimizer, lr_scheduler, criterion, mixup_fn)
        
        if epoch_acc > best_tr_acc:
            best_tr_acc = epoch_acc
        train_loss_history.append(epoch_loss)
        train_acc_history.append(epoch_acc.cpu())

        epoch_loss, epoch_acc = validate(model, epoch, dataloaders)
        if epoch_acc > best_val_acc:
            best_val_acc = epoch_acc
        val_loss_history.append(epoch_loss)
        val_acc_history.append(epoch_acc.cpu())
        
        # save model weights
        if save_dir:
            model_weights = copy.deepcopy(model.state_dict())
            torch.save(model_weights, os.path.join(save_dir, model_name + str(epoch) + '.pth'), 
                        _use_new_zipfile_serialization=False)
                
    print('Best train Acc: {:4f}'.format(best_tr_acc))
    print('Best val Acc: {:4f}'.format(best_val_acc))

    return train_loss_history, val_loss_history, train_acc_history, val_acc_history


In [None]:
# hyperparameters in paper
num_epoches = 100
warmup_epoches = 10
# num_epoches = 50
# warmup_epoches = 3
# batchsize = 1024
# batchsize = 512
batchsize = 128 # for debug
img_size = 224
patch_size = 4
embed_dim = 96
depths = [2, 2, 6, 2]
num_heads = [3, 6, 12, 24]
window_size = 7
drop_path_rate = 0.2
num_workers = 2 # in public code this is 8
print_freq = 10

num_classes = 200 # for tiny-imagenet-200
# num_classes = 1000 # for imagenet-1k


In [None]:
# init swin transformer model
model = SwinTransformerModel(
    img_size=img_size,
    patch_size=patch_size, 
    in_chans=3, 
    num_classes=num_classes,
    embed_dim=embed_dim, 
    depths=depths, 
    num_heads=num_heads,
    window_size=window_size,
    drop_path_rate=drop_path_rate
    )

(56, 56)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


(28, 28)
(14, 14)
(14, 14)
(14, 14)
(7, 7)


In [None]:
# load pretrained 
%cd '/content/drive/MyDrive/EECS542_project/'

pretrain = './models/log021/swinT_210.pth'
model = model.to(device)
model.load_state_dict(torch.load(pretrain))
print(model)


/content/drive/MyDrive/EECS542_project
SwinTransformerModel(
  (patch_embed): PatchEmbed(
    (linear): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (block1): SwinTransformerBlock(
    (LN1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    (LN2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    (LN3): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    (LN4): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    (MLP1): Mlp(
      (fc1): Linear(in_features=96, out_features=384, bias=True)
      (act): GELU()
      (fc2): Linear(in_features=384, out_features=96, bias=True)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (MLP2): Mlp(
      (fc1): Linear(in_features=96, out_features=384, bias=True)
      (act): GELU()
      (fc2): Linear(in_features=384, out_features=96, bias=True)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (WindowAttention1): WindowAttention(
      (qkv): 

In [None]:
# transforms and dataloaders
# !cp -r /content/drive/MyDrive/EECS542_project/tiny-imagenet-200/ /content/data/

imagenet_path = './tiny-imagenet-200/'
train_transform, val_transform = get_transforms()
dataloaders, dataset_sizes = get_dataloaders(imagenet_path, train_transform, val_transform)
print(dataset_sizes)


{'train': 100261, 'val': 10000}


In [None]:
# training
save_dir = './models/log022/'
logger = log_creater(save_dir)

train_loss_history, val_loss_history, train_acc_history, val_acc_history = \
train(dataloaders, model, num_epoches=num_epoches, save_dir=save_dir, model_name='swinT_')


[2022-04-26 08:23:54,193][line: 237] ==> Start training-----------
[2022-04-26 08:23:54,198][line: 97] ==> Train: 
[2022-04-26 08:23:56,947][line: 132] ==> Epoch [0/99][0/783]	lr: 0.0000001	loss: 3.1719	grad norm: 2.0000
[2022-04-26 08:24:04,444][line: 132] ==> Epoch [0/99][10/783]	lr: 0.0000002	loss: 3.7946	grad norm: 2.0000
[2022-04-26 08:24:11,766][line: 132] ==> Epoch [0/99][20/783]	lr: 0.0000004	loss: 2.2735	grad norm: 2.0000
[2022-04-26 08:24:19,220][line: 132] ==> Epoch [0/99][30/783]	lr: 0.0000005	loss: 2.6768	grad norm: 2.0000
[2022-04-26 08:24:26,669][line: 132] ==> Epoch [0/99][40/783]	lr: 0.0000006	loss: 3.8157	grad norm: 1.9178
[2022-04-26 08:24:33,654][line: 132] ==> Epoch [0/99][50/783]	lr: 0.0000007	loss: 3.1721	grad norm: 2.0000
[2022-04-26 08:24:40,734][line: 132] ==> Epoch [0/99][60/783]	lr: 0.0000009	loss: 3.5855	grad norm: 2.0000
[2022-04-26 08:24:47,979][line: 132] ==> Epoch [0/99][70/783]	lr: 0.0000010	loss: 2.8225	grad norm: 2.0000
[2022-04-26 08:24:55,037][line

100261


[2022-04-26 08:33:38,368][line: 169] ==> Epoch [0/99][0/50]	loss: 1.0043
[2022-04-26 08:33:45,419][line: 169] ==> Epoch [0/99][10/50]	loss: 1.7498
[2022-04-26 08:33:52,331][line: 169] ==> Epoch [0/99][20/50]	loss: 1.6558
[2022-04-26 08:33:59,219][line: 169] ==> Epoch [0/99][30/50]	loss: 1.9630
[2022-04-26 08:34:06,189][line: 169] ==> Epoch [0/99][40/50]	loss: 1.5559
[2022-04-26 08:34:11,835][line: 174] ==> Val: epoch loss: 1.505751	epoch acc: 0.6613


 


[2022-04-26 08:34:12,195][line: 97] ==> Train: 
[2022-04-26 08:34:14,264][line: 132] ==> Epoch [1/99][0/783]	lr: 0.0000101	loss: 2.4126	grad norm: 2.0000
[2022-04-26 08:34:22,333][line: 132] ==> Epoch [1/99][10/783]	lr: 0.0000102	loss: 3.7527	grad norm: 1.5011
[2022-04-26 08:34:30,044][line: 132] ==> Epoch [1/99][20/783]	lr: 0.0000103	loss: 4.0022	grad norm: 2.0000
[2022-04-26 08:34:37,827][line: 132] ==> Epoch [1/99][30/783]	lr: 0.0000105	loss: 3.6228	grad norm: 1.5142
[2022-04-26 08:34:45,689][line: 132] ==> Epoch [1/99][40/783]	lr: 0.0000106	loss: 4.1558	grad norm: 2.0000
[2022-04-26 08:34:53,678][line: 132] ==> Epoch [1/99][50/783]	lr: 0.0000107	loss: 3.8928	grad norm: 1.7571
[2022-04-26 08:35:01,656][line: 132] ==> Epoch [1/99][60/783]	lr: 0.0000109	loss: 3.8377	grad norm: 1.6968
[2022-04-26 08:35:09,577][line: 132] ==> Epoch [1/99][70/783]	lr: 0.0000110	loss: 4.2779	grad norm: 1.6926
[2022-04-26 08:35:17,292][line: 132] ==> Epoch [1/99][80/783]	lr: 0.0000111	loss: 2.8522	grad nor

100261


[2022-04-26 08:44:33,980][line: 169] ==> Epoch [1/99][0/50]	loss: 1.0103
[2022-04-26 08:44:40,866][line: 169] ==> Epoch [1/99][10/50]	loss: 1.6679
[2022-04-26 08:44:47,741][line: 169] ==> Epoch [1/99][20/50]	loss: 1.6118
[2022-04-26 08:44:54,625][line: 169] ==> Epoch [1/99][30/50]	loss: 1.8387
[2022-04-26 08:45:01,098][line: 169] ==> Epoch [1/99][40/50]	loss: 1.5442
[2022-04-26 08:45:06,822][line: 174] ==> Val: epoch loss: 1.476737	epoch acc: 0.6677


 


[2022-04-26 08:45:07,188][line: 97] ==> Train: 


RuntimeError: ignored

In [None]:
import matplotlib.pyplot as plt
# draw figures
print(train_loss_history)
print(val_loss_history)
print(train_acc_history)
print(val_loss_history)

x = list(range(num_epoches))
plt.figure()
plt.plot(x, train_loss_history)
plt.plot(x, val_loss_history)
plt.title('loss history')
plt.show()

plt.figure()
plt.plot(x, train_acc_history)
plt.plot(x, val_acc_history)
plt.title('acc history')
plt.show()
