In [1]:
from lib.network import PoseNetRGBOnly, PoseNet
from lib.loss import Loss
from lib.utils import setup_logger
from datasets.linemod.dataset import PoseDataset as PoseDataset_linemod
import torch
import os
from torch import optim
import numpy as np
import time
from torch.autograd import Variable
%load_ext autoreload
%autoreload 2


In [2]:
import warnings
warnings.filterwarnings("ignore")

# Load data

In [3]:
num_objects = 5
num_points = 500
outf = 'trained_models/linemod'
log_dir = 'experiments/logs/linemod'
repeat_epoch = 20
decay_margin = 0.016
decay_start = False
lr_rate = 0.3
w_rate = 0.3

is_RGBD = True
input_depth=True
add_depth_to_output=True
true_depth=False

In [5]:
dataset_root = "./datasets/linemod/Linemod_preprocessed"
noise_trans = 0.03
refine_start = False
decay_start = False


dataset = PoseDataset_linemod('train', 
                              num_points, 
                              True, 
                              dataset_root, 
                              noise_trans, 
                              refine_start, 
                              use_true_depth=False)

dataloader = torch.utils.data.DataLoader(dataset, 
                                         batch_size=1, 
                                         shuffle=True, 
                                         num_workers=4)
test_dataset = PoseDataset_linemod('test', 
                                   num_points, 
                                   False, 
                                   dataset_root, 
                                   0.0, 
                                   refine_start, use_true_depth=False)
testdataloader = torch.utils.data.DataLoader(test_dataset, 
                                             batch_size=1, 
                                             shuffle=False, 
                                             num_workers=4)

Object 2 buffer loaded
Object 4 buffer loaded
Object 5 buffer loaded
Object 10 buffer loaded
Object 11 buffer loaded
Object 2 buffer loaded
Object 4 buffer loaded
Object 5 buffer loaded
Object 10 buffer loaded
Object 11 buffer loaded


In [6]:
sym_list = dataset.get_sym_list()
num_points_mesh = dataset.get_num_points_mesh()

# Models

In [7]:
if is_RGBD:
    # RGBD
    estimator = PoseNet(num_points = num_points, num_obj = num_objects)
#     model = 'trained_models/linemod/pose_model_5_0.033827896095197464.pth'
else:
    estimator = PoseNetRGBOnly(num_points=num_points, num_obj = num_objects)
    model = 'trained_models/linemod/pose_model_9_0.03693710159944388.pth'
# estimator.load_state_dict(torch.load(model))
estimator.cuda();

In [8]:
lr = 0.0001
w = 0.015
# lr *= lr_rate
# w *= w_rate
optimizer = optim.Adam(estimator.parameters(), lr=lr)

# Cases
## True depth, predict depth
1. input + depth, output + depth
2. input, output + depth
3. input, output

In [9]:
criterion = Loss(num_points_mesh, sym_list)

In [10]:
start_epoch = 6
nepoch = 18
# w = 0.015
batch_size = 8

In [13]:
lr *= lr_rate
w *= w_rate
optimizer = optim.Adam(estimator.parameters(), lr=lr)

In [None]:
best_test = np.Inf
print_every = 50

if start_epoch == 1:
    for log in os.listdir(log_dir):
        if '.ipyn' not in log:
            os.remove(os.path.join(log_dir, log))
st_time = time.time()

for epoch in range(start_epoch, nepoch):
    logger = setup_logger('epoch%d' % epoch, os.path.join(log_dir, 'epoch_%d_log.txt' % epoch))
    mess = 'Train time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Training started')
    print(mess)
    logger.info(mess)
    
    train_count = 0
    train_dis_avg = 0.0
    
    estimator.train()
    optimizer.zero_grad()

    for rep in range(repeat_epoch):
        for i, data in enumerate(dataloader, 0):
            points, choose, img, target, model_points, idx = data  

            points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                             Variable(choose).cuda(), \
                                                             Variable(img).cuda(), \
                                                             Variable(target).cuda(), \
                                                             Variable(model_points).cuda(), \
                                                             Variable(idx).cuda()
            if is_RGBD:
                pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
            else:
#                 if not add_depth_to_output:
#                     points[0, :, 2] = 0 
                pred_r, pred_t, pred_c, emb = estimator(img, choose, idx)
                
            loss, dis, new_points, new_target = criterion(pred_r, 
                                                          pred_t, 
                                                          pred_c, 
                                                          target, 
                                                          model_points, 
                                                          idx, 
                                                          points, 
                                                          w, refine_start)
            loss.backward()
            
            train_dis_avg += dis.item()
            train_count += 1
            if train_count % batch_size == 0:
                
                mess = 'Train time {0} Epoch {1} Batch {2} Frame {3} Avg_dis:{4}'.format(
                        time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), 
                        epoch, int(train_count / batch_size), 
                        train_count, 
                        train_dis_avg / batch_size)
                logger.info(mess)
                
                optimizer.step()
                optimizer.zero_grad()
                train_dis_avg = 0
                
                
                if train_count % print_every == 0:
                    print(mess)
            
            if train_count != 0 and train_count % 500 == 0:                    
                torch.save(estimator.state_dict(), '{0}/pose_model_current.pth'.format(outf))
            
    print('>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(epoch))
    
    test_dis = 0.0
    test_count = 0
    estimator.eval()
    
    for j, data in enumerate(testdataloader, 0):
        points, choose, img, target, model_points, idx = data
        points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                         Variable(choose).cuda(), \
                                                         Variable(img).cuda(), \
                                                         Variable(target).cuda(), \
                                                         Variable(model_points).cuda(), \
                                                         Variable(idx).cuda()
        
        if is_RGBD:
            pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
        else:
#             if not add_depth_to_output:
#                 points[0, :, 2] = 0 
            pred_r, pred_t, pred_c, emb = estimator(img, choose, idx)
            
        _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, 
                                                   target, 
                                                   model_points, idx, points, 
                                                   w, refine_start)
        test_dis += dis.item()
        mess = 'Test time {0} Test Frame No.{1} dis:{2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_count, dis)
        logger.info(mess)
        
        test_count += 1
    test_dis = test_dis / test_count
    mess = 'Test time {0} Epoch {1} TEST FINISH Avg dis: {2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, test_dis)
    print(mess)
    logger.info(mess)
    
    if test_dis <= best_test:
        best_test = test_dis
        torch.save(estimator.state_dict(), '{0}/pose_model_{1}_{2}.pth'.format(outf, epoch, test_dis))
        print(epoch, '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<')
        
#     if best_test < decay_margin and not decay_start:
#         decay_start = True
#         lr *= lr_rate
#         w *= w_rate
#         optimizer = optim.Adam(estimator.parameters(), lr=lr)