# Code ALL-IN-ONE for the Project IIN MVA 2023/2024, David and Gabriel

The goal is to use already trained models found at https://github.com/VITA-Group/All-In-One-Underwater-Image-Enhancement-using-Domain-Adversarial-Learning/tree/master for our own images. This was very time-consuming for us since we had to change several parts of the code, but we finally managed to do it.

## Library importation

In [62]:
#!pip install --upgrade scikit-image

In [63]:
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torchvision import transforms, utils
from torchvision.utils import save_image
import os
from PIL import Image
from tqdm import tqdm
import random
import numpy as np
from skimage.metrics import structural_similarity as ssim_fn
from skimage.metrics import peak_signal_noise_ratio as psnr_fn
from collections import defaultdict
import click
from glob import glob


## Dataset loading

Here, we had to change several things from the original code. First, we had to change size and test_start values. Then, we had to add the loading of clear images (cl_img).

In [64]:


class UIEBDataset(Dataset):
    def __init__(self, data_path, label_path, img_format='png', size=10, mode='test', train_start=0, val_start=30000, test_start=500):
        self.data_path = data_path
        self.label_path = label_path
        self.mode = mode
        self.size = size
        self.train_start = train_start
        self.test_start = test_start
        self.val_start = val_start

        self.uw_images = glob(os.path.join(self.data_path, '*.' + img_format))
        print("Found uw images:", len(self.uw_images))
        
        self.cl_images = glob(os.path.join(self.label_path, '*.' + img_format))
        print("Found cl images:", len(self.uw_images))

        if self.mode == 'train':
            self.uw_images = self.uw_images[self.train_start:self.train_start+self.size]
        elif self.mode == 'test':
            self.uw_images = self.uw_images[self.test_start:self.test_start+self.size]
            self.cl_images = self.cl_images[self.test_start:self.test_start+self.size]
        elif self.mode == 'val':
            self.uw_images = self.uw_images[self.val_start:self.val_start+self.size]
        
        
        self.transform = transforms.Compose([
            transforms.ToTensor()
            ])

    def __getitem__(self, index):
        if index >= len(self.uw_images):
            raise IndexError(f"Index {index} is out of bounds for dataset with size {len(self.uw_images)}")

        uw_img = self.transform(Image.open(self.uw_images[index]))
        cl_img = self.transform(Image.open(self.cl_images[index]))
        return uw_img, cl_img, -1, os.path.basename(self.uw_images[index]) #The last is the name

    def __len__(self):
        return self.size

## Models

In [65]:

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class UNetEncoder(nn.Module):
    def __init__(self, n_channels=3):
        super(UNetEncoder, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        return x5, (x1, x2, x3, x4)

class UNetDecoder(nn.Module):
    def __init__(self, n_channels=3):
        super(UNetDecoder, self).__init__()
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128)
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.outc = outconv(64, n_channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, enc_outs):
        x = self.sigmoid(x)
        x = self.up1(x, enc_outs[3])
        x = self.up2(x, enc_outs[2])
        x = self.up3(x, enc_outs[1])
        x = self.up4(x, enc_outs[0])
        x = self.outc(x)
        return nn.Tanh()(x)

class Classifier(nn.Module):
    def __init__(self, num_classes):
        super(Classifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Conv2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),
            Flatten(),
            nn.Linear(4096, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(True),
            nn.Dropout(0.3),
            nn.Linear(1024, num_classes)
            )

    def forward(self, input):
        return self.classifier(input)
    
    
    
# UNET PARTS:

class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))
        

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x

## Test

In [74]:
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    return x

def var_to_img(img):
    return (img * 255).cpu().data.numpy().transpose(1, 2, 0).astype(np.uint8)

def test(fE, fI, dataloader, model_name, which_epoch):
    mse_scores = []
    ssim_scores = []
    psnr_scores = []
    criterion_MSE = nn.MSELoss().cuda()

    for idx, data in tqdm(enumerate(dataloader)):
        uw_img, cl_img, water_type, name = data
        uw_img = Variable(uw_img)
        cl_img = Variable(cl_img, requires_grad=False)
        
        fE_out, enc_outs = fE(uw_img)
        fI_out = to_img(fI(fE_out, enc_outs).detach())
        enc_outs = None
        
        print("uw_img shape:", uw_img.squeeze().cpu().data.shape)
        print("fI_out shape:", fI_out.squeeze().cpu().data.shape)
        print("cl_img shape:", cl_img.squeeze().cpu().data.shape)

        save_image(torch.stack([uw_img.squeeze().cpu().data, fI_out.squeeze().cpu().data, cl_img.squeeze().cpu().data]), 'C:/Users/davfa/Desktop/results/{}/{}/{}_{}.jpg'.format(model_name, which_epoch, name[0], 'out'))

        mse = criterion_MSE(fI_out, cl_img).item()
        mse_scores.append(mse)

        fI_out = (fI_out * 255).squeeze(0).cpu().data.numpy().transpose(1, 2, 0).astype(np.uint8)
        cl_img = (cl_img * 255).squeeze(0).cpu().data.numpy().transpose(1, 2, 0).astype(np.uint8)

       # ssim = ssim_fn(fI_out, cl_img, multichannel=True)
       # psnr = psnr_fn(cl_img, fI_out)

       # ssim_scores.append(ssim)
        #psnr_scores.append(psnr)

    return _,_,mse_scores #ssim_scores, psnr_scores, mse_scores


