<a href="https://colab.research.google.com/github/akramkhatami/IMDB_DeepInvo/blob/main/IMDN_IMDB_DeepInvo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os, shutil
from google.colab import drive
drive.mount('/content/drive')

# png2npy

In [None]:
import os
import argparse
import skimage.io as sio
import numpy as np

parser = argparse.ArgumentParser(description='Pre-processing .png images')
parser.add_argument('--pathFrom', default='./drive/MyDrive/dataset/DIV2K',
                    help='directory of images to convert')
parser.add_argument('--pathTo', default='./drive/MyDrive/dataset/DIV2K_Decoded',
                    help='directory of images to save')
parser.add_argument('--split', default=True,
                    help='save individual images')
parser.add_argument('--select', default='',
                    help='select certain path')


args = parser.parse_args(args = [])

for (path, dirs, files) in os.walk(args.pathFrom):
  print(path)
  targetDir = os.path.join(args.pathTo, path[len(args.pathFrom) + 1:])
  if len(args.select) > 0 and path.find(args.select) == -1:
    continue

  if not os.path.exists(targetDir):
    os.mkdir(targetDir)

  if len(dirs) == 0:
    pack = {}
    n = 0
    for fileName in files:
      (idx, ext) = os.path.splitext(fileName)
      if ext == '.png':
        image = sio.imread(os.path.join(path, fileName))
        if args.split:
          np.save(os.path.join(targetDir, idx + '.npy'), image)
        n += 1
        if n % 100 == 0:
          print('Converted ' + str(n) + ' images.')


# common

In [None]:
import random
import torch
import numpy as np
import skimage.color as sc


def get_patch(*args, patch_size, scale):
  ih, iw = args[0].shape[:2]

  tp = patch_size  # target patch (HR)
  ip = tp // scale  # input patch (LR)

  ix = random.randrange(0, iw - ip + 1)
  iy = random.randrange(0, ih - ip + 1)
  tx, ty = scale * ix, scale * iy

  ret = [
      args[0][iy:iy + ip, ix:ix + ip, :],
      *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]]  # results

  return ret


def set_channel(*args, n_channels=3):
  def _set_channel(img):
    if img.ndim == 2:
      img = np.expand_dims(img, axis=2)

    c = img.shape[2]
    if n_channels == 1 and c == 3:
      img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
    elif n_channels == 3 and c == 1:
      img = np.concatenate([img] * n_channels, 2)

    return img

  return [_set_channel(a) for a in args]


def np2Tensor(*args, rgb_range):
  def _np2Tensor(img):
    np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
    tensor = torch.from_numpy(np_transpose).float()
    tensor.mul_(rgb_range / 255)

    return tensor

  return [_np2Tensor(a) for a in args]


def augment(*args, hflip=True, rot=True):
  hflip = hflip and random.random() < 0.5
  vflip = rot and random.random() < 0.5
  rot90 = rot and random.random() < 0.5

  def _augment(img):
    if hflip: img = img[:, ::-1, :]
    if vflip: img = img[::-1, :, :]
    if rot90: img = img.transpose(1, 0, 2)

    return img

  return [_augment(a) for a in args]

# DIV2K

In [None]:
import torch.utils.data as data
import os.path
import cv2
import numpy as np


def default_loader(path):
  return cv2.imread(path, cv2.IMREAD_UNCHANGED)[:, :, [2, 1, 0]]

def npy_loader(path):
  return np.load(path)

IMG_EXTENSIONS = ['.png', '.npy',
                  ]

def is_image_file(filename):
  return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def make_dataset(dir):
  images = []
  assert os.path.isdir(dir), '%s is not a valid directory' % dir

  for root, _, fnames in sorted(os.walk(dir)):
    for fname in fnames:
      if is_image_file(fname):
        path = os.path.join(root, fname)
        images.append(path)
  return images


class div2k(data.Dataset):
  def __init__(self, opt):
    self.opt = opt
    self.scale = self.opt.scale
    self.root = self.opt.root
    self.ext = self.opt.ext   # '.png' or '.npy'(default)
    self.train = True if self.opt.phase == 'train' else False
    self.repeat =10 #self.opt.test_every // (self.opt.n_train // self.opt.batch_size)
    self._set_filesystem(self.root)
    self.images_hr, self.images_lr = self._scan()

  def _set_filesystem(self, dir_data):
    self.root = dir_data + '/DIV2K_Decoded'
    self.dir_hr = os.path.join(self.root, 'DIV2K_train_HR')
    self.dir_lr = os.path.join(self.root, 'DIV2K_train_LR_bicubic/X' + str(self.scale))

  def __getitem__(self, idx):
    lr, hr = self._load_file(idx)
    lr, hr = self._get_patch(lr, hr)
    lr, hr = set_channel(lr, hr, n_channels=self.opt.n_colors)
    lr_tensor, hr_tensor = np2Tensor(lr, hr, rgb_range=self.opt.rgb_range)
    return lr_tensor, hr_tensor

  def __len__(self):
    if self.train:
      return self.opt.n_train * self.repeat

  def _get_index(self, idx):
    if self.train:
      return idx % self.opt.n_train
    else:
      return idx

  def _get_patch(self, img_in, img_tar):
    patch_size = self.opt.patch_size
    scale = self.scale
    if self.train:
      img_in, img_tar = get_patch(img_in, img_tar, patch_size=patch_size, scale=scale)
      img_in, img_tar = augment(img_in, img_tar)
    else:
      ih, iw = img_in.shape[:2]
      img_tar = img_tar[0:ih * scale, 0:iw * scale, :]
    return img_in, img_tar

  def _scan(self):
    list_hr = sorted(make_dataset(self.dir_hr))
    list_lr = sorted(make_dataset(self.dir_lr))
    return list_hr, list_lr

  def _load_file(self, idx):
    idx = self._get_index(idx)
    if self.ext == '.npy':
      lr = npy_loader(self.images_lr[idx])
      hr = npy_loader(self.images_hr[idx])
    else:
      lr = default_loader(self.images_lr[idx])
      hr = default_loader(self.images_hr[idx])
    return lr, hr

