Prep
==

In [None]:
from google.colab import drive
drive.mount('/content/drive/')


Mounted at /content/drive/


In [None]:
!pip install tensorboard einops scikit-image opencv-python pytorch_msssim warmup-scheduler

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Collecting pytorch_msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB)
Collecting warmup-scheduler
  Downloading warmup_scheduler-0.3.tar.gz (2.1 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: warmup-scheduler
  Building wheel for warmup-scheduler (setup.py) ... [?25l[?25hdone
  Created wheel for warmup-scheduler: filename=warmup_scheduler-0.3-py3-none-any.whl size=2968 sha256=c2e88f5dc27e4cd6b2804768cc4c4981eb898907c4b846508544cc04c447bece
  Stored in directory: /root/.cache/pip/wheels/59/01/9e/d1820991c32916e9808c940f572b462f3e46427f3e76c4d852
Successfully built warmup-scheduler
Installing collected packages: warmup-scheduler, einops, pytorch_msssim
Successfully installed einops-0.7.0 pytorch_msssim-1.0.0 warmup-scheduler-0.3


In [None]:
path = "/content/drive/MyDrive/reside-indoor"

Data Augment
==

In [None]:
import random
import torchvision.transforms as transforms
import torchvision.transforms.functional as FUNCTIONAL


class PairRandomCrop(transforms.RandomCrop):

    def __call__(self, image, label):

        if self.padding is not None:
            image = FUNCTIONAL.pad(image, self.padding, self.fill, self.padding_mode)
            label = FUNCTIONAL.pad(label, self.padding, self.fill, self.padding_mode)

        # pad the width if needed
        if self.pad_if_needed and image.size[0] < self.size[1]:
            image = FUNCTIONAL.pad(image, (self.size[1] - image.size[0], 0), self.fill, self.padding_mode)
            label = FUNCTIONAL.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode)
        # pad the height if needed
        if self.pad_if_needed and image.size[1] < self.size[0]:
            image = FUNCTIONAL.pad(image, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode)
            label = FUNCTIONAL.pad(label, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode)

        i, j, h, w = self.get_params(image, self.size)

        return FUNCTIONAL.crop(image, i, j, h, w), FUNCTIONAL.crop(label, i, j, h, w)


class PairCompose(transforms.Compose):
    def __call__(self, image, label):
        for t in self.transforms:
            image, label = t(image, label)
        return image, label


class PairRandomHorizontalFilp(transforms.RandomHorizontalFlip):
    def __call__(self, img, label):
        """
        Args:
            img (PIL Image): Image to be flipped.

        Returns:
            PIL Image: Randomly flipped image.
        """
        if random.random() < self.p:
            return FUNCTIONAL.hflip(img), FUNCTIONAL.hflip(label)
        return img, label


class PairToTensor(transforms.ToTensor):
    def __call__(self, pic, label):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return FUNCTIONAL.to_tensor(pic), FUNCTIONAL.to_tensor(label)

Data Loader
==

In [None]:
import os
import torch
import numpy as np
from PIL import Image as Image
from torchvision.transforms import functional as Functional
from torch.utils.data import Dataset, DataLoader


def train_dataloader(path, batch_size=1, num_workers=0, use_transform=True): # we change batch size = 1, but the original batch size = 64
    image_dir = os.path.join(path, 'train')

    transform = None
    if use_transform:
        transform = PairCompose(
            [
                PairRandomCrop(256),
                PairRandomHorizontalFilp(),
                PairToTensor()
            ]
        )
    dataloader = DataLoader(
        DeblurDataset(image_dir, transform=transform),
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    return dataloader


def test_dataloader(path, batch_size=1, num_workers=0):
    image_dir = os.path.join(path, 'test')
    dataloader = DataLoader(
        DeblurDataset(image_dir, is_test=True),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return dataloader


def valid_dataloader(path, batch_size=1, num_workers=0):
    dataloader = DataLoader(
        DeblurDataset(os.path.join(path, 'test')),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers
    )

    return dataloader


class DeblurDataset(Dataset):
    def __init__(self, image_dir, transform=None, is_test=False):
        self.image_dir = image_dir
        self.image_list = os.listdir(os.path.join(image_dir, 'hazy/'))
        self._check_image(self.image_list)
        self.image_list.sort()
        self.transform = transform
        self.is_test = is_test

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.image_dir, 'hazy', self.image_list[idx]))
        label = Image.open(os.path.join(self.image_dir, 'gt', self.image_list[idx].split('_')[0]+'.png'))

        if self.transform:
            image, label = self.transform(image, label)
        else:
            image = Functional.to_tensor(image)
            label = Functional.to_tensor(label)
        if self.is_test:
            name = self.image_list[idx]
            return image, label, name
        return image, label

    @staticmethod
    def _check_image(lst):
        for x in lst:
            splits = x.split('.')
            if splits[-1] not in ['png', 'jpg', 'jpeg']:
                raise ValueError

Utils
==

In [None]:
import time
import numpy as np


class Adder(object):
    def __init__(self):
        self.count = 0
        self.num = float(0)

    def reset(self):
        self.count = 0
        self.num = float(0)

    def __call__(self, num):
        self.count += 1
        self.num += num

    def average(self):
        return self.num / self.count


class Timer(object):
    def __init__(self, option='s'):
        self.tm = 0
        self.option = option
        if option == 's':
            self.devider = 1
        elif option == 'm':
            self.devider = 60
        else:
            self.devider = 3600

    def tic(self):
        self.tm = time.time()

    def toc(self):
        return (time.time() - self.tm) / self.devider


def check_lr(optimizer):
    for i, param_group in enumerate(optimizer.param_groups):
        lr = param_group['lr']
    return lr

Eval
==

In [None]:
import os
import torch
from torchvision.transforms import functional as F
import numpy as np
from skimage.metrics import peak_signal_noise_ratio
import time
from pytorch_msssim import ssim
import torch.nn.functional as f

from skimage import img_as_ubyte
import cv2

def _eval(model, args):
    state_dict = torch.load(args.test_model)
    model.load_state_dict(state_dict['model'])
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataloader = test_dataloader(args.data_dir, batch_size=1, num_workers=0)
    torch.cuda.empty_cache()
    adder = Adder()
    model.eval()
    factor = 32
    with torch.no_grad():
        psnr_adder = Adder()
        ssim_adder = Adder()

        for iter_idx, data in enumerate(dataloader):
            input_img, label_img, name = data

            input_img = input_img.to(device)

            h, w = input_img.shape[2], input_img.shape[3]
            H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor)
            padh = H-h if h%factor!=0 else 0
            padw = W-w if w%factor!=0 else 0
            input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect')

            tm = time.time()

            pred = model(input_img)[2]
            pred = pred[:,:,:h,:w]

            elapsed = time.time() - tm
            adder(elapsed)

            pred_clip = torch.clamp(pred, 0, 1)

            pred_numpy = pred_clip.squeeze(0).cpu().numpy()
            label_numpy = label_img.squeeze(0).cpu().numpy()


            label_img = (label_img).cuda()
            psnr_val = 10 * torch.log10(1 / f.mse_loss(pred_clip, label_img))
            down_ratio = max(1, round(min(H, W) / 256))
            ssim_val = ssim(f.adaptive_avg_pool2d(pred_clip, (int(H / down_ratio), int(W / down_ratio))),
                            f.adaptive_avg_pool2d(label_img, (int(H / down_ratio), int(W / down_ratio))),
                            data_range=1, size_average=False)
            print('%d iter PSNR_dehazing: %.2f ssim: %f' % (iter_idx + 1, psnr_val, ssim_val))
            ssim_adder(ssim_val)

            if args.save_image:
                save_name = os.path.join(args.result_dir, name[0])
                pred_clip += 0.5 / 255
                pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB')
                pred.save(save_name)

            psnr_mimo = peak_signal_noise_ratio(pred_numpy, label_numpy, data_range=1)
            psnr_adder(psnr_val)

            print('%d iter PSNR: %.2f time: %f' % (iter_idx + 1, psnr_mimo, elapsed))

        print('==========================================================')
        print('The average PSNR is %.2f dB' % (psnr_adder.average()))
        print('The average SSIM is %.5f dB' % (ssim_adder.average()))

        print("Average time: %f" % adder.average())

Layers
==

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class BasicConv(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
        super(BasicConv, self).__init__()
        if bias and norm:
            bias = False

        padding = kernel_size // 2
        layers = list()
        if transpose:
            padding = kernel_size // 2 -1
            layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
        else:
            layers.append(
                nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
        if norm:
            layers.append(nn.BatchNorm2d(out_channel))
        if relu:
            layers.append(nn.GELU())
        self.main = nn.Sequential(*layers)

    def forward(self, x):
        return self.main(x)


class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, filter=False):
        super(ResBlock, self).__init__()
        self.main = nn.Sequential(
            BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
            DeepPoolLayer(in_channel, out_channel) if filter else nn.Identity(),
            BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
        )

    def forward(self, x):
        return self.main(x) + x


class DeepPoolLayer(nn.Module):
    def __init__(self, k, k_out):
        super(DeepPoolLayer, self).__init__()
        self.pools_sizes = [8,4,2]
        pools, convs, dynas = [],[],[]
        for i in self.pools_sizes:
            pools.append(nn.AvgPool2d(kernel_size=i, stride=i))
            convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False))
            dynas.append(dynamic_filter(inchannels=k, kernel_size=3))
        self.pools = nn.ModuleList(pools)
        self.convs = nn.ModuleList(convs)
        self.dynas = nn.ModuleList(dynas)
        self.relu = nn.GELU()
        self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False)

    def forward(self, x):
        x_size = x.size()
        resl = x
        for i in range(len(self.pools_sizes)):
            if i == 0:
                y = self.dynas[i](self.convs[i](self.pools[i](x)))
            else:
                y = self.dynas[i](self.convs[i](self.pools[i](x)+y_up))
            resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True))
            if i != len(self.pools_sizes)-1:
                y_up = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
        resl = self.relu(resl)
        resl = self.conv_sum(resl)

        return resl

