In [None]:
#!git clone "https://github.com/liuh127/NTIRE-2021-Dehazing-DWGAN.git"

# Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Dependencies

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.models import vgg16, VGG16_Weights
from torchvision.utils import save_image as imwrite
import cv2
import numpy as np
from torchvision import transforms
from skimage.metrics import structural_similarity as ssim
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
import random
from math import exp, log10
import numbers
import os

try:
  import einops
  from einops import rearrange
except:
  !pip install einops
  import einops
  from einops import rearrange

import gc

# Data Loader

In [None]:
#data augmentation for image rotate
def custom_augment(hazy, clean, clean_gray):
    augmentation_method = random.choice([0, 1, 2, 3, 4, 5])
    rotate_degree = random.choice([90, 180, 270])
    '''Rotate'''
    if augmentation_method == 0:
        hazy = transforms.functional.rotate(hazy, rotate_degree)
        clean = transforms.functional.rotate(clean, rotate_degree)
        clean_gray = transforms.functional.rotate(clean_gray, rotate_degree)
        return hazy, clean, clean_gray
    '''Vertical'''
    if augmentation_method == 1:
        vertical_flip = transforms.RandomVerticalFlip(p=1)
        hazy = vertical_flip(hazy)
        clean = vertical_flip(clean)
        clean_gray = vertical_flip(clean_gray)
        return hazy, clean, clean_gray
    '''Horizontal'''
    if augmentation_method == 2:
        horizontal_flip = transforms.RandomHorizontalFlip(p=1)
        hazy = horizontal_flip(hazy)
        clean = horizontal_flip(clean)
        clean_gray = horizontal_flip(clean_gray)
        return hazy, clean, clean_gray
    '''no change'''
    if augmentation_method == 3 or augmentation_method == 4 or augmentation_method == 5:
        return hazy, clean, clean_gray

class custom_dehaze_train_dataset(Dataset):
    def __init__(self, HAZY_path = None, GT_path = None, Image_Size = (256,256), is_train = True):
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.HAZY_path = Path(HAZY_path)
        self.GT_path = Path(GT_path)
        self.HAZY_Image = sorted(self.HAZY_path.glob("*.png")) # list all the files present in HAZY images folder...
        self.GT_Image = sorted(self.GT_path.glob("*.png")) # list all the files present in GT images folder...
        self.Image_Size = Image_Size
        self.is_train = is_train

    def __getitem__(self, index):
        hazy = Image.open(self.HAZY_Image[index])
        clean = Image.open(self.GT_Image[index])
        if self.is_train:
            # clean_gray = clean.convert('L')
            clean_gray = clean
            #crop a patch
            i,j,h,w = transforms.RandomCrop.get_params(hazy, output_size = self.Image_Size)
            hazy_ = TF.crop(hazy, i, j, h, w)
            clean_ = TF.crop(clean, i, j, h, w)
            clean_gray_ = TF.crop(clean_gray, i, j, h, w)

            #data argumentation
            hazy_arg, clean_arg, clean_gray_arg = custom_augment(hazy_, clean_, clean_gray_)
            hazy = self.transform(hazy_arg)
            clean = self.transform(clean_arg)
            rgb_edged_cv2_x = cv2.Sobel(np.float32(clean_gray_arg), cv2.CV_64F, 1, 0, ksize=3)
            rgb_edged_cv2_y = cv2.Sobel(np.float32(clean_gray_arg), cv2.CV_64F, 0, 1, ksize=3)
            rgb_edged_cv2 = np.sqrt(np.square(rgb_edged_cv2_x), np.square(rgb_edged_cv2_y))
            clean_gray = self.transform(rgb_edged_cv2)
            return hazy,clean,clean_gray/255
        else:
          hazy = self.transform(hazy)
          clean = self.transform(clean)
          return hazy,clean

    def __len__(self):
        return len(self.HAZY_Image) # return length of dataset

class dehaze_test_dataset(Dataset):
    def __init__(self, HAZY_PATH = None, GT_PATH = None):
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.root_hazy = Path(HAZY_PATH)
        self.root_GT = Path(GT_PATH)
        self.list_test = sorted(self.root_hazy.glob("*.png")) # list all the files present in HAZY images folder...
        self.list_GT = sorted(self.root_GT.glob("*.png")) # list all the files present in GT images folder...
        self.file_len = len(self.list_test)
    def __getitem__(self, index, is_train=True):
        hazy = Image.open(self.list_test[index])
        hazy = self.transform(hazy)
        #----------- Gives cuda out of memory -----------
        hazy_up=hazy[:,0:1152,:]
        hazy_down=hazy[:,48:1200,:]

        # ----------- Doesn't give cuda out of memory but the separating line is visible -----------
        # hazy_up=hazy[:,0:768,:]
        # hazy_down=hazy[:,432:1200,:]
        name=self.list_test[index].stem
        if len(self.list_GT) == 0:
          return hazy_up,hazy_down,name
        else:
          clean=Image.open(self.list_GT[index])
          clean = self.transform(clean)
          return hazy_up, hazy_down, name, clean 
    def __len__(self):
        return self.file_len

class CustomDataLoader(Dataset):
    def __init__(self, HAZY_path = None, GT_path = None, image_size = None, resize = None):
        self.HAZY_path = Path(HAZY_path)
        self.GT_path = Path(GT_path)
        self.HAZY_Image = sorted(self.HAZY_path.glob("*.png")) # list all the files present in HAZY images folder...
        self.GT_Image = sorted(self.GT_path.glob("*.png")) # list all the files present in GT images folder...
        assert len(self.HAZY_Image) == len(self.GT_Image)  
        self.resize = resize
        if(self.resize):
            self.data_transforms = transforms.Compose([transforms.Resize(image_size),
                                                        transforms.ToTensor()])
        else:
            self.data_transforms = transforms.Compose([transforms.ToTensor()])

    def load_image(self, index: int, image_type = "HAZY") -> Image.Image:
        "Opens an image via a path and returns it."

        if image_type == "HAZY":
          image_path = self.HAZY_Image[index]
        elif image_type == "GT":
          image_path = self.GT_Image[index]

        return Image.open(image_path)
        
    def __len__(self):
        return len(self.HAZY_Image) # return length of dataset
    
    def __getitem__(self, index):
        HAZY = Image.open(self.HAZY_Image[index])
        GT = Image.open(self.GT_Image[index]) 
        return self.data_transforms(HAZY), self.data_transforms(GT), self.HAZY_Image[index].stem

