In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import torch.utils.data as data
from collections import OrderedDict
import numpy as np
import datasets.GeneralDataset as GenDataset
import time
from san_vision import transforms
import os.path as osp
from utils import AverageMeter
from model import ITN_CPM
from model import np2variable, variable2np
import torchvision.transforms as TT
from tensorboardX import SummaryWriter

In [3]:
global params
params = {
    'num_pts': 68,
    'convert68to49' : False,
    'path': 'xx', 
    'argmax_radius' : 1,
    'downsample' : 8,
    'batch_size' : 20,
    'heatmap_type' : 'gaussian',
    'dataset_name' : '300W/Original(GTB)',
    'momentum' : 0.9,
    'learning_rate' : 0.00010,
    'decay' : 0.0005,
    'total_epochs' : 100,
    'crop_width' : 256,
    'crop_height' : 256,
    'pre_crop_expand' : 0.2,
    'crop_perturb_max' : 30,
    'scale_prob' : 1.1,
    'scale_max' : 1,
    'scale_min' : 1,
    'scale_eval' : 1,
    'sigma' : 4,
    'train_list' : '/home/abhirup/Datasets/300W-Style/box-coords/300W-Original/300w.train.GTB',
    'resume' : 'checkpoint.pth.tar'
         }

In [4]:
def weights_init_cpm(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0, 0.01)
        if m.bias is not None: m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        m.weight.data.fill_(1)
        m.bias.data.zero_()

In [5]:
def remove_module_dict(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v
    return new_state_dict
def load_weight_from_dict(model, weight_state_dict, param_pair=None, remove_prefix=True):
    if remove_prefix: weight_state_dict = remove_module_dict(weight_state_dict)
    all_parameter = model.state_dict()
    all_weights   = []
    finetuned_layer, random_initial_layer = [], []
    for key, value in all_parameter.items():
        if param_pair is not None and key in param_pair:
            all_weights.append((key, weight_state_dict[ param_pair[key] ]))
        elif key in weight_state_dict:
            all_weights.append((key, weight_state_dict[key]))
            finetuned_layer.append(key)
        else:
            all_weights.append((key, value))
            random_initial_layer.append(key)
#     print ('==>[load_model] finetuned layers : {}'.format(finetuned_layer))
#     print ('==>[load_model] keeped layers : {}'.format(random_initial_layer))
    all_weights = OrderedDict(all_weights)
    model.load_state_dict(all_weights)

In [6]:
def save_checkpoint(state, filename='checkpoint.pth.tar'):
    torch.save(state, filename)

In [7]:
def compute_loss(criterion, target_var, outputs, mask_var, total_labeled_cpm):
    total_loss = 0
    each_stage_loss = []
    mask_outputs = []
    for output_var in outputs:
        stage_loss = 0
        output = torch.masked_select(output_var, mask_var)
        target = torch.masked_select(target_var, mask_var)
    #     print(output.size(), target.size())
        mask_outputs.append(output)

        stage_loss = criterion(output, target)/(total_labeled_cpm*2)
        total_loss += stage_loss
        each_stage_loss.append(stage_loss.item())
    return total_loss, each_stage_loss
# mask_var[:,:num_pts,:,:]

In [8]:
mean_fill   = tuple( [int(x*255) for x in [0.5, 0.5, 0.5] ] )
normalize   = transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                      std=[0.5, 0.5, 0.5])

In [9]:
train_transform = [transforms.PreCrop(params['pre_crop_expand'])]
train_transform = [transforms.TrainScale2WH((params['crop_width'], params['crop_height']))]
train_transform += [transforms.AugScale(params['scale_prob'], params['scale_min'], params['scale_max'])]
train_transform += [transforms.AugCrop(params['crop_width'], params['crop_height'], params['crop_perturb_max'], mean_fill)]
train_transform += [transforms.ToTensor(), normalize]
train_transform = transforms.Compose( train_transform )

In [10]:
train_data = GenDataset(train_transform, params['sigma'], params['downsample'], params['heatmap_type'], params['dataset_name'])
train_data.load_list(params['train_list'], params['num_pts'], True)
# if(params['convert68to49']): train_data.convert68to49()

The general dataset initialization done, sigma is 4, downsample is 8, dataset-name : 300W/Original(GTB), self is : GeneralDataset(number of point=-1, heatmap_type=gaussian)
Load list from /home/abhirup/Datasets/300W-Style/box-coords/300W-Original/300w.train.GTB
Load [0/1]-th list : /home/abhirup/Datasets/300W-Style/box-coords/300W-Original/300w.train.GTB with 3148 images
Start load data for the general datas
Load data done for the general dataset, which has 3148 images.


In [11]:
# eval_transform  = transforms.Compose([transforms.PreCrop(params['pre_crop_expand']),
#                                       transforms.TrainScale2WH((params['crop_width'],params['crop_height'])),
#                                       transforms.ToTensor(), normalize])

# eval_data = GenDataset(eval_transform, params['sigma'], params['downsample'], params['heatmap_type'], params['dataset_name'])
# eval_data.load_list('/home/abhirup/Datasets/300W-Style/box-coords/300W-Original/test',
#                    68, True)
# eval_loader = data.DataLoader(eval_data, batch_size=5, shuffle=False, pin_memory=False)

In [12]:
writer = SummaryWriter(log_dir='./log')
net = ITN_CPM(params)
writer.add_graph(net)
net.apply(weights_init_cpm)
# net_param_dict = net.parameters()
net = net.cuda()

Error occurs, No graph saved


In [13]:
criterion = torch.nn.MSELoss(False)
criterion.cuda()
optimizer = torch.optim.Adam(
#                             net.parameters(), lr=params['learning_rate'],
                            net.specify_parameter(base_lr=params['learning_rate'], 
                                                  base_weight_decay=params['decay']),  amsgrad=False)
