# Network Training

## Includes

In [None]:
# mass includes
import os, sys, warnings
import math
import torch as t
import torchnet as tnt
import random
from tqdm.notebook import tqdm

# add paths for all sub-folders
paths = [root for root, _, _ in os.walk('.')\
         if 'evals' not in root]
for item in paths:
    sys.path.append(item)

from ipynb.fs.full.config import Config
from ipynb.fs.full.monitor import Visualizer
from ipynb.fs.full.network import *
from ipynb.fs.full.dataLoader import *
from ipynb.fs.full.util import *

## Initialization

In [None]:
# load config
opt = Config()

# setup random environment
if opt.seed is None:
    seed = t.seed()
else:
    seed = opt.seed
t.manual_seed(seed)
random.seed(seed % 2**32)
np.random.seed(seed % 2**32)
print('Use seed', seed)

# define models
net_enc = Encoder().to(opt.device)
net_sum = Summarizer().to(opt.device)
net_mi_disc = MIDiscriminator().to(opt.device)
net_pr_disc = PirorDiscriminator().to(opt.device)

# dataset for training
train_dataset = ImageSet(opt,
                         mode='train',
                         norm=True,
                         rand_trans=True,
                         mask_out=True)
train_loader = t.utils.data.DataLoader(train_dataset,
                                       batch_size=opt.batch_size,
                                       shuffle=True,
                                       pin_memory=True,
                                       num_workers=opt.num_workers,
                                       drop_last=True)

# dataset for testing
test_dataset = ImageSet(opt,
                        mode='test',
                        norm=True,
                        rand_trans=False,
                        mask_out=True)
test_loader = t.utils.data.DataLoader(test_dataset,
                                      batch_size=opt.data_part[1])

# optimizers
enc_params = list(net_enc.parameters()) + list(net_sum.parameters())
enc_optim = t.optim.Adam(enc_params, lr=opt.lr)
disc_params = list(net_mi_disc.parameters()) + list(net_pr_disc.parameters())
disc_optim = t.optim.Adam(disc_params, lr=opt.lr)

# visualizer
vis = Visualizer(env='InfoMax', port=8000)
meter = tnt.meter.AverageValueMeter()

## Validation entry

In [None]:
def validate():
    # set to evaluation mode
    net_enc.eval()
    net_sum.eval()
    net_mi_disc.eval()
    net_pr_disc.eval()

    for s_idx, (sample, f_names) in enumerate(test_loader):
        # copy to device
        img = sample[:, :, :3, :, :].to(opt.device)
        mask = sample[:, :, 3, :, :].to(opt.device).unsqueeze(2)

        # reshape for batch processing
        b, n, _, h, w = sample.size()
        imgs = img.view(b * n, -1, h, w)
        masks = mask.view(b * n, -1, h, w)

        # central cropping
        crop_h = math.floor(h / 2 - opt.crop_size / 2)
        crop_w = math.floor(w / 2 - opt.crop_size / 2)
        imgs = imgs[:, :, crop_h:crop_h + opt.crop_size,
                    crop_w:crop_w + opt.crop_size]
        masks = masks[:, :, crop_h:crop_h + opt.crop_size,
                      crop_w:crop_w + opt.crop_size]

        with t.no_grad():
            # inference
            pred_loc_feats, pred_loc_masks = net_enc(imgs, masks)
            pred_glb_feats = net_sum(pred_loc_feats, pred_loc_masks)
            disc_loc_feats, disc_glb_feats, disc_loc_masks = net_mi_disc(
                pred_loc_feats, pred_glb_feats, pred_loc_masks)

            # decouple batch
            disc_loc_feats = disc_loc_feats.view(b, n, disc_loc_feats.size(1),
                                                 disc_loc_feats.size(2))
            disc_glb_feats = disc_glb_feats.view(b, n, disc_glb_feats.size(1),
                                                 disc_glb_feats.size(2))
            disc_loc_masks = disc_loc_masks.view(b, n, disc_loc_masks.size(1),
                                                 disc_loc_masks.size(2))

            # compute mutual information
            mi_loss = calcDIMLoss(disc_loc_feats, disc_glb_feats,
                                  disc_loc_masks)

    # set to training mode
    net_enc.train(mode=True)
    net_sum.train(mode=True)
    net_mi_disc.train(mode=True)
    net_pr_disc.train(mode=True)

    return mi_loss.item()

## Training entry (visdom [link](http://localhost:8000))

In [None]:
min_loss = 1e3
for epoch in tqdm(range(opt.max_epoch), desc='epoch', total=opt.max_epoch):
    # reset loss meter
    meter.reset()

    for index, (samples, _) in enumerate(train_loader):
        # copy to device
        imgs = samples[:, :, :3, :, :].to(opt.device)
        masks = samples[:, :, 3, :, :].to(opt.device).unsqueeze(2)

        # reshape for batch processing
        b, n, _, h, w = samples.size()
        imgs = imgs.view(b * n, -1, h, w)
        masks = masks.view(b * n, -1, h, w)

        # reset gradient
        enc_optim.zero_grad()
        disc_optim.zero_grad()

        # extract local encodings
        pred_loc_feats, pred_loc_masks = net_enc(imgs, masks)

        # extract global encodings
        pred_glb_feats = net_sum(pred_loc_feats, pred_loc_masks)

        # compute mi discriminative encodings
        disc_loc_feats, disc_glb_feats, disc_loc_masks = net_mi_disc(
            pred_loc_feats, pred_glb_feats, pred_loc_masks)

        # compute prior discriminative encodings
        prior = t.rand_like(pred_glb_feats)
        prior_disc = net_pr_disc(prior)
        glb_disc = net_pr_disc(pred_glb_feats)

        # decouple batch
        disc_loc_feats = disc_loc_feats.view(b, n, disc_loc_feats.size(1),
                                             disc_loc_feats.size(2))
        disc_glb_feats = disc_glb_feats.view(b, n, disc_glb_feats.size(1),
                                             disc_glb_feats.size(2))
        disc_loc_masks = disc_loc_masks.view(b, n, disc_loc_masks.size(1),
                                             disc_loc_masks.size(2))

        # compute mutual information
        mi_loss = calcDIMLoss(disc_loc_feats, disc_glb_feats, disc_loc_masks)

        # compute cross entropy
        pr_loss = -(t.log(prior_disc).mean() + t.log(1.0 - glb_disc).mean())

        # update loss
        loss = 1.0 * mi_loss + 0.1 * pr_loss

        # update network params
        loss.backward()
        enc_optim.step()
        disc_optim.step()

        # add to loss meter for logging
        meter.add(loss.item())

    # save models if needed
    val_loss = validate()
    if val_loss < min_loss:
        net_enc.save()
        net_sum.save()
        min_loss = val_loss

    # show training status
    if epoch + 1 >= 0.1 * opt.max_epoch:
        vis.plot('training loss', meter.value()[0])

    max_val = t.max(imgs[0, :, :, :])
    min_val = t.min(imgs[0, :, :, :])
    last_img = (imgs[0, :, :, :] - min_val) / (max_val - min_val)
    vis.img('input', last_img)
    vis.img('mask', masks[0, :, :, :])
    vis.log('epoch: %d, cur val loss: %.4f, min val loss: %.4f' %
            (epoch, val_loss, min_loss))