In [None]:
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import logging
import math

from os.path import join as pjoin

import torch
import torch.nn as nn
import numpy as np

from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
from .swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys

6. This is the model to be trained(SWINunet) which is created based on swin transformers or Hierarchical Vision Transformer using Shifted Windows.

 this code takes the implementation of swin transformer, ceeates raw swinUnet and from a pretrained check point loads pretrained parameters for it


there are different arguments clearly defined in the code and the code is going to load pretrained model parametrs both for up sampling and downsampling layers

In [None]:


logger = logging.getLogger(__name__)
#this is the main class of network SwinUnet
class SwinUnet(nn.Module):
    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
        super(SwinUnet, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.config = config
        
        #swin unet is madeup Swin transformer with the following parameters as below
       r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """
        self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE,
                                patch_size=config.MODEL.SWIN.PATCH_SIZE,
                                in_chans=config.MODEL.SWIN.IN_CHANS,
                                num_classes=self.num_classes,
                                embed_dim=config.MODEL.SWIN.EMBED_DIM,
                                depths=config.MODEL.SWIN.DEPTHS,
                                num_heads=config.MODEL.SWIN.NUM_HEADS,
                                window_size=config.MODEL.SWIN.WINDOW_SIZE,
                                mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
                                qkv_bias=config.MODEL.SWIN.QKV_BIAS,
                                qk_scale=config.MODEL.SWIN.QK_SCALE,
                                drop_rate=config.MODEL.DROP_RATE,
                                drop_path_rate=config.MODEL.DROP_PATH_RATE,
                                ape=config.MODEL.SWIN.APE,
                                patch_norm=config.MODEL.SWIN.PATCH_NORM,
                                use_checkpoint=config.TRAIN.USE_CHECKPOINT)
    #it actually consists of different layers including encoder and decoder ,bottleneck ,etc, for more details check its code networks> swin_transormer.....

    def forward(self, x):
        if x.size()[1] == 1:
            x = x.repeat(1,3,1,1)
        logits = self.swin_unet(x)
        return logits

    def load_from(self, config):
        pretrained_path = config.MODEL.PRETRAIN_CKPT   #1.pretrained file name is specified in the cofiguration file
        if pretrained_path is not None:                  #2. 
            print("pretrained_path:{}".format(pretrained_path))
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            pretrained_dict = torch.load(pretrained_path, map_location=device) #3.loads the pretrained checkpoint
            
            if "model"  not in pretrained_dict:         
                print("---start load pretrained modle by splitting---")
                pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()}
                for k in list(pretrained_dict.keys()):
                    if "output" in k:
                        print("delete key:{}".format(k))
                        del pretrained_dict[k]
                msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False)
                # print(msg)
                return
            pretrained_dict = pretrained_dict['model'] #take only key=model of pretrained check point
            print("---start load pretrained modle of swin encoder---")  #4. start load pretrained modle of swin encoder---
            full_dict = copy.deepcopy(pretrained_dict)     #at this point full dict=pretrained dict=which includes parameters of pretrained model

            #the goal of this for is updating the full dict 
            for k, v in pretrained_dict.items(): # k in pretrained_dic is a pretrained parameter for different layer(mostly for layers). if this parameter (k)  
                                                  #is for a layer it starts with layer. then layer number     >  layer.0 for example 
                if "layers." in k:
                    current_layer_num = 3-int(k[7:8])   #int(k[7:8]) means x in  layer.x in k  > this is just a mapping > if layer.0 > current up_layer is 3
                                                                                                                           # layer.1 > current up_layer is 2
                                                                                                                           # layer.2 > current layer is 1
                                                                                                                          # this is done to coorectly load pretrained model based on implementation and simply can be said to be a renaming procedure
                    current_k = "layers_up." + str(current_layer_num) + k[8:]  # reconstruct the name of k for update , to include parametrs for upsampling layers
                    full_dict.update({current_k:v})  # this is a full dict including parameters both for upsampling and downsampling pretrained model

            
            model_dict = self.swin_unet.state_dict() #A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor(learnable parameters)        
            for k in list(full_dict.keys()):
                if k in model_dict:
                    if full_dict[k].shape != model_dict[k].shape:
                        print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape))
                        del full_dict[k]

            msg = self.swin_unet.load_state_dict(full_dict, strict=False) #now we have a model parameters loaded from pretrained model  both for up and down sampling layers
            # print(msg)
        else:
            print("none pretrain")