In [1]:
import os
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, random_split
from torch_pconv import PConv2d  # Partial convolution layer; ensure this is installed.

In [2]:
def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )

# Encoder module that downsamples the input and returns skip connections.
class Encoder(nn.Module):
    def __init__(self, in_channels, base_channels=8, levels=3):
        super(Encoder, self).__init__()
        self.levels = levels
        self.blocks = nn.ModuleList()
        for i in range(levels):
            out_channels = base_channels * (2 ** i)
            self.blocks.append(conv_block(in_channels, out_channels))
            in_channels = out_channels
        self.pool = nn.MaxPool2d(2)
        
    def forward(self, x):
        skips = []
        for block in self.blocks:
            x = block(x)
            skips.append(x)
            x = self.pool(x)
        return skips, x  # return list of skip features and bottleneck

# Decoder block that upsamples and fuses with the skip connection.
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = conv_block(in_channels + skip_channels, out_channels)
        
    def forward(self, x, skip):
        x = self.upsample(x)
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        return x

# Decoder that uses a list of DecoderBlocks.
class Decoder(nn.Module):
    def __init__(self, bottleneck_channels, skip_channels, levels=3):
        super(Decoder, self).__init__()
        self.levels = levels
        self.blocks = nn.ModuleList()
        in_channels = bottleneck_channels
        # Skip channels are provided in order from shallowest to deepest.
        for i in range(levels):
            # Use skip from the reverse order.
            skip_ch = skip_channels[-(i+1)]
            out_channels = skip_ch // 2
            self.blocks.append(DecoderBlock(in_channels, skip_ch, out_channels))
            in_channels = out_channels
        self.final_conv = nn.Conv2d(in_channels, 3, kernel_size=1)  # Output 3 channels

    def forward(self, x, skips):
        for i, block in enumerate(self.blocks):
            skip = skips[-(i+1)]
            x = block(x, skip)
        x = self.final_conv(x)
        return x

# Simple encoder-decoder network for inpainting.
class SimpleInpaintingNet(nn.Module):
    def __init__(self, base_channels=8, levels=3):
        super(SimpleInpaintingNet, self).__init__()
        # One encoder for the masked input image (3 channels).
        self.encoder = Encoder(in_channels=3, base_channels=base_channels, levels=levels)
        
        # Since we are using only one encoder, the skip features are not doubled.
        # Each encoder level produces channels: base_channels * (2**i)
        skip_channels = [base_channels * (2 ** i) for i in range(levels)]
        # The bottleneck has channels equal to the last encoder block's output:
        bottleneck_channels = base_channels * (2 ** (levels - 1))
        
        self.decoder = Decoder(bottleneck_channels=bottleneck_channels, skip_channels=skip_channels, levels=levels)
        
    def forward(self, masked_img):
        skips, bottleneck = self.encoder(masked_img)
        out = self.decoder(bottleneck, skips)
        return torch.sigmoid(out) 

In [3]:
# A simple convolutional block using built-in Sequential.
def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )

# Encoder module that downsamples the input and returns skip connections.
class Encoder(nn.Module):
    def __init__(self, in_channels, base_channels=8, levels=3):
        super(Encoder, self).__init__()
        self.levels = levels
        self.blocks = nn.ModuleList()
        for i in range(levels):
            out_channels = base_channels * (2 ** i)
            self.blocks.append(conv_block(in_channels, out_channels))
            in_channels = out_channels
        self.pool = nn.MaxPool2d(2)
        
    def forward(self, x):
        skips = []
        for block in self.blocks:
            x = block(x)
            skips.append(x)
            x = self.pool(x)
        return skips, x  # return list of skip features and bottleneck

# Decoder block that upsamples and fuses with the skip connection.
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = conv_block(in_channels + skip_channels, out_channels)
        
    def forward(self, x, skip):
        x = self.upsample(x)
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        return x

# Decoder that uses a list of DecoderBlocks.
class Decoder(nn.Module):
    def __init__(self, bottleneck_channels, skip_channels, levels=3):
        super(Decoder, self).__init__()
        self.levels = levels
        self.blocks = nn.ModuleList()
        in_channels = bottleneck_channels
        # skip_channels should be provided in order from shallowest to deepest.
        for i in range(levels):
            # Use skip from the reverse order.
            skip_ch = skip_channels[-(i+1)]
            out_channels = skip_ch // 2
            self.blocks.append(DecoderBlock(in_channels, skip_ch, out_channels))
            in_channels = out_channels
        self.final_conv = nn.Conv2d(in_channels, 3, kernel_size=1)
        
    def forward(self, x, skips):
        for i, block in enumerate(self.blocks):
            skip = skips[-(i+1)]
            x = block(x, skip)
        x = self.final_conv(x)
        return x
    

