In [1]:
import os, glob, platform, datetime, random
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable
from torch import functional as F
# import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

import cv2
from PIL import Image
from tensorboardX import SummaryWriter

import numpy as np
from numpy.linalg import inv as denseinv
from scipy import sparse
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import inv as spinv
import scipy.misc

from myimagefolder import MyImageFolder
from mymodel import GradientNet
from myargs import Args
from myutils import MyUtils

# Configurations

In [2]:
myutils = MyUtils()

args = Args()
args.arch = "densenet121"
args.epoches = 500
args.epoches_unary_threshold = 0
args.image_h = 256
args.image_w = 256
args.img_extentions = ["png"]
args.training_thresholds = [250,200,150,50,0,300]
args.base_lr = 1
args.lr = args.base_lr
args.snapshot_interval = 5000
args.debug = True


# growth_rate = (4*(2**(args.gpu_num)))
transition_scale=2
pretrained_scale=4
growth_rate = 32

#######
# args.test_scene = ['alley_2', 'bamboo_2', 'bandage_2', 'cave_4', 'market_5', 'mountain_1', 'shaman_3', 'sleeping_2', 'temple_3']
args.test_scene = 'bandage_2'
gradient=False
args.gpu_num = 0
#######

writer_comment = '{}_rgb'.format(args.test_scene)
if gradient == True:
    writer_comment = '{}_gd'.format(args.test_scene)

offset = 0.
if gradient == True: offset = 0.5

args.display_interval = 50
args.display_curindex = 0

system_ = platform.system()
system_dist, system_version, _ = platform.dist()
if system_ == "Darwin": 
    args.train_dir = '/Volumes/Transcend/dataset/sintel2'
    args.pretrained = False
elif platform.dist() ==  ('debian', 'jessie/sid', ''):
    args.train_dir = '/home/lwp/workspace/sintel2'
    args.pretrained = True
elif platform.dist() == ('debian', 'stretch/sid', ''):
    args.train_dir = '/home/cad/lwp/workspace/dataset/sintel2'
    args.pretrained = True

if platform.system() == 'Linux': use_gpu = True
else: use_gpu = False
if use_gpu:
    torch.cuda.set_device(args.gpu_num)
    

print(platform.dist())

('debian', 'jessie/sid', '')


# My DataLoader

In [3]:


train_dataset = MyImageFolder(args.train_dir, 'train',
                       transforms.Compose(
        [transforms.ToTensor()]
    ), random_crop=True, 
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)
test_dataset = MyImageFolder(args.train_dir, 'test', 
                       transforms.Compose(
        [transforms.CenterCrop((args.image_h, args.image_w)),
         transforms.ToTensor()]
    ), random_crop=False,
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)

train_loader = data_utils.DataLoader(train_dataset,1,True,num_workers=1)
test_loader = data_utils.DataLoader(test_dataset,1,True,num_workers=1)

# Load Pretrained Model

[Defination](https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py)
* DenseNet-121: num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16)
    * First Convolution: 32M -> 16M -> 8M
    * every transition: 8M -> 4M -> 2M (downsample 1/2, except the last block)

In [4]:
densenet = models.__dict__[args.arch](pretrained=args.pretrained)

for param in densenet.parameters():
    param.requires_grad = False

if use_gpu: densenet.cuda()


In [5]:
ss = 6
s0 = ss*5
# s0 = 2

args.display_curindex = 0
args.base_lr = 0.05
args.display_interval = 20
args.momentum = 0.9
args.epoches = 240
args.training_thresholds = [0,0,0,0,0,s0]
args.training_merge_thresholds = [s0+ss*3*3,s0+ss*2*3, s0+ss*1*3, s0, -1, s0+ss*4*3]
args.power = 0.5



# pretrained = PreTrainedModel(densenet)
# if use_gpu: 
#     pretrained.cuda()


net = GradientNet(densenet=densenet, growth_rate=growth_rate, 
                  transition_scale=transition_scale, pretrained_scale=pretrained_scale,
                 gradient=gradient)
if use_gpu:
    net.cuda()

if use_gpu: 
    mse_losses = [nn.MSELoss().cuda()] * 6
    test_losses = [nn.MSELoss().cuda()] * 6
    mse_merge_losses = [nn.MSELoss().cuda()] * 6
    test_merge_losses = [nn.MSELoss().cuda()] * 6
