In [1]:
import os
import sys
import math
import timeit
import random
import json
import argparse
from argparse import Namespace

import numpy as np
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from mpl_toolkits.mplot3d import Axes3D

module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)
    
from datasets.NYUDataset import NYUDataset
from model.DepthGenLM import DepthGenLM
from utils.util import normalize_batch


def plot_hands(ax, points, color, linewidth='3'):
    # Add bone connections
    bones = [(0, 1),
             (1, 2),
             (2, 3),
             (3, 4),
             (0, 21),
             (21, 5),
             (5, 6),
             (6, 7),
             (7, 8),
             (21, 9),
             (9, 10),
             (10, 11),
             (11, 12),
             (21, 13),
             (13, 14),
             (14, 15),
             (15, 16),
             (21, 17),
             (17, 18),
             (18, 19),
             (19, 20)]

    for connection in bones:
        coord1 = points[connection[0]]
        coord2 = points[connection[1]]
        coords = np.stack([coord1, coord2])
        ax.plot(coords[:, 0], coords[:, 1], coords[:, 2], c=color, linewidth=linewidth)

In [2]:
tsfms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(120, Image.NEAREST),
    transforms.ToTensor(),
])

data_path = "/home/alex/Data/nyu/dataset/train/"
dataset = NYUDataset(data_path, tsfms, tsfms, [], 8192, True)

In [4]:
checkpoint_path = "../saved/pretrain.ckpt"

hparams = Namespace(**json.load(open(("/home/alex/dev/projects/3dhpe-udd/configs/alexnet_ss_nyu.json"))))
pretrained_weights = torch.load(checkpoint_path)
model = DepthGenLM(hparams)
model.load_state_dict(pretrained_weights['state_dict'])
model.cuda()

DepthGenLM(
  (estimator): AlexNetHM(
    (features): Sequential(
      (0): Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU(inplace=True)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (estimator): Sequential(
      (0): Linear(in_features=1024, out_features=4096, bias=True)
      (1): ReLU(inplace=True)

In [5]:
import torch.nn.functional as F

joint_idxs = [1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 6]
idx = torch.randint(0, len(dataset), [1]).item()
idx = 0
_, target, kps, kps14, center, norm_size, bbox, padding, has_anno = dataset[idx]
target = target.unsqueeze(0)

# Render Points
target[:, :, 0] = (target[:, :, 0] + 1.0) * 0.5 * 119
target[:, :, 1] = (target[:, :, 1] + 1.0) * 0.5 * 119

with torch.no_grad():
    depth = model.depth_gen.renderer(target)
    depth = normalize_batch(depth)
    
preds = model(depth.unsqueeze(0).cuda())
img = preds[0]
coords = preds[1]

kps_vis = kps.cpu()
samples = depth
pred_imgs = img.detach().cpu()
sample = samples[0]
pred_img = pred_imgs[0]
diff_img = (sample - pred_img).abs()
loss = diff_img.mean()
print(loss)


%matplotlib notebook
img_fig = plt.figure()
img_ax1 = img_fig.add_subplot(131)
img_ax2 = img_fig.add_subplot(132)
img_ax3 = img_fig.add_subplot(133)
img_ax1.imshow(pred_img)
img_ax2.imshow(sample)
img_ax3.imshow(diff_img)

kp_vis = kps.cpu()
coord_vis = coords[0, joint_idxs].detach().cpu()

coord_fig = plt.figure()
coord_ax = coord_fig.add_subplot(111, projection='3d')
plot_hands(coord_ax, kp_vis, 'r')
plot_hands(coord_ax, coord_vis, 'b')
# coord_ax.scatter(kp_vis[:, 0], kp_vis[:, 1], kp_vis[:, 2], 'r')
# coord_ax.scatter(coord_vis[:, 0], coord_vis[:, 1], coord_vis[:, 2], 'b')
coord_ax.view_init(90, -90)
annotations = [str(i) for i in range(21)]
for i, anno in enumerate(annotations):
    coord_ax.text(kp_vis[i, 0], kp_vis[i, 1], kp_vis[i, 2], anno)
    coord_ax.text(coord_vis[i, 0], coord_vis[i, 1], coord_vis[i, 2], anno)

tensor(0.1059)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>