In [1]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from torch_pconv import PConv2d  # Ensure you have torch_pconv installed

In [None]:
def pair_color_depth_files(root_dir):
    color_dict = {}
    depth_dict = {}
    
    for directory, _, files in os.walk(root_dir):
        for f in files:
            full_path = os.path.join(directory, f)
            
            # Determine a prefix based on whether we're in a subfolder or not.
            if directory == root_dir:
                prefix_prefix = ""  # Flat structure: don't add folder name.
            else:
                prefix_prefix = os.path.basename(directory) + "_"
            
            # Check file naming conventions and strip specific suffixes:
            if f.endswith("_colors.png"):
                prefix = prefix_prefix + f.replace("_colors.png", "")
                color_dict[prefix] = full_path
            elif f.endswith("_depth.png"):
                prefix = prefix_prefix + f.replace("_depth.png", "")
                depth_dict[prefix] = full_path
            elif f.endswith(".jpg"):
                # For jpg files assume they are color images.
                prefix = prefix_prefix + f.replace(".jpg", "")
                color_dict[prefix] = full_path
            elif f.endswith(".png"):
                # For png files not already handled, assume they are depth images.
                prefix = prefix_prefix + f.replace(".png", "")
                depth_dict[prefix] = full_path
    
    pairs = []
    for prefix, cpath in color_dict.items():
        if prefix in depth_dict:
            pairs.append((cpath, depth_dict[prefix]))
    return pairs