# Set5

In [None]:
import torch.utils.data as data
from os.path import join
from os import listdir

from PIL import Image
import numpy as np
from torchvision.transforms import Compose

from torchvision.transforms import ToTensor

def img_modcrop(image, modulo):
    sz = image.size
    w = np.int32(sz[0] / modulo) * modulo
    h = np.int32(sz[1] / modulo) * modulo
    out = image.crop((0, 0, w, h))
    return out


def np2tensor():
    return Compose([
        ToTensor(),
    ])


def is_image_file(filename):
      return any(filename.endswith(extension) for extension in [".bmp", ".png", ".jpg"])


def load_image(filepath):
    return Image.open(filepath).convert('RGB')


class DatasetFromFolderVal(data.Dataset):
    def __init__(self, hr_dir, lr_dir, upscale):
        super(DatasetFromFolderVal, self).__init__()
        self.hr_filenames = sorted([join(hr_dir, x) for x in listdir(hr_dir) if is_image_file(x)])
        self.lr_filenames = sorted([join(lr_dir, x) for x in listdir(lr_dir) if is_image_file(x)])
        self.upscale = upscale

    def __getitem__(self, index):
        input = load_image(self.lr_filenames[index])
        target = load_image(self.hr_filenames[index])
        input = np2tensor()(input)
        target = np2tensor()(img_modcrop(target, self.upscale))

        return input, target

    def __len__(self):
          return len(self.lr_filenames)

# utils

In [None]:
import numpy as np
import os
import torch
from collections import OrderedDict
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

def compute_psnr(im1, im2):
  p = psnr(im1, im2)
  return p


def compute_ssim(im1, im2):
  isRGB = len(im1.shape) == 3 and im1.shape[-1] == 3
  s = ssim(im1, im2, K1=0.01, K2=0.03, gaussian_weights=True, sigma=1.5, use_sample_covariance=False,
           multichannel=isRGB)
  return s


def shave(im, border):
  border = [border, border]
  im = im[border[0]:-border[0], border[1]:-border[1], ...]
  return im


def modcrop(im, modulo):
  sz = im.shape
  h = np.int32(sz[0] / modulo) * modulo
  w = np.int32(sz[1] / modulo) * modulo
  ims = im[0:h, 0:w, ...]
  return ims


def get_list(path, ext):
  return [os.path.join(path, f) for f in os.listdir(path) if f.endswith(ext)]


def convert_shape(img):
  img = np.transpose((img * 255.0).round(), (1, 2, 0))
  img = np.uint8(np.clip(img, 0, 255))
  return img


def quantize(img):
  return img.clip(0, 255).round().astype(np.uint8)


def tensor2np(tensor, out_type=np.uint8, min_max=(0, 1)):
  tensor = tensor.float().cpu().clamp_(*min_max)
  tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])  # to range [0, 1]
  img_np = tensor.numpy()
  img_np = np.transpose(img_np, (1, 2, 0))
  if out_type == np.uint8:
    img_np = (img_np * 255.0).round()

  return img_np.astype(out_type)

def convert2np(tensor):
  return tensor.cpu().mul(255).clamp(0, 255).byte().squeeze().permute(1, 2, 0).numpy()


def adjust_learning_rate(optimizer, epoch, step_size, lr_init, gamma):
  factor = epoch // step_size
  lr = lr_init * (gamma ** factor)
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr

def load_state_dict(path):

  state_dict = torch.load(path)
  new_state_dcit = OrderedDict()
  for k, v in state_dict.items():
    if 'module' in k:
      name = k[7:]
    else:
      name = k
      new_state_dcit[name] = v
  return new_state_dcit


# Involution

In [None]:
!pip install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu118/torch2.0/index.html

In [None]:
from torch.autograd import Function
import torch
from torch.nn.modules.utils import _pair
import torch.nn.functional as F
import torch.nn as nn
from mmcv.cnn import ConvModule


from collections import namedtuple

import cupy
from string import Template


Stream = namedtuple('Stream', ['ptr'])


def Dtype(t):
    if isinstance(t, torch.cuda.FloatTensor):
        return 'float'
    elif isinstance(t, torch.cuda.DoubleTensor):
        return 'double'


@cupy._util.memoize(for_each_device=True)
def load_kernel(kernel_name, code, **kwargs):
    code = Template(code).substitute(**kwargs)
    kernel_code = cupy.cuda.compile_with_cache(code)
    return kernel_code.get_function(kernel_name)


CUDA_NUM_THREADS = 1024

kernel_loop = '''
#define CUDA_KERNEL_LOOP(i, n)                        \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
      i < (n);                                       \
      i += blockDim.x * gridDim.x)
'''


def GET_BLOCKS(N):
    return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS


_involution_kernel = kernel_loop + '''
extern "C"
__global__ void involution_forward_kernel(
const ${Dtype}* bottom_data, const ${Dtype}* weight_data, ${Dtype}* top_data) {
  CUDA_KERNEL_LOOP(index, ${nthreads}) {
    const int n = index / ${channels} / ${top_height} / ${top_width};
    const int c = (index / ${top_height} / ${top_width}) % ${channels};
    const int h = (index / ${top_width}) % ${top_height};
    const int w = index % ${top_width};
    const int g = c / (${channels} / ${groups});
    ${Dtype} value = 0;
    #pragma unroll
    for (int kh = 0; kh < ${kernel_h}; ++kh) {
      #pragma unroll
      for (int kw = 0; kw < ${kernel_w}; ++kw) {
        const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h};
        const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w};
        if ((h_in >= 0) && (h_in < ${bottom_height})
          && (w_in >= 0) && (w_in < ${bottom_width})) {
          const int offset = ((n * ${channels} + c) * ${bottom_height} + h_in)
            * ${bottom_width} + w_in;
          const int offset_weight = ((((n * ${groups} + g) * ${kernel_h} + kh) * ${kernel_w} + kw) * ${top_height} + h)
            * ${top_width} + w;
          value += weight_data[offset_weight] * bottom_data[offset];
        }
      }
    }
    top_data[index] = value;
  }
}
'''