class dynamic_filter(nn.Module):
    def __init__(self, inchannels, kernel_size=3, stride=1, group=8):
        super(dynamic_filter, self).__init__()

        self.stride = stride
        self.kernel_size = kernel_size
        self.group = group

        self.conv = nn.Conv2d(inchannels, group*kernel_size**2, kernel_size=1, stride=1, bias=False)
        self.bn = nn.BatchNorm2d(group*kernel_size**2)
        self.act = nn.Tanh()

        nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
        self.lamb_l = nn.Parameter(torch.zeros(inchannels), requires_grad=True)
        self.lamb_h = nn.Parameter(torch.zeros(inchannels), requires_grad=True)
        self.pad = nn.ReflectionPad2d(kernel_size//2)

        self.ap = nn.AdaptiveAvgPool2d((1, 1))
        self.gap = nn.AdaptiveAvgPool2d(1)

        self.inside_all = nn.Parameter(torch.zeros(inchannels,1,1), requires_grad=True)

    def forward(self, x):
        identity_input = x
        # the Conv_{3x3} layer in eq.3 is included in DeepPoolLayer.convs
        low_filter = self.ap(x)
        low_filter = self.conv(low_filter)
        low_filter = self.bn(low_filter)

        n, c, h, w = x.shape
        x = F.unfold(self.pad(x), kernel_size=self.kernel_size).reshape(n, self.group, c//self.group, self.kernel_size**2, h*w)

        n,c1,p,q = low_filter.shape
        low_filter = low_filter.reshape(n, c1//self.kernel_size**2, self.kernel_size**2, p*q).unsqueeze(2)
        low_filter = self.act(low_filter)
        low_part = torch.sum(x * low_filter, dim=3).reshape(n, c, h, w)

        # the variables here are slightly different from the paper: (code) --> (paper)
        # low_filter --> A (eq.3)
        # In Eq.7, X*A'= X*(A_{l} + WA_{h})
        #              = X*A_{l} + WX*(A - A_{l})
        #              = X*A_{l} + WX*A - WX*A_{l}
        #              = WX*A - X*A_{l}(W-1)
        # we substitute gap for A_{l} for simplicity, which is a coarser low-frequency filter
        out_low = low_part * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input)

        out_low = out_low * self.lamb_l[None,:,None,None]
        out_high = (identity_input) * (self.lamb_h[None,:,None,None] + 1.)

        return out_low + out_high

IRNeXt
==

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F

class EBlock(nn.Module):
    def __init__(self, out_channel, num_res=8):
        super(EBlock, self).__init__()

        layers = [ResBlock(out_channel, out_channel) for _ in range(num_res-1)]
        layers.append(ResBlock(out_channel, out_channel, filter=True))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


class DBlock(nn.Module):
    def __init__(self, channel, num_res=8):
        super(DBlock, self).__init__()

        layers = [ResBlock(channel, channel) for _ in range(num_res-1)]
        layers.append(ResBlock(channel, channel, filter=True))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


class SCM(nn.Module):
    def __init__(self, out_plane):
        super(SCM, self).__init__()

        self.main = nn.Sequential(
            BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True),
            BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True),
            BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True),
            BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False),
            nn.InstanceNorm2d(out_plane, affine=True)
        )

    def forward(self, x):
        x = self.main(x)
        return x

