In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Residual Block
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch // 4, kernel_size=1, bias=False),
            nn.GroupNorm(8, out_ch // 4),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch // 4, out_ch // 4, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.GroupNorm(8, out_ch // 4),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch // 4, out_ch, kernel_size=1, bias=False),
            nn.GroupNorm(8, out_ch),
        )
        self.relu = nn.ReLU(inplace=True)
        self.stride = stride
        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, bias=False),
                nn.GroupNorm(8, out_ch)
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        return self.relu(self.conv(x) + self.shortcut(x))

# POLA
class POLA_Attention(nn.Module):
    """
    Patch-based OverLapping Attention (POLA)
    """
    def __init__(self, channels, patch_size=4, num_heads=8):
        super(POLA_Attention, self).__init__()
        self.patch_size = patch_size
        self.num_heads = num_heads
        self.head_dim = channels // num_heads
        self.scale = self.head_dim ** -0.5

        self.q_proj = nn.Linear(channels, channels)
        self.k_proj = nn.Linear(channels, channels)
        self.v_proj = nn.Linear(channels, channels)
        self.out_proj = nn.Linear(channels, channels)  

    def forward(self, x):
        B, C, H, W = x.shape
        p = self.patch_size
        
        n_h, n_w = H // p, W // p  
        
        x_patches = x.view(B, C, n_h, p, n_w, p)
        x_patches = x_patches.permute(0, 2, 4, 3, 5, 1).contiguous()
        
        x_patches = x_patches.view(B, n_h, n_w, p * p, C)
        
        x_for_pad = x_patches.view(B, n_h, n_w, -1)
        x_for_pad = x_for_pad.permute(0, 3, 1, 2)  

        x_padded = F.pad(x_for_pad, (1, 1, 1, 1), mode='replicate')

        neighbors = x_padded.unfold(2, 3, 1).unfold(3, 3, 1)
        
        neighbors = neighbors.permute(0, 2, 3, 4, 5, 1).contiguous()
        
        neighbors = neighbors.view(B, n_h, n_w, 9, p * p, C)

        neighbors = neighbors.view(B, n_h, n_w, 9 * p * p, C)

        Q = self.q_proj(x_patches)  
        K = self.k_proj(neighbors)  
        V = self.v_proj(neighbors)  

        Q = Q.view(B, n_h, n_w, p * p, self.num_heads, self.head_dim)
        K = K.view(B, n_h, n_w, 9 * p * p, self.num_heads, self.head_dim)
        V = V.view(B, n_h, n_w, 9 * p * p, self.num_heads, self.head_dim)
        
        Q = Q.permute(0, 1, 2, 4, 3, 5)  # (B, n_h, n_w, heads, p*p, head_dim)
        K = K.permute(0, 1, 2, 4, 3, 5)  # (B, n_h, n_w, heads, 9*p*p, head_dim)
        V = V.permute(0, 1, 2, 4, 3, 5)
        
        attn = (Q @ K.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        
        out = attn @ V  # (B, n_h, n_w, heads, p*p, head_dim)
        
        out = out.permute(0, 1, 2, 4, 3, 5).contiguous()
        out = out.view(B, n_h, n_w, p * p, C)
        
        out = self.out_proj(out)
  
        out = out.view(B, n_h, n_w, p, p, C)
        out = out.permute(0, 5, 1, 3, 2, 4).contiguous()
        out = out.view(B, C, H, W)
        
        return out

# PT-MVSNet Backbone
class PT_FeatureExtractor(nn.Module):
    def __init__(self):
        super(PT_FeatureExtractor, self).__init__()
        
        self.conv0 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(8, 64),
            nn.ReLU(inplace=True)
        )
        self.layer1 = ConvBlock(64, 256, stride=2)   
        self.layer2 = ConvBlock(256, 512, stride=2)  
        self.layer3 = ConvBlock(512, 1024, stride=2) 
        self.layer4 = ConvBlock(1024, 2048, stride=2)

        self.latlayer4 = nn.Conv2d(2048, 256, kernel_size=1)
        self.latlayer3 = nn.Conv2d(1024, 256, kernel_size=1)
        self.latlayer2 = nn.Conv2d(512, 256, kernel_size=1)

        self.pola_enhancer = POLA_Attention(channels=256)

        self.out5 = nn.Sequential(nn.Conv2d(256, 256, 3, padding=1), nn.GroupNorm(8, 256), nn.ReLU(inplace=True))
        self.out4 = nn.Sequential(nn.Conv2d(256, 256, 3, padding=1), nn.GroupNorm(8, 256), nn.ReLU(inplace=True))
        self.out3 = nn.Sequential(nn.Conv2d(256, 256, 3, padding=1), nn.GroupNorm(8, 256), nn.ReLU(inplace=True))

    def forward(self, x):
        # 1. Bottom-up
        c1 = self.conv0(x)
        c2 = self.layer1(c1)
        c3 = self.layer2(c2)
        c4 = self.layer3(c3)
        c5 = self.layer4(c4)

        # 2. Top-down + Lateral Fusion 
        m5 = self.latlayer4(c5)
        m4 = F.interpolate(m5, scale_factor=2, mode='nearest') + self.latlayer3(c4)
        m3 = F.interpolate(m4, scale_factor=2, mode='nearest') + self.latlayer2(c3)

        # 3. POLA 
        f5 = self.pola_enhancer(self.out5(m5))
        f4 = self.pola_enhancer(self.out4(m4))
        f3 = self.pola_enhancer(self.out3(m3))

        return {"f3": f3, "f4": f4, "f5": f5}

if __name__ == "__main__":
    model = PT_FeatureExtractor()
    test = torch.randn(1, 3, 512, 640)
    features = model(test)
    for k, v in features.items():
        print(f"{k}: {v.shape}")

f3: torch.Size([1, 256, 128, 160])
f4: torch.Size([1, 256, 64, 80])
f5: torch.Size([1, 256, 32, 40])
