In [1]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
!python --version

Python 3.10.12


Run it for testing

In [None]:
# Set default values
no_gpu = False
save_name = './results_tmp'
load_name = '/content/derain/Models/models_effderain/models_rain1400/KPN_rainy_image_epoch60_bs4.pth'
test_batch_size = 1
num_workers = 1
color = True
burst_length = 1
blind_est = True
kernel_size = [3]
sep_conv = False
channel_att = False
spatial_att = False
upMode = 'bilinear'
core_bias = False
init_type = 'xavier'
init_gain = 0.02
baseroot = 'rainy_image_dataset/testing'
crop = False
crop_size = 512
geometry_aug = False
angle_aug = False
scale_min = 1.0
scale_max = 1.0
add_noise = False
mu = 0
sigma = 30


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# ----------------------------------------
#         Initialize the networks
# ----------------------------------------
def weights_init(net, init_type = 'normal', init_gain = 0.02):
    """Initialize network weights.
    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal
    In our paper, we choose the default setting: zero mean Gaussian distribution with a standard deviation of 0.02
    """
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and classname.find('Conv') != -1:
            if init_type == 'normal':
                torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                torch.nn.init.xavier_normal_(m.weight.data, gain = init_gain)
            elif init_type == 'kaiming':
                torch.nn.init.kaiming_normal_(m.weight.data, a = 0, mode = 'fan_in')
            elif init_type == 'orthogonal':
                torch.nn.init.orthogonal_(m.weight.data, gain = init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
        elif classname.find('BatchNorm2d') != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)

    # apply the initialization function <init_func>
    print('initialize network with %s type' % init_type)
    net.apply(init_func)

