In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from einops import rearrange
from tqdm import tqdm
from pytorch_msssim import ssim
from torchsummary import summary

  from .autonotebook import tqdm as notebook_tqdm


In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class EfficientAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=8, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv_proj(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), qkv)
        k = k.softmax(dim=-2)
        context = torch.einsum('bhnd,bhne->bhde', k, v)
        attn = torch.einsum('bhnd,bhde->bhne', q, context) * self.scale
        out = rearrange(attn, 'b h n d -> b n (h d)')
        out = self.out_proj(out)
        return self.dropout(out)

class EnhancedPatchEmbedding(nn.Module):
    def __init__(self, img_size=256, patch_size=16, in_channels=3, embed_dim=512):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Conv2d(in_channels, embed_dim // 2, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(embed_dim // 2),
            nn.GELU(),
            nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(embed_dim),
            nn.GELU(),
            nn.Conv2d(embed_dim, embed_dim, kernel_size=patch_size // 2, stride=patch_size // 2, padding=0),
            nn.BatchNorm2d(embed_dim)
        )

    def forward(self, x):
        x = self.proj(x)
        return rearrange(x, 'b c h w -> b (h w) c')


class EnhancedTransformerEncoder(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, num_layers=4, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.ModuleDict({
                'norm1': nn.LayerNorm(embed_dim),
                'attn': EfficientAttention(embed_dim, num_heads, dropout),
                'norm2': nn.LayerNorm(embed_dim),
                'mlp': nn.Sequential(
                    nn.Linear(embed_dim, embed_dim * 4),
                    nn.GELU(),
                    nn.Dropout(dropout),
                    nn.Linear(embed_dim * 4, embed_dim),
                    nn.Dropout(dropout)
                )
            }) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        skip_connections = []
        for layer in self.layers:
            x_norm = layer['norm1'](x)
            x = x + layer['attn'](x_norm)
            x_norm = layer['norm2'](x)
            x = x + layer['mlp'](x_norm)
            skip_connections.append(x)
        return self.norm(x), skip_connections

class FeatureAlignmentModule(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.LayerNorm(out_dim),
            nn.GELU()
        )
        self.spatial_align = nn.Conv2d(1, 1, kernel_size=3, padding=1)

    def forward(self, x, target_size):
        B, N, C = x.shape
        H, W = target_size
        x = self.proj(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=int(N**0.5))
        x = F.interpolate(x, size=target_size, mode='bilinear', align_corners=False)
        x = x + self.spatial_align(x.mean(dim=1, keepdim=True))
        return x

class EnhancedSSMDecoder(nn.Module):
    def __init__(self, embed_dim=512, out_channels=3):
        super().__init__()
        self.embed_dim = embed_dim
        self.skip_align = nn.ModuleList([
            FeatureAlignmentModule(embed_dim, embed_dim//2),
            FeatureAlignmentModule(embed_dim, embed_dim//4),
            FeatureAlignmentModule(embed_dim, embed_dim//8),
            FeatureAlignmentModule(embed_dim, embed_dim//16)
            #FeatureAlignmentModule(embed_dim, embed_dim//32)
        ])
        self.up_blocks = nn.ModuleList([
            self._make_up_block(embed_dim, embed_dim//2),
            self._make_up_block(embed_dim//2, embed_dim//4),
            self._make_up_block(embed_dim//4, embed_dim//8),
            self._make_up_block(embed_dim//8, embed_dim//16),
            self._make_up_block(embed_dim//16, embed_dim//32),
            self._make_up_block(embed_dim // 32, embed_dim // 64)
        ])
        self.final_conv = nn.Sequential(
            nn.Conv2d(embed_dim // 16, embed_dim // 16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(embed_dim // 16, out_channels, kernel_size=1)
        )
    def _make_up_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch * 4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.BatchNorm2d(out_ch),
            nn.GELU(),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.GELU()
        )


    def forward(self, x, skip_connections):
        B, N, C = x.shape
        x = rearrange(x, 'b (h w) c -> b c h w', h=int(N**0.5))
        for i, (up_block, skip_align) in enumerate(zip(self.up_blocks, self.skip_align)):
            x = up_block(x)
            skip_idx = min(len(skip_connections) - 1, i)
            skip = skip_align(skip_connections[-(skip_idx + 1)], x.shape[-2:])
            x = x + skip
        return torch.sigmoid(self.final_conv(x))

class EnhancedHybridMamba(nn.Module):
    def __init__(self, img_size=256, patch_size=16, in_channels=3, embed_dim=512):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.patch_embed = EnhancedPatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.transformer = EnhancedTransformerEncoder(embed_dim, num_layers=8)
        self.decoder = EnhancedSSMDecoder(embed_dim)
        
        # Add one more upsampling block and Conv2d layer
        self.final_upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),  # e.g., 128x128 -> 256x256
            nn.Conv2d(3, 3, kernel_size=3, padding=1),  # 3 channels in, 3 out
            nn.Sigmoid()  # Final output in [0,1]
        )

    def forward(self, x):
        x = self.patch_embed(x)
        x, skip_connections = self.transformer(x)
        x = self.decoder(x, skip_connections)
        x = self.final_upsample(x)
        return x



In [18]:
model = EnhancedHybridMamba(img_size=256)
input_tensor = torch.randn(1, 3, 256, 256)  # Batch of 1 RGB image
output = model(input_tensor)
print(output.shape)  # Should print torch.Size([1, 3, 256, 256])

torch.Size([1, 3, 256, 256])


In [19]:
from torchinfo import summary
model = EnhancedHybridMamba()
summary(model, input_size=(1, 3, 256, 256), col_names=["input_size", "output_size", "num_params", "mult_adds"])



Layer (type:depth-idx)                        Input Shape               Output Shape              Param #                   Mult-Adds
EnhancedHybridMamba                           --                        --                        --                        --
├─EnhancedTransformerEncoder: 1               --                        --                        --                        --
│    └─ModuleList: 2-1                        --                        --                        --                        --
├─EnhancedSSMDecoder: 1                       --                        --                        --                        --
│    └─ModuleList: 2-2                        --                        --                        --                        --
│    └─ModuleList: 2-3                        --                        --                        --                        --
├─EnhancedPatchEmbedding: 1-1                 [1, 3, 256, 256]          [1, 64, 512]              --    