# Utils_Test

In [None]:
def to_psnr(frame_out, gt):
    mse = F.mse_loss(frame_out, gt, reduction='none')
    mse_split = torch.split(mse, 1, dim=0)
    mse_list = [torch.mean(torch.squeeze(mse_split[ind])).item() for ind in range(len(mse_split))]
    intensity_max = 1.0
    psnr_list = [10.0 * log10(intensity_max / mse) for mse in mse_list]
    return psnr_list

def to_ssim_skimage(dehaze, gt):
    dehaze_list = torch.split(dehaze, 1, dim=0)
    gt_list = torch.split(gt, 1, dim=0)

    dehaze_list_np = [dehaze_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() for ind in range(len(dehaze_list))]
    gt_list_np = [gt_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() for ind in range(len(dehaze_list))]
    # ssim_list = [ssim(dehaze_list_np[ind],  gt_list_np[ind], data_range=1, multichannel=True) for ind in range(len(dehaze_list))]
    ssim_list = [ssim(dehaze_list_np[ind],  gt_list_np[ind], data_range=1, channel_axis=-1) for ind in range(len(dehaze_list))]

    return ssim_list

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()


def create_window(window_size, channel=1):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window

def ssim1(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
    if val_range is None:
        if torch.max(img1) > 128:
            max_val = 255
        else:
            max_val = 1

        if torch.min(img1) < -0.5:
            min_val = -1
        else:
            min_val = 0
        L = max_val - min_val
    else:
        L = val_range

    padd = 0
    (_, channel, height, width) = img1.size()
    if window is None:
        real_size = min(window_size, height, width)
        window = create_window(real_size, channel=channel).to(img1.device)

    mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
    mu2 = F.conv2d(img2, window, padding=padd, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2

    C1 = (0.01 * L) ** 2
    C2 = (0.03 * L) ** 2

    v1 = 2.0 * sigma12 + C2
    v2 = sigma1_sq + sigma2_sq + C2
    cs = v1 / v2  # contrast sensitivity

    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)

    if size_average:
        cs = cs.mean()
        ret = ssim_map.mean()
    else:
        cs = cs.mean(1).mean(1).mean(1)
        ret = ssim_map.mean(1).mean(1).mean(1)

    if full:
        return ret, cs
    return ret


def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
    device = img1.device
    weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
    levels = weights.size()[0]
    mssim = []
    mcs = []
    for _ in range(levels):
        sim, cs = ssim1(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
        mssim.append(sim)
        mcs.append(cs)

        img1 = F.avg_pool2d(img1, (2, 2))
        img2 = F.avg_pool2d(img2, (2, 2))

    
    mssim = torch.stack(mssim)
    mcs = torch.stack(mcs)

    # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
    if normalize:
        mssim = (mssim + 1) / 2
        mcs = (mcs + 1) / 2

    pow1 = mcs ** weights
    pow2 = mssim ** weights
    # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
    output = torch.prod(pow1[:-1] * pow2[-1])
    return output

# Loss Network

In [None]:
# --- Perceptual loss network  --- #
class LossNetwork(torch.nn.Module):
    def __init__(self, vgg_model):
        super(LossNetwork, self).__init__()
        self.vgg_layers = vgg_model
        self.layer_name_mapping = {
            '3': "relu1_2",
            '8': "relu2_2",
            '15': "relu3_3"
        }

    def output_features(self, x):
        output = {}
        for name, module in self.vgg_layers._modules.items():
            x = module(x)
            if name in self.layer_name_mapping:
                output[self.layer_name_mapping[name]] = x
        return list(output.values())

    def forward(self, dehaze, gt):
        loss = []
        dehaze_features = self.output_features(dehaze)
        gt_features = self.output_features(gt)
        for dehaze_feature, gt_feature in zip(dehaze_features, gt_features):
            loss.append(F.mse_loss(dehaze_feature, gt_feature))
        return sum(loss)/len(loss)

# Restormer

In [None]:
## Restormer: Efficient Transformer for High-Resolution Image Restoration
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
## https://arxiv.org/abs/2111.09881



##########################################################################
## Layer Norm

def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')
    # flatten = nn.Flatten(2,3)
    # return flatten(x).permute(0,2,1)

def to_4d(x,h,w):
    return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
    # unflatten = nn.Unflatten(1,(h,w))
    # return unflatten(x).permute(0,3,1,2)

class BiasFree_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(BiasFree_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return x / torch.sqrt(sigma+1e-5) * self.weight

class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias


class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type):
        super(LayerNorm, self).__init__()
        if LayerNorm_type =='BiasFree':
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)

    def forward(self, x):
        h, w = x.shape[-2:]
        return to_4d(self.body(to_3d(x)), h, w)



##########################################################################
## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()

        hidden_features = int(dim*ffn_expansion_factor)

        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)

        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)

        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x



##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        


    def forward(self, x):
        b,c,h,w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q,k,v = qkv.chunk(3, dim=1)   
        
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)
        
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out



##########################################################################
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
        super(TransformerBlock, self).__init__()

        self.norm1 = LayerNorm(dim, LayerNorm_type)
        self.attn = Attention(dim, num_heads, bias)
        self.norm2 = LayerNorm(dim, LayerNorm_type)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))

        return x



