In [114]:
import sys
sys.path.append("..")
import torch
import torch.nn as nn
from Tensorized_components.patch_embedding  import Patch_Embedding     
from Tensorized_components.w_msa  import WindowMSA     
from Tensorized_components.sh_wmsa  import ShiftedWindowMSA     
from Tensorized_components.patch_merging  import TensorizedPatchMerging  
from Tensorized_Layers.TCL_CHANGED import TCL_CHANGED   
from Tensorized_Layers.TRL import TRL   

In [115]:
class SwinBlock1(nn.Module):
    """
    A class representing 'Block 1' in your Swin Transformer.
    This captures the sequence of:
        (1) Window MSA + residual
        (2) TCL + residual
        (3) Shifted Window MSA + residual
        (4) TCL + residual
    but only for the first block’s hyperparameters and submodules.
    """
    def __init__(self, w_msa, sw_msa, tcl, embed_shape, dropout=0.5):
        super(SwinBlock1, self).__init__()
        # Typically each sub-layer has its own LayerNorm
        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # We pass in pre-built modules (WindowMSA, ShiftedWindowMSA, TCL)
        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.tcl = tcl

    def forward(self, x):
        # ----- First Window MSA + Residual -----
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x + x_res

        # ----- TCL + Residual -----
        x_res = x
        x = self.norm2(x)
        x = self.tcl(x)
        x = x + x_res

        # ----- Shifted Window MSA + Residual -----
        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x + x_res

        # ----- TCL + Residual -----
        x_res = x
        x = self.norm4(x)
        x = self.tcl(x)
        x = x + x_res

        return x


In [116]:
class SwinBlock2(nn.Module):
    def __init__(self, w_msa, sw_msa, tcl, embed_shape=(4,4,6), dropout=0.5):
        super(SwinBlock2, self).__init__()
        # LN layers
        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        self.dropout = nn.Dropout(dropout)

        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.tcl = tcl

    def forward(self, x):
        # Window MSA
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x + x_res

        # TCL
        x_res = x
        x = self.norm2(x)
        x = self.tcl(x)
        x = x + x_res

        # Shifted Window MSA
        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x + x_res

        # TCL
        x_res = x
        x = self.norm4(x)
        x = self.tcl(x)
        x = x + x_res

        return x


In [117]:
class SwinBlock3(nn.Module):
    def __init__(self, w_msa, sw_msa, tcl, embed_shape=(4,4,12), dropout=0.5):
        super(SwinBlock3, self).__init__()
        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        self.dropout = nn.Dropout(dropout)

        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.tcl = tcl

    def forward(self, x):
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x + x_res

        x_res = x
        x = self.norm2(x)
        x = self.tcl(x)
        x = x + x_res

        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x + x_res

        x_res = x
        x = self.norm4(x)
        x = self.tcl(x)
        x = x + x_res
        return x

In [118]:
class SwinBlock4(nn.Module):
    def __init__(self, w_msa, sw_msa, tcl, embed_shape=(4,4,24), dropout=0.5):
        super(SwinBlock4, self).__init__()
        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        self.dropout = nn.Dropout(dropout)
        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.tcl = tcl

    def forward(self, x):
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x + x_res

        x_res = x
        x = self.norm2(x)
        x = self.tcl(x)
        x = x + x_res

        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x + x_res

        x_res = x
        x = self.norm4(x)
        x = self.tcl(x)
        x = x + x_res

        return x


