In [None]:
# %matplotlib notebook 
from model.resattnet import SelfPose as SelfPose
from dataloader import XRegoDataset 
from torch.utils.data import DataLoader
import time
import copy
import torch
import numpy as np
import matplotlib.pyplot as plt
import webdataset as wds
from torchvision import transforms
import glob
from webdataset.handlers import ignore_and_continue
from mpl_toolkits.mplot3d import axes3d 
import typing as tp
import numpy as np
import matplotlib.pyplot as plt


In [None]:
preproc = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
])

def transform(x):
    x = x.decode("utf-8").split(",")
    x = np.array(x).astype(float).reshape(-1,3)
    x = np.concatenate(( x[14:17], x[22:25]))

    return x.reshape(-1)

def image_preproc(x):
    return preproc(x)





## Plot Kinematic Tree

In [None]:


def get_chain_dots(
        joints: np.ndarray,   # shape == (n_dots, 3)
        chain_dots_indexes: tp.List[int], # length == n_dots_in_chain
                                          # in continuous order, i.e. 
                                          # left_hand_ix >>> chest_ix >>> right_hand_ix
        ) -> np.ndarray:    # chain of dots
    return joints[chain_dots_indexes]


def get_chains(
        joints,   # shape == (n_dots, 3)
        head: tp.List[int], # pelvis >>> chest >>> head
        spine: tp.List[int], # left_leg >>> pelvis >>> right_leg
        arm1: tp.List[int],
        arm2,
        neck,
        right_leg,
        left_leg,
        ):
    return (get_chain_dots(joints, head),
            get_chain_dots(joints, spine),
            get_chain_dots(joints, arm1),
            get_chain_dots(joints, arm2),
            get_chain_dots(joints, neck),
            get_chain_dots(joints, right_leg),
            get_chain_dots(joints, left_leg),
            )


def subplot_nodes(dots: np.ndarray, ax, size=20):
    return ax.scatter3D(*dots.T, c=dots[:, -1], s=size)


def subplot_bones(chains: tp.Tuple[np.ndarray, ...], ax):
    return [ax.plot(*chain.T) for chain in chains]


def plot_skeletons(skeletons, chains_ixs, err_size=20, std_size=20):
    fig = plt.figure(figsize=(10, 5))
    sizes = [err_size, std_size]
    labels = ["Prediction: circle size=uncertainty", "GT: circle size=error"]
    for i, joints in enumerate(skeletons, start=1):
        chains = get_chains(joints, *chains_ixs)
        ax = fig.add_subplot(1, 2, i, projection='3d')
        ax.set_xlabel(labels[i-1], fontsize=10, rotation=100)
        subplot_nodes(joints, ax, sizes[i-1])
        subplot_bones(chains, ax)
    plt.show()
    
def plot(gt_pose, pred_pose, err_size=20, std_size=20):
    # Kinematic Tree for visualization
    # each list create bones between joints
    chains_ixs = ([0],[1, 2, 3],[3, 4, 5, 6],[3, 7, 8, 9],[3, 10],[1, 11,12,13],[1, 14,15,16]) 
    plot_skeletons([gt_pose, pred_pose], chains_ixs, err_size, std_size)
 
       


In [None]:
import imgaug.augmenters as iaa
from PIL import Image
import io



def Gaussian(sigma):
    if sigma == 7:
        return np.array([1.23409802e-04, 1.50343915e-03, 6.73794700e-03, 1.11089963e-02,
                     6.73794700e-03, 1.50343915e-03, 1.23409802e-04, 1.50343915e-03,
                     1.83156393e-02, 8.20849985e-02, 1.35335281e-01, 8.20849985e-02,
                     1.83156393e-02, 1.50343915e-03, 6.73794700e-03, 8.20849985e-02,
                     3.67879450e-01, 6.06530666e-01, 3.67879450e-01, 8.20849985e-02,
                     6.73794700e-03, 1.11089963e-02, 1.35335281e-01, 6.06530666e-01,
                     1.00000000e+00, 6.06530666e-01, 1.35335281e-01, 1.11089963e-02,
                     6.73794700e-03, 8.20849985e-02, 3.67879450e-01, 6.06530666e-01,
                     3.67879450e-01, 8.20849985e-02, 6.73794700e-03, 1.50343915e-03,
                     1.83156393e-02, 8.20849985e-02, 1.35335281e-01, 8.20849985e-02,
                     1.83156393e-02, 1.50343915e-03, 1.23409802e-04, 1.50343915e-03,
                     6.73794700e-03, 1.11089963e-02, 6.73794700e-03, 1.50343915e-03,
                     1.23409802e-04]).reshape(7,7)
    elif sigma == n:
        return g_inp
    else:
        raise Exception('Gaussian {} Not Implement'.format(sigma))