#                             momentum=params['momentum'],
#                             nesterov=True)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.75**(epoch//5))

In [15]:
start_epoch=0
best_loss=99999
if (params['resume'] and osp.isfile(params['resume'])):
    print("=> loading checkpoint '{}'".format(params['resume']))
    checkpoint = torch.load(params['resume'])
    start_epoch = checkpoint['epoch']
    best_loss = checkpoint['best_loss']
    net.load_state_dict(checkpoint['state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer'])
    print("=> loaded checkpoint '{}' (epoch {}) current_loss = {}"
          .format(params['resume'], checkpoint['epoch'], best_loss))
else:
    model_urls = 'http://download.pytorch.org/models/vgg16-397923af.pth'
    weights = model_zoo.load_url(model_urls)
    load_weight_from_dict(net, weights, None, False);

=> loading checkpoint 'checkpoint.pth.tar'
=> loaded checkpoint 'checkpoint.pth.tar' (epoch 27) current_loss = 19.98334503173828


In [16]:
train_loader = data.DataLoader(train_data, batch_size=params['batch_size'], shuffle=True, pin_memory=False)

In [17]:
batch_time = AverageMeter()
data_time = AverageMeter()
forward_time = AverageMeter()
visible_points = AverageMeter()

In [None]:
# start_epoch = 1
end = time.time()
for epoch in range(start_epoch, params['total_epochs']):
    scheduler.step(epoch)
    print('Epoch:{}, Now learning rate:{}'.format(epoch, scheduler.get_lr()))
    for i, (inputs, target, mask, points, image_index, label_sign) in enumerate(train_loader):
        # inputs : Batch, Squence, Channel, Height, Width
        # data prepare
        target = target.cuda(async=True)
        # get the real mask
        mask.masked_scatter_((1-label_sign).unsqueeze(-1).unsqueeze(-1), torch.ByteTensor(mask.size()).zero_())
        mask_var   = mask.cuda(async=True)

        batch_size, num_pts = inputs.size(0), mask.size(1)-1
        image_index = variable2np(image_index).squeeze(1).tolist()
        # check the label indicator, whether is has annotation or not
        sign_list = variable2np(label_sign).astype('bool').squeeze(1).tolist()
        data_time.update(time.time() - end)
        cvisible_points = torch.sum(mask[:,:-1,:,:]) * 1. / batch_size
        visible_points.update(cvisible_points, batch_size)

        batch_cpms, batch_locs, batch_scos = net(inputs.cuda())

#         forward_time.update(time.time() - end)

        total_labeled_cpm = int(np.sum(sign_list))

        cpm_loss, each_stage_loss_values = compute_loss(criterion, target, batch_cpms, mask_var, total_labeled_cpm)
        writer.add_scalars('Loss', {'total loss':cpm_loss, 
                                    'stage1 loss': each_stage_loss_values[0],
                                    'stage2 loss': each_stage_loss_values[1],
                                    'stage3 loss': each_stage_loss_values[2]}, epoch*209+i)
        for name, param in net.named_parameters():
            writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch*209+i)
        #passing 'weight_of_idt' or identity loss as None
        optimizer.zero_grad()
        cpm_loss.backward()
        optimizer.step()

        if (cpm_loss < best_loss + 5 and i%10==0):
            best_loss = cpm_loss
            save_checkpoint({
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'best_loss': best_loss,
                'optimizer' : optimizer.state_dict(),
            })
        print("batch_iter:{}, total_loss:{}, stage_losses:{}, {}, {}".format(i, cpm_loss, each_stage_loss_values[0], each_stage_loss_values[1], each_stage_loss_values[2]))
    print("End of {}th epoch\n".format(epoch))
#     if cpm_loss < best_loss:
#             save_checkpoint({
#                 'epoch': epoch,
#                 'state_dict': net.state_dict(),
#                 'best_loss': best_loss,
#                 'optimizer' : optimizer.state_dict(),
#             })

Epoch:27, Now learning rate:[2.3730468750000002e-05, 4.7460937500000004e-05, 2.3730468750000002e-05, 4.7460937500000004e-05, 2.3730468750000002e-05, 4.7460937500000004e-05, 9.492187500000001e-05, 0.00018984375000000002, 9.492187500000001e-05, 0.00018984375000000002]
batch_iter:0, total_loss:14.886838912963867, stage_losses:7.192742347717285, 3.745077610015869, 3.9490184783935547
batch_iter:1, total_loss:29.446449279785156, stage_losses:7.339168071746826, 10.28019905090332, 11.827082633972168
batch_iter:2, total_loss:24.208053588867188, stage_losses:10.54419231414795, 7.289914608001709, 6.373946666717529
batch_iter:3, total_loss:21.671680450439453, stage_losses:6.635532379150391, 7.9913763999938965, 7.044771671295166
batch_iter:4, total_loss:27.40978240966797, stage_losses:8.574284553527832, 9.302638053894043, 9.532858848571777
batch_iter:5, total_loss:23.738903045654297, stage_losses:9.726409912109375, 6.585782051086426, 7.426711559295654
batch_iter:6, total_loss:26.554901123046875, st

In [None]:
trans = TT.ToPILImage()
import matplotlib.pyplot as plt

In [None]:
%%time
for i, (inputs, target, mask, points, image_index, label_sign) in enumerate(eval_loader):
    images = [trans(x) for x in inputs]
    _, locs, _ = net(inputs.cuda())
    for j in range(len(images)):
        im = images[j]
        loc = locs[j].cpu().data.numpy()
        plt.imshow(im)
        plt.scatter(loc[:params['num_pts'],0], loc[:params['num_pts'],1], s=8, c='red')
        plt.show()