# WPP

In [1]:
import sys
  
# Prints the list of directories that the 
# interpreter will search for the required module. 
print(sys.path)

sys.path.insert(0, "/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM")

['/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/WPP/OT_SUOT', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python37.zip', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/lib-dynload', '', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/site-packages', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/site-packages/IPython/extensions', '/home/prof/smignon/.ipython']


In [None]:
# This code belongs to the paper
#
# J. Hertrich, A. Houdard and C. Redenbach.
# Wasserstein Patch Prior for Image Superresolution.
# IEEE Transactions on Computational Imaging, 2022.
#
# Please cite the paper, if you use this code.
#
# This script applies the Wasserstein Patch Prior reconstruction onto the 2D SiC Diamonds image
# from Section 4.2 of the paper.
#
import argparse
import wgenpatex_color as wgenpatex
import torch
import skimage.transform
from skimage import (color, data, measure)
from skimage.metrics import structural_similarity as ssim
import numpy as np
from scipy.interpolate import griddata
import os
import lpips
import glob
import torchvision
from torchvision.transforms import Resize as tv_resize

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
torch.cuda.set_device(0)

# images PATH
os.chdir('/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/Datasets/18_images_wd_wod_dataset')  

list_im_name   = [file for file in glob.glob("HR/*.png")]#.sort()
list_im_modele_with_def   = [file for file in glob.glob("HR_wd/*.png")]#.sort()
list_im_modele_without_def   = [file for file in glob.glob("HR_wod/*.png")]#.sort()

list_im_name.sort()
list_im_modele_with_def.sort()
list_im_modele_without_def.sort()
list_im_modele=[list_im_modele_with_def,list_im_modele_without_def]

# Lists
#PSNR 
def PSNR(im,im_new):
    C,M,N=im_new.shape
    EQM=1/(C*M*N)*torch.sum((im-im_new)**2)
    psnr=10*torch.log10(1/EQM)
    return(psnr)

list_psnr_WPP=[[],[]]

# restored images 
list_im_rest_WPP=[[],[]]

# lr img
list_im_LR=[[],[]]

# LPIPS
loss_fn_alex = lpips.LPIPS(net='alex')
list_lpips_WPP=[[],[]]

# SSIM
list_ssim_WPP=[[],[]]

