# Laplacian Pyramid Transformer Reflection Removal
Author: Peiyao Tao

Date: 10/19/2025

Class: CS 7180 Advanced Perception

## Purpose of the file
For this file, we introduce a laplacian-pyramid-based transformer for reflection removal tasks.

In [None]:
import json
import os
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.models import vgg19
from PIL import Image
import timm
import torchmetrics
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision.models import vgg19, VGG19_Weights

In [None]:
# Json file contains list of image pairs
JSON_FILE_PATH = "./VOC2012/VOC_results_list.json"

# Load the main JSON file
with open(JSON_FILE_PATH, 'r') as f:
    all_image_pairs = json.load(f)

# Split the list of image pairs into 80% training and 20% validation
train_pairs, val_pairs = train_test_split(all_image_pairs, test_size=0.2, random_state=42)
print(f"Data split into {len(train_pairs)} training pairs and {len(val_pairs)} validation pairs.")

# Save the split lists to new JSON files in the current directory
with open('train_list.json', 'w') as f:
    json.dump(train_pairs, f, indent=4)
with open('val_list.json', 'w') as f:
    json.dump(val_pairs, f, indent=4)

print("Created 'train_list.json' and 'val_list.json'.")

class ReflectionDataset(Dataset):
    """
    Custom Dataset for loading reflection removal data.
    Uses a JSON file to correctly pair images.
    Handles images smaller than the crop size by resizing them first.
    Applies random cropping for training, resizing for validation.
    Adds normalization for pretrained Swin Transformer.
    """
    
    def __init__(self, root_dir, json_path, crop_size=(224, 224), is_train=True):
        self.root_dir = root_dir
        self.is_train = is_train
        self.crop_size = crop_size
        self.json_path = json_path
        
        self.blended_dir = os.path.join(root_dir, 'blended')
        self.transmission_dir = os.path.join(root_dir, 'transmission_layer')
        self.reflection_dir = os.path.join(root_dir, 'reflection_layer')
        
        with open(json_path, 'r') as f:
            self.image_pairs = json.load(f)
        
        # Normalization for ImageNet pretrained models
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        """ Load and return a sample from the dataset at the given index """
        pair_info = self.image_pairs[idx]
        blended_fn = pair_info['blended']
        transmission_fn = pair_info['transmission_layer']
        reflection_fn = pair_info['reflection_layer']
        
        blended_img = Image.open(os.path.join(self.blended_dir, blended_fn)).convert('RGB')
        transmission_img = Image.open(os.path.join(self.transmission_dir, transmission_fn)).convert('RGB')
        reflection_img = Image.open(os.path.join(self.reflection_dir, reflection_fn)).convert('RGB')
        
        if self.is_train:
            # If image size smaller than crop size, resize before cropping
            if blended_img.size[0] < self.crop_size[1] or blended_img.size[1] < self.crop_size[0]:
                blended_img = TF.resize(blended_img, self.crop_size[0], interpolation=transforms.InterpolationMode.BICUBIC)
                transmission_img = TF.resize(transmission_img, self.crop_size[0], interpolation=transforms.InterpolationMode.BICUBIC)
                reflection_img = TF.resize(reflection_img, self.crop_size[0], interpolation=transforms.InterpolationMode.BICUBIC)
            
            # Random crop the input images
            i, j, h, w = transforms.RandomCrop.get_params(blended_img, output_size=self.crop_size)
            blended_img = TF.crop(blended_img, i, j, h, w)
            transmission_img = TF.crop(transmission_img, i, j, h, w)
            reflection_img = TF.crop(reflection_img, i, j, h, w)
            
            # Random horizontal flip
            if torch.rand(1) < 0.5:
                blended_img = TF.hflip(blended_img)
                transmission_img = TF.hflip(transmission_img)
                reflection_img = TF.hflip(reflection_img)
        else:
            # For validation, resize to crop size
            blended_img = TF.resize(blended_img, self.crop_size, interpolation=transforms.InterpolationMode.BICUBIC)
            transmission_img = TF.resize(transmission_img, self.crop_size, interpolation=transforms.InterpolationMode.BICUBIC)
            reflection_img = TF.resize(reflection_img, self.crop_size, interpolation=transforms.InterpolationMode.BICUBIC)

        blended_tensor = self.normalize(TF.to_tensor(blended_img))
        transmission_tensor = TF.to_tensor(transmission_img)
        reflection_tensor = TF.to_tensor(reflection_img) 

        return {
            'blended': blended_tensor,
            'transmission': transmission_tensor,
            'reflection': reflection_tensor
        }