class FAM(nn.Module):
    def __init__(self, channel):
        super(FAM, self).__init__()

        self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False)

    def forward(self, x1, x2):
        return self.merge(torch.cat([x1, x2], dim=1))

class IRNeXt(nn.Module):
    def __init__(self, num_res=4):
        super(IRNeXt, self).__init__()

        base_channel = 32
        self.Encoder = nn.ModuleList([
            EBlock(base_channel, num_res),
            EBlock(base_channel*2, num_res),
            EBlock(base_channel*4, num_res),
        ])

        self.feat_extract = nn.ModuleList([
            BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
            BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
            BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
            BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
            BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
            BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
        ])

        self.Decoder = nn.ModuleList([
            DBlock(base_channel * 4, num_res),
            DBlock(base_channel * 2, num_res),
            DBlock(base_channel, num_res)
        ])

        self.Convs = nn.ModuleList([
            BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
            BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
        ])

        self.ConvsOut = nn.ModuleList(
            [
                BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
                BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
            ]
        )

        self.FAM1 = FAM(base_channel * 4)
        self.SCM1 = SCM(base_channel * 4)
        self.FAM2 = FAM(base_channel * 2)
        self.SCM2 = SCM(base_channel * 2)

    def forward(self, x):
        x_2 = F.interpolate(x, scale_factor=0.5)
        x_4 = F.interpolate(x_2, scale_factor=0.5)
        z2 = self.SCM2(x_2)
        z4 = self.SCM1(x_4)

        outputs = list()
        # 256
        x_ = self.feat_extract[0](x)
        res1 = self.Encoder[0](x_)
        # 128
        z = self.feat_extract[1](res1)
        z = self.FAM2(z, z2)
        res2 = self.Encoder[1](z)
        # 64
        z = self.feat_extract[2](res2)
        z = self.FAM1(z, z4)
        z = self.Encoder[2](z)

        z = self.Decoder[0](z)
        z_ = self.ConvsOut[0](z)
        # 128
        z = self.feat_extract[3](z)
        outputs.append(z_+x_4)

        z = torch.cat([z, res2], dim=1)
        z = self.Convs[0](z)
        z = self.Decoder[1](z)
        z_ = self.ConvsOut[1](z)
        # 256
        z = self.feat_extract[4](z)
        outputs.append(z_+x_2)

        z = torch.cat([z, res1], dim=1)
        z = self.Convs[1](z)
        z = self.Decoder[2](z)
        z = self.feat_extract[5](z)
        outputs.append(z+x)

        return outputs


