In [1]:
import os
import numpy as np
import torch

import cv2
from skimage.metrics import structural_similarity as compare_ssim
from runpy import run_path


In [None]:
!pip install einops
!pip install hdf5storage
import hdf5storage
import pandas as pd 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting hdf5storage
  Downloading hdf5storage-0.1.18-py2.py3-none-any.whl (53 kB)
[K     |████████████████████████████████| 53 kB 680 kB/s 
Installing collected packages: hdf5storage
Successfully installed hdf5storage-0.1.18


In [None]:
def load_model():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    parameters = {'inp_channels':1, 'out_channels':1, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'BiasFree', 'dual_pixel_task':False}
    load_arch = run_path('/content/drive/MyDrive/SRFP/Restormer/basicsr/models/archs/restormer_arch.py')      #os.path.join('models', 'restormer_arch.py'))
    model = load_arch['Restormer'](**parameters)
    checkpoint = torch.load('/content/drive/MyDrive/SRFP/DPIR/model_zoo/gaussian_gray_denoising_blind.pth') #('/content/drive/MyDrive/SRFP/DPIR/model_zoo/gaussian_gray_denoising_blind.pth'')  #('/content/drive/MyDrive/SRFP/DPIR/model_zoo/gaussian_gray_denoising_sigma50.pth') #('/content/drive/MyDrive/SRFP/DPIR/model_zoo/gaussian_gray_denoising_blind.pth')  #torch.load('/content/drive/MyDrive/SRFP/DPIR/model_zoo/gaussian_color_denoising_blind.pth')
    model.load_state_dict(checkpoint['params'], strict=True)
    model.eval()
    for k, v in model.named_parameters():

        v.requires_grad = False
    model = model.to(device)
    return model    

In [None]:
def proj(im_input, minval, maxval):
    im_out = np.where(im_input > maxval, maxval, im_input)
    im_out = np.where(im_out < minval, minval, im_out)
    return im_out

def psnr(x,im_orig):
    norm2 = np.mean((x - im_orig) ** 2)
    psnr = -10 * np.log10(norm2)
    return psnr

def funcAtranspose(im_input, mask, fx, fy):
    m,n = im_input.shape
    fx = int(1/fx)
    fy = int(1/fy)
    im_inputres = np.zeros([m*fx, m*fy], im_input.dtype)
    for i in range(m):
        for j in range(n):
            im_inputres[fx*i,fy*j] = im_input[i,j]
 
    m,n = im_inputres.shape
    w = len(mask[0])
    r = int((w - 1) / 2)
    im_inputres = cv2.copyMakeBorder(im_inputres, r, r, r, r, borderType=cv2.BORDER_WRAP)
    im_output = cv2.filter2D(im_inputres, -1, mask)
    im_output = im_output[r:r+m, r:r+n]
    return im_output

def funcA(im_input, mask, fx, fy):
    m,n = im_input.shape
    w = len(mask[0])
    r = int((w - 1) / 2)
    im_input = cv2.copyMakeBorder(im_input, r, r, r, r, borderType=cv2.BORDER_WRAP)
    im_output = cv2.filter2D(im_input, -1, mask)
    im_output = im_output[r:r+m, r:r+n]
    im_outputres = cv2.resize(im_output, (0,0), fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)
    #print('im_output',im_output.shape,'im_outputres',im_outputres.shape)
    return im_outputres

In [None]:
def pnp_fbs_superresolution(model, im_input, im_ref, fx, fy, mask, **opts):

    max_list = []
    lamda = opts.get('lamda', 2.0)
    rho = opts.get('rho', 1.0)
    maxitr = opts.get('maxitr', 100)
    verbose = opts.get('verbose',1)
    sigma = opts.get('sigma', 15)

    """ Initialization. """
    index = np.nonzero(mask)
   # print('im_input',im_input.shape)
    y = funcAtranspose(im_input, mask, fx, fy)
    #print('y',y.shape)
    m, n = y.shape

    x = cv2.resize(im_input, (m, n))
    #print('shape of x ',x.shape)

    """ Main loop. """
    for i in range(maxitr):

        xold = np.copy(x)

        """ Update gradient. """
        xoldhat = funcA(x, mask, fx, fy)
        gradx = funcAtranspose(xoldhat, mask, fx, fy) - y

        """ Denoising step. """

        xtilde = np.copy(xold - rho * gradx)

        xtilde_torch = np.reshape(xtilde, (1,1,m,n))
        xtilde_torch = torch.from_numpy(xtilde_torch).type(torch.FloatTensor).cuda()
        r = model(xtilde_torch).cpu().numpy()
        r = np.reshape(r, (m,n))
        x = 0.5*r + 0.5*xtilde
        x = proj(x, 0.0, 1.0)

        """ Monitoring. """
        max_list.append( psnr(x,im_ref) )
        index = max_list.index(np.max(max_list))
        if verbose:
            print("i: {}, \t psnr: {} ssim= {} "\
                  .format(i+1, psnr(x,im_ref), compare_ssim(x, im_ref, data_range=1.)))

    return x,np.max(max_list),index


In [None]:
def iterate(input_array, rho_ = 4.0, itr_ = 50):

  with torch.no_grad():

      K = 2 # downsampling factor
      # ---- load the ground truth ----
      im_orig = input_array
      # im_orig1 = cv2.imread(input_str,0)/255.0     
      # im_orig = im_orig1[:320,:320]#im_orig1[:480,:320]
      m,n = im_orig.shape

      # ---- blur the image 
      kernel = cv2.getGaussianKernel(9, 1)
      mask = np.outer(kernel, kernel.transpose())
      w = len(mask[0])
      r = int((w - 1) / 2)
      im_orig = cv2.copyMakeBorder(im_orig, r, r, r, r, borderType=cv2.BORDER_WRAP)
      im_blur = cv2.filter2D(im_orig, -1, mask)
      im_blur = im_blur[r:r+m, r:r+n]
      im_orig = im_orig[r:r+m, r:r+n]

      # ---- Downsample the image
      fx = 1./K
      fy = 1./K 
      im_down = cv2.resize(im_blur, (0,0), fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)

      # ---- add noise -----
      noise_level = 5.0 / 255.0
      gauss = np.random.normal(0.0, noise_level, im_down.shape)
      im_noisy = im_down + gauss
      psnr_final = 0.

      # ---- set options -----
      sigma = 50
      rho = rho_
      maxiter = itr_
      # ---- load the model ----
      model = load_model()#("DnCNN", sigma)

      opts = dict(sigma = sigma, rho = rho, maxitr = maxiter, verbose = False)

      # ---- plug and play -----
      out,max, index = pnp_fbs_superresolution(model, im_noisy, im_orig, fx, fy, mask, **opts)
      
      # ---- results ----
      psnr_ours = psnr(out, im_orig)
      ssim_ours = compare_ssim(out, im_orig, data_range=1.)
      print('sigma = {}, rho = {} - PNSR: {}, SSIM = {}  Max : {}, Index : {}'.format(sigma, rho, psnr_ours, ssim_ours,max,index))
  return  (out, sigma, rho, psnr_ours, ssim_ours,max,index)


In [None]:
import math
import os
def calculate_psnr(img1, img2):
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    mse = np.mean((img1 - img2)**2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(255.0 / math.sqrt(mse))

In [None]:
def driver(rho_ = 10, itr_= 5):
  temp = '/content/drive/MyDrive/SRFP/DPIR/testsets/set3c/leaves.png'
  im_orig = cv2.imread(temp)/255.0

  #------------------------- splitting each channel 
  im_orig_r = im_orig[:,:,0]
  im_orig_g = im_orig[:,:,1]
  im_orig_b = im_orig[:,:,2]

  #-------------------------Predicting each channel 
  sum = 0
  avg_psnr = 0 
  _rho_ = rho_
  _itr_ = itr_
  out_r , sigma, rho, psnr_ours, ssim_ours,max,index = iterate(im_orig_r,rho_ = _rho_, itr_ = _itr_) ; sum = sum + index  
  avg_psnr = avg_psnr + max
  out_g , sigma, rho, psnr_ours, ssim_ours,max,index = iterate(im_orig_g,rho_ = _rho_, itr_ = _itr_) ; sum = sum + index 
  avg_psnr = avg_psnr + max
  out_b , sigma, rho, psnr_ours, ssim_ours,max,index = iterate(im_orig_b,rho_ = _rho_, itr_ = _itr_) ; sum = sum + index  
  avg_psnr = avg_psnr + max

  print('give_itr',math.ceil(sum/3), 'with_iter_set ****',avg_psnr/3)

  #-------------------------Merginng all channels together
  im_recon = np.ones(im_orig.shape, out_b.dtype)
  im_recon[:,:,0] = np.uint8((out_r*255.0).round())
  im_recon[:,:,1] = np.uint8((out_g*255.0).round())
  im_recon[:,:,2] = np.uint8((out_b*255.0).round()) 
  im_recon = im_recon.astype('uint8')

  #-------------------------Psnr
  psnr = calculate_psnr(im_recon,cv2.imread(temp))
  print('current psnr',psnr)
  cv2.imwrite('LISTASUPERBLIND'+str(rho_)+'itr'+str(itr_)+str(psnr) +'.png', im_recon)