##########################################################################
## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
    def __init__(self, in_c=3, embed_dim=48, bias=False):
        super(OverlapPatchEmbed, self).__init__()

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, x):
        x = self.proj(x)

        return x



##########################################################################
## Resizing modules
class Downsample(nn.Module):
    def __init__(self, n_feat):
        super(Downsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelUnshuffle(2))

    def forward(self, x):
        return self.body(x)

class Upsample(nn.Module):
    def __init__(self, n_feat):
        super(Upsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelShuffle(2))

    def forward(self, x):
        return self.body(x)

##########################################################################
##---------- Restormer -----------------------
class Restormer(nn.Module):
    def __init__(self, 
        inp_channels=3, 
        out_channels=3, 
        dim = 48,
        num_blocks = [4,6,6,8], 
        num_refinement_blocks = 4,
        heads = [1,2,4,8],
        ffn_expansion_factor = 2.66,
        bias = False,
        LayerNorm_type = 'WithBias',   ## Other option 'BiasFree'
        dual_pixel_task = False        ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
    ):

        super(Restormer, self).__init__()

        self.patch_embed = OverlapPatchEmbed(inp_channels, dim)

        self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        
        self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
        self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        
        self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
        self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])

        self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
        self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
        
        self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
        self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
        self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])


        self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
        self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
        self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        
        self.up2_1 = Upsample(int(dim*2**1))  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)

        self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        
        self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
        
        #### For Dual-Pixel Defocus Deblurring Task ####
        self.dual_pixel_task = dual_pixel_task
        if self.dual_pixel_task:
            self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
        ###########################
            
        self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, inp_img):
        inp_enc_level1 = self.patch_embed(inp_img)
        out_enc_level1 = self.encoder_level1(inp_enc_level1)
        inp_enc_level1 = inp_enc_level1.detach().cpu()
        del inp_enc_level1
        gc.collect()
        
        inp_enc_level2 = self.down1_2(out_enc_level1)
        out_enc_level2 = self.encoder_level2(inp_enc_level2)
        inp_enc_level2 = inp_enc_level2.detach().cpu()
        del inp_enc_level2
        gc.collect()

        inp_enc_level3 = self.down2_3(out_enc_level2)
        out_enc_level3 = self.encoder_level3(inp_enc_level3) 
        inp_enc_level3 = inp_enc_level3.detach().cpu()
        del inp_enc_level3
        gc.collect()

        inp_enc_level4 = self.down3_4(out_enc_level3)        
        latent = self.latent(inp_enc_level4) 
        inp_enc_level4 = inp_enc_level4.detach().cpu()
        del inp_enc_level4
        gc.collect()
                        
        inp_dec_level3 = self.up4_3(latent)
        latent = latent.detach().cpu()
        del latent
        gc.collect()

        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
        out_enc_level3 = out_enc_level3.detach().cpu()
        del out_enc_level3
        gc.collect()

        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
        out_dec_level3 = self.decoder_level3(inp_dec_level3) 
        inp_dec_level3 = inp_dec_level3.detach().cpu()
        del inp_dec_level3
        gc.collect()

        inp_dec_level2 = self.up3_2(out_dec_level3)
        out_dec_level3 = out_dec_level3.detach().cpu()
        del out_dec_level3
        gc.collect()

        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
        out_enc_level2 = out_enc_level2.detach().cpu()
        del out_enc_level2
        gc.collect()

        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
        out_dec_level2 = self.decoder_level2(inp_dec_level2) 
        inp_dec_level2 = inp_dec_level2.detach().cpu()
        del inp_dec_level2
        gc.collect()

        inp_dec_level1 = self.up2_1(out_dec_level2)
        out_dec_level2 = out_dec_level2.detach().cpu()
        del out_dec_level2
        gc.collect()
        
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
        out_enc_level1 = out_enc_level1.detach().cpu()
        del out_enc_level1
        gc.collect()

        torch.cuda.empty_cache()
        
        out_dec_level1 = self.decoder_level1(inp_dec_level1)
        inp_dec_level1 = inp_dec_level1.detach().cpu()
        del inp_dec_level1
        gc.collect()

        out_dec_level1 = self.refinement(out_dec_level1)
        out_dec_level1 = self.output(out_dec_level1) + inp_img

        return out_dec_level1


# Haze Density Map

In [None]:
def blockUNet1(in_c, out_c, name, transposed=False, bn=False, relu=True, dropout=False):
    block = nn.Sequential()
    if relu:
        block.add_module('%s_relu' % name, nn.ReLU(inplace=True))
    else:
        block.add_module('%s_leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True))
    if not transposed:
        block.add_module('%s_conv' % name, nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False))
    else:
        block.add_module('%s_tconv' % name, nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False))
    if bn:
        block.add_module('%s_bn' % name, nn.BatchNorm2d(out_c))
    if dropout:
        block.add_module('%s_dropout' % name, nn.Dropout2d(0.5, inplace=True))
    return block