# ----------------------------------------
#      Kernel Prediction Network (KPN)
# ----------------------------------------
class Basic(nn.Module):
    def __init__(self, in_ch, out_ch, g=16, channel_att=False, spatial_att=False):
        super(Basic, self).__init__()
        self.channel_att = channel_att
        self.spatial_att = spatial_att
        self.conv1 = nn.Sequential(
                nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1),
                # nn.BatchNorm2d(out_ch),
                nn.ReLU(),
                nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1),
                # nn.BatchNorm2d(out_ch),
                nn.ReLU(),
                nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1),
                # nn.BatchNorm2d(out_ch),
                nn.ReLU()
            )

        if channel_att:
            self.att_c = nn.Sequential(
                nn.Conv2d(2*out_ch, out_ch//g, 1, 1, 0),
                nn.ReLU(),
                nn.Conv2d(out_ch//g, out_ch, 1, 1, 0),
                nn.Sigmoid()
            )
        if spatial_att:
            self.att_s = nn.Sequential(
                nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3),
                nn.Sigmoid()
            )

    def forward(self, data):
        """
        Forward function.
        :param data:
        :return: tensor
        """
        fm = self.conv1(data)
        if self.channel_att:
            # fm_pool = F.adaptive_avg_pool2d(fm, (1, 1)) + F.adaptive_max_pool2d(fm, (1, 1))
            fm_pool = torch.cat([F.adaptive_avg_pool2d(fm, (1, 1)), F.adaptive_max_pool2d(fm, (1, 1))], dim=1)
            att = self.att_c(fm_pool)
            fm = fm * att
        if self.spatial_att:
            fm_pool = torch.cat([torch.mean(fm, dim=1, keepdim=True), torch.max(fm, dim=1, keepdim=True)[0]], dim=1)
            att = self.att_s(fm_pool)
            fm = fm * att
        return fm

class KPN(nn.Module):
    def __init__(self, color=True, burst_length=1, blind_est=True, kernel_size=[5], sep_conv=False,
                 channel_att=False, spatial_att=False, upMode='bilinear', core_bias=False):
        super(KPN, self).__init__()
        self.upMode = upMode
        self.burst_length = burst_length
        self.core_bias = core_bias
        self.color_channel = 3 if color else 1
        in_channel = (3 if color else 1) * (burst_length if blind_est else burst_length+1)
        out_channel = (3 if color else 1) * (2 * sum(kernel_size) if sep_conv else np.sum(np.array(kernel_size) ** 2)) * burst_length
        if core_bias:
            out_channel += (3 if color else 1) * burst_length
        # 各个卷积层定义
        # 2~5层都是均值池化+3层卷积
        self.conv1 = Basic(in_channel, 64, channel_att=False, spatial_att=False)
        self.conv2 = Basic(64, 128, channel_att=False, spatial_att=False)
        self.conv3 = Basic(128, 256, channel_att=False, spatial_att=False)
        self.conv4 = Basic(256, 512, channel_att=False, spatial_att=False)
        self.conv5 = Basic(512, 512, channel_att=False, spatial_att=False)
        # 6~8层要先上采样再卷积
        self.conv6 = Basic(512+512, 512, channel_att=channel_att, spatial_att=spatial_att)
        self.conv7 = Basic(256+512, 256, channel_att=channel_att, spatial_att=spatial_att)
        self.conv8 = Basic(256+128, out_channel, channel_att=channel_att, spatial_att=spatial_att)
        self.outc = nn.Conv2d(out_channel, out_channel, 1, 1, 0)

        self.kernel_pred = KernelConv(kernel_size, sep_conv, self.core_bias)

        self.conv_final = nn.Conv2d(in_channels=12, out_channels=3, kernel_size=3, stride=1, padding=1)

    # 前向传播函数
    def forward(self, data_with_est, data, white_level=1.0):
        """
        forward and obtain pred image directly
        :param data_with_est: if not blind estimation, it is same as data
        :param data:
        :return: pred_img_i and img_pred
        """
        conv1 = self.conv1(data_with_est)
        conv2 = self.conv2(F.avg_pool2d(conv1, kernel_size=2, stride=2))
        conv3 = self.conv3(F.avg_pool2d(conv2, kernel_size=2, stride=2))
        conv4 = self.conv4(F.avg_pool2d(conv3, kernel_size=2, stride=2))
        conv5 = self.conv5(F.avg_pool2d(conv4, kernel_size=2, stride=2))
        # 开始上采样  同时要进行skip connection
        conv6 = self.conv6(torch.cat([conv4, F.interpolate(conv5, scale_factor=2, mode=self.upMode)], dim=1))
        conv7 = self.conv7(torch.cat([conv3, F.interpolate(conv6, scale_factor=2, mode=self.upMode)], dim=1))
        #print(conv7.size())
        conv8 = self.conv8(torch.cat([conv2, F.interpolate(conv7, scale_factor=2, mode=self.upMode)], dim=1))
        # return channel K*K*N
        core = self.outc(F.interpolate(conv8, scale_factor=2, mode=self.upMode))

        pred1 = self.kernel_pred(data, core, white_level, rate=1)
        pred2 = self.kernel_pred(data, core, white_level, rate=2)
        pred3 = self.kernel_pred(data, core, white_level, rate=3)
        pred4 = self.kernel_pred(data, core, white_level, rate=4)

        pred_cat = torch.cat([torch.cat([torch.cat([pred1, pred2], dim=1), pred3], dim=1), pred4], dim=1)

        pred = self.conv_final(pred_cat)

        #pred = self.kernel_pred(data, core, white_level, rate=1)

        return pred

class KernelConv(nn.Module):
    """
    the class of computing prediction
    """
    def __init__(self, kernel_size=[5], sep_conv=False, core_bias=False):
        super(KernelConv, self).__init__()
        self.kernel_size = sorted(kernel_size)
        self.sep_conv = sep_conv
        self.core_bias = core_bias

    def _sep_conv_core(self, core, batch_size, N, color, height, width):
        """
        convert the sep_conv core to conv2d core
        2p --> p^2
        :param core: shape: batch*(N*2*K)*height*width
        :return:
        """
        kernel_total = sum(self.kernel_size)
        core = core.view(batch_size, N, -1, color, height, width)
        if not self.core_bias:
            core_1, core_2 = torch.split(core, kernel_total, dim=2)
        else:
            core_1, core_2, core_3 = torch.split(core, kernel_total, dim=2)
        # output core
        core_out = {}
        cur = 0
        for K in self.kernel_size:
            t1 = core_1[:, :, cur:cur + K, ...].view(batch_size, N, K, 1, 3, height, width)
            t2 = core_2[:, :, cur:cur + K, ...].view(batch_size, N, 1, K, 3, height, width)
            core_out[K] = torch.einsum('ijklno,ijlmno->ijkmno', [t1, t2]).view(batch_size, N, K * K, color, height, width)
            cur += K
        # it is a dict
        return core_out, None if not self.core_bias else core_3.squeeze()

    def _convert_dict(self, core, batch_size, N, color, height, width):
        """
        make sure the core to be a dict, generally, only one kind of kernel size is suitable for the func.
        :param core: shape: batch_size*(N*K*K)*height*width
        :return: core_out, a dict
        """
        core_out = {}
        core = core.view(batch_size, N, -1, color, height, width)
        core_out[self.kernel_size[0]] = core[:, :, 0:self.kernel_size[0]**2, ...]
        bias = None if not self.core_bias else core[:, :, -1, ...]
        return core_out, bias

    def forward(self, frames, core, white_level=1.0, rate=1):
        """
        compute the pred image according to core and frames
        :param frames: [batch_size, N, 3, height, width]
        :param core: [batch_size, N, dict(kernel), 3, height, width]
        :return:
        """
        if len(frames.size()) == 5:
            batch_size, N, color, height, width = frames.size()
        else:
            batch_size, N, height, width = frames.size()
            color = 1
            frames = frames.view(batch_size, N, color, height, width)
        if self.sep_conv:
            core, bias = self._sep_conv_core(core, batch_size, N, color, height, width)
        else:
            core, bias = self._convert_dict(core, batch_size, N, color, height, width)
        img_stack = []
        pred_img = []
        kernel = self.kernel_size[::-1]
        for index, K in enumerate(kernel):
            if not img_stack:
                padding_num = (K//2) * rate
                frame_pad = F.pad(frames, [padding_num, padding_num, padding_num, padding_num])
                for i in range(0, K):
                    for j in range(0, K):
                        img_stack.append(frame_pad[..., i*rate:i*rate + height, j*rate:j*rate + width])
                img_stack = torch.stack(img_stack, dim=2)
            else:
                k_diff = (kernel[index - 1] - kernel[index]) // 2
                img_stack = img_stack[:, :, k_diff:-k_diff, ...]
            # print('img_stack:', img_stack.size())
            pred_img.append(torch.sum(
                core[K].mul(img_stack), dim=2, keepdim=False
            ))
        pred_img = torch.stack(pred_img, dim=0)
        # print('pred_stack:', pred_img.size())
        pred_img_i = torch.mean(pred_img, dim=0, keepdim=False)
        #print("pred_img_i", pred_img_i.size())
        # N = 1
        pred_img_i = pred_img_i.squeeze(2)
        #print("pred_img_i", pred_img_i.size())
        # if bias is permitted
        if self.core_bias:
            if bias is None:
                raise ValueError('The bias should not be None.')
            pred_img_i += bias
        # print('white_level', white_level.size())
        pred_img_i = pred_img_i / white_level
        #pred_img = torch.mean(pred_img_i, dim=1, keepdim=True)
        # print('pred_img:', pred_img.size())
        # print('pred_img_i:', pred_img_i.size())
        return pred_img_i

class LossFunc(nn.Module):
    """
    loss function of KPN
    """
    def __init__(self, coeff_basic=1.0, coeff_anneal=1.0, gradient_L1=True, alpha=0.9998, beta=100):
        super(LossFunc, self).__init__()
        self.coeff_basic = coeff_basic
        self.coeff_anneal = coeff_anneal
        self.loss_basic = LossBasic(gradient_L1)
        self.loss_anneal = LossAnneal(alpha, beta)

    def forward(self, pred_img_i, pred_img, ground_truth, global_step):
        """
        forward function of loss_func
        :param frames: frame_1 ~ frame_N, shape: [batch, N, 3, height, width]
        :param core: a dict coverted by ......
        :param ground_truth: shape [batch, 3, height, width]
        :param global_step: int
        :return: loss
        """
        return self.coeff_basic * self.loss_basic(pred_img, ground_truth), self.coeff_anneal * self.loss_anneal(global_step, pred_img_i, ground_truth)

class LossBasic(nn.Module):
    """
    Basic loss function.
    """
    def __init__(self, gradient_L1=True):
        super(LossBasic, self).__init__()
        self.l1_loss = nn.L1Loss()
        self.l2_loss = nn.MSELoss()
        self.gradient = TensorGradient(gradient_L1)

    def forward(self, pred, ground_truth):
        return self.l2_loss(pred, ground_truth) + \
               self.l1_loss(self.gradient(pred), self.gradient(ground_truth))

class LossAnneal(nn.Module):
    """
    anneal loss function
    """
    def __init__(self, alpha=0.9998, beta=100):
        super(LossAnneal, self).__init__()
        self.global_step = 0
        self.loss_func = LossBasic(gradient_L1=True)
        self.alpha = alpha
        self.beta = beta

    def forward(self, global_step, pred_i, ground_truth):
        """
        :param global_step: int
        :param pred_i: [batch_size, N, 3, height, width]
        :param ground_truth: [batch_size, 3, height, width]
        :return:
        """
        loss = 0
        for i in range(pred_i.size(1)):
            loss += self.loss_func(pred_i[:, i, ...], ground_truth)
        loss /= pred_i.size(1)
        return self.beta * self.alpha ** global_step * loss

class TensorGradient(nn.Module):
    """
    the gradient of tensor
    """
    def __init__(self, L1=True):
        super(TensorGradient, self).__init__()
        self.L1 = L1

    def forward(self, img):
        w, h = img.size(-2), img.size(-1)
        l = F.pad(img, [1, 0, 0, 0])
        r = F.pad(img, [0, 1, 0, 0])
        u = F.pad(img, [0, 0, 1, 0])
        d = F.pad(img, [0, 0, 0, 1])
        if self.L1:
            return torch.abs((l - r)[..., 0:w, 0:h]) + torch.abs((u - d)[..., 0:w, 0:h])
        else:
            return torch.sqrt(
                torch.pow((l - r)[..., 0:w, 0:h], 2) + torch.pow((u - d)[..., 0:w, 0:h], 2)
            )

if __name__ == '__main__':

    kpn = KPN().cuda()
    a = torch.randn(4, 3, 224, 224).cuda()
    b = kpn(a, a)
    print(b.shape)


torch.Size([4, 3, 224, 224])


In [14]:
import os
import cv2
import skimage
import numpy as np
import torch
import torch.nn as nn
import torchvision as tv
import math

# ----------------------------------------
#                 Network
# ----------------------------------------
def create_generator():
    # Initialize the network
    generator = KPN(color, burst_length, blind_est, kernel_size, sep_conv, \
        channel_att, spatial_att, upMode, core_bias)
    if load_name == '':
        # Init the network
        weights_init(generator, init_type = init_type, init_gain = init_gain)
        print('Generator is created!')
    else:
        # Load a pre-trained network
        pretrained_net = torch.load(load_name)
        load_dict(generator, pretrained_net)
        print('Generator is loaded!')
    return generator

def load_dict(process_net, pretrained_net):
    # Get the dict from pre-trained network
    pretrained_dict = pretrained_net
    # Get the dict from processing network
    process_dict = process_net.state_dict()
    # Delete the extra keys of pretrained_dict that do not belong to process_dict
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in process_dict}
    # Update process_dict using pretrained_dict
    process_dict.update(pretrained_dict)
    # Load the updated dict to processing network
    process_net.load_state_dict(process_dict)
    return process_net

# ----------------------------------------
#    Validation and Sample at training
# ----------------------------------------
def save_sample_png(sample_folder, sample_name, img_list, name_list, pixel_max_cnt = 255, height = -1, width = -1):
    # Save image one-by-one
    for i in range(len(img_list)):
        img = img_list[i]
        # Recover normalization
        img = img * 255.0
        # Process img_copy and do not destroy the data of img
        #print(img.size())
        img_copy = img.clone().data.permute(0, 2, 3, 1).cpu().numpy()
        img_copy = np.clip(img_copy, 0, pixel_max_cnt)
        img_copy = img_copy.astype(np.uint8)[0, :, :, :]
        img_copy = cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB)
        if (height != -1) and (width != -1):
            img_copy = cv2.resize(img_copy, (width, height))
        # Save to certain path
        save_img_name = sample_name + '_' + name_list[i] + '.png'
        save_img_path = os.path.join(sample_folder, save_img_name)
        cv2.imwrite(save_img_path, img_copy)

