# WPP

In [2]:
import sys
  
# Prints the list of directories that the 
# interpreter will search for the required module. 
print(sys.path)

sys.path.insert(0, "/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM")

['/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/WPP/OT_SUOT', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python37.zip', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/lib-dynload', '', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/site-packages', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/site-packages/IPython/extensions', '/home/prof/smignon/.ipython']


In [3]:
# This code belongs to the paper
#
# J. Hertrich, A. Houdard and C. Redenbach.
# Wasserstein Patch Prior for Image Superresolution.
# IEEE Transactions on Computational Imaging, 2022.
#
# Please cite the paper, if you use this code.
#
# This script applies the Wasserstein Patch Prior reconstruction onto the 2D SiC Diamonds image
# from Section 4.2 of the paper.
#
import argparse
import wgenpatex_color as wgenpatex
import torch
import skimage.transform
from skimage import (color, data, measure)
from skimage.metrics import structural_similarity as ssim
import numpy as np
from scipy.interpolate import griddata
import os
import lpips
import glob
import torchvision
from torchvision.transforms import Resize as tv_resize

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
torch.cuda.set_device(2)

# images PATH
os.chdir('/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/Datasets/18_images_wa_woa_dataset')  

list_im_name   = [file for file in glob.glob("HR/*.png")]#.sort()
list_im_modele_with_a   = [file for file in glob.glob("HR_with_anomalies/*.png")]#.sort()
list_im_modele_without_a   = [file for file in glob.glob("HR_without_anomalies/*.png")]#.sort()

list_im_name.sort()
list_im_modele_with_a.sort()
list_im_modele_without_a.sort()
list_im_modele=[list_im_modele_with_a,list_im_modele_without_a]

# Lists
#PSNR 
def PSNR(im,im_new):
    C,M,N=im_new.shape
    EQM=1/(C*M*N)*torch.sum((im-im_new)**2)
    psnr=10*torch.log10(1/EQM)
    return(psnr)

list_psnr_WPP=[[],[]]

# restored images 
list_im_rest_WPP=[[],[]]

# lr img
list_im_LR=[[],[]]

# LPIPS
loss_fn_alex = lpips.LPIPS(net='alex')
list_lpips_WPP=[[],[]]

# SSIM
list_ssim_WPP=[[],[]]

