In [1]:
# import mamba_ssm
import networks.mamba_sys as mamba_sys
import torch
import torch as tc
import my_mambanet
import importlib
import torch.nn as nn

from einops import rearrange, repeat

importlib.reload(my_mambanet)


<module 'my_mambanet' from '/home/collettida/myprojects/ExploringMamba/code/my_mambanet.py'>

In [2]:
class PatchEmbed2D(nn.Module):
    r""" Image to Patch Embedding
    Args:
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, groups=1, norm_layer=None, **kwargs):
        super().__init__()
        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)
        self.proj = nn.Conv2d(in_chans, 
                              embed_dim, 
                              kernel_size=patch_size,
                              stride=patch_size,
                              groups=groups)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None
    
    def forward(self, x):
        x = self.proj(x)
        m = nn.ZeroPad2d((0, 1, 0, 1))
        # x = m(x)
        x = x.permute(0, 2, 3, 1)
        if self.norm is not None:
            x = self.norm(x)
        return x

inp = torch.randn(16,30,21,21).cuda()
m = nn.ZeroPad2d((0, 3, 0, 3))
inp_p = m(inp)
# int = torch.randn(16,1,21,21).cuda()
pb2d = PatchEmbed2D(img_size=21, patch_size=3, in_chans=30, embed_dim=96, groups=1).cuda()
out = pb2d(inp)
print(out.shape)
out = pb2d(inp_p)
print(out.shape)



torch.Size([16, 7, 7, 96])
torch.Size([16, 8, 8, 96])


In [3]:
class PatchEmbedVideo(nn.Module):
    """ Video to Patch Embedding
    Args:
        img_size (int): Frame size.  Default: 21
        in_chans (int): Number of input image channels. Default: 3
        n_frames (int): Number of frames. Default: 8
        patch_size (int): Patch token size. Default: 3
        stride (int): Stride of the patch embedding. Default: None
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """
    def __init__(self, image_size=21, n_frames=8, patch_size=3, 
                 stride=None, in_chans=2, embed_dim=96, groups=1, 
                 norm_layer=None, **kwargs):
        super().__init__()
        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)
        if stride is None:
            stride = patch_size
        if isinstance(stride, int):
            stride = (stride, stride)

        self.proj = nn.Conv2d(in_chans, 
                              embed_dim, 
                              kernel_size=patch_size,
                              stride=patch_size,
                              groups=groups)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None 
        
    def forward(self, x: tc.Tensor):
        """ Forward function.
        Args:
            x: (tc.Tensor) Input video of shape (B, T, C, H, W)
        Returns:
            tc.Tensor: Patch embedded video of shape (B, H', W', embed_dim)
        """
        B, T, C, H, W = x.shape
         
        for t in range(T):
            x_t = x[:, t, :, :, :]  # (B, C, H, W)
            x_t = self.proj(x_t)  # (B, embed_dim, H', W')
            if t == 0:
                x_out = x_t.unsqueeze(1)  # (B, 1, embed_dim, H', W')
            else:
                x_out = tc.cat((x_out, x_t.unsqueeze(1)), dim=1)  # (B, T, embed_dim, H', W')
        
        
        x_out = tc.reshape(x_out, (B, T, x_out.shape[2], x_out.shape[3]*x_out.shape[4]))
        x_out = x_out.permute(0, 1, 3, 2)  # (B, T, H'*W', embed_dim)
        if self.norm is not None:
            x_out = self.norm(x_out)
        
        return x_out

inp = torch.randn(16,16,2,24,24).cuda()
m = nn.ZeroPad2d((0, 3, 0, 3))
# inp_p = m(inp)
# int = torch.randn(16,1,21,21).cuda()
pbv = PatchEmbedVideo().cuda()
out = pbv(inp)
print(out.shape)
# out = pbv(inp_p)
# print(out.shape)

torch.Size([16, 16, 64, 96])


In [4]:
class PatchMerging2D(nn.Module):
    r""" PatchMerging2D performs spatial downsampling by a factor of 2.
        It groups each 2x2 neighborhood of tokens, concatenates their channels
        (C → 4C), applies LayerNorm, and projects them to a lower-dimensional
        embedding (4C → 2C). The output has half the spatial resolution and
        twice the channel dimension: (B, H, W, C) → (B, H/2, W/2, 2C).
        Patch Merging Layer.
    Args:
        dim (int): Resolution of input token.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        B, H, W, C = x.shape

        SHAPE_FIX = [-1, -1]
        if H % 2 != 0:
            print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True)
            SHAPE_FIX[0] = H // 2
            SHAPE_FIX[1] = W // 2
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        
        if SHAPE_FIX[0] > 0:
            x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]

        x = tc.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, H//2, W//2, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)       # FC layer to reduce dimension

        return x