else:
    mse_losses = [nn.MSELoss()] * 6
    mse_merge_losses = [nn.MSELoss()] * 6
    test_losses = [nn.MSELoss()] * 6
    test_merge_losses = [nn.MSELoss()] * 6    

_ ConvTranspose2d weight 0.002867696673382022
_ ConvTranspose2d weight 0.002867696673382022
_ ConvTranspose2d weight 0.003031695312954162
_ ConvTranspose2d weight 0.003031695312954162
_ ConvTranspose2d weight 0.004419417382415922


In [6]:
def test_model(epoch, go_through_merge=False, phase='train'):
    if phase == 'train': net.train()
    else: net.eval()
    
    test_losses_trainphase = [0] * len(args.training_thresholds)
    test_cnts_trainphase   = [0.00001] * len(args.training_thresholds)  
    test_merge_losses_trainphase = [0] * len(args.training_thresholds)
    test_merge_cnts_trainphase   = [0.00001] * len(args.training_thresholds)
    
    for ind, data in enumerate(test_loader, 0):
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda(args.gpu_num)
        
#         pretrained.train(); ft_pretreained = pretrained(input_img)
        ft_test, merged_RGB = net(input_img, go_through_merge=go_through_merge)
            
        for i,v in enumerate(ft_test):
            if epoch < args.training_thresholds[i]: continue
            if i == 5: s = 1
            else: s = (2**(i+1))
            gt0 = gt_albedo.cpu().data.numpy()
            n,c,h,w = gt0.shape
            gt, display = myutils.processGt(gt0, scale_factor=s, gd=gradient, return_image=True)
            gt_mg, display_mg = myutils.processGt(gt0, scale_factor=s//2, gd=gradient, return_image=True)
            
            if use_gpu: 
                gt = gt.cuda()
                gt_mg = gt_mg.cuda()
            
            if i != 5: 
                loss = mse_losses[i](ft_test[i], gt)
                test_losses_trainphase[i] += loss.data.cpu().numpy()[0]
                test_cnts_trainphase[i] += 1
            
            if go_through_merge != False and i != 4:
                if ((go_through_merge == '32M') or
                    (go_through_merge == '16M' and i != 5) or  
                    (go_through_merge == '08M' and i != 5 and i > 0) or
                    (go_through_merge == '04M' and i != 5 and i > 1) or
                    (go_through_merge == '02M' and i != 5 and i > 2)):
                    if i==5: gt2=gt
                    else: gt2=gt_mg
#                     print(i)
#                     print('merge size', merged_RGB[i].size())
#                     print('gt2 size', gt2.size())
                    loss = mse_merge_losses[i](merged_RGB[i], gt2)
                    test_merge_losses_trainphase[i] += loss.data.cpu().numpy()[0]
                    test_merge_cnts_trainphase[i] += 1
            

            
            if ind == 0: 
                if i != 5:
                    v = v[0].cpu().data.numpy()
                    v = v.transpose(1,2,0)
                    v = v[:,:,0:3]
                    cv2.imwrite('snapshot{}/test-phase_{}-{}-{}.png'.format(args.gpu_num, phase, epoch, i), (v[:,:,::-1]+offset)*255)
                if go_through_merge != False and i != 4:
                    if ((go_through_merge == '32M') or
                    (go_through_merge == '16M' and i != 5) or  
                    (go_through_merge == '08M' and i != 5 and i > 0) or
                    (go_through_merge == '04M' and i != 5 and i > 1) or
                    (go_through_merge == '02M' and i != 5 and i > 2)):
                        v = merged_RGB[i][0].cpu().data.numpy()
                        v = v.transpose(1,2,0)
                        v = v[:,:,0:3]
                        cv2.imwrite('snapshot{}/test-mg-phase_{}-{}-{}.png'.format(args.gpu_num, phase, epoch, i), (v[:,:,::-1]+offset)*255)
                    
    run_losses = test_losses_trainphase
    run_cnts = test_cnts_trainphase
    writer.add_scalars('16M loss', {'test 16M phase {}'.format(phase): np.array([run_losses[0]/ run_cnts[0]])}, global_step=epoch)  
    writer.add_scalars('8M loss', {'test 8M phase {}'.format(phase): np.array([run_losses[1]/ run_cnts[1]])}, global_step=epoch) 
    writer.add_scalars('4M loss', {'test 4M phase {}'.format(phase): np.array([run_losses[2]/ run_cnts[2]])}, global_step=epoch) 
    writer.add_scalars('2M loss', {'test 2M ': np.array([run_losses[3]/ run_cnts[3]])}, global_step=epoch) 
    writer.add_scalars('1M loss', {'test 1M phase {}'.format(phase): np.array([run_losses[4]/ run_cnts[4]])}, global_step=epoch) 
    writer.add_scalars('merged loss', {'test merged phase {}'.format(phase): np.array([run_losses[5]/ run_cnts[5]])}, global_step=epoch)
    
    run_losses = test_merge_losses_trainphase
    run_cnts = test_merge_cnts_trainphase
    writer.add_scalars('16M loss', {'mg test 16M phase {}'.format(phase): np.array([run_losses[0]/ run_cnts[0]])}, global_step=epoch)  
    writer.add_scalars('8M loss', {'mg test 8M phase {}'.format(phase): np.array([run_losses[1]/ run_cnts[1]])}, global_step=epoch) 
    writer.add_scalars('4M loss', {'mg test 4M phase {}'.format(phase): np.array([run_losses[2]/ run_cnts[2]])}, global_step=epoch) 
    writer.add_scalars('2M loss', {'mg test 2M ': np.array([run_losses[3]/ run_cnts[3]])}, global_step=epoch) 
    writer.add_scalars('1M loss', {'mg test 1M phase {}'.format(phase): np.array([run_losses[4]/ run_cnts[4]])}, global_step=epoch) 
    writer.add_scalars('merged loss', {'mg test merged phase {}'.format(phase): np.array([run_losses[5]/ run_cnts[5]])}, global_step=epoch)

In [7]:
# training loop

writer = SummaryWriter(comment='-{}'.format(writer_comment))

parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=args.base_lr, momentum=args.momentum)

