# Network Training

## Includes

In [None]:
# mass includes
import os, sys, warnings
import ipdb
import torch as t
import torchvision as tv
import torchnet as tnt
from tqdm import tqdm_notebook as tqdm

# add paths for all sub-folders
paths = [root for root, dirs, files in os.walk('.')]
for item in paths:
    sys.path.append(item)

from ipynb.fs.full.config import Config
from ipynb.fs.full.monitor import Visualizer
from ipynb.fs.full.network import MobileNetV2
from ipynb.fs.full.util import *

## Initialization

In [None]:
# for debugging only
%pdb off
warnings.filterwarnings('ignore')

# imageNet mean and std
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# choose GPU if available
device = t.device('cuda' if t.cuda.is_available() else 'cpu')

# define model
opt = Config()
model = MobileNetV2().to(device)

# load pre-trained model if necessary
if opt.save_root:
    model.load(opt.save_root)

# dataloader for training
train_dir = os.path.join(opt.data_root, 'train')
train_dataset = tv.datasets.ImageFolder(
    train_dir,
    tv.transforms.Compose([
        tv.transforms.RandomResizedCrop(opt.img_size, scale=(0.2, 1.0)),
        tv.transforms.RandomHorizontalFlip(),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ]))
train_loader = t.utils.data.DataLoader(
    train_dataset,
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.num_workers,
    pin_memory=True)

# dataloader for validation
val_dir = os.path.join(opt.data_root, 'val')
val_dataset = tv.datasets.ImageFolder(
    val_dir,
    tv.transforms.Compose([
        tv.transforms.Resize(int(opt.img_size / 0.875)),
        tv.transforms.CenterCrop(opt.img_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ]))
val_loader = t.utils.data.DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=opt.num_workers,
    pin_memory=True)

# optimizer
criterion = t.nn.CrossEntropyLoss()
optimizer = t.optim.SGD(
    model.parameters(),
    lr=opt.lr,
    momentum=opt.momentum,
    weight_decay=opt.weight_decay)
scheduler = t.optim.lr_scheduler.StepLR(
    optimizer, step_size=opt.upd_freq, gamma=opt.lr_decay)

# visualizer
vis = Visualizer(port=8866)
loss_meter = tnt.meter.AverageValueMeter()

## Training entry

In [None]:
def validate():
    correct = 0
    total = 0

    # evaluation mode
    model.eval()

    for (img_batch, gnd_batch) in val_loader:
        # inference
        with t.no_grad():
            img_batch = img_batch.to(device)
            gnd_batch = gnd_batch.to(device)
            pred_batch = model(img_batch)

        # match results
        _, index_batch = t.topk(pred_batch.data, 5, dim=1)
        index_batch = index_batch.t()
        gnd_batch = gnd_batch - 1
        total += gnd_batch.size(0)
        correct += index_batch.eq(
            gnd_batch.view(1, -1).expand_as(index_batch)).sum()

    # training mode
    model.train(mode=True)

    return float(correct) / float(total) * 100.0


for epoch in tqdm(range(opt.max_epoch), desc='epoch', total=opt.max_epoch):
    # reset meter and update learning rate
    loss_meter.reset()
    scheduler.step()

    for index, (img_batch, gnd_batch) in enumerate(train_loader):
        # reset gradient
        optimizer.zero_grad()

        # inference
        img_batch = img_batch.to(device)
        gnd_batch = gnd_batch.to(device)
        pred_batch = model(img_batch)

        # compute loss
        loss = criterion(pred_batch, gnd_batch)

        # backpropagation
        loss.backward()
        optimizer.step()

        # add to loss meter for logging
        loss_meter.add(loss.item())
        if (index + 1) % opt.plot_freq == 0:
            vis.plot('loss', loss_meter.value()[0])
            vis.log('epoch: {epoch}, loss: {loss:.5f}'.format(
                epoch=epoch, loss=loss_meter.value()[0]))

    # save model
    model.save()

    # validation
    accuracy = validate()
    vis.log('lr: {lr:.5f}, acc@5: {top5:.3f}'.format(
        lr=opt.lr * opt.lr_decay**(epoch // opt.upd_freq), top5=accuracy))