BATCH_SIZE = 32
DATASET_ROOT_PATH = "./VOC2012" 

train_dataset = ReflectionDataset(root_dir=DATASET_ROOT_PATH, json_path='train_list.json', is_train=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

val_dataset = ReflectionDataset(root_dir=DATASET_ROOT_PATH, json_path='val_list.json', is_train=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

print(f"DataLoaders created.")

In [None]:
class CrossAttention(nn.Module):
    """ 
    Cross-Attention module for dual-stream interaction
    """
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=False)
        self.k = nn.Linear(dim, dim, bias=False)
        self.v = nn.Linear(dim, dim, bias=False)
        self.proj = nn.Linear(dim, dim)

    def forward(self, query, key, value):
        B, C, H, W = query.shape
        query = query.flatten(2).transpose(1, 2)
        key = key.flatten(2).transpose(1, 2)
        value = value.flatten(2).transpose(1, 2)

        q = self.q(query).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = self.k(key).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = self.v(value).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        out = (attn @ v).transpose(1, 2).reshape(B, H * W, C)
        out = self.proj(out)
        out = out.transpose(1, 2).view(B, C, H, W)
        return out

class ReflectionRemovalModel(nn.Module):
    """ 
    Reflection Removal Model with Dual-Stream Decoder and Cross-Attention 
    """
    def __init__(self):
        super().__init__()
        
        # Load pretrained Swin Small as backbone
        self.backbone = timm.create_model(
            'swin_small_patch4_window7_224',
            pretrained=True,
            features_only=True
        )

        for param in self.backbone.parameters():
            param.requires_grad = False

        for param in self.backbone.layers_3.parameters():
            param.requires_grad = True

        self.trans_up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(768, 384, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.trans_conv3 = nn.Sequential(
            nn.Conv2d(384 + 384, 384, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.trans_up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(384, 192, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.trans_conv2 = nn.Sequential(
            nn.Conv2d(192 + 192, 192, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.trans_up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(192, 96, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.trans_conv1 = nn.Sequential(
            nn.Conv2d(96 + 96, 96, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(96, 96, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.trans_final = nn.Sequential(
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True),
            nn.Conv2d(96, 96, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(96, 3, 3, padding=1),
            nn.Sigmoid()
        )
        
        self.ref_up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(768, 384, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.ref_conv3 = nn.Sequential(
            nn.Conv2d(384 + 384, 384, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.ref_up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(384, 192, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.ref_conv2 = nn.Sequential(
            nn.Conv2d(192 + 192, 192, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.ref_up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(192, 96, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.ref_conv1 = nn.Sequential(
            nn.Conv2d(96 + 96, 96, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(96, 96, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.ref_final = nn.Sequential(
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True),
            nn.Conv2d(96, 96, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(96, 3, 3, padding=1),
            nn.Sigmoid()
        )
        
        # Cross-attention modules at each level
        self.cross3 = CrossAttention(384)
        self.cross2 = CrossAttention(192)
        self.cross1 = CrossAttention(96)
        
        # Reflection mask head (from low-freq processing)
        self.mask_head = nn.Sequential(
            nn.Conv2d(3, 1, 3, padding=1),
            nn.Sigmoid()
        )
        
        # Refiner for high-freq Laplacian levels
        self.refiner = nn.Sequential(
            nn.Conv2d(3 + 1, 3, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(3, 3, 3, padding=1)
        )

    def forward(self, x, is_train=True, pyramid_levels=4):
        """ 
        Forward pass of the model.
        During training, process fixed-size input without pyramid.
        During evaluation, use Laplacian pyramid for full-resolution processing.
        """
        if is_train:
            features = self.backbone(x)
            
            feat0 = features[0].permute(0, 3, 1, 2)
            feat1 = features[1].permute(0, 3, 1, 2)
            feat2 = features[2].permute(0, 3, 1, 2)
            feat3 = features[3].permute(0, 3, 1, 2)
            
            # Transmission stream
            x_trans = self.trans_up3(feat3)
            x_trans = torch.cat([x_trans, feat2], dim=1)
            x_trans = self.trans_conv3(x_trans)
            
            # Reflection stream
            x_ref = self.ref_up3(feat3)
            x_ref = torch.cat([x_ref, feat2], dim=1)
            x_ref = self.ref_conv3(x_ref)
            
            # Cross-attention at level 3
            x_trans = x_trans + self.cross3(x_trans, x_ref, x_ref)
            x_ref = x_ref + self.cross3(x_ref, x_trans, x_trans)
            
            # Level 2
            x_trans = self.trans_up2(x_trans)
            x_trans = torch.cat([x_trans, feat1], dim=1)
            x_trans = self.trans_conv2(x_trans)
            
            x_ref = self.ref_up2(x_ref)
            x_ref = torch.cat([x_ref, feat1], dim=1)
            x_ref = self.ref_conv2(x_ref)
            
            # Cross-attention at level 2
            x_trans = x_trans + self.cross2(x_trans, x_ref, x_ref)
            x_ref = x_ref + self.cross2(x_ref, x_trans, x_trans)
            
            # Level 1
            x_trans = self.trans_up1(x_trans)
            x_trans = torch.cat([x_trans, feat0], dim=1)
            x_trans = self.trans_conv1(x_trans)
            
            x_ref = self.ref_up1(x_ref)
            x_ref = torch.cat([x_ref, feat0], dim=1)
            x_ref = self.ref_conv1(x_ref)
            
            # Cross-attention at level 1
            x_trans = x_trans + self.cross1(x_trans, x_ref, x_ref)
            x_ref = x_ref + self.cross1(x_ref, x_trans, x_trans)
            
            # Final outputs
            transmission = self.trans_final(x_trans)
            reflection = self.ref_final(x_ref)

            return transmission, reflection
        
        else:
            # During eval/inference, use Laplacian pyramid for full resolution
            pyramid = build_laplacian_pyramid(x, levels=pyramid_levels)
            
            # Process lowest level (low-freq) with full dual-stream
            low_freq = pyramid[0]
            features = self.backbone(low_freq)
            
            feat0 = features[0].permute(0, 3, 1, 2)
            feat1 = features[1].permute(0, 3, 1, 2)
            feat2 = features[2].permute(0, 3, 1, 2)
            feat3 = features[3].permute(0, 3, 1, 2)
            
            # Transmission/reflection as before...
            # (omit repetition, same as train block above)
            transmission_low = self.trans_final(x_trans)
            reflection_low = self.ref_final(x_ref)
            
            # Get reflection mask from low-freq reflection
            mask_low = self.mask_head(reflection_low)
            
            # Refine higher Laplacian levels (high-freq)
            refined_pyramid = [transmission_low]  # Start with refined low-freq as base
            for lap in pyramid[1:]:
                # Upsample mask to current level
                mask_up = F.interpolate(mask_low, size=lap.shape[2:], mode='bilinear', align_corners=False)
                
                # Refine lap with mask-guided conv
                refined_lap = self.refiner(lap, mask_up)
                
                refined_pyramid.append(refined_lap)
                mask_low = mask_up  # For next level
            
            # Reconstruct transmission (reflection similar if needed)
            transmission = reconstruct_laplacian_pyramid(refined_pyramid)
            reflection = x - transmission  # Simple subtract for reflection, or process similarly
            
            return transmission, reflection

In [None]:
# Cell 6: Loss Function with Perceptual and Exclusion Losses
class PerceptualLoss(nn.Module):
    """ 
    Perceptual Loss using VGG19 features 
    """
    def __init__(self):
        super().__init__()
        vgg = vgg19(weights=VGG19_Weights.DEFAULT).features
        layers = [vgg[:4], vgg[4:9], vgg[9:16], vgg[16:23]]
        self.layers = nn.ModuleList(layers)
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, pred, gt):
        loss = 0.0
        x, y = pred, gt
        for layer in self.layers:
            x = layer(x)
            y = layer(y)
            loss += F.l1_loss(x, y)
        return loss / len(self.layers)

def build_laplacian_pyramid(img, levels=4):
    """ Build Laplacian pyramid from input image """
    pyramid = []
    current = img
    for _ in range(levels - 1):
        down = F.avg_pool2d(current, kernel_size=2, stride=2)
        up = F.interpolate(down, size=current.shape[2:], mode='bilinear', align_corners=False)
        lap = current - up
        pyramid.append(lap)
        current = down
    pyramid.append(current)  # Base level
    return pyramid[::-1]  # Low to high freq

def exclusion_loss(pred_trans, pred_ref, levels=3):
    """ Exclusion loss to encourage separation of transmission and reflection """
    loss = 0.0
    trans, ref = pred_trans, pred_ref
    for _ in range(levels):
        trans_grad_x = trans[:, :, :-1, :-1] - trans[:, :, :-1, 1:]
        trans_grad_y = trans[:, :, :-1, :-1] - trans[:, :, 1:, :-1]
        ref_grad_x = ref[:, :, :-1, :-1] - ref[:, :, :-1, 1:]
        ref_grad_y = ref[:, :, :-1, :-1] - ref[:, :, 1:, :-1]
        
        loss_x = torch.mean(torch.abs(trans_grad_x * ref_grad_x))
        loss_y = torch.mean(torch.abs(trans_grad_y * ref_grad_y))
        loss += loss_x + loss_y
        
        trans = F.avg_pool2d(trans, kernel_size=2, stride=2)
        ref = F.avg_pool2d(ref, kernel_size=2, stride=2)
    return loss / levels

def gradient_loss(pred, gt, alpha=1.0):
    """ Gradient loss to preserve edges """
    # Sobel kernels for edge detection
    sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3).to(pred.device)
    sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3).to(pred.device)
    
    pred_grad_x = F.conv2d(pred.mean(1, keepdim=True), sobel_x, padding=1)
    pred_grad_y = F.conv2d(pred.mean(1, keepdim=True), sobel_y, padding=1)
    gt_grad_x = F.conv2d(gt.mean(1, keepdim=True), sobel_x, padding=1)
    gt_grad_y = F.conv2d(gt.mean(1, keepdim=True), sobel_y, padding=1)
    
    return alpha * (F.l1_loss(pred_grad_x, gt_grad_x) + F.l1_loss(pred_grad_y, gt_grad_y))

def reflection_removal_loss(pred_trans, pred_ref, gt_trans, gt_ref, blended, perceptual_module, pyramid_levels=4):
    """ Combined loss function for reflection removal """
    mse = nn.MSELoss()
    l1 = nn.L1Loss()
    
    # Base full-res losses
    l_trans = mse(pred_trans, gt_trans)
    l_ref = mse(pred_ref, gt_ref)
    l_recon = mse(pred_trans + pred_ref, blended)
    base_loss = l_trans + l_ref + l_recon
    
    perc_loss = perceptual_module(pred_trans, gt_trans)
    
    excl_loss = exclusion_loss(pred_trans, pred_ref)
    
    grad_loss = gradient_loss(pred_trans, gt_trans, alpha=1.0)
    
    total_loss = 1.0 * base_loss + 0.5 * perc_loss + 0.5 * excl_loss + 0.5 * grad_loss
    
    # Multi-scale pyramid loss
    gt_pyr_trans = build_laplacian_pyramid(gt_trans, pyramid_levels)
    pred_pyr_trans = build_laplacian_pyramid(pred_trans, pyramid_levels)
    
    ms_loss = 0.0
    for level in range(pyramid_levels):
        weight = 1.0 / (2 ** level)  # Higher weight for finer levels
        ms_loss += weight * l1(pred_pyr_trans[level], gt_pyr_trans[level])
    
    total_loss += 0.5 * ms_loss  # Add to encourage multi-scale fidelity
    
    return total_loss

In [5]:
class EarlyStopping:
    """
    Early stops the training if validation loss doesn't improve after a given patience.
    Saves the best model and allows for resuming training.
    """
    def __init__(self, patience=10, verbose=False, delta=0, path='best_model.pt'):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.path = path
        self.counter = 0
        self.best_loss = float('inf')  # Use a direct loss value, not a score
        self.early_stop = False

    def __call__(self, val_loss, model):
        # Checks if the current loss is an improvement
        if self.best_loss - val_loss > self.delta:
            # If it is, save the model and reset the counter
            self.save_checkpoint(val_loss, model)
            self.best_loss = val_loss
            self.counter = 0
        else:
            # If not, increment the counter
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter}/{self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, val_loss, model):
        """Saves a checkpoint with model state and validation loss."""
        if self.verbose:
            print(f'Val loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}). Saving model to {self.path}.')
        
        # Save a dictionary containing both the model's state and the best loss
        checkpoint = {
            'best_loss': val_loss,
            'model_state_dict': model.state_dict()
        }
        torch.save(checkpoint, self.path)

    def load_checkpoint(self, model):
        """Loads model and best loss from a checkpoint."""
        if os.path.exists(self.path):
            if self.verbose:
                print(f"Loading checkpoint from '{self.path}'")
            checkpoint = torch.load(self.path)
            model.load_state_dict(checkpoint['model_state_dict'])
            self.best_loss = checkpoint['best_loss']
            if self.verbose:
                print(f"Resuming with best validation loss: {self.best_loss:.6f}")
        else:
            if self.verbose:
                print(f"No checkpoint found at '{self.path}'. Starting from scratch.")

In [None]:
model = ReflectionRemovalModel()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

perceptual_module = PerceptualLoss().to(device)

# Optimizer
encoder_layer3_param_ids = {id(p) for p in model.backbone.layers_3.parameters()}

optimizer = optim.Adam([
    {'params': model.backbone.layers_3.parameters(), 'lr': 1e-5},
    {"params": [p for p in model.parameters() if p.requires_grad and id(p) not in encoder_layer3_param_ids], "lr": 1e-4}
])

# Learning rate scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

# Early stopping
early_stopping = EarlyStopping(patience=10, verbose=True, path='transformer_best_model.pt')
early_stopping.load_checkpoint(model)  # Load if exists
early_stopping.best_loss = float('inf')

num_epochs = 100

In [None]:
# Training loop
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
        blended = batch['blended'].to(device)
        transmission = batch['transmission'].to(device)
        reflection = batch['reflection'].to(device)
        
        optimizer.zero_grad()
        
        pred_trans, pred_ref = model(blended, is_train=model.training)
        
        loss = reflection_removal_loss(pred_trans, pred_ref, transmission, reflection, blended, perceptual_module)
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} Validation"):
            blended = batch['blended'].to(device)
            transmission = batch['transmission'].to(device)
            reflection = batch['reflection'].to(device)
            
            pred_trans, pred_ref = model(blended)
            
            loss = reflection_removal_loss(pred_trans, pred_ref, transmission, reflection, blended, perceptual_module)
            
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Val Loss: {avg_val_loss:.4f}")
    
    # Step the scheduler
    scheduler.step(avg_val_loss)
    
    # Early stopping check
    early_stopping(avg_val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping triggered.")
        break

print("Training complete.")

In [None]:
def evaluate_model(model, dataloader, device):
    """
    Evaluates the model and visualizes results with a two-plot style:
    1. Ground Truth Plot: Input, GT Transmission, GT Reflection.
    2. Prediction Plot: Predicted Transmission, Inferred Reflection (with PSNR/SSIM).
    """
    psnr_metric = torchmetrics.PeakSignalNoiseRatio().to(device)
    ssim_metric = torchmetrics.StructuralSimilarityIndexMeasure().to(device)
    
    def denormalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        tensor = tensor.clone()  # Avoid modifying original
        for t, m, s in zip(tensor, mean, std):
            t.mul_(s).add_(m)
        return tensor
    
    model.eval()
    print("Calculating metrics over the entire validation set.")
    with torch.no_grad():
        # This part for calculating average metrics
        for data in tqdm(dataloader, desc="Calculating Metrics"):
            if not data: continue
            blended = data['blended'].to(device)
            ground_truth_transmission = data['transmission'].to(device)

            with torch.cuda.amp.autocast():
                recon_bg, inferred_reflection = model(blended)  # Two outputs

            h, w = ground_truth_transmission.shape[-2:]
            recon_bg_resized = TF.resize(recon_bg, size=[h, w])
            inferred_reflection_resized = TF.resize(inferred_reflection, size=[h, w])  # Resize inferred reflection
            
            psnr_metric.update(recon_bg_resized, ground_truth_transmission)
            ssim_metric.update(recon_bg_resized, ground_truth_transmission)

    avg_psnr = psnr_metric.compute()
    avg_ssim = ssim_metric.compute()
    
    print(f"\nAverage PSNR: {avg_psnr:.4f}")
    print(f"Average SSIM: {avg_ssim:.4f}")

    print("\nDisplaying a few examples with the new plot style...")
    psnr_metric.reset() 
    ssim_metric.reset()
    
    images_shown = 0
    for data in dataloader:
        if images_shown >= 3:
            break
        if not data: continue

        blended_vis, t_vis, r_vis = data['blended'].to(device), data['transmission'].to(device), data['reflection'].to(device)
        with torch.cuda.amp.autocast():
            recon_bg, inferred_reflection = model(blended_vis)  # Two outputs

        h, w = t_vis.shape[-2:]
        recon_bg_resized = TF.resize(recon_bg, size=[h, w])
        inferred_reflection_resized = TF.resize(inferred_reflection, size=[h, w])

        image_psnr = psnr_metric(recon_bg_resized, t_vis)
        image_ssim = ssim_metric(recon_bg_resized, t_vis)

        gt_images = {
            "Input Image": denormalize(blended_vis.cpu().detach().squeeze(0)),
            "Ground Truth Transmission": t_vis.cpu().squeeze(0),
            "Ground Truth Reflection": r_vis.cpu().detach().squeeze(0)
        }
        fig1, axs1 = plt.subplots(1, 3, figsize=(18, 6))
        fig1.suptitle("Ground Truth Comparison", fontsize=16)
        for ax, (title, img) in zip(axs1, gt_images.items()):
            img_np = img.permute(1, 2, 0).numpy().astype(np.float32)
            ax.imshow(np.clip(img_np, 0, 1))
            ax.set_title(title)
            ax.axis('off')
        plt.show()

        pred_images = {
            "Predicted Transmission": recon_bg_resized.cpu().detach().squeeze(0),
            "Inferred Reflection": inferred_reflection_resized.cpu().detach().squeeze(0)
        }
        fig2, axs2 = plt.subplots(1, 2, figsize=(12, 6))
        fig2.suptitle(f"Model Predictions (Transmission PSNR: {image_psnr:.2f} dB, SSIM: {image_ssim:.4f})", fontsize=16)
        for ax, (title, img) in zip(axs2, pred_images.items()):
            img_np = img.permute(1, 2, 0).numpy().astype(np.float32)
            ax.imshow(np.clip(img_np, 0, 1))
            ax.set_title(title)
            ax.axis('off')
        plt.show()

        images_shown += 1

print("\nEvaluating the model on the validation set.")
evaluate_model(model, val_loader, device)

In [None]:
class WildSceneDataset(Dataset):
    """
    A Dataset class for the SIRR Wildscene test set.
    It walks through numbered subdirectories and loads m.jpg and g.jpg.
    """
    def __init__(self, root_dir, crop_size=(224, 224)):
        self.root_dir = root_dir
        self.crop_size = crop_size
        self.image_pairs = []
        # Normalization for ImageNet pretrained models
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        # Walk through the root directory to find all scene folders
        for image_pair_dir in sorted(os.listdir(root_dir), key=int):
            image_pair_path = os.path.join(root_dir, image_pair_dir)
            if os.path.isdir(image_pair_path):
                mixed_path = os.path.join(image_pair_path, 'm.jpg')
                gt_path = os.path.join(image_pair_path, 'g.jpg')
                reflection_path = os.path.join(image_pair_path, 'r.jpg')
               
                # Ensure both files exist before adding them to the list
                if os.path.exists(mixed_path) and os.path.exists(gt_path):
                    self.image_pairs.append({
                        'blended': mixed_path,
                        'transmission': gt_path,
                        'reflection': reflection_path
                    })
    def __len__(self):
        return len(self.image_pairs)
    
    def __getitem__(self, idx):
        pair_info = self.image_pairs[idx]
       
        blended_img = Image.open(pair_info['blended']).convert('RGB')
        transmission_img = Image.open(pair_info['transmission']).convert('RGB')
        reflection_img = Image.open(pair_info['reflection']).convert('RGB')

        blended_img = TF.resize(blended_img, self.crop_size, interpolation=transforms.InterpolationMode.BICUBIC)
        transmission_img = TF.resize(transmission_img, self.crop_size, interpolation=transforms.InterpolationMode.BICUBIC)
        reflection_img = TF.resize(reflection_img, self.crop_size, interpolation=transforms.InterpolationMode.BICUBIC)
       
        blended_tensor = self.normalize(TF.to_tensor(blended_img))
        transmission_tensor = TF.to_tensor(transmission_img)
        reflection_tensor = TF.to_tensor(reflection_img)
           
        return {
            'blended': blended_tensor,
            'transmission': transmission_tensor,
            'reflection': reflection_tensor
        }

WILD_SCENE_PATH = "./Wildscene/Wildscene"
wildscene_dataset = WildSceneDataset(root_dir=WILD_SCENE_PATH)
test_loader = DataLoader(wildscene_dataset, batch_size=1, shuffle=False, num_workers=8)
evaluate_model(model, test_loader, device)