In [1]:
# Import necessary libraries
import random
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import lpips
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from rdrobust import rdrobust, rdplot

# Set device (GPU/CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
# Data Preprocessing and Dataset Definition
class BlurredImageDataset(Dataset):
    def __init__(self, input_dir, target_dir, transform=None):
        # Initialize the paths for noisy and clean images
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.transform = transform
        
        # Get the list of filenames in both directories, filter out directories
        self.input_files = [f for f in sorted(os.listdir(input_dir)) if os.path.isfile(os.path.join(input_dir, f))]
        self.target_files = [f for f in sorted(os.listdir(target_dir)) if os.path.isfile(os.path.join(target_dir, f))]
        
        # Ensure both directories have the same number of images
        assert len(self.input_files) == len(self.target_files), "Mismatch between input and target images!"

    def __len__(self):
        return len(self.input_files)  # Return the number of images in the dataset

    def __getitem__(self, idx):
        # Construct the file paths for input and target images
        input_image_path = os.path.join(self.input_dir, self.input_files[idx])
        target_image_path = os.path.join(self.target_dir, self.target_files[idx])
        
        # Open the images and convert them to RGB mode
        input_image = Image.open(input_image_path).convert('RGB')
        target_image = Image.open(target_image_path).convert('RGB')
        
        # Apply any transformations (such as resizing, normalizing, etc.)
        if self.transform:
            input_image = self.transform(input_image)
            target_image = self.transform(target_image)
        
        # Return a pair of input (noisy) and target (clean) images
        return input_image, target_image


In [3]:

# Function to add Gaussian noise to an image
def add_gaussian_noise(image, mean=0, std=0.5):
    """Add Gaussian noise to an image"""
    np_image = np.array(image)
    noise = np.random.normal(mean, std, np_image.shape)
    noisy_image = np_image + noise * 255  # Adjust noise level
    noisy_image = np.clip(noisy_image, 0, 255)  # Clip to valid range [0, 255]
    return Image.fromarray(noisy_image.astype(np.uint8))

# Function to apply motion blur to an image
def apply_motion_blur(image, size=5, angle=45):
    """Apply motion blur to an image"""
    image = image.convert('RGB')
    # Generate a motion blur kernel
    kernel = ImageFilter.GaussianBlur(radius=size)
    blurred_image = image.filter(kernel)
    return blurred_image

# Function to generate simulated dataset (noisy and blurred images)
def generate_simulated_dataset(num_images=1000, image_size=(96, 96), output_dir='simulated_dataset'):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Example: Use random images from torchvision datasets or generate random images
    # For simplicity, using random images with solid color and basic shapes
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
    ])

    # Create dataset
    for i in range(num_images):
        # Create a random image (e.g., solid color, gradient, or random noise)
        img = np.random.rand(*image_size, 3) * 255  # Random noise as a basic image
        img = Image.fromarray(img.astype(np.uint8))

        # Apply motion blur
        blurred_img = apply_motion_blur(img, size=random.randint(3, 10), angle=random.randint(0, 360))
        
        # Add Gaussian noise
        noisy_img = add_gaussian_noise(blurred_img, mean=0, std=random.uniform(0.05, 0.2))
        
        # Save the generated noisy image and target (clean) image
        noisy_image_path = os.path.join(output_dir, f"noisy_{i:04d}.png")
        clean_image_path = os.path.join(output_dir, f"clean_{i:04d}.png")
        
        noisy_img.save(noisy_image_path)
        img.save(clean_image_path)

    print(f"Simulated dataset generated with {num_images} images in {output_dir}")

# Generate simulated dataset
generate_simulated_dataset(num_images=1000, image_size=(128, 128), output_dir='simulated_dataset')

# Paths to simulated dataset
input_dir_simulated = 'simulated_dataset'  # Simulated dataset input (noisy images)
target_dir_simulated = 'simulated_dataset'  # Simulated dataset target (clean images)

# Define Data Transformation
transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load the simulated dataset
simulated_dataset = BlurredImageDataset(input_dir_simulated, target_dir_simulated, transform=transform)

# Split the simulated dataset into train and validation
train_dataset_simulated, val_dataset_simulated = train_test_split(simulated_dataset, test_size=0.2, random_state=42)
train_loader_simulated = DataLoader(train_dataset_simulated, batch_size=10, shuffle=True)
val_loader_simulated = DataLoader(val_dataset_simulated, batch_size=10, shuffle=False)


Simulated dataset generated with 1000 images in simulated_dataset


