<a href="https://colab.research.google.com/github/NethmiAmasha/Edge-Detection-with-Mamba/blob/main/EDMB_fromscratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Colab cell
import torch
print("cuda available:", torch.cuda.is_available())
print("torch:", torch.__version__)
print("cuda device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)

cuda available: True
torch: 2.8.0+cu126
cuda device: Tesla T4


In [2]:
pip install -q torch torchvision

In [3]:
pip install -q tqdm opencv-python matplotlib pillow scipy

In [4]:
pip install -q git+https://github.com/Dao-AILab/causal-conv1d@v1.1.1

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m446.5/446.5 kB[0m [31m29.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.9/44.9 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m180.7/180.7 kB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m119.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.8/42.8 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.6/74.6 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m264.7/264.7 kB[0m [31m28.0 MB/s[0m eta [36m0:00:00[0

High-resolution encoder (Eh)

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class HighResEncoder(nn.Module):
    """
    Minimal high-resolution encoder (Eh).
    Produces two feature maps:
      - f1: same spatial resolution as input (useful for edges)
      - f2: half resolution (downsampled)
    This is intentionally simple so you can read & understand every line.
    """
    def __init__(self, in_ch=3, base_ch=16):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, base_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(base_ch),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(base_ch, base_ch*2, kernel_size=3, stride=2, padding=1),  # downsamples by 2
            nn.BatchNorm2d(base_ch*2),
            nn.ReLU(inplace=True)
        )
        # small extra conv to produce a refined same-res feature from upsampled low-res
        self.refine = nn.Sequential(
            nn.Conv2d(base_ch*2, base_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(base_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # x: (B,3,H,W)
        f1 = self.conv1(x)               # (B, base_ch, H, W)
        f2 = self.conv2(f1)              # (B, base_ch*2, H/2, W/2)
        # upsample f2 back to input resolution and refine
        f2_up = F.interpolate(f2, size=f1.shape[-2:], mode='bilinear', align_corners=False)
        f2_refined = self.refine(f2_up)  # (B, base_ch, H, W)
        # return a list of features consistent with paper notation [f1 (HR), f2 (down->HR)]
        return [f1, f2_refined]

# Quick shape test
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = HighResEncoder(in_ch=3, base_ch=16).to(device)
    x = torch.randn(2,3,320,320).to(device)   # batch of 2, 320x320 images
    feats = model(x)
    print("f1 shape:", feats[0].shape)
    print("f2_refined shape:", feats[1].shape)

f1 shape: torch.Size([2, 16, 320, 320])
f2_refined shape: torch.Size([2, 16, 320, 320])


Global Mamba Encoder (Eg) and Fine-grained Mamba Encoder (Ef)

In [8]:
# Simplified MIXENC (Global + Fine-grained encoders)
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNNEncoder(nn.Module):
    """A lightweight CNN that returns 3 feature maps at different scales."""
    def __init__(self, in_ch=3, base_ch=32):
        super().__init__()
        self.stage1 = nn.Sequential(
            nn.Conv2d(in_ch, base_ch, 3, 1, 1),
            nn.BatchNorm2d(base_ch),
            nn.ReLU(inplace=True)
        )
        self.stage2 = nn.Sequential(
            nn.Conv2d(base_ch, base_ch*2, 3, 2, 1),  # downsample by 2
            nn.BatchNorm2d(base_ch*2),
            nn.ReLU(inplace=True)
        )
        self.stage3 = nn.Sequential(
            nn.Conv2d(base_ch*2, base_ch*4, 3, 2, 1),  # downsample by 4
            nn.BatchNorm2d(base_ch*4),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        f1 = self.stage1(x)  # (B, base_ch, H, W)
        f2 = self.stage2(f1) # (B, base_ch*2, H/2, W/2)
        f3 = self.stage3(f2) # (B, base_ch*4, H/4, W/4)
        return [f1, f2, f3]

class MIXENC_Light(nn.Module):
    """
    Simplified version of MIXENC (global + local encoders) for understanding.
    Mimics structure of Eg (global) + Ef (fine-grained).
    """
    def __init__(self, base_ch=32):
        super().__init__()
        self.global_encoder = SimpleCNNEncoder(in_ch=3, base_ch=base_ch)
        self.local_encoder = SimpleCNNEncoder(in_ch=3, base_ch=base_ch)

    def cat_patch(self, f00, f01, f10, f11, target_size):
        """Stitch 4 patches back together and resize to match global feature size."""
        top = torch.cat([f00, f01], dim=3)
        bottom = torch.cat([f10, f11], dim=3)
        combined = torch.cat([top, bottom], dim=2)
        combined = F.interpolate(combined, size=target_size, mode='bilinear', align_corners=False)
        return combined

    def forward(self, x):
        """
        Returns:
          global_feats: list of feature maps from global encoder
          local_feats: list of patch-fused feature maps
        """
        # 1️⃣ Global features
        global_feats = self.global_encoder(x)

        # 2️⃣ Prepare local patches
        _, _, H, W = x.shape
        x_up = F.interpolate(x, scale_factor=1.2, mode='bilinear', align_corners=False)
        _, _, H2, W2 = x_up.shape
        h_mid, w_mid = H2 // 2, W2 // 2

        patches = [
            x_up[..., :h_mid, :w_mid],   # top-left
            x_up[..., :h_mid, w_mid:],   # top-right
            x_up[..., h_mid:, :w_mid],   # bottom-left
            x_up[..., h_mid:, w_mid:]    # bottom-right
        ]

        # 3️⃣ Extract local features for each patch
        local_feats_per_patch = [self.local_encoder(p) for p in patches]  # list of 4 lists

        # 4️⃣ Merge patch features layer-wise
        local_feats = []
        for i in range(len(global_feats)):
            f00, f01, f10, f11 = local_feats_per_patch[0][i], local_feats_per_patch[1][i], local_feats_per_patch[2][i], local_feats_per_patch[3][i]
            merged = self.cat_patch(f00, f01, f10, f11, target_size=global_feats[i].shape[2:])
            local_feats.append(merged)

        return global_feats, local_feats

# 🔍 Test with random input
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MIXENC_Light(base_ch=32).to(device)
x = torch.randn(1, 3, 320, 320).to(device)

global_feats, local_feats = model(x)
for i, (g, l) in enumerate(zip(global_feats, local_feats)):
    print(f"Level {i+1}: global {g.shape}, local {l.shape}")

Level 1: global torch.Size([1, 32, 320, 320]), local torch.Size([1, 32, 320, 320])
Level 2: global torch.Size([1, 64, 160, 160]), local torch.Size([1, 64, 160, 160])
Level 3: global torch.Size([1, 128, 80, 80]), local torch.Size([1, 128, 80, 80])


In [9]:
# High-resolution encoder (Eh)
model_Eh = HighResEncoder(in_ch=3, base_ch=16).to(device)

# Combined Global + Fine-grained encoder (Eg + Ef)
model_MIX = MIXENC_Light(base_ch=32).to(device)

In [10]:
# Test both encoders and print feature shapes
with torch.no_grad():
    x = torch.randn(1, 3, 320, 320).to(device)
    f_high = model_Eh(x)
    g_global, g_local = model_MIX(x)

print("Eh outputs:")
for i, f in enumerate(f_high):
    print(f"  f{i+1}: {list(f.shape)}")

print("\nEg/Ef outputs (local_feats):")
for i, f in enumerate(g_local):
    print(f"  level{i+1}: {list(f.shape)}")

Eh outputs:
  f1: [1, 16, 320, 320]
  f2: [1, 16, 320, 320]

Eg/Ef outputs (local_feats):
  level1: [1, 32, 320, 320]
  level2: [1, 64, 160, 160]
  level3: [1, 128, 80, 80]


Learnable Gaussian Distribution Decoder

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LGDDecoder_Light(nn.Module):
    """
    Simplified Learnable Gaussian Distribution (LGD) decoder.
    Fuses high-res (Eh) and global (Eg/Ef) features.
    Outputs mu, sigma², and an auxiliary edge map.
    """
    def __init__(self, in_ch_high=16, in_ch_global=128, mid_ch=64):
        super().__init__()
        # 1️⃣ Fuse features from Eh and Eg/Ef
        self.fuse = nn.Sequential(
            nn.Conv2d(in_ch_high + in_ch_global, mid_ch, 3, padding=1),
            nn.BatchNorm2d(mid_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_ch, mid_ch, 3, padding=1),
            nn.BatchNorm2d(mid_ch),
            nn.ReLU(inplace=True)
        )
        # 2️⃣ Output heads
        self.mu_head = nn.Conv2d(mid_ch, 1, kernel_size=1)      # mean
        self.var_head = nn.Sequential(                          # variance (>=0)
            nn.Conv2d(mid_ch, 1, kernel_size=1),
            nn.Softplus()
        )
        self.edge_head = nn.Conv2d(mid_ch, 1, kernel_size=1)    # auxiliary edge map

    def forward(self, f_high, f_global):
        # Upsample global feature to match high-res size
        f_global_up = F.interpolate(f_global, size=f_high.shape[-2:], mode='bilinear', align_corners=False)
        fused = torch.cat([f_high, f_global_up], dim=1)
        x = self.fuse(fused)
        mu = self.mu_head(x)
        sigma2 = self.var_head(x)
        edge_logits = self.edge_head(x)
        return mu, sigma2, edge_logits

# 🔍 Test with random features
device = "cuda" if torch.cuda.is_available() else "cpu"
decoder = LGDDecoder_Light(in_ch_high=16, in_ch_global=128, mid_ch=64).to(device)

# fake inputs from your encoders
f_high = torch.randn(1, 16, 320, 320).to(device)
f_global = torch.randn(1, 128, 80, 80).to(device)

mu, sigma2, edge_logits = decoder(f_high, f_global)
print("mu:", mu.shape, "sigma2:", sigma2.shape, "edge_logits:", edge_logits.shape)
print("sigma2 min/max:", sigma2.min().item(), sigma2.max().item())

mu: torch.Size([1, 1, 320, 320]) sigma2: torch.Size([1, 1, 320, 320]) edge_logits: torch.Size([1, 1, 320, 320])
sigma2 min/max: 0.24327895045280457 1.2697547674179077
