In [None]:
%load_ext autoreload
%autoreload 2

In [1]:
import os  # noqa
import sys  # noqa

proj_root = os.path.dirname(os.getcwd())
sys.path.append(proj_root)

OBJ_NAME = "mustard_bottle"
VIDEO_NAME = "mustard0"

VIS_CONF_THRESHOLD = 0.9


video_dir = os.path.join(proj_root, "data", "inputs", VIDEO_NAME)
tracker_result_video = os.path.join(video_dir)
obj_dir = os.path.join(proj_root, "data", "objects", OBJ_NAME)

In [2]:
import matplotlib.pyplot as plt
import torch
from posingpixels.datasets import YCBinEOATDataset, load_video_images
from posingpixels.utils.cotracker import visualize_results
from posingpixels.utils.evaluation import get_gt_tracks
from posingpixels.pnp import GradientPnP


from posingpixels.utils.cotracker import unscale_by_crop

from posingpixels.utils.evaluation import compute_add_metrics

from posingpixels.pointselector import SelectMostConfidentPoint
from posingpixels.utils.evaluation import compute_tapvid_metrics


import mediapy
from posingpixels.utils.geometry import (
    apply_pose_to_points_batch,
    render_points_in_2d_batch,
)
from posingpixels.visualization import overlay_bounding_box_on_video

import random
import numpy as np
from posingpixels.alignment import CanonicalPointSampler
from posingpixels.cotracker import CropCoPoseTracker
from posingpixels.pnp import OpenCVePnP
from posingpixels.cotracker import CoTrackerInput
from posingpixels.pointselector import SelectMostConfidentView
from posingpixels.utils.cotracker import get_ground_truths

torch.manual_seed(42)
random.seed(0)
np.random.seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")



In [3]:
from inference import create_reference_database_from_RGB_images
from posingpixels.utils.gs_pose import create_or_load_gaussian_splat_from_ycbineoat, load_model_net


dataset = YCBinEOATDataset(video_dir, obj_dir, use_cad_rgb=True, use_cad_mask=True)
model_net = load_model_net(os.path.join(proj_root, "checkpoints/model_weights.pth"))

ref_database = create_or_load_gaussian_splat_from_ycbineoat(dataset, model_net, device=device)





FPS indices:  tensor([  0, 502, 103, 564,  91, 582, 112, 548], device='cuda:0')
[100/737], 13-14:54:27
[200/737], 13-14:54:30
[300/737], 13-14:54:32
[400/737], 13-14:54:35
[500/737], 13-14:54:38
[600/737], 13-14:54:41
[700/737], 13-14:54:43
Creating 3D-OGS model for mustard_bottle 
Output folder: /home/joao/Documents/repositories/GSPose/data/objects/mustard_bottle
Reading 737  training image ...
737 training samples
-----------------------------------------
4 testing samples
----------------------------------------
Loading Training Cameras
Loading Test Cameras
Number of points at initialisation :  4350


3DGO modeling progress:  23%|██▎       | 6830/30000 [06:04<03:06, 124.27it/s, Loss=nan]  

KeyboardInterrupt: 

In [None]:
from posingpixels.utils.gs_pose import render_gaussian_model, render_gaussian_model_with_info
from posingpixels.visualization import get_gaussian_splat_pointcloud, plot_pointclouds

frame_idx = 0

gaussian_object = ref_database["obj_gaussians"]

object_pointcloud = get_gaussian_splat_pointcloud(gaussian_object)
plot_pointclouds({OBJ_NAME.capitalize(): object_pointcloud}, "Gaussian Object")

initial_pose = dataset.get_gt_pose(frame_idx)
initial_R, initial_T = initial_pose[:3, :3], initial_pose[:3, 3]

render = render_gaussian_model_with_info(
    gaussian_object, dataset.K, dataset.H, dataset.W, R=initial_R, T=initial_T
)

mask = dataset.get_mask(frame_idx) > 0.7
x1, y1, x2, y2 = np.min(np.where(mask)[1]) - 5, np.min(np.where(mask)[0]) - 5, np.max(np.where(mask)[1]) + 5, np.max(np.where(mask)[0]) + 5

image = render['image'][y1:y2, x1:x2]
alpha = render['alpha'][0, y1:y2, x1:x2].detach().cpu().numpy()
depth = render['depth'][0, y1:y2, x1:x2].detach().cpu().numpy()
cad_depth = dataset.get_cad_depth(frame_idx)[y1:y2, x1:x2]
mask = dataset.get_mask(frame_idx)[y1:y2, x1:x2] > 0.7
rgb = dataset.get_rgb(frame_idx)[y1:y2, x1:x2]

image_high_alpha = image.copy()
image_high_alpha[alpha < 0.9] = 0

plt.imshow(image)
plt.title("Gaussian Splat Rendered")
plt.show()

