In [6]:
import numpy as np
import os
import time
import torch
import torch.nn as nn


from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm

In [7]:
class Inc(nn.Module):
    def __init__(self,in_channels,filters):
        super(Inc, self).__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1,padding=(1-1) // 2),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=filters, out_channels=filters, kernel_size=(3, 3), stride=(1, 1),dilation=1,padding=(3-1) // 2),
            nn.LeakyReLU(),
            )   
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1,padding=(1-1) // 2),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=filters, out_channels=filters, kernel_size=(5, 5), stride=(1, 1),dilation=1,padding=(5-1) // 2),
            nn.LeakyReLU(),
            )  
        self.branch3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=(3, 3), stride=(1, 1), padding=1),
            nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1),
            nn.LeakyReLU(),

        )   
        self.branch4 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1),
            nn.LeakyReLU(),
        )    
    def forward(self,input):
        o1 = self.branch1(input)
        o2 = self.branch2(input)
        o3 = self.branch3(input)
        o4 = self.branch4(input)
        return torch.cat([o1,o2,o3,o4],dim=1)
class Flatten(nn.Module):
    def forward(self, inp):
        return inp.view(inp.size(0), -1)
class CAM(nn.Module):
    def __init__(self,in_channels,reduction_ratio):
        super(CAM, self).__init__()
        self.module = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            Flatten(),
            nn.Linear(in_channels, in_channels // reduction_ratio),
            nn.Softsign(),
            nn.Linear(in_channels // reduction_ratio, in_channels ),
            nn.Softsign()
            )
    def forward(self,input):
        return input* self.module(input).unsqueeze(2).unsqueeze(3).expand_as(input)

class DICAM(nn.Module):
    def __init__(self):
        super(DICAM, self).__init__()  # Corrected the superclass to DICAM
        self.layer_1_r = Inc(in_channels=1,filters= 64)
        self.layer_1_g = Inc(in_channels=1,filters= 64)
        self.layer_1_b = Inc(in_channels=1,filters= 64)

        self.layer_2_r = CAM(256,4)
        self.layer_2_g = CAM(256,4)
        self.layer_2_b = CAM(256,4)

        self.layer_3 = Inc(768,64)
        self.layer_4 = CAM(256,4)

        self.layer_tail = nn.Sequential(
            nn.Conv2d(in_channels=256,out_channels=24,kernel_size=(3,3),stride=(1, 1),padding=(3-1) // 2),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=24,out_channels=3,kernel_size=(1,1),stride=(1, 1),padding=(1-1) // 2),
            nn.Sigmoid()
        )

    def forward(self,input):
        input_r = torch.unsqueeze(input[:,0,:,:], dim=1)
        input_g  = torch.unsqueeze(input[:,1,:,:], dim=1)
        input_b = torch.unsqueeze(input[:,2,:,:], dim=1)
        
        layer_1_r = self.layer_1_r(input_r)
        layer_1_g = self.layer_1_g(input_g)
        layer_1_b = self.layer_1_b(input_b)

        layer_2_r = self.layer_2_r(layer_1_r)
        layer_2_g = self.layer_2_g(layer_1_g)
        layer_2_b = self.layer_2_b(layer_1_b)

        layer_concat = torch.cat([layer_2_r,layer_2_g,layer_2_b],dim=1)

        layer_3 = self.layer_3(layer_concat)
        layer_4 = self.layer_4(layer_3)

        output = self.layer_tail(layer_4)
        return output

In [8]:
dicam = DICAM()

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
chkpt = torch.load(r"C:\DEVELOPMENT\Semester-projects\DIP\DICAM-Implementation\ckpts\UIEB\DICAM_60.pt", device)

In [11]:
dicam.load_state_dict(chkpt["model_state_dict"])

<All keys matched successfully>

In [12]:
INP_DIR=r"C:\DEVELOPMENT\Semester-projects\DIP\DICAM-Implementation\Data\UIEB\hazy_test"

In [13]:
device = 'cuda' if torch.cuda.is_available() else 'cpu' 

Number of channels = 3

In [14]:
ch = 3

## Custom preprocessor

In [15]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

In [16]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.file_list = os.listdir(root_dir)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.file_list[idx])
        image = Image.open(img_name).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

We've re-sized the image 256x256 and converted into Tensor

### Dataset and Dataloader

In [17]:
dataset = CustomDataset(root_dir=INP_DIR, transform=transform)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

## Checkpoint DIR

In [18]:
CHKPT_DIR=r"C:\DEVELOPMENT\Semester-projects\DIP\DICAM-Implementation\ckpts\UIEB\DICAM_60.pt"

In [19]:
result_dir = 'results/UIEB_DICAM/'
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

In [20]:
checkpoint = torch.load(CHKPT_DIR, device)
dicam.load_state_dict(checkpoint['model_state_dict'])
dicam.eval()
dicam.to(device)

    
if __name__ =='__main__':
    st = time.time()
    with tqdm(total=len(dataset)) as t:
        for idx, img in enumerate(data_loader):
            img = img.to(device)

            output = dicam(img)
            output = (output.clamp_(0.0, 1.0)[0].detach().cpu().numpy().transpose(1, 2, 0) * 255.0).astype('uint8')

            output_image_path = os.path.join(result_dir, dataset.file_list[idx])
            Image.fromarray(output).save(output_image_path)

            t.set_postfix_str("name: {} | new [hw]: {}/{}".format(dataset.file_list[idx], output.shape[0], output.shape[1]))
            t.update(1)

    end = time.time()
    print('Total time taken in secs: ' + str(end - st))
    print('Per image (avg): ' + str(float((end - st) / len(dataset))))


100%|██████████| 178/178 [22:16<00:00,  7.51s/it, name: 9_img_.png | new [hw]: 256/256]  

Total time taken in secs: 1336.227995634079
Per image (avg): 7.506898851876848



