In [61]:
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import scipy

from guided_diffusion.blind_condition_methods import get_conditioning_method
from guided_diffusion.measurements import get_operator, get_noise

# Here replaces the regular unet by our trained unet
# from guided_diffusion.unet import create_model
import guided_diffusion.diffusion_model_unet 
import guided_diffusion.unet

from guided_diffusion.gaussian_diffusion import create_sampler
from data.dataloader import get_dataset, get_dataloader
from motionblur.motionblur import Kernel
from util.img_utils import Blurkernel, clear_color
from util.logger import get_logger
from skimage.restoration import richardson_lucy, wiener, unsupervised_wiener
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.signal import convolve
import numpy as np
from scipy.signal import fftconvolve

from skimage.restoration import denoise_tv_chambolle
from skimage.metrics import peak_signal_noise_ratio

device = 'cuda'
    
name = 'ffhq'
root = '/home/modrzyk/code/data/EUSIPCO_2024/label/'

transform = transforms.Compose([transforms.ToTensor(), 
                                ])
dataset = get_dataset(name=name, root=root, transforms=transform)
loader = get_dataloader(dataset, batch_size=1, num_workers=0, train=False)

# set seed for reproduce
# set seed for reproduce
np.random.seed(123)
torch.manual_seed(123)
torch.backends.cudnn.deterministic = True  # if using CUDA

