In [1]:
import torch
import torchvision
import tarfile
from torchvision.datasets.utils import download_url
from torch.utils.data import random_split

from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor

import pandas as pd
from torchvision.io import read_image

from torch.utils.data.dataset import Dataset
from glob import glob
import os
from PIL import Image
from torchvision import transforms
import numpy as np
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset

import torch.nn as nn
import torch.nn.functional as F

from torchvision import models
from torch.autograd import Variable

from tqdm import tqdm
import random

import datetime

from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

#  DataSet Loader

In [2]:
class NYUUWData(Dataset):
    def __init__(self, data_path, label_path, size=4500, mode='val', train_start=0, val_start=4500, test_start=5000):

        self.data_path = data_path
        self.label_path = pd.read_csv(os.path.join(label_path), names=['uw_images', 'cl_images', 'w_type'])
        self.mode = mode
        self.size = size
        self.train_start = train_start
        self.test_start = test_start
        self.val_start = val_start
        
        if self.mode == 'train':
            self.label_path = self.label_path[self.train_start:self.train_start+self.size]
        #    self.label_path = shuffle(self.label_path, random_state=1)
        #    self.label_path = self.label_path[0:8]
        elif self.mode == 'test':
            self.label_path = self.label_path[self.test_start:self.test_start+self.size]
        elif self.mode == 'val':
            self.label_path = self.label_path[self.val_start:self.val_start+self.size]
        (self.label_path).reset_index(inplace = True)
        
        self.transform = transforms.Compose([
           transforms.ToPILImage(),
           transforms.Resize(size = (270, 360)),
           transforms.CenterCrop((256, 256)),
           transforms.ToTensor()
        ])
        
    def __len__(self):
        return len(self.label_path)

    def __getitem__(self, idx):
        img_path = os.path.join(self.data_path, self.label_path['uw_images'][idx])
        uw_images = read_image(img_path)

        label_path = os.path.join(self.data_path, self.label_path['cl_images'][idx])
        cl_images = read_image(label_path)
        
        water_type = int(self.label_path['w_type'][idx])

        name = os.path.basename(self.label_path['uw_images'][idx])[:-4]
    
        if self.transform is not None:
            uw_images = self.transform(uw_images)
            cl_images = self.transform(cl_images)
            
        return uw_images, cl_images, water_type, name

## convert to image

In [3]:
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 3, 256, 256)
    return x


## Convert Requirement Grid

In [4]:
def set_requires_grad(nets, requires_grad=False):
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad

    return requires_grad

## Compute

MSE, SSIM, PSNR

In [5]:
def compute_val_metrics(fE, fI, fN, dataloader, no_adv_loss):
    """
        Compute SSIM, PSNR scores for the validation set
    """

    fE.eval()
    fI.eval()
    fN.eval()

    mse_scores = []
    ssim_scores = []
    psnr_scores = []
    corr = 0

    criterion_MSE = nn.MSELoss().cuda()

    for idx, data in tqdm(enumerate(dataloader)):
        uw_img, cl_img, water_type, _ = data
        uw_img = Variable(uw_img).cuda()
        cl_img = Variable(cl_img, requires_grad=False).cuda()

        fE_out, enc_outs = fE(uw_img)
        fI_out = to_img(fI(fE_out, enc_outs))
        fN_out = F.softmax(fN(fE_out), dim=1)

        if int(fN_out.max(1)[1].item()) == int(water_type.item()):
            corr += 1

        mse_scores.append(criterion_MSE(fI_out, cl_img).item())

        fI_out = (fI_out * 255).squeeze(0).cpu().data.numpy().transpose(1, 2, 0).astype(np.uint8)
        cl_img = (cl_img * 255).squeeze(0).cpu().data.numpy().transpose(1, 2, 0).astype(np.uint8)

        ssim_scores.append(ssim(fI_out, cl_img, channel_axis = -1))
        psnr_scores.append(psnr(cl_img, fI_out))

    fE.train()
    fI.train()
    if not no_adv_loss:
        fN.train()

    return sum(ssim_scores)/len(dataloader), sum(psnr_scores)/len(dataloader), sum(mse_scores)/len(dataloader), corr/len(dataloader)

## LOSSES

In [6]:
# adversarial loss

def backward_adv_loss(fN, fE_out, water_type, lambda_adv_loss, num_classes, neg_entropy):
    """
        Backpropagate the adversarial loss
    """

    fN_out = fN(fE_out)
    adv_loss = calc_adv_loss(fN_out, num_classes, neg_entropy) * lambda_adv_loss

    adv_loss.backward()

    return adv_loss