In [4]:
# Define WEDDM Model (based on ANoiseRobust proposed algorithm)
class WEDDM(nn.Module):
    def __init__(self):
        super(WEDDM, self).__init__()
        # Define the denoising and diffusion modules
        self.denoise_module = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, padding=1)
        )
        self.diffusion_module = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, padding=1)
        )

    def forward(self, x):
        denoised = self.denoise_module(x)
        output = self.diffusion_module(denoised)
        return output
        

In [5]:
# Initialize Model, Loss, and Optimizer
model = WEDDM().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training Function
def train_model(model, train_loader, optimizer, criterion, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            if (i+1) % 40 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

# Train WEDDM on Simulated dataset (since GS-Blur dataset loading is not defined for the error you're facing)
print("Training WEDDM on Simulated dataset...")
train_model(model, train_loader_simulated, optimizer, criterion)


Training WEDDM on Simulated dataset...
Epoch [1/10], Step [40/160], Loss: 0.0448
Epoch [1/10], Step [80/160], Loss: 0.0401
Epoch [1/10], Step [120/160], Loss: 0.0208
Epoch [1/10], Step [160/160], Loss: 0.0072
Epoch [2/10], Step [40/160], Loss: 0.0034
Epoch [2/10], Step [80/160], Loss: 0.0044
Epoch [2/10], Step [120/160], Loss: 0.0030
Epoch [2/10], Step [160/160], Loss: 0.0021
Epoch [3/10], Step [40/160], Loss: 0.0035
Epoch [3/10], Step [80/160], Loss: 0.0022
Epoch [3/10], Step [120/160], Loss: 0.0012
Epoch [3/10], Step [160/160], Loss: 0.0022
Epoch [4/10], Step [40/160], Loss: 0.0021
Epoch [4/10], Step [80/160], Loss: 0.0010
Epoch [4/10], Step [120/160], Loss: 0.0019
Epoch [4/10], Step [160/160], Loss: 0.0022
Epoch [5/10], Step [40/160], Loss: 0.0012
Epoch [5/10], Step [80/160], Loss: 0.0010
Epoch [5/10], Step [120/160], Loss: 0.0017
Epoch [5/10], Step [160/160], Loss: 0.0018
Epoch [6/10], Step [40/160], Loss: 0.0021
Epoch [6/10], Step [80/160], Loss: 0.0010
Epoch [6/10], Step [120/160

In [6]:
def evaluate_model(model, val_loader):
    model.eval()
    psnr_values = []
    ssim_values = []
    lpips_values = []
    lpips_model = lpips.LPIPS(net='vgg').to(device)  # Make sure LPIPS is initialized here

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            # Normalize inputs and targets to [-1, 1] for LPIPS
            inputs_lpips = (inputs - 0.5) / 0.5
            targets_lpips = (targets - 0.5) / 0.5
            outputs_lpips = (outputs - 0.5) / 0.5

            # PSNR
            for i in range(inputs.shape[0]):
                psnr_value = psnr(targets[i].cpu().numpy(), outputs[i].cpu().numpy(), data_range=1.0)
                psnr_values.append(psnr_value)

            # SSIM
            for i in range(inputs.shape[0]):
                if min(targets[i].shape[-2:]) >= 7:
                    ssim_value = ssim(
                        targets[i].cpu().numpy(), 
                        outputs[i].cpu().numpy(), 
                        win_size=3, 
                        channel_axis=-1, 
                        data_range=1.0
                    )
                else:
                    ssim_value = 0  # or handle it differently
                ssim_values.append(ssim_value)

            # LPIPS
            lpips_value = lpips_model(outputs_lpips, targets_lpips)  # Compute LPIPS
            lpips_values.extend(lpips_value.squeeze().cpu().numpy())  # Flatten and append

    mean_psnr = np.mean(psnr_values)
    mean_ssim = np.mean(ssim_values)
    mean_lpips = np.mean(lpips_values)

    return mean_psnr, mean_ssim, mean_lpips

# Evaluate WEDDM on Simulated dataset
print("Evaluating WEDDM on Simulated dataset...")
mean_psnr_simulated, mean_ssim_simulated, mean_lpips_simulated = evaluate_model(model, val_loader_simulated)
print(f'Simulated Dataset -> Mean PSNR: {mean_psnr_simulated:.4f}, Mean SSIM: {mean_ssim_simulated:.4f}, Mean LPIPS: {mean_lpips_simulated:.4f}')


Evaluating WEDDM on Simulated dataset...
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /opt/anaconda3/lib/python3.12/site-packages/lpips/weights/v0.1/vgg.pth




Simulated Dataset -> Mean PSNR: 31.0597, Mean SSIM: 0.9563, Mean LPIPS: 0.0105
