In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from typing import Tuple, List, Optional
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import random

random.seed(42)

class SampledCelebA(Dataset):
    def __init__(self, root_dir, attr_path, transform=None, sample_size=5000, seed=42):
        """
        Args:
            root_dir (string): Directory with all the images.
            attr_path (string): Path to attributes file.
            transform (callable, optional): Optional transform to be applied on a sample.
            sample_size (int): Number of images to sample from the dataset.
            seed (int): Random seed for reproducibility.
        """
        self.root_dir = root_dir
        self.transform = transform
        
        # Set random seed for reproducibility
        random.seed(seed)
        
        # Get all image filenames
        all_image_files = os.listdir(root_dir)
        
        # Sample a subset of images
        self.sample_size = min(sample_size, len(all_image_files))
        self.image_files = random.sample(all_image_files, self.sample_size)
        
        # Read attributes (optional)
        try:
            self.attr_df = pd.read_csv(attr_path, delim_whitespace=True, header=1)
        except:
            print("Warning: Could not read attributes file. Will return zero attributes.")
            self.attr_df = None

    def __len__(self):
        return self.sample_size
        
    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')  # Ensure RGB format
        
        # Default empty attributes tensor
        attributes = torch.zeros(40, dtype=torch.float32)
        
        if self.transform:
            image = self.transform(image)
            
        return image, attributes

# Define transforms
transform_raw = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])


# 1. Sharpening Filter
class SharpeningFilter(nn.Module):
    def __init__(self, strength=1.0):
        super().__init__()
        kernel_init = torch.tensor([
            [0, -1, 0],
            [-1, 5, -1],
            [0, -1, 0]
        ], dtype=torch.float32)
        
        identity = torch.eye(3, dtype=torch.float32)
        kernel_init = identity + (kernel_init - identity) * strength
        
        kernel_init = kernel_init.unsqueeze(0).unsqueeze(0)
        self.kernel = nn.Parameter(kernel_init)
    
    def forward(self, x):
        b, c, h, w = x.shape
        out = torch.zeros_like(x)
        for i in range(c):
            channel = x[:, i:i+1, :, :]
            out[:, i:i+1, :, :] = F.conv2d(
                channel, 
                self.kernel.expand(1, 1, 3, 3),
                padding=1
            )
        return torch.clamp(out, 0, 1)

