# Network Training

## Includes

In [None]:
# mass includes
import os, sys, warnings
import torch as t
import torchnet as tnt
import random
from tqdm.notebook import tqdm

# add paths for all sub-folders
paths = [root for root, _, _ in os.walk('.')\
         if 'evals' not in root]
for item in paths:
    sys.path.append(item)

from ipynb.fs.full.config import Config
from ipynb.fs.full.monitor import Visualizer
from ipynb.fs.full.network import *
from ipynb.fs.full.dataLoader import *
from ipynb.fs.full.util import *

## Initialization

In [None]:
# load config
opt = Config()

# setup random environment
if opt.seed is None:
    seed = t.seed()
else:
    seed = opt.seed
t.manual_seed(seed)
random.seed(seed % 2**32)
np.random.seed(seed % 2**32)
print('Use seed', seed)

# define models
net_enc = Encoder().to(opt.device)
net_sum = Summarizer().to(opt.device)
net_mi_disc = MIDiscriminator().to(opt.device)

# dataset for training
train_dataset = ImageSet(opt, mode='train')
train_loader = t.utils.data.DataLoader(train_dataset,
                                       batch_size=opt.batch_size,
                                       shuffle=True,
                                       pin_memory=True,
                                       num_workers=opt.num_workers)

# optimizers
enc_params = list(net_enc.parameters()) + list(net_sum.parameters())
enc_optim = t.optim.Adam(enc_params, lr=opt.lr)
disc_params = list(net_mi_disc.parameters())
disc_optim = t.optim.Adam(disc_params, lr=opt.lr)

# loss function
dim_loss = DIMLoss(l2_penalty=0.0)

# visualizer
vis = Visualizer(env='InfoMax', port=8000)
meter = tnt.meter.AverageValueMeter()

## Training entry (visdom [link](http://localhost:8000))

In [None]:
min_loss = 1e3
for epoch in tqdm(range(opt.max_epoch), desc='epoch', total=opt.max_epoch):
    # reset loss meter
    meter.reset()

    for index, (imgs) in enumerate(train_loader):
        # copy to opt.device
        imgs = imgs.to(opt.device)

        # reset gradient
        enc_optim.zero_grad()
        disc_optim.zero_grad()

        # extract local encodings
        pred_loc_feats = net_enc(imgs)

        # extract global encodings
        pred_glb_feats = net_sum(pred_loc_feats)

        # compute discriminative encodings
        disc_loc_feats, disc_glb_feats = net_mi_disc(pred_loc_feats,
                                                     pred_glb_feats)

        # update loss
        loss = dim_loss(disc_loc_feats, disc_glb_feats)

        # run bp and update params
        loss.backward()
        enc_optim.step()
        disc_optim.step()

        # add to loss meter for logging
        meter.add(loss.item())

    # save models if needed
    if loss.item() < min_loss:
        net_enc.save()
        net_sum.save()
        min_loss = loss.item()

    # show training status
    max_val = t.max(imgs[0, :, :, :])
    min_val = t.min(imgs[0, :, :, :])
    disp_img = (imgs[0, :, :, :] - min_val) / (max_val - min_val)
    vis.img('input', disp_img.cpu() * 255)
    vis.plot('loss', meter.value()[0])
    vis.log('epoch: %d, cur loss: %.2f, min loss: %.2f' %
            (epoch, loss.item(), min_loss))