In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
from datasets import load_dataset
from torch.utils.data import random_split
from torchvision.transforms import v2
from collections import defaultdict
import pandas as pd
import matplotlib.pyplot as plt
import os
import random

In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import numpy as np

import pytorch_lightning as pl
from torch.utils.data import DataLoader


class SegmentationDataset(Dataset):
    def __init__(self, images_dir, labels_dir):
        self.images_dir = images_dir
        self.labels_dir = labels_dir

        # Only list image files
        self.image_files = sorted(
            [f for f in os.listdir(images_dir) if f.lower().endswith(".jpg")]
        )

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        base_name = os.path.splitext(img_name)[0]

        img_path = os.path.join(self.images_dir, img_name)
        label_path = os.path.join(self.labels_dir, base_name + ".png")

        image = Image.open(img_path).convert("RGB")
        label = Image.open(label_path).convert("L")  # binary mask

        # To numpy
        image = np.array(image, dtype=np.float32) / 255.0
        label = np.array(label, dtype=np.float32) / 255.0

        # To tensors
        image = torch.from_numpy(image).permute(2, 0, 1)
        label = torch.from_numpy(label)

        # Ensure binary mask
        label = (label > 0.5).long()

        return image, label

class SegmentationDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=4, num_workers=4):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.train_dataset = SegmentationDataset(
            images_dir=os.path.join(self.data_dir, "Training Set", "Training_Images"),
            labels_dir=os.path.join(self.data_dir, "Training Set", "Training_Labels"),
        )

        self.test_dataset = SegmentationDataset(
            images_dir=os.path.join(self.data_dir, "Test Set", "Test_Images"),
            labels_dir=os.path.join(self.data_dir, "Test Set", "Test_Labels"),
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )


In [None]:
import matplotlib.pyplot as plt

def visualize_sample(dataset, idx=0):
    image, mask = dataset[idx]

    image = image.permute(1, 2, 0).numpy()
    mask = mask.numpy()

    fig, axes = plt.subplots(1, 2, figsize=(8, 4))

    axes[0].imshow(image)
    axes[0].set_title("Image")
    axes[0].axis("off")

    axes[1].imshow(mask, cmap="gray")
    axes[1].set_title("Binary Mask")
    axes[1].axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
dm = SegmentationDataModule(data_dir="dataset", batch_size=2)
dm.setup()

visualize_sample(dm.train_dataset, idx=0)
visualize_sample(dm.test_dataset, idx=0)


In [None]:
class LightningModule(pl.LightningModule):
    def __init__(self, learning_rate=1e-3, use_weighted_loss=True):
        super().__init__()
        self.save_hyperparameters()
        
        # Creating model
        self.model = smp.Unet(
            encoder_name='resnet50',
            encoder_weights='imagenet',
            in_channels=3,
            classes=3
        )
        self.training_step_outputs = defaultdict(float)
        self.validation_step_outputs = defaultdict(float)
        
    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return {"optimizer": optimizer}
    
    def training_step(self, batch, batch_idx):
        x = batch['masked_image']
        y = batch['original_image']
        mask = batch['mask']
        weights = batch['weights']
        
        y_hat = self(x)
        
        # Calcular loss con o sin pesos
        if self.hparams.use_weighted_loss:
            # Loss ponderado por los weights
            loss_per_pixel = F.l1_loss(y_hat, y, reduction='none')  # (B, C, H, W)
            # Expandir weights para que coincida con los canales de color
            weights_expanded = weights.expand_as(loss_per_pixel)  # (B, C, H, W)
            weighted_loss = loss_per_pixel * weights_expanded
            loss = weighted_loss.sum() / (weights_expanded.sum() + 1e-8)
        else:
            # Loss simple solo en áreas enmascaradas
            loss = (F.l1_loss(y_hat, y, reduction='none') * mask).sum() / (mask.sum() + 1e-8)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.training_step_outputs['loss'] += loss.detach().cpu()
        self.training_step_outputs['steps'] += 1
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x = batch['masked_image']
        y = batch['original_image']
        mask = batch['mask']
        weights = batch['weights']
        
        y_hat = self(x)
        
        # Calcular loss con o sin pesos
        if self.hparams.use_weighted_loss:
            # Loss ponderado por los weights
            loss_per_pixel = F.l1_loss(y_hat, y, reduction='none')  # (B, C, H, W)
            # Expandir weights para que coincida con los canales de color
            weights_expanded = weights.expand_as(loss_per_pixel)  # (B, C, H, W)
            weighted_loss = loss_per_pixel * weights_expanded
            loss = weighted_loss.sum() / (weights_expanded.sum() + 1e-8)
        else:
            # Loss simple solo en áreas enmascaradas
            loss = (F.l1_loss(y_hat, y, reduction='none') * mask).sum() / (mask.sum() + 1e-8)
        
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.validation_step_outputs['loss'] += loss.detach().cpu()
        self.validation_step_outputs['steps'] += 1
        
        return loss
    
    def on_train_epoch_end(self): 
        avg_loss = self.training_step_outputs['loss'] / self.training_step_outputs['steps']
        print(f"Average training loss for epoch {self.current_epoch}: {avg_loss.item():.4f}")
        self.training_step_outputs.clear() 
    
    def on_validation_epoch_end(self):
        avg_loss = self.validation_step_outputs['loss'] / self.validation_step_outputs['steps']
        print(f"Average validation loss for epoch {self.current_epoch}: {avg_loss.item():.4f}")
        self.validation_step_outputs.clear()