def calc_adv_loss(fN_out, num_classes, neg_entropy):
    """
        Calculate the adversarial loss (negative entropy or cross entropy with uniform distribution)
    """

    if neg_entropy:
        fN_out_softmax = F.softmax(fN_out, dim=1)
        return torch.mean(torch.sum(fN_out_softmax * torch.log(torch.clamp(fN_out_softmax, min=1e-10, max=1.0)), 1))
    else:
        fN_out_log_softmax = F.log_softmax(fN_out, dim=1)
        return -torch.mean(torch.div(torch.sum(fN_out_log_softmax, 1), num_classes))

In [7]:
# new reconstruction loss

class VGG19_PercepLoss(nn.Module):
    """ Calculates perceptual loss in vgg19 space
    """
    def __init__(self, _pretrained_=True):
        super(VGG19_PercepLoss, self).__init__()
        self.vgg = models.vgg19(pretrained=_pretrained_).features
        for param in self.vgg.parameters():
            param.requires_grad_(False)

    def get_features(self, image, layers=None):
        if layers is None: 
            layers = {'30': 'conv5_2'} # may add other layers
        features = {}
        x = image
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in layers:
                features[layers[name]] = x
        return features

    def forward(self, pred, true, layer='conv5_2'):
        true_f = self.get_features(true)
        pred_f = self.get_features(pred)
        return torch.mean((true_f[layer]-pred_f[layer])**2)
    

def back_I_loss(fI, fE_out, enc_outs, uw_img, cl_img, criterion_MSE, optimizer_fI, retain_graph):
    
    ######### New Defind some Loss ##########
    Adv_cGAN = torch.nn.MSELoss().cuda()
    L1_G  = torch.nn.L1Loss().cuda() # similarity loss (l1)
    L_vgg = VGG19_PercepLoss().cuda() # content loss (vgg)
    lambda_1, lambda_con = 7, 3 # 7:3 (as in paper)
    #########################################
    
    fI_out = to_img(fI(fE_out, enc_outs))
    
    loss_GAN =  Adv_cGAN(fI_out, uw_img)
    loss_1 = L1_G(fI_out, cl_img) # similarity loss
    loss_con = L_vgg(fI_out, cl_img)# content loss
    
    I_loss = loss_GAN + lambda_1 * loss_1  + lambda_con * loss_con 
    
    optimizer_fI.zero_grad()
    I_loss.backward(retain_graph=retain_graph)
    optimizer_fI.step()

    return fI_out,  I_loss


In [8]:
# Focal loss

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        if self.reduction == 'mean':
            return torch.mean(focal_loss)
        elif self.reduction == 'sum':
            return torch.sum(focal_loss)
        else:
            return focal_loss

        
def new_back_N(fN, fE_out, num_classes, actual_target, criterion_CE, optimizer_fN):
    
    ## Focal loss
    criterion_focal = FocalLoss(alpha=1, gamma=2).cuda()
    
    #----------- water target -----
    tar = np.arange(num_classes)
    mapping = {}
    for x in range(len(tar)): mapping[tar[x]] = x
    target = []
    for c in actual_target:
      x = c.item()
      arr = list(np.zeros(len(tar), dtype = int))
      arr[mapping[x]] = 1
      target.append(arr)
    #-----------------------------
    target = torch.FloatTensor(target).cuda()
    
    
    fN_out = fN(fE_out.detach())
    
    N_loss = criterion_focal(fN_out, target)
    
    N_loss.backward()

    return N_loss


## Log files   #Save Status

In [9]:
def write_to_log(log_file_path, status):
    """
        Write to the log file
    """
    with open(log_file_path, "a") as log_file:
        log_file.write(status+'\n')

## Model

In [10]:
class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x

class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2))

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


In [11]:
# Encoder
class UNetEncoder(nn.Module):
    def __init__(self, n_channels=3):
        super(UNetEncoder, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        return x5, (x1, x2, x3, x4)




In [12]:
# Decoder

class UNetDecoder(nn.Module):
    def __init__(self, n_channels=3):
        super(UNetDecoder, self).__init__()
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128)
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.outc = outconv(64, n_channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, enc_outs):
        x = self.sigmoid(x)
        x = self.up1(x, enc_outs[3])
        x = self.up2(x, enc_outs[2])
        x = self.up3(x, enc_outs[1])
        x = self.up4(x, enc_outs[0])
        x = self.outc(x)
        return nn.Tanh()(x)

