In [31]:
# Import necessary libraries
import random
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import lpips
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

import sys
sys.path.append("/Users/zhanglin/MPRNet/Deblurring") # Path to where your MPRNet.py file is stored
from Deblurring.MPRNet import MPRNet # Make sure you have cloned and set up the MPRNet repository

# Set device (GPU/CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [33]:
# Data Preprocessing and Dataset Definition
class BlurredImageDataset(Dataset):
    def __init__(self, input_dir, target_dir, transform=None):
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.transform = transform
        self.input_files = sorted(os.listdir(input_dir))
        self.target_files = sorted(os.listdir(target_dir))
        
        assert len(self.input_files) == len(self.target_files), "Mismatch between input and target images!"

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        input_image_path = os.path.join(self.input_dir, self.input_files[idx])
        target_image_path = os.path.join(self.target_dir, self.target_files[idx])
        
        input_image = Image.open(input_image_path).convert('RGB')
        target_image = Image.open(target_image_path).convert('RGB')
        
        if self.transform:
            input_image = self.transform(input_image)
            target_image = self.transform(target_image)
        
        return input_image, target_image


In [42]:
# Data augmentation techniques
transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.RandomRotation(15),  # Random rotation for data augmentation
    transforms.RandomHorizontalFlip(),  # Random horizontal flip
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


In [44]:
# Paths to GS-Blur dataset
input_dir_gsblur = '/Users/zhanglin/Documents/dku/2024-2025/session3/STATS 201/reflection/week2/mini/input_noise'  # Folder containing blurry images
target_dir_gsblur = '/Users/zhanglin/Documents/dku/2024-2025/session3/STATS 201/reflection/week2/mini/target'      # Folder containing clean images

# Load GS-Blur dataset
gsblur_dataset = BlurredImageDataset(input_dir_gsblur, target_dir_gsblur, transform=transform)

# Split the dataset into train and validation
train_dataset_gsblur, val_dataset_gsblur = train_test_split(gsblur_dataset, test_size=0.2, random_state=42)
train_loader_gsblur = DataLoader(train_dataset_gsblur, batch_size=16, shuffle=True, num_workers=2)
val_loader_gsblur = DataLoader(val_dataset_gsblur, batch_size=10, shuffle=False)


In [45]:

# Initialize MPRNet model (without passing num_channels or num_features)
model = MPRNet().to(device)  

# Initialize Loss Function (e.g., MSELoss for deblurring)
criterion = nn.MSELoss()

# Initialize Optimizer (Adam optimizer)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Learning Rate Scheduler (optional)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)



In [46]:

# Training Function
def train_model(model, train_loader, optimizer, criterion, scheduler, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0
        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)  # MPRNet returns a list of tensors
            
            # Extract the final output (last tensor in the list)
            final_output = outputs[-1]  # Assuming the last tensor is the final output
            
            # Ensure final_output and targets have the same shape
            if final_output.shape != targets.shape:
                raise ValueError(f"Shape mismatch: final_output {final_output.shape}, targets {targets.shape}")
            
            # Compute loss using the final output
            loss = criterion(final_output, targets)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            if (i + 1) % 10 == 0:
                print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')
        
        # Step learning rate scheduler
        scheduler.step(epoch_loss / len(train_loader))
        print(f'Epoch [{epoch + 1}/{num_epochs}], Average Loss: {epoch_loss / len(train_loader):.4f}')

# Train MPRNet on GS-Blur dataset
print("Training MPRNet on GS-Blur dataset...")
train_model(model, train_loader_gsblur, optimizer, criterion, scheduler, num_epochs=10)


Training MPRNet on GS-Blur dataset...
Epoch [1/10], Step [10/50], Loss: 0.1503
Epoch [1/10], Step [20/50], Loss: 0.1339
Epoch [1/10], Step [30/50], Loss: 0.1789
Epoch [1/10], Step [40/50], Loss: 0.1295
Epoch [1/10], Step [50/50], Loss: 0.1349
Epoch [1/10], Average Loss: 0.1540


python(2178) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(2179) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Epoch [2/10], Step [10/50], Loss: 0.1553
Epoch [2/10], Step [20/50], Loss: 0.1539
Epoch [2/10], Step [30/50], Loss: 0.1455
Epoch [2/10], Step [40/50], Loss: 0.1139
Epoch [2/10], Step [50/50], Loss: 0.1403
Epoch [2/10], Average Loss: 0.1476


python(3200) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(3202) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Epoch [3/10], Step [10/50], Loss: 0.1342
Epoch [3/10], Step [20/50], Loss: 0.1435
Epoch [3/10], Step [30/50], Loss: 0.1293
Epoch [3/10], Step [40/50], Loss: 0.1631
Epoch [3/10], Step [50/50], Loss: 0.1904
Epoch [3/10], Average Loss: 0.1454


python(4292) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(4293) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Epoch [4/10], Step [10/50], Loss: 0.1874
Epoch [4/10], Step [20/50], Loss: 0.1170
Epoch [4/10], Step [30/50], Loss: 0.1104
Epoch [4/10], Step [40/50], Loss: 0.1548
Epoch [4/10], Step [50/50], Loss: 0.1775
Epoch [4/10], Average Loss: 0.1433