class HazeDensityMap(nn.Module):
    """
    this is a class for generating haze density map taken from Trident paper
    for the pre-trained weights it's in the graduation project drive folder
        Graduation Project/HM.pt

    it has a bit funny usage for the feed forward you can find the functions in the utils file
    and an example for the usage in the notebook:
        https://colab.research.google.com/drive/1Ngj5rMHFh1BMWUotsgEVJulpwIbgLP6x#scrollTo=Z3Xr6hqAfuXC

    """

    def __init__(self, input_nc=3, output_nc=3, nf=8):
        super(HazeDensityMap, self).__init__()
        # input is 256 x 256
        layer_idx = 1
        name = 'layer%d' % layer_idx
        layer1 = nn.Sequential()
        layer1.add_module(name, nn.Conv2d(input_nc, nf, 4, 2, 1, bias=False))
        # input is 128 x 128
        layer_idx += 1
        name = 'layer%d' % layer_idx
        layer2 = blockUNet1(nf, nf * 2, name, transposed=False, bn=True, relu=False, dropout=False)
        # input is 64 x 64
        layer_idx += 1
        name = 'layer%d' % layer_idx
        layer3 = blockUNet1(nf * 2, nf * 4, name, transposed=False, bn=True, relu=False, dropout=False)
        # input is 32
        layer_idx += 1
        name = 'layer%d' % layer_idx
        layer4 = blockUNet1(nf * 4, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False)
        # input is 16
        layer_idx += 1
        name = 'layer%d' % layer_idx
        layer5 = blockUNet1(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False)
        # input is 8
        layer_idx += 1
        name = 'layer%d' % layer_idx
        layer6 = blockUNet1(nf * 8, nf * 8, name, transposed=False, bn=False, relu=False, dropout=False)

        ## NOTE: decoder
        # input is 4
        layer_idx -= 1
        name = 'dlayer%d' % layer_idx

        # dlayer6 = blockUNet1(nf*16, nf*8, name, transposed=True, bn=True, relu=True, dropout=True)
        dlayer6 = blockUNet1(nf * 8, nf * 8, name, transposed=True, bn=True, relu=True, dropout=False)
        # input is 8
        layer_idx -= 1
        name = 'dlayer%d' % layer_idx

        dlayer5 = blockUNet1(nf * 16, nf * 8, name, transposed=True, bn=True, relu=True, dropout=False)
        # input is 16
        layer_idx -= 1
        name = 'dlayer%d' % layer_idx

        dlayer4 = blockUNet1(nf * 16, nf * 4, name, transposed=True, bn=True, relu=True, dropout=False)
        # input is 32
        layer_idx -= 1
        name = 'dlayer%d' % layer_idx

        dlayer3 = blockUNet1(nf * 8, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False)
        # input is 64
        layer_idx -= 1
        name = 'dlayer%d' % layer_idx

        dlayer2 = blockUNet1(nf * 4, nf, name, transposed=True, bn=True, relu=True, dropout=False)
        # input is 128
        layer_idx -= 1
        name = 'dlayer%d' % layer_idx
        dlayer1 = blockUNet1(nf * 2, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False)

        self.layer1 = layer1
        self.layer2 = layer2
        self.layer3 = layer3
        self.layer4 = layer4
        self.layer5 = layer5
        self.layer6 = layer6
        self.dlayer6 = dlayer6
        self.dlayer5 = dlayer5
        self.dlayer4 = dlayer4
        self.dlayer3 = dlayer3
        self.dlayer2 = dlayer2
        self.dlayer1 = dlayer1
        self.tail_conv = nn.Conv2d(nf * 2, output_nc, 3, padding=1, bias=True)

    def forward(self, x):
        b, c, h, w = x.shape
        mod1 = h % 64
        mod2 = w % 64
        if (mod1):
            down1 = 64 - mod1
            x = F.pad(x, (0, 0, 0, down1), "reflect")
        if (mod2):
            down2 = 64 - mod2
            x = F.pad(x, (0, down2, 0, 0), "reflect")

        out1 = self.layer1(x)
        out2 = self.layer2(out1)
        out3 = self.layer3(out2)
        out4 = self.layer4(out3)
        out5 = self.layer5(out4)
        out6 = self.layer6(out5)
        dout6 = self.dlayer6(out6)
        dout6_out5 = torch.cat([dout6, out5], 1)
        dout5 = self.dlayer5(dout6_out5)
        dout5_out4 = torch.cat([dout5, out4], 1)
        dout4 = self.dlayer4(dout5_out4)
        dout4_out3 = torch.cat([dout4, out3], 1)
        dout3 = self.dlayer3(dout4_out3)
        dout3_out2 = torch.cat([dout3, out2], 1)
        dout2 = self.dlayer2(dout3_out2)
        dout2_out1 = torch.cat([dout2, out1], 1)
        dout1 = self.dlayer1(dout2_out1)
        dout1 = self.tail_conv(dout1)

        if (mod1): dout1 = dout1[:, :, :-down1, :]
        if (mod2): dout1 = dout1[:, :, :, :-down2]

        return dout1

# Sobel_UNet

