In [None]:
import time
import math as m
import numpy as np
import torch as t
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import structural_similarity as ssim, 
from skimage.metrics import peak_signal_noise_ratio as psnr

import numpy as np
from PIL import Image
import glob

import matplotlib.pyplot as plt

In [None]:
 
class HazeDataset(Dataset):
    def __init__(self, gt_images_path, hazy_images_path, transform=None):
         
        gt_image_paths = list(glob.glob(gt_images_path + '*.jpg'))
        hazy_image_paths = list(glob.glob(hazy_images_path + '*.jpg'))

        gt_images = []
        hazy_images = []

        for gt_image in gt_image_paths:
            img_name = gt_image.split('/')[-1].split('.')[0]
            for hazy_image in hazy_image_paths:
                if hazy_image.find(img_name) != -1:
                    gt_images.append(gt_image)
                    hazy_images.append(hazy_image)

#         total_images = len(gt_images)

#         temp = list(zip(gt_images, hazy_images))
#         np.random.shuffle(temp)
#         gt_images, hazy_images = zip(*temp)

#         if self.train:
#             self.gt_image_paths = gt_images[: int(total_images * 0.9)]
#             self.hazy_image_paths = hazy_images[: int(total_images * 0.9)]
#         else:
#             self.gt_image_paths = gt_images[int(total_images * 0.9) : ]
#             self.hazy_image_paths = hazy_images[int(total_images * 0.9) : ]

    def __getitem__(self, index):
        gt_image = Image.open(self.gt_image_paths[index])
        gt_image = np.array(gt_image, dtype=np.float32)
        hazy_image = Image.open(self.hazy_image_paths[index])
        hazy_image = np.array(hazy_image, dtype=np.float32)

        gt_image /= 255
        hazy_image /= 255

        if transform:
            gt_image  = self.transforms(gt_image)
            hazy_image = self.transforms(hazy_image)
            
        return hazy_image, gt_image
         
    def __len__(self):
        return len(self.gt_image_paths)


In [None]:
def random_crop(gt_image, hazy_image, target_shape = (224, 224)):
    
    x = np.random.randint(0, gt_image.shape[1] - target_shape[0])
    y = np.random.randint(0, gt_image.shape[0] - target_shape[1])
    
    gt_image = gt_image[y : y + target_shape[0], x : x + target_shape[1]]
    hazy_image = hazy_image[y : y + target_shape[0], x : x + target_shape[1]]
    
    return hazy_image, gt_image


# load image 
def load_image(hazy_image_path, gt_image_path):
    
    gt_image = Image.open(gt_image_path)
    gt_image = np.array(gt_image, dtype=np.float32)
    hazy_image = Image.open(hazy_image_path)
    hazy_image = np.array(hazy_image, dtype=np.float32)
    
    # Random Crop the image as suggested in paper
    hazy_image, gt_image = random_crop(gt_image, hazy_image)
    gt_image /= 255
    hazy_image /= 255
    
    gt_img_tensor = t.from_numpy(gt_image)
    hazy_img_tensor = t.from_numpy(hazy_image)
    
    return hazy_img_tensor.permute(2, 0, 1), gt_img_tensor.permute(2, 0, 1)


# Split dataset into train and validation splits
def get_data_splits(gt_images_path, hazy_images_path):
    
    gt_image_paths = list(glob.glob(gt_images_path + '*.jpg'))
    hazy_image_paths = list(glob.glob(hazy_images_path + "*.jpg"))
    
    gt_images = []
    hazy_images = []
    
    for gt_image in gt_image_paths:
        img_name = gt_image.split('/')[-1].split('.')[0]
        for hazy_image in hazy_image_paths:
            if hazy_image.find(img_name) != -1:
                gt_images.append(gt_image)
                hazy_images.append(hazy_image)
        
    
    total_images = len(gt_images)
    
    temp = list(zip(gt_images, hazy_images)) 
    np.random.shuffle(temp) 
    gt_images, hazy_images = zip(*temp)
        
    gt_images = list(gt_images)
    hazy_images = list(hazy_images)
    
    
    train_gt = gt_images[: int(total_images * 0.9)]
    train_hazy = hazy_images[: int(total_images * 0.9)]
    val_gt = gt_images[int(total_images * 0.9) : ]
    val_hazy = hazy_images[int(total_images * 0.9) : ]
    
    
    return {
        'train_gt': train_gt,
        'train_hazy': train_hazy,
        'val_gt': val_gt,
        'val_hazy': val_hazy
    }