class SimpleAttentionFusion(nn.Module):
    def __init__(self, in_channels, reduction=4):
        """
        Args:
            in_channels (int): Number of channels from each branch (assumed equal).
            reduction (int): Factor to reduce the number of channels for the attention computation.
        """
        super(SimpleAttentionFusion, self).__init__()
        # Here we concatenate RGB and depth features: total channels = 2 * in_channels
        # The conv layers compute an attention map that will be applied to the RGB features.
        self.conv1   = nn.Conv2d(in_channels * 2, in_channels // reduction, kernel_size=1)
        self.relu    = nn.ReLU(inplace=True)
        self.conv2   = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, rgb_encoded, depth_encoded):
        # Concatenate the RGB and depth features along the channel dimension.
        combined = torch.cat([rgb_encoded, depth_encoded], dim=1)
        # Compute a lower-dimensional representation.
        attn = self.conv1(combined)
        attn = self.relu(attn)
        # Map back to the original channel dimension.
        attn = self.conv2(attn)
        # Convert to an attention mask with values in the range [0, 1].
        attn = self.sigmoid(attn)
        # Use the attention mask to modulate the RGB features, and fuse with depth via addition.
        fused = rgb_encoded * attn + depth_encoded
        return fused


class DepthEnhancedInpaintingNet_SimpleAttention(nn.Module):
    def __init__(self, base_channels=8, levels=3, reduction=4):
        super(DepthEnhancedInpaintingNet_SimpleAttention, self).__init__()
        # Encoders for RGB (3 channels) and depth (1 channel)
        self.rgb_encoder   = Encoder(in_channels=3, base_channels=base_channels, levels=levels)
        self.depth_encoder = Encoder(in_channels=1, base_channels=base_channels, levels=levels)
        
        # Bottleneck channels for each encoder (last block channels)
        bottleneck_channels = base_channels * (2 ** (levels - 1))
        
        # Replace the multihead attention fusion with a simple attention fusion module.
        self.fusion = SimpleAttentionFusion(in_channels=bottleneck_channels, reduction=reduction)
        
        # For the skip features, we can continue to fuse by concatenation if desired.
        fused_skips = [base_channels * (2 ** i) * 2 for i in range(levels)]
        self.decoder = Decoder(bottleneck_channels=bottleneck_channels, skip_channels=fused_skips, levels=levels)
        
    def forward(self, rgb, depth):
        rgb_skips, rgb_bottleneck     = self.rgb_encoder(rgb)
        depth_skips, depth_bottleneck = self.depth_encoder(depth)
        
        # Fuse the bottleneck features with the simple attention mechanism.
        fused_bottleneck = self.fusion(rgb_bottleneck, depth_bottleneck)
        
        # Fuse skip connections by simple concatenation.
        fused_skips = [torch.cat([r, d], dim=1) for r, d in zip(rgb_skips, depth_skips)]
        
        # Decode to produce the inpainted RGB image.
        out = self.decoder(fused_bottleneck, fused_skips)
        return out


In [4]:
# A simple convolutional block using built-in Sequential.
def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )

# Encoder module that downsamples the input and returns skip connections.
class Encoder(nn.Module):
    def __init__(self, in_channels, base_channels=8, levels=3):
        super(Encoder, self).__init__()
        self.levels = levels
        self.blocks = nn.ModuleList()
        for i in range(levels):
            out_channels = base_channels * (2 ** i)
            self.blocks.append(conv_block(in_channels, out_channels))
            in_channels = out_channels
        self.pool = nn.MaxPool2d(2)
        
    def forward(self, x):
        skips = []
        for block in self.blocks:
            x = block(x)
            skips.append(x)
            x = self.pool(x)
        return skips, x  # return list of skip features and bottleneck

# Decoder block that upsamples and fuses with the skip connection.
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = conv_block(in_channels + skip_channels, out_channels)
        
    def forward(self, x, skip):
        x = self.upsample(x)
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        return x