_involution_kernel_backward_grad_input = kernel_loop + '''
extern "C"
__global__ void involution_backward_grad_input_kernel(
    const ${Dtype}* const top_diff, const ${Dtype}* const weight_data, ${Dtype}* const bottom_diff) {
  CUDA_KERNEL_LOOP(index, ${nthreads}) {
    const int n = index / ${channels} / ${bottom_height} / ${bottom_width};
    const int c = (index / ${bottom_height} / ${bottom_width}) % ${channels};
    const int h = (index / ${bottom_width}) % ${bottom_height};
    const int w = index % ${bottom_width};
    const int g = c / (${channels} / ${groups});
    ${Dtype} value = 0;
    #pragma unroll
    for (int kh = 0; kh < ${kernel_h}; ++kh) {
      #pragma unroll
      for (int kw = 0; kw < ${kernel_w}; ++kw) {
        const int h_out_s = h + ${pad_h} - kh * ${dilation_h};
        const int w_out_s = w + ${pad_w} - kw * ${dilation_w};
        if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) {
          const int h_out = h_out_s / ${stride_h};
          const int w_out = w_out_s / ${stride_w};
          if ((h_out >= 0) && (h_out < ${top_height})
                && (w_out >= 0) && (w_out < ${top_width})) {
            const int offset = ((n * ${channels} + c) * ${top_height} + h_out)
                  * ${top_width} + w_out;
            const int offset_weight = ((((n * ${groups} + g) * ${kernel_h} + kh) * ${kernel_w} + kw) * ${top_height} + h_out)
                  * ${top_width} + w_out;
            value += weight_data[offset_weight] * top_diff[offset];
          }
        }
      }
    }
    bottom_diff[index] = value;
  }
}
'''


_involution_kernel_backward_grad_weight = kernel_loop + '''
extern "C"
__global__ void involution_backward_grad_weight_kernel(
    const ${Dtype}* const top_diff, const ${Dtype}* const bottom_data, ${Dtype}* const buffer_data) {
  CUDA_KERNEL_LOOP(index, ${nthreads}) {
    const int h = (index / ${top_width}) % ${top_height};
    const int w = index % ${top_width};
    const int kh = (index / ${kernel_w} / ${top_height} / ${top_width})
          % ${kernel_h};
    const int kw = (index / ${top_height} / ${top_width}) % ${kernel_w};
    const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h};
    const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w};
    if ((h_in >= 0) && (h_in < ${bottom_height})
          && (w_in >= 0) && (w_in < ${bottom_width})) {
      const int g = (index / ${kernel_h} / ${kernel_w} / ${top_height} / ${top_width}) % ${groups};
      const int n = (index / ${groups} / ${kernel_h} / ${kernel_w} / ${top_height} / ${top_width}) % ${num};
      ${Dtype} value = 0;
      #pragma unroll
      for (int c = g * (${channels} / ${groups}); c < (g + 1) * (${channels} / ${groups}); ++c) {
        const int top_offset = ((n * ${channels} + c) * ${top_height} + h)
              * ${top_width} + w;
        const int bottom_offset = ((n * ${channels} + c) * ${bottom_height} + h_in)
              * ${bottom_width} + w_in;
        value += top_diff[top_offset] * bottom_data[bottom_offset];
      }
      buffer_data[index] = value;
    } else {
      buffer_data[index] = 0;
    }
  }
}
'''


class _involution(Function):
    @staticmethod
    def forward(ctx, input, weight, stride, padding, dilation):
        assert input.dim() == 4 and input.is_cuda
        assert weight.dim() == 6 and weight.is_cuda
        batch_size, channels, height, width = input.size()
        kernel_h, kernel_w = weight.size()[2:4]
        output_h = int((height + 2 * padding[0] - (dilation[0] * (kernel_h - 1) + 1)) / stride[0] + 1)
        output_w = int((width + 2 * padding[1] - (dilation[1] * (kernel_w - 1) + 1)) / stride[1] + 1)

        output = input.new(batch_size, channels, output_h, output_w)
        n = output.numel()

        with torch.cuda.device_of(input):
            f = load_kernel('involution_forward_kernel', _involution_kernel, Dtype=Dtype(input), nthreads=n,
                            num=batch_size, channels=channels, groups=weight.size()[1],
                            bottom_height=height, bottom_width=width,
                            top_height=output_h, top_width=output_w,
                            kernel_h=kernel_h, kernel_w=kernel_w,
                            stride_h=stride[0], stride_w=stride[1],
                            dilation_h=dilation[0], dilation_w=dilation[1],
                            pad_h=padding[0], pad_w=padding[1])
            f(block=(CUDA_NUM_THREADS,1,1),
              grid=(GET_BLOCKS(n),1,1),
              args=[input.data_ptr(), weight.data_ptr(), output.data_ptr()],
              stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))

        ctx.save_for_backward(input, weight)
        ctx.stride, ctx.padding, ctx.dilation = stride, padding, dilation
        return output

    @staticmethod
    def backward(ctx, grad_output):
        assert grad_output.is_cuda and grad_output.is_contiguous()
        input, weight = ctx.saved_tensors
        stride, padding, dilation = ctx.stride, ctx.padding, ctx.dilation

        batch_size, channels, height, width = input.size()
        kernel_h, kernel_w = weight.size()[2:4]
        output_h, output_w = grad_output.size()[2:]

        grad_input, grad_weight = None, None

        opt = dict(Dtype=Dtype(grad_output),
                   num=batch_size, channels=channels, groups=weight.size()[1],
                   bottom_height=height, bottom_width=width,
                   top_height=output_h, top_width=output_w,
                   kernel_h=kernel_h, kernel_w=kernel_w,
                   stride_h=stride[0], stride_w=stride[1],
                   dilation_h=dilation[0], dilation_w=dilation[1],
                   pad_h=padding[0], pad_w=padding[1])

        with torch.cuda.device_of(input):
            if ctx.needs_input_grad[0]:
                grad_input = input.new(input.size())

                n = grad_input.numel()
                opt['nthreads'] = n

                f = load_kernel('involution_backward_grad_input_kernel',
                                _involution_kernel_backward_grad_input, **opt)
                f(block=(CUDA_NUM_THREADS,1,1),
                  grid=(GET_BLOCKS(n),1,1),
                  args=[grad_output.data_ptr(), weight.data_ptr(), grad_input.data_ptr()],
                  stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))

            if ctx.needs_input_grad[1]:
                grad_weight = weight.new(weight.size())

                n = grad_weight.numel()
                opt['nthreads'] = n

                f = load_kernel('involution_backward_grad_weight_kernel',
                                _involution_kernel_backward_grad_weight, **opt)
                f(block=(CUDA_NUM_THREADS,1,1),
                  grid=(GET_BLOCKS(n),1,1),
                  args=[grad_output.data_ptr(), input.data_ptr(), grad_weight.data_ptr()],
                  stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))

        return grad_input, grad_weight, None, None, None


