# WPPNets

### Neural network (pre-trained by default) :

In [4]:
# This code belongs to the paper
#
# F. Altekrüger and J. Hertrich. 
# WPPNets and WPPFlows: The Power of Wasserstein Patch Priors for Superresolution. 
# ArXiv Preprint#2201.08157
#
# Please cite the paper, if you use the code.
#
# The script reproduces the numerical example with the textures 'Floor'
# and 'Grass' in the paper.

import torch
from torch import nn
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('agg')
import numpy as np
import os
import skimage.io as io
import model.small_acnet
import random
import utils
import argparse
from tqdm import tqdm
import time

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
torch.cuda.set_device(2)

def Downsample(scale = 0.25, gaussian_std = 2):
    ''' 
    downsamples an img by factor 4 using gaussian downsample from utils.py
    '''
    if scale > 1:
        print('Error. Scale factor is larger than 1.')
        return
    gaussian_std = gaussian_std
    kernel_size = 16
    gaussian_down = utils.gaussian_downsample(kernel_size,gaussian_std,int(1/scale),pad=True) #gaussian downsample with zero padding
    return gaussian_down.to(DEVICE)

def WLoss(args, input_img, ref_pat, model, psi):
    '''
    Computes the proposed wasserstein loss fct consisting of a MSELoss and a Wasserstein regularizer
    '''
    lam = args.lam
    n_patches_out = args.n_patches_out
    patch_size = args.patch_size
    n_iter_psi = args.n_iter_psi
    keops = args.keops
    
    im2patch = utils.patch_extractor(patch_size,center=args.center)
    
    num_ref = ref_pat.shape[0] #number of patches of reference image
    patch_weights = torch.ones(num_ref,device=DEVICE,dtype=torch.float) #same weight for all patches
    
    semidual_loss = utils.semidual(ref_pat,usekeops=keops) 
    semidual_loss.psi.data = psi #update the maximizer psi from previous step
    pred = model(input_img) #superresolution of input_img
    
    #sum up all patches of whole batch
    inp_pat = torch.empty(0, device = DEVICE)
    for k in range(pred.shape[0]):
        inp = im2patch(pred[k,:,:,:].unsqueeze(0)) #use all patches of input_img
        inp_pat = torch.cat([inp_pat,inp],0)
    inp = inp_pat
    
    #gradient ascent to find maximizer psi for dual formulation of W2^2
    optim_psi = torch.optim.ASGD([semidual_loss.psi], lr=1e-0, alpha=0.5, t0=1)
    for i in range(n_iter_psi):
        sem = -semidual_loss(inp,patch_weights)
        optim_psi.zero_grad()
        sem.backward(retain_graph=True)
        optim_psi.step()
    semidual_loss.psi.data = optim_psi.state[semidual_loss.psi]['ax']
    psi = semidual_loss.psi.data #update psi
    
    reg = semidual_loss(inp,patch_weights) #wasserstein regularizer 
    
    down_pred = operator(pred) #downsample pred by scale_factor

    loss_fct = nn.MSELoss()
    loss = loss_fct(down_pred,input_img) #||f(G(y)) - y||^2
    total_loss = loss + lam * reg
    
    return [total_loss,loss,lam*reg,psi]