def adjust_learning_rate(optimizer, epoch, beg, end, reset_lr=None, base_lr=args.base_lr):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    for param_group in optimizer.param_groups:
#         print('para gp', param_group)
        if reset_lr != None:
            param_group['lr'] = reset_lr
            continue
        param_group['lr'] = base_lr * (float(end-epoch)/(end-beg)) ** (args.power)
        if param_group['lr'] < 1.0e-8: param_group['lr'] = 1.0e-8
        

for epoch in range(args.epoches):
#     epoch = 234
    net.train()
    print('epoch: {} [{}]'.format(epoch, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

    if epoch < args.training_thresholds[-1]: 
        adjust_learning_rate(optimizer, epoch, beg=0, end=s0-1)
    elif epoch < args.training_merge_thresholds[-1]:
        adjust_learning_rate(optimizer, (epoch-s0)%(ss), beg=0, end=ss-1, base_lr=args.base_lr)
    else:
        adjust_learning_rate(optimizer, epoch, beg=args.training_merge_thresholds[-1], end=args.epoches-1, base_lr=args.base_lr)  
        
        
    if epoch < args.training_thresholds[-1]: go_through_merge = False
    elif epoch >= args.training_merge_thresholds[5]: go_through_merge = '32M'
    elif epoch >= args.training_merge_thresholds[0]: go_through_merge = '16M'
    elif epoch >= args.training_merge_thresholds[1]: go_through_merge = '08M'
    elif epoch >= args.training_merge_thresholds[2]: go_through_merge = '04M'
    elif epoch >= args.training_merge_thresholds[3]: go_through_merge = '02M'

    run_losses = [0] * len(args.training_thresholds)
    run_cnts   = [0.00001] * len(args.training_thresholds)
    run_merge_losses = [0] * len(args.training_thresholds)
    run_merge_cnts   = [0.00001] * len(args.training_thresholds)
    if (epoch in args.training_thresholds) == True: 
        adjust_learning_rate(optimizer, epoch, reset_lr=args.base_lr, beg=-1, end=-1)
    if (epoch in args.training_merge_thresholds) == True:
        adjust_learning_rate(optimizer, epoch, reset_lr=args.base_lr, beg=-1, end=-1)
        
    writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], global_step=epoch)
    for ind, data in enumerate(train_loader, 0):