plt.imshow(image_high_alpha)
# plt.imshow(alpha > 0.99, alpha=0.5)
plt.title("Gaussian Splat Rendered, alpha > 0.9")
plt.show()

plt.imshow(cad_depth)
plt.imshow(rgb, alpha=0.5)
plt.title("CAD Depth")
plt.show()

plt.imshow(cad_depth > 0)
plt.imshow(mask, alpha=0.5)
plt.title("CAD Mask vs Actual Mask")
plt.show()


plt.imshow(image)
plt.imshow(mask, alpha=0.5)
plt.title("Actual Mask vs Rendered Splat (Showcasing Ghost Effect)")
plt.show()

# plt.imshow(alpha > 0.99)
# plt.show()
# plt.imshow(depth)
# plt.show()

# # Overlap alpha with image 50% transparency
# plt.imshow(image)
# plt.imshow(alpha > 0.99, alpha=0.7)
# plt.show()

In [None]:
pnp_solver = OpenCVePnP(min_inliers=20, ransac_inliner_threshold=2.0)
# tracker = CoMeshTracker(
#     dataset,
#     None,
#     # support_grid=10,
#     offline=False,
#     # crop=False,
#     # visible_background=True,
#     # downcast=True,
#     # better_initialization=False,
#     # limit=100,
#     # interpolation_steps=80,
#     axis_rotation_steps=40,
#     final_interpolation_steps=40,
#     query_frames=[0, 10, 20, 30],
#     device=device,
# )

point_sampler = CanonicalPointSampler()
tracker = CropCoPoseTracker(
    canonical_point_sampler=point_sampler,
    # pnp_solver=pnp_solver,
    pose_interpolation_steps=1,
)

In [None]:
dataset.reset_frame_range()
with torch.no_grad():
    pred_tracks, pred_visibility, pred_confidence, pred_tracks_original, tracker_input = (
        tracker(dataset)
    )
Q = tracker_input.num_query_points
N = len(tracker_input)
# Pickle the results
import pickle

with open(os.path.join(video_dir, "tracker_results.pkl"), "wb") as f:
    pickle.dump(
        {
            "pred_tracks": pred_tracks,
            "pred_visibility": pred_visibility,
            "pred_confidence": pred_confidence,
            "pred_tracks_original": pred_tracks_original,
            "N": N,
            "Q": Q,
            "bboxes": tracker_input.bboxes[tracker_input.prepend_length:],
            "scaling": tracker_input.scaling[tracker_input.prepend_length:],
            "video": load_video_images(tracker_input.video_dir)[:, tracker_input.prepend_length:],
        },
        f,
    )
    
# Load the results
import pickle

with open(os.path.join(video_dir, "tracker_results.pkl"), "rb") as f:
    results = pickle.load(f)
    pred_tracks = results["pred_tracks"]
    pred_visibility = results["pred_visibility"]
    pred_confidence = results["pred_confidence"]
    pred_tracks_original = results["pred_tracks_original"]
    N = results["N"]
    Q = results["Q"]
    bboxes = results["bboxes"]
    scaling = results["scaling"]
    video = results["video"]
    

In [None]:
def choose_best(
    tracker_input: CoTrackerInput, pred_tracks, pred_visibility, pred_confidence, view=False
):
    true_indexes = torch.tensor(tracker_input.query_to_point_indexes, device=device)
    query_lengths = torch.tensor(tracker_input.query_lengths, device=device)

    if not view:
        point_selector = SelectMostConfidentPoint(
            tracker_input.num_canonical_points, true_indexes, query_lengths
        )
    else:
        point_selector = SelectMostConfidentView(
            tracker_input.num_canonical_points, true_indexes, query_lengths
        )

    best_coords, best_vis, best_conf, best_indices = point_selector.query_to_point(
        pred_tracks[0],
        pred_visibility[0],
        pred_confidence[0],
        # pred_visibility[0] * pred_confidence[0],
    )
    best_coords = best_coords.unsqueeze(0)
    best_vis = best_vis.unsqueeze(0)
    best_conf = best_conf.unsqueeze(0)

    best_coords_original = unscale_by_crop(
        best_coords[0],
        torch.tensor(tracker_input.bboxes).to(device),
        torch.tensor(tracker_input.scaling).to(device),
    ).unsqueeze(0)

    return best_coords, best_vis, best_conf, best_coords_original, best_indices


