# Denoising Network Training

## Includes

In [None]:
# mass includes
import os, sys, warnings
import ipdb
import torch as t
import torchnet as tnt
from tqdm.notebook import tqdm
from torch.nn.functional import mse_loss

# add paths for all sub-folders
paths = [root for root, _, _ in os.walk('.')\
         if 'evals' not in root]
for item in paths:
    sys.path.append(item)

from ipynb.fs.full.config import Config
from ipynb.fs.full.monitor import Visualizer
from ipynb.fs.full.network import *
from ipynb.fs.full.dataLoader import *
from ipynb.fs.full.util import *

## Initialization

In [None]:
cam_model = 'Pixel3'

In [None]:
# for debugging only
%pdb off
warnings.filterwarnings('ignore')

# choose GPU if available
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
device = t.device('cuda' if t.cuda.is_available() else 'cpu')

# define models
opt = Config('denoise')
net_N = Denoiser(cam_model=cam_model).to(device)

# dataset for training
train_dataset = DenoiseSet(opt, cam_model=cam_model, mode='train')
train_loader = t.utils.data.DataLoader(train_dataset,
                                       batch_size=opt.batch_size,
                                       shuffle=True,
                                       pin_memory=True,
                                       num_workers=opt.num_workers)

# dataset for validation
val_dataset = DenoiseSet(opt, cam_model=cam_model, mode='val')
val_loader = t.utils.data.DataLoader(val_dataset)

# optimizer
net_N_optim = t.optim.Adam(net_N.parameters(), lr=opt.lr)
net_N_sched = t.optim.lr_scheduler.StepLR(net_N_optim,
                                          step_size=opt.upd_freq,
                                          gamma=opt.lr_decay)

# loss function
l1_loss = t.nn.L1Loss()
mse_loss = t.nn.MSELoss()

# visualizer
vis = Visualizer(env='denoiser', port=8686)
net_N_meter = tnt.meter.AverageValueMeter()

## Validation

In [None]:
def validate():
    # set to evaluation mode
    net_N.eval()

    psnr_list = []
    for index, (noisy_raw, clean_raw, noise_map, ilm_coes,
                cam2xyz) in enumerate(val_loader):
        # copy to device
        noisy_raw = noisy_raw.to(device)
        clean_raw = clean_raw.to(device)
        noise_map = noise_map.to(device)
        ilm_coes = ilm_coes.to(device)
        cam2xyz = cam2xyz.to(device)

        # inference
        with t.no_grad():
            # initial denoising
            dnoise_raw = net_N(noisy_raw, noise_map)

            # illumination enhancement
            for ch in range(ilm_coes.size(1)):
                coe_slice = ilm_coes[:, ch, :, :].unsqueeze(1)

                # for denoised image
                lmn_img = rgb2lumin(dnoise_raw, cam2xyz)
                res = coe_slice * (1.0 - lmn_img) * dnoise_raw
                dnoise_raw = net_N(noisy_raw, noise_map, res)

                # for clean image
                lmn_img = rgb2lumin(clean_raw, cam2xyz)
                res = coe_slice * (1.0 - lmn_img) * clean_raw
                clean_raw = clean_raw + res
                clean_raw = t.clamp(clean_raw, 0.0, 1.0)

        # compute PSNR
        psnr = 10 * t.log10(1 / mse_loss(dnoise_raw, clean_raw))
        psnr_list.append(psnr)

    # set to training mode
    net_N.train(mode=True)
    avg_psnr = t.mean(t.stack(psnr_list)).item()

    return avg_psnr

## Training entry

In [None]:
prev_psnr = 0
for epoch in tqdm(range(opt.max_epoch), desc='epoch', total=opt.max_epoch):
    # reset meter
    net_N_meter.reset()
    net_N_sched.step()

    for index, (noisy_raw, clean_raw, noise_map, ilm_coes,
                cam2xyz) in enumerate(train_loader):
        # copy to device
        noisy_raw = noisy_raw.to(device)
        clean_raw = clean_raw.to(device)
        noise_map = noise_map.to(device)
        ilm_coes = ilm_coes.to(device)
        cam2xyz = cam2xyz.to(device)

        # reset gradient
        net_N_optim.zero_grad()

        # initial denoising
        dnoise_raw = net_N(noisy_raw, noise_map)
        net_N_loss = l1_loss(dnoise_raw, clean_raw)

        # illumination enhancement
        for ch in range(ilm_coes.size(1)):
            coe_slice = ilm_coes[:, ch, :, :].unsqueeze(1)

            # for denoised image
            lmn_img = rgb2lumin(dnoise_raw, cam2xyz)
            res = coe_slice * (1.0 - lmn_img) * dnoise_raw
            dnoise_raw = net_N(noisy_raw, noise_map, res)

            # for clean image
            lmn_img = rgb2lumin(clean_raw, cam2xyz)
            res = coe_slice * (1.0 - lmn_img) * clean_raw
            clean_raw = clean_raw + res
            clean_raw = t.clamp(clean_raw, 0.0, 1.0)

            # update loss
            net_N_loss = net_N_loss + l1_loss(dnoise_raw, clean_raw)

        # update network params
        net_N_loss.backward()
        net_N_optim.step()

        # add to loss meter for logging
        net_N_meter.add(net_N_loss.item())

    # show training status
    for ch in range(ilm_coes.size(1)):
        coe_slice = ilm_coes[:, ch, :, :].unsqueeze(1)

        # for clean image
        lmn_img = rgb2lumin(noisy_raw, cam2xyz)
        res = coe_slice * (1.0 - lmn_img) * noisy_raw
        noisy_raw = noisy_raw + res
        noisy_raw = t.clamp(noisy_raw, 0.0, 1.0)
    noisy_img = cam2sRGB(noisy_raw, cam2xyz)
    noisy_img = t.where(noisy_img <= 0.0031308, 12.92 * noisy_img,
                        1.055 * (noisy_img**(1 / 2.4)) - 0.055)
    dnoise_img = cam2sRGB(dnoise_raw, cam2xyz)
    dnoise_img = t.where(dnoise_img <= 0.0031308, 12.92 * dnoise_img,
                         1.055 * (dnoise_img**(1 / 2.4)) - 0.055)
    clean_img = cam2sRGB(clean_raw, cam2xyz)
    clean_img = t.where(clean_img <= 0.0031308, 12.92 * clean_img,
                        1.055 * (clean_img**(1 / 2.4)) - 0.055)
    disp_img = t.cat([noisy_img, dnoise_img, clean_img], dim=3)[0, :, :, :]
    vis.img('noisy/pred/gt', disp_img.cpu() * 255)
    vis.plot('loss(net_N)', net_N_meter.value()[0])

    # perform validation and save models if needed
    cur_psnr = validate()
    if cur_psnr > prev_psnr:
        net_N.save()
        prev_psnr = cur_psnr
        vis.log('epoch: %d, val psnr: %.2f' % (epoch, cur_psnr))