In [1]:
import random
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import lpips
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# Set device (GPU/CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
# Define the BlurredImageDataset class
class BlurredImageDataset(Dataset):
    def __init__(self, input_dir, target_dir, transform=None, noise_std=0.1):
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.transform = transform
        self.noise_std = noise_std
        self.input_files = sorted(os.listdir(input_dir))
        self.target_files = sorted(os.listdir(target_dir))
        
        assert len(self.input_files) == len(self.target_files), "Mismatch between input and target images!"

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        input_image_path = os.path.join(self.input_dir, self.input_files[idx])
        target_image_path = os.path.join(self.target_dir, self.target_files[idx])

        input_image = Image.open(input_image_path).convert('RGB')
        target_image = Image.open(target_image_path).convert('RGB')

        # Add noise to the input image (simulating real-world noisy input)
        noisy_input_image = add_gaussian_noise(input_image, mean=0, std=self.noise_std)

        if self.transform:
            noisy_input_image = self.transform(noisy_input_image)
            target_image = self.transform(target_image)

        return noisy_input_image, target_image


In [3]:
# Function to add Gaussian noise to an image
def add_gaussian_noise(image, mean=0, std=0.1):
    np_image = np.array(image)
    noise = np.random.normal(mean, std, np_image.shape)
    noisy_image = np_image + noise * 255  # Scale the noise to match image pixel values
    noisy_image = np.clip(noisy_image, 0, 255)  # Ensure pixel values are within the valid range
    return Image.fromarray(noisy_image.astype(np.uint8))


In [7]:
# Define Data Transformation
transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Paths to GS-Blur dataset (use your dataset path)
input_dir_gsblur = '/Users/zhanglin/Documents/dku/2024-2025/session3/STATS 201/reflection/week2/mini/input_noise'  # Folder containing blurry images
target_dir_gsblur = '/Users/zhanglin/Documents/dku/2024-2025/session3/STATS 201/reflection/week2/mini/target'      # Folder containing clean images

# Load the GS-Blur dataset
gsblur_dataset = BlurredImageDataset(input_dir_gsblur, target_dir_gsblur, transform=transform, noise_std=0.2)

# Split the GS-Blur dataset into train and validation
train_dataset_gsblur, val_dataset_gsblur = train_test_split(gsblur_dataset, test_size=0.2, random_state=42)
train_loader_gsblur = DataLoader(train_dataset_gsblur, batch_size=10, shuffle=True)
val_loader_gsblur = DataLoader(val_dataset_gsblur, batch_size=10, shuffle=False)


In [8]:
# Define WEDDM Model (based on ANoiseRobust proposed algorithm)
class WEDDM(nn.Module):
    def __init__(self):
        super(WEDDM, self).__init__()
        self.denoise_module = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, padding=1)
        )
        self.diffusion_module = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, padding=1)
        )

    def forward(self, x):
        denoised = self.denoise_module(x)
        output = self.diffusion_module(denoised)
        return output


In [9]:
# Initialize Model, Loss, and Optimizer
model = WEDDM().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training Function
def train_model(model, train_loader, optimizer, criterion, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            if (i+1) % 10 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

# Train WEDDM on GS-Blur dataset
print("Training WEDDM on GS-Blur dataset...")
train_model(model, train_loader_gsblur, optimizer, criterion)


Training WEDDM on GS-Blur dataset...
Epoch [1/10], Step [10/80], Loss: 0.1481
Epoch [1/10], Step [20/80], Loss: 0.1259
Epoch [1/10], Step [30/80], Loss: 0.0830
Epoch [1/10], Step [40/80], Loss: 0.0550
Epoch [1/10], Step [50/80], Loss: 0.0513
Epoch [1/10], Step [60/80], Loss: 0.0297
Epoch [1/10], Step [70/80], Loss: 0.0459
Epoch [1/10], Step [80/80], Loss: 0.0272
Epoch [2/10], Step [10/80], Loss: 0.0359
Epoch [2/10], Step [20/80], Loss: 0.0269
Epoch [2/10], Step [30/80], Loss: 0.0289
Epoch [2/10], Step [40/80], Loss: 0.0171
Epoch [2/10], Step [50/80], Loss: 0.0237
Epoch [2/10], Step [60/80], Loss: 0.0206
Epoch [2/10], Step [70/80], Loss: 0.0166
Epoch [2/10], Step [80/80], Loss: 0.0196
Epoch [3/10], Step [10/80], Loss: 0.0155
Epoch [3/10], Step [20/80], Loss: 0.0198
Epoch [3/10], Step [30/80], Loss: 0.0172
Epoch [3/10], Step [40/80], Loss: 0.0139
Epoch [3/10], Step [50/80], Loss: 0.0184
Epoch [3/10], Step [60/80], Loss: 0.0192
Epoch [3/10], Step [70/80], Loss: 0.0149
Epoch [3/10], Step [

In [10]:
# Evaluate Function (using PSNR, SSIM, LPIPS)
def evaluate_model(model, val_loader):
    model.eval()
    psnr_values = []
    ssim_values = []
    lpips_values = []
    lpips_model = lpips.LPIPS(net='vgg').to(device)  # Make sure LPIPS is initialized here

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            # Normalize inputs and targets to [-1, 1] for LPIPS
            inputs_lpips = (inputs - 0.5) / 0.5
            targets_lpips = (targets - 0.5) / 0.5
            outputs_lpips = (outputs - 0.5) / 0.5

            # PSNR
            for i in range(inputs.shape[0]):
                psnr_value = psnr(targets[i].cpu().numpy(), outputs[i].cpu().numpy(), data_range=1.0)
                psnr_values.append(psnr_value)

            # SSIM
            for i in range(inputs.shape[0]):
                if min(targets[i].shape[-2:]) >= 7:
                    ssim_value = ssim(
                        targets[i].cpu().numpy(), 
                        outputs[i].cpu().numpy(), 
                        win_size=3, 
                        channel_axis=-1, 
                        data_range=1.0
                    )
                else:
                    ssim_value = 0  # or handle it differently
                ssim_values.append(ssim_value)

            # LPIPS
            lpips_value = lpips_model(outputs_lpips, targets_lpips)  # Compute LPIPS
            lpips_values.extend(lpips_value.squeeze().cpu().numpy())  # Flatten and append

    mean_psnr = np.mean(psnr_values)
    mean_ssim = np.mean(ssim_values)
    mean_lpips = np.mean(lpips_values)

    return mean_psnr, mean_ssim, mean_lpips

# Evaluate WEDDM on GS-Blur dataset
print("Evaluating WEDDM on GS-Blur dataset...")
mean_psnr_gsblur, mean_ssim_gsblur, mean_lpips_gsblur = evaluate_model(model, val_loader_gsblur)
print(f'GS-Blur Dataset -> Mean PSNR: {mean_psnr_gsblur:.4f}, Mean SSIM: {mean_ssim_gsblur:.4f}, Mean LPIPS: {mean_lpips_gsblur:.4f}')


Evaluating WEDDM on GS-Blur dataset...
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /opt/anaconda3/lib/python3.12/site-packages/lpips/weights/v0.1/vgg.pth


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


GS-Blur Dataset -> Mean PSNR: 20.3270, Mean SSIM: 0.7277, Mean LPIPS: 0.2872
