# Network Training

## Includes

In [None]:
# mass includes
import os, sys, warnings
import ipdb
import math as m
import torch as t
import torchnet as tnt
from tqdm import tqdm_notebook as tqdm

# add paths for all sub-folders
paths = [root for root, dirs, files in os.walk('.')]
for item in paths:
    sys.path.append(item)

from ipynb.fs.full.config import Config
from ipynb.fs.full.module import BasicModule
from ipynb.fs.full.monitor import Visualizer
from ipynb.fs.full.network import UNet
from ipynb.fs.full.dataLoader import sonySet

## Initialization

In [None]:
# enable debugging
%pdb off
warnings.filterwarnings('ignore')

# choose GPU if available
device = t.device('cuda' if t.cuda.is_available() else 'cpu')

# define model
opt = Config()
model = UNet().to(device)

# load pre-trained model if necessary
if opt.save_root:
    model.load(opt.save_root, device=device)

# dataloader for training
train_dataset = sonySet(opt.data_root, 512, opt.img_size, mode='train')
train_loader = t.utils.data.DataLoader(
    train_dataset,
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.num_workers,
    pin_memory=True)

# # dataloader for validation
# val_dataset = sonySet(opt.data_root, 512, opt.img_size, mode='val')
# val_loader = t.utils.data.DataLoader(val_dataset)

# optimizer
criterion = t.nn.SmoothL1Loss()
optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
scheduler = t.optim.lr_scheduler.StepLR(
    optimizer, step_size=opt.upd_freq, gamma=opt.lr_decay)

# visualizer
vis = Visualizer()
loss_meter = tnt.meter.AverageValueMeter()
psnr_meter = tnt.meter.AverageValueMeter()

## Training entry

In [None]:
for epoch in tqdm(range(opt.max_epoch), desc='epoch', total=opt.max_epoch):
    # reset meter and update learning rate
    loss_meter.reset()
    scheduler.step()

    for index, (img_batch, gt_batch) in enumerate(train_loader):
        # reset gradient
        optimizer.zero_grad()

        # inference
        img_batch = img_batch.to(device)
        gt_batch = gt_batch.to(device)
        pred_batch = model(img_batch)

        # compute loss
        loss = criterion(pred_batch, gt_batch)

        # backpropagation
        loss.backward()
        optimizer.step()

        # add to meters for logging
        loss_meter.add(loss.item())
        mse = t.nn.functional.mse_loss(pred_batch, gt_batch)
        psnr_meter.add(-10 * m.log10(mse))

    # add a new log
    img_show = t.cat((pred_batch[0, :, :, :], gt_batch[0, :, :, :]),
                     dim=2).unsqueeze(0)
    img_show = t.clamp(img_show, 0.0, 1.0)
    img_show = t.nn.functional.interpolate(img_show, scale_factor=0.3)
    vis.log('epoch: %d, loss: %.3E' % (epoch, loss_meter.value()[0]))
    vis.plot('PSNR', psnr_meter.value()[0])
    vis.img('pred/gt', img_show.squeeze())

    # save model
    if (epoch + 1) % opt.save_freq == 0:
        model.save()

#     # validation
#     if (epoch + 1) % opt.val_freq == 0:
#         model.eval()
#         val_psnr = 0.0
#         for index, (img_batch, gt_batch) in enumerate(val_loader):
#             # inference
#             img_batch = img_batch.to(device)
#             gt_batch = gt_batch.to(device)
#             pred_batch = model(img_batch)

#             # add to meters for logging
#             mse = t.nn.functional.mse_loss(pred_batch, gt_batch)
#             val_psnr += -10 * m.log10(mse)

#         val_psnr /= index
#         vis.log('val PSNR: %.3f' % val_psnr)
#         model.train()