def build_net():
    return IRNeXt()

Valid
==

In [None]:
import torch
from torchvision.transforms import functional as F
import os
from skimage.metrics import peak_signal_noise_ratio
import torch.nn.functional as f


def _valid(model, args, ep):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    its = valid_dataloader(args.data_dir, batch_size=1, num_workers=0)
    model.eval()
    psnr_adder = Adder()

    with torch.no_grad():
        print('Start Evaluation')
        factor = 32
        for idx, data in enumerate(its):
            input_img, label_img = data
            input_img = input_img.to(device)

            h, w = input_img.shape[2], input_img.shape[3]
            H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor)
            padh = H-h if h%factor!=0 else 0
            padw = W-w if w%factor!=0 else 0
            input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect')

            if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))):
                os.mkdir(os.path.join(args.result_dir, '%d' % (ep)))

            pred = model(input_img)[2]
            pred = pred[:,:,:h,:w]

            pred_clip = torch.clamp(pred, 0, 1)
            p_numpy = pred_clip.squeeze(0).cpu().numpy()
            label_numpy = label_img.squeeze(0).cpu().numpy()

            psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1)

            psnr_adder(psnr)
            print('\r%03d'%idx, end=' ')

    print('\n')
    model.train()
    return psnr_adder.average()

TRAIN
==

