In [1]:
import sys
sys.path.append('..')
import torch
import torch.nn as nn
from einops import rearrange

from Tensorized_Layers.TCL import   TCL as TCL_CHANGED
# from Tensorized_Layers.TRL import TRL

class PatchEmbedding(nn.Module):
    def __init__(self, input_size, patch_size, embed_dim, bias = True, device = 'cuda', ignore_modes = (0,1,2)):
        super(PatchEmbedding, self).__init__()
        self.input_size = input_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.bias = bias
        self.device = device
        self.ignore_modes = ignore_modes

        self.tcl_input_size = (self.input_size[0], self.input_size[2]//self.patch_size, self.input_size[3]//self.patch_size,
                                self.patch_size, self.patch_size, self.input_size[1]) # patched input image size
        self.tcl = TCL_CHANGED(input_size=self.tcl_input_size,
                            rank=self.embed_dim,
                            ignore_modes=self.ignore_modes,
                            bias=self.bias, 
                            device=self.device)
            

    def forward(self, x):
        x = rearrange(x, 
                        'b c (p1 h) (p2 w) -> b p1 p2 h w c',
                        h=self.patch_size, w=self.patch_size) # X = [B P1 P2 H W C]
        
        x = self.tcl(x) # X = [B P1 P2 D1 D2 D3]



        return x # patches

In [2]:
x = torch.randn(16, 3, 224, 224)

patch_embed = PatchEmbedding(
    input_size=(16, 3, 224, 224),
    patch_size=4,
    embed_dim=(4,4,3),
    bias=True,
    device='cpu'
).to('cpu')

# 3. Forward pass
out = patch_embed(x)

print(f"Output shape: {out.shape}")

Output shape: torch.Size([16, 56, 56, 4, 4, 3])


# trl patch embedding

In [5]:
import sys
sys.path.append('..')
import torch
import torch.nn as nn
from einops import rearrange

from Tensorized_Layers.TCL import   TCL as TCL_CHANGED
from Tensorized_Layers.TRL import   TRL
# from Tensorized_Layers.TRL import TRL

class TRLPatchEmbedding(nn.Module):
    def __init__(self, input_size, patch_size, embed_dim, bias = True, device = 'cuda', ignore_modes = (0,1,2)):
        super(TRLPatchEmbedding, self).__init__()
        self.input_size = input_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.bias = bias
        self.device = device
        self.ignore_modes = ignore_modes

        self.tcl_input_size = (self.input_size[0], self.input_size[2]//self.patch_size, self.input_size[3]//self.patch_size,
                                self.patch_size, self.patch_size, self.input_size[1]) # patched input image size

        rank = self.embed_dim + self.embed_dim

    
        self.trl = TRL(input_size=self.tcl_input_size,
                            output=self.embed_dim,
                            rank=rank,
                            ignore_modes=self.ignore_modes,
                            bias=self.bias, 
                            device=self.device)

    def forward(self, x):
        x = rearrange(x, 
                        'b c (p1 h) (p2 w) -> b p1 p2 h w c',
                        h=self.patch_size, w=self.patch_size) # X = [B P1 P2 H W C]
        

        print("shape of x is" , x.shape)
        
        x = self.trl(x) # X = [B P1 P2 D1 D2 D3]



        return x # patches

In [8]:
x = torch.randn(16, 3, 224, 224)

patch_embed = TRLPatchEmbedding(
    input_size=(16, 3, 224, 224),
    patch_size=4,
    embed_dim=(4,4,3),
    bias=True,
    device='cpu'
).to('cpu')

# 3. Forward pass
out = patch_embed(x)

print(f"Output shape: {out.shape}")

shape of x is torch.Size([16, 56, 56, 4, 4, 3])
Output shape: torch.Size([16, 56, 56, 4, 4, 3])