# Decoder that uses a list of DecoderBlocks.
class Decoder(nn.Module):
    def __init__(self, bottleneck_channels, skip_channels, levels=3):
        super(Decoder, self).__init__()
        self.levels = levels
        self.blocks = nn.ModuleList()
        in_channels = bottleneck_channels
        # skip_channels should be provided in order from shallowest to deepest.
        for i in range(levels):
            # Use skip from the reverse order.
            skip_ch = skip_channels[-(i+1)]
            out_channels = skip_ch // 2
            self.blocks.append(DecoderBlock(in_channels, skip_ch, out_channels))
            in_channels = out_channels
        self.final_conv = nn.Conv2d(in_channels, 3, kernel_size=1)
        
    def forward(self, x, skips):
        for i, block in enumerate(self.blocks):
            skip = skips[-(i+1)]
            x = block(x, skip)
        x = self.final_conv(x)
        return x
    

# A self-attention module using PyTorch's built-in MultiheadAttention.
class MultiheadSelfAttention(nn.Module):
    def __init__(self, in_dim, num_heads=4):
        super(MultiheadSelfAttention, self).__init__()
        # nn.MultiheadAttention expects input shape (B, L, C) if batch_first=True.
        self.mha = nn.MultiheadAttention(embed_dim=in_dim, num_heads=num_heads, batch_first=True)
        
    def forward(self, x):
        # x shape: (B, C, H, W) --> flatten spatial dims to sequence: (B, H*W, C)
        B, C, H, W = x.size()
        x_flat = x.view(B, C, H * W).permute(0, 2, 1)  # now (B, L, C) with L=H*W
        attn_output, _ = self.mha(x_flat, x_flat, x_flat)
        # reshape back to (B, C, H, W)
        attn_output = attn_output.permute(0, 2, 1).view(B, C, H, W)
        return attn_output


# Full model: Dual-branch encoder for RGB and depth, attention fusion, then decoder.
class DepthEnhancedInpaintingNet(nn.Module):
    def __init__(self, base_channels=8, levels=3, num_heads=4):
        super(DepthEnhancedInpaintingNet, self).__init__()
        # Encoders for RGB (3 channels) and depth (1 channel)
        self.rgb_encoder   = Encoder(in_channels=3, base_channels=base_channels, levels=levels)
        self.depth_encoder = Encoder(in_channels=1, base_channels=base_channels, levels=levels)
        
        # Bottleneck channels for each encoder (last block channels)
        bottleneck_rgb   = base_channels * (2 ** (levels - 1))
        bottleneck_depth = base_channels * (2 ** (levels - 1))
        fused_bottleneck_channels = bottleneck_rgb + bottleneck_depth
        
        # Fuse bottleneck features with multihead self-attention
        self.attention = MultiheadSelfAttention(in_dim=fused_bottleneck_channels, num_heads=num_heads)
        
        # Skip channels: each encoder level produces channels = base_channels * (2**i)
        # Fused skip channels from both branches are doubled.
        fused_skips = [base_channels * (2 ** i) * 2 for i in range(levels)]
        self.decoder = Decoder(bottleneck_channels=fused_bottleneck_channels, skip_channels=fused_skips, levels=levels)
        
    def forward(self, rgb, depth):
        rgb_skips, rgb_bottleneck     = self.rgb_encoder(rgb)
        depth_skips, depth_bottleneck = self.depth_encoder(depth)
        
        # Fuse bottleneck features by concatenation
        fused_bottleneck = torch.cat([rgb_bottleneck, depth_bottleneck], dim=1)
        fused_bottleneck = self.attention(fused_bottleneck)
        
        # Fuse skip connections by concatenating corresponding features
        fused_skips = [torch.cat([r, d], dim=1) for r, d in zip(rgb_skips, depth_skips)]
        
        # Decode to produce the inpainted RGB image
        out = self.decoder(fused_bottleneck, fused_skips)
        return out


In [5]:
def count_parameters(model: torch.nn.Module) -> int:
    """Returns the number of trainable parameters in `model`."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# instantiate each
models = {
    "SimpleInpaintingNet": SimpleInpaintingNet(base_channels=8, levels=3),
    "DepthSimpleAttention": DepthEnhancedInpaintingNet_SimpleAttention(base_channels=8, levels=3, reduction=4),
    "DepthMultiheadAttention": DepthEnhancedInpaintingNet(base_channels=8, levels=3, num_heads=4),
}

for name, m in models.items():
    print(f"{name:25s}: {count_parameters(m):,} parameters")

SimpleInpaintingNet      : 33,711 parameters
DepthSimpleAttention     : 89,107 parameters
DepthMultiheadAttention  : 114,155 parameters