def estimate_poses(
    tracker_input: CoTrackerInput, best_coords_original, best_vis, best_conf
):
    N = len(tracker_input)
    K = tracker_input.dataset.K
    x = (
        torch.tensor(tracker_input.canonical_points, dtype=torch.float32)
        .to(device)
        .unsqueeze(0)
        .repeat(N, 1, 1)
    )
    y = best_coords_original.detach().clone().squeeze(0)[:N]

    weights = (best_vis * best_conf).float()[:N]
    weights[best_vis * best_conf < VIS_CONF_THRESHOLD] = 0
    weights = weights.squeeze(0)

    camKs = torch.tensor(K[np.newaxis, :], device=device).float()

    epnp_cv_solver = OpenCVePnP(
        X=x[0],
        K=camKs,
    )
    epnp_cv_R, epnp_cv_T, err = epnp_cv_solver(
        y, X=x, K=torch.tensor(K).to(device).float(), weights=weights
    )

    epnp_cv_poses = torch.eye(4).to(device).unsqueeze(0).repeat(N, 1, 1)
    epnp_cv_poses[:, :3, :3] = epnp_cv_R
    epnp_cv_poses[:, :3, 3] = epnp_cv_T
    return epnp_cv_poses, err

# def improve_poses(
#     x, y, K, poses, weights
# ) -> torch.Tensor:
#     # Initialize the optimizer to the current pose
    
#     # Try every pose against every frame, take the best (using reprojection error w/ Huber loss)
    


with torch.no_grad():
    (
        pred_tracks_batch,
        pred_confidence_batch,
        pred_visibility_batch,
        pred_poses_batch,
        pred_poses_err,
        best_indices_batch
    ) = [], [], [], [], [], []
    dataset.reset_frame_range()
    start_pose = dataset.get_gt_pose(0)
    step = 32
    overlap = 0
    tracks = vis = conf = track_input = best_coords = best_conf = best_vis = None
    for i in range(0, dataset.max_frames, step - overlap):
        dataset.start_frame = i
        dataset.end_frame = min(i + step, dataset.max_frames)
        rgb = dataset.get_rgb(0)
        # plt.imshow(rgb)
        # plt.show()
        print(f"Processing frames {dataset.start_frame} to {dataset.end_frame}")
        # TODO: Force pose to always be in view, and if it is too far gone, do not include it
        # start_pose[:3, 3] = np.array([0, 0, dataset._get_safe_distance()]) # TODO: Doesn't give right perspective, but ensures it's always in view
        rgb, depth, _ = dataset.render_mesh_at_pose(start_pose)
        # TODO: It's good to initialize every point (maybe with a confidence penalty for the ones in the dynamic template)
        # Initialzie dynamic ones from the ones in the best_coords for that point (same with conf and vis)
        if tracks is not None and overlap > 0:
            assert (
                vis is not None
                and conf is not None
                and track_input is not None
                and best_coords is not None
                and best_conf is not None
                and best_vis is not None
            )
            last_specific_length = track_input.query_lengths[-1] if len(track_input.query_lengths) > 4 else -tracks.shape[2]
            forced_coords = tracks[:, -overlap:, :-last_specific_length]
            forced_vis = vis[:, -overlap:, :-last_specific_length]
            forced_vis = torch.logit(forced_vis).clamp(-tracker.init_value, tracker.init_value)
            forced_conf = conf[:, -overlap:, :-last_specific_length]
            forced_conf = torch.logit(forced_conf).clamp(-tracker.init_value, tracker.init_value)
        else:
            forced_coords = forced_vis = forced_conf = None
        print(start_pose)
        tracks, vis, conf, tracks_original, track_input = tracker(
            dataset,
            start_pose=start_pose,
            query_poses=start_pose[np.newaxis],
            forced_coords=forced_coords,
            forced_vis=forced_vis,
            forced_conf=forced_conf,
        )

        Q = track_input.num_query_points
        N = len(track_input)

        best_coords, best_vis, best_conf, best_coords_original, best_indices = (
            choose_best(track_input, tracks, vis, conf, view=True)
        )
        best_indices_batch.append(best_indices[track_input.prepend_length :])
        

        poses, err = estimate_poses(track_input, best_coords_original, best_vis, best_conf)
        poses = poses[
            track_input.prepend_length :
        ]
        if err is not None:
            err = err[
                track_input.prepend_length :
            ]
        start_pose = poses[-1].detach().cpu().numpy()
        
        video = load_video_images(track_input.video_dir, limit=N)
        visualize_results(
            video,
            tracks,
            vis,
            conf,
            tracker_result_video + f"_{i}",
            num_of_main_queries=track_input.num_query_points,
        )
        
        visualize_results(
            video,
            best_coords,
            best_vis,
            best_conf,
            tracker_result_video + f"_{i}_best",
            num_of_main_queries=track_input.num_canonical_points,
        )
        
        best_coords, best_vis, best_conf, best_coords_original, best_indices = (
            choose_best(track_input, tracks, vis, conf, view=False)
        )


        print(tracks.shape, vis.shape, conf.shape)
        tracks = tracks[:, track_input.prepend_length :]
        vis = vis[:, track_input.prepend_length :]
        conf = conf[:, track_input.prepend_length :]
        print(tracks.shape, vis.shape, conf.shape)

        pred_tracks_batch.append(tracks.cpu().numpy())
        pred_visibility_batch.append(vis.cpu().numpy())
        pred_confidence_batch.append(conf.cpu().numpy())
        pred_poses_batch.append(poses.cpu().numpy())
        pred_poses_err.append(err)


