## READ BEFORE USING

### file structure

```
├─data
│  ├─test_data
│  │  ├─DICM
│  │  └─LIME
│  └─train_data
└─snapshots
```

### experimental facility

- if gpu avaliable, please uncomment "for gpu users" statement and comment "for cpu users" statement
- if not, no need to modify anything

### python environment setting

- python 3.7
- torch - pip install torch
- torchvision - pip install torchvision

### experimental parameters setting, tips is highlighted by code comment

- paths
    - training data path - default is 'L:\\teng\\Documents\\zdce\\data\\train_data\\', edit in the 2nd cell in Sec.data_loader
- optimizer
    - weight decay - default is 0.0001, edit in the 3rd cell in Sec.data_loader
    - learning rate - default is 0.0001, edit in the 1st cell in Sec.training_function
    - grad_clip_norm - default is 0.1, edit in the first cell in Sec.training_function
- training
    - epoches - default is 200, edit in the first cell in Sec.training_function
    - batch_size - default is 1, edit in the first cell in Sec.training_function
    - num_workers(mul-processings) - default is 0, edit in the first cell in Sec.training_function
    - display training iters - default is 50, edit in the first cell in Sec.training_function
    - save checkpoint iters - default is 100, edit in the first cell in Sec.training_function
    - checkpoint save folder pth - default is "snapshots/", edit in the first cell in Sec.training_function

In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.vgg import vgg16
import torch.utils.data as data
from torchvision import transforms

import numpy as np
from PIL import Image

import os
import sys
import glob
import math
import random
import cv2
import time

random.seed(1143)

## Loss

In [33]:
class L_color(nn.Module):
    def __init__(self):
        super(L_color, self).__init__()
        

    def forward(self, x ):
        b,c,h,w = x.shape
        mean_rgb = torch.mean(x,[2,3],keepdim=True)
        mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
        Drg = torch.pow(mr-mg,2)
        Drb = torch.pow(mr-mb,2)
        Dgb = torch.pow(mb-mg,2)
        k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5)
        
        return k


class L_spa(nn.Module):
    def __init__(self):
        super(L_spa, self).__init__()
#         # for gpu users
#         kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
#         kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
#         kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
#         kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).cuda().unsqueeze(0).unsqueeze(0)
        # for cpu users
        kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).unsqueeze(0).unsqueeze(0)
        kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).unsqueeze(0).unsqueeze(0)
        kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).unsqueeze(0).unsqueeze(0)
        kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).unsqueeze(0).unsqueeze(0)
        
        self.weight_left = nn.Parameter(data=kernel_left, requires_grad=False)
        self.weight_right = nn.Parameter(data=kernel_right, requires_grad=False)
        self.weight_up = nn.Parameter(data=kernel_up, requires_grad=False)
        self.weight_down = nn.Parameter(data=kernel_down, requires_grad=False)
        self.pool = nn.AvgPool2d(4)


    def forward(self, org , enhance ):
        b,c,h,w = org.shape
        org_mean = torch.mean(org,1,keepdim=True)
        enhance_mean = torch.mean(enhance,1,keepdim=True)
        org_pool =  self.pool(org_mean)
        enhance_pool = self.pool(enhance_mean)
#         # for gpu users
#         weight_diff =torch.max(torch.FloatTensor([1]).cuda() + 10000*torch.min(org_pool - torch.FloatTensor([0.3]).cuda(),torch.FloatTensor([0]).cuda()),torch.FloatTensor([0.5]).cuda())
#         E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5]).cuda()) ,enhance_pool-org_pool)
        # for cpu users
        weight_diff =torch.max(torch.FloatTensor([1]) + 10000*torch.min(org_pool - torch.FloatTensor([0.3]),torch.FloatTensor([0])),torch.FloatTensor([0.5]))
        E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5])) ,enhance_pool-org_pool)
        
        D_org_letf = F.conv2d(org_pool , self.weight_left, padding=1)
        D_org_right = F.conv2d(org_pool , self.weight_right, padding=1)
        D_org_up = F.conv2d(org_pool , self.weight_up, padding=1)
        D_org_down = F.conv2d(org_pool , self.weight_down, padding=1)
        D_enhance_letf = F.conv2d(enhance_pool , self.weight_left, padding=1)
        D_enhance_right = F.conv2d(enhance_pool , self.weight_right, padding=1)
        D_enhance_up = F.conv2d(enhance_pool , self.weight_up, padding=1)
        D_enhance_down = F.conv2d(enhance_pool , self.weight_down, padding=1)
        D_left = torch.pow(D_org_letf - D_enhance_letf,2)
        D_right = torch.pow(D_org_right - D_enhance_right,2)
        D_up = torch.pow(D_org_up - D_enhance_up,2)
        D_down = torch.pow(D_org_down - D_enhance_down,2)
        E = (D_left + D_right + D_up +D_down)

        return E