def DrawGaussian(img, pt, sigma):
    tmpSize = int(np.math.ceil(3 * sigma))
    ul = [int(np.math.floor(pt[0] - tmpSize)), int(np.math.floor(pt[1] - tmpSize))]
    br = [int(np.math.floor(pt[0] + tmpSize)), int(np.math.floor(pt[1] + tmpSize))]

    if ul[0] > img.shape[1] or ul[1] > img.shape[0] or br[0] < 1 or br[1] < 1:
        return img

    size = 2 * tmpSize + 1
    g = Gaussian(size)
    g_x = [max(0, -ul[0]), min(br[0], img.shape[1]) - max(0, ul[0]) + max(0, -ul[0])]
    g_y = [max(0, -ul[1]), min(br[1], img.shape[0]) - max(0, ul[1]) + max(0, -ul[1])]

    img_x = [max(0, ul[0]), min(br[0], img.shape[1])]
    img_y = [max(0, ul[1]), min(br[1], img.shape[0])]

    img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
    return img

def get_heatmap(points, shape, new_shape=(51,51), joints=6):
    height, width = shape[0], shape[1]
    points[:, 0] = (points[:, 0] / width) * new_shape[1]
    points[:, 1] = (points[:, 1] / height) * new_shape[0]
    points = points.astype(int)
    heatmaps = np.zeros((joints, new_shape[0], new_shape[1]))
    for i in range(joints):
        heatmaps[i] = DrawGaussian(heatmaps[i], (points[i][0], points[i][1]), 1)
    return heatmaps



def transform2D(x):
    x = x.decode("utf-8").split(",")
    x = np.array(x).astype(float).reshape(-1,2)
    x = np.concatenate((x[14:17], x[22:26]))
    heatmaps = get_heatmap(x, (800, 1200))
    return heatmaps



    





In [None]:
train_url = glob.glob("/egopose-data/web-datasets/xr*train*tar")
"/egopose-data/web"
# test_url = glob.glob("/egopose-data/web-datasets/mo2cap2_test/mo*train*tar")
test_url = glob.glob("/egopose-data/web-datasets/xr*test*")
print(len(test_url))
ds = wds.WebDataset(train_url).decode("pil").to_tuple("pose_image.png", "pose_image_depth.png", "pose_3dp.csv", "pose_2dp.csv").map_tuple(image_preproc, image_preproc, transform, transform2D)
train_loader = DataLoader(ds.batched(1), num_workers=1, batch_size=None)

ds_test = wds.WebDataset(test_url).decode("pil").shuffle(1000, initial=1000).to_tuple("pose_image.png", "pose_image_depth.png", "pose_3dp.csv", "pose_2dp.csv",handler=ignore_and_continue).map_tuple(image_preproc, image_preproc, transform, transform2D)
test_loader = DataLoader(ds_test, num_workers=8, batch_size=32)
 
self_pose = SelfPose()

dataloaders = {"train": train_loader, "val": test_loader}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
self_pose.load_state_dict(torch.load("./ckpts/resattnet/14_epoch_1.1127434143256608e-05.pth"))
self_pose.to(device)
print()

In [None]:
# # imgs = []
# for im,_, _, _ in train_loader:
#     im = np.transpose(im.cpu().numpy()[0], (1,2,0))
#     im = cv2.resize(im, (500,368)) * 255
#     im = im.astype("uint8")
#     imgs.append(im)
#     plt.imshow(im)
#     plt.show()
#     break
# len(imgs)

In [None]:
# blur =  iaa.Sequential(

#                  iaa.OneOf([
#                     iaa.GaussianBlur((3, 5.0)), # blur images with a sigma between 0 and 3.0
#                     iaa.AverageBlur(k=(4, 9)), # blur image using local means with kernel sizes between 2 and 7
#                     iaa.MedianBlur(k=(5, 9)), # blur image using local medians with kernel sizes between 2 and 7
#                 ]) )
# contrast =iaa.LinearContrast((0.5, 2.0), per_channel=0.5)
# hue_sat = iaa.AddToHueAndSaturation()

# do= iaa.OneOf([
#                     iaa.Dropout((0.01, 0.1), per_channel=0.5), # randomly remove up to 10% of the pixels
#                     iaa.CoarseDropout((0.03, 0.15), size_percent=(0.02, 0.05), per_channel=0.2),
#                 ])
# simp = iaa.SimplexNoiseAlpha(iaa.OneOf([
#                     iaa.EdgeDetect(alpha=(0.2, 0.6)),iaa.DirectedEdgeDetect(alpha=(0.5, 1.0), direction=(0.0, 1.0))]))


                   
# dos = do(images= imgs)
# cons = contrast(images= imgs)
# hues = hue_sat(images= imgs)
# blurs = blur(images= imgs)
# sims = simp(images = imgs)

In [None]:

# for im in sims:
#     plt.axis('off')
#     plt.imshow(im)
#     plt.show()

In [None]:
def calculate_error(pred, gt):
    pred = pred.reshape((-1, 6, 3))
    gt = gt.reshape((-1, 6, 3))
    error = np.sqrt(np.sum((pred-gt) ** 2, axis=2))
    return error

def inference(models):
    results = []
    for model in models:
        model.eval()   # Set model to evaluate mode
    phase="val"
    errors = []
    count = 0
    try:
        for inputs,_, labels,hm in dataloaders["val"]: 
            inputs = inputs.to(device).float()