In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from torchvision import transforms

class InpaintingInferenceDataset(Dataset):
    """Dataset para inferencia de inpainting usando máscaras de segmentación"""
    def __init__(self, images_dir, labels_dir, mean, std):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.mean = mean
        self.std = std
        
        # Only list image files
        self.image_files = sorted(
            [f for f in os.listdir(images_dir) if f.lower().endswith(".jpg")]
        )
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        base_name = os.path.splitext(img_name)[0]
        
        img_path = os.path.join(self.images_dir, img_name)
        label_path = os.path.join(self.labels_dir, base_name + ".png")
        
        # Load image and mask
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(label_path).convert("L")  # binary mask
        
        # To numpy
        image = np.array(image, dtype=np.float32) / 255.0
        mask = np.array(mask, dtype=np.float32) / 255.0
        
        # To tensors
        image = torch.from_numpy(image).permute(2, 0, 1)  # (3, H, W)
        mask = torch.from_numpy(mask).unsqueeze(0)  # (1, H, W)
        
        # Binary mask (1 where we want to inpaint, 0 otherwise)
        mask = (mask > 0.5).float()
        
        # Create masked image (set masked regions to 0)
        masked_image = image * (1 - mask)
        
        # Normalize masked image for model input
        normalized_masked = masked_image.clone()
        for i in range(3):
            normalized_masked[i] = (normalized_masked[i] - self.mean[i]) / self.std[i]
        
        return {
            'original_image': image,  # Original sin normalizar
            'masked_image': normalized_masked,  # Normalizado para el modelo
            'mask': mask,
            'filename': img_name
        }

class InpaintingInferenceDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=4, num_workers=4):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        # Normalization parameters from training
        self.mean = [0.34437724, 0.38029198, 0.40777111]
        self.std = [0.20265734, 0.13689059, 0.11554374]
    
    def setup(self, stage=None):
        self.test_dataset = InpaintingInferenceDataset(
            images_dir=os.path.join(self.data_dir, "Test Set", "Test_Images"),
            labels_dir=os.path.join(self.data_dir, "Test Set", "Test_Labels"),
            mean=self.mean,
            std=self.std
        )
        
        self.train_dataset = InpaintingInferenceDataset(
            images_dir=os.path.join(self.data_dir, "Training Set", "Training_Images"),
            labels_dir=os.path.join(self.data_dir, "Training Set", "Training_Labels"),
            mean=self.mean,
            std=self.std
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )

def denormalize(img, mean, std):
    """Denormalize image for visualization"""
    img = img.clone()
    for i in range(3):
        img[i] = img[i] * std[i] + mean[i]
    return img.clamp(0, 1)