In [None]:
# Pickle the results
import pickle

with open(tracker_result_video + ".pkl", "wb") as f:
    pickle.dump(
        (
            pred_tracks_batch,
            pred_visibility_batch,
            pred_confidence_batch,
            pred_poses_batch,
            pred_poses_err,
            best_indices_batch,
            track_input.canonical_points,
            step,
            overlap
        ),
        f,
    )

# Load the results

with open(tracker_result_video + ".pkl", "rb") as f:
    (
        pred_tracks_batch,
        pred_visibility_batch,
        pred_confidence_batch,
        pred_poses_batch,
        pred_poses_err,
        best_indices_batch,
        canonical_points,
        step,
        overlap
    ) = pickle.load(f)

In [None]:
dataset.reset_frame_range()
N = len(dataset)
video = load_video_images(tracker_input.video_dir)[:, -N:]
# init_video = load_video_images(tracker_input.prepend_dir, limit=N, file_type="jpg")
video_original = load_video_images(dataset.video_rgb_dir)[:, -N:]
# full_video = torch.cat([init_video, video_original], dim=1)[:, :N]

In [None]:
print(pred_poses_batch)
print(len(pred_poses_batch))
print(pred_poses_batch[0].shape)
pred_poses = []
for i, pose_batch in enumerate(pred_poses_batch):
    pred_poses.append(torch.tensor(pose_batch[int(i > 0) * overlap:]).float().to(device))
epnp_slide_poses = torch.cat(pred_poses, dim=0).detach().cpu().numpy()
print(epnp_slide_poses.shape)
K = dataset.K
camKs = torch.tensor(K[np.newaxis, :], device=device).float()
gt_poses = torch.tensor(dataset.get_gt_poses()).float().to(device)
print(gt_poses.shape)

video_permuted = video_original[0].permute(0, 2, 3, 1)
bbox_video = overlay_bounding_box_on_video(
    video_permuted.detach().cpu().numpy(),
    dataset.bbox.float(),
    camKs.repeat(len(dataset), 1, 1).cpu(),
    gt_poses.detach().cpu().numpy(),
)
bbox_video = overlay_bounding_box_on_video(
    bbox_video,
    dataset.bbox.float(),
    camKs.repeat(len(dataset), 1, 1).cpu(),
    epnp_slide_poses,
    color=(255, 0, 0),
)
mediapy.show_video(bbox_video, fps=15)

In [None]:
N = len(dataset)
K = dataset.K
x = (
    torch.tensor(canonical_points, dtype=torch.float32)
    .to(device)
    .unsqueeze(0)
    .repeat(N, 1, 1)
)
num_canonical_points = x.shape[1]

best_tracks_slide = torch.zeros(1, N, num_canonical_points, 2).to(device)
best_vis_slide = torch.zeros(1, N, num_canonical_points).to(device)
best_conf_slide = torch.zeros(1, N, num_canonical_points).to(device)
for i, coords, vis, conf, indices in zip(range(0, len(pred_tracks_batch)), pred_tracks_batch, pred_visibility_batch, pred_confidence_batch, best_indices_batch):
    coords = torch.tensor(coords).float().to(device)
    vis = torch.tensor(vis).float().to(device)
    conf = torch.tensor(conf).float().to(device)
    indices = torch.tensor(indices).int().to(device)
    valid_indices = indices[indices >= 0]
    invalid_indices_frames, invalid_indices_idx = torch.where(indices < 0)
    left = i * (step - overlap)
    right = min(left + step, dataset.max_frames)
    frames_window = torch.arange(left, right, device=device).unsqueeze(-1)
    if i > 0:
        right_overlap = min(left + overlap, dataset.max_frames)
        best_tracks_slide_overlap = best_tracks_slide[:, left:right_overlap]
        best_vis_slide_overlap = best_vis_slide[:, left:right_overlap]
        best_conf_slide_overlap = best_conf_slide[:, left:right_overlap]
    best_tracks_slide[:, frames_window[:, 0]] = coords[:, frames_window - left, indices]
    best_vis_slide[:, frames_window[:, 0]] = vis[:, frames_window - left, indices]
    best_conf_slide[:, frames_window[:, 0]] = conf[:, frames_window - left, indices]
    best_conf_slide[:, invalid_indices_frames + left, invalid_indices_idx] = 0
    best_vis_slide[:, invalid_indices_frames + left, invalid_indices_idx] = 0
    if i > 0:
        best_tracks_slide[:, left:right_overlap] = best_tracks_slide_overlap
        best_vis_slide[:, left:right_overlap] = best_vis_slide_overlap
        best_conf_slide[:, left:right_overlap] = best_conf_slide_overlap
        
    