pm2d = PatchMerging2D(dim=96).cuda()
inp2 = torch.randn(16,8,8,96).cuda()
out2 = pm2d(inp2)
print(out2.shape)

torch.Size([16, 4, 4, 192])


In [5]:
class PatchExpand(nn.Module):
    """
    PatchExpand layer.

    Upsamples the input feature map by a factor of 2 by converting channel
    information into spatial resolution. The operation applies a linear
    projection followed by a PixelShuffle-style rearrangement:
    (B, H, W, C) → (B, 2H, 2W, C/2).
    """
    def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.dim_scale = dim_scale
        self.expand = nn.Linear(
            dim,  2*dim, bias = False) if dim_scale == 2 else nn.Identity()
            # applied to last dimension
        self.norm = norm_layer(dim // dim_scale)

    def forward(self, x):
        x = self.expand(x)      # B, H, W, C -> B, H, W, 2C
        B, H, W, C = x.shape
        x = rearrange(x, 'b h w (p1 p2 c) -> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
        x = self.norm(x)
        return x

pe = PatchExpand(dim=96*2).cuda()
inp3 = torch.randn(16,4,4,192).cuda()
out3 = pe(inp3)
print(out3.shape)

torch.Size([16, 8, 8, 96])


In [6]:
class FinalPatchExpand(nn.Module):
    """
    FinalPatchExpand layer.
    # a kind of unembedding
    Upsamples the input feature map by a factor of 4 by converting channel
    information into spatial resolution. The operation applies a linear
    projection followed by a PixelShuffle-style rearrangement:
    (B, H, W, C) → (B, 4H, 4W, C).
    """
    def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.dim_scale = dim_scale
        self.expand = nn.Linear(dim, self.dim_scale**2 * dim, bias = False)
            # applied to last dimension
        self.output_dim = dim            
        self.norm = norm_layer(self.output_dim)

    def forward(self, x):
        x = self.expand(x)
        B, H, W, C = x.shape
        x = rearrange(x, 'b h w (p1 p2 c) -> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale**2)
        x = self.norm(x)
        return x    
    
fpex4 = FinalPatchExpand(dim=96, dim_scale=3).cuda()
out_l = nn.Conv2d(in_channels=96,out_channels=2,kernel_size=1,bias=False).cuda()
out_l = nn.Conv2d(in_channels=16,out_channels=1,kernel_size=1,bias=False).cuda()
inp4 = torch.randn(16,16,64,96*3).cuda()
out4 = inp4
# out4 = fpex4(inp4)
out4 = out_l(out4)
print(f"out4.shape = {out4.shape}")




out4.shape = torch.Size([16, 1, 64, 288])


In [8]:
def s_flatten(x):
    B, T, H, W, C = x.shape
    x = x.view(B, T, H*W, C)
    return x

def s_unflatten(x, H, W):
    B, T, HW, C = x.shape
    x = x.view(B, T, H, W, C)
    return x

In [11]:
class FinalPatchExpandVideo(nn.Module):
    """
    FinalPatchExpand layer.
    # a kind of unembedding
    Upsamples the input feature map by a factor of 4 by converting channel
    information into spatial resolution. The operation applies a linear
    projection followed by a PixelShuffle-style rearrangement:
    (B, H, W, C) → (B, 4H, 4W, C).
    """
    def __init__(self, T, dim, dim_scale=3, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.dim_scale = dim_scale
        self.expand = nn.Linear(dim, self.dim_scale**2 * dim, bias = False)
            # applied to last dimension
        self.output_dim = dim            
        self.norm = norm_layer(self.output_dim)
        self.condenseT = nn.Conv2d(in_channels=T, out_channels=1, kernel_size=1, bias=False)

    def forward(self, x):
        x = self.condenseT(x)
        print(f"x.shape after condenseT: {x.shape}")
        x = s_unflatten(x, H=8, W=8).squeeze(1)
        x = self.expand(x)
        B, H, W, C = x.shape
        x = rearrange(x, 'b h w (p1 p2 c) -> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale**2)
        x = self.norm(x)
        return x    
 
fpex4v = FinalPatchExpandVideo(dim=96, T = 16, dim_scale=3).cuda()
inp4v = torch.randn(16,16,64,96).cuda()
out4v = fpex4v(inp4v)
print(f"out4v.shape = {out4v.shape}")

x.shape after condenseT: torch.Size([16, 1, 64, 96])
out4v.shape = torch.Size([16, 24, 24, 96])
