In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.transforms.functional as F
from scipy.ndimage import gaussian_filter
import numpy as np
import cv2

In [None]:
# Define the Unrolled Network for 5 Layers
class UnrolledDeblurringNetwork(nn.Module):
    def __init__(self, num_layers=5):
        super(UnrolledDeblurringNetwork, self).__init__()
        self.num_layers = num_layers
        self.shared_layer = nn.Conv2d(in_channels = 1, out_channels = 1, kernel_size = 3, padding = 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        for i in range(self.num_layers):
            x = self.relu(self.shared_layer(x))
        return x

In [None]:
# Create a function to generate Gaussian blur kernels
def gaussian_blur(image, sigma):
    image = np.array(image)
    blurred = gaussian_filter(image, sigma=sigma)
    return F.to_tensor(blurred)

# Blur functions for training and testing
def blur_image(img, kernel):
    img_np = img.numpy()
    blurred = cv2.filter2D(img_np, -1, kernel.numpy())
    return torch.from_numpy(blurred)

In [None]:
# PSNR Calculation
import math
def calculate_psnr(target, prediction):
    mse = nn.functional.mse_loss(prediction, target, reduction='mean').item()
    if mse == 0:
        return float('inf')
    psnr = 20 * math.log10(1.0 / math.sqrt(mse))
    return psnr

In [None]:
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
# Instantiate model, loss, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UnrolledDeblurringNetwork(num_layers=5).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training
h1_sigma = 1
h2_sigma = 2


epochs = 10
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    total_psnr_train = 0
    for images, _ in train_loader:
        images = images.cpu()
        blurred_images = torch.stack([gaussian_blur(input, sigma = h1_sigma) for input in images])
        blurred_images = blurred_images.permute(0,2,3,1)
        blurred_images = blurred_images.to(device)
        images = images.to(device)

        outputs = model(blurred_images)
        loss = criterion(outputs, images)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Calculate PSNR for mini-batch
        psnr = calculate_psnr(images, outputs)
        total_psnr_train += psnr
        print(f"Mini-batch PSNR: {psnr:.4f}")

        correct_train += torch.sum(torch.abs(outputs - images) < 0.05).item()
        total_train += images.numel()

    train_accuracy = 100 * correct_train / total_train
    avg_psnr_train = total_psnr_train / len(train_loader)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}, Train Accuracy: {train_accuracy:.4f}, Avg Train PSNR: {avg_psnr_train:.4f}")

    model.eval()
    with torch.no_grad():
        test_loss = 0.0
        correct_test = 0
        total_test = 0
        total_psnr_test = 0
        for images, _ in test_loader:
            images = images.cpu()
            blurred_images = torch.stack([gaussian_blur(input, sigma = h2_sigma) for input in images])
            blurred_images = blurred_images.permute(0,2,3,1)
            blurred_images = blurred_images.to(device)
            images = images.to(device)

            # Forward pass
            outputs = model(blurred_images)
            loss = criterion(outputs, images)
            test_loss += loss.item()

            # Calculate PSNR for mini-batch
            psnr = calculate_psnr(images, outputs)
            total_psnr_test += psnr

            # Calculating accur
            correct_test += torch.sum(torch.abs(outputs - images) < 0.05).item()
            total_test += images.numel()

        test_accuracy = 100 * correct_test / total_test
        avg_psnr_test = total_psnr_test / len(test_loader)
        print(f"Epoch [{epoch+1}/{epochs}], Test Loss: {test_loss/len(test_loader):.4f}, Test Accuracy: {test_accuracy:.4f}, Avg Test PSNR: {avg_psnr_test:.4f}")