In [78]:
def main(name, num_channels, test_dataset, data_path, label_path, which_epoch, test_size, fe_load_path, fi_load_path):

    if not os.path.exists('C:/Users/davfa/Desktop/results'):
        os.mkdir('C:/Users/davfa/Desktop/results')

    if not os.path.exists('C:/Users/davfa/Desktop/results/{}'.format(name)):
        os.mkdir('C:/Users/davfa/Desktop/results/{}'.format(name))

    if not os.path.exists('C:/Users/davfa/Desktop/results/{}/{}'.format(name, which_epoch)):
        os.mkdir('C:/Users/davfa/Desktop/results/{}/{}'.format(name, which_epoch))

    fE_load_path = fe_load_path
    fI_load_path = fi_load_path

    fE = UNetEncoder(num_channels)
    fI = UNetDecoder(num_channels)

    if which_epoch:
        fE.load_state_dict(torch.load(os.path.join('C:/Users/davfa/Downloads', name, 'fE_{}.pth'.format(which_epoch)),map_location=torch.device('cpu')))
        fI.load_state_dict(torch.load(os.path.join('C:/Users/davfa/Downloads', name, 'fI_{}.pth'.format(which_epoch)),map_location=torch.device('cpu')))
    else:
        fE.load_state_dict(torch.load(fE_load_path,map_location=torch.device('cpu')))
        fI.load_state_dict(torch.load(fI_load_path,map_location=torch.device('cpu')))

    fE.eval()
    fI.eval()

    if test_dataset=='nyu':
        test_dataset = NYUUWDataset(data_path, 
            label_path,
            size=3000,
            test_start=33000,
            mode='test')
    else:
        # Add more datasets
        test_dataset = UIEBDataset(data_path, 
            label_path,
            size=10,
            test_start=500,
            mode='test')

    batch_size = 1
    dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    _,_, mse_scores = test(fE, fI, dataloader, name, which_epoch)

    print ("Average MSE: {}".format(sum(mse_scores)/len(mse_scores)))


In [79]:
name = "testdavid"
num_channels = 3
which_epoch = False
fe_load_path = "C:/Users/davfa/Downloads/fE_86.pth"
fi_load_path = "C:/Users/davfa/Downloads/fI_86.pth"

data_path = "C:/Users/davfa/Downloads/raw-890/raw-890"
label_path = "C:/Users/davfa/Downloads/reference-890/reference-890"
test_size = 10
test_dataset = 'uieb'

main(name, num_channels, test_dataset, data_path, label_path, which_epoch, test_size, fe_load_path, fi_load_path)

Found uw images: 890
Found cl images: 890


1it [00:12, 12.20s/it]

uw_img shape: torch.Size([3, 333, 500])
fI_out shape: torch.Size([3, 333, 500])
cl_img shape: torch.Size([3, 333, 500])


2it [00:24, 12.05s/it]

uw_img shape: torch.Size([3, 333, 500])
fI_out shape: torch.Size([3, 333, 500])
cl_img shape: torch.Size([3, 333, 500])


3it [00:45, 16.47s/it]

uw_img shape: torch.Size([3, 480, 640])
fI_out shape: torch.Size([3, 480, 640])
cl_img shape: torch.Size([3, 480, 640])


4it [00:57, 14.43s/it]

uw_img shape: torch.Size([3, 333, 500])
fI_out shape: torch.Size([3, 333, 500])
cl_img shape: torch.Size([3, 333, 500])


5it [01:08, 13.29s/it]

uw_img shape: torch.Size([3, 333, 500])
fI_out shape: torch.Size([3, 333, 500])
cl_img shape: torch.Size([3, 333, 500])


6it [01:28, 15.68s/it]

uw_img shape: torch.Size([3, 480, 640])
fI_out shape: torch.Size([3, 480, 640])
cl_img shape: torch.Size([3, 480, 640])


7it [01:40, 14.29s/it]

uw_img shape: torch.Size([3, 333, 500])
fI_out shape: torch.Size([3, 333, 500])
cl_img shape: torch.Size([3, 333, 500])


8it [01:51, 13.49s/it]

uw_img shape: torch.Size([3, 333, 500])
fI_out shape: torch.Size([3, 333, 500])
cl_img shape: torch.Size([3, 333, 500])


9it [02:03, 12.92s/it]

uw_img shape: torch.Size([3, 333, 500])
fI_out shape: torch.Size([3, 333, 500])
cl_img shape: torch.Size([3, 333, 500])


10it [02:16, 13.67s/it]

uw_img shape: torch.Size([3, 333, 500])
fI_out shape: torch.Size([3, 333, 500])
cl_img shape: torch.Size([3, 333, 500])
Average MSE: 0.048664132878184316