visualize_results(
    video,
    best_tracks_slide,
    best_vis_slide,
    best_conf_slide,
    tracker_result_video + "_slide",
    num_of_main_queries=num_canonical_points,
)
print(num_canonical_points, tracker_input.num_canonical_points)


y_slide = unscale_by_crop(
        best_tracks_slide[0],
        torch.tensor(bboxes).to(device),
        torch.tensor(scaling).to(device),
    )

# y = best_tracks_slide.detach().clone().squeeze(0)
init_poses = torch.tensor(epnp_slide_poses, device=device).float()[:N]

weights_slide = (best_vis_slide * best_conf_slide).float()
weights_slide[best_vis_slide * best_conf_slide < VIS_CONF_THRESHOLD] = 0
weights_slide = weights_slide.squeeze(0)
   
print(x.shape, camKs.shape, init_poses.shape, y_slide.shape, weights_slide.shape)
gradient_pnp = GradientPnP(
    max_lr=0.02,
    temporal_consistency_weight=1,
    X=x[0],
    K=camKs,
)

rotations, translations, all_results = gradient_pnp(
    y_slide,
    weights=weights_slide,
    R=init_poses[:, :3, :3],
    T=init_poses[:, :3, 3],
)

gradient_poses_slide = torch.eye(4).to(device).unsqueeze(0).repeat(N, 1, 1)
gradient_poses_slide[:, :3, :3] = rotations
gradient_poses_slide[:, :3, 3] = translations

In [None]:
from posingpixels.pointselector import SelectMostConfidentView


true_indexes = torch.tensor(tracker_input.query_to_point_indexes, device=device)
query_lengths = torch.tensor(tracker_input.query_lengths, device=device)

# tensor_query_to_point = torch.tensor(tracker.query_to_point, device=device)
# true_indexes = torch.nonzero(tensor_query_to_point)
# query_lengths = torch.tensor(tracker.queries_sizes, device=device)

# point_selector = SelectMostConfidentPoint(
#     tracker_input.num_canonical_points, true_indexes, query_lengths
# )
point_selector = SelectMostConfidentView(
    tracker_input.num_canonical_points, true_indexes, query_lengths
)

best_coords, best_vis, best_conf, best_indices = point_selector.query_to_point(
    pred_tracks[0],
    pred_visibility[0],
    pred_confidence[0],
    # pred_visibility[0] * pred_confidence[0],
)
best_coords = best_coords.unsqueeze(0)
best_vis = best_vis.unsqueeze(0)
best_conf = best_conf.unsqueeze(0)

best_coords_original = unscale_by_crop(
    best_coords[0],
    torch.tensor(tracker_input.bboxes).to(device),
    torch.tensor(tracker_input.scaling).to(device),
).unsqueeze(0)

In [None]:
visualize_results(
    video,
    pred_tracks[:, tracker_input.prepend_length :],
    pred_visibility[:, tracker_input.prepend_length :],
    pred_confidence[:, tracker_input.prepend_length :],
    tracker_result_video,
    num_of_main_queries=Q,
)

visualize_results(
    video,
    best_coords[:, tracker_input.prepend_length :],
    best_vis[:, tracker_input.prepend_length :],
    best_conf[:, tracker_input.prepend_length :],
    tracker_result_video,
    filename="selected_video",
)

gt_tracks, gt_visibility = get_gt_tracks(tracker_input)
visualize_results(
    video,
    torch.tensor(gt_tracks).to(device).unsqueeze(0).float()[:, tracker_input.prepend_length :],
    torch.tensor(gt_visibility).to(device).unsqueeze(0).float()[:, tracker_input.prepend_length :],
    torch.ones_like(torch.tensor(gt_visibility).to(device)).unsqueeze(0).float()[:, tracker_input.prepend_length :],
    tracker_result_video,
    num_of_main_queries=Q,
    filename="gt_video",
    threshold=VIS_CONF_THRESHOLD,
)
# pred_tracks_original = unscale_by_crop(
#     pred_tracks[0],
#     torch.tensor(tracker_input.bboxes).to(device),
#     torch.tensor(tracker_input.scaling).to(device),
# ).unsqueeze(0)

# full_video = torch.cat([init_video, video_original], dim=1)[:, : N]
# visualize_results(
#     full_video,
#     pred_tracks_original,
#     pred_visibility,
#     pred_confidence,
#     tracker_result_video,
#     num_of_main_queries=Q,
#     filename="original",
#     threshold=VIS_CONF_THRESHOLD,
# )