def _involution_cuda(input, weight, bias=None, stride=1, padding=0, dilation=1):
    """ involution kernel
    """
    assert input.size(0) == weight.size(0)
    assert input.size(-2)//stride == weight.size(-2)
    assert input.size(-1)//stride == weight.size(-1)
    if input.is_cuda:
        out = _involution.apply(input, weight, _pair(stride), _pair(padding), _pair(dilation))
        if bias is not None:
            out += bias.view(1,-1,1,1)
    else:
        raise NotImplementedError
    return out


class involution(nn.Module):

    def __init__(self,
                 channels,
                 kernel_size,
                 stride):
        super(involution, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.channels = channels
        reduction_ratio = 4
        self.group_channels = 16
        self.groups = self.channels // self.group_channels
        self.conv1 = ConvModule(
            in_channels=channels,
            out_channels=channels // reduction_ratio,
            kernel_size=1,
            conv_cfg=None,
            norm_cfg=dict(type='BN'),
            act_cfg=dict(type='ReLU'))
        self.conv2 = ConvModule(
            in_channels=channels // reduction_ratio,
            out_channels=kernel_size**2 * self.groups,
            kernel_size=1,
            stride=1,
            conv_cfg=None,
            norm_cfg=None,
            act_cfg=None)
        if stride > 1:
            self.avgpool = nn.AvgPool2d(stride, stride)

    def forward(self, x):
        weight = self.conv2(self.conv1(x if self.stride == 1 else self.avgpool(x)))
        b, c, h, w = weight.shape
        weight = weight.view(b, self.groups, self.kernel_size, self.kernel_size, h, w)
        out = _involution_cuda(x, weight, stride=self.stride, padding=(self.kernel_size-1)//2)
        return out

# block

In [None]:
import torch.nn as nn
from collections import OrderedDict
import torch


def conv_layer(in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
  padding = int((kernel_size - 1) / 2) * dilation
  return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias, dilation=dilation,
                   groups=groups)


def norm(norm_type, nc):
  norm_type = norm_type.lower()
  if norm_type == 'batch':
    layer = nn.BatchNorm2d(nc, affine=True)
  elif norm_type == 'instance':
    layer = nn.InstanceNorm2d(nc, affine=False)
  else:
    raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
  return layer


def pad(pad_type, padding):
  pad_type = pad_type.lower()
  if padding == 0:
    return None
  if pad_type == 'reflect':
    layer = nn.ReflectionPad2d(padding)
  elif pad_type == 'replicate':
    layer = nn.ReplicationPad2d(padding)
  else:
    raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
  return layer


def get_valid_padding(kernel_size, dilation):
  kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
  padding = (kernel_size - 1) // 2
  return padding


def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
               pad_type='zero', norm_type=None, act_type='relu'):
  padding = get_valid_padding(kernel_size, dilation)
  p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
  padding = padding if pad_type == 'zero' else 0

  c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
                  dilation=dilation, bias=bias, groups=groups)
  a = activation(act_type) if act_type else None
  n = norm(norm_type, out_nc) if norm_type else None
  return sequential(p, c, n, a)


def activation(act_type, inplace=True, neg_slope=0.05, n_prelu=1):
  act_type = act_type.lower()
  if act_type == 'relu':
    layer = nn.ReLU(inplace)
  elif act_type == 'lrelu':
    layer = nn.LeakyReLU(neg_slope, inplace)
  elif act_type == 'prelu':
    layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
  else:
    raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
  return layer


class ShortcutBlock(nn.Module):
  def __init__(self, submodule):
    super(ShortcutBlock, self).__init__()
    self.sub = submodule

  def forward(self, x):
    output = x + self.sub(x)
    return output



def mean_channels(F):
  assert(F.dim() == 4)
  spatial_sum = F.sum(3, keepdim=True).sum(2, keepdim=True)
  return spatial_sum / (F.size(2) * F.size(3))

def stdv_channels(F):
  assert(F.dim() == 4)
  F_mean = mean_channels(F)
  F_variance = (F - F_mean).pow(2).sum(3, keepdim=True).sum(2, keepdim=True) / (F.size(2) * F.size(3))
  return F_variance.pow(0.5)