In [None]:
import os
import torch
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import torch.nn as nn

from warmup_scheduler import GradualWarmupScheduler

def _train(model, args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    criterion = torch.nn.L1Loss()

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay)
    dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker)
    max_iter = len(dataloader)
    warmup_epochs=3
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch-warmup_epochs, eta_min=1e-6)
    scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
    scheduler.step()
    epoch = 1
    if args.resume:
        state = torch.load(args.resume)
        epoch = state['epoch']
        optimizer.load_state_dict(state['optimizer'])
        model.load_state_dict(state['model'])
        print('Resume from %d'%epoch)
        epoch += 1
        for i in range(epoch-1):
          scheduler.step()

    writer = SummaryWriter()
    epoch_pixel_adder = Adder()
    epoch_fft_adder = Adder()
    iter_pixel_adder = Adder()
    iter_fft_adder = Adder()
    epoch_timer = Timer('m')
    iter_timer = Timer('m')
    best_psnr=-1

    for epoch_idx in range(epoch, args.num_epoch + 1):

        epoch_timer.tic()
        iter_timer.tic()
        for iter_idx, batch_data in enumerate(dataloader):

            input_img, label_img = batch_data
            input_img = input_img.to(device)
            label_img = label_img.to(device)

            optimizer.zero_grad()
            pred_img = model(input_img)
            label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear')
            label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear')
            l1 = criterion(pred_img[0], label_img4)
            l2 = criterion(pred_img[1], label_img2)
            l3 = criterion(pred_img[2], label_img)
            loss_content = l1+l2+l3

            label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1))
            label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1)

            pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1))
            pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1)

            label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1))
            label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1)

            pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1))
            pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1)

            label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1))
            label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1)

            pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1))
            pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1)

            f1 = criterion(pred_fft1, label_fft1)
            f2 = criterion(pred_fft2, label_fft2)
            f3 = criterion(pred_fft3, label_fft3)
            loss_fft = f1+f2+f3

            loss = loss_content + 0.1 * loss_fft
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.001)
            optimizer.step()

            iter_pixel_adder(loss_content.item())
            iter_fft_adder(loss_fft.item())

            epoch_pixel_adder(loss_content.item())
            epoch_fft_adder(loss_fft.item())

            if (iter_idx + 1) % args.print_freq == 0:
                print("Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % (
                    iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(),
                    iter_fft_adder.average()))
                writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter)
                writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter)

                iter_timer.tic()
                iter_pixel_adder.reset()
                iter_fft_adder.reset()
        overwrite_name = os.path.join(args.model_save_dir, 'model.pkl')
        torch.save({'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': epoch_idx}, overwrite_name)

        if epoch_idx % args.save_freq == 0:
            save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx)
            torch.save({'model': model.state_dict()}, save_name)
        print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % (
            epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average()))
        epoch_fft_adder.reset()
        epoch_pixel_adder.reset()
        scheduler.step()
        if epoch_idx % args.valid_freq == 0:
            val = _valid(model, args, epoch_idx)
            print('%03d epoch \n Average PSNR %.2f dB' % (epoch_idx, val))
            writer.add_scalar('PSNR', val, epoch_idx)
            if val >= best_psnr:
                torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl'))
    save_name = os.path.join(args.model_save_dir, 'Final.pkl')
    torch.save({'model': model.state_dict()}, save_name)