def save_sample_png_test(sample_folder, sample_name, img_list, name_list, pixel_max_cnt = 255):
    # Save image one-by-one
    for i in range(len(img_list)):
        img = img_list[i]
        # Recover normalization
        img = img * 255.0
        # Process img_copy and do not destroy the data of img
        img_copy = img.clone().data.permute(0, 2, 3, 1).cpu().numpy()
        img_copy = np.clip(img_copy, 0, pixel_max_cnt)
        img_copy = img_copy.astype(np.uint8)[0, :, :, :]
        img_copy = img_copy.astype(np.float32)
        img_copy = cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB)
        # Save to certain path
        save_img_name = sample_name + '_' + name_list[i] + '.png'
        save_img_path = os.path.join(sample_folder, save_img_name)
        cv2.imwrite(save_img_path, img_copy)

def recover_process(img, height = -1, width = -1):
    img = img * 255.0
    img_copy = img.clone().data.permute(0, 2, 3, 1).cpu().numpy()
    img_copy = np.clip(img_copy, 0, 255)
    img_copy = img_copy.astype(np.uint8)[0, :, :, :]
    img_copy = img_copy.astype(np.float32)
    img_copy = cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB)
    if (height != -1) and (width != -1):
        img_copy = cv2.resize(img_copy, (width, height))
    return img_copy

def psnr(pred, target):
    #print(pred.shape)
    #print(target.shape)
    mse = np.mean( (pred - target) ** 2 )
    if mse == 0:
        return 100
    PIXEL_MAX = 255.0
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

def grey_psnr(pred, target, pixel_max_cnt = 255):
    pred = torch.sum(pred, dim = 0)
    target = torch.sum(target, dim = 0)
    mse = torch.mul(target - pred, target - pred)
    rmse_avg = (torch.mean(mse).item()) ** 0.5
    p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
    return p

def ssim(pred, target):
    pred = pred.clone().data.permute(0, 2, 3, 1).cpu().numpy()
    target = target.clone().data.permute(0, 2, 3, 1).cpu().numpy()
    target = target[0]
    pred = pred[0]
    ssim = skimage.measure.compare_ssim(target, pred, multichannel = True)
    return ssim

# ----------------------------------------
#             PATH processing
# ----------------------------------------
def check_path(path):
    if not os.path.exists(path):
        os.makedirs(path)

def savetxt(name, loss_log):
    np_loss_log = np.array(loss_log)
    np.savetxt(name, np_loss_log)