# --------------------------------------------------
# Dataset: NYUDepthDataset with combined RGB+Depth
# --------------------------------------------------
class NYUDepthDataset(Dataset):
    def __init__(self, root_dir, img_size=(240, 320), transform=None, apply_mask=True):
        """
        Args:
            root_dir: Path to the folder containing color/depth pairs.
            img_size: Desired (height, width) for resizing.
            transform: Optional transforms to be applied.
            apply_mask: Whether to apply a random mask.
        """
        super(NYUDepthDataset, self).__init__()
        self.root_dir = root_dir
        self.img_size = img_size  # e.g., (240, 320)
        self.transform = transform
        self.apply_mask = apply_mask

        # Create a list of (color_path, depth_path) pairs.
        self.samples = pair_color_depth_files(root_dir)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        color_path, depth_path = self.samples[idx]

        # ----- Load color image -----
        img = cv2.imread(color_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # Resize using (width, height)
        img = cv2.resize(img, (self.img_size[1], self.img_size[0]))
        img = img.astype(np.float32) / 255.0

        # ----- Load depth image -----
        depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
        depth = cv2.resize(depth, (self.img_size[1], self.img_size[0]))
        depth = depth.astype(np.float32) / 255.0

        # ----- Apply mask if needed -----
        if self.apply_mask:
            mask_2d = self.create_random_mask(self.img_size)
            # For the color image, create a 3-channel mask.
            mask_3d = np.stack([mask_2d] * 3, axis=-1)

            masked_img = img.copy()
            masked_depth = depth.copy()

            # Set masked regions to white for RGB and 0 for depth.
            masked_img[mask_3d == 0] = 1.0
            masked_depth[mask_2d == 0] = 0.0
        else:
            mask_2d = np.ones((self.img_size[0], self.img_size[1]), dtype=np.float32) * 255
            masked_img = img
            masked_depth = depth

        # ----- Convert to Torch Tensors -----
        # Color: (H, W, 3) -> (3, H, W)
        img_tensor = torch.from_numpy(img).permute(2, 0, 1)
        masked_img_tensor = torch.from_numpy(masked_img).permute(2, 0, 1)
        # Depth: (H, W) -> (1, H, W)
        depth_tensor = torch.from_numpy(depth).unsqueeze(0)
        masked_depth_tensor = torch.from_numpy(masked_depth).unsqueeze(0)
        # Mask: (H, W) -> (1, H, W), normalized to [0,1]
        mask_tensor = torch.from_numpy(mask_2d.astype(np.float32) / 255.0).unsqueeze(0)

        # ----- Combine channels: create 4-channel tensors -----
        # Combined masked input: (3+1, H, W)
        combined_masked = torch.cat([masked_img_tensor, masked_depth_tensor], dim=0)
        # Combined target: (3+1, H, W)
        combined_target = torch.cat([img_tensor, depth_tensor], dim=0)

        if self.transform is not None:
            combined_masked = self.transform(combined_masked)
            combined_target = self.transform(combined_target)

        return {
            "combined_masked": combined_masked,
            "combined_target": combined_target,
            "mask": mask_tensor
        }

    def create_random_mask(self, size):
        """
        Creates a random mask with white (255) for unmasked and black (0) for masked areas.
        """
        H, W = size
        mask = np.full((H, W), 255, np.uint8)
        num_lines = np.random.randint(1, 10)
        for _ in range(num_lines):
            x1, x2 = np.random.randint(0, W, size=2)
            y1, y2 = np.random.randint(0, H, size=2)
            thickness = np.random.randint(1, 3)
            cv2.line(mask, (x1, y1), (x2, y2), 0, thickness)
        return mask

# --------------------------------------------------
# Model: U-Net like Inpainting Model with Partial Convolutions
# --------------------------------------------------
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super(EncoderBlock, self).__init__()
        self.pconv1 = PConv2d(in_channels, out_channels, kernel_size=kernel_size, 
                              stride=1, padding=padding, bias=True)
        self.pconv2 = PConv2d(out_channels, out_channels, kernel_size=kernel_size, 
                              stride=2, padding=padding, bias=True)
    
    def forward(self, x, mask):
        # Ensure mask is (B, H, W)
        if mask.dim() == 4 and mask.size(1) == 1:
            mask = mask.squeeze(1)
        x1, mask1 = self.pconv1(x, mask)
        x1 = F.relu(x1)
        x2, mask2 = self.pconv2(x1, mask1)
        x2 = F.relu(x2)
        return x1, mask1, x2, mask2

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels1, out_channels2, kernel_size=3, padding=1):
        super(DecoderBlock, self).__init__()
        self.pconv1 = PConv2d(skip_channels + in_channels, out_channels1, kernel_size=kernel_size, 
                              stride=1, padding=padding, bias=True)
        self.pconv2 = PConv2d(out_channels1, out_channels2, kernel_size=kernel_size, 
                              stride=1, padding=padding, bias=True)
    
    def forward(self, x, mask, skip_x, skip_mask):
        x_up = F.interpolate(x, scale_factor=2, mode='nearest')
        if mask.dim() == 4 and mask.size(1) == 1:
            mask = mask.squeeze(1)
        mask_up = F.interpolate(mask.unsqueeze(1), scale_factor=2, mode='nearest').squeeze(1)
        if skip_mask.dim() == 4 and skip_mask.size(1) == 1:
            skip_mask = skip_mask.squeeze(1)
        # Combine masks using maximum (logical OR)
        mask_cat = torch.max(skip_mask, mask_up)
        x_cat = torch.cat([skip_x, x_up], dim=1)
        x1, mask1 = self.pconv1(x_cat, mask_cat)
        x1 = F.relu(x1)
        x2, mask2 = self.pconv2(x1, mask1)
        x2 = F.relu(x2)
        return x1, mask1, x2, mask2

class InpaintingModel(nn.Module):
    def __init__(self, input_channels=4, output_channels=4):
        """
        Args:
            input_channels: Number of channels for the combined input (3 for RGB + 1 for depth).
            output_channels: Number of channels for the combined output.
        """
        super(InpaintingModel, self).__init__()
        # Encoder
        self.enc1 = EncoderBlock(input_channels, 32)
        self.enc2 = EncoderBlock(32, 64)
        self.enc3 = EncoderBlock(64, 128)
        self.enc4 = EncoderBlock(128, 256)
        
        # Decoder with skip connections.
        self.dec1 = DecoderBlock(256, 256, 256, 128)
        self.dec2 = DecoderBlock(128, 128, 128, 64)
        self.dec3 = DecoderBlock(64, 64, 64, 32)
        self.dec4 = DecoderBlock(32, 32, 32, output_channels)
        
        self.final_conv = nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1)
    
    def forward(self, x, mask):
        if mask.dim() == 4 and mask.size(1) == 1:
            mask = mask.squeeze(1)
        conv1, mask1, conv2, mask2 = self.enc1(x, mask)
        conv3, mask3, conv4, mask4 = self.enc2(conv2, mask2)
        conv5, mask5, conv6, mask6 = self.enc3(conv4, mask4)
        conv7, mask7, conv8, mask8 = self.enc4(conv6, mask6)
        
        conv9, mask9, conv10, mask10 = self.dec1(conv8, mask8, conv7, mask7)
        conv11, mask11, conv12, mask12 = self.dec2(conv10, mask10, conv5, mask5)
        conv13, mask13, conv14, mask14 = self.dec3(conv12, mask12, conv3, mask3)
        conv15, mask15, conv16, mask16 = self.dec4(conv14, mask14, conv1, mask1)
        
        out = self.final_conv(conv16)
        return torch.sigmoid(out)

