In [6]:
from lib.network import DepthNet
from lib.loss_depth import LossDepth
from lib.utils import setup_logger, im_convert, depth_to_img
from datasets.linemod.dataset import DepthDataset
import torch
import os
from torch import optim
import numpy as np
import time
from torch.autograd import Variable
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore") 

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
dataset_root = "./datasets/linemod/Linemod_preprocessed"
log_dir = 'experiments/logs/depth_prediction'
outf = 'trained_models/depth_prediction'

dataset = DepthDataset('train', dataset_root)
dataloader = torch.utils.data.DataLoader(dataset, 
                                         batch_size=1, 
                                         shuffle=True, 
                                         num_workers=1)

Object 6 buffer loaded
Object 8 buffer loaded
Object 9 buffer loaded
Object 12 buffer loaded
Object 13 buffer loaded
Object 14 buffer loaded
Object 15 buffer loaded


In [8]:
test_dataset = DepthDataset('test', dataset_root)
testdataloader = torch.utils.data.DataLoader(test_dataset, 
                                             batch_size=1, 
                                             shuffle=False, 
                                             num_workers=1)

Object 6 buffer loaded
Object 8 buffer loaded
Object 9 buffer loaded
Object 12 buffer loaded
Object 13 buffer loaded
Object 14 buffer loaded
Object 15 buffer loaded


In [12]:
estimator = DepthNet()
estimator.cuda()
lr = 0.0001
optimizer = optim.Adam(estimator.parameters(), lr=lr)

batch_size = 32
print_every = 32

criterion = LossDepth()

In [13]:
start_epoch = 1
nepoch = 20
repeat_epoch = 1
best_test = np.Inf

if start_epoch == 1:
    for log in os.listdir(log_dir):
        if '.ipyn' not in log:
            os.remove(os.path.join(log_dir, log))
st_time = time.time()


for epoch in range(start_epoch, nepoch):
    logger = setup_logger('epoch%d' % epoch, os.path.join(log_dir, 'epoch_%d_log.txt' % epoch))
    mess = 'Train time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Training started')
    print(mess)
    logger.info(mess)
    
    train_count = 0
    loss_avg = 0.0
    
    estimator.train()
    optimizer.zero_grad()
    

    for i, data in enumerate(dataloader, 0):
        img, depth = data
        img, depth = img.float().cuda(), depth.float().cuda()

        pred_log_depth = estimator(img)

        loss = criterion(pred_log_depth, depth)

        loss.backward()

        loss_avg += loss.item()
        train_count += 1

        if train_count % batch_size == 0:

            mess = 'Train time {0} Epoch {1} Batch {2} Frame {3} Avg_dis:{4}'.format(
                    time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), 
                    epoch, int(train_count / batch_size), 
                    train_count, 
                    loss_avg / batch_size)
            logger.info(mess)

            optimizer.step()
            optimizer.zero_grad()
            loss_avg = 0

            if train_count % print_every == 0:
                print(mess)

        if train_count != 0 and train_count % 500 == 0:                    
            torch.save(estimator.state_dict(), '{0}/pose_model_current.pth'.format(outf))

    print('>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(epoch))
    
    test_loss = 0.0
    test_count = 0
    estimator.eval()
    
    for j, data in enumerate(testdataloader, 0):
        img, depth = data
        img, depth = img.float().cuda(), depth.float().cuda()
        
        pred_log_depth = estimator(img)
        loss = criterion(pred_log_depth, depth)
        
        test_loss += loss.item()
        mess = 'Test time {0} Test Frame No.{1} dis:{2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_count, loss)
        logger.info(mess)
        test_count += 1
        
    test_loss = test_loss / test_count
    mess = 'Test time {0} Epoch {1} TEST FINISH Avg dis: {2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, test_loss)
    print(mess)
    logger.info(mess)
    
    if test_loss <= best_test:
        best_test = test_loss
        torch.save(estimator.state_dict(), '{0}/pose_model_{1}_{2}.pth'.format(outf, epoch, test_loss))
        print(epoch, '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<')
        
#     if best_test < decay_margin and not decay_start:
#         decay_start = True
#         lr *= lr_rate
#         optimizer = optim.Adam(estimator.parameters(), lr=lr)

Train time 00h 00m 00s, Training started
Train time 00h 00m 06s Epoch 1 Batch 1 Frame 32 Avg_dis:24.05642569065094
Train time 00h 00m 13s Epoch 1 Batch 2 Frame 64 Avg_dis:21.240368366241455
Train time 00h 00m 19s Epoch 1 Batch 3 Frame 96 Avg_dis:10.0992983430624
Train time 00h 00m 26s Epoch 1 Batch 4 Frame 128 Avg_dis:7.479442711919546
Train time 00h 00m 32s Epoch 1 Batch 5 Frame 160 Avg_dis:2.087237721309066
Train time 00h 00m 39s Epoch 1 Batch 6 Frame 192 Avg_dis:4.018476247787476
Train time 00h 00m 45s Epoch 1 Batch 7 Frame 224 Avg_dis:5.101112507283688
Train time 00h 00m 52s Epoch 1 Batch 8 Frame 256 Avg_dis:2.6054070815443993
Train time 00h 00m 58s Epoch 1 Batch 9 Frame 288 Avg_dis:1.8963651731610298
Train time 00h 01m 05s Epoch 1 Batch 10 Frame 320 Avg_dis:4.194527769461274
Train time 00h 01m 11s Epoch 1 Batch 11 Frame 352 Avg_dis:1.2372991442680359
Train time 00h 01m 18s Epoch 1 Batch 12 Frame 384 Avg_dis:2.04804923851043
Train time 00h 01m 24s Epoch 1 Batch 13 Frame 416 Avg_dis

KeyboardInterrupt: 

In [6]:
t = next(iter(testdataloader))

In [7]:
img, depth = t

In [18]:
model = 'trained_models/depth_prediction/pose_model_2_0.12568162765364324.pth'
# estimator = DepthNet()
# estimator.cuda()
estimator.load_state_dict(torch.load(model))
# estimator.eval();

<All keys matched successfully>

In [None]:
depth.eq(0.).double()

In [None]:
depth_norm.shape

In [None]:
max_d = torch.max(depth)
min_d = torch.min(depth)
depth_norm = (depth - min_d) * 255 / (max_d - min_d)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
ax1.imshow(im_convert(img))
ax1.axis("off")
ax2.imshow(depth_to_img(depth_norm), cmap='gray')
ax2.axis("off")