def sequential(*args):
  if len(args) == 1:
    if isinstance(args[0], OrderedDict):
      raise NotImplementedError('sequential does not support OrderedDict input.')
    return args[0]
  modules = []
  for module in args:
    if isinstance(module, nn.Sequential):
      for submodule in module.children():
        modules.append(submodule)
    elif isinstance(module, nn.Module):
      modules.append(module)
  return nn.Sequential(*modules)



# contrast-aware channel attention module
class CCALayer(nn.Module):
  def __init__(self, channel, reduction=16):
    super(CCALayer, self).__init__()

    self.contrast = stdv_channels
    self.avg_pool = nn.AdaptiveAvgPool2d(1)
    self.conv_du = nn.Sequential(
        nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
        nn.ReLU(inplace=True),
        nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
        nn.Sigmoid()
      )


  def forward(self, x):
    y = self.contrast(x) + self.avg_pool(x)
    y = self.conv_du(y)
    return x * y


class IMDModule1(nn.Module):
  def __init__(self, in_channels, distillation_rate=0.25):
    super(IMDModule1, self).__init__()
    self.distilled_channels = int(in_channels * distillation_rate)
    self.remaining_channels = int(in_channels - self.distilled_channels)
    self.c1 = conv_layer(in_channels, in_channels, 3)
    self.c2 = conv_layer(self.remaining_channels, in_channels, 3)
    self.c3 = conv_layer(self.remaining_channels, in_channels, 3)
    self.c4 = conv_layer(self.remaining_channels, self.distilled_channels, 3)
    self.act = activation('lrelu', neg_slope=0.05)
    self.c5 = conv_layer(in_channels, in_channels, 1)
    self.cca = CCALayer(self.distilled_channels * 4)

  def forward(self, input):
    out_c1 = self.act(self.c1(input))
    distilled_c1, remaining_c1 = torch.split(out_c1, (self.distilled_channels, self.remaining_channels), dim=1)
    out_c2 = self.act(self.c2(remaining_c1))
    distilled_c2, remaining_c2 = torch.split(out_c2, (self.distilled_channels, self.remaining_channels), dim=1)
    out_c3 = self.act(self.c3(remaining_c2))
    distilled_c3, remaining_c3 = torch.split(out_c3, (self.distilled_channels, self.remaining_channels), dim=1)
    out_c4 = self.c4(remaining_c3)
    out = torch.cat([distilled_c1, distilled_c2, distilled_c3, out_c4], dim=1)
    out_fused = self.c5(self.cca(out)) + input
    return out_fused


class IMDModule2(nn.Module):
  def __init__(self, in_channels, distillation_rate=0.25):
    super(IMDModule2, self).__init__()
    self.distilled_channels = int(in_channels * distillation_rate)
    self.remaining_channels = int(in_channels - self.distilled_channels)
    self.c1 = conv_layer(in_channels, in_channels, 3)
    self.c2 = conv_layer(self.remaining_channels, in_channels, 3)
    self.c3 = involution(in_channels , 3 , 1)#input channels,kernel_size, stride
    self.c4 = conv_layer(self.remaining_channels, self.distilled_channels, 3)
    self.act = activation('lrelu', neg_slope=0.05)
    self.c5 = conv_layer(in_channels, in_channels, 1)
    self.cca = CCALayer(self.distilled_channels * 4)

  def forward(self, input):
    out_c1 = self.act(self.c1(input))
    distilled_c1, remaining_c1 = torch.split(out_c1, (self.distilled_channels, self.remaining_channels), dim=1)
    out_c2 = self.act(self.c2(remaining_c1))
    distilled_c2, remaining_c2 = torch.split(out_c2, (self.distilled_channels, self.remaining_channels), dim=1)
    out_c3 = self.act(self.c3(out_c2))
    distilled_c3, remaining_c3 = torch.split(out_c3, (self.distilled_channels, self.remaining_channels), dim=1)
    out_c4 = self.c4(remaining_c3)
    out = torch.cat([distilled_c1, distilled_c2, distilled_c3, out_c4], dim=1)
    out_fused = self.c5(self.cca(out)) + input
    return out_fused


class IMDModule3(nn.Module):
  def __init__(self, in_channels, distillation_rate=0.25):
    super(IMDModule3, self).__init__()
    self.distilled_channels = int(in_channels * distillation_rate)
    self.remaining_channels = int(in_channels - self.distilled_channels)
    self.c1 = involution(in_channels , 3 , 1)#input channels,kernel_size, stride
    self.c2 = conv_layer(self.remaining_channels, in_channels, 3)
    self.c3 = involution(in_channels , 3 , 1)
    self.c4 = conv_layer(self.remaining_channels, self.distilled_channels, 3)
    self.act = activation('lrelu', neg_slope=0.05)
    self.c5 = conv_layer(in_channels, in_channels, 1)
    self.cca = CCALayer(self.distilled_channels * 4)

  def forward(self, input):
    out_c1 = self.act(self.c1(input))
    distilled_c1, remaining_c1 = torch.split(out_c1, (self.distilled_channels, self.remaining_channels), dim=1)
    out_c2 = self.act(self.c2(remaining_c1))
    distilled_c2, remaining_c2 = torch.split(out_c2, (self.distilled_channels, self.remaining_channels), dim=1)
    out_c3 = self.act(self.c3(out_c2))
    distilled_c3, remaining_c3 = torch.split(out_c3, (self.distilled_channels, self.remaining_channels), dim=1)
    out_c4 = self.c4(remaining_c3)
    out = torch.cat([distilled_c1, distilled_c2, distilled_c3, out_c4], dim=1)
    out_fused = self.c5(self.cca(out)) + input
    return out_fused