# gt_tracks_original, gt_visibility_original = get_gt_tracks(tracker_input, crop=False)
# visualize_results(
#     full_video,
#     torch.tensor(gt_tracks_original).to(device).unsqueeze(0).float(),
#     torch.tensor(gt_visibility_original).to(device).unsqueeze(0).float(),
#     torch.ones_like(torch.tensor(gt_visibility_original).to(device))
#     .unsqueeze(0)
#     .float(),
#     tracker_result_video,
#     num_of_main_queries=Q,
#     filename="gt_original",
#     threshold=VIS_CONF_THRESHOLD,
# )

In [None]:
query_points = tracker_input.input_query[
    np.newaxis, : tracker_input.num_canonical_points
]
print(query_points.shape)
gt_occluded = (gt_visibility.T < 0.5)[np.newaxis, :]
print(gt_occluded.shape)
gt_tracks_ = np.transpose(gt_tracks[np.newaxis, :], (0, 2, 1, 3))
print(gt_tracks_.shape)
pred_visibility_ = best_vis.permute(0, 2, 1).cpu().numpy()
pred_occluded = pred_visibility_ < 0.5
print(pred_occluded.shape)
pred_tracks_ = best_coords.permute(0, 2, 1, 3).cpu().numpy()
print(pred_tracks_.shape)
pred_confidence_ = best_conf.cpu().permute(0, 2, 1).numpy()
print(pred_confidence_.shape)

# We are only interested in evaluating points where confidence * visibility > 0.6 (B x N x T)
threshold = VIS_CONF_THRESHOLD
evaluation_points = pred_confidence_ * pred_visibility_ > threshold
# evaluation_points = np.ones_like(pred_confidence_, dtype=bool)

metrics = compute_tapvid_metrics(
    query_points=query_points,
    gt_occluded=gt_occluded,
    gt_tracks=gt_tracks_,
    pred_occluded=pred_occluded,
    pred_tracks=pred_tracks_,
    query_mode="first",
    evaluation_points=evaluation_points,
)
# Print the following metrics
print("occlusion_accuracy", metrics["occlusion_accuracy"])
print("average_jaccard", metrics["average_jaccard"])
print("average_pts_within_thresh", metrics["average_pts_within_thresh"])
print("pts_within_1", metrics["pts_within_1"])
print("jaccard_1", metrics["jaccard_1"])
print("pts_within_2", metrics["pts_within_2"])
print("jaccard_2", metrics["jaccard_2"])
print("pts_within_4", metrics["pts_within_4"])
print("jaccard_4", metrics["jaccard_4"])
print("pts_within_8", metrics["pts_within_8"])
print("jaccard_8", metrics["jaccard_8"])
print("pts_within_16", metrics["pts_within_16"])
print("jaccard_16", metrics["jaccard_16"])
print(metrics.keys())

In [None]:
# How many points are being evaluated per time step?
evaluation_points.shape
plt.plot(evaluation_points.sum(axis=(0, 1)))
plt.title(
    f"Number of points considered per time step with visibility * confidence > {threshold}"
)

In [None]:
metric_name = "per_point_pts_within_8"
values = metrics[metric_name][0]
# If we're doing a time metric, plot a line plot
# If we're doing a per point metric, plot distribution of values
plt.plot(values) if "time" in metric_name else plt.hist(values, bins=20)
# Plot vertical line on self.interpolation_steps
plt.axvline(
    tracker_input.prepend_length, color="r", linestyle="--"
) if "time" in metric_name else None
plt.title(f"{metric_name} (threshold={threshold})")
plt.xlabel("Frame" if "time" in metric_name else "Point")
plt.ylabel("Value")

In [None]:
# ==========
# Input
# ==========
import time

import tqdm
from posingpixels.pnp import OpenCVePnP
from posingpixels.query_refiner import QueryRefiner
tracker_input.dataset.reset_frame_range()

K = tracker_input.dataset.K
x = (
    torch.tensor(tracker_input.canonical_points, dtype=torch.float32)
    .to(device)
    .unsqueeze(0)
    .repeat(N, 1, 1)
)
gt_poses = torch.tensor(tracker_input.gt_poses[tracker_input.prepend_length: tracker_input.prepend_length + N]).float().to(device)
gt_posed_x = apply_pose_to_points_batch(x, gt_poses[:, :3, :3], gt_poses[:, :3, 3])
y_gt = render_points_in_2d_batch(gt_posed_x, torch.tensor(K[:3, :3]).float().to(device))

y = best_coords_original.detach().clone().squeeze(0)[tracker_input.prepend_length: tracker_input.prepend_length + N]


weights = (best_vis * best_conf).float()
weights[best_vis * best_conf < VIS_CONF_THRESHOLD] = 0
weights = weights.squeeze(0)[tracker_input.prepend_length: tracker_input.prepend_length + N]