#         if  ind == 1 : break
        """prepare  training data"""
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        im = input_img[0,:,:,:].numpy(); im = im.transpose(1,2,0); im = im[:,:,::-1]*255
        input_img, gt_albedo, gt_shading = Variable(input_img), Variable(gt_albedo), Variable(gt_shading)
        if use_gpu: input_img, gt_albedo, gt_shading = input_img.cuda(), gt_albedo.cuda(), gt_shading.cuda()

        if args.display_curindex % args.display_interval == 0: cv2.imwrite('snapshot{}/input.png'.format(args.gpu_num), im)

        optimizer.zero_grad()
        
            
        ft_predict, merged_RGB = net(input_img, go_through_merge=go_through_merge)
        for i, threshold in enumerate(args.training_thresholds):
            if epoch >= threshold:
#             if epoch >= 0:
                """prepare resized gt"""
                if i == 5: s = 1
                else: s = (2**(i+1))
                gt0 = gt_albedo.cpu().data.numpy()
                n,c,h,w = gt0.shape
                gt, display = myutils.processGt(gt0, scale_factor=s, gd=gradient, return_image=True)
                gt_mg, display_mg = myutils.processGt(gt0, scale_factor=s//2, gd=gradient, return_image=True)
                if use_gpu: 
                    gt = gt.cuda()
                    gt_mg = gt_mg.cuda()
                if args.display_curindex % args.display_interval == 0:
                    display = display[:,:,0:3]
                    cv2.imwrite('snapshot{}/gt-{}-{}.png'.format(args.gpu_num, epoch, i), display[:,:,::-1]*255)                
                
                """compute loss"""
                if i != 5: 
                    loss = mse_losses[i](ft_predict[i], gt)
                    run_losses[i] += loss.data.cpu().numpy()[0]
                    loss.backward(retain_graph=True)
                    run_cnts[i] += 1
                
                if go_through_merge != False and i != 4:
                    if ((go_through_merge == '32M') or
                    (go_through_merge == '16M' and i != 5) or  
                    (go_through_merge == '08M' and i != 5 and i > 0) or
                    (go_through_merge == '04M' and i != 5 and i > 1) or
                    (go_through_merge == '02M' and i != 5 and i > 2)):
#                         print(epoch, go_through_merge, i)
                        
#                         print (merged_RGB[i].cpu().data.numpy().max(), merged_RGB[i].cpu().data.numpy().min())
                        if i==5: gt2=gt
                        else: gt2=gt_mg
#                         print(i)
#                         print('merge size', merged_RGB[i].size())
#                         print('gt2 size', gt2.size())
                        loss = mse_merge_losses[i](merged_RGB[i], gt2)
                        run_merge_losses[i] += loss.data.cpu().numpy()[0]
                        loss.backward(retain_graph=True)
                        run_merge_cnts[i] += 1
                
                """save training image"""
                if args.display_curindex % args.display_interval == 0:
                    
                    if i != 5:
                        im = (ft_predict[i].cpu().data.numpy()[0].transpose((1,2,0))+offset) * 255
                        im = im[:,:,0:3]
                        
                        cv2.imwrite('snapshot{}/train-{}-{}.png'.format(args.gpu_num, epoch, i), im[:,:,::-1])
                    
                    if go_through_merge != False and i != 4:
                        if ((go_through_merge == '32M') or
                        (go_through_merge == '16M' and i != 5) or  
                        (go_through_merge == '08M' and i != 5 and i > 0) or
                        (go_through_merge == '04M' and i != 5 and i > 1) or
                        (go_through_merge == '02M' and i != 5 and i > 2)):
                            im = (merged_RGB[i].cpu().data.numpy()[0].transpose((1,2,0))+offset) * 255
                            im = im[:,:,0:3]
                            cv2.imwrite('snapshot{}/train-mg-{}-{}.png'.format(args.gpu_num, epoch, i), im[:,:,::-1])
        optimizer.step()
        args.display_curindex += 1

    """ every epoch """
#     loss_output = 'ind: ' + str(args.display_curindex)
    loss_output = ''
    
    
    
    for i,v in enumerate(run_losses):
        if i == len(run_losses)-1: 
            loss_output += ' merged: %6f' % (run_losses[i] / run_cnts[i])
            continue
        loss_output += ' %2dM: %6f' % ((2**(4-i)), (run_losses[i] / run_cnts[i]))
    print(loss_output)
    loss_output = ''
    for i,v in enumerate(run_merge_losses):
        if i == len(run_merge_losses)-1: 
            loss_output += 'mg merged: %6f' % (run_merge_losses[i] / run_merge_cnts[i])
            continue
        loss_output += ' mg %2dM: %6f' % ((2**(4-i)), (run_merge_losses[i] / run_merge_cnts[i]))
    print(loss_output)
    
    """save at every epoch"""
    if (epoch+1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'args' : args,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }, 'snapshot{}/snapshot-{}.pth.tar'.format(args.gpu_num, epoch))
    
    # test 
    if (epoch+1) % 5 == 0:
        test_model(epoch, phase='train', go_through_merge=go_through_merge)
        test_model(epoch, phase='test', go_through_merge=go_through_merge)

        writer.add_scalars('16M loss', {'train 16M ': np.array([run_losses[0]/ run_cnts[0]])}, global_step=epoch)  
        writer.add_scalars('8M loss', {'train 8M ': np.array([run_losses[1]/ run_cnts[1]])}, global_step=epoch) 
        writer.add_scalars('4M loss', {'train 4M ': np.array([run_losses[2]/ run_cnts[2]])}, global_step=epoch) 
        writer.add_scalars('2M loss', {'train 2M ': np.array([run_losses[3]/ run_cnts[3]])}, global_step=epoch) 
        writer.add_scalars('1M loss', {'train 1M ': np.array([run_losses[4]/ run_cnts[4]])}, global_step=epoch) 
        writer.add_scalars('merged loss', {'train merged ': np.array([run_losses[5]/ run_cnts[5]])}, global_step=epoch) 


