In [None]:
%matplotlib inline

import os
os.chdir("../../")  # If ran from test directory.

import mkl
nproc = mkl.get_max_threads()  # e.g. 12
mkl.set_num_threads(nproc)

import random
from time import time

import numpy as np
from PIL import Image
import torch
from torch import optim, nn
from torch.autograd import Variable
import torch.nn.functional as F

from portraitseg.pytorch_dataloaders import get_train_valid_loader
from portraitseg.utils import (plots,
                               set_seed,
                               show_portrait_pred_mask,
                               scoretensor2mask)
from portraitseg.portraitfcn import PortraitFCN

###############################################################################


def cross_entropy2d(input, target, weight=None, size_average=True):
    n, c, h, w = input.size()
    log_p = F.log_softmax(input)
    log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0].cuda()
    log_p = log_p.view(-1, c)
    mask = target >= 0
    target = target[mask]
    loss = F.nll_loss(log_p, target, weight=weight, size_average=False)
    if size_average:
        loss /= mask.data.sum()
    return loss

###############################################################################

SEED = 3
set_seed(SEED)

###############################################################################

DATA_DIR = "../data/"
FLICKR_DIR = DATA_DIR + "portraits/flickr/"

###############################################################################

# Load pretrained FCN8s
model = PortraitFCN().cuda()
path_to_weights = "portraitseg/portraitfcn_untrained.pth"
model.load_state_dict(torch.load(path_to_weights))

###############################################################################

# Hyperparameters
BATCH_SIZE = 1
AUGMENT = False
LR = 1e-10
NB_EPOCHS = 200
VALID_SIZE = 0.2
loss_fn = cross_entropy2d
optimizer = optim.SGD(model.parameters(), lr=LR)

# Get DataLoaders
trn_loader, val_loader = get_train_valid_loader(FLICKR_DIR,
                                                batch_size=BATCH_SIZE,
                                                augment=AUGMENT,
                                                random_seed=SEED,
                                                valid_size=VALID_SIZE,
                                                show_sample=False,
                                                num_workers=6,
                                                pin_memory=True)

portraits, masks = next(iter(trn_loader))
portraits, masks = Variable(portraits).cuda(), Variable(masks).cuda()

portrait = portraits[0].data.clone().cpu()
mask = masks[0].data.clone().cpu()

# Train
    # Roughly 42 seconds per 100 epochs on one sample
    # Idea: "Is the network powerful enough to at least memorize?"
preds = []
start = time()
for epoch in range(NB_EPOCHS):
    model.train()
    outputs = model(portraits)
    loss = loss_fn(outputs, masks, size_average=False)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % 100 == 99:
        # Test on validation set
        model.eval()
        outputs = model(portraits)
        loss = loss_fn(outputs, masks, size_average=False)
        print("Validation loss: %.3f" % (loss.data[0]))
        print("Duration (1 sample, 100 epochs): %.2f seconds" % (time() - start))
        scoretensor = outputs[0].data.cpu()
        pred = scoretensor2mask(scoretensor)
        preds.append(pred)
        show_portrait_pred_mask(portrait, preds, mask)
        start = time()

print("Training complete.")
