In [None]:
import torch
import torch.nn as nn

############################################################
# 1) Drop Path Function & Module
############################################################

def drop_path(x, drop_prob: float = 0.0, training: bool = False):
    if drop_prob == 0.0 or not training:
        return x
    
    keep_prob = 1 - drop_prob
    batch_size = x.shape[0]
    # shape for random mask -> (B, 1, 1, 1, 1, 1) for your 6D input
    random_tensor = keep_prob + torch.rand(
        (batch_size, 1, 1, 1, 1, 1),
        dtype=x.dtype, device=x.device
    )
    random_tensor.floor_()
    x = x / keep_prob * random_tensor
    return x

class DropPath(nn.Module):
    def __init__(self, drop_prob=0.0):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


In [None]:
class SwinBlock1(nn.Module):
    def __init__(self, w_msa, sw_msa, tcl, embed_shape, dropout=0.0, drop_path_rate=0.0):
        super(SwinBlock1, self).__init__()
        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        self.dropout = nn.Dropout(dropout)
        self.drop_path = DropPath(drop_path_rate)

        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.tcl = tcl

    def forward(self, x):
        # (1) Window MSA + residual
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x_res + self.drop_path(x)

        # (2) TCL + residual
        x_res = x
        x = self.norm2(x)
        x = self.tcl(x)
        x = x_res + self.drop_path(x)

        # (3) Shifted Window MSA + residual
        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x_res + self.drop_path(x)

        # (4) TCL + residual
        x_res = x
        x = self.norm4(x)
        x = self.tcl(x)
        x = x_res + self.drop_path(x)

        return x

############################################################
# 3) SwinTransformer passing drop_path_rate to the blocks
############################################################

class SwinTransformer(nn.Module):
    def __init__(self,
                 img_size=224,
                 patch_size=4,
                 in_chans=3,
                 embed_shape=(4,4,12),
                 bias=True,
                 dropout=0,
                 drop_path_rate=0.0,  # <---
                 device="cuda"):
        super(SwinTransformer, self).__init__()

        self.device = device

        # patch embedding, etc...
        self.patch_embedding = Patch_Embedding(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_shape=embed_shape,
            bias=bias
        )

        # block 1
        self.w_msa_1 = WindowMSA(...)
        self.sw_msa_1 = ShiftedWindowMSA(...)
        self.tcl_1    = TCL_CHANGED(...)
        self.block1_list = nn.ModuleList([
            SwinBlock1(
                w_msa=self.w_msa_1,
                sw_msa=self.sw_msa_1,
                tcl=self.tcl_1,
                embed_shape=embed_shape,
                dropout=dropout,
                drop_path_rate=drop_path_rate  # pass in
            )
            for _ in range(2)
        ])

        # block 2, block 3, block 4... do similarly
        self.patch_merging_1 = TensorizedPatchMerging(...)
        self.block2_list = nn.ModuleList([
            SwinBlock2(..., drop_path_rate=drop_path_rate) 
            for _ in range(2)
        ])

        # etc...

        self.classifier = TRL(...)
        self.pos_embedding = nn.Parameter(
            torch.randn(1, 56, 56, 4, 4, 12, device=self.device),
            requires_grad=True
        )

    def forward(self, x):
        x = self.patch_embedding(x)
        x += self.pos_embedding

        for blk in self.block1_list:
            x = blk(x)

        x = self.patch_merging_1(x)

        for blk in self.block2_list:
            x = blk(x)
        
        # block3, block4, etc...

        x = x.mean(dim=(1, 2))
        output = self.classifier(x)
        return output