#rain100H/L / SPA
def get_files(path):
    if ('rain1400' in path):
        ret = []
        path_rainy = path + "/rainy_image"
        path_gt = path + "/ground_truth"

        for root, dirs, files in os.walk(path_gt):
            files.sort()
            for name in files:
                if name.split('.')[1] != "jpg":
                    continue
                id = name.split('.')[0]
                file_gt = path_gt + "/" + id + ".jpg"
                for i in range(1, 15):
                    file_rainy = path_rainy + "/" + id + "_" + str(i) + ".jpg"
                    ret.append([file_rainy, file_gt])
        return ret
    else:
        ret = []
        path_rainy = path + "/rain"
        path_gt = path + "/norain"

        for root, dirs, files in os.walk(path_rainy):
            files.sort()

            for name in files:
                if name.split('.')[1] != 'png':
                    continue
                file_rainy = path_rainy + "/" + name
                file_gt = path_gt + "/" + name
                ret.append([file_rainy, file_gt])
        return ret

def get_jpgs(path):
    # read a folder, return the image name
    ret = []
    for root, dirs, files in os.walk(path):
        for filespath in files:
            ret.append(filespath)
    return ret

def get_last_2paths(path):
    # read a folder, return the image name
    ret = []
    for root, dirs, files in os.walk(path):
        for filespath in files:
            if filespath[-4:] == '.png':
                wholepath = os.path.join(root, filespath)
                last_2paths = os.path.join(wholepath.split('/')[-2], wholepath.split('/')[-1])
                ret.append(last_2paths)
    return ret

def text_readlines(filename):
    # Try to read a txt file and return a list.Return [] if there was a mistake.
    try:
        file = open(filename, 'r')
    except IOError:
        error = []
        return error
    content = file.readlines()
    # This for loop deletes the EOF (like \n)
    for i in range(len(content)):
        content[i] = content[i][:len(content[i])-1]
    file.close()
    return content

def text_save(content, filename, mode = 'a'):
    # save a list to a txt
    # Try to save a list variable in txt file.
    file = open(filename, mode)
    for i in range(len(content)):
        file.write(str(content[i]))
    file.close()


In [7]:
"""Base augmentations operators."""

import numpy as np
from PIL import Image, ImageOps, ImageEnhance

# ImageNet code should change this value
#IMAGE_SIZE = 256


def int_parameter(level, maxval):
  """Helper function to scale `val` between 0 and maxval .

  Args:
    level: Level of the operation that will be between [0, `PARAMETER_MAX`].
    maxval: Maximum value that the operation can have. This will be scaled to
      level/PARAMETER_MAX.

  Returns:
    An int that results from scaling `maxval` according to `level`.
  """
  return int(level * maxval / 10)


def float_parameter(level, maxval):
  """Helper function to scale `val` between 0 and maxval.

  Args:
    level: Level of the operation that will be between [0, `PARAMETER_MAX`].
    maxval: Maximum value that the operation can have. This will be scaled to
      level/PARAMETER_MAX.

  Returns:
    A float that results from scaling `maxval` according to `level`.
  """
  return float(level) * maxval / 10.


def sample_level(n):
  return np.random.uniform(low=0.1, high=n)


def autocontrast(pil_img, _):
  return ImageOps.autocontrast(pil_img)


def equalize(pil_img, _):
  return ImageOps.equalize(pil_img)


def posterize(pil_img, level):
  level = int_parameter(sample_level(level), 4)
  return ImageOps.posterize(pil_img, 4 - level)


def rotate(pil_img, level):
  degrees = int_parameter(sample_level(level), 30)
  if np.random.uniform() > 0.5:
    degrees = -degrees
  return pil_img.rotate(degrees, resample=Image.BILINEAR)


def solarize(pil_img, level):
  level = int_parameter(sample_level(level), 256)
  return ImageOps.solarize(pil_img, 256 - level)


def shear_x(pil_img, level):
  level = float_parameter(sample_level(level), 0.3)
  if np.random.uniform() > 0.5:
    level = -level
  return pil_img.transform((pil_img.width, pil_img.height),
                           Image.AFFINE, (1, level, 0, 0, 1, 0),
                           resample=Image.BILINEAR)


def shear_y(pil_img, level):
  level = float_parameter(sample_level(level), 0.3)
  if np.random.uniform() > 0.5:
    level = -level
  return pil_img.transform((pil_img.width, pil_img.height),
                           Image.AFFINE, (1, 0, 0, level, 1, 0),
                           resample=Image.BILINEAR)


def translate_x(pil_img, level):
  level = int_parameter(sample_level(level), pil_img.width / 3)
  if np.random.random() > 0.5:
    level = -level
  return pil_img.transform((pil_img.width, pil_img.height),
                           Image.AFFINE, (1, 0, level, 0, 1, 0),
                           resample=Image.BILINEAR)


def translate_y(pil_img, level):
  level = int_parameter(sample_level(level), pil_img.height / 3)
  if np.random.random() > 0.5:
    level = -level
  return pil_img.transform((pil_img.width, pil_img.height),
                           Image.AFFINE, (1, 0, 0, 0, 1, level),
                           resample=Image.BILINEAR)


# operation that overlaps with ImageNet-C's test set
def color(pil_img, level):
    level = float_parameter(sample_level(level), 1.8) + 0.1
    return ImageEnhance.Color(pil_img).enhance(level)


# operation that overlaps with ImageNet-C's test set
def contrast(pil_img, level):
    level = float_parameter(sample_level(level), 1.8) + 0.1
    return ImageEnhance.Contrast(pil_img).enhance(level)


# operation that overlaps with ImageNet-C's test set
def brightness(pil_img, level):
    level = float_parameter(sample_level(level), 1.8) + 0.1
    return ImageEnhance.Brightness(pil_img).enhance(level)


# operation that overlaps with ImageNet-C's test set
def sharpness(pil_img, level):
    level = float_parameter(sample_level(level), 1.8) + 0.1
    return ImageEnhance.Sharpness(pil_img).enhance(level)

def zoom_x(pil_img, level):
    level = float_parameter(sample_level(level), 6.0)
    rate = 1.0/level
    if np.random.random() > 0.5:
        bias = pil_img.width*(1-rate)
    else:
        bias = 0
    return pil_img.transform((pil_img.width, pil_img.height),
                           Image.AFFINE, (rate, 0, bias, 0, 1, 0),
                           resample=Image.BILINEAR)


def zoom_y(pil_img, level):
    level = float_parameter(sample_level(level), 6.0)
    rate = 1.0/level
    if np.random.random() > 0.5:
        bias = pil_img.height*(1-rate)
    else:
        bias = 0
    return pil_img.transform((pil_img.width, pil_img.height),
                           Image.AFFINE, (1, 0, 0, 0, rate, bias),
                           resample=Image.BILINEAR)