class L_exp(nn.Module):
    def __init__(self,patch_size,mean_val):
        super(L_exp, self).__init__()
        self.pool = nn.AvgPool2d(patch_size)
        self.mean_val = mean_val
        

    def forward(self, x ):
        b,c,h,w = x.shape
        x = torch.mean(x,1,keepdim=True)
        mean = self.pool(x)
#         # for gpu users
#         d = torch.mean(torch.pow(mean- torch.FloatTensor([self.mean_val] ).cuda(),2))
        # for cpu users
        d = torch.mean(torch.pow(mean- torch.FloatTensor([self.mean_val] ),2))

        return d
        

class L_TV(nn.Module):
    def __init__(self,TVLoss_weight=1):
        super(L_TV,self).__init__()
        self.TVLoss_weight = TVLoss_weight


    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h =  (x.size()[2]-1) * x.size()[3]
        count_w = x.size()[2] * (x.size()[3] - 1)
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        
        return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
    

class Sa_Loss(nn.Module):
    def __init__(self):
        super(Sa_Loss, self).__init__()
        

    def forward(self, x ):
        b,c,h,w = x.shape
        r,g,b = torch.split(x , 1, dim=1)
        mean_rgb = torch.mean(x,[2,3],keepdim=True)
        mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
        Dr = r-mr
        Dg = g-mg
        Db = b-mb
        k =torch.pow( torch.pow(Dr,2) + torch.pow(Db,2) + torch.pow(Dg,2),0.5)
        k = torch.mean(k)
        
        return k


class perception_loss(nn.Module):
    def __init__(self):
        super(perception_loss, self).__init__()
        features = vgg16(pretrained=True).features
        self.to_relu_1_2 = nn.Sequential() 
        self.to_relu_2_2 = nn.Sequential() 
        self.to_relu_3_3 = nn.Sequential()
        self.to_relu_4_3 = nn.Sequential()
        for x in range(4):
            self.to_relu_1_2.add_module(str(x), features[x])
        for x in range(4, 9):
            self.to_relu_2_2.add_module(str(x), features[x])
        for x in range(9, 16):
            self.to_relu_3_3.add_module(str(x), features[x])
        for x in range(16, 23):
            self.to_relu_4_3.add_module(str(x), features[x])
        for param in self.parameters():
            param.requires_grad = False


    def forward(self, x):
        h = self.to_relu_1_2(x)
        h_relu_1_2 = h
        h = self.to_relu_2_2(h)
        h_relu_2_2 = h
        h = self.to_relu_3_3(h)
        h_relu_3_3 = h
        h = self.to_relu_4_3(h)
        h_relu_4_3 = h

        return h_relu_4_3

## model

In [34]:
class enhance_net_nopool(nn.Module):
    def __init__(self):
        super(enhance_net_nopool, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        number_f = 32
        self.e_conv1 = nn.Conv2d(3,number_f,3,1,1,bias=True) 
        self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) 
        self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) 
        self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) 
        self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) 
        self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) 
        self.e_conv7 = nn.Conv2d(number_f*2,24,3,1,1,bias=True) 
        self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
        self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)


    def forward(self, x):
        x1 = self.relu(self.e_conv1(x))
        x2 = self.relu(self.e_conv2(x1))
        x3 = self.relu(self.e_conv3(x2))
        x4 = self.relu(self.e_conv4(x3))
        x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))
        x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))
        x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))
        r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)
        x = x + r1*(torch.pow(x,2)-x)
        x = x + r2*(torch.pow(x,2)-x)
        x = x + r3*(torch.pow(x,2)-x)
        enhance_image_1 = x + r4*(torch.pow(x,2)-x)
        x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1)
        x = x + r6*(torch.pow(x,2)-x)
        x = x + r7*(torch.pow(x,2)-x)
        enhance_image = x + r8*(torch.pow(x,2)-x)
        r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1)

        return enhance_image_1,enhance_image,r


