In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            # 1x1 conv: reduce channels to out_ch/4 (bottleneck compression)
            nn.Conv2d(in_ch, out_ch // 4, kernel_size=1, bias=False),
            nn.GroupNorm(8, out_ch // 4),
            nn.ReLU(inplace=True),
            # 3x3 conv: spatial processing at reduced channel dimension
            nn.Conv2d(out_ch // 4, out_ch // 4, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.GroupNorm(8, out_ch // 4),
            nn.ReLU(inplace=True),
            # 1x1 conv: expand channels back to out_ch
            nn.Conv2d(out_ch // 4, out_ch, kernel_size=1, bias=False),
            nn.GroupNorm(8, out_ch),
        )
        self.relu = nn.ReLU(inplace=True)
        self.stride = stride
        # Shortcut connection: use 1x1 conv if dimensions change, otherwise identity
        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, bias=False),
                nn.GroupNorm(8, out_ch)
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        # Residual connection: F(x) + x
        return self.relu(self.conv(x) + self.shortcut(x))

#   h_i = Attention(L_i^Q(S), L_i^K(T), L_i^V(T))
#   H = Concat(h_1, h_2, ..., h_n)
#   U = L^O(H)
# =============================================================================
class POLA_Attention(nn.Module):
    """
    Patch-based OverLapping Attention (POLA)
    """
    def __init__(self, channels, patch_size=4, num_heads=8):
        super(POLA_Attention, self).__init__()
        self.patch_size = patch_size      # m: size of each square patch (m×m pixels)
        self.num_heads = num_heads        # n: number of attention heads (paper uses n=8)
        self.head_dim = channels // num_heads  # d_k = d/8 as per paper
        self.scale = self.head_dim ** -0.5     # 1/sqrt(d_k) scaling factor, Eq. (1)

        # Separate linear projections for Q, K, V as specified in Eq. (2):
        #   L_i^Q projects the current patch S
        #   L_i^K, L_i^V project the neighborhood T
        self.q_proj = nn.Linear(channels, channels)   # L^Q: query projection
        self.k_proj = nn.Linear(channels, channels)   # L^K: key projection
        self.v_proj = nn.Linear(channels, channels)   # L^V: value projection
        self.out_proj = nn.Linear(channels, channels)  # L^O: output projection, Eq. (4)

    def forward(self, x):
        B, C, H, W = x.shape
        p = self.patch_size

        # Number of patches along height and width
        n_h, n_w = H // p, W // p

        # Reshape (B, C, H, W) → (B, C, n_h, p, n_w, p) to isolate patches
        x_patches = x.view(B, C, n_h, p, n_w, p)
        # Permute to (B, n_h, n_w, p, p, C) — group spatial patch dims together
        x_patches = x_patches.permute(0, 2, 4, 3, 5, 1).contiguous()
        # Flatten each patch's spatial dims: (B, n_h, n_w, m², C)
        # This is S ∈ R^(m²×d) for each patch position
        x_patches = x_patches.view(B, n_h, n_w, p * p, C)

        # Flatten patch tokens for padding: (B, n_h, n_w, m²*C) → (B, m²*C, n_h, n_w)
        x_for_pad = x_patches.view(B, n_h, n_w, -1)
        x_for_pad = x_for_pad.permute(0, 3, 1, 2)

        # Pad the patch grid by 1 on each side (replicate border patches)
        # This handles boundary patches that don't have all 8 neighbors
        x_padded = F.pad(x_for_pad, (1, 1, 1, 1), mode='replicate')

        # Extract 3×3 sliding windows over the padded patch grid
        # unfold(dim=2, size=3, step=1): slide a 3-wide window along patch rows
        # unfold(dim=3, size=3, step=1): slide a 3-wide window along patch cols
        # Result: (B, m²*C, n_h, n_w, 3, 3) — 9 neighbors per patch position
        neighbors = x_padded.unfold(2, 3, 1).unfold(3, 3, 1)

        # Rearrange to (B, n_h, n_w, 3, 3, m²*C) then reshape
        neighbors = neighbors.permute(0, 2, 3, 4, 5, 1).contiguous()
        # Separate back into (B, n_h, n_w, 9, m², C) — 9 neighbor patches, each with m² tokens
        neighbors = neighbors.view(B, n_h, n_w, 9, p * p, C)
        # Merge neighbor and token dims: (B, n_h, n_w, 9m², C)
        # This is T ∈ R^(9M²×d) as defined in the paper
        neighbors = neighbors.view(B, n_h, n_w, 9 * p * p, C)

        # Q from current patch only (S), K and V from full neighborhood (T)
        Q = self.q_proj(x_patches)  # (B, n_h, n_w, m², C) — queries from current patch
        K = self.k_proj(neighbors)  # (B, n_h, n_w, 9m², C) — keys from neighborhood
        V = self.v_proj(neighbors)  # (B, n_h, n_w, 9m², C) — values from neighborhood

        # Split channels into multiple heads: last dim C → (num_heads, head_dim)
        Q = Q.view(B, n_h, n_w, p * p, self.num_heads, self.head_dim)
        K = K.view(B, n_h, n_w, 9 * p * p, self.num_heads, self.head_dim)
        V = V.view(B, n_h, n_w, 9 * p * p, self.num_heads, self.head_dim)

        # Move head dim before token dim for batched matmul
        Q = Q.permute(0, 1, 2, 4, 3, 5)  # (B, n_h, n_w, heads, m², head_dim)
        K = K.permute(0, 1, 2, 4, 3, 5)  # (B, n_h, n_w, heads, 9m², head_dim)
        V = V.permute(0, 1, 2, 4, 3, 5)  # (B, n_h, n_w, heads, 9m², head_dim)

        # Scaled dot-product attention: softmax(Q·K^T / sqrt(d_k)) · V, Eq. (1)
        # Attention matrix shape: (B, n_h, n_w, heads, m², 9m²)
        # Each of the m² query tokens attends to all 9m² neighborhood tokens
        attn = (Q @ K.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        # Weighted sum of values: (B, n_h, n_w, heads, m², head_dim)
        out = attn @ V

        # --- Step 5: Merge heads and project output (Eq. 3-4) ---
        # Concat heads: move head dim back and merge with head_dim
        out = out.permute(0, 1, 2, 4, 3, 5).contiguous()
        out = out.view(B, n_h, n_w, p * p, C)  # U = L^O(H), before final projection

        # Final output projection, Eq. (4)
        out = self.out_proj(out)

        # --- Step 6: Reconstruct spatial feature map from patches ---
        # (B, n_h, n_w, m², C) → (B, n_h, n_w, p, p, C)
        out = out.view(B, n_h, n_w, p, p, C)
        # Permute back to (B, C, n_h, p, n_w, p) then merge to (B, C, H, W)
        out = out.permute(0, 5, 1, 3, 2, 4).contiguous()
        out = out.view(B, C, H, W)

        return out

class PT_FeatureExtractor(nn.Module):
    def __init__(self):
        super(PT_FeatureExtractor, self).__init__()

        # Initial convolution: 3 → 64 channels at full resolution
        self.conv0 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(8, 64),
            nn.ReLU(inplace=True)
        )
        # Progressive downsampling with increasing channels
        # Each layer halves spatial dims (stride=2) and increases channels
        self.layer1 = ConvBlock(64, 256, stride=2)     # 1/2 resolution, 256 ch
        self.layer2 = ConvBlock(256, 512, stride=2)    # 1/4 resolution, 512 ch
        self.layer3 = ConvBlock(512, 1024, stride=2)   # 1/8 resolution, 1024 ch
        self.layer4 = ConvBlock(1024, 2048, stride=2)  # 1/16 resolution, 2048 ch

        self.latlayer4 = nn.Conv2d(2048, 256, kernel_size=1)  # C5 → 256 ch
        self.latlayer3 = nn.Conv2d(1024, 256, kernel_size=1)  # C4 → 256 ch
        self.latlayer2 = nn.Conv2d(512, 256, kernel_size=1)   # C3 → 256 ch

        self.pola_enhancer = POLA_Attention(channels=256)

        self.out5 = nn.Sequential(nn.Conv2d(256, 256, 3, padding=1), nn.GroupNorm(8, 256), nn.ReLU(inplace=True))
        self.out4 = nn.Sequential(nn.Conv2d(256, 256, 3, padding=1), nn.GroupNorm(8, 256), nn.ReLU(inplace=True))
        self.out3 = nn.Sequential(nn.Conv2d(256, 256, 3, padding=1), nn.GroupNorm(8, 256), nn.ReLU(inplace=True))

    def forward(self, x):
        # Extract hierarchical features at decreasing resolutions
        c1 = self.conv0(x)      # (B, 64, H, W)
        c2 = self.layer1(c1)    # (B, 256, H/2, W/2)
        c3 = self.layer2(c2)    # (B, 512, H/4, W/4)
        c4 = self.layer3(c3)    # (B, 1024, H/8, W/8)
        c5 = self.layer4(c4)    # (B, 2048, H/16, W/16)

        # Reduce C5 to 256 channels → M5
        m5 = self.latlayer4(c5)
        # Upsample M5 and add lateral connection from C4 → M4
        m4 = F.interpolate(m5, scale_factor=2, mode='nearest') + self.latlayer3(c4)
        # Upsample M4 and add lateral connection from C3 → M3
        m3 = F.interpolate(m4, scale_factor=2, mode='nearest') + self.latlayer2(c3)

        # Apply 3×3 smoothing conv, then POLA overlapping attention
        # POLA enables each patch to attend to its 3×3 neighborhood,
        # improving local feature matching for cost volume construction
        f5 = self.pola_enhancer(self.out5(m5))  # (B, 256, H/16, W/16) — coarsest
        f4 = self.pola_enhancer(self.out4(m4))  # (B, 256, H/8, W/8)   — middle
        f3 = self.pola_enhancer(self.out3(m3))  # (B, 256, H/4, W/4)   — finest

        # Return multi-scale features for coarse-to-fine depth estimation
        return {"f3": f3, "f4": f4, "f5": f5}


if __name__ == "__main__":
    model = PT_FeatureExtractor()
    # Test with paper's training resolution: 512×640, single image
    test = torch.randn(1, 3, 512, 640)
    features = model(test)
    for k, v in features.items():
        print(f"{k}: {v.shape}")

f3: torch.Size([1, 256, 128, 160])
f4: torch.Size([1, 256, 64, 80])
f5: torch.Size([1, 256, 32, 40])
