# WPP

In [2]:
import sys
  
# Prints the list of directories that the 
# interpreter will search for the required module. 
print(sys.path)

sys.path.insert(0, "/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM") 

['/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/WPP/OT_SUOT', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python37.zip', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/lib-dynload', '', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/site-packages', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/site-packages/IPython/extensions', '/home/prof/smignon/.ipython']


In [3]:
# This code belongs to the paper
#
# J. Hertrich, A. Houdard and C. Redenbach.
# Wasserstein Patch Prior for Image Superresolution.
# IEEE Transactions on Computational Imaging, 2022.
#
# Please cite the paper, if you use this code.
#
# This script applies the Wasserstein Patch Prior reconstruction onto the 2D SiC Diamonds image
# from Section 4.2 of the paper.
#
import argparse
import wgenpatex as wgenpatex
import torch
import skimage.transform
from skimage import (color, data, measure)
from skimage.metrics import structural_similarity as ssim
import numpy as np
from scipy.interpolate import griddata
import os
import lpips
import glob
from torchvision.transforms import Resize as tv_resize
import torchvision

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
torch.cuda.set_device(2)

# images PATH
os.chdir('/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/Datasets/18_images_wa_woa_dataset')  

list_im_name   = [file for file in glob.glob("HR/*.png")]#.sort()
list_im_modele_with_a   = [file for file in glob.glob("HR_with_anomalies/*.png")]#.sort()
list_im_modele_without_a   = [file for file in glob.glob("HR_without_anomalies/*.png")]#.sort()

list_im_name.sort()
list_im_modele_with_a.sort()
list_im_modele_without_a.sort()
list_im_modele=[list_im_modele_with_a,list_im_modele_without_a]


# lists
#PSNR
def PSNR(im,im_new):
    M,N=im_new.shape
    EQM=1/(M*N)*torch.sum((im-im_new)**2)
    psnr=10*torch.log10(1/EQM)
    return(psnr)

list_psnr_WPP=[[],[]]

# Restored iamges
list_im_rest_WPP=[[],[]]

# lr img
list_im_LR=[[],[]]

# LPIPS
loss_fn_alex = lpips.LPIPS(net='alex')
list_lpips_WPP=[[],[]]

# SSIM
list_ssim_WPP=[[],[]]