Main (train model)
==

In [None]:
import os
import torch
import torchvision.transforms as transforms
# from torch.backends import cudnn # Uncomment if you need it

class Args:
    model_name = 'IRNeXt'
    mode = 'train'
    data_dir = path

    # Train
    batch_size = 4
    learning_rate = 1e-4
    weight_decay = 5e-4
    num_epoch = 300
    print_freq = 100
    num_worker = 8
    save_freq = 10
    valid_freq = 10
    resume = '/content/drive/MyDrive/reside-indoor/results/IRNeXt/ITS/model.pkl'

    # Test
    # test_model = ''
    # save_image = False

    # Directories (set these as per your requirement)
    model_save_dir = os.path.join('/content/drive/MyDrive/reside-indoor/results/', 'IRNeXt', 'ITS/')
    result_dir = os.path.join('/content/drive/MyDrive/reside-indoor/results/', model_name, 'test')

def main(args):
    # CUDNN
    # cudnn.benchmark = True # Uncomment if you need it

    if not os.path.exists('/content/drive/MyDrive/reside-indoor/results/'):
        os.makedirs(args.model_save_dir)
    if not os.path.exists('/content/drive/MyDrive/reside-indoor/results/' + args.model_name + '/'):
        os.makedirs('/content/drive/MyDrive/reside-indoor/results/' + args.model_name + '/')
    if not os.path.exists(args.model_save_dir):
        os.makedirs(args.model_save_dir)
    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)

    model = build_net()  # Make sure to define build_net or import it if it's defined elsewhere
    print(model)

    if torch.cuda.is_available():
        model.cuda()
    if args.mode == 'train':
        _train(model, args)  # Make sure to define _train or import it if it's defined elsewhere

    elif args.mode == 'test':
        _eval(model, args)   # Make sure to define _eval or import it if it's defined elsewhere

# Replace parser.parse_args() with an instance of the Args class
args = Args()
if not os.path.exists(args.model_save_dir):
    os.makedirs(args.model_save_dir)
# Copying files (make sure these paths are correct)
command = 'cp ' + 'models/layers.py ' + args.model_save_dir
os.system(command)
command = 'cp ' + 'models/IRNeXt.py ' + args.model_save_dir
os.system(command)
command = 'cp ' + 'train.py ' + args.model_save_dir
os.system(command)
command = 'cp ' + 'main.py ' + args.model_save_dir
os.system(command)
print(args)
main(args)