class IMDModule4(nn.Module):
  def __init__(self, in_channels, distillation_rate=0.25):
    super(IMDModule4, self).__init__()
    self.distilled_channels = int(in_channels * distillation_rate)
    self.remaining_channels = int(in_channels - self.distilled_channels)
    self.c1 = involution(in_channels , 3 , 1)
    self.c2 = involution(in_channels , 3 , 1)#input channels,kernel_size, stride
    self.c3 = involution(in_channels , 3 , 1)
    self.c4 = conv_layer(self.remaining_channels, self.distilled_channels, 3)
    self.act = activation('lrelu', neg_slope=0.05)
    self.c5 = conv_layer(in_channels, in_channels, 1)
    self.cca = CCALayer(self.distilled_channels * 4)

  def forward(self, input):
    out_c1 = self.act(self.c1(input))
    distilled_c1, remaining_c1 = torch.split(out_c1, (self.distilled_channels, self.remaining_channels), dim=1)
    out_c2 = self.act(self.c2(out_c1))
    distilled_c2, remaining_c2 = torch.split(out_c2, (self.distilled_channels, self.remaining_channels), dim=1)
    out_c3 = self.act(self.c3(out_c2))
    distilled_c3, remaining_c3 = torch.split(out_c3, (self.distilled_channels, self.remaining_channels), dim=1)
    out_c4 = self.c4(remaining_c3)
    out = torch.cat([distilled_c1, distilled_c2, distilled_c3, out_c4], dim=1)
    out_fused = self.c5(self.cca(out)) + input
    return out_fused

def pixelshuffle_block(in_channels, out_channels, upscale_factor=2, kernel_size=3, stride=1):
  conv = conv_layer(in_channels, out_channels * (upscale_factor ** 2), kernel_size, stride)
  pixel_shuffle = nn.PixelShuffle(upscale_factor)
  return sequential(conv, pixel_shuffle)


#  model

In [None]:
import torch.nn as nn
import torch



class IMDN(nn.Module):
  def __init__(self, in_nc=3, nf=64, num_modules=6, out_nc=3, upscale=4):
    super(IMDN, self).__init__()

    self.fea_conv = conv_layer(in_nc, nf, kernel_size=3)

    # IMDBs
    self.IMDB1 = IMDModule1(in_channels=nf)
    self.IMDB2 = IMDModule1(in_channels=nf)
    self.IMDB3 = IMDModule2(in_channels=nf)
    self.IMDB4 = IMDModule2(in_channels=nf)
    self.IMDB5 = IMDModule3(in_channels=nf)
    self.IMDB6 = IMDModule4(in_channels=nf)
    self.c = conv_block(nf * num_modules, nf, kernel_size=1, act_type='lrelu')

    self.LR_conv = conv_layer(nf, nf, kernel_size=3)

    upsample_block = pixelshuffle_block
    self.upsampler = upsample_block(nf, out_nc, upscale_factor=upscale)


  def forward(self, input):
    out_fea = self.fea_conv(input)
    out_B1 = self.IMDB1(out_fea)
    out_B2 = self.IMDB2(out_B1)
    out_B3 = self.IMDB3(out_B2)
    out_B4 = self.IMDB4(out_B3)
    out_B5 = self.IMDB5(out_B4)
    out_B6 = self.IMDB6(out_B5)

    out_B = self.c(torch.cat([out_B1, out_B2, out_B3, out_B4, out_B5, out_B6], dim=1))
    out_lr = self.LR_conv(out_B) + out_fea
    output = self.upsampler(out_lr)
    return output


# loss

In [None]:
import torch
import torchvision

class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__()
        blocks = []
        blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
        for bl in blocks:
            for p in bl.parameters():
                p.requires_grad = False
        self.blocks = torch.nn.ModuleList(blocks)
        self.transform = torch.nn.functional.interpolate
        self.resize = resize
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
        if input.shape[1] != 3:
            input = input.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
        input = (input-self.mean) / self.std
        target = (target-self.mean) / self.std
        if self.resize:
            input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
            target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
        loss = 0.0
        x = input
        y = target
        for i, block in enumerate(self.blocks):
            x = block(x)
            y = block(y)
            if i in feature_layers:
                loss += torch.nn.functional.l1_loss(x, y)
            if i in style_layers:
                act_x = x.reshape(x.shape[0], x.shape[1], -1)
                act_y = y.reshape(y.shape[0], y.shape[1], -1)
                gram_x = act_x @ act_x.permute(0, 2, 1)
                gram_y = act_y @ act_y.permute(0, 2, 1)
                loss += torch.nn.functional.l1_loss(gram_x, gram_y)
        return loss

#train model

In [None]:
!pip install lpips

In [None]:
import argparse, os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import lpips
import skimage.color as sc
import random
from collections import OrderedDict


# Training settings
parser = argparse.ArgumentParser(description="IMDN")
parser.add_argument("--batch_size", type=int, default=16,
                    help="training batch size")#16
parser.add_argument("--testBatchSize", type=int, default=1,
                    help="testing batch size")
parser.add_argument("-nEpochs", type=int, default=100,
                    help="number of epochs to train")#1000
parser.add_argument("--lr", type=float, default=2e-4,
                    help="Learning Rate. Default=2e-4")
parser.add_argument("--step_size", type=int, default=200,
                    help="learning rate decay per N epochs")#200
parser.add_argument("--gamma", type=int, default=0.5,
                    help="learning rate decay factor for step decay")
parser.add_argument("--cuda", action="store_true", default=True,
                    help="use cuda")
parser.add_argument("--resume", default="", type=str,
                    help="path to checkpoint")
parser.add_argument("--start-epoch", default=1, type=int,
                    help="manual epoch number")
parser.add_argument("--threads", type=int, default=8,
                    help="number of threads for data loading")#8
parser.add_argument("--root", type=str, default="./drive/MyDrive/dataset",
                    help='dataset directory')#training_data/