In [None]:
class Sobel_UNet(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, nf=8):
        super(Sobel_UNet, self).__init__()
        # input is 256 x 256
        layer_idx = 1
        name = 'layer%d' % layer_idx
        layer1 = nn.Sequential()
        layer1.add_module(name, nn.Conv2d(input_nc, nf, 4, 2, 1, bias=False))
        # input is 128 x 128
        layer_idx += 1
        name = 'layer%d' % layer_idx
        layer2 = blockUNet1(nf, nf * 2, name, transposed=False, bn=True, relu=False, dropout=False)
        # input is 64 x 64
        layer_idx += 1
        name = 'layer%d' % layer_idx
        layer3 = blockUNet1(nf * 2, nf * 4, name, transposed=False, bn=True, relu=False, dropout=False)
        # input is 32
        layer_idx += 1
        name = 'layer%d' % layer_idx
        layer4 = blockUNet1(nf * 4, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False)
        # input is 16
        layer_idx += 1
        name = 'layer%d' % layer_idx
        layer5 = blockUNet1(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False)
        # input is 8
        layer_idx += 1
        name = 'layer%d' % layer_idx
        layer6 = blockUNet1(nf * 8, nf * 8, name, transposed=False, bn=False, relu=False, dropout=False)

        ## NOTE: decoder
        # input is 4
        layer_idx -= 1
        name = 'dlayer%d' % layer_idx

        # dlayer6 = blockUNet1(nf*16, nf*8, name, transposed=True, bn=True, relu=True, dropout=True)
        dlayer6 = blockUNet1(nf * 8, nf * 8, name, transposed=True, bn=True, relu=True, dropout=False)
        # input is 8
        layer_idx -= 1
        name = 'dlayer%d' % layer_idx

        dlayer5 = blockUNet1(nf * 16, nf * 8, name, transposed=True, bn=True, relu=True, dropout=False)
        # input is 16
        layer_idx -= 1
        name = 'dlayer%d' % layer_idx

        dlayer4 = blockUNet1(nf * 16, nf * 4, name, transposed=True, bn=True, relu=True, dropout=False)
        # input is 32
        layer_idx -= 1
        name = 'dlayer%d' % layer_idx

        dlayer3 = blockUNet1(nf * 8, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False)
        # input is 64
        layer_idx -= 1
        name = 'dlayer%d' % layer_idx

        dlayer2 = blockUNet1(nf * 4, nf, name, transposed=True, bn=True, relu=True, dropout=False)
        # input is 128
        layer_idx -= 1
        name = 'dlayer%d' % layer_idx
        dlayer1 = blockUNet1(nf * 2, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False)

        self.layer1 = layer1
        self.layer2 = layer2
        self.layer3 = layer3
        self.layer4 = layer4
        self.layer5 = layer5
        self.layer6 = layer6
        self.dlayer6 = dlayer6
        self.dlayer5 = dlayer5
        self.dlayer4 = dlayer4
        self.dlayer3 = dlayer3
        self.dlayer2 = dlayer2
        self.dlayer1 = dlayer1
        self.tail_conv = nn.Conv2d(nf * 2, output_nc, 3, padding=1, bias=True)

    def forward(self, x):
        b, c, h, w = x.shape  
        mod1 = h % 64
        mod2 = w % 64
        if (mod1):
            down1 = 64 - mod1
            x = F.pad(x, (0, 0, 0, down1), "reflect")
        if (mod2):
            down2 = 64 - mod2
            x = F.pad(x, (0, down2, 0, 0), "reflect")
        out1 = self.layer1(x)
        out2 = self.layer2(out1)
        out3 = self.layer3(out2)
        out4 = self.layer4(out3)
        out5 = self.layer5(out4)
        out6 = self.layer6(out5)
        dout6 = self.dlayer6(out6)
        dout6_out5 = torch.cat([dout6, out5], 1)
        dout5 = self.dlayer5(dout6_out5)
        dout5_out4 = torch.cat([dout5, out4], 1)
        dout4 = self.dlayer4(dout5_out4)
        dout4_out3 = torch.cat([dout4, out3], 1)
        dout3 = self.dlayer3(dout4_out3)
        dout3_out2 = torch.cat([dout3, out2], 1)
        dout2 = self.dlayer2(dout3_out2)
        dout2_out1 = torch.cat([dout2, out1], 1)
        dout1 = self.dlayer1(dout2_out1)
        dout1 = self.tail_conv(dout1)
        if (mod1): dout1 = dout1[:, :, :-down1, :]
        if (mod2): dout1 = dout1[:, :, :, :-down2]
        return dout1

# Adaptive White Balancing (AWB)

## deep_wb_blocks

In [None]:
"""
 Main blocks of the network
 Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
 If you use this code, please cite the following paper:
 Mahmoud Afifi and Michael S Brown. Deep White-Balance Editing. In CVPR, 2020.
"""
__author__ = "Mahmoud Afifi"
__credits__ = ["Mahmoud Afifi"]

class DoubleConvBlock(nn.Module):
    """double conv layers block"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class DownBlock(nn.Module):
    """Downscale block: maxpool -> double conv block"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConvBlock(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class BridgeDown(nn.Module):
    """Downscale bottleneck block: maxpool -> conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class BridgeUP(nn.Module):
    """Downscale bottleneck block: conv -> transpose conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_up = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        )

    def forward(self, x):
        return self.conv_up(x)



class UpBlock(nn.Module):
    """Upscale block: double conv block -> transpose conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConvBlock(in_channels * 2, in_channels)
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)



    def forward(self, x1, x2):
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return torch.relu(self.up(x))


class OutputBlock(nn.Module):
    """Output block: double conv block -> output conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.out_conv = nn.Sequential(
            DoubleConvBlock(in_channels * 2, in_channels),
            nn.Conv2d(in_channels, out_channels, kernel_size=1))

    def forward(self, x1, x2):
        x = torch.cat([x2, x1], dim=1)
        return self.out_conv(x)

## deep_wb_model

In [None]:
"""
 Constructs network architecture
 Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
 If you use this code, please cite the following paper:
 Mahmoud Afifi and Michael S Brown. Deep White-Balance Editing. In CVPR, 2020.
"""
__author__ = "Mahmoud Afifi"
__credits__ = ["Mahmoud Afifi"]

class deepWBNet(nn.Module):
    def __init__(self):
        super(deepWBNet, self).__init__()
        self.n_channels = 3
        self.encoder_inc = DoubleConvBlock(self.n_channels, 24)
        self.encoder_down1 = DownBlock(24, 48)
        self.encoder_down2 = DownBlock(48, 96)
        self.encoder_down3 = DownBlock(96, 192)
        self.encoder_bridge_down = BridgeDown(192, 384)
        self.awb_decoder_bridge_up = BridgeUP(384, 192)
        self.awb_decoder_up1 = UpBlock(192, 96)
        self.awb_decoder_up2 = UpBlock(96, 48)
        self.awb_decoder_up3 = UpBlock(48, 24)
        self.awb_decoder_out = OutputBlock(24, self.n_channels)

    def forward(self, x):
        x1 = self.encoder_inc(x)
        x2 = self.encoder_down1(x1)
        x3 = self.encoder_down2(x2)
        x4 = self.encoder_down3(x3)
        x5 = self.encoder_bridge_down(x4)
        x_awb = self.awb_decoder_bridge_up(x5)
        x_awb = self.awb_decoder_up1(x_awb, x4)
        x_awb = self.awb_decoder_up2(x_awb, x3)
        x_awb = self.awb_decoder_up3(x_awb, x2)
        awb = self.awb_decoder_out(x_awb, x1)
        return awb