epoch: 0 [2017-12-20 13:10:34]
 16M: 0.047860  8M: 0.046290  4M: 0.052096  2M: 0.049988  1M: 0.046738 merged: 0.000000
 mg 16M: 0.000000 mg  8M: 0.000000 mg  4M: 0.000000 mg  2M: 0.000000 mg  1M: 0.000000mg merged: 0.000000
epoch: 1 [2017-12-20 13:11:53]
 16M: 0.035208  8M: 0.036065  4M: 0.036911  2M: 0.038711  1M: 0.038810 merged: 0.000000
 mg 16M: 0.000000 mg  8M: 0.000000 mg  4M: 0.000000 mg  2M: 0.000000 mg  1M: 0.000000mg merged: 0.000000
epoch: 2 [2017-12-20 13:13:11]
 16M: 0.035811  8M: 0.026222  4M: 0.029935  2M: 0.026393  1M: 0.027452 merged: 0.000000
 mg 16M: 0.000000 mg  8M: 0.000000 mg  4M: 0.000000 mg  2M: 0.000000 mg  1M: 0.000000mg merged: 0.000000
epoch: 3 [2017-12-20 13:14:29]
 16M: 0.025005  8M: 0.020065  4M: 0.023321  2M: 0.020171  1M: 0.021057 merged: 0.000000
 mg 16M: 0.000000 mg  8M: 0.000000 mg  4M: 0.000000 mg  2M: 0.000000 mg  1M: 0.000000mg merged: 0.000000
epoch: 4 [2017-12-20 13:15:46]
 16M: 0.021601  8M: 0.017884  4M: 0.021716  2M: 0.015772  1M: 0.016284 me

 16M: 0.007830  8M: 0.007170  4M: 0.006077  2M: 0.005046  1M: 0.005122 merged: 0.000000
 mg 16M: 0.000000 mg  8M: 0.000000 mg  4M: 0.000000 mg  2M: 0.009380 mg  1M: 0.000000mg merged: 0.000000
epoch: 38 [2017-12-20 14:02:21]
 16M: 0.007393  8M: 0.006496  4M: 0.005574  2M: 0.004648  1M: 0.004746 merged: 0.000000
 mg 16M: 0.000000 mg  8M: 0.000000 mg  4M: 0.000000 mg  2M: 0.007945 mg  1M: 0.000000mg merged: 0.000000
epoch: 39 [2017-12-20 14:03:53]
 16M: 0.007124  8M: 0.005906  4M: 0.004961  2M: 0.003965  1M: 0.004195 merged: 0.000000
 mg 16M: 0.000000 mg  8M: 0.000000 mg  4M: 0.000000 mg  2M: 0.006841 mg  1M: 0.000000mg merged: 0.000000
epoch: 40 [2017-12-20 14:05:32]
 16M: 0.006373  8M: 0.005445  4M: 0.004478  2M: 0.003647  1M: 0.003720 merged: 0.000000
 mg 16M: 0.000000 mg  8M: 0.000000 mg  4M: 0.000000 mg  2M: 0.006194 mg  1M: 0.000000mg merged: 0.000000
