In [1]:
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import scipy

from guided_diffusion.blind_condition_methods import get_conditioning_method
from guided_diffusion.measurements import get_operator, get_noise

# Here replaces the regular unet by our trained unet
# from guided_diffusion.unet import create_model
import guided_diffusion.diffusion_model_unet 
import guided_diffusion.unet

from guided_diffusion.gaussian_diffusion import create_sampler
from data.dataloader import get_dataset, get_dataloader
from motionblur.motionblur import Kernel
from util.img_utils import Blurkernel, clear_color
from util.logger import get_logger
from skimage.restoration import richardson_lucy, wiener, unsupervised_wiener
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.signal import convolve
import numpy as np
from scipy.signal import fftconvolve

device = 'cuda'
    
name = 'ffhq'
root = '/home/modrzyk/code/data/EUSIPCO_2024/label/'

transform = transforms.Compose([transforms.ToTensor(), 
                                ])
dataset = get_dataset(name=name, root=root, transforms=transform)
loader = get_dataloader(dataset, batch_size=1, num_workers=0, train=False)

# set seed for reproduce
# set seed for reproduce
np.random.seed(123)
torch.manual_seed(123)
torch.backends.cudnn.deterministic = True  # if using CUDA



def wiener_deconv(x_0_hat, steps, **kwargs):
    img = x_0_hat['img'].numpy().astype(np.float32)
    kernel = x_0_hat['kernel'].numpy().astype(np.float32)
    
    deconv_img = wiener(image=img, balance=1.0, psf=kernel, clip=False)
    # deconv_img, _ = unsupervised_wiener(image=img.numpy(), psf=kernel.numpy(), clip=False)
    
    x_0_hat['img'] = torch.from_numpy(deconv_img).to(device)
    
    return x_0_hat

def richardson_lucy_blind(image, psf, original, num_iter=50):    
    im_deconv = original.copy()    # init output
    for i in range(num_iter):
        psf_mirror = np.flip(psf)
        conv = fftconvolve(im_deconv, psf, mode='same')
        relative_blur = image / conv
        im_deconv *= fftconvolve(relative_blur, psf_mirror, mode='same')
        im_deconv_mirror = np.flip(im_deconv)
        psf *= fftconvolve(relative_blur, im_deconv_mirror, mode='same')    
    return im_deconv, psf

def blind_mlem(x_0_hat, steps, clip, filter_epsilon, **kwargs):
    img = x_0_hat['img'].numpy()
    psf = x_0_hat['kernel'].numpy()
    
    im_deconv, psf_deconv = richardson_lucy_blind(img, psf, img, num_iter=steps)
    x_0_hat['img'] = torch.from_numpy(im_deconv).to(device)
    x_0_hat['kernel'] = torch.from_numpy(psf_deconv).to(device)
    
    return x_0_hat
    

In [24]:
from scipy.signal import convolve
import numpy as np
import torch

def mlem(observation, x_0_hat, steps, clip, filter_epsilon, **kwargs):
    img = x_0_hat['img']
    kernel = x_0_hat['kernel']
    
    image = observation.cpu().numpy().astype(np.float32, copy=True)
    psf = kernel.cpu().numpy().astype(np.float32, copy=False)
    # im_deconv = np.full(image.shape, 0.5, dtype=np.float32)
    im_deconv = img.cpu().numpy().astype(np.float32, copy=True)
    psf_mirror = np.flip(psf)

    # Small regularization parameter used to avoid 0 divisions
    eps = 1e-6

    for _ in range(steps):
        conv = convolve(im_deconv, psf, mode='same', method='fft') + eps
        if filter_epsilon:
            relative_blur = np.where(conv < filter_epsilon, 0, image / conv)
        else:
            relative_blur = image / conv
        im_deconv *= convolve(relative_blur, psf_mirror, mode='same', method='fft')

    if clip:
        im_deconv[im_deconv > 1] = 1
        im_deconv[im_deconv < -1] = -1

    x_0_hat['img'] = torch.from_numpy(im_deconv).to(device)
    
    return x_0_hat

In [2]:
import torch
import torch.nn.functional as F
import numpy as np
from skimage.restoration import denoise_tv_chambolle
from skimage.metrics import peak_signal_noise_ratio

