# NAFNet: Nonlinear Activation Free Network

**Image Deblurring**

# 1. Setup

**Clone repo and install dependencies. **

In [23]:
!git clone https://github.com/megvii-research/NAFNet
%cd NAFNet

Cloning into 'NAFNet'...
remote: Enumerating objects: 517, done.[K
remote: Counting objects: 100% (161/161), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 517 (delta 140), reused 112 (delta 112), pack-reused 356[K
Receiving objects: 100% (517/517), 16.19 MiB | 39.01 MiB/s, done.
Resolving deltas: 100% (271/271), done.
/content/NAFNet/NAFNet/NAFNet


In [24]:
!pip install -r requirements.txt
!pip install --upgrade --no-cache-dir gdown
!python3 setup.py develop --no_cuda_ext

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
running develop
running egg_info
creating basicsr.egg-info
writing basicsr.egg-info/PKG-INFO
writing dependency_links to basicsr.egg-info/dependency_links.txt
writing top-level names to basicsr.egg-info/top_level.txt
writing manifest file 'basicsr.egg-info/SOURCES.txt'
adding license file 'LICENSE'
writing manifest file 'basicsr.egg-info/SOURCES.txt'
running build_ext
Creating /usr/local/lib/python3.8/dist-packages/basicsr.egg-link (link to .)
Removing basicsr 1.2.0+50cb149 from easy-install.pth file
Adding basicsr 1.2.0+50cb149 to easy-install.pth file

Installed /content/NAFNet/NAFNet/NAFNet
Processing dependencies for basicsr==1.2.0+50cb149
Finished processing dependencies for basicsr==1.2.0+50cb149


## 2. Download pretrained models

In [28]:
import gdown
gdown.download('https://drive.google.com/uc?id=14D4V4raNYIOhETfcuuLI3bGLB-OYIv6X', "./experiments/pretrained_models/", quiet=False) # deblurring

#gdown.download('https://drive.google.com/uc?id=14Fht1QQJ2gMlk4N1ERCRuElg8JfjrWWR', "./experiments/pretrained_models/", quiet=False) # denoising

Downloading...
From: https://drive.google.com/uc?id=14D4V4raNYIOhETfcuuLI3bGLB-OYIv6X
To: /content/NAFNet/NAFNet/NAFNet/experiments/pretrained_models/NAFNet-REDS-width64.pth
100%|██████████| 272M/272M [00:03<00:00, 72.4MB/s]


'./experiments/pretrained_models/NAFNet-REDS-width64.pth'

# 3. Mounting Google Drive to uplaod the images from

In [25]:
from google.colab import drive
drive.mount('/content/google-drive')

drive_dir = "/content/google-drive/Shareddrives/DeepLearning/skulldatasets"

input_dir = drive_dir + '/v1/input-blurry'
out_dir = drive_dir + '/v1/Nafnet-deblurring'

Drive already mounted at /content/google-drive; to attempt to forcibly remount, call drive.mount("/content/google-drive", force_remount=True).


# 4. Prepare Model and Load Checkpoint

In [29]:
import torch

from basicsr.models import create_model
from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite
from basicsr.utils.options import parse
import numpy as np
import cv2
import matplotlib.pyplot as plt

def imread(img_path):
  img = cv2.imread(img_path)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  return img
def img2tensor(img, bgr2rgb=False, float32=True):
    img = img.astype(np.float32) / 255.
    return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32)

def display(img1, img2):
  fig = plt.figure(figsize=(25, 10))
  ax1 = fig.add_subplot(1, 2, 1) 
  plt.title('Input image', fontsize=16)
  ax1.axis('off')
  ax2 = fig.add_subplot(1, 2, 2)
  plt.title('NAFNet output', fontsize=16)
  ax2.axis('off')
  ax1.imshow(img1)
  ax2.imshow(img2)

def single_image_inference(model, img, save_path):
      model.feed_data(data={'lq': img.unsqueeze(dim=0)})

      if model.opt['val'].get('grids', False):
          model.grids()

      model.test()

      if model.opt['val'].get('grids', False):
          model.grids_inverse()

      visuals = model.get_current_visuals()
      sr_img = tensor2img([visuals['result']])
      imwrite(sr_img, save_path)


In [30]:
opt_path = 'options/test/REDS/NAFNet-width64.yml'
#opt_path = 'options/test/SIDD/NAFNet-width64.yml'
opt = parse(opt_path, is_train=False)
opt['dist'] = False
NAFNet = create_model(opt)

 load net keys <built-in method keys of dict object at 0x7f36e586c940>


## 5. Inference

In [31]:
import glob
import os

input_list = sorted(glob.glob(os.path.join(input_dir, '*')))
for input_path in input_list:
  img_input = imread(input_path)
  inp = img2tensor(img_input)
  output_path = os.path.join(out_dir, os.path.basename(input_path))
  single_image_inference(NAFNet, inp, output_path)

## 6. Visualize Results

In [None]:
# visualize
input_list = sorted(glob.glob(os.path.join(input_dir, '*')))
output_list = sorted(glob.glob(os.path.join(out_dir, '*')))
for input_path, output_path in zip(input_list, output_list):
  img_input = imread(input_path)
  img_output = imread(output_path)
  display(img_input, img_output)

# 7. PSNR Measurement on patches

In [33]:
import os
import numpy as np
from torch.utils.data import Dataset
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from pdb import set_trace as stx
import random

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])