In [119]:
class SwinTransformer(nn.Module):
    def __init__(self,
                 img_size=224,
                 patch_size=4,
                 in_chans=3,
                 embed_shape=(4,4,3),
                 bias=True,
                 dropout=0.5,
                 device="cpu"):
        super(SwinTransformer, self).__init__()


        self.patch_embedding = Patch_Embedding(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_shape=embed_shape,
            bias=bias
        )

        # -------------------------------- block 4 --------------------------

        self.w_msa_1 = WindowMSA(
            window_size=7,
            embed_dims=embed_shape,
            rank_window=embed_shape,
            head_factors=(2, 2, 1),
            device=device
        ).to(device)

        self.sw_msa_1 = ShiftedWindowMSA(
            window_size=7,
            embed_dims=embed_shape,
            rank_window=embed_shape,
            head_factors=(2, 2, 1),
            device=device
        ).to(device)

        self.tcl_1 = TCL_CHANGED(
            input_size=(1, 56, 56, 4, 4, 3),
            rank=(4, 4, 3),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=device
        )

        self.block1_list = nn.ModuleList([
            SwinBlock1(
                w_msa=self.w_msa_1,
                sw_msa=self.sw_msa_1,
                tcl=self.tcl_1,
                embed_shape=embed_shape,
                dropout=dropout
            )
            for _ in range(2)
        ])

        # -------------------------------- block 2 --------------------------


        self.patch_merging_1 = TensorizedPatchMerging(
            input_size=(16, 56, 56, 4, 4, 3),
            in_embed_shape=embed_shape,
            out_embed_shape=(4, 4, 6),
            bias=bias,
            ignore_modes=(0, 1, 2),
            device=device
        ).to(device)

        self.w_msa_2 = WindowMSA(
            window_size=7,
            embed_dims=(4,4,6),
            rank_window=(4,4,6),
            head_factors=(2, 2, 1),
            device=device
        ).to(device)

        self.sw_msa_2 = ShiftedWindowMSA(
            window_size=7,
            embed_dims=(4,4,6),
            rank_window=(4,4,6),
            head_factors=(2, 2, 1),
            device=device
        ).to(device)

        self.tcl_2 = TCL_CHANGED(
            input_size=(1, 28, 28, 4, 4, 6),
            rank=(4, 4, 6),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=device
        )

        # We repeat Block2 two times
        self.block2_list = nn.ModuleList([
            SwinBlock2(
                w_msa=self.w_msa_2,
                sw_msa=self.sw_msa_2,
                tcl=self.tcl_2,
                embed_shape=(4,4,6),  # Stage 2 shape
                dropout=dropout
            )
            for _ in range(2)
        ])


        # -------------------------------- block 3 --------------------------

        self.patch_merging_2 = TensorizedPatchMerging(
            input_size=(1, 28, 28, 4, 4, 6),
            in_embed_shape=(4,4,6),
            out_embed_shape=(4, 4, 12),
            bias=bias,
            ignore_modes=(0, 1, 2),
            device=device
        ).to(device)


        self.w_msa_3 = WindowMSA(
            window_size=7,
            embed_dims=(4,4,12),
            rank_window=(4,4,12),
            head_factors=(2, 2, 1),
            device=device
        ).to(device)

        self.sw_msa_3 = ShiftedWindowMSA(
            window_size=7,
            embed_dims=(4,4,12),
            rank_window=(4,4,12),
            head_factors=(2, 2, 1),
            device=device
        ).to(device)

        self.tcl_3 = TCL_CHANGED(
            input_size=(1, 14, 14, 4, 4, 12),
            rank=(4,4,12),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=device
        )

        # Repeat Block3 6 times
        self.block3_list = nn.ModuleList([
            SwinBlock3(
                w_msa=self.w_msa_3,
                sw_msa=self.sw_msa_3,
                tcl=self.tcl_3,
                embed_shape=(4,4,12),
                dropout=dropout
            )
            for _ in range(6)
        ])

        # -------------------------------- block 4 --------------------------

        

        self.patch_merging_3 = TensorizedPatchMerging(
            input_size=(1, 14, 14, 4, 4, 12),
            in_embed_shape=(4,4,12),
            out_embed_shape=(4, 4, 24),
            bias=bias,
            ignore_modes=(0, 1, 2),
            device=device
        ).to(device)

        self.w_msa_4 = WindowMSA(
            window_size=7,
            embed_dims=(4,4,24),
            rank_window=(4,4,24),
            head_factors=(2, 2, 1),
            device=device
        ).to(device)

        self.sw_msa_4 = ShiftedWindowMSA(
            window_size=7,
            embed_dims=(4,4,24),
            rank_window=(4,4,24),
            head_factors=(2, 2, 1),
            device=device
        ).to(device)

        self.tcl_4 = TCL_CHANGED(
            input_size=(1, 7, 7, 4, 4, 24),
            rank=(4, 4, 24),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=device
        )


        self.block4_list = nn.ModuleList([
            SwinBlock4(
                w_msa=self.w_msa_4,
                sw_msa=self.sw_msa_4,
                tcl=self.tcl_4,
                embed_shape=(4,4,24),
                dropout=dropout
            )
            for _ in range(2)
        ])

        # -------------------------------- classifier --------------------------

        self.classifier = TRL(input_size=(2,4,4,24),
                            output=(200,),
                            rank=(4, 4, 24, 200),
                            ignore_modes=(0,),
                            bias=bias,
                            device=device) # trl rank the same

                


    def forward(self, x):
        # --------
        # Stage 1
        # --------
        x = self.patch_embedding(x)
        print("Data size after patch:", x.shape)

        # ------------------------------------
        # Apply SwinBlock1 twice via for-loop
        # ------------------------------------
        for i, blk in enumerate(self.block1_list, 1):
            x = blk(x)

        # ------------------------------------
        # Patch merging (if you need it)
        # ------------------------------------
        x = self.patch_merging_1(x)


        for i, blk in enumerate(self.block2_list, 1):
            x = blk(x)

        x = self.patch_merging_2(x)
        print("Data size after patch:", x.shape)

        for i, blk in enumerate(self.block3_list, 1):
            x = blk(x)

        print("Data size after patch:", x.shape)

        x = self.patch_merging_3(x)

        print("x after patch merging" , x.shape)

        for i, blk in enumerate(self.block4_list, 1):
            x = blk(x)
            print(f"Shape after Block4_{i}:", x.shape)


        x = x.mean(dim=(1, 2))

        print(x.shape)

        output = self.classifier(x)

        print(output.shape)
        return x

In [120]:
# Create a dummy input tensor (batch_size=1, channels=3, height=224, width=224)
dummy_input = torch.randn(2, 3, 224, 224)

# Initialize the model
model = SwinTransformer(img_size=224,patch_size=4,in_chans=3,embed_shape=(4,4,3),bias=True,device="cpu")

# Forward pass
output = model(dummy_input)

# Output shape
# print(output.shape)

Data size after patch: torch.Size([2, 56, 56, 4, 4, 3])
Data size after patch: torch.Size([2, 14, 14, 4, 4, 12])
Data size after patch: torch.Size([2, 14, 14, 4, 4, 12])
x after patch merging torch.Size([2, 7, 7, 4, 4, 24])
Shape after Block4_1: torch.Size([2, 7, 7, 4, 4, 24])
Shape after Block4_2: torch.Size([2, 7, 7, 4, 4, 24])
torch.Size([2, 4, 4, 24])
torch.Size([2, 200])


In [121]:
total_params = sum(p.numel() for p in model.parameters())
print("Total parameters:", total_params)

Total parameters: 156699
