In [1]:
import sys
sys.path.append("..")
import torch
from torch import nn
from torch import optim
import torchvision.transforms as transforms
import time
import os
from Tensorized_components.tcl_patch_embedding_ex  import PatchEmbedding     
from Tensorized_components.w_msa_w_o_b_sign_extended  import WindowMSA     
from Tensorized_components.sh_wmsa_w_o_b_sign_extended import ShiftedWindowMSA     
from Tensorized_components.patch_merging  import TensorizedPatchMerging  
from Tensorized_Layers.TCL import TCL_extended as TCL_CHANGED   
from Tensorized_Layers.TRL import TRL   
from Utils.Accuracy_measures import topk_accuracy
from Utils.TinyImageNet_loader import get_tinyimagenet_dataloaders
from Utils.Num_parameter import count_parameters

In [2]:
class SwinBlock1(nn.Module):
    """
    A class representing 'Block 1' in your Swin Transformer.
    This captures the sequence of:
        (1) Window MSA + residual
        (2) TCL + residual
        (3) Shifted Window MSA + residual
        (4) TCL + residual
    but only for the first block’s hyperparameters and submodules.
    """
    def __init__(self, w_msa, sw_msa, tcl, embed_shape, dropout=0):
        super(SwinBlock1, self).__init__()
        # Typically each sub-layer has its own LayerNorm
        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # We pass in pre-built modules (WindowMSA, ShiftedWindowMSA, TCL)
        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.tcl = tcl

    def forward(self, x):
        # ----- First Window MSA + Residual -----
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x + x_res

        # ----- TCL + Residual -----
        x_res = x
        x = self.norm2(x)
        x = self.tcl(x)
        x = x + x_res

        # ----- Shifted Window MSA + Residual -----
        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x + x_res

        # ----- TCL + Residual -----
        x_res = x
        x = self.norm4(x)
        x = self.tcl(x)
        x = x + x_res

        return x


In [3]:
class SwinBlock2(nn.Module):
    def __init__(self, w_msa, sw_msa, tcl, embed_shape=(4,4,6), dropout=0):
        super(SwinBlock2, self).__init__()
        # LN layers
        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        self.dropout = nn.Dropout(dropout)

        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.tcl = tcl

    def forward(self, x):
        # Window MSA
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x + x_res

        # TCL
        x_res = x
        x = self.norm2(x)
        x = self.tcl(x)
        x = x + x_res

        # Shifted Window MSA
        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x + x_res

        # TCL
        x_res = x
        x = self.norm4(x)
        x = self.tcl(x)
        x = x + x_res

        return x


In [4]:
class SwinBlock3(nn.Module):
    def __init__(self, w_msa, sw_msa, tcl, embed_shape=(4,4,12), dropout=0):
        super(SwinBlock3, self).__init__()
        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        self.dropout = nn.Dropout(dropout)

        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.tcl = tcl

    def forward(self, x):
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x + x_res

        x_res = x
        x = self.norm2(x)
        x = self.tcl(x)
        x = x + x_res

        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x + x_res

        x_res = x
        x = self.norm4(x)
        x = self.tcl(x)
        x = x + x_res
        return x

In [5]:
class SwinBlock4(nn.Module):
    def __init__(self, w_msa, sw_msa, tcl, embed_shape=(4,4,24), dropout=0):
        super(SwinBlock4, self).__init__()
        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        self.dropout = nn.Dropout(dropout)
        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.tcl = tcl

    def forward(self, x):
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x + x_res

        x_res = x
        x = self.norm2(x)
        x = self.tcl(x)
        x = x + x_res

        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x + x_res

        x_res = x
        x = self.norm4(x)
        x = self.tcl(x)
        x = x + x_res

        return x