camKs = torch.tensor(K[np.newaxis, :], device=device).float()

# ==========
# ePnP
# ==========
# epnp_solver = RANSACePnP(num_iterations=100)
# epnp_solver = ePnP()
# epnp_R, epnp_T, _ = epnp_solver(
#     y,
#     X=x,
#     K=torch.tensor(K).to(device).float(),  # weights=weights
# )
# epnp_poses = torch.eye(4).to(device).unsqueeze(0).repeat(N, 1, 1)
# epnp_poses[:, :3, :3] = epnp_R
# epnp_poses[:, :3, 3] = epnp_T
# Start time
start_time = time.time()
epnp_cv_solver = OpenCVePnP(
    X=x[0],
    K=camKs,
    ransac_iterations=5000,
    ransac_inliner_threshold=2.0,
)
epnp_cv_R, epnp_cv_T, _ = epnp_cv_solver(
    y,
    # X=x,
    K=torch.tensor(K).to(device).float(), weights=weights,
)
epnp_cv_poses = torch.eye(4).to(device).unsqueeze(0).repeat(N, 1, 1)
epnp_cv_poses[:, :3, :3] = epnp_cv_R
epnp_cv_poses[:, :3, 3] = epnp_cv_T
# End time
end_time = time.time()
print(f"Time to run OpenCV ePnP: {end_time - start_time}")

# ==========
# Our Model Complete
# ==========
# query_refiner = QueryRefiner(
#     point_selector,
#     epnp_cv_solver,
#     torch.tensor(tracker_input.bboxes, device=device),
#     torch.tensor(tracker_input.scaling, device=device),
#     gt_poses[: tracker_input.prepend_length, :3, :3],
#     gt_poses[: tracker_input.prepend_length, :3, 3],
#     threshold=0.9,
#     pad_inputs=False,
# )
# coordinates = pred_tracks[0, :N].clone()
# visibility = pred_visibility[0, :N].clone()
# confidence = pred_confidence[0, :N].clone()
# optimization_results_query_refiner = []
# for i in tqdm.tqdm(range(N // query_refiner.step)):
#     if query_refiner.current >= N:
#         break
#     left = i * query_refiner.step
#     right = min((i + 1) * query_refiner.step + query_refiner.step, N)
#     (
#         coordinates[left:right],
#         visibility[left:right],
#         confidence[left:right],
#         opt_results,
#     ) = query_refiner(
#         coordinates[left:right], visibility[left:right], confidence[left:right]
#     )
#     optimization_results_query_refiner.append(opt_results)
# query_refiner_poses = torch.eye(4).to(device).unsqueeze(0).repeat(N, 1, 1)
# query_refiner_poses[:, :3, :3] = query_refiner.R
# query_refiner_poses[:, :3, 3] = query_refiner.T

# ==========
# Our model
# ==========

gradient_pnp = GradientPnP(
    max_lr=0.02,
    temporal_consistency_weight=1,
    X=x[0],
    K=camKs,
    # R=gt_poses[0, :3, :3],
    # T=gt_poses[0, :3, 3],
)

rotations, translations, all_results = gradient_pnp(
    y,
    weights=weights,
    # R=gt_poses[0, :3, :3],
    # T=gt_poses[0, :3, 3],
    # R=torch.eye(3).to(device),
    # T=torch.zeros(3).to(device),
    # R=query_refiner.R.clone(),
    # T=query_refiner.T.clone(),
    R=epnp_cv_poses[:, :3, :3].clone(),
    T=epnp_cv_poses[:, :3, 3].clone(),
)

gradient_poses = torch.eye(4).to(device).unsqueeze(0).repeat(N, 1, 1)
gradient_poses[:, :3, :3] = rotations
gradient_poses[:, :3, 3] = translations


# ==========
# Visualize
# ==========

my_predicted_poses = gradient_poses
# my_predicted_poses = epnp_poses

# video_permuted = full_video[0].permute(0, 2, 3, 1)
# bbox_video = overlay_bounding_box_on_video(
#     video_permuted[:N].detach().cpu().numpy(),
#     dataset.bbox.float(),
#     camKs.repeat(N, 1, 1).cpu(),
#     gt_poses.detach().cpu().numpy(),
# )
# bbox_video = overlay_bounding_box_on_video(
#     bbox_video,
#     dataset.bbox.float(),
#     camKs.repeat(N, 1, 1).cpu(),
#     my_predicted_poses.detach().cpu().numpy(),
#     color=(255, 0, 0),
# )
# mediapy.show_video(bbox_video[:N], fps=15)

In [None]:
# query_refiner_poses = torch.eye(4).to(device).unsqueeze(0).repeat(N, 1, 1)
# query_refiner_poses[:, :3, :3] = query_refiner.R
# query_refiner_poses[:, :3, 3] = query_refiner.T

