In [8]:
import torch
import torch.nn as nn
import torch.nn.init as init
import math
from einops import rearrange
import torch.nn.functional as F

class TCL_CHANGED(nn.Module):
    def __init__(self, input_size, rank, ignore_modes=(0,), bias=True, device='cuda'):
        """
        input_size: tuple, shape of the input tensor (e.g., (7,7,4,4,3) for a window)
        rank: tuple or int, target rank for the non-ignored modes (e.g., (4,4,3))
        ignore_modes: tuple or int, indices of dimensions to leave unchanged (e.g., (0,1) to preserve spatial grid)
        bias: bool, whether to add a bias parameter.
        device: device string.
        """
        super(TCL_CHANGED, self).__init__()
        
        alphabet = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQERSUVWXYZ'
        self.device = device
        self.bias = bias
        
        if isinstance(input_size, int):
            self.input_size = (input_size,)
        else:
            self.input_size = tuple(input_size)
        
        if isinstance(rank, int):
            self.rank = (rank,)
        else:
            self.rank = tuple(rank)
        
        if isinstance(ignore_modes, int):
            self.ignore_modes = (ignore_modes,)
        else:
            self.ignore_modes = tuple(ignore_modes)
        
        # Remove ignored modes from the input size
        new_size = []
        for i in range(len(self.input_size)):
            if i in self.ignore_modes:
                continue
            else:
                new_size.append(self.input_size[i])
        
        # Register bias if enabled
        if self.bias:
            # Bias shape is the same as the rank tensor
            self.register_parameter('b', nn.Parameter(torch.empty(self.rank, device=self.device), requires_grad=True))
        else:
            self.register_parameter('b', None)
            
        # Register factor matrices (one per mode being contracted)
        for i, r in enumerate(self.rank):
            self.register_parameter(f'u{i}', nn.Parameter(torch.empty((r, new_size[i]), device=self.device), requires_grad=True))
        
        # Dynamically build the einsum formula for tensor contraction.
        index = 0
        formula = ''
        core_str = ''
        extend_str = ''
        out_str = ''
        # Build input part and track core (contracted) vs. extended (ignored) dimensions
        for i in range(len(self.input_size)):
            formula += alphabet[index]
            if i not in self.ignore_modes:
                core_str += alphabet[index]
            else:
                extend_str += alphabet[index]
            index += 1
            if i == len(self.input_size) - 1:
                formula += ','
        
        # Build factor matrices part and output mapping
        for l in range(len(self.rank)):
            formula += alphabet[index]
            formula += core_str[l]
            out_str += alphabet[index]
            index += 1
            if l < len(self.rank) - 1:
                formula += ','
            elif l == len(self.rank) - 1:
                formula += '->'
        formula += extend_str + out_str
        
        self.out_formula = formula
        # Uncomment the following line to inspect the generated einsum formula:
        # print("Generated einsum formula:", self.out_formula)

        self.init_param()  # Initialize parameters

    def forward(self, x):
        """
        If the input x has an extra batch dimension (i.e. its dimension equals len(input_size)+1),
        insert an ellipsis in both the input and output parts of the einsum equation.
        """


        
        if x.dim() == len(self.input_size) + 1:
            input_part, output_part = self.out_formula.split("->")
            new_formula = "..." + input_part + "->..." + output_part
        else:
            new_formula = self.out_formula
        
        operands = [x]
        for i in range(len(self.rank)):
            operands.append(getattr(self, f'u{i}'))
        
        out = torch.einsum(new_formula, *operands)
        if self.bias:
            out += self.b
        return out

    def init_param(self):
        # Initialize factor matrices using Kaiming Uniform initialization
        for i in range(len(self.rank)):
            init.kaiming_uniform_(getattr(self, f'u{i}'), a=math.sqrt(5))
        if self.bias:
            bound = 1 / math.sqrt(self.input_size[0])
            init.uniform_(self.b, -bound, bound)



In [9]:
import torch
import torch.nn as nn