In [6]:
class SwinTransformer(nn.Module):
    def __init__(self,
                 img_size=224,
                 patch_size=4,
                 in_chans=3,
                 embed_shape=(4,4,12),
                 bias=True,
                 dropout=0,
                 device="cuda"):
        super(SwinTransformer, self).__init__()

        self.device = device


        self.patch_embedding = PatchEmbedding(
            input_size=(32,3,224,224),
            patch_size=4,
            embed_dim=(4,4,12),
            bias=bias,
            device=self.device,
            ignore_modes = (0,1,2),
            tcl_type='ex'
        )

        # self, input_size, patch_size, embed_dim, bias = True, device = 'cuda', ignore_modes = (0,1,2), tcl_type='normal', tcl_r = 2

        # -------------------------------- block 1 --------------------------

        self.w_msa_1 = WindowMSA(
            window_size=7,
            embed_dims=embed_shape,
            rank_window=embed_shape,
            head_factors=(1,2,3),
            device=self.device
        )

        self.sw_msa_1 = ShiftedWindowMSA(
            window_size=7,
            embed_dims=embed_shape,
            rank_window=embed_shape,
            head_factors=(1,2,3),
            device=self.device
        )

        self.tcl_1 = TCL_CHANGED(
            input_size=(16, 56, 56, 4,4,12),
            rank=(4,4,12),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )

        self.block1_list = nn.ModuleList([
            SwinBlock1(
                w_msa=self.w_msa_1,
                sw_msa=self.sw_msa_1,
                tcl=self.tcl_1,
                embed_shape=embed_shape,
                dropout=dropout
            )
            for _ in range(2)
        ])

        # -------------------------------- block 2 --------------------------


        self.patch_merging_1 = TensorizedPatchMerging(
            input_size=(16, 56, 56, 4,4,12),
            in_embed_shape=embed_shape,
            out_embed_shape=(4,4,24),
            bias=bias,
            ignore_modes=(0, 1, 2),
            device=self.device
        )

        self.w_msa_2 = WindowMSA(
            window_size=7,
            embed_dims=(4,4,24),
            rank_window=(4,4,24),
            head_factors=(1,2,6),
            device=self.device
        )

        self.sw_msa_2 = ShiftedWindowMSA(
            window_size=7,
            embed_dims=(4,4,24),
            rank_window=(4,4,24),
            head_factors=(1,2,6),
            device=self.device
        )

        self.tcl_2 = TCL_CHANGED(
            input_size=(16, 28, 28, 4,4,24),
            rank=(4,4,24),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )

        # We repeat Block2 two times
        self.block2_list = nn.ModuleList([
            SwinBlock2(
                w_msa=self.w_msa_2,
                sw_msa=self.sw_msa_2,
                tcl=self.tcl_2,
                embed_shape=(4,4,24),  
                dropout=dropout
            )
            for _ in range(2)
        ])


        # # -------------------------------- block 3 --------------------------

        self.patch_merging_2 = TensorizedPatchMerging(
            input_size=(16, 28, 28, 4,4,24),
            in_embed_shape=(4,4,24),
            out_embed_shape=(4,4,48),
            bias=bias,
            ignore_modes=(0, 1, 2),
            device=self.device
        )


        self.w_msa_3 = WindowMSA(
            window_size=7,
            embed_dims=(4,4,48),
            rank_window=(4,4,48),
            head_factors=(2,1,12),
            device=self.device
        )

        self.sw_msa_3 = ShiftedWindowMSA(
            window_size=7,
            embed_dims=(4,4,48),
            rank_window=(4,4,48),
            head_factors=(2,1,12),
            device=self.device
        )

        self.tcl_3 = TCL_CHANGED(
            input_size=(16, 14, 14, 4,4,48),
            rank=(4,4,48),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )

        # Repeat Block3 6 times
        self.block3_list = nn.ModuleList([
            SwinBlock3(
                w_msa=self.w_msa_3,
                sw_msa=self.sw_msa_3,
                tcl=self.tcl_3,
                embed_shape=(4,4,48),
                dropout=dropout
            )
            for _ in range(18)
        ])

        # # # -------------------------------- block 4 --------------------------

        self.patch_merging_3 = TensorizedPatchMerging(
            input_size=(16, 14, 14, 4,4,48),
            in_embed_shape=(4,4,48),
            out_embed_shape=(4,4,96),
            bias=bias,
            ignore_modes=(0, 1, 2),
            device=self.device
        )

        self.w_msa_4 = WindowMSA(
            window_size=7,
            embed_dims=(4,4,96),
            rank_window=(4,4,96),
            head_factors=(2,1,24),
            device=self.device
        )

        self.sw_msa_4 = ShiftedWindowMSA(
            window_size=7,
            embed_dims=(4,4,96),
            rank_window=(4,4,96),
            head_factors=(2,1,24),
            device=self.device
        )

        self.tcl_4 = TCL_CHANGED(
            input_size=(16, 7, 7, 4,4,96),
            rank=(4,4,96),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )


        self.block4_list = nn.ModuleList([
            SwinBlock4(
                w_msa=self.w_msa_4,
                sw_msa=self.sw_msa_4,
                tcl=self.tcl_4,
                embed_shape=(4,4,96),
                dropout=dropout
            )
            for _ in range(2)
        ])

        # -------------------------------- classifier --------------------------

    

        # self.classifier = TRL(input_size=(16,4,4,96),
        #                     output=(200,),
        #                     rank=(4,4,96,200),
        #                     ignore_modes=(0,),
        #                     bias=bias,
        #                     device=self.device) 
        

        # positoin embedding


        # self.pos_embedding = nn.Parameter(
        #     torch.randn(1,
        #                 56,
        #                 56,
        #                 4,
        #                 4,
        #                 12,
        #                 device = self.device
        #                 ), requires_grad=True)

    def forward(self, x):
 

        x = self.patch_embedding(x)

        # x += self.pos_embedding

        for i, blk in enumerate(self.block1_list, 1):
            x = blk(x)


        x = self.patch_merging_1(x)



        for i, blk in enumerate(self.block2_list, 1):
            x = blk(x)


        x = self.patch_merging_2(x)

        for i, blk in enumerate(self.block3_list, 1):
            x = blk(x)


        x = self.patch_merging_3(x)


        for i, blk in enumerate(self.block4_list, 1):
            x = blk(x)


        x = x.mean(dim=(1, 2))

        output = self.classifier(x)
        return output

In [7]:
# Setup the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
print(f'Device is set to : {device}')

# Configs

TEST_ID = 'Test_ID015'
batch_size = 16
n_epoch = 400
image_size = 224

model = SwinTransformer(img_size=224,patch_size=4,in_chans=3,embed_shape=(4,4,12),bias=True,device=device).to(device)


Device is set to : cpu


In [8]:
num_parameters = count_parameters(model)
print(f'This Model has {num_parameters} parameters')

This Model has 423080 parameters
