In [2]:
import torch
from pykeops.torch import LazyTensor
import numpy as np
from torchvision.transforms import Resize as tv_resize
from PIL import Image
import time
from torch import nn
import math
import skimage.io as io
import os
import skimage.transform
import skimage.metrics as sm
from skimage.metrics import structural_similarity as ssim
from scipy.interpolate import griddata
import lpips
import torchvision
import glob
import argparse
import wgenpatex as wgenpatex
from ROT_RSUOT_RUOT import ROT,RSUOT

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [1]:
import sys
  
# Prints the list of directories that the 
# interpreter will search for the required module. 
print(sys.path)

sys.path.insert(0, "/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM")

['/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/WPP/ROT_RSUOT', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python37.zip', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/lib-dynload', '', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/site-packages', '/home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/site-packages/IPython/extensions', '/home/prof/smignon/.ipython']


### Functions for images

In [3]:
def PSNR(im,im_new): 
    '''
    Compute PSNR
    '''
    M,N=im_new.shape
    EQM=1/(M*N)*torch.sum((im-im_new)**2)
    psnr=10*torch.log10(1/EQM)
    return(psnr)

def show(im_deb,col=False):
    '''
    show image 
    '''
    im_deb=im_deb.clone().detach()
    im_deb[im_deb>1]=1
    im_deb[im_deb<0]=0
    if col==True:
        im_deb=(255*im_deb.permute(1, 2, 0)).type(dtype=torch.uint8)
    else:
        im_deb=(255*im_deb).type(dtype=torch.uint8)
    imgpil = Image.fromarray(im_deb.numpy()) 
    display(imgpil)
    
loss_fn_alex = lpips.LPIPS(net='alex') # compute LPIPS

    
def sinkhorn_super_resolution(operator, high_resolution_image,low_resolution_image,init,
                              loss_fct,lbd, niters, patch_size,
                              n_patches_out, device, verbose,lr):
    
    # parameters for Gaussian downsampling
    gaussian_kernel_size = 4
    gaussian_std = 1
    stride = 2
    n_scales=2
    
    # Downsampling operators for the high resolution reference image (target_downsampler) and x (x_downsampler)
    target_downsampler = wgenpatex.create_gaussian_pyramid(gaussian_kernel_size, gaussian_std, n_scales+1, stride, pad=False,dim=2)                  
    x_downsampler = wgenpatex.create_gaussian_pyramid(gaussian_kernel_size, gaussian_std, n_scales+1, stride, pad=False,dim=2)

    # Initialization of x
    x = torch.tensor(init[np.newaxis,np.newaxis,:,:],dtype=torch.float,device=device).requires_grad_() 
    y = low_resolution_image
    
    # Gaussian downsampling of the high resolution reference image
    target_downsampler(high_resolution_image)
    
    # Downsampling operators for the high resolution reference image (target_im2pat) and x (input_im2pat)
    target_im2pat = wgenpatex.patch_extractor(patch_size, pad=False,center=False,dim=2)
    input_im2pat = wgenpatex.patch_extractor(patch_size, pad=False,center=False,dim=2)
    
    # Exctract patches from the high resolution reference image
    nuM = target_im2pat(target_downsampler[0].down_img, n_patches_out).contiguous()
    nuM_ds = target_im2pat(target_downsampler[1].down_img, n_patches_out).contiguous()
    
    # Set the optimizer
    optimizer = torch.optim.Adam([x], lr=lr)
    
    # Initialise dual variables
    fg_i=None
    fg_i_ds=None
    
    # Initialise computation time
    torch.cuda.synchronize()
    t = time.time()
    
    # Gradient descent 
    for i in range(niters):
        # Zero the gradients
        optimizer.zero_grad()
        
        
        # Create gaussian pyramid from x
        x_downsampler(x)
        
        # Evaluate OT cost at each scale 
        # scale L=0:
        nuX = input_im2pat(x_downsampler[0].down_img, -1,split=[1,0]).contiguous()
        l1,fg_i=loss_fct(nuX, nuM,fg_i) #OTSD

        # scale L=1:
        nuX_ds = input_im2pat(x_downsampler[1].down_img, -1,split=[1,0]).contiguous()
        l1_ds,fg_i_ds=loss_fct(nuX_ds,nuM_ds,fg_i_ds) #OTSD

        # Evaluate data attachment term
        l2 = torch.sum((operator(x)-y)**2)
        
        # Evaluate cost function
        l = 1/2*l1+1/2*l1_ds + 1/2*lbd*l2
        
        # Compute the gradient
        l.backward()
        
        if verbose==True:
            print('OT:', "{:.10f}".format(0.5*(l1+l1_ds)),'attache:', "{:.10f}".format(0.5*lbd*l2),'l:', "{:.10f}".format(l))
            print('i=',i)
            print('-------------------------------------------------------------------')
        
        
        # Update x
        optimizer.step()
        
    # Computation time 
    torch.cuda.synchronize()
    print('DONE - total time is '+str(int(time.time()-t))+'s')
    
    return x

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /home/prof/smignon/anaconda3/envs/WPP_color/lib/python3.7/site-packages/lpips/weights/v0.1/alex.pth


### Results of Table 1 and Figure 5

In [6]:
# device 
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(2)

# images PATH
os.chdir('/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/Datasets/18_images_wa_woa_dataset')  

list_im_name   = [file for file in glob.glob("HR/*.png")]#.sort()
list_im_modele_with_a   = [file for file in glob.glob("HR_with_anomalies/*.png")]#.sort()
list_im_modele_without_a   = [file for file in glob.glob("HR_without_anomalies/*.png")]#.sort()

list_im_name.sort()
list_im_modele_with_a.sort()
list_im_modele_without_a.sort()
list_im_modele=[list_im_modele_with_a,list_im_modele_without_a]

# Initialise lists
#PSNR [with defects, without defects]
list_psnr_ROT=[[],[]]
list_psnr_RSUOT=[[],[]]

# im_rest
list_im_restored_ROT=[[],[]]
list_im_restored_RSUOT=[[],[]]

# loss
list_Loss_ROT=[[],[]]
list_Loss_RSUOT=[[],[]]

# lr img
list_im_LR=[[],[]]

#  img
list_im_HR=[[],[]]

# LPIPS
list_lpips_ROT=[[],[]]
list_lpips_RSUOT=[[],[]]

# LPIPS
list_ssim_ROT=[[],[]]
list_ssim_RSUOT=[[],[]]

# Super Resolution

# Data attachment 
lamb=(36/6000)*(600**2/256**2)

# for i,name_im in enumerate(list_im_name): # RUN ALL IMAGES 
for i,name_im in enumerate([list_im_name[0]]): # TEST ON SINGLE IMAGE
    for j in range(2):
        # set arguments
        args=argparse.Namespace()
        args.target_image_path=name_im
        args.learn_image_path=list_im_modele[j][i]
        args.patch_size=6
        args.n_patches_out=10000
        args.n_patches_in=-1


        # Define forward operator
        blur_width=2.0
        add_boundary=0 # no artificial boundary 
        kernel_size=16
        stride=4
        my_layer=wgenpatex.gaussian_layer(kernel_size,blur_width,stride=stride)

        def operator(inp):
            if add_boundary==0:
                return my_layer.forward(inp)
            return my_layer.forward(inp[:,:,add_boundary:-add_boundary,add_boundary:-add_boundary])
        
        # Reproductibility
        torch.manual_seed(i)
        
        # Read HR ground truth
        hr_img=wgenpatex.imread(args.target_image_path)
        hr_img=tv_resize(256, antialias=True)((0.2989 * hr_img[:,0,:, :] + 0.5870 * hr_img[:,1, :, :] + 0.1140 * hr_img[:,2, :, :]).unsqueeze(1))
        
        # create (artificially) LR observation
        lr_img=wgenpatex.imread(args.target_image_path)
        lr_img=tv_resize(256, antialias=True)((0.2989 * lr_img[:,0,:, :] + 0.5870 * lr_img[:,1, :, :] + 0.1140 * lr_img[:,2, :, :]).unsqueeze(1))
        args.size=lr_img.shape[2:4]
        lr_img_=np.zeros((lr_img.shape[2]+2*add_boundary,lr_img.shape[3]+2*add_boundary))
        
        if add_boundary>0:
            lr_img_[add_boundary:-add_boundary,add_boundary:-add_boundary]=lr_img.squeeze().cpu().numpy()
        else:
            lr_img_=lr_img.squeeze().cpu().numpy()

        lr_img=operator(torch.tensor(lr_img_,dtype=torch.float,device=DEVICE).view(1,1,lr_img_.shape[0],lr_img_.shape[1]))
        lr_img+=0.01*torch.randn_like(lr_img)


        # build initialization by rescaling the lr observation and extending it to the boundary
        upscaled=skimage.transform.resize(lr_img.squeeze().cpu().numpy(),[lr_img.shape[2]*stride,lr_img.shape[3]*stride])
        diff=args.size[0]-upscaled.shape[0]

        init=np.zeros(args.size,dtype=bool)
        init[diff//2:-diff//2,diff//2:-diff//2]=True
        grid_x=np.array(range(init.shape[0]))
        grid_x=np.tile(grid_x[:,np.newaxis],[1,init.shape[1]])
        grid_y=np.array(range(init.shape[1]))
        grid_y=np.tile(grid_y[np.newaxis,:],[init.shape[0],1])
        points_x=np.reshape(grid_x[init],[-1])
        points_y=np.reshape(grid_y[init],[-1])
        values=np.reshape(upscaled,[-1])
        points=np.stack([points_x,points_y],0).transpose()
        init=griddata(points,values,(grid_x,grid_y),method='nearest')
        init_=np.random.uniform(size=(init.shape[0]+2*add_boundary,init.shape[1]+2*add_boundary))
        if add_boundary==0:
            init_=init
        else:
            init_[add_boundary:-add_boundary,add_boundary:-add_boundary]=init
        args.size=init_.shape
        
        # Read HR reference image 
        learn_img=wgenpatex.imread(args.learn_image_path)
        learn_img=tv_resize(256, antialias=True)((0.2989 * learn_img[:,0,:, :] + 0.5870 * learn_img[:,1, :, :] + 0.1140 * learn_img[:,2, :, :]).unsqueeze(1))
        
        # Define ROT
        def R_OT(x,y,fg_init):
            return(ROT(x,y,ε=1e-4,fg_init=fg_init,nb_it=10,dev=DEVICE))
    
        # Define RSUOT
        def RSU_OT(x,y,f_init):
            return(RSUOT(x,y,ε=1e-4,ρ=0.01,f_init=f_init,nb_it=10,dev=DEVICE))

        # Super resolution with ROT
        im_deb_ROT=sinkhorn_super_resolution(operator=operator, high_resolution_image=learn_img,
                                                  low_resolution_image=lr_img,init=init_,
                                                  loss_fct=R_OT,lbd=lamb, 
                                                  niters=500, patch_size=6,
                                                  n_patches_out=10000,device=DEVICE,verbose=False,lr=0.01)
        # Super resolution with RSUOT
        im_deb_RSUOT=sinkhorn_super_resolution(operator=operator, high_resolution_image=learn_img,
                                                  low_resolution_image=lr_img,init=init_,
                                                  loss_fct=RSU_OT,lbd=lamb, 
                                                  niters=500, patch_size=6,
                                                  n_patches_out=10000,device=DEVICE,verbose=False,lr=0.001)

        print("Etape :",i)

        # Add measures to each list 
        # PSNR
        list_psnr_ROT[j].append(PSNR(torchvision.transforms.CenterCrop(args.size[0]-12)(hr_img.squeeze().to('cpu')),
                                torchvision.transforms.CenterCrop(args.size[0]-12)(im_deb_ROT.squeeze().to('cpu'))).item())
        list_psnr_RSUOT[j].append(PSNR(torchvision.transforms.CenterCrop(args.size[0]-12)(hr_img.squeeze().to('cpu')),
                                torchvision.transforms.CenterCrop(args.size[0]-12)(im_deb_RSUOT.squeeze().to("cpu"))).item())
        
        # LPIPS
        list_lpips_ROT[j].append(loss_fn_alex(torchvision.transforms.CenterCrop(args.size[0]-12)(hr_img.squeeze().to('cpu')).unsqueeze(0).unsqueeze(0)
                                          , torchvision.transforms.CenterCrop(args.size[0]-12)(im_deb_ROT.squeeze().to('cpu')).unsqueeze(0).unsqueeze(0)))
        list_lpips_RSUOT[j].append(loss_fn_alex(torchvision.transforms.CenterCrop(args.size[0]-12)(hr_img.squeeze().to('cpu')).unsqueeze(0).unsqueeze(0)
                                          , torchvision.transforms.CenterCrop(args.size[0]-12)(im_deb_RSUOT.squeeze().to('cpu')).unsqueeze(0).unsqueeze(0)))
        
        # SSIM
        img_hr=torchvision.transforms.CenterCrop(args.size[0]-12)(hr_img.squeeze().to('cpu')).detach().numpy()
        img_pred_ROT=torchvision.transforms.CenterCrop(args.size[0]-12)(im_deb_ROT.squeeze().to('cpu')).detach().numpy()
        img_pred_RSUOT=torchvision.transforms.CenterCrop(args.size[0]-12)(im_deb_RSUOT.squeeze().to('cpu')).detach().numpy()
        
        list_ssim_ROT[j].append(ssim(img_hr, img_pred_ROT,data_range=img_pred_ROT.max() - img_pred_ROT.min()))
        list_ssim_RSUOT[j].append(ssim(img_hr, img_pred_RSUOT,data_range=img_pred_RSUOT.max() - img_pred_RSUOT.min()))
        
        # Resulting images 
        list_im_restored_ROT[j].append(im_deb_ROT.clone().to('cpu'))
        list_im_restored_RSUOT[j].append(im_deb_RSUOT.clone().to('cpu'))
        
        # LR images 
        list_im_LR[j].append(lr_img.clone().to('cpu'))
        
        # HR images 
        list_im_HR[j].append(hr_img.clone().to('cpu'))

# Save results 
os.chdir('/home/prof/smignon/ot_patch_denoising/Wasserstein_Patch_Prior/GitHub_SIAM/WPP/ROT_RSUOT/')

# PSNR
torch.save(list_psnr_ROT,"list_psnr_ROT")
torch.save(list_psnr_RSUOT,"list_psnr_RSUOT")

# LPIPS
torch.save(list_lpips_ROT,"list_lpips_ROT")
torch.save(list_lpips_RSUOT,"list_lpips_RSUOT")

# SSIM
torch.save(list_ssim_ROT,"list_ssim_ROT")
torch.save(list_ssim_RSUOT,"list_ssim_RSUOT")

# RESTORED IMAGES
torch.save(list_im_restored_ROT,"list_im_restored_ROT")
torch.save(list_im_restored_RSUOT,"list_im_restored_RSUOT")

# LR IMAGES
torch.save(list_im_LR,"list_im_LR")

# LR IMAGES
torch.save(list_im_HR,"list_im_HR")

DONE - total time is 229s
DONE - total time is 229s
Etape : 0


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


DONE - total time is 229s
DONE - total time is 229s
Etape : 0
