# Network Training

## Includes

In [None]:
# mass includes
import os, sys, warnings
import ipdb
import torch as t
import torchnet as tnt
from tqdm.notebook import tqdm

# add paths for all sub-folders
paths = [root for root, dirs, files in os.walk('.')]
for item in paths:
    sys.path.append(item)

from ipynb.fs.full.config import r2rNetConf
from ipynb.fs.full.monitor import Visualizer
from ipynb.fs.full.network import r2rNet
from ipynb.fs.full.dataLoader import r2rSet
from ipynb.fs.full.util import *

## Initialization

In [None]:
# for debugging only
%pdb off
warnings.filterwarnings('ignore')

# choose GPU if available
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
device = t.device('cuda' if t.cuda.is_available() else 'cpu')

# define model
opt = r2rNetConf()
model = r2rNet().to(device)

# load pre-trained model if necessary
if opt.save_root:
    last_epoch = model.load(opt.save_root)
    last_epoch += opt.save_epoch
else:
    last_epoch = 0

# dataloader for training
train_dataset = r2rSet(opt, mode='train')
train_loader = t.utils.data.DataLoader(train_dataset,
                                       batch_size=opt.batch_size,
                                       shuffle=True,
                                       num_workers=opt.num_workers,
                                       pin_memory=True)

# dataloader for validation
val_dataset = r2rSet(opt, mode='val')
val_loader = t.utils.data.DataLoader(val_dataset)

# optimizer
last_lr = opt.lr * opt.lr_decay**(last_epoch // opt.upd_freq)
optimizer = t.optim.Adam(model.parameters(), lr=last_lr)
scheduler = t.optim.lr_scheduler.StepLR(optimizer,
                                        step_size=opt.upd_freq,
                                        gamma=opt.lr_decay)

# visualizer
vis = Visualizer(env='r2rNet', port=8686)
loss_meter = tnt.meter.AverageValueMeter()

## Validation

In [None]:
def validate():
    # set to evaluation mode
    model.eval()

    psnr = 0.0
    for (raw_patch, srgb_patch, cam_wb) in val_loader:
        with t.no_grad():
            # copy to device
            raw_patch = raw_patch.to(device)
            srgb_patch = srgb_patch.to(device)
            rggb_patch = toRGGB(srgb_patch)
            cam_wb = cam_wb.to(device)

            # inference
            pred_patch = model(rggb_patch, cam_wb)
            pred_patch = t.clamp(pred_patch, 0.0, 1.0)

            # compute psnr
            mse = t.mean((pred_patch - raw_patch)**2)
            psnr += 10 * t.log10(1 / mse)
    psnr /= len(val_loader)

    # set to training mode
    model.train(mode=True)

    return psnr

## Training entry

In [None]:
for epoch in tqdm(range(last_epoch, opt.max_epoch),
                  desc='epoch',
                  total=opt.max_epoch - last_epoch):
    # reset meter and update learning rate
    loss_meter.reset()
    scheduler.step()

    for (raw_patch, srgb_patch, cam_wb) in train_loader:
        # reset gradient
        optimizer.zero_grad()

        # copy to device
        raw_patch = raw_patch.to(device)
        srgb_patch = srgb_patch.to(device)
        rggb_patch = toRGGB(srgb_patch)
        cam_wb = cam_wb.to(device)

        # inference
        pred_patch = model(rggb_patch, cam_wb)

        # compute loss
        loss = t.mean(t.abs(pred_patch - raw_patch))

        # backpropagation
        loss.backward()
        optimizer.step()

        # add to loss meter for logging
        loss_meter.add(loss.item())

    # show training status
    vis.plot('loss', loss_meter.value()[0])
    gt_img = raw2Img(raw_patch[0, :, :, :],
                     wb=opt.d65_wb,
                     cam_matrix=opt.cam_matrix)
    pred_img = raw2Img(pred_patch[0, :, :, :],
                       wb=opt.d65_wb,
                       cam_matrix=opt.cam_matrix)
    vis.img('gt/pred/mask', t.cat([gt_img, pred_img], dim=2).cpu() * 255)

    # save model and do validation
    if (epoch + 1) > opt.save_epoch or (epoch + 1) % 50 == 0:
        model.save()
        psnr = validate()
        vis.log('epoch: %d, psnr: %.2f' % (epoch, psnr))