augmentations = [
    rotate, shear_x, shear_y,
    translate_x, translate_y, zoom_x, zoom_y
]

'''
augmentations = [
    autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
    translate_x, translate_y
]
'''

augmentations_all = [
    autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
    translate_x, translate_y, color, contrast, brightness, sharpness
]


In [8]:
"""Reference implementation of AugMix's data augmentation method in numpy."""
import numpy as np
import random
from PIL import Image

# CIFAR-10 constants
MEAN = [0.4914, 0.4822, 0.4465]
STD = [0.2023, 0.1994, 0.2010]


def normalize(image):
  """Normalize input image channel-wise to zero mean and unit variance."""
  '''
  image = image.transpose(2, 0, 1)  # Switch to channel-first
  mean, std = np.array(MEAN), np.array(STD)
  image = (image - mean[:, None, None]) / std[:, None, None]
  return image.transpose(1, 2, 0)
  '''
  return image

def apply_op(image, op, severity):
  image = np.clip(image * 255., 0, 255).astype(np.uint8)
  pil_img = Image.fromarray(image)  # Convert to PIL.Image
  pil_img = op(pil_img, severity)
  return np.asarray(pil_img) / 255.


def augment_and_mix(image, severity=3, width=3, depth=-1, alpha=1.):
  """Perform AugMix augmentations and compute mixture.

  Args:
    image: Raw input image as float32 np.ndarray of shape (h, w, c)
    severity: Severity of underlying augmentation operators (between 1 to 10).
    width: Width of augmentation chain
    depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
      from [1, 3]
    alpha: Probability coefficient for Beta and Dirichlet distributions.

  Returns:
    mixed: Augmented and mixed image.
  """
  ws = np.float32(
      np.random.dirichlet([alpha] * width))
  m = np.float32(np.random.beta(alpha, alpha))

  mix = np.zeros_like(image)
  for i in range(width):
    image_aug = image.copy()
    depth = depth if depth > 0 else np.random.randint(2, 4)
    for _ in range(depth):
      op = np.random.choice(augmentations)
      #print(op)
      image_aug = apply_op(image_aug, op, severity)
    # Preprocessing commutes since all coefficients are convex
    mix += ws[i] * normalize(image_aug)

  max_ws = max(ws)
  rate = 1.0 / max_ws
  #print(rate)


  #mixed = (random.randint(5000, 9000)/10000) * normalize(image) + (random.randint((int)(rate*3000), (int)(rate*10000))/10000) * mix
  mixed = max((1 - m), 0.7) * normalize(image) + max(m, rate*0.5) * mix
  #mixed = (1 - m) * normalize(image) + m * mix
  return mixed



In [9]:
import os
import random
import numpy as np
import cv2
import math
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class RandomCrop(object):
    def __init__(self, image_size, crop_size):
        self.ch, self.cw = crop_size
        ih, iw = image_size

        self.h1 = random.randint(0, ih - self.ch)
        self.w1 = random.randint(0, iw - self.cw)

        self.h2 = self.h1 + self.ch
        self.w2 = self.w1 + self.cw

    def __call__(self, img):
        if len(img.shape) == 3:
            return img[self.h1: self.h2, self.w1: self.w2, :]
        else:
            return img[self.h1: self.h2, self.w1: self.w2]

class DenoisingDataset(Dataset):
    def __init__(self):                                   		    # root: list ; transform: torch transform
        self.imglist = get_files(baseroot)
        self.rainaug = rainaug
        '''
        for pair in self.imglist:
            print(pair[0] + ' | ' + pair[1])
        '''
    def getRainLayer2(self, rand_id1, rand_id2):
        path_img_rainlayer_src = "./rainmix/Streaks_Garg06/" + str(rand_id1) + "-" + str(rand_id2) + ".png"
        rainlayer_rand = cv2.imread(path_img_rainlayer_src).astype(np.float32) / 255.0
        rainlayer_rand = cv2.cvtColor(rainlayer_rand, cv2.COLOR_BGR2RGB)
        return rainlayer_rand

    def getRandRainLayer2(self):
        rand_id1 = random.randint(1, 165)
        rand_id2 = random.randint(4, 8)
        rainlayer_rand = self.getRainLayer2(rand_id1, rand_id2)
        return rainlayer_rand

    def rain_aug(self, img_rainy, img_gt):
        img_rainy = (img_rainy.astype(np.float32)) / 255.0
        img_gt = (img_gt.astype(np.float32)) / 255.0
        if random.randint(0, 10) > 3:
            img_rainy_ret = img_rainy
        else:
            img_rainy_ret = img_gt
        img_gt_ret = img_gt

        rainlayer_rand2 = self.getRandRainLayer2()
        rainlayer_aug2 = augment_and_mix(rainlayer_rand2, severity = 3, width = 3, depth = -1) * 1
        #rainlayer_rand2ex = self.getRandRainLayer2()
        #rainlayer_aug2ex = augment_and_mix.augment_and_mix(rainlayer_rand2ex, severity = 3, width = 3, depth = -1) * 1

        height = min(img_gt.shape[0], rainlayer_aug2.shape[0])
        width = min(img_gt.shape[1], rainlayer_aug2.shape[1])
        #height = min(img_gt.shape[0], min(rainlayer_aug2.shape[0], rainlayer_aug2ex.shape[0]))
        #width = min(img_gt.shape[1], min(rainlayer_aug2.shape[1], rainlayer_aug2ex.shape[1]))

        cropper = RandomCrop(rainlayer_aug2.shape[:2], (height, width))
        rainlayer_aug2_crop = cropper(rainlayer_aug2)
        #cropper = RandomCrop(rainlayer_aug2ex.shape[:2], (height, width))
        #rainlayer_aug2ex_crop = cropper(rainlayer_aug2ex)
        #print(height, width, rainlayer_aug2_crop.shape, rainlayer_aug2ex_crop.shape)
        #rainlayer_aug2_crop = rainlayer_aug2_crop + rainlayer_aug2ex_crop
        cropper = RandomCrop(img_gt_ret.shape[:2], (height, width))
        img_rainy_ret = cropper(img_rainy_ret)
        img_gt_ret = cropper(img_gt_ret)
        img_rainy_ret = img_rainy_ret + rainlayer_aug2_crop - img_rainy_ret*rainlayer_aug2_crop
        np.clip(img_rainy_ret, 0.0, 1.0)

        img_rainy_ret = img_rainy_ret * 255
        img_gt_ret = img_gt_ret * 255

        #cv2.imwrite("./temp/temp.jpg", cv2.cvtColor(img_rainy_ret, cv2.COLOR_RGB2BGR))

        return img_rainy_ret, img_gt_ret

    def __getitem__(self, index):
        ## read an image
        img_rainy = cv2.imread(self.imglist[index][0])
        img_gt = cv2.imread(self.imglist[index][1])

        img_rainy = cv2.cvtColor(img_rainy, cv2.COLOR_BGR2RGB)

        img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB)

        if self.rainaug:
            img_rainy, img_gt = self.rain_aug(img_rainy, img_gt)

        # random crop
        cropper = RandomCrop(img_gt.shape[:2], (crop_size, crop_size))
        img_rainy = cropper(img_rainy)
        img_gt = cropper(img_gt)
        # random rotate and horizontal flip
        # according to paper, these two data augmentation methods are recommended
        if angle_aug:
            rotate = random.randint(0, 3)
            if rotate != 0:
                img_rainy = np.rot90(img_rainy, rotate)
                img_gt = np.rot90(img_gt, rotate)
            if np.random.random() >= 0.5:
                img_rainy = cv2.flip(img_rainy, flipCode = 0)
                img_gt = cv2.flip(img_gt, flipCode = 0)

        # normalization
        img_rainy = img_rainy.astype(np.float32) # RGB image in range [0, 255]
        img_gt = img_gt.astype(np.float32) # RGB image in range [0, 255]
        img_rainy = img_rainy / 255.0
        img_rainy = torch.from_numpy(img_rainy.transpose(2, 0, 1)).contiguous()
        img_gt = img_gt / 255.0
        img_gt = torch.from_numpy(img_gt.transpose(2, 0, 1)).contiguous()

        return img_rainy, img_gt

    def __len__(self):
        return len(self.imglist)