## data loader

In [35]:
def populate_train_list(lowlight_images_path):
    image_list_lowlight = glob.glob(lowlight_images_path + "*.jpg")
    train_list = image_list_lowlight
    random.shuffle(train_list)

    return train_list


class lowlight_loader(data.Dataset):
    def __init__(self, lowlight_images_path):
        self.train_list = populate_train_list(lowlight_images_path) 
        self.size = 256
        self.data_list = self.train_list
        print("Total training examples:", len(self.train_list))

    def __getitem__(self, index):
        data_lowlight_path = self.data_list[index]
        data_lowlight = Image.open(data_lowlight_path)
        data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
        data_lowlight = (np.asarray(data_lowlight)/255.0) 
        data_lowlight = torch.from_numpy(data_lowlight).float()

        return data_lowlight.permute(2,0,1)

    
    def __len__(self):
        return len(self.data_list)

In [36]:
# ----------param edit start----------
# training dataset path
# ----------param edit end----------
train_data_pth = 'L:\\teng\\Documents\\zdce\\data\\train_data\\'

## training function

In [37]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
        

def train(load_from_pretrained=False):
#     # for gpu users
#     os.environ['CUDA_VISIBLE_DEVICES']='0'
#     DCE_net = model.enhance_net_nopool().cuda()
    # for cpu users
    DCE_net = enhance_net_nopool()
    
    DCE_net.apply(weights_init)
    if load_from_pretrained == True:
        DCE_net.load_state_dict(torch.load(config.pretrain_dir))
    # training data path
    train_dataset = lowlight_loader(train_data_pth)
#     train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
# ----------param edit start----------
    # edit batch_size, num_workers
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
# ----------param edit end----------
    LL_color = L_color()
    LL_spa = L_spa()
    LL_exp = L_exp(16,0.6)
    LL_TV = L_TV()
    
# ----------param edit start----------
    # setting learning rate,weight_decay
    optimizer = torch.optim.Adam(DCE_net.parameters(), lr=0.0001, weight_decay=0.0001)
# ----------param edit end----------
    
    DCE_net.train()
# ----------param edit start----------
    # setting training epoches    
    for epoch in range(200):
# ----------param edit end----------
        for iteration, img_lowlight in enumerate(train_loader):

#             # for gpu users
#             img_lowlight = img_lowlight.cuda()
            # for cpu users
            img_lowlight = img_lowlight
            enhanced_image_1,enhanced_image,A  = DCE_net(img_lowlight)
            Loss_TV = 200*LL_TV(A)
            loss_spa = torch.mean(LL_spa(enhanced_image, img_lowlight))
            loss_col = 5*torch.mean(LL_color(enhanced_image))
            loss_exp = 10*torch.mean(LL_exp(enhanced_image))
            loss =  Loss_TV + loss_spa + loss_col + loss_exp

            optimizer.zero_grad()
            loss.backward()
# ----------param edit start----------
            # edit grad_clip_norm
            torch.nn.utils.clip_grad_norm(DCE_net.parameters(),0.1)
# ----------param edit end----------
            optimizer.step()

# ----------param edit start----------
            # edit diskplay_training_iters
            if ((iteration+1) % 50) == 0:
# ----------param edit end----------
                print("Loss at iteration", iteration+1, ":", loss.item())
# ----------param edit start----------
            # edit save_checkpoint_iters
            if ((iteration+1) % 100) == 0:
# ----------param edit end----------
# ----------param edit start----------
                # checkpoint save folder pth
                torch.save(DCE_net.state_dict(), "snapshots/" + "Epoch" + str(epoch) + '.pth')
# ----------param edit end----------

## training processing

In [None]:
train()

Total training examples: 2002