def mlem_gpu(observation, x_0_hat, steps, clip, filter_epsilon, device):
    with torch.no_grad():
        img = x_0_hat['img']
        kernel = x_0_hat['kernel'].repeat(1,3,1,1)
        # kernel = x_0_hat['kernel']
        
        image = observation.to(torch.float32).clone().to(device)
        psf = kernel.to(torch.float32).clone().to(device)
        im_deconv = img.to(torch.float32).clone().to(device)
        psf_mirror = torch.flip(psf, dims=[0, 1])

        # Small regularization parameter used to avoid 0 divisions
        eps = 1e-12
        pad = (psf.size(2) // 2, psf.size(2) // 2, psf.size(3) // 2, psf.size(3) // 2)
        for _ in range(steps):
            conv = F.conv2d(F.pad(im_deconv, pad, mode='replicate'), psf) + eps
            if filter_epsilon:
                relative_blur = torch.where(conv < filter_epsilon, torch.tensor(0.0, device=device), image / conv)
            else:
                relative_blur = image / conv
            im_deconv *= F.conv2d(F.pad(relative_blur, pad, mode='replicate'), psf_mirror)

            im_deconv = torch.from_numpy(denoise_tv_chambolle(im_deconv.cpu().numpy(), weight=0.005, max_num_iter=50, channel_axis=1)).to(device)
            
        if clip:
            im_deconv = torch.clamp(im_deconv, -1, 1)

        x_0_hat['img'] = im_deconv.to(device)
    
    return x_0_hat


In [132]:
import os
from tqdm import tqdm
import torch
import numpy as np

kernel_size = 61
kernel_std = 3.0
noise_level = 0.05

operator = get_operator(name='gaussian_blur', kernel_size=kernel_size, intensity=kernel_std, device=device)
noiser = get_noise(name='gaussian', sigma=noise_level)

psnr_tab = []

# Do Inference
for i, ref_img in enumerate(tqdm(loader)):
    if(i > 10):
        break   
    conv = Blurkernel('gaussian', kernel_size=kernel_size, std=kernel_std, device=device)
    kernel = conv.get_kernel().type(torch.float32)
    kernel = kernel.to(device).view(1, 1, kernel_size, kernel_size)

    ref_img = ref_img.to(device)
    y_conv = operator.forward(ref_img)
    y = noiser(y_conv)

    y_mlem = y.clone()
    y_mlem = torch.clamp(y_mlem, min=0)
    
    # Call mlem
    x_0_hat={'img': y_mlem, 'kernel': kernel}
    
    psnr_values = []

    def mlem_gpu(observation, x_0_hat, steps, clip, filter_epsilon, device):
        with torch.no_grad():
            img = x_0_hat['img']
            kernel = x_0_hat['kernel'].repeat(1,3,1,1)
            # kernel = x_0_hat['kernel']
            
            psnr_values = []

            image = observation.to(torch.float32).clone().to(device)
            psf = kernel.to(torch.float32).clone().to(device)
            im_deconv = img.to(torch.float32).clone().to(device)
            psf_mirror = torch.flip(psf, dims=[0, 1])

            # Small regularization parameter used to avoid 0 divisions
            eps = 1e-12
            pad = (psf.size(2) // 2, psf.size(2) // 2, psf.size(3) // 2, psf.size(3) // 2)
            for _ in range(steps):
                conv = F.conv2d(F.pad(im_deconv, pad, mode='replicate'), psf) + eps
                if filter_epsilon:
                    relative_blur = torch.where(conv < filter_epsilon, torch.tensor(0.0, device=device), image / conv)
                else:
                    relative_blur = image / conv
                im_deconv *= F.conv2d(F.pad(relative_blur, pad, mode='replicate'), psf_mirror)

                im_deconv = torch.from_numpy(denoise_tv_chambolle(im_deconv.cpu().numpy(), weight=0.005, max_num_iter=50, channel_axis=1)).to(device)
                
                psnr_values.append(peak_signal_noise_ratio(ref_img.cpu().numpy(), torch.clamp(im_deconv, -1, 1).cpu().numpy()))
            if clip:
                im_deconv = torch.clamp(im_deconv, -1, 1)

            x_0_hat['img'] = im_deconv.to(device)
        
        return x_0_hat, psnr_values

    x_0_hat, psnr_values = mlem_gpu(y_mlem, x_0_hat, steps=40, clip=False, filter_epsilon=1e-15, device=device)

    psnr_tab.append(psnr_values)        

  1%|          | 11/1000 [00:16<25:03,  1.52s/it]


In [140]:
min_indices = [np.argmax(arr) for arr in psnr_tab]

for i, k in zip(min_indices, range(len(min_indices))):
    print(i, psnr_tab[k][i]) 
    
np.mean(min_indices)

18 27.47198872407659
15 26.85524340129515
31 28.883712080068587
38 26.70252826697649
34 22.608825886459876
33 26.725120197067067
25 27.13388440854719
18 25.574857032982937
24 26.73496557394038
29 26.727218514995513
20 23.116434288088485


25.90909090909091

In [144]:
import os
from tqdm import tqdm
import torch
import numpy as np
from skimage.metrics import peak_signal_noise_ratio

kernel_size = 61
kernel_std = 3.0
noise_level = 0.05

operator = get_operator(name='gaussian_blur', kernel_size=kernel_size, intensity=kernel_std, device=device)
noiser = get_noise(name='gaussian', sigma=noise_level)

psnr_tab = []

# Do Inference
for i, ref_img in enumerate(tqdm(loader)):
    if(i > 468):
        conv = Blurkernel('gaussian', kernel_size=kernel_size, std=kernel_std, device=device)
        kernel = conv.get_kernel().type(torch.float32)
        kernel = kernel.to(device).view(1, 1, kernel_size, kernel_size)

        ref_img = ref_img.to(device)
        y_conv = operator.forward(ref_img)
        y = noiser(y_conv)

        y_mlem = y.clone()
        y_mlem = torch.clamp(y_mlem, min=0)
        
        # Call mlem
        x_0_hat={'img': y_mlem, 'kernel': kernel}
        
        psnr_values = []

        x_0_hat = mlem_gpu(y_mlem, x_0_hat, steps=25, clip=False, filter_epsilon=1e-15, device=device)
        psnr = peak_signal_noise_ratio(ref_img.cpu().numpy(), x_0_hat['img'].cpu().numpy())
        psnr_values.append(psnr)

        psnr_values = np.array(psnr_values)
        
        
        # Specify the output directory
        output_dir = './results/ffhq/richardson-lucy/'
        label_dir = os.path.join(output_dir, 'label')
        input_dir = os.path.join(output_dir, 'input')
        recon_dir = os.path.join(output_dir, 'recon')
        
        # Create the output directory if it doesn't exist
        os.makedirs(label_dir, exist_ok=True)
        os.makedirs(input_dir, exist_ok=True)
        os.makedirs(recon_dir, exist_ok=True)

        # Save the images
        output_path_original = os.path.join(output_dir, 'label', f'{str(i).zfill(5)}.png')
        output_path_measurement = os.path.join(output_dir, 'input', f'{str(i).zfill(5)}.png')
        output_path_mlem = os.path.join(output_dir, 'recon', f'{str(i).zfill(5)}.png')

        # Save the original image
        plt.imsave(output_path_original, clear_color(ref_img))

        # Save the measurement image
        plt.imsave(output_path_measurement, clear_color(y_mlem))

        # Save the MLEM image
        plt.imsave(output_path_mlem, clear_color(x_0_hat['img']))

        psnr_tab.append(psnr_values)

100%|██████████| 1000/1000 [08:48<00:00,  1.89it/s]


In [95]:
psnr_tab

[array([27.38016392, 27.06987285, 26.47366818, 25.78959156, 25.04268333,
        24.25774465, 23.43894457, 22.57212148]),
 array([26.66908039, 26.19404774, 25.19876411, 23.87090406, 22.50583346,
        21.26535303, 20.1159095 , 19.02529859]),
 array([28.6015395 , 28.67292075, 28.4014288 , 28.05017924, 27.64035424,
        27.19399464, 26.69667925, 26.1056366 ]),
 array([26.33872798, 26.52201741, 26.26174517, 25.87703418, 25.38428416,
        24.79060871, 24.12216647, 23.36355165]),
 array([22.32424186, 22.35960646, 22.16874455, 21.9377999 , 21.67544075,
        21.33137803, 20.89107093, 20.35884114])]

In [96]:
import numpy as np

min_indices = [np.argmax(arr) for arr in psnr_tab]
print(min_indices)



[0, 0, 1, 1, 1]
