Download model weights

In [None]:
# Manually download the weights from https://drive.google.com/uc?id=14Fht1QQJ2gMlk4N1ERCRuElg8JfjrWWR and put them in the models folder

Init image processing functions and load model

In [None]:
import torch
from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite                                                                                                                                                                                  
import numpy as np
import cv2
import matplotlib.pyplot as plt
from basicsr.models import create_model
from basicsr.utils.options import parse

def imread(img_path):
  img = cv2.imread(img_path)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  return img
def img2tensor(img, bgr2rgb=False, float32=True):
    img = img.astype(np.float32) / 255.
    return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32)

def display(img1, img2):
  fig = plt.figure(figsize=(25, 10))
  ax1 = fig.add_subplot(1, 2, 1) 
  plt.title('Input image', fontsize=16)
  ax1.axis('off')
  ax2 = fig.add_subplot(1, 2, 2)
  plt.title('NAFNet output', fontsize=16)
  ax2.axis('off')
  ax1.imshow(img1)
  ax2.imshow(img2)

def single_image_inference(model, img, save_path):
      model.feed_data(data={'lq': img.unsqueeze(dim=0)})

      if model.opt['val'].get('grids', False):
          model.grids()

      model.test()

      if model.opt['val'].get('grids', False):
          model.grids_inverse()

      visuals = model.get_current_visuals()
      sr_img = tensor2img([visuals['result']])
      imwrite(sr_img, save_path)

# Load the model
opt_path = 'models/NAFNet-width64.yml'
opt = parse(opt_path, is_train=False)
opt['dist'] = False
model = create_model(opt)

model

Establish model compression parameters

In [None]:
prune_amount = 0.5
apply_quantization = True
apply_half_precision = False
save_model = True

Prune model weights

In [None]:
import torch.nn.utils.prune as prune

if not prune_amount == 0:
    print(f"Pruning {prune_amount*100}% of weights...")
    for module in model.net_g.modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=prune_amount)
    # Make pruning permanent (optional)
    for module in model.net_g.modules():
        if isinstance(module, torch.nn.Conv2d) and hasattr(module, 'weight_orig'):
            prune.remove(module, 'weight')

Fine-tune Pruned Model

Part 1: Dataset Preparation

- Download train set from https://drive.google.com/file/d/1UHjWZzLPGweA9ZczmV8lFSRcIxqiOVJw/view?usp=sharing and extract `train` from the zip into `./datasets/SIDD/Data`

- Run `python sidd.py` to crop the train image pairs to 512x512 patches and make the data into lmdb format.

- Download evaluation data from https://drive.google.com/file/d/1gZx_K2vmiHalRNOb1aj93KuUQ2guOlLp/view?usp=sharing and extract `SIDD` from the zip into `./datasets` to get `./datasets/SIDD/val/input_crops.lmdb` and `./datasets/SIDD/val/gt_crops.lmdb` in your directory structure.

- Ran into problems at this point - basicsr requires linux to run its training functions and setting up GPU acceleration in WSL2 was unsuccessful.

In [None]:
import torch
torch.cuda.is_available()

Apply Quantization

In [None]:
import torch.ao.quantization as quant

if apply_quantization:
    model.net_g.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    
    prep = quant.prepare(model.net_g, inplace=True);
    
    # model.net_g.to("cpu")  # Switch to cpu for quantization
    # model.device = torch.device("cpu")

    # model.net_g = quant.convert(prep, inplace=True);

elif apply_half_precision:
    # Attempting half-precision over quantization
    model.net_g.cuda().half()

Demo with a test image

In [None]:
input_path = 'demo_images/noisy.png'
output_path = 'demo_output/noisy.png'

img_input = imread(input_path)
inp = img2tensor(img_input)
single_image_inference(model, inp, output_path)
img_output = imread(output_path)
display(img_input, img_output)

Save Model

In [None]:
from os import rename

model.save(epoch=-1, current_iter=-1)  
# Rename model (defaults to "net_g_latest.pth") to a format like pruned30_quantized.pth
rename("models/net_g_latest.pth", "models/%s" % "pruned%d%s.pth" % (prune_amount * 100, "_quantized" if apply_quantization else ""))