class DenoisingValDataset(Dataset):
    def __init__(self):                                   		    # root: list ; transform: torch transform
        self.imglist = get_files(baseroot)

    def __getitem__(self, index):
        ## read an image
        img_rainy = cv2.imread(self.imglist[index][0])
        img_gt = cv2.imread(self.imglist[index][1])

        height = img_rainy.shape[0]
        width = img_rainy.shape[1]
        height_origin = height
        width_origin = width
        if height % 16 != 0:
            height = ((height // 16) + 1) * 16
        if width % 16 !=0:
            width = ((width // 16) + 1) * 16
        img_rainy = cv2.resize(img_rainy, (width, height))
        img_gt = cv2.resize(img_gt, (width, height))

        img_rainy = cv2.cvtColor(img_rainy, cv2.COLOR_BGR2RGB)
        img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB)

        # random crop
        if crop:
            cropper = RandomCrop(img_rainy.shape[:2], (crop_size, crop_size))
            img_rainy = cropper(img_rainy)
            img_gt = cropper(img_gt)
        # random rotate and horizontal flip
        # according to paper, these two data augmentation methods are recommended
        if angle_aug:
            rotate = random.randint(0, 3)
            if rotate != 0:
                img_rainy = np.rot90(img_rainy, rotate)
                img_gt = np.rot90(img_gt, rotate)
            if np.random.random() >= 0.5:
                img_rainy = cv2.flip(img_rainy, flipCode = 0)
                img_gt = cv2.flip(img_gt, flipCode = 0)

        # normalization
        img_rainy = img_rainy.astype(np.float32) # RGB image in range [0, 255]
        img_gt = img_gt.astype(np.float32) # RGB image in range [0, 255]
        img_rainy = img_rainy / 255.0
        img_rainy = torch.from_numpy(img_rainy.transpose(2, 0, 1)).contiguous()
        img_gt = img_gt / 255.0
        img_gt = torch.from_numpy(img_gt.transpose(2, 0, 1)).contiguous()

        return img_rainy, img_gt, height_origin, width_origin

    def __len__(self):
        return len(self.imglist)

In [10]:
!pip install gdown



In [16]:
!gdown 1IBENJqN6XU5HMlSQ2W7jkdW22Z1x4W9K

Downloading...
From: https://drive.google.com/uc?id=1IBENJqN6XU5HMlSQ2W7jkdW22Z1x4W9K
To: /content/datasets_Synthetic_EfficientDerain.zip
100% 1.82G/1.82G [00:23<00:00, 76.4MB/s]


In [14]:
!gdown 1gXckj9HTsJpzt9VZGlr8Akbp3XILAqZ4

Downloading...
From: https://drive.google.com/uc?id=1gXckj9HTsJpzt9VZGlr8Akbp3XILAqZ4
To: /content/Models.zip
100% 4.41G/4.41G [01:00<00:00, 73.3MB/s]


In [15]:
!unzip '/content/datasets_Synthetic_EfficientDerain.zip' -d '/content/derain/'
!unzip '/content/Models.zip' -d '/content/derain/'

Archive:  /content/TestData.zip
   creating: /content/derain/TestData/
   creating: /content/derain/TestData/input_100L/
  inflating: /content/derain/TestData/input_100L/rain-001.png  
  inflating: /content/derain/TestData/input_100L/rain-002.png  
  inflating: /content/derain/TestData/input_100L/rain-003.png  
  inflating: /content/derain/TestData/input_100L/rain-004.png  
  inflating: /content/derain/TestData/input_100L/rain-005.png  
  inflating: /content/derain/TestData/input_100L/rain-006.png  
  inflating: /content/derain/TestData/input_100L/rain-007.png  
  inflating: /content/derain/TestData/input_100L/rain-008.png  
  inflating: /content/derain/TestData/input_100L/rain-009.png  
  inflating: /content/derain/TestData/input_100L/rain-010.png  
  inflating: /content/derain/TestData/input_100L/rain-011.png  
  inflating: /content/derain/TestData/input_100L/rain-012.png  
  inflating: /content/derain/TestData/input_100L/rain-013.png  
  inflating: /content/derain/TestData/input_100

In [18]:
!unzip '/content/derain/datasets/rain100H.zip' -d '/content/derain/datasets/'
!unzip '/content/derain/datasets/rain100L.zip' -d '/content/derain/datasets/'
!unzip '/content/derain/datasets/rain1400.zip' -d '/content/derain/datasets/'

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/derain/datasets/rain1400/train/rainy_image/18_5.jpg  
  inflating: /content/derain/datasets/rain1400/train/rainy_image/190_13.jpg  
  inflating: /content/derain/datasets/rain1400/train/rainy_image/190_7.jpg  
  inflating: /content/derain/datasets/rain1400/train/rainy_image/191_13.jpg  
  inflating: /content/derain/datasets/rain1400/train/rainy_image/191_6.jpg  
  inflating: /content/derain/datasets/rain1400/train/rainy_image/192_8.jpg  
  inflating: /content/derain/datasets/rain1400/train/rainy_image/194_10.jpg  
  inflating: /content/derain/datasets/rain1400/train/rainy_image/194_9.jpg  
  inflating: /content/derain/datasets/rain1400/train/rainy_image/195_9.jpg  
  inflating: /content/derain/datasets/rain1400/train/rainy_image/196_13.jpg  
  inflating: /content/derain/datasets/rain1400/train/rainy_image/197_11.jpg  
  inflating: /content/derain/datasets/rain1400/train/rainy_image/19_14.jpg  
  infla

Training

In [29]:
# ----------------------------------------
# Initialize the parameters
# ----------------------------------------
# Pre-train, saving, and loading parameters
save_path = './models'
sample_path = './samples'
save_mode = 'epoch'
save_by_epoch = 10
save_by_iter = 100000
load_name = ''

# GPU parameters
no_gpu = False
multi_gpu = False
gpu_ids = '0, 1, 2, 3'
cudnn_benchmark = True

# Training parameters
epochs = 100
train_batch_size = 16
lr_g = 0.0002
b1 = 0.5
b2 = 0.999
weight_decay = 0
lr_decrease_epoch = 20
num_workers = 8

# Initialization parameters
color = True
burst_length = 1
blind_est = True
kernel_size = [3]
sep_conv = False
channel_att = False
spatial_att = False
upMode = 'bilinear'
core_bias = False
init_type = 'xavier'
init_gain = 0.02

# Dataset parameters
baseroot = '/content/derain/datasets/rain100H/train/'
rainaug = False
crop_size = 256
geometry_aug = False
angle_aug = False
scale_min = 1
scale_max = 1
mu = 0
sigma = 30

In [21]:
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, size_average = True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)

            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)

            self.window = window
            self.channel = channel


        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)