def visualize_inpainting_segmentation(model, dataloader, device='cuda', num_images=4, save_dir='inpainting_results'):
    """
    Visualize inpainting results on segmentation dataset
    
    Args:
        model: Trained inpainting model
        dataloader: DataLoader with inference data
        device: 'cuda' or 'cpu'
        num_images: Number of images to visualize
        save_dir: Directory to save individual results
    """
    model.eval()
    model.to(device)
    
    # Create save directory
    os.makedirs(save_dir, exist_ok=True)
    
    # Normalization parameters
    mean = [0.34437724, 0.38029198, 0.40777111]
    std = [0.20265734, 0.13689059, 0.11554374]
    
    # Get a batch
    batch = next(iter(dataloader))
    original_images = batch['original_image'][:num_images]
    masked_images_norm = batch['masked_image'][:num_images].to(device)
    masks = batch['mask'][:num_images].to(device)
    filenames = batch['filename'][:num_images]
    
    # Get predictions
    with torch.no_grad():
        predicted_images = model(masked_images_norm)
    
    # Denormalize predictions
    predicted_images_denorm = []
    for i in range(len(predicted_images)):
        pred_denorm = denormalize(predicted_images[i].cpu(), mean, std)
        predicted_images_denorm.append(pred_denorm)
    predicted_images_denorm = torch.stack(predicted_images_denorm)
    
    # Create composed images: original + predicted in masked areas
    composed_images = []
    for i in range(len(original_images)):
        composed = original_images[i] * (1 - masks[i].cpu()) + predicted_images_denorm[i] * masks[i].cpu()
        composed_images.append(composed)
    composed_images = torch.stack(composed_images)
    
    # Move to CPU
    masks = masks.cpu()
    
    # Create visualization figure
    fig, axes = plt.subplots(num_images, 4, figsize=(16, 4*num_images))
    
    # Handle single image case
    if num_images == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(num_images):
        # Convert to numpy for plotting
        original_np = original_images[i].permute(1, 2, 0).numpy()
        mask_np = masks[i, 0].numpy()
        composed_np = composed_images[i].permute(1, 2, 0).numpy()
        
        # Create masked visualization (original with mask overlay)
        masked_vis = original_images[i].clone()
        masked_vis = masked_vis * (1 - masks[i])  # Black out masked regions
        masked_vis_np = masked_vis.permute(1, 2, 0).numpy()
        
        # Calculate MAE only on masked regions
        mask_3ch = masks[i].expand(3, -1, -1)
        mae_masked = (torch.abs(predicted_images_denorm[i] - original_images[i]) * mask_3ch).sum() / (mask_3ch.sum() + 1e-8)
        
        # Plot
        axes[i, 0].imshow(original_np)
        axes[i, 0].set_title(f'Original\n{filenames[i]}', fontsize=10, fontweight='bold')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(mask_np, cmap='gray')
        axes[i, 1].set_title('Segmentation Mask', fontsize=10, fontweight='bold')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(masked_vis_np)
        axes[i, 2].set_title('Masked Input', fontsize=10, fontweight='bold')
        axes[i, 2].axis('off')
        
        axes[i, 3].imshow(composed_np)
        axes[i, 3].set_title(f'Inpainted Result\nMAE: {mae_masked:.4f}', fontsize=10, fontweight='bold')
        axes[i, 3].axis('off')
        
        # Save individual result
        plt.figure(figsize=(6, 6))
        plt.imshow(composed_np)
        plt.axis('off')
        plt.tight_layout()
        save_path = os.path.join(save_dir, f'inpainted_{filenames[i]}')
        plt.savefig(save_path, dpi=150, bbox_inches='tight', pad_inches=0)
        plt.close()
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'comparison_grid.png'), dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"Results saved to '{save_dir}/'")

# ============================================================================
# MAIN INFERENCE SCRIPT
# ============================================================================

# Load your trained inpainting model
model = LightningModule.load_from_checkpoint(
    'denoising_weights/version_1/checkpoints/best_valid_loss.ckpt',
    use_weighted_loss=True
)

# Setup data module for segmentation dataset
inpainting_dm = InpaintingInferenceDataModule(
    data_dir="dataset", 
    batch_size=4,
    num_workers=4
)
inpainting_dm.setup()

# Run inference on test set
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Running inference on {len(inpainting_dm.test_dataset)} test images...")

visualize_inpainting_segmentation(
    model=model,
    dataloader=inpainting_dm.test_dataloader(),
    device=device,
    num_images=8,  # Visualizar 8 ejemplos
    save_dir='inpainting_segmentation_results'
)

# También puedes correr en el training set si quieres
print(f"\nRunning inference on {len(inpainting_dm.train_dataset)} training images...")
visualize_inpainting_segmentation(
    model=model,
    dataloader=inpainting_dm.train_dataloader(),
    device=device,
    num_images=8,
    save_dir='inpainting_segmentation_results_train'
)