In [13]:
# Classifier

class Classifier(nn.Module):
    def __init__(self, num_classes):
        super(Classifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Conv2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),
            Flatten(),
            nn.Linear(4096, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(True),
            nn.Dropout(0.3),
            nn.Linear(1024, num_classes)
            )

    def forward(self, input):
        return self.classifier(input)

## Main

In [None]:
def main():
    
    # Define data path and label path
    data_path = './data/'
    label_path = './data/label.csv'
    
    train_size = 4500
    val_size = 500
    test_size = 340

    batch_size = 4
    learning_rate = 1e-3
    start_epoch = 1
    end_epoch = 201

    num_classes = 6
    num_channels = 3
    save_interval = 5
    
    lambda_adv_loss = 1
    
    name = 'Checkpoints_values'
    
    continue_train = False
    neg_entropy = True
    no_adv_loss = False
    
    
    train_dataset = NYUUWData(data_path, label_path, size=train_size,train_start=0, mode='train')
    val_dataset = NYUUWData(data_path, label_path, size=val_size,val_start=4500, mode='val')
    test_dataset = NYUUWData(data_path, label_path, size=test_size,test_start=5000, mode='test')

    train_dataloader= DataLoader(train_dataset, batch_size=4, shuffle=False)
    val_dataloader= DataLoader(val_dataset, batch_size=1, shuffle=True)
    test_dataloader= DataLoader(test_dataset, batch_size=1, shuffle=True)
    
    fN = Classifier(num_classes).cuda()
    fN_req_grad = True
    fN.train()
    criterion_CE = nn.CrossEntropyLoss().cuda()
    optimizer_fN = torch.optim.Adam(fN.parameters(), lr=learning_rate, weight_decay=1e-5)
    
    
    
    fE = UNetEncoder(num_channels).cuda()
    fI = UNetDecoder(num_channels).cuda()

    criterion_MSE = nn.MSELoss().cuda()

    optimizer_fE = torch.optim.Adam(fE.parameters(), lr=learning_rate, weight_decay=1e-5)
    optimizer_fI = torch.optim.Adam(fI.parameters(), lr=learning_rate, weight_decay=1e-5)

    fE.train()
    fI.train()
    
    # IF want to use pre trained model or want to continue training
    
   
    if continue_train:
        """
            Load pretrained models to continue training
        """
        
        # path of checkpoints
        fE_load_path = './checkpoints/unet_adv/fE_86.pth'  
        fI_load_path = './checkpoints/unet_adv/fI_86.pth'
        fN_load_path = './checkpoints/unet_adv/fN_86.pth'

        if fE_load_path:
            fE.load_state_dict(torch.load(fE_load_path))
            print ('Loaded fE from {}'.format(fE_load_path))
        if fI_load_path:
            fI.load_state_dict(torch.load(fI_load_path))
            print ('Loaded fI from {}'.format(fI_load_path))
        if not no_adv_loss and fN_load_path:
            fN.load_state_dict(torch.load(fN_load_path))
            print ('Loaded fN from {}'.format(fN_load_path))
            
            
    ## Create checkpoint folder to store files 
    if not os.path.exists('./checkpoints/{}'.format(name)):
        os.mkdir('./checkpoints/{}'.format(name))

    log_file_path = './checkpoints/{}/log_file.txt'.format(name)

    now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M")

    status = '\nTRAINING SESSION STARTED ON {}\n'.format(now)
    write_to_log(log_file_path, status)
    
    if continue_train and not no_adv_loss:
        fI_val_ssim, _, _, fN_val_acc = compute_val_metrics(fE, fI, fN, val_dataloader, no_adv_loss)
    else:
        fI_val_ssim = -1
        fN_val_acc = -1
        
    
    print("Encoder SSIM value = {}, Decoder Acc = {}".format(fI_val_ssim, fN_val_acc))
    
    i = 1
    
    for epoch in range(start_epoch, end_epoch):
        """
            Main training loop
        """

        if not no_adv_loss:
            """
                Print the current cross-validation scores
            """

            status = 'Avg fI val SSIM: {}, Avg fN val acc: {}'.format(fI_val_ssim, fN_val_acc)
            print (status)
            write_to_log(log_file_path, status)

        for idx, data in tqdm(enumerate(train_dataloader)):
            uw_img, cl_img, water_type, _ = data
            uw_img = Variable(uw_img).cuda()
            cl_img = Variable(cl_img, requires_grad=False).cuda()
            actual_target = Variable(water_type, requires_grad=False).cuda()

            fE_out, enc_outs = fE(uw_img)

            if i <= 10:
                """
                    Train the encoder-decoder only
                """

                optimizer_fE.zero_grad()
                fI_out, I_loss = back_I_loss(fI, fE_out, enc_outs, uw_img, cl_img, criterion_MSE, optimizer_fI,  retain_graph=not no_adv_loss)

                if not no_adv_loss:
                    if fN_req_grad:
                        fN_req_grad = set_requires_grad(fN, requires_grad=False)
                    adv_loss = backward_adv_loss(fN, fE_out, water_type, lambda_adv_loss, num_classes, neg_entropy)
                    progress = "\tEpoch: {}\tIter: {}\tI_loss: {}\tadv_loss: {}".format(epoch, idx, I_loss.item(), adv_loss.item())
                else:
                    progress = "\tEpoch: {}\tIter: {}\tI_loss: {}".format(epoch, idx, I_loss.item())

                optimizer_fE.step()

                if idx % 50 == 0:
                    save_image(uw_img.cpu().data, './results/uw_img.png')
                    save_image(fI_out.cpu().data, './results/fI_out.png')
                    save_image(cl_img.cpu().data, './results/cl_img.png')
                
                
            elif i <= 20:
                """
                    Train the nusiance classifier only
                """

                if not fN_req_grad:
                    fN_req_grad = set_requires_grad(fN, requires_grad=True)
                    
                    N_loss = new_back_N_Focal(fN, fE_out, num_classes, actual_target, criterion_CE, optimizer_fN)
                progress = "\tEpoch: {}\tIter: {}\tN_loss: {}".format(epoch, idx, N_loss.item())
                

            else:
                """
                    Train the encoder-decoder only
                """

                optimizer_fE.zero_grad()
                fI_out, I_loss = back_I_loss(fI, fE_out, enc_outs, uw_img, cl_img, criterion_MSE, optimizer_fI,  retain_graph=not no_adv_loss)

                if not no_adv_loss:
                    if fN_req_grad:
                        fN_req_grad = set_requires_grad(fN, requires_grad=False)
                    adv_loss = backward_adv_loss(fN, fE_out, water_type, lambda_adv_loss, num_classes, neg_entropy)

                    progress = "\tEpoch: {}\tIter: {}\tI_loss: {}\tadv_loss: {}".format(epoch, idx, I_loss.item(), adv_loss.item())

                else:
                    progress = "\tEpoch: {}\tIter: {}\tI_loss: {}".format(epoch, idx, I_loss.item())

                optimizer_fE.step()

                if idx % 50 == 0:
                    save_image(uw_img.cpu().data, './results/uw_img.png')
                    save_image(fI_out.cpu().data, './results/fI_out.png')
                    save_image(cl_img.cpu().data, './results/cl_img.png')
                
                
            if idx % 50 == 0:
                print (progress)
                write_to_log(log_file_path, progress)
            
        if i >= 30:
            i = 1
        else :
            i = i+1

        # Save models
        torch.save(fE.state_dict(), './checkpoints/{}/fE_latest.pth'.format(name))
        torch.save(fI.state_dict(), './checkpoints/{}/fI_latest.pth'.format(name))
        if not no_adv_loss:
            torch.save(fN.state_dict(), './checkpoints/{}/fN_latest.pth'.format(name))

        if epoch % save_interval == 0:
            torch.save(fE.state_dict(), './checkpoints/{}/fE_{}.pth'.format(name, epoch))
            torch.save(fI.state_dict(), './checkpoints/{}/fI_{}.pth'.format(name, epoch))
            if not no_adv_loss:
                torch.save(fN.state_dict(), './checkpoints/{}/fN_{}.pth'.format(name, epoch))

        status = 'End of epoch. Models saved.'
        print (status)
        write_to_log(log_file_path, status)

        if not no_adv_loss:
            """
                Compute the cross validation scores after the epoch
            """
            fI_val_ssim, _, _, fN_val_acc = compute_val_metrics(fE, fI, fN, val_dataloader, no_adv_loss)


In [None]:
if __name__== "__main__":
    if not os.path.exists('./results'):
        os.mkdir('./results')
    if not os.path.exists('./checkpoints'):
        os.mkdir('./checkpoints')
    main()