# Network Training

## Includes

In [None]:
# mass includes
import os, sys, warnings
import ipdb
import torch as t
import torchvision as tv
import torchnet as tnt
from tqdm.notebook import tqdm

# add paths for all sub-folders
paths = [root for root, dirs, files in os.walk('.')]
for item in paths:
    sys.path.append(item)

from ipynb.fs.full.config import mainConf
from ipynb.fs.full.monitor import Visualizer
from ipynb.fs.full.network import r2rNet, rawProcess
from ipynb.fs.full.dataLoader import fivekNight, valSet
from ipynb.fs.full.util import *

## Initialization

In [None]:
# for debugging only
%pdb off
warnings.filterwarnings('ignore')

# choose GPU if available
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
device = t.device('cuda' if t.cuda.is_available() else 'cpu')

# define models
opt = mainConf()
converter = r2rNet().to(device)
converter.load('./saves')
converter.eval()
raw_process_model = rawProcess().to(device)

# load pre-trained model if necessary
if opt.save_root:
    last_epoch = raw_process_model.load(opt.save_root)
else:
    last_epoch = 0

# dataloader for training
train_dataset = fivekNight(opt)
train_loader = t.utils.data.DataLoader(train_dataset,
                                       batch_size=opt.batch_size,
                                       shuffle=True,
                                       num_workers=opt.num_workers,
                                       pin_memory=True)

# dataloader for validation
val_dataset = valSet(opt)
val_loader = t.utils.data.DataLoader(val_dataset)

# optimizer
img_loss = imgLoss(device=device)
raw_process_optim = t.optim.Adam(raw_process_model.parameters(), lr=opt.lr)

# visualizer
vis = Visualizer(env='deepSelfie(rawProcess)', port=8686)
raw_process_meter = tnt.meter.AverageValueMeter()

## Validation

In [None]:
def validate():
    # set to evaluation mode
    raw_process_model.eval()

    isp_psnr = 0.0
    fuse_psnr = 0.0
    for (syth_img, _, _, _, amp, noisy_raw, sorted_mask, wb) in val_loader:
        with t.no_grad():
            # copy to device
            syth_img = syth_img.to(device)
            amp = amp.to(device)
            noisy_raw = noisy_raw.to(device)
            sorted_mask = sorted_mask.to(device)
            wb = wb.to(device)

            # pre-processing
            amp_high, _ = t.max(amp, 1)
            amp_low, _ = t.min(amp, 1)

            # inference
            pred_high, pred_low, pred_fused = raw_process_model(
                noisy_raw, amp_high, amp_low, wb)

            # compute mse
            pred_masked = sorted_mask[:, 0, :, :].unsqueeze(
                1) * pred_high + sorted_mask[:,
                                             1, :, :].unsqueeze(1) * pred_low
            isp_mse = t.nn.functional.mse_loss(pred_masked, syth_img)
            isp_psnr += 10 * t.log10(1 / isp_mse)
            fuse_mse = t.nn.functional.mse_loss(pred_fused, syth_img)
            fuse_psnr += 10 * t.log10(1 / fuse_mse)
    isp_psnr /= len(val_loader)
    fuse_psnr /= len(val_loader)

    # set to training mode
    raw_process_model.train(mode=True)

    return isp_psnr, fuse_psnr

## Training entry

In [None]:
# reset meter and gradient
raw_process_meter.reset()
raw_process_optim.zero_grad()

for index, (syth_img, syth_mask) in tqdm(enumerate(train_loader),
                                         desc='progress',
                                         total=len(train_loader)):
    # copy to device
    syth_img = syth_img.to(device)
    syth_mask = syth_mask.to(device)

    # convert to training sample
    _, _, _, amp, noisy_raw, sorted_mask, wb = toRaw(converter, syth_img,
                                                     syth_mask, opt)
    amp_high, _ = t.max(amp, 1)
    amp_low, _ = t.min(amp, 1)

    # inference
    pred_high, pred_low, pred_fused = raw_process_model(
        noisy_raw, amp_high, amp_low, wb)

    # compute loss
    pred_masked = sorted_mask[:, 0, :, :].unsqueeze(
        1) * pred_high + sorted_mask[:, 1, :, :].unsqueeze(1) * pred_low
    raw_process_loss = img_loss(pred_masked, pred_fused, syth_img)

    # compute gradient
    raw_process_loss.backward()

    # update parameter and reset gradient
    raw_process_optim.step()
    raw_process_optim.zero_grad()

    # add to loss meter for logging
    raw_process_meter.add(raw_process_loss.item())

    # show intermediate result
    if (index + 1) % opt.plot_freq == 0:
        vis.plot('loss (raw process)', raw_process_meter.value()[0])
        raw_process_plot = t.nn.functional.interpolate(
            t.clamp(t.cat([syth_img, pred_high, pred_low, pred_fused], dim=-1),
                    0.0, 1.0),
            scale_factor=0.5)[0, :, :, :]
        vis.img('raw process gt/hi/lo/fuse', raw_process_plot.cpu() * 255)

    # save model
    if (index + 1) % opt.save_freq == 0:
        raw_process_model.save()
        isp_psnr, fuse_psnr = validate()
        vis.log('psnr(isp/fuse): %.2f, %.2f' % (isp_psnr, fuse_psnr))