def ssim(img1, img2, window_size = 11, size_average = True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(img1, img2, window, window_size, channel, size_average)


In [31]:
import time
import datetime
import os
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.autograd as autograd
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
from torchvision import transforms

def Pre_train():
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    #torch.cuda.set_device(1)

    # cudnn benchmark
    cudnn.benchmark = cudnn_benchmark

    # configurations
    save_folder = save_path
    sample_folder = sample_path
    check_path(save_folder)
    check_path(sample_folder)

    # Loss functions
    if no_gpu == False:
        criterion_L1 = torch.nn.L1Loss().cuda()
        criterion_L2 = torch.nn.MSELoss().cuda()
        #criterion_rainypred = torch.nn.L1Loss().cuda()
        criterion_ssim = SSIM().cuda()
    else:
        criterion_L1 = torch.nn.L1Loss()
        criterion_L2 = torch.nn.MSELoss()
        #criterion_rainypred = torch.nn.L1Loss().cuda()
        criterion_ssim = SSIM()

    # Initialize Generator
    generator = create_generator()

    # To device
    if no_gpu == False:
        if multi_gpu:
            generator = nn.DataParallel(generator)
            generator = generator.cuda()
        else:
            generator = generator.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()), lr = lr_g, betas = (b1, b2), weight_decay = weight_decay)
    print("pretrained models loaded")

    # Learning rate decrease
    def adjust_learning_rate(epoch, optimizer):
        target_epoch = epochs - lr_decrease_epoch
        remain_epoch = epochs - epoch
        if epoch >= lr_decrease_epoch:
            lr = lr_g * remain_epoch / target_epoch
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        # Define the name of trained model
        if save_mode == 'epoch':
            model_name = 'KPN_rainy_image_epoch%d_bs%d.pth' % (epoch, train_batch_size)
        if save_mode == 'iter':
            model_name = 'KPN_rainy_image_iter%d_bs%d.pth' % (iteration, train_batch_size)
        save_model_path = os.path.join(save_path, model_name)
        if multi_gpu == True:
            if save_mode == 'epoch':
                if (epoch % save_by_epoch == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.module.state_dict(), save_model_path)
                    print('The trained model is successfully saved at epoch %d' % (epoch))
            if save_mode == 'iter':
                if iteration % save_by_iter == 0:
                    torch.save(generator.module.state_dict(), save_model_path)
                    print('The trained model is successfully saved at iteration %d' % (iteration))
        else:
            if save_mode == 'epoch':
                if (epoch % save_by_epoch == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.state_dict(), save_model_path)
                    print('The trained model is successfully saved at epoch %d' % (epoch))
            if save_mode == 'iter':
                if iteration % save_by_iter == 0:
                    torch.save(generator.state_dict(), save_model_path)
                    print('The trained model is successfully saved at iteration %d' % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Handle multiple GPUs
    #os.environ["CUDA_VISIBLE_DEVICES"] = ""
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)

    #print(opt.multi_gpu)
    '''
    print(opt.no_gpu == False)
    print(opt.no_gpu)
    print(gpu_num)
    print(opt.train_batch_size)
    '''

    # Define the dataset
    trainset = DenoisingDataset()
    print('The overall number of training images:', len(trainset))

    # Define the dataloader
    train_loader = DataLoader(trainset, batch_size = train_batch_size, shuffle = True, num_workers = num_workers, pin_memory = True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(epochs):
        for i, (true_input, true_target) in enumerate(train_loader):

            #print("in epoch %d" % i)

            if no_gpu == False:
                # To device
                true_input = true_input.cuda()
                true_target = true_target.cuda()

            # Train Generator
            optimizer_G.zero_grad()
            fake_target = generator(true_input, true_input)


            ssim_loss = -criterion_ssim(true_target, fake_target)

            # L1 Loss
            Pixellevel_L1_Loss = criterion_L1(fake_target, true_target)

            # Overall Loss and optimize
            loss = Pixellevel_L1_Loss + 0.2*ssim_loss
            #loss = Pixellevel_L1_Loss
            #loss = Pixellevel_L1_Loss + Pixellevel_L2_Loss + Loss_rainypred
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(train_loader) + i
            iters_left = epochs * len(train_loader) - iters_done
            time_left = datetime.timedelta(seconds = iters_left * (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print("\r[Iter %d] [Epoch %d/%d] [Batch %d/%d] [Loss: %.4f %.4f] Time_left: %s" %
                ((iters_done + 1), (epoch + 1), epochs, i, len(train_loader), Pixellevel_L1_Loss.item(), ssim_loss.item(), time_left))


            # Save model at certain epochs or iterations
            save_model((epoch + 1), (iters_done + 1), len(train_loader), generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate((epoch + 1), optimizer_G)

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [true_input, fake_target, true_target]
            name_list = ['in', 'pred', 'gt']
            save_sample_png(sample_folder = sample_folder, sample_name = 'train_epoch%d' % (epoch + 1), img_list = img_list, name_list = name_list, pixel_max_cnt = 255)

In [32]:
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def str2bool(v):
    #print(v)
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Unsupported value encountered.')

if __name__ == "__main__":
    Pre_train()

initialize network with xavier type
Generator is created!
pretrained models loaded
There are 1 GPUs used
The overall number of training images: 1254
[Iter 1] [Epoch 1/100] [Batch 0/79] [Loss: 0.3705 -0.0193] Time_left: 5:40:27.410746
[Iter 2] [Epoch 1/100] [Batch 1/79] [Loss: 0.4111 -0.0308] Time_left: 0:54:58.629273
[Iter 3] [Epoch 1/100] [Batch 2/79] [Loss: 0.4061 -0.0443] Time_left: 1:22:43.397643
[Iter 4] [Epoch 1/100] [Batch 3/79] [Loss: 0.4340 -0.0497] Time_left: 1:27:54.540653
[Iter 5] [Epoch 1/100] [Batch 4/79] [Loss: 0.3461 -0.0873] Time_left: 1:28:18.601954
[Iter 6] [Epoch 1/100] [Batch 5/79] [Loss: 0.3750 -0.0736] Time_left: 1:28:18.486187
[Iter 7] [Epoch 1/100] [Batch 6/79] [Loss: 0.3893 -0.0976] Time_left: 1:30:48.569377
[Iter 8] [Epoch 1/100] [Batch 7/79] [Loss: 0.4184 -0.1132] Time_left: 1:29:32.383590
[Iter 9] [Epoch 1/100] [Batch 8/79] [Loss: 0.3570 -0.1160] Time_left: 1:28:44.933903
[Iter 10] [Epoch 1/100] [Batch 9/79] [Loss: 0.3234 -0.1486] Time_left: 1:27:28.207040


KeyboardInterrupt: ignored

Testing

In [18]:
import argparse
import os
import torch
import numpy as np
import cv2
from skimage.metrics import structural_similarity
import time


def str2bool(v):
    #print(v)
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Unsupported value encountered.')

if __name__ == "__main__":

    if no_gpu:
        generator = create_generator()
    else:
        generator = create_generator().cuda()

    test_dataset = DenoisingValDataset()
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = test_batch_size, shuffle = False, num_workers = num_workers, pin_memory = True)
    sample_folder = save_name
    check_path(sample_folder)

    psnr_sum, psnr_ave, ssim_sum, ssim_ave, eval_cnt = 0, 0, 0, 0, 0
    time_test = 0
    count = 0
    # forward
    for i, (true_input, true_target, height_origin, width_origin) in enumerate(test_loader):

        # To device
        if no_gpu:
            true_input = true_input
            true_target = true_target
        else:
            true_input = true_input.cuda()
            true_target = true_target.cuda()

        # Forward propagation
        start_time = time.time()

        with torch.no_grad():
            #print(true_input.size())
            fake_target = generator(true_input, true_input)

        end_time = time.time()
        dur_time = end_time - start_time
        time_test += dur_time
        count += 1
        #print(fake_target.shape, true_input.shape)

        # Save
        print('The %d-th iteration' % (i))
        img_list = [true_input, fake_target, true_target]
        name_list = ['in', 'pred', 'gt']
        sample_name = '%d' % (i+1)
        save_sample_png(sample_folder = sample_folder, sample_name = '%d' % (i + 1), img_list = img_list, name_list = name_list, pixel_max_cnt = 255, height = image_height, width = image_width)

        # Evaluation
        #psnr_sum = psnr_sum + utils.psnr(cv2.imread(sample_folder + '/' + sample_name + '_' + name_list[1] + '.png').astype(np.float32), cv2.imread(sample_folder + '/' + sample_name + '_' + name_list[2] + '.png').astype(np.float32))
        img_pred_recover = recover_process(fake_target, height = image_height, width = image_width)
        img_gt_recover = recover_process(true_target, height = image_height, width = image_width)
        #psnr_sum = psnr_sum + utils.psnr(utils.recover_process(fake_target, height = height_origin, width = width_origin), utils.recover_process(true_target, height = height_origin, width = width_origin))
        psnr_sum = psnr_sum + psnr(img_pred_recover, img_gt_recover)
        ssim_sum = ssim_sum + structural_similarity(img_gt_recover, img_pred_recover, multichannel = True, data_range = 255)
        eval_cnt = eval_cnt + 1

    print('Avg. time:', time_test/count)





Generator is loaded!
The 0-th iteration


  ssim_sum = ssim_sum + structural_similarity(img_gt_recover, img_pred_recover, multichannel = True, data_range = 255)


The 1-th iteration
The 2-th iteration
The 3-th iteration
The 4-th iteration
The 5-th iteration
The 6-th iteration
The 7-th iteration
The 8-th iteration
The 9-th iteration
The 10-th iteration
The 11-th iteration
The 12-th iteration
The 13-th iteration
The 14-th iteration
The 15-th iteration
The 16-th iteration
The 17-th iteration
The 18-th iteration
The 19-th iteration
The 20-th iteration
The 21-th iteration
The 22-th iteration
The 23-th iteration
The 24-th iteration
The 25-th iteration
The 26-th iteration
The 27-th iteration
The 28-th iteration
The 29-th iteration
The 30-th iteration
The 31-th iteration
The 32-th iteration
The 33-th iteration
The 34-th iteration
The 35-th iteration
The 36-th iteration
The 37-th iteration
The 38-th iteration
The 39-th iteration
The 40-th iteration
The 41-th iteration
The 42-th iteration
The 43-th iteration
The 44-th iteration
The 45-th iteration
The 46-th iteration
The 47-th iteration
The 48-th iteration
The 49-th iteration
The 50-th iteration
The 51-th

FileNotFoundError: ignored

In [19]:
psnr_ave = psnr_sum / eval_cnt
ssim_ave = ssim_sum / eval_cnt
psnr_file = "./derain/psnr_data.txt"
ssim_file = "./derain/ssim_data.txt"
psnr_content = load_name + ": " + str(psnr_ave) + "\n"
ssim_content = load_name + ": " + str(ssim_ave) + "\n"
text_save(content = psnr_content, filename = psnr_file)
text_save(content = ssim_content, filename = ssim_file)