# Custom Dataset with Hazy anf Ground Truth Images
class CustomDataset(Dataset):
    def __init__(self, hazy_image_paths, gt_image_paths):
        self.gt_image_paths = gt_image_paths
        self.hazy_image_paths = hazy_image_paths
        
    def __getitem__(self, index):
        return load_image(self.hazy_image_paths[index], self.gt_image_paths[index])

    
    def __len__(self):
        return len(self.gt_image_paths)

In [None]:
def conv2(in_c, out_c):
    conv = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
    )
    return conv

def crop_img(tensor, target_tensor):

    target_size = target_tensor.size()[2]
    delta = target_size
    # 
    return tensor[:, :,  :delta, :delta]


class DUnet(nn.Module):
    def __init__(self):
        super(DUnet, self).__init__()

        self.maxP_2 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.conv1 = conv2(3, 64)
        self.conv2 = conv2(64, 64)
        self.conv3 = conv2(64, 128)
        self.conv4 = conv2(128, 128)
        self.conv5 = conv2(128, 256)
        self.conv6 = conv2(256, 256)
        self.conv7 = conv2(256, 512)
        self.conv8 = conv2(512, 512)
        self.conv9 = conv2(512, 1024)

        self.up_trans1 = nn.ConvTranspose2d(
            in_channels = 1024,
            out_channels = 512,
            kernel_size = 4,
            stride = 2, padding=1
        )
        self.up_conv1 = conv2(1024,512)

        self.up_trans2 = nn.ConvTranspose2d(
            in_channels = 512,
            out_channels = 256,
            kernel_size = 4,
            stride = 2, padding=1
        )
        self.up_conv2 = conv2(512, 256)

        self.up_trans3 = nn.ConvTranspose2d(
            in_channels = 256,
            out_channels = 128,
            kernel_size = 4,
            stride = 2, padding=1
        )
        self.up_conv3 = conv2(256,128)

        self.up_trans4 = nn.ConvTranspose2d(
            in_channels = 128,
            out_channels = 64,
            kernel_size = 4,
            stride = 2, padding=1
        )
        self.up_conv4 = conv2(128, 64)

        self.out = nn.Conv2d(
            in_channels=64, 
            out_channels=3, 
            kernel_size=1)

    def forward(self, input):

        ######### INCODING #########

        x1 = self.conv1(input)
        x2 = self.conv2(x1)  # x = y1
        xm = self.maxP_2(x2) 
        x3 = self.conv3(xm)
        y1 = crop_img(x1, xm)
        x4 = self.conv4(t.cat([xm, y1], axis=1))

        xm2 = self.maxP_2(x4)
        x5 = self.conv5(xm2) #y3

        y2 = crop_img(x3, xm2)
        x6 = self.conv6(t.cat([xm2, y2], axis=1))
        x7 = self.conv6(x6)

        xm3 = self.maxP_2(x7)
        y3 = crop_img(x5, xm3)
         
        x8 = self.conv8(t.cat([xm3, y3], axis=1))
        x9 = self.conv9(x8)
        xm4 = self.maxP_2(x9)

        ######## DECODEING #########

        x = self.up_trans1(xm4)
        y = crop_img(x8,x)
        x = self.up_conv1(t.cat([x,y],1))
        
        x = self.up_trans2(x)
        y = crop_img(x7,x)
        x = self.up_conv2(t.cat([x,y],axis=1)) 

        x = self.up_trans3(x)
        y = crop_img(x4,x)
        x = self.up_conv3(t.cat([x,y],axis=1))

        x = self.up_trans4(x)
        y = crop_img(x2,x)
        x = self.up_conv4(t.cat([x,y],axis=1))
        
        out = self.out(x)
        
        return out