# for i,name_im in enumerate(list_im_name): # RUN ALL IMAGES 
for i,name_im in enumerate([list_im_name[0]]): # TEST ON SINGLE IMAGE
    for j in range(2):
        # set arguments
        args=argparse.Namespace()
        args.target_image_path=name_im
        args.scales=2
        args.keops=True
        args.n_iter_max=500
        args.save=True
        args.n_patches_out=10000
        args.learn_image_path=list_im_modele[j][i]
        args.patch_size=6
        args.lam=(6000/(3*args.patch_size**2))*(256**2/600**2)
        args.n_iter_psi=10
        args.n_patches_in=-1
        args.visu=False


        # define forward operator
        blur_width=2.0
        add_boundary=20
        kernel_size=16
        stride=4
        my_layer=wgenpatex.gaussian_layer(kernel_size,blur_width,stride=stride)

        def operator(inp):
            if add_boundary==0:
                return my_layer.forward(inp)
            return my_layer.forward(inp[:,:,add_boundary:-add_boundary,add_boundary:-add_boundary])
        
        torch.manual_seed(i) # reproductibilité de l'expérience
        
        # read HR ground truth
        hr_img=wgenpatex.imread(args.target_image_path)
        hr_img=tv_resize(256, antialias=True)(hr_img)

        lr_img=wgenpatex.imread(args.target_image_path)
        lr_img=tv_resize(256, antialias=True)(lr_img)
        args.size=lr_img.shape[2:4]
        lr_img_=np.zeros((3,lr_img.shape[2]+2*add_boundary,lr_img.shape[3]+2*add_boundary))
        if add_boundary>0:
            lr_img_[:,add_boundary:-add_boundary,add_boundary:-add_boundary]=lr_img.squeeze().cpu().numpy()
        else:
            lr_img_=lr_img.squeeze().cpu().numpy()

        # create (artificially) LR observation
        lr_img=operator(torch.tensor(lr_img_,dtype=torch.float,device=DEVICE).view(1,3,lr_img_.shape[1],lr_img_.shape[2]))
        lr_img+=0.01*torch.randn_like(lr_img)


        # build initialization by rescaling the lr observation and extending it to the boundary
        upscaled=skimage.transform.resize(lr_img.squeeze().cpu().numpy(),[3,lr_img.shape[2]*stride,lr_img.shape[3]*stride])
        diff=args.size[0]-upscaled.shape[1]

        init=np.zeros((3,args.size[0],args.size[1]),dtype=bool)
        init[:,diff//2:-diff//2,diff//2:-diff//2]=True
        grid_x=np.array(range(init.shape[1]))
        grid_x=np.tile(grid_x[:,np.newaxis],[1,init.shape[2]])
        grid_y=np.array(range(init.shape[2]))
        grid_y=np.tile(grid_y[np.newaxis,:],[init.shape[1],1])
        points_x=np.reshape(grid_x[init[0,:,:]],[-1])
        points_y=np.reshape(grid_y[init[0,:,:]],[-1])
        init=init.astype(float)
        for k in range(3):
            values=np.reshape(upscaled[k,:,:],[-1])
            points=np.stack([points_x,points_y],0).transpose()
            init[k,:,:]=griddata(points,values,(grid_x,grid_y),method='nearest')
            
        init_=np.random.uniform(size=(3,init.shape[1]+2*add_boundary,init.shape[2]+2*add_boundary))
        if add_boundary==0:
            init_=init
        else:
            init_[:,add_boundary:-add_boundary,add_boundary:-add_boundary]=init
        args.size=init_.shape

        # load learn img
        learn_img=wgenpatex.imread(args.learn_image_path)
        learn_img=tv_resize(256, antialias=True)(learn_img)
        print(learn_img.dtype)

        # run reconstruction
        synth_img= wgenpatex.optim_synthesis_SR(args,operator,lr_img,learn_img,args.lam,init=init_,add_boundary=add_boundary)
        
        if add_boundary>0:
            synth_img=synth_img[:,:,add_boundary:-add_boundary,add_boundary:-add_boundary]

        list_im_rest_WPP[j].append(synth_img)
        # PSNR
        list_psnr_WPP[j].append(PSNR(torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')),
                                torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu'))))
        # LPIPS
        list_lpips_WPP[j].append(loss_fn_alex(torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')).unsqueeze(0)
                                          , torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu')).unsqueeze(0)))
        
        # SSIM
        img_hr=torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')).detach().numpy().transpose(1, 2, 0)
        img_pred_WPP=torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu')).detach().numpy().transpose(1, 2, 0)
        list_ssim_WPP[j].append(ssim(img_hr, img_pred_WPP,data_range=img_pred_WPP.max() - img_pred_WPP.min(),multichannel=True))
        print(i)
        
        #wgenpatex.imsave('output_imgs_diam/synthesized_with_bound.png', synth_img)

os.chdir('/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/WPP/OT_SUOT') 
torch.save(list_psnr_WPP,"list_psnr_WPP_color")
torch.save(list_lpips_WPP,"list_lpips_WPP_color")
torch.save(list_ssim_WPP,"list_ssim_WPP_color")
torch.save(list_im_rest_WPP,"list_im_rest_WPP_color")

  from .autonotebook import tqdm as notebook_tqdm


cuda
cuda
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/site-packages/lpips/weights/v0.1/alex.pth
torch.float32
iteration 0 - elapsed 2s - loss = 1.7010601460933685
iteration 50 - elapsed 27s - loss = 0.16055934876203537
iteration 100 - elapsed 53s - loss = 0.10671623796224594
iteration 150 - elapsed 79s - loss = 0.11182770598679781
iteration 200 - elapsed 105s - loss = 0.11609230376780033
iteration 250 - elapsed 132s - loss = 0.11927298456430435
iteration 300 - elapsed 158s - loss = 0.12256767321377993
iteration 350 - elapsed 185s - loss = 0.12541600223630667
iteration 400 - elapsed 212s - loss = 0.12815236300230026
iteration 450 - elapsed 239s - loss = 0.13089385628700256
DONE - total time is 265s


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


0
torch.float32
iteration 0 - elapsed 0s - loss = 1.813114732503891
iteration 50 - elapsed 25s - loss = 0.1541540939360857
iteration 100 - elapsed 51s - loss = 0.08210498373955488
iteration 150 - elapsed 77s - loss = 0.0814661243930459
iteration 200 - elapsed 104s - loss = 0.08433889597654343
iteration 250 - elapsed 130s - loss = 0.08783263806253672
iteration 300 - elapsed 157s - loss = 0.09105996135622263
iteration 350 - elapsed 184s - loss = 0.093702451325953
iteration 400 - elapsed 211s - loss = 0.09561012964695692
iteration 450 - elapsed 237s - loss = 0.0973976356908679
DONE - total time is 255s
0


# WPPSU

In [3]:
# This code belongs to the paper
#
# J. Hertrich, A. Houdard and C. Redenbach.
# Wasserstein Patch Prior for Image Superresolution.
# IEEE Transactions on Computational Imaging, 2022.
#
# Please cite the paper, if you use this code.
#
# This script applies the Wasserstein Patch Prior reconstruction onto the 2D SiC Diamonds image
# from Section 4.2 of the paper.
#
import argparse
import wgenpatex_color_SU as wgenpatex
import torch
import skimage.transform
from skimage import (color, data, measure)
from skimage.metrics import structural_similarity as ssim
import numpy as np
from scipy.interpolate import griddata
import os
import lpips
import glob
import torchvision
from torchvision.transforms import Resize as tv_resize

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
torch.cuda.set_device(2)

# images PATH
os.chdir('/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/Datasets/18_images_wa_woa_dataset')  

list_im_name   = [file for file in glob.glob("HR/*.png")]#.sort()
list_im_modele_with_a   = [file for file in glob.glob("HR_with_anomalies/*.png")]#.sort()
list_im_modele_without_a   = [file for file in glob.glob("HR_without_anomalies/*.png")]#.sort()

list_im_name.sort()
list_im_modele_with_a.sort()
list_im_modele_without_a.sort()
list_im_modele=[list_im_modele_with_a,list_im_modele_without_a]

# Lists
#PSNR 
def PSNR(im,im_new):
    C,M,N=im_new.shape
    EQM=1/(C*M*N)*torch.sum((im-im_new)**2)
    psnr=10*torch.log10(1/EQM)
    return(psnr)

list_psnr_WPPSU=[[],[]]

# restored images 
list_im_rest_WPPSU=[[],[]]

# lr img
list_im_LR=[[],[]]

# LPIPS
loss_fn_alex = lpips.LPIPS(net='alex')
list_lpips_WPPSU=[[],[]]

# SSIM
list_ssim_WPPSU=[[],[]]

# SSIM
list_ssim_WPPSU=[[],[]]
# for i,name_im in enumerate(list_im_name): # RUN ALL IMAGES 
for i,name_im in enumerate([list_im_name[0]]): # TEST ON SINGLE IMAGE
    for j in range(2):
        # set arguments
        args=argparse.Namespace()
        args.target_image_path=name_im
        args.scales=2
        args.keops=True
        args.n_iter_max=500
        args.save=True
        args.n_patches_out=10000
        args.learn_image_path=list_im_modele[j][i]
        args.patch_size=6
        args.lam=(6000/(3*args.patch_size**2))*(256**2/600**2)
        args.n_iter_psi=10
        args.n_patches_in=-1
        args.visu=False


        # define forward operator
        blur_width=2.0
        add_boundary=0
        kernel_size=16
        stride=4
        my_layer=wgenpatex.gaussian_layer(kernel_size,blur_width,stride=stride)

        def operator(inp):
            if add_boundary==0:
                return my_layer.forward(inp)
            return my_layer.forward(inp[:,:,add_boundary:-add_boundary,add_boundary:-add_boundary])
        
        torch.manual_seed(i) # reproductibilité de l'expérience
        
        # read HR ground truth
        hr_img=wgenpatex.imread(args.target_image_path)
        hr_img=tv_resize(256, antialias=True)(hr_img)

        lr_img=wgenpatex.imread(args.target_image_path)
        lr_img=tv_resize(256, antialias=True)(lr_img)
        args.size=lr_img.shape[2:4]
        lr_img_=np.zeros((3,lr_img.shape[2]+2*add_boundary,lr_img.shape[3]+2*add_boundary))
        if add_boundary>0:
            lr_img_[:,add_boundary:-add_boundary,add_boundary:-add_boundary]=lr_img.squeeze().cpu().numpy()
        else:
            lr_img_=lr_img.squeeze().cpu().numpy()

        # create (artificially) LR observation
        lr_img=operator(torch.tensor(lr_img_,dtype=torch.float,device=DEVICE).view(1,3,lr_img_.shape[1],lr_img_.shape[2]))
        lr_img+=0.01*torch.randn_like(lr_img)
        


        # build initialization by rescaling the lr observation and extending it to the boundary
        upscaled=skimage.transform.resize(lr_img.squeeze().cpu().numpy(),[3,lr_img.shape[2]*stride,lr_img.shape[3]*stride])
        diff=args.size[0]-upscaled.shape[1]

        init=np.zeros((3,args.size[0],args.size[1]),dtype=bool)
        init[:,diff//2:-diff//2,diff//2:-diff//2]=True
        grid_x=np.array(range(init.shape[1]))
        grid_x=np.tile(grid_x[:,np.newaxis],[1,init.shape[2]])
        grid_y=np.array(range(init.shape[2]))
        grid_y=np.tile(grid_y[np.newaxis,:],[init.shape[1],1])
        points_x=np.reshape(grid_x[init[0,:,:]],[-1])
        points_y=np.reshape(grid_y[init[0,:,:]],[-1])
        init=init.astype(float)
        for k in range(3):
            values=np.reshape(upscaled[k,:,:],[-1])
            points=np.stack([points_x,points_y],0).transpose()
            init[k,:,:]=griddata(points,values,(grid_x,grid_y),method='nearest')
            #print(init.shape)
        init_=np.random.uniform(size=(3,init.shape[1]+2*add_boundary,init.shape[2]+2*add_boundary))
        if add_boundary==0:
            init_=init
        else:
            init_[:,add_boundary:-add_boundary,add_boundary:-add_boundary]=init
        args.size=init_.shape

        # load learn img
        learn_img=wgenpatex.imread(args.learn_image_path)
        learn_img=tv_resize(256, antialias=True)(learn_img)
        print(learn_img.dtype)

       # run reconstruction
        synth_img= wgenpatex.optim_synthesis_SR(args,operator,lr_img,learn_img,args.lam,ρ=0.01,init=init_,add_boundary=add_boundary)

        # save reconstruction
        #if not os.path.isdir('output_imgs_diam'):
            #os.mkdir('output_imgs_diam')

        #wgenpatex.imsave('output_imgs_diam/synthesized_no_crop_bound.png', synth_img)
        if add_boundary>0:
            synth_img=synth_img[:,:,add_boundary:-add_boundary,add_boundary:-add_boundary]
        # Restored image 
        list_im_rest_WPPSU[j].append(synth_img)
        # PSNR
        list_psnr_WPPSU[j].append(PSNR(torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')),
                                torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu'))))
        # LPIPS
        list_lpips_WPPSU[j].append(loss_fn_alex(torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')).unsqueeze(0)
                                          , torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu')).unsqueeze(0)))
        
        # SSIM
        img_hr=torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')).detach().numpy().transpose(1, 2, 0)
        img_pred_WPPSU=torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu')).detach().numpy().transpose(1, 2, 0)
        list_ssim_WPPSU[j].append(ssim(img_hr, img_pred_WPPSU,data_range=img_pred_WPPSU.max() - img_pred_WPPSU.min(),multichannel=True))
        print(i)

os.chdir('/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/WPP/OT_SUOT')
torch.save(list_psnr_WPPSU,"list_psnr_WPPSU_color")
torch.save(list_lpips_WPPSU,"list_lpips_WPPSU_color")
torch.save(list_ssim_WPPSU,"list_ssim_WPPSU_color")
torch.save(list_im_rest_WPPSU,"list_im_rest_WPPSU_color")

  from .autonotebook import tqdm as notebook_tqdm


cuda
cuda
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/site-packages/lpips/weights/v0.1/alex.pth
torch.float32
iteration 0 - elapsed 1s - loss = 0.45976512506604195
iteration 50 - elapsed 23s - loss = 0.06566169019788504
iteration 100 - elapsed 44s - loss = 0.06556749250739813
iteration 150 - elapsed 66s - loss = 0.06591822300106287
iteration 200 - elapsed 87s - loss = 0.06607427354902029
iteration 250 - elapsed 109s - loss = 0.06619521789252758
iteration 300 - elapsed 131s - loss = 0.06629220396280289
iteration 350 - elapsed 152s - loss = 0.06636372581124306
iteration 400 - elapsed 173s - loss = 0.0664227232336998
iteration 450 - elapsed 195s - loss = 0.06646533124148846
DONE - total time is 216s


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


0
torch.float32
iteration 0 - elapsed 0s - loss = 0.46226460114121437
iteration 50 - elapsed 22s - loss = 0.053009345196187496
iteration 100 - elapsed 43s - loss = 0.052411168813705444
iteration 150 - elapsed 65s - loss = 0.05260974634438753
iteration 200 - elapsed 86s - loss = 0.05266841221600771
iteration 250 - elapsed 104s - loss = 0.052698975428938866
iteration 300 - elapsed 121s - loss = 0.052756570279598236
iteration 350 - elapsed 152s - loss = 0.0527811162173748
iteration 400 - elapsed 187s - loss = 0.05277923494577408
iteration 450 - elapsed 222s - loss = 0.052807172760367393
DONE - total time is 256s
0
