In [1]:
import os
import numpy as np
import torch

import cv2
from skimage.metrics import structural_similarity as compare_ssim
from runpy import run_path

!pip install einops
!pip install hdf5storage
import hdf5storage
import pandas as pd 

In [None]:
def load_model():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    parameters = {'inp_channels':1, 'out_channels':1, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'BiasFree', 'dual_pixel_task':False}
    load_arch = run_path('/content/drive/MyDrive/SRFP/Restormer/basicsr/models/archs/restormer_arch.py')      #os.path.join('models', 'restormer_arch.py'))
    model = load_arch['Restormer'](**parameters)
    checkpoint = torch.load('/content/drive/MyDrive/SRFP/DPIR/model_zoo/gaussian_gray_denoising_sigma50.pth') #('/content/drive/MyDrive/SRFP/DPIR/model_zoo/gaussian_gray_denoising_sigma15.pth') # #torch.load('/content/drive/MyDrive/SRFP/DPIR/model_zoo/gaussian_color_denoising_blind.pth')
    model.load_state_dict(checkpoint['params'], strict=True)
    model.eval()
    for k, v in model.named_parameters():

        v.requires_grad = False
    model = model.to(device)
    return model    

In [None]:
import math
import os
def calculate_psnr(img1, img2):
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    mse = np.mean((img1 - img2)**2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(255.0 / math.sqrt(mse))

In [None]:
def proj(im_input, minval, maxval):
    im_out = np.where(im_input > maxval, maxval, im_input)
    im_out = np.where(im_out < minval, minval, im_out)
    return im_out

def psnr(x,im_orig):
    norm2 = np.mean((x - im_orig) ** 2)
    psnr = -10 * np.log10(norm2)
    return psnr

def funcAtranspose(im_input, mask, fx, fy):
    m,n = im_input.shape
    fx = int(1/fx)
    fy = int(1/fy)
    im_inputres = np.zeros([m*fx, m*fy], im_input.dtype)
    for i in range(m):
        for j in range(n):
            im_inputres[fx*i,fy*j] = im_input[i,j]
 
    m,n = im_inputres.shape
    w = len(mask[0])
    r = int((w - 1) / 2)
    im_inputres = cv2.copyMakeBorder(im_inputres, r, r, r, r, borderType=cv2.BORDER_WRAP)
    im_output = cv2.filter2D(im_inputres, -1, mask)
    im_output = im_output[r:r+m, r:r+n]
    return im_output

def funcA(im_input, mask, fx, fy):
    m,n = im_input.shape
    w = len(mask[0])
    r = int((w - 1) / 2)
    im_input = cv2.copyMakeBorder(im_input, r, r, r, r, borderType=cv2.BORDER_WRAP)
    im_output = cv2.filter2D(im_input, -1, mask)
    im_output = im_output[r:r+m, r:r+n]
    im_outputres = cv2.resize(im_output, (0,0), fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)
    return im_outputres

def np_fftconvolve(A, B):
    m1,n1 = A.shape
    m2,n2 = B.shape
    m3 = m1 + m2 -1
    n3 = n1 + n2 -1
    A2 = np.zeros([m3, n3])
    B2 = np.zeros([m3, n3])
    A2[:m1, :n1] = A
    B2[:m2, :n2] = B
    return np.real(np.fft.ifft2(np.fft.fft2(A2)*np.fft.fft2(B2)))

def createAAtranspose(mask, fx, fy):
    m,n = mask.shape
    maskrot = np.rot90(mask,2)
    maskinter = np_fftconvolve(mask, maskrot)
    maskdown = cv2.resize(maskinter, (0,0), fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)
    return maskdown

def proxrhof(y, filtfft, rho, mask, fx, fy, xhat):
    b = y + rho*xhat
    vf = np.fft.fft2(funcA(b, mask, fx, fy))
    vf2 = np.real(np.fft.ifft2(np.divide(vf, filtfft + rho)))
    vf3 = funcAtranspose(vf2, mask, fx, fy)
    x = (b - vf3) / rho
    return x

import math
import os
def calculate_psnr(img1, img2):
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    mse = np.mean((img1 - img2)**2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(255.0 / math.sqrt(mse))

In [None]:
def pnp_admm_superresolution(model, im_input, im_ref, fx, fy, mask, **opts):

    lamda = opts.get('lamda', 2.0)
    rho = opts.get('rho', 1.0)   
    maxitr = opts.get('maxitr', 100)
    verbose = opts.get('verbose',1)
    sigma = opts.get('sigma', 5)
    max_list = []

    """ Initialization. """

    y = funcAtranspose(im_input, mask, fx, fy)
    m, n = y.shape
    maskdown = createAAtranspose(mask, fx, fy)
    filtfft = np.absolute(np.fft.fft2(maskdown, im_input.shape))
    v = cv2.resize(im_input, (m, n))
    x = np.copy(v)
    u = np.zeros((m,n), dtype=np.float64)

    """ Main loop. """
    for i in range(maxitr):

        xold = np.copy(x)
        vold = np.copy(v)
        uold = np.copy(u)

        """ Update variables. """

        x = proxrhof(y, filtfft, rho, mask, fx, fy, v-u)

        """ Denoising step. """

        vtilde = np.copy(x+u)
        
        # This is the proposed denoiser part
        vtilde_torch = np.reshape(vtilde, (1,1,m,n))
        vtilde_torch = torch.from_numpy(vtilde_torch).type(torch.FloatTensor).cuda()
        r = model(vtilde_torch).cpu().numpy()
        r = np.reshape(r, (m,n))
        r = proj(r, -2., 2.)
        v = 0.5*r + 0.5*vtilde
        
        """ Update variables. """
        u = uold + x - v


        """ Monitoring. """
        max_list.append( psnr(x,im_ref) )
        index = max_list.index(np.max(max_list))
        if verbose:
            print("i: {}, \t psnr: {}  ssim= {}"\
                  .format(i+1, psnr(x,im_ref), compare_ssim(x, im_ref, data_range=1.)))

    return v,np.max(max_list),index

In [None]:
def iterate(input_array, rho_ = 0.05, itr_ = 30):
    with torch.no_grad():

        # ---- load the ground truth ----
        im_orig = input_array #cv2.imread(input_str, 0)/255.0
        m,n = im_orig.shape

        # ---- blur the image 
        kernel = cv2.getGaussianKernel(9, 1)
        mask = np.outer(kernel, kernel.transpose())
        w = len(mask[0])
        r = int((w - 1) / 2)
        im_orig = cv2.copyMakeBorder(im_orig, r, r, r, r, borderType=cv2.BORDER_WRAP)
        im_blur = cv2.filter2D(im_orig, -1, mask)
        im_blur = im_blur[r:r+m, r:r+n]
        im_orig = im_orig[r:r+m, r:r+m]

        # ---- Downsample the image
        fx = 1./K
        fy = 1./K 
        im_down = cv2.resize(im_blur, (0,0), fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)

        # ---- add noise -----
        noise_level = 5.0 / 255.0
        gauss = np.random.normal(0.0, noise_level, im_down.shape)
        im_noisy = im_down + gauss
        psnr_final = 0.

        # ---- set options -----
        sigma = 25
        rho = rho_
        maxiter = itr_
        # ---- load the model ----
        model = load_model()

        opts = dict(sigma = sigma, rho = rho, maxitr = maxiter, verbose = False)

        # ---- plug and play -----
        out,max, index= pnp_admm_superresolution(model, im_noisy, im_orig, fx, fy, mask, **opts)
        
        # ---- results ----
        #cv2.imwrite('lena_output.jpg', out * 255.0)
        
        psnr_ours = psnr(out, im_orig)
        ssim_ours = compare_ssim(out, im_orig, data_range=1.)
        print('sigma = {}, rho = {} - PNSR: {}, SSIM = {}   Max : {}, Index : {}'.format(sigma, rho, psnr_ours, ssim_ours,max,index))
    return  (out, sigma, rho, psnr_ours, ssim_ours,max,index)



In [None]:
def driver(rho_ = 10, itr_= 5):
  temp = '/content/drive/MyDrive/SRFP/DPIR/testsets/set3c/leaves.png'
  im_orig = cv2.imread(temp)/255.0

  #------------------------- splitting each channel 
  im_orig_r = im_orig[:,:,0]
  im_orig_g = im_orig[:,:,1]
  im_orig_b = im_orig[:,:,2]

  #-------------------------Predicting each channel 
  sum = 0
  avg_psnr = 0 
  _rho_ = rho_
  _itr_ = itr_
  out_r , sigma, rho, psnr_ours, ssim_ours,max,index = iterate(im_orig_r,rho_ = _rho_, itr_ = _itr_) ; sum = sum + index  
  avg_psnr = avg_psnr + max
  out_g , sigma, rho, psnr_ours, ssim_ours,max,index = iterate(im_orig_g,rho_ = _rho_, itr_ = _itr_) ; sum = sum + index 
  avg_psnr = avg_psnr + max
  out_b , sigma, rho, psnr_ours, ssim_ours,max,index = iterate(im_orig_b,rho_ = _rho_, itr_ = _itr_) ; sum = sum + index  
  avg_psnr = avg_psnr + max

  print('give_itr',math.ceil(sum/3), 'with_iter_set ****',avg_psnr/3)

  #-------------------------Merginng all channels together
  im_recon = np.ones(im_orig.shape, out_b.dtype)
  im_recon[:,:,0] = np.uint8((out_r*255.0).round())
  im_recon[:,:,1] = np.uint8((out_g*255.0).round())
  im_recon[:,:,2] = np.uint8((out_b*255.0).round()) 
  im_recon = im_recon.astype('uint8')

  #-------------------------Psnr
  print(im_recon.shape,cv2.imread(temp).shape )
  print(im_recon.dtype,cv2.imread(temp).dtype )
  psnr = calculate_psnr(im_recon,cv2.imread(temp))
  print('current psnr',psnr)
  
  cv2.imwrite('LADMMSUPRER50'+str(rho_)+'itr'+str(itr_)+str(psnr) +'.png', im_recon)




In [None]:
driver(rho_ = .15, itr_ =15)

sigma = 25, rho = 0.15 - PNSR: 26.213789147361865, SSIM = 0.9203361981597346   Max : 26.21401667973074, Index : 14
sigma = 25, rho = 0.15 - PNSR: 27.701649414243995, SSIM = 0.9118973419919121   Max : 27.70116621344204, Index : 14
sigma = 25, rho = 0.15 - PNSR: 27.180494419890056, SSIM = 0.9282622607266057   Max : 27.17935700491053, Index : 14
give_itr 14 with_iter_set **** 27.031513299361105
(256, 256, 3) (256, 256, 3)
uint8 uint8
current psnr 26.364167020938257


In [None]:
for rho in [1,2.5,3,3.5,5]:
  driver(rho_ = rho, itr_ = 50)

sigma = 25, rho = 1 - PNSR: 20.205993314912924, SSIM = 0.8232580524672649   Max : 22.181692922536385, Index : 20
sigma = 25, rho = 1 - PNSR: 20.3373677842296, SSIM = 0.7550241276216346   Max : 22.557792738865107, Index : 0
sigma = 25, rho = 1 - PNSR: 20.184256796729763, SSIM = 0.8120855415864356   Max : 23.006757379820662, Index : 21
give_itr 14 with_iter_set **** 22.58208101374072
(256, 256, 3) (256, 256, 3)
uint8 uint8
current psnr 12.25644946122211
sigma = 25, rho = 2.5 - PNSR: 19.02493028442742, SSIM = 0.767663725373298   Max : 19.42659270324515, Index : 0
sigma = 25, rho = 2.5 - PNSR: 16.60054281055459, SSIM = 0.6245441547233775   Max : 22.207477304370524, Index : 0
sigma = 25, rho = 2.5 - PNSR: 19.149837333227, SSIM = 0.7601724783642526   Max : 20.05970235520286, Index : 0
give_itr 0 with_iter_set **** 20.56459078760618
(256, 256, 3) (256, 256, 3)
uint8 uint8
current psnr 10.260286020386873
sigma = 25, rho = 3 - PNSR: 18.71961118651034, SSIM = 0.7583044380120568   Max : 19.380332