class DataLoaderVal(Dataset):
    def __init__(self, rgb_dir, input_dir, target_dir, img_options=None, rgb_dir2=None):
        super(DataLoaderVal, self).__init__()

        inp_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir)))
        tar_files = sorted(os.listdir(os.path.join(rgb_dir, target_dir)))

        self.inp_filenames = [os.path.join(rgb_dir, input_dir, x)  for x in inp_files if is_image_file(x)]
        self.tar_filenames = [os.path.join(rgb_dir, target_dir, x) for x in tar_files if is_image_file(x)]

        self.img_options = img_options
        self.sizex       = len(self.tar_filenames)  # get the size of target

        self.ps = self.img_options['patch_size']

    def __len__(self):
        return self.sizex

    def __getitem__(self, index):
        index_ = index % self.sizex
        ps = self.ps

        inp_path = self.inp_filenames[index_]
        tar_path = self.tar_filenames[index_]

        inp_img = Image.open(inp_path)
        tar_img = Image.open(tar_path)

        # Validate on center crop
        if self.ps is not None:
            inp_img = TF.center_crop(inp_img, (ps,ps))
            tar_img = TF.center_crop(tar_img, (ps,ps))

        inp_img = TF.to_tensor(inp_img)
        tar_img = TF.to_tensor(tar_img)

        filename = os.path.splitext(os.path.split(tar_path)[-1])[0]

        return tar_img, inp_img, filename

In [34]:
def torchPSNR(tar_img, prd_img):
    imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
    rmse = (imdff**2).mean().sqrt()
    ps = 20*torch.log10(1/rmse)
    return ps

def get_validation_data(rgb_dir, img_options, input_dir, target_dir ):
    assert os.path.exists(rgb_dir)
    return DataLoaderVal(rgb_dir, input_dir, target_dir, img_options)

def get_pnsr(data_loader, model=None):
    psnr_val_rgb = []
    for ii, data_val in enumerate((data_loader), 0):
        target = data_val[0].cuda()
        input_ = data_val[1].cuda()
        restored = input_

        # if model:
        #     # model.eval()
        #     with torch.no_grad():
        #         restored = model(input_)
        #     restored = restored[0]

        for res,tar in zip(restored,target):
            psnr_val_rgb.append(torchPSNR(res, tar))

    psnr_val_rgb  = torch.stack(psnr_val_rgb).mean().item()

    return psnr_val_rgb

In [36]:
from torch.utils.data import DataLoader


val_dir = '/content/google-drive/Shareddrives/DeepLearning/skulldatasets/v1/'
val_dataset = get_validation_data(val_dir, {'patch_size':256}, input_dir='Nafnet-deblurring/',
                            target_dir='traget',)
val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=4, drop_last=False, pin_memory=True)

print("Average PSNR between restored images and original images with Denoising Pre-trained model")
print(get_pnsr(val_loader, model=None))

Average PSNR between restored images and original images with Denoising Pre-trained model
27.468708038330078