# Custom fusion net

In [None]:
class GBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super(GBlock, self).__init__()

        self.c = nn.Conv2d(in_c + in_c, out_c, 1)
        self.sig = nn.Sigmoid()

    def forward(self, x1, x2, x3):
        x = torch.cat([x1, x2], axis = 1)
        x = self.c(x)
        x = self.sig(x)
        x = x * x3
        return x
class Custom_fusion_net(nn.Module):
    def __init__(self):
        super(Custom_fusion_net, self).__init__()
        self.restormer = Restormer()
        checkpoint = torch.load("/content/drive/Shareddrives/Untitled shared drive/CANT_Haze/Weights/motion_deblurring.pth")
        self.restormer.load_state_dict(checkpoint['params'])
        self.sobel_UNet = Sobel_UNet()
        self.haze_density = HazeDensityMap()
        self.haze_density.load_state_dict(torch.load("/content/drive/MyDrive/Copy of HM.pt"))
        self.GBlock = GBlock(3,3)
        self.awb =  deepWBNet()
        checkpoints = torch.load("/content/drive/Shareddrives/Untitled shared drive/CANT_Haze/Weights/net_awb.pth")
        self.awb.load_state_dict(checkpoints['state_dict'])
    def forward(self, input):
        restormer=self.restormer(input)
        x = self.haze_density(input)
        hazy_sobel = self.sobel_UNet(restormer)
        x = self.GBlock(x,hazy_sobel,restormer)
        x = self.awb(x)    
        return x , hazy_sobel


# Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))

# Testing Without Resizing

In [None]:
# # VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/valid_dense/HAZY/"
# # VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/valid_dense/GT/"

# # VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_Hazy/"
# # VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_GT/"

# VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_Hazy/"
# VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_GT/"

# #----------- Gives Cuda out of memory -----------
# RESIZE = False 

# VAL_BATCH_SIZE = 1
# NUM_WORKERS = 0


# # --- output picture and check point --- #
# #G_model_save_dir = "/content/drive/MyDrive/Graduation Project/CANT_Haze/Weights/Generator_NH_Restormer_Twice_HM_GBlock_AWB_Sobel_Best.pth"
# G_model_save_dir = "/content/drive/Shareddrives/Untitled shared drive/CANT_Haze/Weights/Generator_NH_Restormer_Twice_HM_GBlock_AWB_Sobel_Best.pth"

# # --- Gpu device --- #
# device_ids = [Id for Id in range(torch.cuda.device_count())]
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# # --- Define the network --- #
# MyEnsembleNet = Custom_fusion_net().float()
# print('MyEnsembleNet parameters:', sum(param.numel() for param in MyEnsembleNet.parameters()))

# # --- Load testing data --- #
# val_data = CustomDataLoader(HAZY_path = VAL_HAZY_IMAGES_PATH,
#                             GT_path = VAL_GT_IMAGES_PATH,
#                             resize = RESIZE)

# val_loader = DataLoader(val_data, 
#                         batch_size = VAL_BATCH_SIZE, 
#                         num_workers = NUM_WORKERS)
# MyEnsembleNet = MyEnsembleNet.to(device)

# # --- Load the network weight --- #
# try:
#     MyEnsembleNet.load_state_dict(torch.load(G_model_save_dir))
#     print('--- weight loaded ---')
# except:
#     print('--- no weight loaded ---')
# # --- Start training --- #
# print("-----Testing-----")     
# with torch.inference_mode():
#     psnr_list = []
#     ssim_list = []
#     MyEnsembleNet.eval()
#     for batch_idx, (hazy, clean, data_name) in enumerate(val_loader): 
#         clean = clean.to(device)
#         hazy = hazy.to(device)
#         frame_out, _ = MyEnsembleNet(hazy)
#         psnr_list.extend(to_psnr(frame_out, clean))
#         ssim_list.extend(to_ssim_skimage(frame_out, clean))
#         if not os.path.exists('test/'):
#             os.makedirs('test/')
#         imwrite(frame_out, 'test/' + ''.join(data_name) + '.png', range=(0, 1))

# avr_psnr = sum(psnr_list) / len(psnr_list)
# avr_ssim = sum(ssim_list) / len(ssim_list)
# print('PSNR: ', avr_psnr, 'SSIM: ', avr_ssim)

# Testing With Resizing

In [None]:
# VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/valid_dense/HAZY/"
# VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/valid_dense/GT/"

VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_Hazy/"
VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_GT/"

# RESIZE = True

#----------- Doesn't give Cuda out of memory -----------
RESIZE = True 

TEST_IMAGE_SIZE = (768,1024) # won't be used in the data loader if RESIZE is set to false

VAL_BATCH_SIZE = 1
NUM_WORKERS = 2


# --- output picture and check point --- #
G_model_save_dir = "/content/drive/MyDrive/Graduation Project/CANT_Haze/Weights/Generator_NH_Restormer_Twice_HM_GBlock_AWB_Sobel_Best.pth"
# --- Gpu device --- #
device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- Define the network --- #
MyEnsembleNet = Custom_fusion_net().float()
print('MyEnsembleNet parameters:', sum(param.numel() for param in MyEnsembleNet.parameters()))

# --- Load testing data --- #
val_data = CustomDataLoader(HAZY_path = VAL_HAZY_IMAGES_PATH,
                            GT_path = VAL_GT_IMAGES_PATH,
                            image_size = TEST_IMAGE_SIZE,
                            resize = RESIZE)

val_loader = DataLoader(val_data, 
                        batch_size = VAL_BATCH_SIZE, 
                        num_workers = NUM_WORKERS)
MyEnsembleNet = MyEnsembleNet.to(device)