#             hm = np.sum(hm[0].numpy(), 0)
#             plt.imshow(hm)
#             plt.show()
            count += inputs.shape[0]
#             print("before")
            results, hm, hm_pred = infer_multiple(models[0], inputs)#
#             print("what")
            results = results.cpu().detach().numpy()
#             print("a")
#             hm = np.sum(hm[0].cpu().detach().numpy(), 0)
#             print(hm.shape)
#             hm = hm[0].cpu().detach().numpy()[3]
#             hm_pred = np.sum(hm_pred[0].cpu().detach().numpy(), 0)
#             print("hain")
#             print(hm.shape)
#             print(np.mean(hm), np.max(hm))
#             hm[hm<0.01] = 0
#             plt.imshow(hm)
#             plt.show()
#             fig, axs = plt.subplots(2, 2)
            
#             axs[0, 0].imshow(hm)
#             axs[0, 0].set_title('Axis [0, 0]')
#             axs[0, 1].imshow(hm_pred)
#             axs[0, 1].set_title('Axis [0, 1]')
#             axs[1, 0].imshow(hm_pred)
#             axs[1, 0].set_title('Axis [1, 0]')
#             axs[1, 1].imshow(hm)
#             axs[1, 1].set_title('Axis [1, 1]')
            
            labels = labels.reshape(-1, 6, 3).cpu().numpy() # 17 joints for visualizatoin
            error = calculate_error(results, labels) # 4-10 are both arm joints
            err = np.mean(error)
            print( "count: ", count, "err", err)
            yield results, labels, error
    except Exception as e:
        print("\n\n\nException as e:", e)
        return
            

def plot_inference(model, plot_max=5):
    for i, (outputs, labels, error) in enumerate(inference(model)):
        pred = labels.copy()
        pred[:, 4:10] = outputs[0] # replace orignal arm joints with predicted arm joints in kinematic tree
        plot(pred[0], labels[0]) #plot first prediction in batch
        if (i+1) >= plot_max:
            break
            
def mean_error(model, total_batches=float("inf")):
    errors = []
    results = [[] for i in range(21)]
    gts = []
    for i, (outputs, gt, error) in enumerate(inference(model)):

        errors.extend(error)
        gts.extend(gt)
#         for j in range(len(outputs)):
#             results[j].extend(outputs[j])
        if i == total_batches:
            break
    print(f"Mean Error of all batches: {np.mean(errors)*10} millimeters ")
    return results, gts, errors

def infer_multiple(model: list, inputs):
    results = []
    with torch.set_grad_enabled(False):
        outputs, hm, pred_hm,_, _= model(inputs)
#         outputs = model(inputs)
#         print(outputs.shape, hm.shape, pred_hm.shape)
#         outputs = [o.view(-1, 6, 3).cpu().numpy() * 2000 for o in outputs]# both arms 
        return outputs * 300, hm, pred_hm
#         return outputs * 300, "", ""
    

results, gts, errors = mean_error([self_pose], 400)

In [None]:
np_errors = np.stack(errors)
# np.save("resnset18.npy", np_errors)

In [None]:
(np.mean(np_errors, 0) * 10)

In [None]:
array([ 35.25950946,  55.16844587, 103.40663034,  30.61601325,
        47.45479543,  93.35602984]) 
# VAE

array([28.07453525, 52.11339807, 99.06696299, 29.21238883, 48.14259912,
       89.86875672])
# AE

array([25.46017313, 40.54446942, 65.40962568, 21.48813598, 34.28043555,
       65.77914521])

42.23 + 59.53 + 110.44 + 37.42 + 52.90 + 99.42

In [None]:
28.07 & 52.11 & 99.06 & 29.21 & 48.14 & 89.86 
 35.25 &  55.16 103.40  & 30.61 & 47.45& 93.35
25.46 & 40.54 & 65.40 & 21.48 & 34.28 &65.77

# VAE uncertainty

Average standard deviation of each joint for 10 random points per image
                    x(mm)       y(mm)       z(mm)
 Left Shoulder      0.08918472  0.09157255  0.0698511 
 Left Elbow         0.11902293  0.12438361  0.127807
 Left Hand          0.15035595, 0.15886429, 0.16023856
 Right Shoulder     0.07806852, 0.08045479, 0.05805141
 Right Elbow        0.09848355, 0.10558292, 0.09553152
 Right Hand         0.14335191, 0.16634406, 0.15219972

# 5 Model uncertainty
Average standard deviation of each joint for 5 models trained from random initialization
                    x(mm)       y(mm)       z(mm)
 Left Shoulder      0.77400464, 0.9956709 , 0.6167278 
 Left Elbow         1.3933133 , 1.9136935 , 1.6917565
 Left Hand          2.4741955 , 2.9097495 , 3.5901506
 Right Shoulder     0.86049104, 1.019562  , 0.6316848
 Right Elbow        1.3152746 , 1.8138525 , 1.3953665
 Right Hand         2.0660796 , 2.756212  , 2.9830382