def mlem_gpu(observation, x_0_hat, steps, clip, filter_epsilon, device, reg, lam=0.005):
    with torch.no_grad():
        kernel = x_0_hat['kernel'].repeat(1, 3, 1, 1)
        
        image = observation.to(torch.float32).clone().to(device)
        psf = kernel.to(torch.float32).clone().to(device)
        im_deconv = x_0_hat['img'].to(torch.float32).clone().to(device)
        psf_mirror = torch.flip(psf, dims=[2, 3])  # Flipping should be on the last two dimensions for 4D tensor

        # Define the Laplacian kernel for gradient L2 regularization
        laplacian_kernel = torch.tensor([[0, 1, 0], [1, -4, 1], [0, 1, 0]], device=device, dtype=torch.float32)
        laplacian_kernel = laplacian_kernel.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, 3, 3] to match convolution requirements
        laplacian_kernel = laplacian_kernel.repeat(1, 3, 1, 1)  
        eps = 1e-12
        pad = (psf.size(2) // 2, psf.size(2) // 2, psf.size(3) // 2, psf.size(3) // 2)
        
        for _ in range(steps):
            conv = F.conv2d(F.pad(im_deconv, pad, mode='replicate'), psf) + eps
            if filter_epsilon:
                relative_blur = torch.where(conv < filter_epsilon, torch.tensor(0.0, device=device), image / conv)
            else:
                relative_blur = image / conv
            im_deconv *= F.conv2d(F.pad(relative_blur, pad, mode='replicate'), psf_mirror)

            if reg == "tv":
                im_deconv = torch.from_numpy(denoise_tv_chambolle(im_deconv.cpu().numpy(), weight=0.005, max_num_iter=50, channel_axis=1)).to(device)
            
            elif reg == "l1":
                im_deconv -= lam * torch.sign(im_deconv)
                
            elif reg == "l2":
                im_deconv -= lam * im_deconv

            elif reg == "grad_l2":
                # Convolve with the Laplacian kernel and update
                laplacian = F.conv2d(im_deconv, laplacian_kernel, padding=1)
                im_deconv -= lam * laplacian
            
            elif reg == "h1":
                # L2 regularization on image values
                im_deconv -= lam * im_deconv

                # L2 regularization on the image gradient (smoothness)
                laplacian = F.conv2d(im_deconv, laplacian_kernel, padding=1)
                im_deconv -= lam * laplacian
                
        if clip:
            im_deconv = torch.clamp(im_deconv, -1, 1)

        return im_deconv.to(device)
    
    
def gd(observation, x_0_hat, steps, learning_rate, device, reg_type=None, lambda_reg=0.001):
    psf = x_0_hat['kernel'].repeat(1, 3, 1, 1)
    x_hat = observation.clone().to(device)
    psf = psf.to(device)
    psf /= psf.sum()
    psf_flipped = torch.flip(psf, [2, 3])
    
    laplacian_kernel = torch.tensor([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0)
    laplacian_kernel = laplacian_kernel.repeat(1, 3, 1, 1)
    for step in range(steps):
        convolved = F.conv2d(x_hat, psf, padding='same')
        gradient = 2 * (convolved - observation)

        if reg_type == "L2":
            gradient += 2 * lambda_reg * x_hat
        elif reg_type == "L1":
            gradient += lambda_reg * torch.sign(x_hat)
        elif reg_type == "H1":
            # Add L2 regularization
            gradient += 2 * lambda_reg * x_hat
            # Approximate gradient's L2 norm using Laplacian for smoothness
            laplacian_x_hat = F.conv2d(x_hat, laplacian_kernel, padding=1)
            gradient += 2 * lambda_reg * laplacian_x_hat

        x_hat -= learning_rate * F.conv2d(gradient, psf_flipped, padding='same')
    
    return x_hat



In [74]:
import os
from tqdm import tqdm
import torch
import numpy as np
import numpy as np

kernel_size = 61
kernel_std = 3.0
noise_level = 0.05

operator = get_operator(name='gaussian_blur', kernel_size=kernel_size, intensity=kernel_std, device=device)
noiser = get_noise(name='gaussian', sigma=noise_level)

psnr_tab = []

# Do Inference
for i, ref_img in enumerate(tqdm(loader)):
    if(i==1):
        conv = Blurkernel('gaussian', kernel_size=kernel_size, std=kernel_std, device=device)
        kernel = conv.get_kernel().type(torch.float32)
        kernel = kernel.to(device).view(1, 1, kernel_size, kernel_size)

        ref_img = ref_img.to(device)
        y_conv = operator.forward(ref_img)
        y = noiser(y_conv)

        y_mlem = y.clone()
        y_mlem = torch.clamp(y_mlem, min=0)
        
        # Call mlem
        x_0_hat={'img': y_mlem, 'kernel': kernel}
        
        x_0_rl = mlem_gpu(y_mlem, x_0_hat, steps=40, clip=False, filter_epsilon=0, device=device, reg=None, lam = 0)
        x_0_rl_tv = mlem_gpu(y_mlem, x_0_hat, steps=40, clip=False, filter_epsilon=1e-15, device=device, reg="tv", lam = 0.005)
        
        # x_0_gd = gd(y_mlem, x_0_hat, steps=40, learning_rate=1.0, device=device, reg_type=None, lambda_reg = 0)
        # x_0_gd_l1 = gd(y_mlem, x_0_hat, steps=40, learning_rate=1.0, device=device, reg_type="L1", lambda_reg = 0.03)
        # x_0_gd_l2 = gd(y_mlem, x_0_hat, steps=40, learning_rate=1.0, device=device, reg_type="L2", lambda_reg = 0.03)
        # x_0_gd_h1 = gd(y_mlem, x_0_hat, steps=40, learning_rate=1.0, device=device, reg_type="H1", lambda_reg = 0.04)
    
        plt.imsave('results/rl.png', torch.clamp(x_0_rl, 0, 1).squeeze().detach().cpu().numpy().transpose(1, 2, 0))
        plt.imsave('results/rl_tv.png', torch.clamp(x_0_rl_tv, 0, 1).squeeze().detach().cpu().numpy().transpose(1, 2, 0))
        plt.imsave('results/y.png', torch.clamp(y, 0, 1).squeeze().detach().cpu().numpy().transpose(1, 2, 0))
        plt.imsave('results/ref_img.png', torch.clamp(ref_img, 0, 1).squeeze().detach().cpu().numpy().transpose(1, 2, 0))

100%|██████████| 1000/1000 [00:05<00:00, 178.00it/s]
