# Network Pre-training

## Includes

In [None]:
# mass includes
import os, sys, warnings
import ipdb
import torch as t
import torchnet as tnt
from tqdm.notebook import tqdm

# add paths for all sub-folders
paths = [root for root, _, _ in os.walk('.') if 'evals' not in root]
for item in paths:
    sys.path.append(item)

from ipynb.fs.full.config import Config
from ipynb.fs.full.monitor import Visualizer
from ipynb.fs.full.network import *
from ipynb.fs.full.dataLoader import *
from ipynb.fs.full.util import *

## Initialization

In [None]:
# for debugging only
%pdb off
warnings.filterwarnings('ignore')

# choose GPU if available
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
device = t.device('cuda' if t.cuda.is_available() else 'cpu')

# load configuration
opt = Config('paired')

# dataset for training
train_set = PairedSet(opt)
train_loader = t.utils.data.DataLoader(train_set,
                                       batch_size=opt.batch_size,
                                       shuffle=True,
                                       num_workers=opt.num_workers,
                                       pin_memory=True)

# visualizer
vis = Visualizer(env='paired', port=8686)

## Train generator

In [None]:
# initialize network
net_E = Enhancer(pretrain=False).to(device)
net_E_optim = t.optim.Adam(net_E.parameters(), lr=opt.lr)
net_E_sched = t.optim.lr_scheduler.StepLR(net_E_optim,
                                          step_size=opt.upd_freq[0],
                                          gamma=opt.lr_decay)
mse_loss = t.nn.MSELoss()
net_E_meter = tnt.meter.AverageValueMeter()

# start training
for epoch in tqdm(range(opt.max_epoch[0]),
                  desc='epoch',
                  total=opt.max_epoch[0]):
    # reset meters and update learning rate
    net_E_meter.reset()
    net_E_sched.step()

    for (in_img, out_img) in train_loader:
        # copy to device
        in_img = in_img.to(device)
        out_img = out_img.to(device)

        # downsample
        in_img = downsize(in_img)
        out_img = downsize(out_img)

        # reset gradient
        net_E_optim.zero_grad()

        # train generator
        ilm_coes, clr_coes = net_E(in_img, in_img)
        pred_ilm_img = applyIlmCoes(in_img, ilm_coes)
        pred_clr_img = applyClrCoes(pred_ilm_img, clr_coes)

        # convert to graycale bt averaging three channels
        pred_avg_img = t.mean(pred_ilm_img, dim=1, keepdim=True)
        gt_avg_img = t.mean(in_img, dim=1, keepdim=True)
        gt_avg_img = t.where(gt_avg_img <= 0.0031308, 12.92 * gt_avg_img,
                             1.055 * (gt_avg_img**(1 / 2.4)) - 0.055)

        # compute generator loss
        net_E_loss = (mse_loss(pred_avg_img, gt_avg_img) +
                      mse_loss(pred_clr_img, out_img)) / 2

        # update generator
        net_E_loss.backward()
        net_E_optim.step()

        # add to loss meter for logging
        net_E_meter.add(net_E_loss.item())

    # show training status
    if epoch + 1 > 5:
        vis.plot('loss (net_E)', net_E_meter.value()[0])
        disp_img = t.clamp(t.cat([pred_avg_img, gt_avg_img], dim=-1), 0.0,
                           1.0)[0, :, :, :]
        vis.img('pred illum/gt', disp_img.cpu() * 255)
        disp_img = t.clamp(t.cat([pred_clr_img, out_img], dim=-1), 0.0,
                           1.0)[0, :, :, :]
        vis.img('pred color/gt', disp_img.cpu() * 255)

    # save models
    if (epoch + 1) % opt.save_freq == 0:
        net_E.save()

## Train discriminator

In [None]:
# initialize network
net_D = Discriminator().to(device)
net_D_optim = t.optim.Adam(net_D.parameters(), lr=opt.lr)
net_D_sched = t.optim.lr_scheduler.StepLR(net_D_optim,
                                          step_size=opt.upd_freq[1],
                                          gamma=opt.lr_decay)
bce_loss = t.nn.BCEWithLogitsLoss()
net_D_meter = tnt.meter.AverageValueMeter()

# set generator to eval mode
net_E.eval()
for param in net_E.parameters():
    param.requires_grad = False

# start training
for epoch in tqdm(range(opt.max_epoch[1]),
                  desc='epoch',
                  total=opt.max_epoch[1]):
    # reset meters and update learning rate
    net_D_meter.reset()
    net_D_sched.step()

    for (in_img, out_img) in train_loader:
        # copy to device
        in_img = in_img.to(device)
        out_img = out_img.to(device)

        # downsample
        in_img = downsize(in_img)
        out_img = downsize(out_img)

        # inference
        with t.no_grad():
            ilm_coes, _ = net_E(in_img, in_img)
            ilm_img = applyIlmCoes(in_img, ilm_coes)
            ilm_img = ilm_img.detach()

        # reset gradient
        net_D_optim.zero_grad()

        # train discriminator
        pred_real = net_D(ilm_img, out_img)
        pred_fake = net_D(out_img, ilm_img)

        # compute discriminator loss
        net_D_loss = (bce_loss(pred_real, t.ones_like(pred_real)) +
                      bce_loss(pred_fake, t.zeros_like(pred_fake))) / 2

        # update discriminator
        net_D_loss.backward()
        net_D_optim.step()

        # add to loss meter for logging
        net_D_meter.add(net_D_loss.item())

    # show training status
    if epoch + 1 > 5:
        vis.plot('loss (net_D)', net_D_meter.value()[0])

    # save models
    if (epoch + 1) % opt.save_freq == 0:
        net_D.save()