In [None]:
from tqdm import tqdm 

device = t.device('cuda' if t.cuda.is_available() else 'cpu')

gt_path = '/kaggle/input/dehaze/clear_images/'
hazy_path = '/kaggle/input/dehaze/haze/'

batch_size = 32
n_epochs = 10
lr = 0.0001

transform = transforms.Compose([
                            transforms.RandomCrop((224, 224)),
                            transforms.ToTensor()
                        ])

#print(device)

data_splits = get_data_splits(gt_path, hazy_path)
train_dataset = CustomDataset(data_splits['train_hazy'], data_splits['train_gt'])
val_dataset = CustomDataset(data_splits['val_hazy'], data_splits['val_gt'])

# Training and Validation Data Loaders

train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle = True)


def show_image(hazy_image, gt_image, predicted_image):
    
    title = ['Hazy Image', 'Ground Truth Image', 'Predicted']

    plt.figure(figsize=(15, 15))
    display_list = [
                        hazy_image.cpu().permute(1, 2, 0).numpy(),
                        gt_image.cpu().permute(1, 2, 0).numpy(),
                        predicted_image.detach().cpu().permute(1, 2, 0).numpy()
                   ]

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')
        
    plt.show()


def init_weights(m):
    if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
        t.nn.init.normal_(m.weight, mean=0.0, std=0.008)
        m.bias.data.fill_(0.01)
        
net = DUnet().to(device)
net.apply(init_weights)
opt = t.optim.Adam(net.parameters(), lr = lr)
criterion = nn.MSELoss()



for epoch in range(n_epochs):
    total_train_loss = 0
    total_val_loss = 0
    ssim_score = 0.0     
    psnr_score = 0.0
    start_time = time.time()
    
    
    print(f'Epoch {epoch + 1} started...')
   
    net.train()
    for (hazy_images, gt_images) in tqdm(train_dataloader):
        curr_batch_size = hazy_images.size()
        hazy_images = hazy_images.to(device)
        gt_images = gt_images.to(device)
        
        k = net(hazy_images)
        outputs = k*hazy_images - k + 1.0

        train_loss = criterion(outputs, gt_images)
        opt.zero_grad()
        train_loss.backward()
        opt.step()
        total_train_loss += train_loss.item() 
    
    with torch.zero_grad()
        for(hazy_images, gt_images) in val_dataloader:
            with t.no_grad():
                hazy_images = hazy_images.to(device)
                gt_images = gt_images.to(device)
                k = net(hazy_images)
                outputs = k*hazy_images - k + 1.0
                
                if epoch%5:
                    show_image(hazy_images[0], gt_images[0], outputs[0])

                val_loss = criterion(outputs, gt_images)
                total_val_loss += val_loss.item() 
                
                for j in range(outputs.size()[0]):
                    ssim_score += ssim(outputs[j, 0, :, :].detach().cpu().numpy(),
                                       gt_images[j, :, :].detach().cpu().numpy(), data_range=1)
                    psnr_score += psnr(outputs[j, 0, :, :].detach().cpu().numpy(),
                                       gt_images[j, :, :].detach().cpu().numpy(), data_range=1)
        
        print(f'SSMI_score: {ssim_score / (len(val_dataloader) * outputs.size()[0])})
        print(f'PSNR_score: {psnr / (len(val_dataloader) * outputs.size()[0])})
        
        print(f'Total train loss: {total_train_loss}')
        print(f'Total validation loss: {total_val_loss}')
    
    end_time = time.time()
    
    print(f'Epoch {epoch + 1} ended, time taken: {end_time - start_time}s')
        
        
t.save(net.state_dict(), 'state_dict_model.pt')