def training(trainset, model, reference_img, batch_size, epochs, args, opti):
    '''
    training process
    '''
    numb_train_img = trainset.shape[0] #number of all img
    
    #create random batches:
    idx = torch.randperm(numb_train_img)
    batch_lr = [] #list of batches
    for i in range(0,numb_train_img,batch_size):
        batch_lr.append(trainset[i:(i+batch_size),...])
    
    #create maximizer psi
    psi_length = args.n_patches_out #length of vector psi
    psi_list = []
    for i in range(len(batch_lr)):
        psi_list.append(torch.zeros(psi_length, device = DEVICE)) #create a list consisting of psi

    #create random patches of reference image
    im2patch = utils.patch_extractor(args.patch_size,center=args.center)
    ref = im2patch(reference_img,args.n_patches_out)
    
    a_psnr_list = [] #for validation
    loss_list = []; reg_list = []; MSE_list = [] #for plot

    for t in tqdm(range(epochs)):
        a_totalloss = 0; a_MSE = 0; a_reg = 0
        ints = random.sample(range(0,len(batch_lr)),len(batch_lr)) #random order of batches
        for i in tqdm(ints):
            psi_temp = psi_list[i] #choose corresponding saved maximizer psi  
            [total_loss,loss,reg,p] = WLoss(args, batch_lr[i], ref, model, psi_temp)  
    
            #backpropagation
            opti.zero_grad()
            total_loss.backward()
            opti.step()
            
            total_loss = total_loss.item(); loss = loss.item(); reg = reg.item()
            a_totalloss += total_loss; a_MSE += loss; a_reg += reg
            psi_list[i] = p #update psi

        a_totalloss = a_totalloss/len(batch_lr); a_MSE = a_MSE/len(batch_lr); a_reg = a_reg/len(batch_lr)
        loss_list.append(a_totalloss); MSE_list.append(a_MSE); reg_list.append(a_reg)
        
        if not os.path.isdir('checkpoints'):
            os.mkdir('checkpoints')
        
        val_step = 10
        if (t+1)%val_step == 0:
            print(f'------------------------------- \nValidation step')
            val_len = len(args.val)
            a_psnr = 0
            for i in range(val_len):
                with torch.no_grad():
                    pred = net(args.val[i][0])
                psnr_val = utils.psnr(pred,args.val[i][1],40)
                a_psnr += psnr_val
            a_psnr = a_psnr / val_len
            print(f'Average Validation PSNR: {a_psnr}')    
            a_psnr_list.append(a_psnr)
            plt.plot(list(range(val_step,val_step*len(a_psnr_list)+val_step,val_step)),a_psnr_list, 'k')
            title = 'Avarage PSNR ' + str(round(a_psnr,2))
            plt.title(title)
            plt.savefig('checkpoints/ValidatonPSNR_'+image_class+'.pdf')
            plt.close()
            print(f'-------------------------------')
        
        #save a checkpoint
        if (t+1)%30 == 0:
            torch.save({'net_state_dict': model.state_dict()}, 'checkpoints/checkpoint_'+image_class+'.pth')
            with torch.no_grad():
                pred_hr = model(lr)
            if not os.path.isdir('checkpoints/tmp'):
                os.mkdir('checkpoints/tmp')
            utils.save_img(pred_hr,'checkpoints/tmp/pred'+str(t+1))
            plt.ylabel('Loss')
            plt.xlabel('Epoch')
            plt.plot(list(range(len(loss_list))), loss_list, 'k-.', label='avarage loss')
            plt.plot(list(range(len(MSE_list))), MSE_list, 'k-', label='avarage MSE')
            plt.plot(list(range(len(reg_list))), reg_list, 'k:', label='avarage Reg')
            plt.legend(loc='upper right')
            plt.yscale('log')
            plt.savefig('checkpoints/losscurve_'+image_class+'.pdf')
            plt.close()

retrain = False
if __name__ == '__main__':
    if not os.path.isdir('results'):
       os.mkdir('results')    
    
    net = model.small_acnet.Net(scale=4).to(device=DEVICE)
    image_classes = ['tile1','wood1']  
    image_class = image_classes[0] #choose the texture
    print('Superresolution for the texture ' + image_class)
    
    hr = utils.imread('test_img/hr_'+image_class+'.png')
    lr = utils.imread('test_img/lr_'+image_class+'.png')
    #  = operator(hr) + 0.01*torch.randn_like(operator(hr))
    if retrain:
        #inputs
        lr_train = utils.Trainset(image_class = image_class, size = 1000)
        val = utils.Validationset(image_class = image_class)	
        lr_size = lr_train.shape[2]
        operator = Downsample(scale = 1/4, gaussian_std = 2)

        args=argparse.Namespace()
        args.lam=12.5/lr_size**2
        args.n_patches_out=10000
        args.patch_size=6
        args.val = val
        args.keops = True

        if image_class == 'tile1':
            args.center = True
            epochs = 5
            args.n_iter_psi=20
        elif image_class == 'wood1':
            args.center = True
            epochs = 270
            args.n_iter_psi=20

        reference_img = utils.imread('test_img/ref_'+image_class+'.png')        
        
        #training process
        batch_size = 25
        learning_rate = 1e-4
        OPTIMIZER = torch.optim.Adam(net.parameters(), lr=learning_rate)    
        
        training(lr_train,net,reference_img,batch_size,epochs,args=args,opti=OPTIMIZER)
        with torch.no_grad():
            pred = net(lr)
        torch.save({'net_state_dict': net.state_dict(), 'optimizer_state_dict': OPTIMIZER.state_dict()},
                    'results/weights_'+image_class+'.pth')        
        utils.save_img(pred,'results/W2_'+image_class)
            
    if not retrain:
        weights = torch.load('results/weights_'+image_class+'.pth',map_location=DEVICE)
        net.load_state_dict(weights['net_state_dict'])
        pred = net(lr)
        utils.save_img(pred,'results/W2_'+image_class)