parser.add_argument("--n_train", type=int, default=800,
                    help="number of training set")#800
parser.add_argument("--n_val", type=int, default=1,
                    help="number of validation set")
parser.add_argument("--test_every", type=int, default=1000)
parser.add_argument("--scale", type=int, default=2,
                    help="super-resolution scale")#2,3,4
parser.add_argument("--patch_size", type=int, default=192,
                    help="output patch size")#192 ,96 ,144
parser.add_argument("--rgb_range", type=int, default=1,
                    help="maxium value of RGB")
parser.add_argument("--n_colors", type=int, default=3,
                    help="number of color channels to use")
parser.add_argument("--pretrained", default="", type=str,
                    help="path to pretrained models")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--isY", action="store_true", default=True)
parser.add_argument("--ext", type=str, default='.npy')
parser.add_argument("--phase", type=str, default='train')

args = parser.parse_args(args = [])

testset = DatasetFromFolderVal("/content/drive/MyDrive/Test_Datasets/Set5/Set5/X2/",
                               "/content/drive/MyDrive/Test_Datasets/Set5/Set5_LR/X{}/".format(args.scale),
                              args.scale)

testing_data_loader = DataLoader(dataset=testset, num_workers=args.threads, batch_size=args.testBatchSize,shuffle=False)
len_testing_data= len(testing_data_loader)

print(args)
torch.backends.cudnn.benchmark = True
# random seed
seed = args.seed
if seed is None:
  seed = random.randint(1, 10000)

print("Ramdom Seed: ", seed)
random.seed(seed)
torch.manual_seed(seed)

cuda = args.cuda
device = torch.device('cuda' if cuda else 'cpu')

print("===> Loading datasets")

trainset = div2k(args)

training_data_loader = DataLoader(dataset=trainset, num_workers=args.threads, batch_size=args.batch_size, shuffle=True, pin_memory=True, drop_last=True)


print("===> Building models")
args.is_train = True

model = IMDN(upscale=args.scale)
l1_criterion = nn.L1Loss()
vgg_criterion = VGGPerceptualLoss()

print("===> Setting GPU")
if cuda:
  model = model.to(device)
  l1_criterion = l1_criterion.to(device)
  vgg_criterion = vgg_criterion.to(device)

if args.pretrained:
  if os.path.isfile(args.pretrained):
    print("===> loading models '{}'".format(args.pretrained))
    checkpoint = torch.load(args.pretrained)
    new_state_dcit = OrderedDict()
    for k, v in checkpoint.items():
      if 'module' in k:
        name = k[7:]
      else:
        name = k
      new_state_dcit[name] = v
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in new_state_dcit.items() if k in model_dict}

    for k, v in model_dict.items():
      if k not in pretrained_dict:
        print(k)
    model.load_state_dict(pretrained_dict, strict=True)

  else:
    print("===> no models found at '{}'".format(args.pretrained))

print("===> Setting Optimizer")

optimizer = optim.Adam(model.parameters(), lr=args.lr)


def train(epoch):
    model.train()
    adjust_learning_rate(optimizer, epoch, args.step_size, args.lr, args.gamma)
    print('epoch =', epoch, 'lr = ', optimizer.param_groups[0]['lr'])
    for iteration, (lr_tensor, hr_tensor) in enumerate(training_data_loader, 1):

        if args.cuda:
            lr_tensor = lr_tensor.to(device)  # ranges from [0, 1]
            hr_tensor = hr_tensor.to(device)  # ranges from [0, 1]

        optimizer.zero_grad()
        sr_tensor = model(lr_tensor)
        loss_l1 = l1_criterion(sr_tensor, hr_tensor) + 0.008*vgg_criterion(sr_tensor, hr_tensor)
        loss_sr = loss_l1

        loss_sr.backward()
        optimizer.step()
        if iteration % 100 == 0:
            print("===> Epoch[{}]({}/{}): Loss_l1: {:.5f}".format(epoch, iteration, len(training_data_loader),
                                                                  loss_l1.item()))




valid_vgg = lpips.LPIPS(net='vgg')
valid_alex = lpips.LPIPS(net='alex')


def valid():
    model.eval()

    avg_psnr, avg_ssim , avg_lpips_vgg , avg_lpips_alex= 0, 0 ,0 , 0
    for batch in testing_data_loader:
        lr_tensor, hr_tensor = batch[0], batch[1]
        if args.cuda:
            lr_tensor = lr_tensor.to(device)
            hr_tensor = hr_tensor.to(device)

        with torch.no_grad():
            pre = model(lr_tensor)

        sr_img = tensor2np(pre.detach()[0])
        gt_img = tensor2np(hr_tensor.detach()[0])
        crop_size = args.scale
        cropped_sr_img = shave(sr_img, crop_size)
        cropped_gt_img = shave(gt_img, crop_size)
        if args.isY is True:
            im_label = quantize(sc.rgb2ycbcr(cropped_gt_img)[:, :, 0])
            im_pre = quantize(sc.rgb2ycbcr(cropped_sr_img)[:, :, 0])
        else:
            im_label = cropped_gt_img
            im_pre = cropped_sr_img
        avg_psnr += compute_psnr(im_pre, im_label)
        avg_ssim += compute_ssim(im_pre, im_label)

        avg_lpips_alex += valid_alex(torch.tensor(im_pre).float() , torch.tensor(im_label).float() )
        avg_lpips_vgg += valid_vgg(torch.tensor(im_pre).float() , torch.tensor(im_label).float() )

    print("===> Valid. psnr: {:.4f}, ssim: {:.4f} ".format(avg_psnr / len_testing_data,
                                                           avg_ssim /len_testing_data ),
           "lpips_vgg: " , avg_lpips_vgg.item() / len_testing_data,
          "/t lpips_alex: " , avg_lpips_alex.item() / len_testing_data)