# --------------------------------------------------
# DataLoader Setup
# --------------------------------------------------
# Replace these paths with your actual directories.
train_data_path = "nyu_data/data/nyu2_train"
test_data_path = "nyu_data/data/nyu2_test"

# Create dataset instances.
train_dataset = NYUDepthDataset(train_data_path, img_size=(240,320), apply_mask=True)
test_dataset = NYUDepthDataset(test_data_path, img_size=(240,320), apply_mask=False)

# Create DataLoaders.
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# --------------------------------------------------
# Training Setup
# --------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = InpaintingModel(input_channels=4, output_channels=4).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.L1Loss()

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for batch in dataloader:
        inputs = batch["combined_masked"].to(device)
        mask = batch["mask"].to(device)
        targets = batch["combined_target"].to(device)
        optimizer.zero_grad()
        outputs = model(inputs, mask)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    return running_loss / len(dataloader.dataset)

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for batch in dataloader:
            inputs = batch["combined_masked"].to(device)
            mask = batch["mask"].to(device)
            targets = batch["combined_target"].to(device)
            outputs = model(inputs, mask)
            loss = criterion(outputs, targets)
            running_loss += loss.item() * inputs.size(0)
    return running_loss / len(dataloader.dataset)

num_epochs = 20
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss = validate(model, test_loader, criterion, device)
    print(f"Epoch {epoch+1}/{num_epochs} Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

torch.save(model.state_dict(), "nyu_inpainting_depth_model.pth")
device = "cpu"
model = InpaintingModel(input_channels=4).to(device)
model.load_state_dict(torch.load("nyu_inpainting_depth_model.pth", map_location=device))

# --------------------------------------------------
# Inference & Visualization
# --------------------------------------------------
num_samples = 32
samples = [test_dataset[i] for i in range(num_samples)]
combined_masked = torch.stack([s["combined_masked"] for s in samples], dim=0)
mask_tensor = torch.stack([s["mask"] for s in samples], dim=0)
combined_target = torch.stack([s["combined_target"] for s in samples], dim=0)

model.eval()
predictions = []
with torch.no_grad():
    for i in range(num_samples):
        inp = combined_masked[i].unsqueeze(0).to(device)
        msk = mask_tensor[i].unsqueeze(0).to(device)
        pred = model(inp, msk)
        predictions.append(pred.squeeze(0).cpu())

# Visualization: show masked input, predicted output, and ground truth for RGB channels.
fig, axs = plt.subplots(nrows=num_samples, ncols=3, figsize=(10, num_samples * 2))
for i in range(num_samples):
    # Extract the first 3 channels (RGB) for visualization.
    masked_rgb = combined_masked[i][:3].permute(1, 2, 0).numpy()
    pred_rgb = predictions[i][:3].permute(1, 2, 0).numpy()
    target_rgb = combined_target[i][:3].permute(1, 2, 0).numpy()
    
    axs[i, 0].imshow(masked_rgb)
    axs[i, 0].set_title("Masked Input (RGB)")
    axs[i, 0].axis("off")
    
    axs[i, 1].imshow(pred_rgb)
    axs[i, 1].set_title("Predicted Output (RGB)")
    axs[i, 1].axis("off")
    
    axs[i, 2].imshow(target_rgb)
    axs[i, 2].set_title("Ground Truth (RGB)")
    axs[i, 2].axis("off")

plt.tight_layout()
plt.show()


Using device: cpu


  from .autonotebook import tqdm as notebook_tqdm


KeyboardInterrupt: 