for i,name_im in enumerate(list_im_name):
    for j in range(2):
        # set arguments
        args=argparse.Namespace()
        args.target_image_path=name_im
        args.scales=2
        args.keops=True
        args.n_iter_max=500
        args.save=True
        args.n_patches_out=10000
        args.learn_image_path=list_im_modele[j][i]
        args.patch_size=6
        args.lam=(6000/(3*args.patch_size**2))*(256**2/600**2)
        args.n_iter_psi=10
        args.n_patches_in=-1
        args.visu=False


        # define forward operator
        blur_width=2.0
        add_boundary=20
        kernel_size=16
        stride=4
        my_layer=wgenpatex.gaussian_layer(kernel_size,blur_width,stride=stride)

        def operator(inp):
            if add_boundary==0:
                return my_layer.forward(inp)
            return my_layer.forward(inp[:,:,add_boundary:-add_boundary,add_boundary:-add_boundary])
        
        torch.manual_seed(i) # reproductibilité de l'expérience
        
        # read HR ground truth
        hr_img=wgenpatex.imread(args.target_image_path)
        hr_img=tv_resize(256, antialias=True)(hr_img)

        lr_img=wgenpatex.imread(args.target_image_path)
        lr_img=tv_resize(256, antialias=True)(lr_img)
        args.size=lr_img.shape[2:4]
        lr_img_=np.zeros((3,lr_img.shape[2]+2*add_boundary,lr_img.shape[3]+2*add_boundary))
        if add_boundary>0:
            lr_img_[:,add_boundary:-add_boundary,add_boundary:-add_boundary]=lr_img.squeeze().cpu().numpy()
        else:
            lr_img_=lr_img.squeeze().cpu().numpy()

        # create (artificially) LR observation
        lr_img=operator(torch.tensor(lr_img_,dtype=torch.float,device=DEVICE).view(1,3,lr_img_.shape[1],lr_img_.shape[2]))
        lr_img+=0.01*torch.randn_like(lr_img)
        wgenpatex.imsave('input_imgs/lr_diam.png',lr_img)


        # build initialization by rescaling the lr observation and extending it to the boundary
        upscaled=skimage.transform.resize(lr_img.squeeze().cpu().numpy(),[3,lr_img.shape[2]*stride,lr_img.shape[3]*stride])
        diff=args.size[0]-upscaled.shape[1]

        init=np.zeros((3,args.size[0],args.size[1]),dtype=bool)
        init[:,diff//2:-diff//2,diff//2:-diff//2]=True
        grid_x=np.array(range(init.shape[1]))
        grid_x=np.tile(grid_x[:,np.newaxis],[1,init.shape[2]])
        grid_y=np.array(range(init.shape[2]))
        grid_y=np.tile(grid_y[np.newaxis,:],[init.shape[1],1])
        points_x=np.reshape(grid_x[init[0,:,:]],[-1])
        points_y=np.reshape(grid_y[init[0,:,:]],[-1])
        init=init.astype(float)
        for k in range(3):
            values=np.reshape(upscaled[k,:,:],[-1])
            points=np.stack([points_x,points_y],0).transpose()
            init[k,:,:]=griddata(points,values,(grid_x,grid_y),method='nearest')
            
        init_=np.random.uniform(size=(3,init.shape[1]+2*add_boundary,init.shape[2]+2*add_boundary))
        if add_boundary==0:
            init_=init
        else:
            init_[:,add_boundary:-add_boundary,add_boundary:-add_boundary]=init
        args.size=init_.shape

        # load learn img
        learn_img=wgenpatex.imread(args.learn_image_path)
        learn_img=tv_resize(256, antialias=True)(learn_img)
        print(learn_img.dtype)

        # run reconstruction
        synth_img= wgenpatex.optim_synthesis_SR(args,operator,lr_img,learn_img,args.lam,init=init_,add_boundary=add_boundary)
        
        if add_boundary>0:
            synth_img=synth_img[:,:,add_boundary:-add_boundary,add_boundary:-add_boundary]

        list_im_rest_WPP[j].append(synth_img)
        # PSNR
        list_psnr_WPP[j].append(PSNR(torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')),
                                torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu'))))
        # LPIPS
        list_lpips_WPP[j].append(loss_fn_alex(torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')).unsqueeze(0)
                                          , torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu')).unsqueeze(0)))
        
        # SSIM
        img_hr=torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')).detach().numpy().transpose(1, 2, 0)
        img_pred_WPP=torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu')).detach().numpy().transpose(1, 2, 0)
        list_ssim_WPP[j].append(ssim(img_hr, img_pred_WPP,data_range=img_pred_WPP.max() - img_pred_WPP.min(),multichannel=True))
        print(i)
        
        #wgenpatex.imsave('output_imgs_diam/synthesized_with_bound.png', synth_img)

os.chdir('/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/WPP/OT_SUOT') 
torch.save(list_psnr_WPPSU,"list_psnr_WPP_color")
torch.save(list_lpips_WPPSU,"list_lpips_WPP_color")
torch.save(list_ssim_WPPSU,"list_ssim_WPP_color")
torch.save(list_im_rest_WPPSU,"list_im_rest_WPP_color")

# WPPSU

In [None]:
# This code belongs to the paper
#
# J. Hertrich, A. Houdard and C. Redenbach.
# Wasserstein Patch Prior for Image Superresolution.
# IEEE Transactions on Computational Imaging, 2022.
#
# Please cite the paper, if you use this code.
#
# This script applies the Wasserstein Patch Prior reconstruction onto the 2D SiC Diamonds image
# from Section 4.2 of the paper.
#
import argparse
import wgenpatex_color_SU as wgenpatex
import torch
import skimage.transform
from skimage import (color, data, measure)
from skimage.metrics import structural_similarity as ssim
import numpy as np
from scipy.interpolate import griddata
import os
import lpips
import glob
import torchvision
from torchvision.transforms import Resize as tv_resize

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
torch.cuda.set_device(2)

# images PATH
os.chdir('/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/Datasets/18_images_wd_wod_dataset')  
 
list_im_name   = [file for file in glob.glob("HR/*.png")]#.sort()
list_im_modele_with_def   = [file for file in glob.glob("HR_wd/*.png")]#.sort()
list_im_modele_without_def   = [file for file in glob.glob("HR_wod/*.png")]#.sort()

list_im_name.sort()
list_im_modele_with_def.sort()
list_im_modele_without_def.sort()
list_im_modele=[list_im_modele_with_def,list_im_modele_without_def]

# Lists
#PSNR 
def PSNR(im,im_new):
    C,M,N=im_new.shape
    EQM=1/(C*M*N)*torch.sum((im-im_new)**2)
    psnr=10*torch.log10(1/EQM)
    return(psnr)

list_psnr_WPPSU=[[],[]]

# restored images 
list_im_rest_WPPSU=[[],[]]

# lr img
list_im_LR=[[],[]]

# LPIPS
loss_fn_alex = lpips.LPIPS(net='alex')
list_lpips_WPPSU=[[],[]]

# SSIM
list_ssim_WPPSU=[[],[]]

# SSIM
list_ssim_WPPSU=[[],[]]
for i,name_im in enumerate(list_im_name):
    for j in range(2):
        # set arguments
        args=argparse.Namespace()
        args.target_image_path=name_im
        args.scales=2
        args.keops=True
        args.n_iter_max=500
        args.save=True
        args.n_patches_out=10000
        args.learn_image_path=list_im_modele[j][i]
        args.patch_size=6
        args.lam=(6000/(3*args.patch_size**2))*(256**2/600**2)
        args.n_iter_psi=10
        args.n_patches_in=-1
        args.visu=False


        # define forward operator
        blur_width=2.0
        add_boundary=0
        kernel_size=16
        stride=4
        my_layer=wgenpatex.gaussian_layer(kernel_size,blur_width,stride=stride)

        def operator(inp):
            if add_boundary==0:
                return my_layer.forward(inp)
            return my_layer.forward(inp[:,:,add_boundary:-add_boundary,add_boundary:-add_boundary])
        
        torch.manual_seed(i) # reproductibilité de l'expérience
        
        # read HR ground truth
        hr_img=wgenpatex.imread(args.target_image_path)
        hr_img=tv_resize(256, antialias=True)(hr_img)

        lr_img=wgenpatex.imread(args.target_image_path)
        lr_img=tv_resize(256, antialias=True)(lr_img)
        args.size=lr_img.shape[2:4]
        lr_img_=np.zeros((3,lr_img.shape[2]+2*add_boundary,lr_img.shape[3]+2*add_boundary))
        if add_boundary>0:
            lr_img_[:,add_boundary:-add_boundary,add_boundary:-add_boundary]=lr_img.squeeze().cpu().numpy()
        else:
            lr_img_=lr_img.squeeze().cpu().numpy()

        # create (artificially) LR observation
        lr_img=operator(torch.tensor(lr_img_,dtype=torch.float,device=DEVICE).view(1,3,lr_img_.shape[1],lr_img_.shape[2]))
        lr_img+=0.01*torch.randn_like(lr_img)
        


        # build initialization by rescaling the lr observation and extending it to the boundary
        upscaled=skimage.transform.resize(lr_img.squeeze().cpu().numpy(),[3,lr_img.shape[2]*stride,lr_img.shape[3]*stride])
        diff=args.size[0]-upscaled.shape[1]

        init=np.zeros((3,args.size[0],args.size[1]),dtype=bool)
        init[:,diff//2:-diff//2,diff//2:-diff//2]=True
        grid_x=np.array(range(init.shape[1]))
        grid_x=np.tile(grid_x[:,np.newaxis],[1,init.shape[2]])
        grid_y=np.array(range(init.shape[2]))
        grid_y=np.tile(grid_y[np.newaxis,:],[init.shape[1],1])
        points_x=np.reshape(grid_x[init[0,:,:]],[-1])
        points_y=np.reshape(grid_y[init[0,:,:]],[-1])
        init=init.astype(float)
        for k in range(3):
            values=np.reshape(upscaled[k,:,:],[-1])
            points=np.stack([points_x,points_y],0).transpose()
            init[k,:,:]=griddata(points,values,(grid_x,grid_y),method='nearest')
            #print(init.shape)
        init_=np.random.uniform(size=(3,init.shape[1]+2*add_boundary,init.shape[2]+2*add_boundary))
        if add_boundary==0:
            init_=init
        else:
            init_[:,add_boundary:-add_boundary,add_boundary:-add_boundary]=init
        args.size=init_.shape

        # load learn img
        learn_img=wgenpatex.imread(args.learn_image_path)
        learn_img=tv_resize(256, antialias=True)(learn_img)
        print(learn_img.dtype)

       # run reconstruction
        synth_img= wgenpatex.optim_synthesis_SR(args,operator,lr_img,learn_img,args.lam,ρ=0.01,init=init_,add_boundary=add_boundary)

        # save reconstruction
        #if not os.path.isdir('output_imgs_diam'):
            #os.mkdir('output_imgs_diam')

        #wgenpatex.imsave('output_imgs_diam/synthesized_no_crop_bound.png', synth_img)
        if add_boundary>0:
            synth_img=synth_img[:,:,add_boundary:-add_boundary,add_boundary:-add_boundary]
        # Restored image 
        list_im_rest_WPPSU[j].append(synth_img)
        # PSNR
        list_psnr_WPPSU[j].append(PSNR(torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')),
                                torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu'))))
        # LPIPS
        list_lpips_WPPSU[j].append(loss_fn_alex(torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')).unsqueeze(0)
                                          , torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu')).unsqueeze(0)))
        
        # SSIM
        img_hr=torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')).detach().numpy().transpose(1, 2, 0)
        img_pred_WPPSU=torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu')).detach().numpy().transpose(1, 2, 0)
        list_ssim_WPPSU[j].append(ssim(img_hr, img_pred_WPPSU,data_range=img_pred_WPPSU.max() - img_pred_WPPSU.min(),multichannel=True))
        print(i)

os.chdir('/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/WPP/OT_SUOT')
torch.save(list_psnr_WPPSU,"list_psnr_WPPSU_color")
torch.save(list_lpips_WPPSU,"list_lpips_WPPSU_color")
torch.save(list_ssim_WPPSU,"list_ssim_WPPSU_color")
torch.save(list_im_rest_WPPSU,"list_im_rest_WPPSU_color")