# for i,name_im in enumerate(list_im_name): # RUN ALL IMAGES 
for i,name_im in enumerate([list_im_name[0]]): # TEST ON SINGLE IMAGE
    for j in range(2):
        # set arguments
        args=argparse.Namespace()
        args.target_image_path=name_im
        args.scales=2
        args.keops=True
        args.n_iter_max=500
        args.save=True
        args.n_patches_out=10000
        args.learn_image_path=list_im_modele[j][i]
        args.patch_size=6
        args.lam=(6000/args.patch_size**2)*(256**2/600**2)
        args.n_iter_psi=10
        args.n_patches_in=-1
        args.visu=False


        # define forward operator
        blur_width=2.0
        add_boundary=20
        kernel_size=16
        stride=4
        my_layer=wgenpatex.gaussian_layer(kernel_size,blur_width,stride=stride)

        def operator(inp):
            if add_boundary==0:
                return my_layer.forward(inp)
            return my_layer.forward(inp[:,:,add_boundary:-add_boundary,add_boundary:-add_boundary])
        
        torch.manual_seed(i) # reproductibility
        
        # read HR ground truth
        hr_img=wgenpatex.imread(args.target_image_path)
        hr_img=tv_resize(256, antialias=True)((0.2989 * hr_img[:,0,:, :] + 0.5870 * hr_img[:,1, :, :] + 0.1140 * hr_img[:,2, :, :]).unsqueeze(1))

        lr_img=wgenpatex.imread(args.target_image_path)
        lr_img=tv_resize(256, antialias=True)((0.2989 * lr_img[:,0,:, :] + 0.5870 * lr_img[:,1, :, :] + 0.1140 * lr_img[:,2, :, :]).unsqueeze(1))
        args.size=lr_img.shape[2:4]
        lr_img_=np.zeros((lr_img.shape[2]+2*add_boundary,lr_img.shape[3]+2*add_boundary))
        if add_boundary>0:
            lr_img_[add_boundary:-add_boundary,add_boundary:-add_boundary]=lr_img.squeeze().cpu().numpy()
        else:
            lr_img_=lr_img.squeeze().cpu().numpy()

        # create (artificially) LR observation
        lr_img=operator(torch.tensor(lr_img_,dtype=torch.float,device=DEVICE).view(1,1,lr_img_.shape[0],lr_img_.shape[1]))
        lr_img+=0.01*torch.randn_like(lr_img)
        #wgenpatex.imsave('input_imgs/lr_diam.png',lr_img)


        # build initialization by rescaling the lr observation and extending it to the boundary
        upscaled=skimage.transform.resize(lr_img.squeeze().cpu().numpy(),[lr_img.shape[2]*stride,lr_img.shape[3]*stride])
        diff=args.size[0]-upscaled.shape[0]

        init=np.zeros(args.size,dtype=bool)
        init[diff//2:-diff//2,diff//2:-diff//2]=True
        grid_x=np.array(range(init.shape[0]))
        grid_x=np.tile(grid_x[:,np.newaxis],[1,init.shape[1]])
        grid_y=np.array(range(init.shape[1]))
        grid_y=np.tile(grid_y[np.newaxis,:],[init.shape[0],1])
        points_x=np.reshape(grid_x[init],[-1])
        points_y=np.reshape(grid_y[init],[-1])
        values=np.reshape(upscaled,[-1])
        points=np.stack([points_x,points_y],0).transpose()
        init=griddata(points,values,(grid_x,grid_y),method='nearest')
        init_=np.random.uniform(size=(init.shape[0]+2*add_boundary,init.shape[1]+2*add_boundary))
        if add_boundary==0:
            init_=init
        else:
            init_[add_boundary:-add_boundary,add_boundary:-add_boundary]=init
        args.size=init_.shape

        # load learn img
        learn_img=wgenpatex.imread(args.learn_image_path)
        learn_img=tv_resize(256, antialias=True)((0.2989 * learn_img[:,0,:, :] + 0.5870 * learn_img[:,1, :, :] + 0.1140 * learn_img[:,2, :, :]).unsqueeze(1))
        print(learn_img.dtype)

        # run reconstruction
        synth_img= wgenpatex.optim_synthesis(args,operator,lr_img,learn_img,args.lam,init=init_,add_boundary=add_boundary)

        if add_boundary>0:
            synth_img=synth_img[:,:,add_boundary:-add_boundary,add_boundary:-add_boundary]

        list_im_rest_WPP[j].append(synth_img)
        list_psnr_WPP[j].append(PSNR(torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')),
                                torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu'))))

        list_lpips_WPP[j].append(loss_fn_alex(torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')).unsqueeze(0).unsqueeze(0)
                                          , torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu')).unsqueeze(0).unsqueeze(0)))
        
        # SSIM
        img_hr=torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')).detach().numpy()
        img_pred_WPP=torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu')).detach().numpy()
        list_ssim_WPP[j].append(ssim(img_hr, img_pred_WPP,data_range=img_pred_WPP.max() - img_pred_WPP.min()))
        print(i)

os.chdir('/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/WPP/OT_SUOT') 
torch.save(list_psnr_WPP,"list_psnr_WPP")
torch.save(list_lpips_WPP,"list_lpips_WPP")
torch.save(list_ssim_WPP,"list_ssim_WPP")
torch.save(list_im_rest_WPP,"list_im_rest_WPP")

  from .autonotebook import tqdm as notebook_tqdm


cuda
cuda
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/site-packages/lpips/weights/v0.1/alex.pth




torch.float32




iteration 0 - elapsed 0s - loss = 0.44148455901630224, wloss = 0.4005368944872171, oloss = 0.04094766452908516




iteration 50 - elapsed 9s - loss = 0.04654542612284285, wloss = 0.043637916207312166, oloss = 0.0029075099155306816




iteration 100 - elapsed 20s - loss = 0.03125158735773326, wloss = 0.028111100813735135, oloss = 0.003140486543998122




iteration 150 - elapsed 31s - loss = 0.030907877694779984, wloss = 0.027623163044417254, oloss = 0.00328471465036273




iteration 200 - elapsed 43s - loss = 0.03210480267442506, wloss = 0.02882691904231649, oloss = 0.003277883632108569




iteration 250 - elapsed 54s - loss = 0.033695402495496296, wloss = 0.030404930408977293, oloss = 0.003290472086519003




iteration 300 - elapsed 66s - loss = 0.03513225962653621, wloss = 0.03183794769704207, oloss = 0.0032943119294941425




iteration 350 - elapsed 78s - loss = 0.03659563499765284, wloss = 0.033288613774644205, oloss = 0.0033070212230086327




iteration 400 - elapsed 90s - loss = 0.037704988060468736, wloss = 0.034357910886356535, oloss = 0.0033470771741122007




iteration 450 - elapsed 103s - loss = 0.03873749871503662, wloss = 0.03535787926752221, oloss = 0.003379619447514415
DONE - total time is 115s


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


0
torch.float32




iteration 0 - elapsed 0s - loss = 0.48349132731202993, wloss = 0.44254366278294477, oloss = 0.04094766452908516




iteration 50 - elapsed 9s - loss = 0.04594942640699351, wloss = 0.04332298848032856, oloss = 0.0026264379266649485




iteration 100 - elapsed 20s - loss = 0.02415741889215539, wloss = 0.02156030924468766, oloss = 0.0025971096474677324




iteration 150 - elapsed 31s - loss = 0.02388948324880147, wloss = 0.021335330855087875, oloss = 0.0025541523937135935




iteration 200 - elapsed 43s - loss = 0.024628483321876615, wloss = 0.022113825058923453, oloss = 0.002514658262953162




iteration 250 - elapsed 54s - loss = 0.025605161593844628, wloss = 0.023124555045356487, oloss = 0.00248060654848814




iteration 300 - elapsed 66s - loss = 0.026567173671594446, wloss = 0.024108244587353056, oloss = 0.0024589290842413902




iteration 350 - elapsed 78s - loss = 0.027392507313507508, wloss = 0.024938914940449308, oloss = 0.0024535923730582




iteration 400 - elapsed 90s - loss = 0.028234400252600267, wloss = 0.025773860678796723, oloss = 0.002460539573803544




iteration 450 - elapsed 103s - loss = 0.028639746014761158, wloss = 0.026169171673046776, oloss = 0.002470574341714382
DONE - total time is 115s
0


# WPPSU

In [4]:
# Adaptated from
#
# J. Hertrich, A. Houdard and C. Redenbach.
# Wasserstein Patch Prior for Image Superresolution.
# IEEE Transactions on Computational Imaging, 2022.
#
#

import argparse
import wgenpatex_SU as wgenpatex
import torch
import skimage.transform
import numpy as np
from scipy.interpolate import griddata
import os
import lpips
import glob
from torchvision.transforms import Resize as tv_resize
from skimage import (color, data, measure)
from skimage.metrics import structural_similarity as ssim
import torchvision 

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
torch.cuda.set_device(2)

# images PATH
os.chdir('/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/Datasets/18_images_wa_woa_dataset')  

list_im_name   = [file for file in glob.glob("HR/*.png")]#.sort()
list_im_modele_with_a   = [file for file in glob.glob("HR_with_anomalies/*.png")]#.sort()
list_im_modele_without_a   = [file for file in glob.glob("HR_without_anomalies/*.png")]#.sort()

list_im_name.sort()
list_im_modele_with_a.sort()
list_im_modele_without_a.sort()
list_im_modele=[list_im_modele_with_a,list_im_modele_without_a]

# lists
#PSNR
def PSNR(im,im_new):
    M,N=im_new.shape
    EQM=1/(M*N)*torch.sum((im-im_new)**2)
    psnr=10*torch.log10(1/EQM)
    return(psnr)

list_psnr_WPPSU=[[],[]]

# im_rest
list_im_rest_WPPSU=[[],[]]

# LR_img
list_im_LR=[[],[]]

# LPIPS
loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
list_lpips_WPPSU=[[],[]]

# SSIM
list_ssim_WPPSU=[[],[]]

# for i,name_im in enumerate(list_im_name): # RUN ALL IMAGES 
for i,name_im in enumerate([list_im_name[0]]): # TEST ON SINGLE IMAGE
    for j in range(2):
        # set arguments
        args=argparse.Namespace()
        args.target_image_path=name_im
        args.scales=2
        args.keops=True
        args.n_iter_max=500
        args.save=True
        args.n_patches_out=10
        args.learn_image_path=list_im_modele[j][i]
        args.patch_size=6
        args.lam=(6000/args.patch_size**2)*(256**2/600**2)
        args.n_iter_psi=10
        args.n_patches_in=-1
        args.visu=False


        # define forward operator
        blur_width=2.0
        add_boundary=0
        kernel_size=16
        stride=4
        my_layer=wgenpatex.gaussian_layer(kernel_size,blur_width,stride=stride)

        def operator(inp):
            if add_boundary==0:
                return my_layer.forward(inp)
            return my_layer.forward(inp[:,:,add_boundary:-add_boundary,add_boundary:-add_boundary])
        
        torch.manual_seed(i) # reproductibility
        
        # read HR ground truth
        hr_img=wgenpatex.imread(args.target_image_path)
        hr_img=tv_resize(256, antialias=True)((0.2989 * hr_img[:,0,:, :] + 0.5870 * hr_img[:,1, :, :] + 0.1140 * hr_img[:,2, :, :]).unsqueeze(1))

        lr_img=wgenpatex.imread(args.target_image_path)
        lr_img=tv_resize(256, antialias=True)((0.2989 * lr_img[:,0,:, :] + 0.5870 * lr_img[:,1, :, :] + 0.1140 * lr_img[:,2, :, :]).unsqueeze(1))
        args.size=lr_img.shape[2:4]
        lr_img_=np.zeros((lr_img.shape[2]+2*add_boundary,lr_img.shape[3]+2*add_boundary))
        if add_boundary>0:
            lr_img_[add_boundary:-add_boundary,add_boundary:-add_boundary]=lr_img.squeeze().cpu().numpy()
        else:
            lr_img_=lr_img.squeeze().cpu().numpy()

        # create (artificially) LR observation
        lr_img=operator(torch.tensor(lr_img_,dtype=torch.float,device=DEVICE).view(1,1,lr_img_.shape[0],lr_img_.shape[1]))
        lr_img+=0.01*torch.randn_like(lr_img)
        #wgenpatex.imsave('input_imgs/lr_diam.png',lr_img)


        # build initialization by rescaling the lr observation and extending it to the boundary
        upscaled=skimage.transform.resize(lr_img.squeeze().cpu().numpy(),[lr_img.shape[2]*stride,lr_img.shape[3]*stride])
        diff=args.size[0]-upscaled.shape[0]

        init=np.zeros(args.size,dtype=bool)
        init[diff//2:-diff//2,diff//2:-diff//2]=True
        grid_x=np.array(range(init.shape[0]))
        grid_x=np.tile(grid_x[:,np.newaxis],[1,init.shape[1]])
        grid_y=np.array(range(init.shape[1]))
        grid_y=np.tile(grid_y[np.newaxis,:],[init.shape[0],1])
        points_x=np.reshape(grid_x[init],[-1])
        points_y=np.reshape(grid_y[init],[-1])
        values=np.reshape(upscaled,[-1])
        points=np.stack([points_x,points_y],0).transpose()
        init=griddata(points,values,(grid_x,grid_y),method='nearest')
        init_=np.random.uniform(size=(init.shape[0]+2*add_boundary,init.shape[1]+2*add_boundary))
        if add_boundary==0:
            init_=init
        else:
            init_[add_boundary:-add_boundary,add_boundary:-add_boundary]=init
        args.size=init_.shape

        # load learn img
        learn_img=wgenpatex.imread(args.learn_image_path)
        learn_img=tv_resize(256, antialias=True)((0.2989 * learn_img[:,0,:, :] + 0.5870 * learn_img[:,1, :, :] + 0.1140 * learn_img[:,2, :, :]).unsqueeze(1))
        print(learn_img.dtype)

        # run reconstruction
        synth_img= wgenpatex.optim_synthesis(args,operator,lr_img,learn_img,args.lam,ρ=0.01,init=init_,add_boundary=add_boundary)


        if add_boundary>0:
            synth_img=synth_img[:,:,add_boundary:-add_boundary,add_boundary:-add_boundary]
        
        list_im_rest_WPPSU[j].append(synth_img)
        list_psnr_WPPSU[j].append(PSNR(torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')),
                                torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu'))))
        # LPIPS
        list_lpips_WPPSU[j].append(loss_fn_alex(torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')).unsqueeze(0).unsqueeze(0)
                                          , torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu')).unsqueeze(0).unsqueeze(0)))
        
        # SSIM
        img_hr=torchvision.transforms.CenterCrop(256-12)(hr_img.squeeze().to('cpu')).detach().numpy()
        img_pred_WPPSU=torchvision.transforms.CenterCrop(256-12)(synth_img.squeeze().to('cpu')).detach().numpy()
        list_ssim_WPPSU[j].append(ssim(img_hr, img_pred_WPPSU,data_range=img_pred_WPPSU.max() - img_pred_WPPSU.min()))
        
        print(i)

os.chdir('/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/WPP/OT_SUOT') 
torch.save(list_psnr_WPPSU,"list_psnr_WPPSU")
torch.save(list_lpips_WPPSU,"list_lpips_WPPSU")
torch.save(list_im_rest_WPPSU,"list_im_rest_WPPSU")
torch.save(list_ssim_WPPSU,"list_ssim_WPPSU")

cuda
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/site-packages/lpips/weights/v0.1/alex.pth
torch.float32




iteration 0 - elapsed 0s - loss = 0.34382583311526105, wloss = 0.3028781685861759, oloss = 0.04094766452908516




iteration 50 - elapsed 9s - loss = 0.3154049883596599, wloss = 0.2895002025179565, oloss = 0.025904785841703415




iteration 100 - elapsed 17s - loss = 0.31282070570159703, wloss = 0.28705770254600793, oloss = 0.025763003155589104




iteration 150 - elapsed 26s - loss = 0.31316940812394023, wloss = 0.28739070473238826, oloss = 0.02577870339155197




iteration 200 - elapsed 35s - loss = 0.31577573635149747, wloss = 0.29001202166546136, oloss = 0.02576371468603611




iteration 250 - elapsed 44s - loss = 0.3123007846297696, wloss = 0.28661222953815013, oloss = 0.02568855509161949




iteration 300 - elapsed 53s - loss = 0.3148857973283157, wloss = 0.2891425866400823, oloss = 0.025743210688233376




iteration 350 - elapsed 62s - loss = 0.31289385887794197, wloss = 0.2871284845750779, oloss = 0.025765374302864075




iteration 400 - elapsed 70s - loss = 0.31422740733250976, wloss = 0.2885154443792999, oloss = 0.025711962953209877




iteration 450 - elapsed 79s - loss = 0.31279169500339776, wloss = 0.2870479404227808, oloss = 0.02574375458061695
DONE - total time is 88s




0
torch.float32




iteration 0 - elapsed 0s - loss = 0.3510368379938882, wloss = 0.310089173464803, oloss = 0.04094766452908516




iteration 50 - elapsed 8s - loss = 0.31653955433284864, wloss = 0.2914858855656348, oloss = 0.02505366876721382




iteration 100 - elapsed 17s - loss = 0.31465076887980103, wloss = 0.28950416715815663, oloss = 0.0251466017216444




iteration 150 - elapsed 26s - loss = 0.3163250861107372, wloss = 0.2911794920801185, oloss = 0.025145594030618668




iteration 200 - elapsed 35s - loss = 0.313397447578609, wloss = 0.28824157547205687, oloss = 0.025155872106552124




iteration 250 - elapsed 44s - loss = 0.3143176601151936, wloss = 0.28916387230856344, oloss = 0.025153787806630135




iteration 300 - elapsed 52s - loss = 0.31415976368589327, wloss = 0.28902642213506624, oloss = 0.025133341550827026




iteration 350 - elapsed 61s - loss = 0.3170935559319332, wloss = 0.2919520299183205, oloss = 0.025141526013612747




iteration 400 - elapsed 70s - loss = 0.31541887333150953, wloss = 0.29028618743177503, oloss = 0.025132685899734497




iteration 450 - elapsed 79s - loss = 0.31778490875149146, wloss = 0.2926421031006612, oloss = 0.02514280565083027
DONE - total time is 87s
0
