# Network Training

## Includes

In [None]:
# mass includes
import os, sys, warnings
import ipdb
import torch as t
import torchvision as tv
import torchnet as tnt
from tqdm.notebook import tqdm

# add paths for all sub-folders
paths = [root for root, dirs, files in os.walk('.')]
for item in paths:
    sys.path.append(item)

from ipynb.fs.full.config import mainConf
from ipynb.fs.full.monitor import Visualizer
from ipynb.fs.full.network import r2rNet, gainEst
from ipynb.fs.full.dataLoader import fivekNight, valSet
from ipynb.fs.full.util import *

## Initialization

In [None]:
# for debugging only
%pdb off
warnings.filterwarnings('ignore')

# choose GPU if available
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
device = t.device('cuda' if t.cuda.is_available() else 'cpu')

# define models
opt = mainConf()
converter = r2rNet().to(device)
converter.load('./saves')
converter.eval()
gain_est_model = gainEst().to(device)

# load pre-trained model if necessary
if opt.save_root:
    _ = gain_est_model.load(opt.save_root)

# dataloader for training
train_dataset = fivekNight(opt)
train_loader = t.utils.data.DataLoader(train_dataset,
                                       batch_size=opt.batch_size,
                                       shuffle=True,
                                       num_workers=opt.num_workers,
                                       pin_memory=True)

# dataloader for validation
val_dataset = valSet(opt)
val_loader = t.utils.data.DataLoader(val_dataset)

# optimizer
bce_loss = t.nn.BCEWithLogitsLoss()
l2_loss = t.nn.MSELoss()
gain_est_optim = t.optim.Adam(gain_est_model.parameters(), lr=opt.lr)

# visualizer
vis = Visualizer(env='deepSelfie(gainEst)', port=8686)
gain_est_meter = tnt.meter.AverageValueMeter()

## Validation

In [None]:
def validate():
    # set to evaluation mode
    gain_est_model.eval()

    mask_error = 0.0
    amp_error = 0.0
    for (_, thumb_img, struct_img, seg_mask, amp, _, _, _) in val_loader:
        with t.no_grad():
            # copy to device
            thumb_img = thumb_img.to(device)
            struct_img = struct_img.to(device)
            seg_mask = seg_mask.to(device)
            amp = amp.to(device)

            # inference
            pred_mask, pred_amp = gain_est_model(thumb_img, struct_img)

            # compute mse
            mask_error += t.mean(
                t.abs(t.nn.functional.sigmoid(pred_mask) - seg_mask))
            amp_error += t.mean(t.abs(pred_amp - amp / opt.amp_range[1]))
    mask_error /= len(val_loader)
    amp_error /= len(val_loader)

    # set to training mode
    gain_est_model.train(mode=True)

    return mask_error, amp_error

## Training entry

In [None]:
for epoch in range(0, 2):
    # reset meter and gradient
    gain_est_meter.reset()
    gain_est_optim.zero_grad()

    for index, (syth_img, syth_mask) in tqdm(enumerate(train_loader),
                                             desc='epoch %d' % epoch,
                                             total=len(train_loader)):
        # copy to device
        syth_img = syth_img.to(device)
        syth_mask = syth_mask.to(device)

        # convert to training sample
        thumb_img, struct_img, seg_mask, amp, _, _, _ = toRaw(
            converter, syth_img, syth_mask, opt)

        # inference
        pred_mask, pred_amp = gain_est_model(thumb_img, struct_img)

        # compute loss
        gain_est_loss = bce_loss(pred_mask, seg_mask) + l2_loss(
            pred_amp, amp / opt.amp_range[1])

        # compute gradient
        gain_est_loss.backward()

        # update parameter and reset gradient
        gain_est_optim.step()
        gain_est_optim.zero_grad()

        # add to loss meter for logging
        gain_est_meter.add(gain_est_loss.item())

        # show intermediate result
        if (index + 1) % opt.plot_freq == 0:
            vis.plot('loss (gain est)', gain_est_meter.value()[0])
            gain_est_plot = t.cat(
                [seg_mask, t.nn.functional.sigmoid(pred_mask)],
                dim=-1)[0, 0, :, :]
            vis.img('gain est mask gt/pred', gain_est_plot.cpu() * 255)

        # save model
        if (index + 1) % opt.save_freq == 0:
            gain_est_model.save()
            mask_error, amp_error = validate()
            vis.log('epoch: %d, err(mask/amp): %.4f, %.4f' %
                    (epoch, mask_error, amp_error))