# --- Load the network weight --- #
try:
    MyEnsembleNet.load_state_dict(torch.load(G_model_save_dir))
    print('--- weight loaded ---')
except:
    print('--- no weight loaded ---')
# --- Start training --- #
print("-----Testing-----")     
with torch.inference_mode():
    psnr_list = []
    ssim_list = []
    MyEnsembleNet.eval()
    for batch_idx, (hazy, clean, data_name) in enumerate(val_loader): 
        clean = clean.to(device)
        hazy = hazy.to(device)
        frame_out, _ = MyEnsembleNet(hazy)
        psnr_list.extend(to_psnr(frame_out, clean))
        ssim_list.extend(to_ssim_skimage(frame_out, clean))
        if not os.path.exists('test/'):
            os.makedirs('test/')
        imwrite(frame_out, 'test/' + ''.join(data_name) + '.png', range=(0, 1))

avr_psnr = sum(psnr_list) / len(psnr_list)
avr_ssim = sum(ssim_list) / len(ssim_list)
print('PSNR: ', avr_psnr, 'SSIM: ', avr_ssim)

MyEnsembleNet parameters: 31417322
--- weight loaded ---
-----Testing-----
PSNR:  20.746215161663343 SSIM:  0.753701651096344


# Testing With Cropping & Fusing

In [None]:
# # VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/valid_dense/HAZY/"
# # VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/valid_dense/GT/"

# VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_Hazy/"
# VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_GT/"

# VAL_BATCH_SIZE = 1
# NUM_WORKERS = 0

# # --- output picture and check point --- #
# G_model_save_dir = "/content/drive/MyDrive/Graduation Project/CANT_Haze/Weights/Generator_NH_Restormer_Twice_HM_GBlock_AWB_Sobel_Best.pth"
# # --- Gpu device --- #
# device_ids = [Id for Id in range(torch.cuda.device_count())]
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# # --- Define the network --- #
# MyEnsembleNet = Custom_fusion_net().float()
# print('MyEnsembleNet parameters:', sum(param.numel() for param in MyEnsembleNet.parameters()))

# # --- Load testing data --- #
# val_data = dehaze_test_dataset(VAL_HAZY_IMAGES_PATH, VAL_GT_IMAGES_PATH)
# val_loader = DataLoader(dataset=val_data, batch_size=VAL_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# MyEnsembleNet = MyEnsembleNet.to(device)

# # --- Load the network weight --- #
# try:
#     MyEnsembleNet.load_state_dict(torch.load(G_model_save_dir))
#     print('--- weight loaded ---')
# except:
#     print('--- no weight loaded ---')
# # --- Start training --- #
# print("-----Testing-----")     
# with torch.inference_mode():
#     psnr_list = []
#     ssim_list = []
#     MyEnsembleNet.eval()
#     for batch_idx, (hazy_up,hazy_down,name,clean) in enumerate(val_loader):
#         hazy_up = hazy_up.to(device)
#         hazy_down = hazy_down.to(device)
#         clean = clean.to(device)
#         frame_out_up, _ = MyEnsembleNet(hazy_up)
#         frame_out_down, _ = MyEnsembleNet(hazy_down)
#         #----------- With Cuda out of memory -----------
#         frame_out = (torch.cat([frame_out_up[:, :, :600, :].permute(0, 2, 3, 1), frame_out_down[:, :, 552:, :].permute(0, 2, 3, 1)],1)).permute(0, 3, 1, 2)
        
#         #----------- Without Cuda out of memory -----------
#         # frame_out = (torch.cat([frame_out_up[:, :, :600, :].permute(0, 2, 3, 1), frame_out_down[:, :, 168:, :].permute(0, 2, 3, 1)],1)).permute(0, 3, 1, 2)
        
#         psnr_list.extend(to_psnr(frame_out, clean))
#         ssim_list.extend(to_ssim_skimage(frame_out, clean))
#         if not os.path.exists('output/'):
#             os.makedirs('output/')
#         imwrite(frame_out, 'output/' + ''.join(name) + '.png', range=(0, 1))
# avr_psnr = sum(psnr_list) / len(psnr_list)
# avr_ssim = sum(ssim_list) / len(ssim_list)
# print('PSNR: ', avr_psnr, 'SSIM: ', avr_ssim)

# Training

In [None]:
# # --- train --- #
# train_epoch = 100 # Currently at 1700 epochs
# best_psnr = 20.75
# # TRAIN_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/train_dense/haze/"
# # TRAIN_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/train_dense/GT/"
# # VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/valid_dense/HAZY/"
# # VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/valid_dense/GT/"
# TRAIN_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/train_NH/haze/"
# TRAIN_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/train_NH/clear_images/"
# VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_Hazy/"
# VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_GT/"
# IMAGE_SIZE = (256,256)
# TRAIN_BATCH_SIZE = 1
# VAL_BATCH_SIZE = 1
# NUM_WORKERS = 0
# SHUFFLE = True

# # --- output picture and check point --- #
# G_model_save_dir = "/content/drive/MyDrive/Graduation Project/CANT_Haze/Weights/Generator_NH_Restormer_Twice_HM_GBlock_AWB_Sobel.pth"
# D_model_save_dir = "/content/drive/MyDrive/Graduation Project/CANT_Haze/Weights/Discriminator_NH_Restormer_Twice_HM_GBlock_AWB_Sobel.pth"
# G_best_model_save_dir = "/content/drive/MyDrive/Graduation Project/CANT_Haze/Weights/Generator_NH_Restormer_Twice_HM_GBlock_AWB_Sobel_Best.pth"
# D_best_model_save_dir = "/content/drive/MyDrive/Graduation Project/CANT_Haze/Weights/Discriminator_NH_Restormer_Twice_HM_GBlock_AWB_Sobel_Best.pth"

# # --- Gpu device --- #
# device_ids = [Id for Id in range(torch.cuda.device_count())]
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# # --- Define the network --- #
# MyEnsembleNet = Custom_fusion_net().float()
# DNet = Discriminator()

