In [None]:
import os, glob, platform, datetime, random
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable
from torch import functional as F
# import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

import cv2
from PIL import Image
from tensorboardX import SummaryWriter

import numpy as np
from numpy.linalg import inv as denseinv
from scipy import sparse
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import inv as spinv
import scipy.misc

from myimagefoldereccv import MyImageFolder
from mymodel import GradientNet
from myargs import Args
from myutils import MyUtils

# Configurations

In [None]:
myutils = MyUtils()

args = Args()
args.arch = "densenet121"
args.epoches = 500
args.epoches_unary_threshold = 0
args.image_h = 256
args.image_w = 256
args.img_extentions = ["png"]
args.training_thresholds = [250,200,150,50,0,300]
args.base_lr = 1
args.lr = args.base_lr
args.snapshot_interval = 5000
args.debug = True


# growth_rate = (4*(2**(args.gpu_num)))
transition_scale=2
pretrained_scale=4
growth_rate = 32

#######
args.test_scene = ['alley_1', 'bamboo_1', 'bandage_1', 'cave_2', 'market_2', 'market_6', 'shaman_2', 'sleeping_1', 'temple_2']
gradient=False
args.gpu_num = 0
#######

writer_comment = 'eccv_albedo'


offset = 0.
if gradient == True: offset = 0.5

args.display_interval = 50
args.display_curindex = 0

system_ = platform.system()
system_dist, system_version, _ = platform.dist()
if system_ == "Darwin": 
    args.train_dir = '/Volumes/Transcend/dataset/sintel2'
    args.pretrained = False
elif platform.dist() ==  ('debian', 'jessie/sid', ''):
    args.train_dir = '/home/albertxavier/dataset/sintel2'
    args.pretrained = True
elif platform.dist() == ('debian', 'stretch/sid', ''):
    args.train_dir = '/home/cad/lwp/workspace/dataset/sintel2'
    args.pretrained = True

if platform.system() == 'Linux': use_gpu = True
else: use_gpu = False
if use_gpu:
    torch.cuda.set_device(args.gpu_num)
    

print(platform.dist())

# My DataLoader

In [None]:


train_dataset = MyImageFolder(args.train_dir, 'train',
                       transforms.Compose(
        [transforms.ToTensor()]
    ), random_crop=True, 
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)
test_dataset = MyImageFolder(args.train_dir, 'test', 
                       transforms.Compose(
        [transforms.CenterCrop((args.image_h, args.image_w)),
         transforms.ToTensor()]
    ), random_crop=False,
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)

train_loader = data_utils.DataLoader(train_dataset,1,True,num_workers=1)
test_loader = data_utils.DataLoader(test_dataset,1,True,num_workers=1)

# Load Pretrained Model

[Defination](https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py)
* DenseNet-121: num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16)
    * First Convolution: 32M -> 16M -> 8M
    * every transition: 8M -> 4M -> 2M (downsample 1/2, except the last block)

In [None]:
densenet = models.__dict__[args.arch](pretrained=args.pretrained)

for param in densenet.parameters():
    param.requires_grad = False

if use_gpu: densenet.cuda()


In [None]:


args.display_curindex = 0
args.base_lr = 0.01
args.display_interval = 20
args.momentum = 0.9
args.epoches = 240
args.training_thresholds = 240//4
args.power = 0.5



net = GradientNet(densenet=densenet, growth_rate=growth_rate, 
                  transition_scale=transition_scale, pretrained_scale=pretrained_scale,
                 gradient=gradient)
if use_gpu:
    net.cuda()

mse_losses = nn.MSELoss()
mse_losses_dx = nn.MSELoss()
mse_losses_dy = nn.MSELoss()
if use_gpu:
    mse_losses = nn.MSELoss().cuda()
    mse_losses_dx = nn.MSELoss().cuda()
    mse_losses_dy = nn.MSELoss().cuda()

parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=args.base_lr, momentum=args.momentum)

In [None]:
def train_eval_model_per_epoch(epoch, net, args, train_loader, test_loader, phase='train'):
    if phase == 'train':
        volatile = False
        net.train()
    else:
        volatile = True