python(5102) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(5103) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Epoch [5/10], Step [10/50], Loss: 0.1564
Epoch [5/10], Step [20/50], Loss: 0.1304
Epoch [5/10], Step [30/50], Loss: 0.1640
Epoch [5/10], Step [40/50], Loss: 0.1438
Epoch [5/10], Step [50/50], Loss: 0.1465
Epoch [5/10], Average Loss: 0.1430


python(6461) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(6462) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Epoch [6/10], Step [10/50], Loss: 0.1323
Epoch [6/10], Step [20/50], Loss: 0.1753
Epoch [6/10], Step [30/50], Loss: 0.1378
Epoch [6/10], Step [40/50], Loss: 0.1277
Epoch [6/10], Step [50/50], Loss: 0.1361
Epoch [6/10], Average Loss: 0.1410


python(7665) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(7666) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Epoch [7/10], Step [10/50], Loss: 0.1315
Epoch [7/10], Step [20/50], Loss: 0.1286
Epoch [7/10], Step [30/50], Loss: 0.1312
Epoch [7/10], Step [40/50], Loss: 0.1582
Epoch [7/10], Step [50/50], Loss: 0.1156
Epoch [7/10], Average Loss: 0.1389


python(8616) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8617) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Epoch [8/10], Step [10/50], Loss: 0.1417
Epoch [8/10], Step [20/50], Loss: 0.1432
Epoch [8/10], Step [30/50], Loss: 0.1286
Epoch [8/10], Step [40/50], Loss: 0.1274
Epoch [8/10], Step [50/50], Loss: 0.1266
Epoch [8/10], Average Loss: 0.1385


python(9508) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9511) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Epoch [9/10], Step [10/50], Loss: 0.1293
Epoch [9/10], Step [20/50], Loss: 0.1462
Epoch [9/10], Step [30/50], Loss: 0.1312
Epoch [9/10], Step [40/50], Loss: 0.1521
Epoch [9/10], Step [50/50], Loss: 0.1641
Epoch [9/10], Average Loss: 0.1357


python(10392) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10393) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Epoch [10/10], Step [10/50], Loss: 0.1533
Epoch [10/10], Step [20/50], Loss: 0.1303
Epoch [10/10], Step [30/50], Loss: 0.1265
Epoch [10/10], Step [40/50], Loss: 0.1105
Epoch [10/10], Step [50/50], Loss: 0.1511
Epoch [10/10], Average Loss: 0.1373


In [63]:
def evaluate_model(model, val_loader):
    model.eval()
    psnr_values = []
    ssim_values = []
    lpips_values = []
    lpips_model = lpips.LPIPS(net='vgg').to(device)  # Make sure LPIPS is initialized here

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)  # MPRNet returns a list of tensors

            # Extract the final output (last tensor in the list)
            final_output = outputs[-1]  # Assuming the last tensor is the final output

            # Normalize inputs and targets to [-1, 1] for LPIPS
            inputs_lpips = (inputs - 0.5) / 0.5
            targets_lpips = (targets - 0.5) / 0.5
            outputs_lpips = (final_output - 0.5) / 0.5  # Normalize the final output

            # PSNR
            for i in range(inputs.shape[0]):
                psnr_value = psnr(targets[i].cpu().numpy(), final_output[i].cpu().numpy(), data_range=1.0)
                psnr_values.append(psnr_value)

            # SSIM
            for i in range(inputs.shape[0]):
                if min(targets[i].shape[-2:]) >= 7:
                    ssim_value = ssim(
                        targets[i].cpu().numpy(), 
                        final_output[i].cpu().numpy(), 
                        win_size=3, 
                        channel_axis=-1, 
                        data_range=1.0
                    )
                else:
                    ssim_value = 0  # or handle it differently
                ssim_values.append(ssim_value)

            # LPIPS
            lpips_value = lpips_model(outputs_lpips, targets_lpips)  # Compute LPIPS
            lpips_values.extend(lpips_value.squeeze().cpu().numpy())  # Flatten and append

    mean_psnr = np.mean(psnr_values)
    mean_ssim = np.mean(ssim_values)
    mean_lpips = np.mean(lpips_values)

    return mean_psnr, mean_ssim, mean_lpips

# Evaluate MPRNet on GS-Blur dataset
print("Evaluating MPRNet on GS-Blur dataset...")
mean_psnr_gsblur, mean_ssim_gsblur, mean_lpips_gsblur = evaluate_model(model, val_loader_gsblur)
print(f'GS-Blur Dataset -> Mean PSNR: {mean_psnr_gsblur:.4f}, Mean SSIM: {mean_ssim_gsblur:.4f}, Mean LPIPS: {mean_lpips_gsblur:.4f}')

Evaluating MPRNet on GS-Blur dataset...
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
Loading model from: /opt/anaconda3/lib/python3.12/site-packages/lpips/weights/v0.1/vgg.pth
GS-Blur Dataset -> Mean PSNR: 8.8505, Mean SSIM: 0.2413, Mean LPIPS: 0.6165