def save_checkpoint(epoch):
  model_folder = "checkpoint_x{}/".format(args.scale)
  model_out_path = model_folder + "epoch_{}.pth".format(epoch)
  if not os.path.exists(model_folder):
    os.makedirs(model_folder)
  torch.save(model.state_dict(), model_out_path)
  print("===> Checkpoint saved to {}".format(model_out_path))

def print_network(net):
  num_params = 0
  for param in net.parameters():
    num_params += param.numel()
  print(net)
  print('Total number of parameters: %d' % num_params)


print("===> Training")
print_network(model)

for epoch in range(args.start_epoch, args.nEpochs + 1):
  train(epoch)
  valid()

  save_checkpoint(epoch)

# test

In [None]:
import argparse
import torch
import os
import numpy as np
from torch.utils.data import DataLoader
import skimage.color as sc
import cv2
import lpips

# Testing settings

parser = argparse.ArgumentParser(description='IMDN')
parser.add_argument("--test_hr_folder", type=str, default='/content/drive/MyDrive/Test_Datasets/Set5/Set5/X2/',
                    help='the folder of the target images')
parser.add_argument("--test_lr_folder", type=str, default='/content/drive/MyDrive/Test_Datasets/Set5/Set5_LR/X2/',
                    help='the folder of the input images')
parser.add_argument("--output_folder", type=str, default='/content/drive/MyDrive/IMDN/checkpoint_x2/deep invo/ressult')
parser.add_argument("--checkpoint", type=str, default='/content/drive/MyDrive/IMDN/checkpoint_x2/deep invo/epoch_100.pth',
                    help='checkpoint folder to use')
parser.add_argument('--cuda', action='store_true', default=True,
                    help='use cuda')
parser.add_argument("--isY", action="store_true", default=True)
parser.add_argument("--upscale_factor", type=int, default=4,
                    help='upscaling factor')
parser.add_argument("--is_y", action='store_true', default=True,
                    help='evaluate on y channel, if False evaluate on RGB channels')
opt = parser.parse_args(args=[])

print(opt)



testset = DatasetFromFolderVal("/content/drive/MyDrive/Test_Datasets/Set5/Set5/X2",
                               "/content/drive/MyDrive/Test_Datasets/Set5/Set5_LR/X{}/".format(args.scale),
                              args.scale)

testing_data_loader = DataLoader(dataset=testset, num_workers=args.threads, batch_size=args.testBatchSize,shuffle=False)
len_testing_data= len(testing_data_loader)

print(len_testing_data)



cuda = opt.cuda
device = torch.device('cuda' if cuda else 'cpu')

filepath = opt.test_hr_folder
if filepath.split('/')[-3] == 'Set5' or filepath.split('/')[-3] == 'Set14' or filepath.split('/')[-3] == 'BSDS100' or filepath.split('/')[-3] == 'Urban100':
  ext = '.png'
else:
  ext = '.bmp'
filelist = get_list(filepath, ext=ext)
print(filelist)



valid_vgg = lpips.LPIPS(net='vgg')
valid_alex = lpips.LPIPS(net='alex')



psnr_list = np.zeros(len(testing_data_loader))
ssim_list = np.zeros(len(testing_data_loader))
lpips_list_vgg = np.zeros(len(testing_data_loader))
lpips_list_alex = np.zeros(len(testing_data_loader))
time_list = np.zeros(len(testing_data_loader))

model = IMDN(upscale=2).to(device)
model_dict =load_state_dict(opt.checkpoint)
model.load_state_dict(model_dict, strict=True)

def test():
    model.eval()
    i=0

    avg_psnr, avg_ssim , avg_lpips_vgg , avg_lpips_alex= 0, 0 ,0 , 0
    for batch in testing_data_loader:
        lr_tensor, hr_tensor = batch[0], batch[1]
        if cuda:
            lr_tensor = lr_tensor.to(device)
            hr_tensor = hr_tensor.to(device)

        with torch.no_grad():
            pre = model(lr_tensor)

        sr_img = tensor2np(pre.detach()[0])
        gt_img = tensor2np(hr_tensor.detach()[0])
        crop_size = opt.upscale_factor
        cropped_sr_img = shave(sr_img, crop_size)
        cropped_gt_img = shave(gt_img, crop_size)
        if opt.isY is True:
            im_label = quantize(sc.rgb2ycbcr(cropped_gt_img)[:, :, 0])
            im_pre = quantize(sc.rgb2ycbcr(cropped_sr_img)[:, :, 0])
        else:
            im_label = cropped_gt_img
            im_pre = cropped_sr_img
        avg_psnr += compute_psnr(im_pre, im_label)
        avg_ssim += compute_ssim(im_pre, im_label)

        #Added avg_lpips by me

        #avg_PieAPP += valid_PieAPP(torch.tensor(im_pre).float() , torch.tensor(im_label).float() )
        avg_lpips_alex += valid_alex(torch.tensor(im_pre).float() , torch.tensor(im_label).float() )
        avg_lpips_vgg += valid_vgg(torch.tensor(im_pre).float() , torch.tensor(im_label).float() )



        output_folder = os.path.join(opt.output_folder, filelist[i].split('/')[-1].split('.')[0] + 'x' + str(opt.upscale_factor) + '.png')


        if not os.path.exists(opt.output_folder):
          os.makedirs(opt.output_folder)

        cv2.imwrite(output_folder, sr_img[:, :, [2, 1, 0]])
        i+=1





    print("===> Valid. psnr: {:.4f}, ssim: {:.4f} ".format(avg_psnr / len_testing_data,
                                                           avg_ssim /len_testing_data ),
           "lpips_vgg: " , avg_lpips_vgg.item() / len_testing_data,
          "/t lpips_alex: " , avg_lpips_alex.item() / len_testing_data)



test()