In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from scipy.ndimage import gaussian_filter
import math

In [18]:
# Define the Half-Quadratic Splitting Module
class HalfQuadraticSplitting(nn.Module):
    def __init__(self, C):
        super(HalfQuadraticSplitting, self).__init__()
        self.filters = nn.ModuleList([nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False) for _ in range(C)])
        self.zeta = nn.Parameter(torch.ones(C))
        self.lambd = nn.Parameter(torch.ones(C))

    def forward(self, y, k):
        g_list, z_list = [], []
        for i, f in enumerate(self.filters):
            y_dft = torch.fft.fft2(y)
            z_i_dft = torch.fft.fft2(f(y))
            numerator = self.zeta[i] * torch.conj(z_i_dft) * y_dft
            denominator = (self.zeta[i] * torch.abs(z_i_dft) ** 2) + 1e-5
            g_i_dft = numerator / denominator
            g_i = torch.fft.ifft2(g_i_dft).real
            g_list.append(g_i)
            # print(self.lambd[i],self.zeta[i])
            z_list.append(F.softshrink(g_i, lambd=(self.lambd[i] * self.zeta[i]).item() ))

        # Calculate k after the first end for
        y_dft = torch.fft.fft2(y)
        numerator = torch.sum(torch.stack([torch.conj(torch.fft.fft2(z)) * y_dft for z in z_list]), dim=0)
        denominator = torch.sum(torch.stack([torch.abs(torch.fft.fft2(z)) ** 2 for z in z_list]), dim=0) + 1e-5
        k = torch.fft.ifft2(numerator / denominator).real
        k = F.relu(k)
        k = k / k.sum()  # Normalize kernel

        return g_list, z_list, k

# Define the Deblurring Network (Unrolled L-Layer Network)
class DeblurringNetwork(nn.Module):
    def __init__(self, C, L):
        super(DeblurringNetwork, self).__init__()
        self.C = C
        self.L = L
        self.hqs_modules = nn.ModuleList([HalfQuadraticSplitting(C) for _ in range(L)])
        self.kernels = nn.ParameterList([nn.Parameter(torch.ones(1, 1, 3, 3)) for _ in range(L)])

    def forward(self, y):
        k = torch.ones_like(y[:, :, :3, :3])  # Initial estimate for kernel (identity)
        for l in range(self.L):
            hqs = self.hqs_modules[l]
            g_list, z_list, k = hqs(y, k)
        return g_list, z_list, k

In [8]:
batch_size=64
transform = transforms.Compose([
    transforms.ToTensor(),
])
train_set = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_set = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [9]:
def calculate_psnr(original, output):
    mse = F.mse_loss(output, original)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * math.log10(max_pixel / math.sqrt(mse))
    return psnr

def calculate_accuracy(original, output, tolerance=0.05):
    diff = torch.abs(original - output)
    correct = torch.sum(diff < tolerance).item()
    total = original.numel()
    accuracy = correct / total
    return accuracy * 100

In [19]:

model = DeblurringNetwork(C=16, L=10).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
optimizer = optim.Adam(model.parameters(), lr=1e-3)