# print('MyEnsembleNet parameters:', sum(param.numel() for param in MyEnsembleNet.parameters()))

# # --- Build optimizer --- #
# G_optimizer = torch.optim.Adam(MyEnsembleNet.parameters(), lr=0.0001)
# D_optim = torch.optim.Adam(DNet.parameters(), lr=0.0001)

# # --- Load training data --- #
# dataset = custom_dehaze_train_dataset(HAZY_path = TRAIN_HAZY_IMAGES_PATH, GT_path = TRAIN_GT_IMAGES_PATH, Image_Size = IMAGE_SIZE,is_train = True)
# train_loader = DataLoader(dataset=dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)


# # --- Load testing data --- #
# val_data = dehaze_test_dataset(VAL_HAZY_IMAGES_PATH, VAL_GT_IMAGES_PATH)
# val_loader = DataLoader(dataset=val_data, batch_size=VAL_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# MyEnsembleNet = MyEnsembleNet.to(device)
# DNet = DNet.to(device)
# # --- Load the network weight --- #
# try:
#     MyEnsembleNet.load_state_dict(torch.load(G_model_save_dir))
#     DNet.load_state_dict(torch.load(D_model_save_dir))
#     print('--- weight loaded ---')
# except:
#     print('--- no weight loaded ---')

# # --- Define the perceptual loss network --- #
# backbone_model = vgg16(weights = VGG16_Weights.DEFAULT)

# backbone_model = backbone_model.features[:16].to(device)
# for param in backbone_model.parameters():
#      param.requires_grad = False

# loss_network = LossNetwork(backbone_model)
# loss_network.eval()
# msssim_loss = msssim

# # --- Start training --- #
# for epoch in range(train_epoch):
#     psnr_list = []
#     ssim_list = []
#     MyEnsembleNet.train()
#     DNet.train()
#     avg_loss = 0
#     print("We are in epoch: " + str(epoch+1))

#     for batch_idx, (hazy, clean, clean_sobel) in enumerate(train_loader): 
#             hazy = hazy.to(device)
#             clean = clean.to(device)
#             clean_sobel = clean_sobel.to(device)
#             output, hazy_sobel = MyEnsembleNet(hazy.float())
#             real_out = DNet(clean)
#             fake_out = DNet(output)
#             real_loss = F.binary_cross_entropy(real_out, torch.ones(real_out.size()).to(device))
#             fake_loss = F.binary_cross_entropy(fake_out, torch.zeros(fake_out.size()).to(device))
#             D_loss = (real_loss + fake_loss) / 2
#             DNet.zero_grad()
#             D_loss.backward(retain_graph=True)
#             smooth_loss_l1 = F.smooth_l1_loss(output, clean)
#             perceptual_loss = loss_network(output, clean)
#             msssim_loss_ = 1 - msssim_loss(output, clean, normalize=True)
#             calc_psnr = to_psnr(output, clean)
#             calc_ssim = to_ssim_skimage(output, clean)
#             sobel_l1_loss = F.smooth_l1_loss(hazy_sobel, clean_sobel.float())
#             sobel_msssim_loss = 1 - msssim_loss(hazy_sobel, clean_sobel.float(), normalize=True)
#             total_loss = (smooth_loss_l1 + sobel_l1_loss)/2 + 0.05 * perceptual_loss +  (msssim_loss_ + sobel_msssim_loss)/2
#             avg_loss += total_loss.item()
#             MyEnsembleNet.zero_grad()
#             total_loss.backward()
#             G_optimizer.step()
#             D_optim.step()
#             psnr_list.extend(calc_psnr)
#             ssim_list.extend(calc_ssim)
    
#     avr_psnr = sum(psnr_list) / len(psnr_list)
#     avr_ssim = sum(ssim_list) / len(ssim_list)
#     print('AVG PSNR: ', avr_psnr, 'AVG SSIM: ', avr_ssim, 'AVG Loss: ', avg_loss/len(psnr_list))

#     if (epoch+1) % 5 == 0: 
#       print("-----Testing-----")     
#       with torch.inference_mode():
#           psnr_list = []
#           ssim_list = []
#           MyEnsembleNet.eval()
#           for batch_idx, (hazy_up,hazy_down,name,clean) in enumerate(val_loader):
#               hazy_up = hazy_up.to(device)
#               hazy_down = hazy_down.to(device)
#               clean = clean.to(device)
#               frame_out_up = MyEnsembleNet(hazy_up)
#               frame_out_down = MyEnsembleNet(hazy_down)
#               frame_out = (torch.cat([frame_out_up[:, :, 0:600, :].permute(0, 2, 3, 1), frame_out_down[:, :, 552:, :].permute(0, 2, 3, 1)],1)).permute(0, 3, 1, 2)
#               psnr_list.extend(to_psnr(frame_out, clean))
#               ssim_list.extend(to_ssim_skimage(frame_out, clean))
#               imwrite(frame_out, '/content/drive/MyDrive/Graduation Project/CANT_Haze/Restormer Twice & AWB Results/' + ''.join(name) + '.png', range=(0, 1))

#       avr_psnr = sum(psnr_list) / len(psnr_list)
#       avr_ssim = sum(ssim_list) / len(ssim_list)
#       print('PSNR: ', avr_psnr, 'SSIM: ', avr_ssim)    
#       torch.save(MyEnsembleNet.state_dict(), G_model_save_dir)
#       torch.save(DNet.state_dict(), D_model_save_dir)
#       print("-----Model Saved-----")
#       if(avr_psnr > best_psnr):
#           best_psnr = avr_psnr
#           torch.save(MyEnsembleNet.state_dict(), G_best_model_save_dir)
#           torch.save(DNet.state_dict(), D_best_model_save_dir)
#           print("-----Best Model Saved-----")