epoch: 41 [2017-12-20 14:07:05]
 16M: 0.006545  8M: 0.005604  4M: 0.004558  2M: 0.003788  1M: 0.003596 merged: 0.000000
 mg 16M: 0.0

 16M: 0.005465  8M: 0.004871  4M: 0.003870  2M: 0.003400  1M: 0.003512 merged: 0.000000
 mg 16M: 0.000000 mg  8M: 0.005961 mg  4M: 0.004905 mg  2M: 0.004989 mg  1M: 0.000000mg merged: 0.000000
epoch: 75 [2017-12-20 15:14:46]
 16M: 0.005440  8M: 0.004633  4M: 0.003695  2M: 0.002943  1M: 0.003319 merged: 0.000000
 mg 16M: 0.000000 mg  8M: 0.005381 mg  4M: 0.004496 mg  2M: 0.004564 mg  1M: 0.000000mg merged: 0.000000
epoch: 76 [2017-12-20 15:17:26]
 16M: 0.005002  8M: 0.004212  4M: 0.003360  2M: 0.002723  1M: 0.002931 merged: 0.000000
 mg 16M: 0.000000 mg  8M: 0.004794 mg  4M: 0.004110 mg  2M: 0.004210 mg  1M: 0.000000mg merged: 0.000000
epoch: 77 [2017-12-20 15:19:58]
 16M: 0.004874  8M: 0.004126  4M: 0.003507  2M: 0.002487  1M: 0.002709 merged: 0.000000
 mg 16M: 0.000000 mg  8M: 0.004689 mg  4M: 0.004017 mg  2M: 0.004136 mg  1M: 0.000000mg merged: 0.000000
epoch: 78 [2017-12-20 15:22:36]
 16M: 0.005930  8M: 0.005122  4M: 0.004214  2M: 0.003249  1M: 0.003396 merged: 0.000000
 mg 16M: 0.0

 16M: 0.005360  8M: 0.004250  4M: 0.003569  2M: 0.002733  1M: 0.003095 merged: 0.000000
 mg 16M: 0.004472 mg  8M: 0.004006 mg  4M: 0.003722 mg  2M: 0.004068 mg  1M: 0.000000mg merged: 0.004472
epoch: 112 [2017-12-20 17:39:03]
 16M: 0.005133  8M: 0.004165  4M: 0.003475  2M: 0.002639  1M: 0.002986 merged: 0.000000
 mg 16M: 0.004278 mg  8M: 0.003895 mg  4M: 0.003582 mg  2M: 0.003889 mg  1M: 0.000000mg merged: 0.004278
epoch: 113 [2017-12-20 17:44:07]
 16M: 0.005456  8M: 0.004404  4M: 0.003596  2M: 0.003254  1M: 0.003119 merged: 0.000000
 mg 16M: 0.004489 mg  8M: 0.004052 mg  4M: 0.003829 mg  2M: 0.004360 mg  1M: 0.000000mg merged: 0.004489
epoch: 114 [2017-12-20 17:49:06]
 16M: 0.004939  8M: 0.003974  4M: 0.003112  2M: 0.002552  1M: 0.002898 merged: 0.000000
 mg 16M: 0.004069 mg  8M: 0.003703 mg  4M: 0.003446 mg  2M: 0.003902 mg  1M: 0.000000mg merged: 0.004069
epoch: 115 [2017-12-20 17:54:12]
 16M: 0.004965  8M: 0.003964  4M: 0.003150  2M: 0.002607  1M: 0.002916 merged: 0.000000
 mg 16M:

 16M: 0.003920  8M: 0.003058  4M: 0.002403  2M: 0.001912  1M: 0.002189 merged: 0.000000
 mg 16M: 0.002706 mg  8M: 0.002750 mg  4M: 0.002579 mg  2M: 0.002970 mg  1M: 0.000000mg merged: 0.002706
epoch: 149 [2017-12-20 20:45:00]
 16M: 0.003859  8M: 0.003038  4M: 0.002392  2M: 0.001891  1M: 0.002190 merged: 0.000000
 mg 16M: 0.002623 mg  8M: 0.002682 mg  4M: 0.002537 mg  2M: 0.002959 mg  1M: 0.000000mg merged: 0.002623
epoch: 150 [2017-12-20 20:50:07]
 16M: 0.004167  8M: 0.003146  4M: 0.002490  2M: 0.001980  1M: 0.002263 merged: 0.000000
 mg 16M: 0.002661 mg  8M: 0.002759 mg  4M: 0.002626 mg  2M: 0.003024 mg  1M: 0.000000mg merged: 0.002661