# Training loop
epochs = 10
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    epoch_psnr = 0
    epoch_accuracy = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

        blurred_data_np = gaussian_filter(data.cpu().numpy(), sigma=1)
        blurred_data = torch.tensor(blurred_data_np).to(data.device)

        # Forward pass
        g_list, z_list, output = model(blurred_data)

        # Calculate loss based on the provided formula
        loss = 0
        for i in range(model.C):
            f_i_y = model.hqs_modules[0].filters[i](blurred_data)
            f_i_g = model.hqs_modules[0].filters[i](output)
            l2_norm_term = 0.5 * torch.norm(f_i_y - f_i_g) ** 2
            l1_norm_term = model.hqs_modules[0].lambd[i] * torch.norm(z_list[i], p=1)
            l2_norm_diff_term = 0.5 * torch.norm(g_list[i] - z_list[i]) ** 2 / model.hqs_modules[0].zeta[i]
            loss += l2_norm_term + l1_norm_term + l2_norm_diff_term
        loss += 0.5 * torch.norm(output) ** 2 * 1e-5

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        # Calculate PSNR for the minibatch
        psnr = calculate_psnr(data, output)
        epoch_psnr += psnr


        accuracy = calculate_accuracy(data, output)
        epoch_accuracy += accuracy

        print(f"Batch [{batch_idx+1}/{len(train_loader)}], Training PSNR: {psnr:.2f} dB, Training Accuracy: {accuracy:.2f}%")


    print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {epoch_loss / len(train_loader):.6f}, Training PSNR: {epoch_psnr / len(train_loader):.2f} dB, Training Accuracy: {epoch_accuracy / len(train_loader):.2f}%")


    model.eval()
    test_loss = 0
    test_psnr = 0
    test_accuracy = 0
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(test_loader):
            data = data.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))


            blurred_data_np = gaussian_filter(data.cpu().numpy(), sigma=2)
            blurred_data = torch.tensor(blurred_data_np).to(data.device)


            g_list, z_list, output = model(blurred_data)

            # Calculate loss based on the provided formula
            loss = 0
            for i in range(model.C):
                f_i_y = model.hqs_modules[0].filters[i](blurred_data)
                f_i_g = model.hqs_modules[0].filters[i](output)
                l2_norm_term = 0.5 * torch.norm(f_i_y - f_i_g) ** 2
                l1_norm_term = model.hqs_modules[0].lambd[i] * torch.norm(z_list[i], p=1)
                l2_norm_diff_term = 0.5 * torch.norm(g_list[i] - z_list[i]) ** 2 / model.hqs_modules[0].zeta[i]
                loss += l2_norm_term + l1_norm_term + l2_norm_diff_term
            loss += 0.5 * torch.norm(output) ** 2 * 1e-5

            test_loss += loss.item()

            # Calculate PSNR for the minibatch
            psnr = calculate_psnr(data, output)
            test_psnr += psnr

            # Calculate accuracy for the minibatch
            accuracy = calculate_accuracy(data, output)
            test_accuracy += accuracy

            print(f"Batch [{batch_idx+1}/{len(test_loader)}], Test PSNR: {psnr:.2f} dB, Test Accuracy: {accuracy:.2f}%")

    print(f"Epoch [{epoch+1}/{epochs}], Test Loss: {test_loss / len(test_loader):.6f}, Test PSNR: {test_psnr / len(test_loader):.2f} dB, Test Accuracy: {test_accuracy / len(test_loader):.2f}%")

Batch [1/938], Training PSNR: 9.22 dB, Training Accuracy: 80.62%
Batch [2/938], Training PSNR: 9.77 dB, Training Accuracy: 82.74%
Batch [3/938], Training PSNR: 9.92 dB, Training Accuracy: 82.95%
Batch [4/938], Training PSNR: 9.42 dB, Training Accuracy: 81.73%
Batch [5/938], Training PSNR: 9.30 dB, Training Accuracy: 81.20%
Batch [6/938], Training PSNR: 9.37 dB, Training Accuracy: 81.10%
Batch [7/938], Training PSNR: 9.66 dB, Training Accuracy: 82.19%
Batch [8/938], Training PSNR: 9.12 dB, Training Accuracy: 80.25%
Batch [9/938], Training PSNR: 9.32 dB, Training Accuracy: 80.94%
Batch [10/938], Training PSNR: 9.53 dB, Training Accuracy: 81.84%
Batch [11/938], Training PSNR: 9.33 dB, Training Accuracy: 80.95%
Batch [12/938], Training PSNR: 9.55 dB, Training Accuracy: 81.85%
Batch [13/938], Training PSNR: 9.75 dB, Training Accuracy: 82.42%
Batch [14/938], Training PSNR: 9.65 dB, Training Accuracy: 82.46%
Batch [15/938], Training PSNR: 9.36 dB, Training Accuracy: 80.94%
Batch [16/938], Tra

KeyboardInterrupt: 