<__main__.Args object at 0x7fd867735360>
IRNeXt(
  (Encoder): ModuleList(
    (0): EBlock(
      (layers): Sequential(
        (0): ResBlock(
          (main): Sequential(
            (0): BasicConv(
              (main): Sequential(
                (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): GELU(approximate='none')
              )
            )
            (1): Identity()
            (2): BasicConv(
              (main): Sequential(
                (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              )
            )
          )
        )
        (1): ResBlock(
          (main): Sequential(
            (0): BasicConv(
              (main): Sequential(
                (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): GELU(approximate='none')
              )
            )
            (1): Identity()
            (2): BasicConv(
              (main): Sequential(
      

In [None]:
state_last = torch.load('/content/drive/MyDrive/reside-indoor/results1/IRNeXt/ITS/model.pkl')

In [None]:
!unzip  /content/drive/MyDrive/reside-indoor/SOTS.zip -d ./SOTS/

Archive:  /content/drive/MyDrive/reside-indoor/SOTS.zip
   creating: ./SOTS/SOTS/
   creating: ./SOTS/SOTS/indoor/
   creating: ./SOTS/SOTS/indoor/gt/
  inflating: ./SOTS/SOTS/indoor/gt/1401.png  
  inflating: ./SOTS/SOTS/indoor/gt/1400.png  
  inflating: ./SOTS/SOTS/indoor/gt/1402.png  
  inflating: ./SOTS/SOTS/indoor/gt/1403.png  
  inflating: ./SOTS/SOTS/indoor/gt/1404.png  
  inflating: ./SOTS/SOTS/indoor/gt/1405.png  
  inflating: ./SOTS/SOTS/indoor/gt/1406.png  
  inflating: ./SOTS/SOTS/indoor/gt/1407.png  
  inflating: ./SOTS/SOTS/indoor/gt/1408.png  
  inflating: ./SOTS/SOTS/indoor/gt/1409.png  
  inflating: ./SOTS/SOTS/indoor/gt/1410.png  
  inflating: ./SOTS/SOTS/indoor/gt/1411.png  
  inflating: ./SOTS/SOTS/indoor/gt/1412.png  
  inflating: ./SOTS/SOTS/indoor/gt/1413.png  
  inflating: ./SOTS/SOTS/indoor/gt/1414.png  
  inflating: ./SOTS/SOTS/indoor/gt/1415.png  
  inflating: ./SOTS/SOTS/indoor/gt/1416.png  
  inflating: ./SOTS/SOTS/indoor/gt/1417.png  
  inflating: ./SOTS/S

Main (test model)
==

In [None]:
import os
import torch
import torchvision.transforms as transforms
# from torch.backends import cudnn # Uncomment if you need it

class Args:
    model_name = 'IRNeXt'
    mode = 'test'
    data_dir = './SOTS/SOTS/indoor'

    # Train
    batch_size = 4
    learning_rate = 1e-4
    weight_decay = 0
    num_epoch = 5
    print_freq = 100
    num_worker = 8
    save_freq = 10
    valid_freq = 10
    resume = ''

    # Test
    test_model = '/content/drive/MyDrive/reside-indoor/results/IRNeXt/ITS/Best.pkl'
    save_image = False

    # Directories (set these as per your requirement)
    model_save_dir = os.path.join('/content/drive/MyDrive/reside-indoor/results/', 'IRNeXt', 'ITS/')
    result_dir = os.path.join('/content/drive/MyDrive/reside-indoor/results/', model_name, 'test')

def main(args):
    # CUDNN
    # cudnn.benchmark = True # Uncomment if you need it

    if not os.path.exists('/content/drive/MyDrive/reside-indoor/results/'):
        os.makedirs(args.model_save_dir)
    if not os.path.exists('/content/drive/MyDrive/reside-indoor/results/' + args.model_name + '/'):
        os.makedirs('/content/drive/MyDrive/reside-indoor/results/' + args.model_name + '/')
    if not os.path.exists(args.model_save_dir):
        os.makedirs(args.model_save_dir)
    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)

    model = build_net()  # Make sure to define build_net or import it if it's defined elsewhere
    print(model)

    if torch.cuda.is_available():
        model.cuda()
    if args.mode == 'train':
        _train(model, args)  # Make sure to define _train or import it if it's defined elsewhere

    elif args.mode == 'test':
        _eval(model, args)   # Make sure to define _eval or import it if it's defined elsewhere

# Replace parser.parse_args() with an instance of the Args class
args = Args()
if not os.path.exists(args.model_save_dir):
    os.makedirs(args.model_save_dir)
# Copying files (make sure these paths are correct)
command = 'cp ' + 'models/layers.py ' + args.model_save_dir
os.system(command)
command = 'cp ' + 'models/IRNeXt.py ' + args.model_save_dir
os.system(command)
command = 'cp ' + 'train.py ' + args.model_save_dir
os.system(command)
command = 'cp ' + 'main.py ' + args.model_save_dir
os.system(command)
print(args)
main(args)


<__main__.Args object at 0x7c4cc873a200>
IRNeXt(
  (Encoder): ModuleList(
    (0): EBlock(
      (layers): Sequential(
        (0): ResBlock(
          (main): Sequential(
            (0): BasicConv(
              (main): Sequential(
                (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): GELU(approximate='none')
              )
            )
            (1): Identity()
            (2): BasicConv(
              (main): Sequential(
                (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              )
            )
          )
        )
        (1): ResBlock(
          (main): Sequential(
            (0): BasicConv(
              (main): Sequential(
                (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): GELU(approximate='none')
              )
            )
            (1): Identity()
            (2): BasicConv(
              (main): Sequential(
      