class TensorizedPatchMerging(nn.Module):
    """
    Tensorized patch‑merging for Swin‑like models with flexible concatenation
    along r1, r2 or C.

    Args
    ----
    input_size     : tuple  (B, H, W, r1, r2, C)
    in_embed_shape : tuple  (r1, r2, C)          – shape of a **single** input patch
    out_embed_shape: tuple  (r1', r2', C')       – target embedding shape per patch
    channel_mode   : int    {0,1,2}
                     0 → concatenate along r1  → (B, H/2, W/2, 4*r1,  r2,  C)
                     1 → concatenate along r2  → (B, H/2, W/2,  r1, 4*r2,  C)
                     2 → concatenate along C   → (B, H/2, W/2,  r1,  r2, 4*C)
    bias           : bool
    ignore_modes   : tuple – passed straight to TCL_CHANGED
    device         : str
    """

    def __init__(
        self,
        input_size=(16, 56, 56, 4, 4, 3),
        in_embed_shape=(4, 4, 3),
        out_embed_shape=(4, 4, 6),
        channel_mode: int = 2,
        bias: bool = True,
        ignore_modes=(0, 1, 2),
        device: str = "cuda",
    ):
        super().__init__()

        if channel_mode not in (0, 1, 2):
            raise ValueError("`channel_mode` must be 0, 1 or 2.")
        self.channel_mode = channel_mode

        # ----------  bookkeeping ----------
        self.in_r1, self.in_r2, self.in_C = in_embed_shape
        self.out_r1, self.out_r2, self.out_C = out_embed_shape
        self.in_dim = self.in_r1 * self.in_r2 * self.in_C
        self.out_dim = self.out_r1 * self.out_r2 * self.out_C

        # usual Swin constraint (kept here in case you still rely on it)
        if 4 * self.in_dim != 2 * self.out_dim:
            raise ValueError(
                f"Dimension mismatch: expected out_dim = 2 * in_dim, got "
                f"{self.out_dim} != {2 * self.in_dim}"
            )

        self.ignore_modes = ignore_modes
        self.bias = bias
        self.device = device
        self.input_size = input_size  # (B, H, W, r1, r2, C)

        # ----------  sizes after merging ----------
        B, H, W, r1, r2, C = self.input_size
        if channel_mode == 0:        # expand r1
            merged_shape = (4 * r1, r2, C)
            cat_dim = 3
        elif channel_mode == 1:      # expand r2
            merged_shape = (r1, 4 * r2, C)
            cat_dim = 4
        else:                        # expand C
            merged_shape = (r1, r2, 4 * C)
            cat_dim = 5
        self._cat_dim = cat_dim      # store for forward()

        self.tcl_input_size = (B, H // 2, W // 2, *merged_shape)
        self.norm = nn.LayerNorm(merged_shape)

        # ----------  tensorized linear ----------
        self.tcl = TCL_CHANGED(
            input_size=self.tcl_input_size,
            rank=out_embed_shape,
            ignore_modes=self.ignore_modes,
            bias=self.bias,
            device=self.device,
        )

    # --------------------------------------------------------------------- #
    #                                 forward                               #
    # --------------------------------------------------------------------- #
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x : (B, H, W, r1, r2, C)
        returns
        ------
        merged embeddings : (B, H/2, W/2, out_r1, out_r2, out_C)
        """
        B, H, W, r1, r2, C = x.shape
        if (r1, r2, C) != (self.in_r1, self.in_r2, self.in_C):
            raise ValueError("Input patch embedding shape mismatch.")

        # 2×2 window split
        tl = x[:, 0::2, 0::2]  # top‑left
        bl = x[:, 1::2, 0::2]  # bottom‑left
        tr = x[:, 0::2, 1::2]  # top‑right
        br = x[:, 1::2, 1::2]  # bottom‑right

        # concatenate along the requested mode
        x_merged = torch.cat([tl, bl, tr, br], dim=self._cat_dim)


        print(x_merged.shape)

        # norm + tensorized linear projection
        x_merged = self.norm(x_merged)
        return self.tcl(x_merged)


In [None]:
import torch
# from Tensorized_Layers.TCL_CHANGED import TCL_CHANGED  # Replace with your actual TCL module
# from your_module import TensorizedPatchMerging  # Replace with actual module/file name

# Step 1: Create input tensor on CPU
input_tensor = torch.randn(32, 14, 14, 2, 6, 32)  # Shape = (B, H, W, r1, r2, C)

# Step 2: Instantiate the module with CPU device
patch_merger = TensorizedPatchMerging(
    input_size=(32, 14, 14, 2, 6, 32),
    in_embed_shape=(2, 6, 32),
    out_embed_shape=(2, 6, 64),
    bias=True,
    ignore_modes=(0, 1, 2),
    device="cpu",             
    channel_mode=0            
)

# Step 3: Forward pass
output_tensor = patch_merger(input_tensor)


print(output_tensor.shape)


torch.Size([32, 7, 7, 8, 6, 32])
torch.Size([32, 7, 7, 2, 6, 64])


: 