my_predicted_poses = gradient_poses

video_permuted = video_original[0].permute(0, 2, 3, 1)
bbox_video = overlay_bounding_box_on_video(
    video_permuted[:N].detach().cpu().numpy(),
    dataset.bbox.float(),
    camKs.repeat(N, 1, 1).cpu(),
    gt_poses.detach().cpu().numpy(),
)
bbox_video = overlay_bounding_box_on_video(
    bbox_video,
    dataset.bbox.float(),
    camKs.repeat(N, 1, 1).cpu(),
    my_predicted_poses.detach().cpu().numpy(),
    color=(255, 0, 0),
)
mediapy.show_video(bbox_video[:N], fps=15)

In [None]:
my_predicted_poses = epnp_cv_poses

video_permuted = video_original[0].permute(0, 2, 3, 1)
bbox_video = overlay_bounding_box_on_video(
    video_permuted[:N].detach().cpu().numpy(),
    dataset.bbox.float(),
    camKs.repeat(N, 1, 1).cpu(),
    gt_poses.detach().cpu().numpy(),
)
bbox_video = overlay_bounding_box_on_video(
    bbox_video,
    dataset.bbox.float(),
    camKs.repeat(N, 1, 1).cpu(),
    my_predicted_poses.detach().cpu().numpy(),
    color=(255, 0, 0),
)
mediapy.show_video(bbox_video[:N], fps=15)

In [None]:
def compute_and_plot_add_metrics(
    model_3D_pts,
    diameter,
    predicted_poses: np.ndarray,
    gt_poses: np.ndarray,
    percentage=0.1,
    vert_lines=[],
):
    add_metrics = []
    for i in range(predicted_poses.shape[0]):
        add_metrics.append(
            compute_add_metrics(
                model_3D_pts,
                diameter,
                predicted_poses[i],
                gt_poses[i],
                percentage=percentage,
                return_error=True,
            )
        )
    threshold = diameter * percentage
    print(
        f"Percentage of ADD error less than {threshold}: {np.mean(np.array(add_metrics) < threshold, axis=0)}"
    )
    print(f"Mean ADD error: {np.mean(add_metrics)}")
    plt.plot(add_metrics)
    plt.axhline(threshold, color="r", linestyle="--")
    for vert_line in vert_lines:
        plt.axvline(vert_line, color="g", linestyle="--")
    plt.title("ADD Error over time")
    plt.xlabel("Frame")
    plt.ylabel("ADD Error")
    plt.show()


print("RANSAC CV ePnP")
compute_and_plot_add_metrics(
    np.array(dataset.get_mesh().vertices),
    dataset.obj_diameter,
    epnp_slide_poses,
    gt_poses.detach().cpu().numpy(), # [tracker_input.prepend_length :],
    percentage=0.1,
    vert_lines=[i * (step - overlap) + overlap * int(i > 0) for i in range(N // (step - overlap))],
)
print("Adam Optimizer")
compute_and_plot_add_metrics(
    np.array(dataset.get_mesh().vertices),
    dataset.obj_diameter,
    gradient_poses_slide.detach().cpu().numpy(),#[tracker_input.prepend_length :],
    gt_poses.detach().cpu().numpy(),#[tracker_input.prepend_length :],
    percentage=0.1,
)




print("RANSAC CV ePnP")
compute_and_plot_add_metrics(
    np.array(dataset.get_mesh().vertices),
    dataset.obj_diameter,
    epnp_cv_poses.detach().cpu().numpy()[:],
    gt_poses.detach().cpu().numpy(),
    percentage=0.1,
)
print("Adam Optimizer")
compute_and_plot_add_metrics(
    np.array(dataset.get_mesh().vertices),
    dataset.obj_diameter,
    gradient_poses.detach().cpu().numpy()[:],
    gt_poses.detach().cpu().numpy(),
    percentage=0.1,
)
# print("QueryRefiner Optimizer")
# compute_and_plot_add_metrics(
#     np.array(dataset.get_mesh().vertices),
#     dataset.obj_diameter,
#     query_refiner_poses.detach().cpu().numpy()[tracker_input.prepend_length :],
#     gt_poses.detach().cpu().numpy()[tracker_input.prepend_length :],
#     percentage=0.1,
# )

In [None]:
import cv2 as cv
sift = cv.SIFT_create()
rgb, _, _ = dataset.render_mesh_at_pose(dataset.get_canonical_pose())
kp = sift.detect(rgb, None)

img = cv.drawKeypoints(rgb, kp, None)
plt.imshow(img)
kp_np = np.array([k.pt for k in kp])
print(kp)
print(kp_np.shape)

In [None]:
# Draw distribution of confidence (in bins)
# pred_confidence
plt.hist(pred_confidence[0].cpu().numpy().flatten(), bins=20)