In [None]:
run config

In [None]:
import numpy as np
from tqdm.notebook import tqdm
import os

from networks import *

import torch
from torch.utils.data import Dataset, DataLoader

from skimage.metrics import peak_signal_noise_ratio
PSNR= peak_signal_noise_ratio

In [None]:
##########################Enter device and the model you wish to test##############################################

In [None]:
device='cuda:0'

model_name = "Img2Img_Mixer" #Enter same name used when training

model = Img2Img_Mixer(
        
        img_size = 256,   #Image Size (assumed to be square image), here 256 x 256
        img_channels = 3, #Image Channels, 3 for RGB, 1 for greyscale
        patch_size = 4,   #Patch Size, P
        embed_dim = 128,  #Embedding Dimension, C
        num_layers = 16,  #Number of Mixer Layers, N
        f_hidden = 4,     #Multiplication Factor for Hidden Dimensions, f
)

In [None]:
###################################################################################################################

In [None]:
##Load paths 
clean_val= data_path + 'clean_val/'
noisy_val= data_path + 'noisy_val/'

In [None]:
##Prepare data
class data():
    
    def __init__(self, path_clean, path_noisy):
        self.path_clean = path_clean
        self.path_noisy = path_noisy
        
    def __len__(self):
        return len(os.listdir(self.path_clean))
    
    def __getitem__(self, idx):
        
        data= dict()
        data['clean']= torch.load(self.path_clean + '{0:05}'.format(idx))
        data['noisy']= torch.load(self.path_noisy + '{0:05}'.format(idx))

        return data
 

In [None]:
#Entire validation set
val_set=data(clean_val, noisy_val)

#Dataloader
val_dl = DataLoader(val_set, batch_size=1, shuffle=False)

In [None]:
model = model.to(device)
best_path = models_path +'best_' + model_name + '.pth'     ##path to trained model
checkpoint = torch.load(best_path)
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
#reconstruct function
def denoise(model, sample): 
    model.eval()
    
    with torch.no_grad(): 

        noisy = sample['noisy'].to(device)
        
        #################### get the prediction ##############################
       
        pred = model(noisy)
        img = torch.clamp(noisy-pred, 0, 1)                
        
        
    return img

In [None]:
#PSNR of denoised images

psnr=0

with tqdm(total=len(val_dl)) as pbar:
    for sample in val_dl: 
        pred=denoise(model,sample).detach().cpu().squeeze(0).numpy()
        original=sample['clean'].squeeze(0).numpy()
        psnr+= PSNR(pred,original)
        pbar.update(1)

print("PSNR of the denoised images is: ",psnr/len(val_dl))        

In [None]:
#PSNR of noisy images

psnr=0

with tqdm(total=len(val_dl)) as pbar:
    for sample in val_dl: 
        pred= sample['noisy'].squeeze(0).numpy()
        original=sample['clean'].squeeze(0).numpy()
        psnr+= PSNR(pred,original)
        pbar.update(1)

print("PSNR of the nosiy images is: ",psnr/len(val_dl))            