# 2. Edge Detection
class SobelEdgeDetection(nn.Module):
    def __init__(self, threshold=0.1):
        super().__init__()
        sobel_x = torch.tensor([
            [-1, 0, 1],
            [-2, 0, 2],
            [-1, 0, 1]
        ], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        
        sobel_y = torch.tensor([
            [-1, -2, -1],
            [0, 0, 0],
            [1, 2, 1]
        ], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        
        self.sobel_x = nn.Parameter(sobel_x)
        self.sobel_y = nn.Parameter(sobel_y)
        self.threshold = nn.Parameter(torch.tensor(threshold))
        self.rgb_weights = nn.Parameter(torch.tensor([0.299, 0.587, 0.114], dtype=torch.float32))
    
    def forward(self, x):
        b, c, h, w = x.shape
        gray = torch.sum(x * self.rgb_weights.view(1, 3, 1, 1), dim=1, keepdim=True)
        grad_x = F.conv2d(gray, self.sobel_x, padding=1)
        grad_y = F.conv2d(gray, self.sobel_y, padding=1)
        magnitude = torch.sqrt(grad_x**2 + grad_y**2 + 1e-8)
        magnitude = magnitude / (torch.max(magnitude, dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] + 1e-8)
        edges = torch.sigmoid((magnitude - self.threshold) * 10)
        return edges.repeat(1, 3, 1, 1)

# 3. Median Filter
class DifferentiableMedianFilter(nn.Module):
    def __init__(self, kernel_size=3):
        super().__init__()
        self.kernel_size = kernel_size
        self.pad = kernel_size // 2
        self.temperature = nn.Parameter(torch.tensor(0.1))
    
    def forward(self, x):
        b, c, h, w = x.shape
        x_padded = F.pad(x, (self.pad, self.pad, self.pad, self.pad), mode='reflect')
        out = torch.zeros_like(x)
        
        for channel in range(c):
            patches = F.unfold(x_padded[:, channel:channel+1, :, :], 
                              kernel_size=self.kernel_size, 
                              stride=1)
            patches = patches.reshape(b, self.kernel_size*self.kernel_size, h*w)
            sorted_patches, _ = torch.sort(patches, dim=1)
            median_idx = self.kernel_size * self.kernel_size // 2
            median_values = sorted_patches[:, median_idx, :]
            out[:, channel, :, :] = median_values.reshape(b, h, w)
        
        return out

# 4. Contrast Enhancement
class ContrastEnhancement(nn.Module):
    def __init__(self, alpha=1.5, beta=0):
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(alpha, dtype=torch.float32))
        self.beta = nn.Parameter(torch.tensor(beta, dtype=torch.float32))
    
    def forward(self, x):
        enhanced = self.alpha * x + self.beta
        return torch.clamp(enhanced, 0, 1)

# 5. Bilateral Filter
class BilateralFilter(nn.Module):
    def __init__(self, kernel_size=5, sigma_space=1.0, sigma_color=0.1):
        super().__init__()
        self.kernel_size = kernel_size
        self.sigma_space = nn.Parameter(torch.tensor(sigma_space))
        self.sigma_color = nn.Parameter(torch.tensor(sigma_color))
        self.padding = kernel_size // 2
        
        x = torch.arange(-(kernel_size // 2), kernel_size // 2 + 1, dtype=torch.float32)
        y = torch.arange(-(kernel_size // 2), kernel_size // 2 + 1, dtype=torch.float32)
        xx, yy = torch.meshgrid(x, y)
        spatial_kernel = torch.exp(-(xx**2 + yy**2) / (2 * sigma_space**2))
        self.register_buffer('spatial_kernel', spatial_kernel)
    
    def forward(self, x):
        # For computational efficiency, we'll implement a simplified version
        # for integration in the pipeline
        b, c, h, w = x.shape
        x_padded = F.pad(x, (self.padding, self.padding, self.padding, self.padding), mode='reflect')
        
        # Extract patches
        patches = F.unfold(x_padded, kernel_size=self.kernel_size, stride=1)
        patches = patches.view(b, c, self.kernel_size**2, h*w)
        
        # Get center pixels for each patch
        center_idx = self.kernel_size**2 // 2
        centers = patches[:, :, center_idx:center_idx+1, :]
        
        # Calculate color distance
        color_diff = patches - centers
        color_weight = torch.exp(-(color_diff**2) / (2 * self.sigma_color**2))
        
        # Apply spatial and color weights
        weight = self.spatial_kernel.view(1, 1, -1, 1) * color_weight
        weight = weight / (weight.sum(dim=2, keepdim=True) + 1e-8)
        
        # Apply weighted average
        out = torch.sum(patches * weight, dim=2)
        out = out.view(b, c, h, w)
        
        return out

# 6. Unsharp Masking
class UnsharpMasking(nn.Module):
    def __init__(self, strength=1.5, kernel_size=5, sigma=1.0):
        super().__init__()
        self.strength = nn.Parameter(torch.tensor(strength))
        self.blur = GaussianBlur(kernel_size=kernel_size, sigma=sigma)
    
    def forward(self, x):
        blurred = self.blur(x)
        mask = x - blurred
        sharpened = x + self.strength * mask
        return torch.clamp(sharpened, 0, 1)

# --------------------------------
# From your original code (reused)
# --------------------------------

# Gaussian Blur (from your original code)
class GaussianBlur(nn.Module):
    def __init__(self, kernel_size=5, sigma=1.5):
        super().__init__()
        self.kernel_size = kernel_size
        self.sigma = sigma

    def forward(self, x):
        channels = x.shape[1]
        kernel_1d = torch.arange(self.kernel_size, dtype=torch.float32) - self.kernel_size // 2
        kernel_1d = torch.exp(-kernel_1d ** 2 / (2 * self.sigma ** 2))
        kernel_1d = kernel_1d / kernel_1d.sum()
        kernel_2d = torch.outer(kernel_1d, kernel_1d)
        kernel = kernel_2d.expand(channels, 1, -1, -1).to(x.device)
        return F.conv2d(x, kernel, padding=self.kernel_size // 2, groups=channels)

# White Balance (from your original code)
class WhiteBalance(nn.Module):
    def __init__(self, init_gains=(1.2, 1.0, 0.9)):
        super().__init__()
        self.gains = nn.Parameter(torch.tensor(init_gains, dtype=torch.float32))

    def forward(self, x):
        return x * self.gains.view(1, -1, 1, 1)

# Gamma Correction (from your original code)
class GammaCorrection(nn.Module):
    def __init__(self, init_gamma=2.2):
        super().__init__()
        self.gamma = nn.Parameter(torch.tensor(init_gamma))

    def forward(self, x):
        return torch.clamp(x, 1e-8, 1) ** (1 / self.gamma)

# Upsampling CNN (from your original code)
class UpsampleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Conv2d(3, 64, 5, padding=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(64, 3, 3, padding=1)
        )

    def forward(self, x):
        return self.upsample(x)

# --------------------------------
# Complete Pipeline Implementations
# --------------------------------

class CompletePipeline(nn.Module):
    """Complete differentiable image processing pipeline with filter, gamma correction, and upsampling"""
    def __init__(self, filter_type='bilateral'):
        super().__init__()
        # Downsampling (fixed operation, not learnable)
        self.downsample_factor = 0.5
        
        # Image processing components
        self.white_balance = WhiteBalance(init_gains=(1.2, 1.0, 0.9))
        
        # Choose filter based on filter_type
        if filter_type == 'sharpening':
            self.filter = SharpeningFilter()
        elif filter_type == 'edge_detection':
            self.filter = SobelEdgeDetection()
        elif filter_type == 'median':
            self.filter = DifferentiableMedianFilter()
        elif filter_type == 'contrast':
            self.filter = ContrastEnhancement()
        elif filter_type == 'bilateral':
            self.filter = BilateralFilter()
        elif filter_type == 'unsharp_masking':
            self.filter = UnsharpMasking()
        else:  # Default to Gaussian Blur
            self.filter = GaussianBlur()
        
        # Gamma correction
        self.gamma = GammaCorrection(init_gamma=2.2)
        
        # Upsampling CNN
        self.upsample = UpsampleCNN()
        
        self.filter_type = filter_type
    
    def forward(self, x_hr):
        # Downsample input (simulate low-res image)
        x_lr = F.interpolate(x_hr, scale_factor=self.downsample_factor, mode='bicubic', align_corners=False)
        
        # Apply white balance
        x_lr = self.white_balance(x_lr)
        
        # Apply specific filter
        x_lr = self.filter(x_lr)
        
        # Apply gamma correction
        x_lr = self.gamma(x_lr)
        
        # Apply upsampling CNN
        x_sr = self.upsample(x_lr)
        
        return x_sr, x_lr  # Return both the final output and the filtered low-res image

In [10]:
# --------------------------------
# Training and Evaluation
# --------------------------------
filter_types = [
        'bilateral',       # Edge-preserving smoothing
        'sharpening',      # Edge enhancement
        'median',          # Noise reduction
        'contrast',        # Contrast enhancement
        'unsharp_masking'  # Another sharpening technique
    ]
for filter_type in filter_types:
    print(f"\n===== Processing with {filter_type} filter =====")
    
    pipeline = CompletePipeline(filter_type=filter_type)
    print(pipeline.filter)
 


===== Processing with bilateral filter =====
BilateralFilter()

===== Processing with sharpening filter =====
SharpeningFilter()

===== Processing with median filter =====
DifferentiableMedianFilter()

===== Processing with contrast filter =====
ContrastEnhancement()

===== Processing with unsharp_masking filter =====
UnsharpMasking(
  (blur): GaussianBlur()
)
