In [1]:
%matplotlib notebook 
from model.self_pose import SelfPose
from dataloader import XRegoDataset 
from torch.utils.data import DataLoader
import time
import copy
import torch
import numpy as np
import matplotlib.pyplot as plt
import webdataset as wds
from torchvision import transforms
import glob
from webdataset.handlers import ignore_and_continue
from mpl_toolkits.mplot3d import axes3d 
import typing as tp
import numpy as np
import matplotlib.pyplot as plt


In [2]:
preproc = transforms.Compose([
    transforms.Resize((368,368)),
    transforms.ToTensor(),
])

def transform(x):
    x = x.decode("utf-8").split(",")
    x = np.array(x).astype(float).reshape(-1,3)
    x = np.concatenate((x[6:7], x[2:5], x[14:17], x[22:25], x[5:6], x[29:32], x[35:38]))

    return x.reshape(-1)

def image_preproc(x):
    return preproc(x)



## Plot Kinematic Tree

In [3]:


def get_chain_dots(
        joints: np.ndarray,   # shape == (n_dots, 3)
        chain_dots_indexes: tp.List[int], # length == n_dots_in_chain
                                          # in continuous order, i.e. 
                                          # left_hand_ix >>> chest_ix >>> right_hand_ix
        ) -> np.ndarray:    # chain of dots
    return joints[chain_dots_indexes]


def get_chains(
        joints,   # shape == (n_dots, 3)
        head: tp.List[int], # pelvis >>> chest >>> head
        spine: tp.List[int], # left_leg >>> pelvis >>> right_leg
        arm1: tp.List[int],
        arm2,
        neck,
        right_leg,
        left_leg,
        ):
    return (get_chain_dots(joints, head),
            get_chain_dots(joints, spine),
            get_chain_dots(joints, arm1),
            get_chain_dots(joints, arm2),
            get_chain_dots(joints, neck),
            get_chain_dots(joints, right_leg),
            get_chain_dots(joints, left_leg),
            )


def subplot_nodes(dots: np.ndarray, # shape == (n_dots, 3)
                  ax):
    return ax.scatter3D(*dots.T, c=dots[:, -1])


def subplot_bones(chains: tp.Tuple[np.ndarray, ...], ax):
    return [ax.plot(*chain.T) for chain in chains]


def plot_skeletons(skeletons, chains_ixs):
    fig = plt.figure(figsize=(10, 5))
    for i, joints in enumerate(skeletons, start=1):
        chains = get_chains(joints, *chains_ixs)
        ax = fig.add_subplot(1, 2, i, projection='3d')
        subplot_nodes(joints, ax)
        subplot_bones(chains, ax)
    plt.show()
    
def plot(gt_pose, pred_pose):
    # Kinematic Tree for visualization
    # each list create bones between joints
    chains_ixs = ([0],[1, 2, 3],[3, 4, 5, 6],[3, 7, 8, 9],[3, 10],[1, 11,12,13],[1, 14,15,16]) 
    plot_skeletons([gt_pose, pred_pose], chains_ixs)
 
       


In [6]:
def load_models():
    paths = ["./ckpts/57mm.pth", "./ckpts/model2/5_epoch_0.00025778871799314704.pth",
             "./ckpts/model3/6_epoch_2.2495238057805626e-05.pth", "./ckpts/model4/5_epoch_0.0002509588166198228.pth",
             "./ckpts/model5/6_epoch_0.0002385608078126097.pth"]
    models = []
    for path in paths:
        models.append(SelfPose())
        models[-1].load_state_dict(torch.load(path))
        models[-1].cuda()
    return models

In [7]:
train_url = glob.glob("/egopose-data/web-datasets/xr*train*tar")
"/egopose-data/web"
test_url = glob.glob("/egopose-data/web-datasets/xr*test*tar")
print(len(test_url))
ds = wds.WebDataset(train_url).decode("pil").to_tuple("pose_image.png", "pose_image_depth.png", "pose_3dp.csv").map_tuple(image_preproc, image_preproc, transform)
train_loader = DataLoader(ds.batched(1), num_workers=1, batch_size=None)

ds_test = wds.WebDataset(test_url).decode("pil").to_tuple("pose_image.png", "pose_image_depth.png","pose_3dp.csv", handler=ignore_and_continue).map_tuple(image_preproc, image_preproc, transform)
test_loader = DataLoader(ds_test, num_workers=8, batch_size=64)
 
# self_pose = SelfPose()
dataloaders = {"train": train_loader, "val": test_loader}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
# self_pose.load_state_dict(torch.load("./ckpts/57mm.pth"))
# self_pose.to(device)
models = load_models()
print()

26



In [12]:
def calculate_error(pred, gt):
    pred = pred.reshape((-1, 6, 3))
    gt = gt.reshape((-1, 6, 3))
    error = np.sqrt(np.sum((pred-gt) ** 2, axis=2))
    return np.mean(error)

def inference(models):
    results = []
    for model in models:
        model.eval()   # Set model to evaluate mode
    phase="val"
    errors = []
    count = 0
#     try:
    for inputs, depth, labels in dataloaders["val"]: 
        inputs = inputs.to(device).float()
        depth =  depth.to(device).float()
        count += inputs.shape[0]
        results = infer_multiple(models, inputs)
        labels = labels.reshape(-1, 17, 3).cpu().numpy() # 17 joints for visualizatoin
#         err = calculate_error(outputs, labels[:, 4:10]) # 4-10 are both arm joints
        print( "count: ", count)
        yield results, labels, "err"
#     except Exception as e:
#         print("\n\n\nException as e:", e)
#         return
            

def plot_inference(model, plot_max=5):
    for i, (outputs, labels, error) in enumerate(inference(model)):
        pred = labels.copy()
        pred[:, 4:10] = outputs # replace orignal arm joints with predicted arm joints in kinematic tree
        plot(pred[0], labels[0]) #plot first prediction in batch
        if i >= plot_max:
            break
            
def mean_error(model, total_batches=float("inf")):
    errors = []
    for i, (outputs, gt, error) in enumerate(inference(model)):
        errors.append(error)
        if i >= total_batches:
            break
            
    print(f"Mean Error of all batches: {np.mean(errors)*10} millimeters ")
        
def infer_multiple(models: list, inputs):
    results = []
    with torch.set_grad_enabled(False):
        for model in models:
            outputs, _, _ = model(inputs)
            outputs = outputs.view(-1, 6, 3).cpu().numpy() * 300 # both arms 
            results.append(outputs)
    return results
    
def process_models(models, batches=float('inf')):
    output_list = [[]* len(models)]
    for outputs, gt, _ in (inference(models)):
        for i, output in enumerate(outputs):
            output_list[i].extend(output)
    return output_list
        

# mean_error(self_pose)
output_list = process_models(models, 5)

count:  64


NameError: name 'err' is not defined