#         net.eval()
        net.train()
    
    print('epoch: {} [{}]'.format(epoch, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

    """adjust learning rate"""
    if epoch < args.training_thresholds: 
        myutils.adjust_learning_rate(optimizer, args, epoch, beg=0, end=args.training_thresholds-1)
    else:
        myutils.adjust_learning_rate(optimizer, args, epoch, beg=args.training_thresholds, end=args.epoches)
    writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], global_step=epoch)

    """init statics"""
    run_losses_unary = 0.
    run_losses_dx = 0.
    run_losses_dy = 0.
    run_cnts   = 0.00001

    """for all training/test data"""
    loader = train_loader if phase == 'train' else test_loader
    
    for ind, data in enumerate(loader, 0):
        """prepare data"""
        input_img, gt_albedo, gt_shading, cur_scene, img_path = data
        cur_frame = img_path[0].split('/')[-1]
        input_img = Variable(input_img, volatile=volatile)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu: 
            input_img, gt_albedo, gt_shading = input_img.cuda(), gt_albedo.cuda(), gt_shading.cuda()
        
        if phase == 'train':
            optimizer.zero_grad()
        
        res = net(input_img)
        
        """prepare gradient"""
        gt_dx = myutils.makeGradientTorch(gt_albedo, direction='x')
        gt_dy = myutils.makeGradientTorch(gt_albedo, direction='y')
        res_dx = myutils.makeGradientTorch(res, direction='x')
        res_dy = myutils.makeGradientTorch(res, direction='y')
        
        """compute loss"""
        if phase == 'train':
            """unary loss"""
            loss_unary = mse_losses(res, gt_albedo)
            loss_dx = mse_losses(2*res_dx, 2*gt_dx)
            loss_dy = mse_losses(2*res_dy, 2*gt_dy)
            
            run_losses_unary += loss_unary.data.cpu().numpy()[0]
            run_losses_dx += loss_dx.data.cpu().numpy()[0]
            run_losses_dy += loss_dy.data.cpu().numpy()[0]
            
            run_cnts += 1

        """backward"""
        if phase == 'train':
            loss_unary.backward(retain_graph=True)
            loss_dx.backward(retain_graph=True)
            loss_dy.backward(retain_graph=True)
            optimizer.step()
        
        """generate display img"""
        display_im = myutils.tensor2Numpy(input_img)[:,:,::-1]*255
        display_gt = myutils.tensor2Numpy(gt_albedo)[:,:,::-1]*255
        display_res = myutils.tensor2Numpy(res)[:,:,::-1]*255

        """display"""
        if (phase == 'train' and args.display_curindex % args.display_interval == 0) or \
        (phase == 'test' and cur_scene == 'alley_1' and cur_frame == 'frame_0001.png'):
            # print('display ', phase, img_path, display_im.shape)
            cv2.imwrite('snapshot{}/input.png'.format(args.gpu_num), display_im)
            cv2.imwrite('snapshot{}/{}-gt-{}.png'.format(args.gpu_num, phase, epoch), display_gt) 
            cv2.imwrite('snapshot{}/{}-rs-{}.png'.format(args.gpu_num, phase, epoch), display_res)
        
        args.display_curindex += 1
    
    """output loss"""
    loss_output = ''
    loss_output += '{} loss: '.format(phase)
    loss_output += 'unary: %6f ' % (run_losses_unary/run_cnts)
    loss_output += 'pairwise: %6f ' % ((run_losses_dx+run_losses_dy)/run_cnts)
    loss_output += 'crf: %6f' % ((run_losses_unary+run_losses_dx+run_losses_dy)/run_cnts)
    print(loss_output)
    
    """write to tensorboard"""
    writer.add_scalars('loss', {phase: np.array([(run_losses_unary+run_losses_dx+run_losses_dy)/run_cnts])}, global_step=epoch)
    
    """save snapshot"""
    if phase == 'train':
        myutils.save_snapshot(epoch, args, net, optimizer)
    
    

In [None]:
"""training loop"""
writer = SummaryWriter(comment='-{}'.format(writer_comment))

for epoch in range(args.epoches):
    phase = 'test' if (epoch+1) % 5 == 0 else 'train'
    train_eval_model_per_epoch(epoch, net, args, train_loader, test_loader, phase=phase)

# Visualize Graph

In [None]:
# x = Variable(torch.zeros(1,3,256,256))
# y = net(x.cuda())
# g = make_dot(y[-1])


In [None]:
# g.render('net-transition_scale_{}'.format(transition_scale)) 