cuda
Superresolution for the texture tile1


### Average PSNR/SSIM/LPIPS scores obtained on the test set TILE or WOOD:

In [6]:
# PSNR, LPIPS, SSIM, Blur Effect moyen sur une base d'images constitué d'un seul type de textures WOOD1:
import numpy as np
import glob
from skimage import (color, data, measure)
from skimage.metrics import structural_similarity as ssim
import lpips
import torchvision

# Choose your image class 
image_class="wood"
#image_class="tile"

def PSNR(im,im_new):
    M,N=im_new.shape
    EQM=1/(M*N)*torch.sum((im-im_new)**2)
    psnr=10*torch.log10(1/EQM)
    return(psnr)

loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores

# emplacement des images 
liste_im_name_HR   = [file for file in glob.glob("test_img/test_img_WPPNets/"+image_class+"/HR/*.png")]#.sort()
liste_im_name_LR   = [file for file in glob.glob("test_img/test_img_WPPNets/"+image_class+"/LR/*.png")]#.sort()

liste_im_name_HR.sort()
liste_im_name_LR.sort()

# listes pour enregistrer PSNR, LPIPS, SSIM, Blur Effect 
PSNRs=[]
LPIPS=[]
SSIM=[]
Blue_Effect=[]
PRED=[]

# Chargement du bon NN
weights = torch.load('results/weights_'+image_class+'1.pth',map_location=DEVICE) 
weights = torch.load('results/weights_'+image_class+'1.pth',map_location=DEVICE) 
net.load_state_dict(weights['net_state_dict'])

L_t=[]
for i,name_im in enumerate(liste_im_name_HR):
    
    hr = utils.imread(liste_im_name_HR[i])
    lr = utils.imread(liste_im_name_LR[i])
    
    # Temps de calcul 
    torch.cuda.synchronize()
    t = time.time()
    
    pred = net(lr)
    
    # Computation time 
    torch.cuda.synchronize()
    temps=time.time()-t
    print('DONE - total time is '+str(temps)+'s')
    L_t.append(temps)
    
    # LPIPS
    LPIPS.append(loss_fn_alex(torchvision.transforms.CenterCrop(600-12)(hr.squeeze().to('cpu')).unsqueeze(0).unsqueeze(0)
                                      , torchvision.transforms.CenterCrop(600-12)(pred.squeeze().to('cpu')).unsqueeze(0).unsqueeze(0)))
    # PSNR
    PSNRs.append(PSNR(torchvision.transforms.CenterCrop(600-12)(hr.squeeze().to('cpu')),
                            torchvision.transforms.CenterCrop(600-12)(pred.squeeze().to('cpu'))))
    
    # SSIM
    img_hr= torchvision.transforms.CenterCrop(600-12)(hr.squeeze().to('cpu')).detach().numpy()
    img_pred= torchvision.transforms.CenterCrop(600-12)(pred.squeeze().to('cpu')).detach().numpy()
    SSIM.append(ssim(img_hr, img_pred,data_range=img_pred.max() - img_pred.min()))
    
    # im pred 
    PRED.append(pred)
    
# Moyenne des valeurs sur le jeu de données
print(torch.mean(torch.tensor(PSNRs)))
print(torch.mean(torch.tensor(LPIPS)))
print(np.mean(SSIM))

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /home/prof/smignon/anaconda3/envs/WPPNets/lib/python3.9/site-packages/lpips/weights/v0.1/alex.pth
DONE - total time is 0.013933658599853516s
DONE - total time is 0.012628555297851562s
DONE - total time is 0.013883352279663086s
DONE - total time is 0.013358831405639648s
DONE - total time is 0.012743949890136719s
DONE - total time is 0.012645244598388672s
DONE - total time is 0.012791156768798828s
DONE - total time is 0.012664318084716797s
DONE - total time is 0.01267385482788086s
DONE - total time is 0.012616157531738281s
DONE - total time is 0.012576103210449219s
DONE - total time is 0.012459278106689453s
DONE - total time is 0.013438940048217773s
DONE - total time is 0.012555599212646484s
DONE - total time is 0.012616872787475586s
DONE - total time is 0.012808799743652344s
DONE - total time is 0.012719869613647461s
DONE - total time is 0.012708187103271484s
DONE - total time is 0.01252508163452