epoch: 151 [2017-12-20 20:55:09]
 16M: 0.003930  8M: 0.003036  4M: 0.002354  2M: 0.001912  1M: 0.002135 merged: 0.000000
 mg 16M: 0.002612 mg  8M: 0.002641 mg  4M: 0.002499 mg  2M: 0.002892 mg  1M: 0.000000mg merged: 0.002612
epoch: 152 [2017-12-20 21:00:07]
 16M: 0.003882  8M: 0.003084  4M: 0.002420  2M: 0.001863  1M: 0.002112 merged: 0.000000
 mg 16M:

epoch: 185 [2017-12-20 23:46:39]
 16M: 0.003441  8M: 0.002699  4M: 0.002053  2M: 0.001601  1M: 0.001749 merged: 0.000000
 mg 16M: 0.002203 mg  8M: 0.002344 mg  4M: 0.002212 mg  2M: 0.002584 mg  1M: 0.000000mg merged: 0.002203
epoch: 186 [2017-12-20 23:51:45]
 16M: 0.003330  8M: 0.002607  4M: 0.001974  2M: 0.001616  1M: 0.001771 merged: 0.000000
 mg 16M: 0.002168 mg  8M: 0.002317 mg  4M: 0.002196 mg  2M: 0.002581 mg  1M: 0.000000mg merged: 0.002168
epoch: 187 [2017-12-20 23:56:47]
 16M: 0.003517  8M: 0.002793  4M: 0.002063  2M: 0.001614  1M: 0.001725 merged: 0.000000
 mg 16M: 0.002193 mg  8M: 0.002332 mg  4M: 0.002215 mg  2M: 0.002567 mg  1M: 0.000000mg merged: 0.002193
epoch: 188 [2017-12-21 00:01:55]
 16M: 0.003288  8M: 0.002583  4M: 0.001963  2M: 0.001538  1M: 0.001681 merged: 0.000000
 mg 16M: 0.002118 mg  8M: 0.002259 mg  4M: 0.002153 mg  2M: 0.002535 mg  1M: 0.000000mg merged: 0.002118
epoch: 189 [2017-12-21 00:06:57]
 16M: 0.003311  8M: 0.002619  4M: 0.001965  2M: 0.001576  1M: 0

 16M: 0.003007  8M: 0.002342  4M: 0.001747  2M: 0.001390  1M: 0.001499 merged: 0.000000
 mg 16M: 0.001873 mg  8M: 0.002066 mg  4M: 0.001965 mg  2M: 0.002306 mg  1M: 0.000000mg merged: 0.001873
epoch: 223 [2017-12-21 02:58:23]
 16M: 0.002974  8M: 0.002336  4M: 0.001748  2M: 0.001414  1M: 0.001488 merged: 0.000000
 mg 16M: 0.001859 mg  8M: 0.002050 mg  4M: 0.001956 mg  2M: 0.002318 mg  1M: 0.000000mg merged: 0.001859
epoch: 224 [2017-12-21 03:03:26]
 16M: 0.002949  8M: 0.002326  4M: 0.001736  2M: 0.001380  1M: 0.001442 merged: 0.000000
 mg 16M: 0.001823 mg  8M: 0.002012 mg  4M: 0.001920 mg  2M: 0.002260 mg  1M: 0.000000mg merged: 0.001823
epoch: 225 [2017-12-21 03:08:35]
 16M: 0.002923  8M: 0.002311  4M: 0.001717  2M: 0.001389  1M: 0.001490 merged: 0.000000
 mg 16M: 0.001842 mg  8M: 0.002052 mg  4M: 0.001958 mg  2M: 0.002334 mg  1M: 0.000000mg merged: 0.001842
epoch: 226 [2017-12-21 03:13:37]
 16M: 0.002939  8M: 0.002280  4M: 0.001727  2M: 0.001377  1M: 0.001478 merged: 0.000000
 mg 16M:

# Visualize Graph

In [8]:
from graphviz import Digraph
import torch
from torch.autograd import Variable


def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph
    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function
    Args:
        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    """
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="10240,10240"), format='svg')
    seen = set()

    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

In [9]:
# x = Variable(torch.zeros(1,3,256,256))
# y = net(x.cuda())
# g = make_dot(y[-1])


In [10]:
# g.render('net-transition_scale_{}'.format(transition_scale)) 