In [1]:
import os
gpu_id = 0
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

import cv2
import sys
import json
import time
import mmcv
import glob
import torch
import shutil
import random
import pickle
import hashlib
import numpy as np
import torch.nn as nn
from torch import optim
import mediapy as media
import matplotlib.pyplot as plt
from argparse import ArgumentParser
import torch.nn.functional as torch_F
from transforms3d import affines, quaternions
from torchvision.ops import roi_align
from pytorch3d import io as py3d_io
from pytorch3d import ops as py3d_ops
from pytorch3d import loss as py3d_loss
from pytorch3d import utils as py3d_util
from pytorch3d import structures as py3d_struct
from pytorch3d import renderer as py3d_renderer
from pytorch3d import transforms as py3d_transform
from pytorch3d.vis import plotly_vis as py3d_vis
from pytorch3d.transforms import (matrix_to_euler_angles,
                                  euler_angles_to_matrix, 
                                  matrix_to_rotation_6d, 
                                  rotation_6d_to_matrix)
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

PROJ_ROOT = os.path.dirname(os.getcwd())
sys.path.append(PROJ_ROOT)

L1Loss = torch.nn.L1Loss(reduction='mean')
MSELoss = torch.nn.MSELoss(reduction='mean')
SSIM_METRIC = SSIM(data_range=1, size_average=True, channel=3) # channel=1 for grayscale images
MS_SSIM_METRIC = MS_SSIM(data_range=1, size_average=True, channel=3)


from inference import *
from misc_utils import gs_utils
from misc_utils.metric_utils import *
from config import inference_cfg as CFG
from model.network import model_arch as ModelNet
from dataset.demo_dataset import OnePoseCap_Dataset


ckpt_file = os.path.join(PROJ_ROOT, 'checkpoints/model_weights.pth')
device = torch.device('cuda:0')
model_net = ModelNet().to(device)
model_net.load_state_dict(torch.load(ckpt_file, map_location=device))
model_net.eval()
print('Model weights are loaded!')




Pretrained weights are loaded from  model_weights.pth
Model weights are loaded!


# 1. Capture a new object

## The demo data is captured using the OnePoseCap APP  and organized as below
This APP can be downloaded from Apple APP store by searching "OnePose Cap", please see [OnePose++](https://github.com/zju3dv/OnePose_Plus_Plus/blob/main/doc/demo.md) for details
```
--- /PROJ_ROOT/demo_data
|       |--- obj_name
|       |       |---obj_name-annotate
|       |       |---obj_name-test
```
You can download the [demo_cam](https://drive.google.com/file/d/18tIqbVaK2G9rOWSz-SPsP0wMBzvX5XfX/view?usp=sharing) object provided by OnePose++.

In [2]:
obj_name = 'demo_buzz'
test_suffix = '-from-far'

demo_data_dir = os.path.abspath(os.path.join(PROJ_ROOT, 'demo_data', obj_name))
refer_seq_dir = os.path.join(demo_data_dir, f'{obj_name}-annotate')
query_seq_dir = os.path.join(demo_data_dir, f'{obj_name}-test{test_suffix}')


# 2. Create reference database using object reference video

In [3]:
obj_refer_dataset = None
obj_database_dir = os.path.join(demo_data_dir, f'{obj_name}-database')
obj_database_path = os.path.join(obj_database_dir, 'reference_database.pkl')

if not os.path.exists(obj_database_path):
    print(f'Generate object reference database for {obj_name} ...')
    obj_refer_dataset = OnePoseCap_Dataset(obj_data_dir=refer_seq_dir, 
                                           obj_database_dir=obj_database_dir,
                                           use_binarized_mask=CFG.BINARIZE_MASK)

    reference_database = create_reference_database_from_RGB_images(model_net, 
                                                                   obj_refer_dataset, 
                                                                   save_pred_mask=True, device=device)
    
    obj_bbox3D = torch.as_tensor(obj_refer_dataset.obj_bbox3d, dtype=torch.float32)
    bbox3d_diameter = torch.as_tensor(obj_refer_dataset.bbox3d_diameter, dtype=torch.float32)
    reference_database['obj_bbox3D'] = obj_bbox3D
    reference_database['bbox3d_diameter'] = bbox3d_diameter


    parser = ArgumentParser(description="Training script parameters")
    ###### arguments for 3D-Gaussian Splatting Refiner ########
    gaussian_ModelP = ModelParams(parser)
    gaussian_PipeP  = PipelineParams(parser)
    gaussian_OptimP = OptimizationParams(parser)
    gaussian_BG = torch.zeros((3), device=device)

    if 'ipykernel_launcher.py' in sys.argv[0]:
        args = parser.parse_args(sys.argv[3:]) # if run in ipython notebook
    else:
        args = parser.parse_args() # if run in terminal
    
    print(f'Creating 3D-OGS model for {obj_name} ')
    gs_pipeData  = gaussian_PipeP.extract(args)
    gs_modelData = gaussian_ModelP.extract(args)
    gs_optimData = gaussian_OptimP.extract(args)
    
    gs_modelData.model_path = obj_database_dir
    gs_modelData.referloader = obj_refer_dataset
    gs_modelData.queryloader = obj_refer_dataset
    
    obj_gaussians = create_3D_Gaussian_object(gs_modelData, gs_optimData, gs_pipeData, return_gaussian=True)
    
    reference_database['obj_gaussians_path'] = f'{obj_database_dir}/3DGO_model.ply'
    
    for _key, _val in reference_database.items():
        if isinstance(_val, torch.Tensor):
            reference_database[_key] = _val.detach().cpu().numpy()
    with open(obj_database_path, 'wb') as df:
        pickle.dump(reference_database, df)
    print('save database to ', obj_database_path)
    
    
print('Load database from ', obj_database_path)
with open(obj_database_path, 'rb') as df:
    reference_database = pickle.load(df)

for _key, _val in reference_database.items():
    if isinstance(_val, np.ndarray):
        reference_database[_key] = torch.as_tensor(_val, dtype=torch.float32).to(device)

gs_ply_path = reference_database['obj_gaussians_path']
obj_gaussians = GaussianModel(sh_degree=3)
obj_gaussians.load_ply(gs_ply_path)
print('load 3D-OGS model from ', gs_ply_path)
reference_database['obj_gaussians'] = obj_gaussians
cannon_3D_bbox = reference_database['obj_bbox3D'].cpu()

Load database from  /home/joao/Documents/repositories/GSPose/demo_data/demo_buzz/demo_buzz-database/reference_database.pkl
load 3D-OGS model from  /home/joao/Documents/repositories/GSPose/demo_data/demo_buzz/demo_buzz-database/3DGO_model.ply


# 3. Load test data for pose estimaton and tracking

In [4]:
query_video_camKs = list()
with open(os.path.join(query_seq_dir, 'Frames.txt'), 'r') as cf:
    for row in cf.readlines():
        if len(row) > 0 and row[0] != '#':
            camk_dat = np.array([float(c) for c in row.strip().split(',')])
            camk = np.eye(3)
            camk[0, 0] = camk_dat[-4]
            camk[1, 1] = camk_dat[-3]
            camk[0, 2] = camk_dat[-2]
            camk[1, 2] = camk_dat[-1]
            query_video_camKs.append(camk)
query_video_frames = media.read_video(os.path.join(query_seq_dir, 'Frames.m4v')) # NxHxWx3    
num_frames = len(query_video_frames)
query_video_frames.shape

(248, 1280, 720, 3)

## 3.1 Perform single-frame pose estimation

In [60]:
gsp_poses = list()
gsp_video_frames = list()
gsp_segmentation_frames = list()


scale = 1
thickness = 2
color = (255, 255, 0)
font = cv2.FONT_HERSHEY_SIMPLEX

start_idx = 0
gsp_accum_runtime = 0
num_frames = len(query_video_frames)
for view_idx in range(num_frames):    
    camK = query_video_camKs[view_idx]
    image = query_video_frames[view_idx]
    image = torch.as_tensor(np.array(image), dtype=torch.float32) / 255.0
    
    camK = torch.as_tensor(camK, dtype=torch.float32)
    
    target_size = CFG.zoom_image_scale
    raw_hei, raw_wid = image.shape[:2]
    raw_long_size = max(raw_hei, raw_wid)
    raw_short_size = min(raw_hei, raw_wid)
    raw_aspect_ratio = raw_short_size / raw_long_size
    if raw_hei < raw_wid:
        new_wid = CFG.query_longside_scale
        new_hei = int(new_wid * raw_aspect_ratio)
    else:
        new_hei = CFG.query_longside_scale
        new_wid = int(new_hei * raw_aspect_ratio)
    query_rescaling_factor = CFG.query_longside_scale / raw_long_size
    que_image = image[None, ...].permute(0, 3, 1, 2).to(device)
    que_image = torch_F.interpolate(que_image, size=(new_hei, new_wid), mode='bilinear', align_corners=True)

    run_timer = time.time()
    
    obj_data = perform_segmentation_and_encoding(model_net, que_image, reference_database, device=device)
    obj_data['camK'] = camK.to(device)
    obj_data['img_scale'] = max(image.shape[:2])
    obj_data['bbox_scale'] /= query_rescaling_factor  # back to the original image scale
    obj_data['bbox_center'] /= query_rescaling_factor # back to the original image scale
     
    try:
        init_RTs = multiple_initial_pose_inference(obj_data, ref_database=reference_database, device=device)
    except Exception as e:
        print(e)
        init_RTs = torch.eye(4)[None].numpy()
        
    refiner_oupt = multiple_refine_pose_with_GS_refiner(
        obj_data, init_pose=init_RTs, gaussians=reference_database['obj_gaussians'], device=device)
            
    gsp_accum_runtime += time.time() - run_timer

    gsp_pose = refiner_oupt['refined_RT']
    iter_step = refiner_oupt['iter_step']
    bbox_scale = refiner_oupt['bbox_scale']
    bbox_center = refiner_oupt['bbox_center']
    gsp_render_frame = refiner_oupt['render_img']
    gsp_poses.append(gsp_pose)
    
    small_hei = raw_hei // 3 # downscale the image for visualization
    small_wid = raw_wid // 3
    
#     gsp_render_frame = render_Gaussian_object_model(obj_gaussians, camK=camK, pose=gsp_RT, 
#                                                     img_hei=raw_hei, img_wid=raw_wid, device=device)
    # Visualize the segmentation mask as a heatmap on top of the segmentation zone
    segmentation_zone = obj_data['rgb_image'] # 3xHxW
    segmentation_mask = obj_data['rgb_mask'] # 1xHxW
    seg_mask_np = (segmentation_mask[0].detach().cpu().numpy() * 255).astype(np.uint8)
    seg_mask_np = cv2.applyColorMap(seg_mask_np, cv2.COLORMAP_JET)
    seg_zone_np = (segmentation_zone.detach().cpu().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    seg_frame = cv2.addWeighted(seg_zone_np, 0.75, seg_mask_np, 0.25, 0)
    # Pad seg_frame with white to be small_hei x small_wid
    padded_seg_frame = np.ones((small_hei, small_wid, 3), dtype=np.uint8) * 255
    padded_seg_frame[:seg_frame.shape[0], :seg_frame.shape[1], :] = seg_frame
    # gsp_segmentation_frames.append(seg_frame)
    
    gsp_render_frame = gs_utils.zoom_out_and_uncrop_image(gsp_render_frame, # 3xSxS
                                                            bbox_scale=bbox_scale,
                                                            bbox_center=bbox_center,
                                                            orig_hei=image.shape[0],
                                                            orig_wid=image.shape[1],
                                                            ).detach().cpu().squeeze() # HxWx3
    gsp_render_frame = (torch.clamp(gsp_render_frame, 0, 1.0) * 255).numpy().astype(np.uint8)
    
    query_img_np = (image * 255).numpy().astype(np.uint8)
    gsp_RT = torch.as_tensor(gsp_pose, dtype=torch.float32)
    gsp_bbox_KRT = torch.einsum('ij,kj->ki', gsp_RT[:3, :3], cannon_3D_bbox.cpu()) + gsp_RT[:3, 3][None, :]
    gsp_bbox_KRT = torch.einsum('ij,kj->ki', camK, gsp_bbox_KRT)
    gsp_bbox_pts = (gsp_bbox_KRT[:, :2] / gsp_bbox_KRT[:, 2:3]).type(torch.int64)
    track_bbox_pts = gsp_bbox_pts.numpy()
    
    gsp_bbox3d_frame = query_img_np.copy()
    gsp_bbox3d_frame = gs_utils.draw_3d_bounding_box(gsp_bbox3d_frame, track_bbox_pts, color=color, linewidth=5)
    gsp_bbox3d_frame = cv2.resize(gsp_bbox3d_frame, (small_wid, small_hei))
        
    query_img_np = cv2.resize(query_img_np, (small_wid, small_hei))
    gsp_render_frame = cv2.resize(gsp_render_frame, (small_wid, small_hei))
    gsp_overlay_frame = cv2.addWeighted(
        cv2.cvtColor(gsp_render_frame.copy(), cv2.COLOR_BGR2HSV), 0.6, query_img_np, 0.4, 1)
            
    cv2.putText(query_img_np,     'Input Video', (10, 40), font, scale, color, thickness=thickness)
    cv2.putText(gsp_bbox3d_frame, 'Pose Estimate', (20, 40), font, scale, color, thickness=thickness)

    cv2.putText(gsp_render_frame,    'Posing 3DGS', (20, 40), font, scale, color, thickness=thickness)
    # cv2.putText(gsp_render_frame,    'with estimated pose', (10, 80), font, scale, color, thickness=thickness)
    
    cv2.putText(gsp_overlay_frame, 'Input Overlay', (20, 40), font, scale, color, thickness=thickness)
    # cv2.putText(gsp_overlay_frame, 'on input image', (60, 80), font, scale, color, thickness=thickness)

    wihite_stripe = np.ones_like(gsp_bbox3d_frame)[:, :50, :] * 255    
    concat_frame = np.concatenate([gsp_bbox3d_frame, wihite_stripe,
                                   gsp_render_frame, wihite_stripe,
                                   gsp_overlay_frame, wihite_stripe,
                                   padded_seg_frame], axis=1)
    gsp_video_frames.append(concat_frame)
    
    if (view_idx + 1) % 30 == 0:
        print('[{}/{}], \t refining_step:{}, \t {:.1f} FPS'.format(
            view_idx+1, num_frames, iter_step, (view_idx - start_idx) / gsp_accum_runtime))
    
gsp_video_frames = np.stack(gsp_video_frames, axis=0)
print(gsp_video_frames.shape)

[30/248], 	 refining_step:63, 	 0.9 FPS
[60/248], 	 refining_step:62, 	 0.9 FPS
[90/248], 	 refining_step:62, 	 0.9 FPS
[120/248], 	 refining_step:61, 	 0.9 FPS
[150/248], 	 refining_step:62, 	 0.9 FPS
[180/248], 	 refining_step:62, 	 0.9 FPS
[210/248], 	 refining_step:54, 	 0.9 FPS
[240/248], 	 refining_step:43, 	 0.9 FPS
(248, 426, 1110, 3)


In [61]:
fps = 15
media.show_video(np.stack(gsp_video_frames, axis=0), fps=fps, width=3*320)
# media.show_video(np.stack(gsp_segmentation_frames, axis=0), fps=fps, width=320)

0
This browser does not support the video tag.


## 3.2 Perform pose tracking

In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import cv2
import time

# Setup configuration and constants
start_idx = 0 # Start frame index
target_size = CFG.zoom_image_scale
num_frames = len(query_video_frames)
frame_interval = 1
scale, thickness = 1, 2
color = (0, 165, 255)
font = cv2.FONT_HERSHEY_SIMPLEX

CFG.MAX_STEPS = 32
CFG.START_LR = 3e-3
CFG.END_LR = 1e-5

# Initialize tracking variables
track_poses, track_video_frames = [], []
track_accum_runtime = 0


# Helper function to prepare the image and camera intrinsics
def prepare_image_and_camK(index):
    camK = torch.as_tensor(query_video_camKs[index], dtype=torch.float32)
    image = torch.as_tensor(np.array(query_video_frames[index]), dtype=torch.float32) / 255.0
    return camK, image

# Rescale the image to models input size
# -- Get the initial image size
camK, image = prepare_image_and_camK(start_idx)
raw_hei, raw_wid = image.shape[:2]
raw_long_size = max(raw_hei, raw_wid)
raw_short_size = min(raw_hei, raw_wid)
raw_aspect_ratio = raw_short_size / raw_long_size

# -- Calculate new image size
new_wid, new_hei = (CFG.query_longside_scale, int(CFG.query_longside_scale * raw_aspect_ratio)) if raw_hei < raw_wid else (int(CFG.query_longside_scale * raw_aspect_ratio), CFG.query_longside_scale)
query_rescaling_factor = CFG.query_longside_scale / raw_long_size

# -- Scale the image
que_image = image[None, ...].permute(0, 3, 1, 2).to(device)
que_image = F.interpolate(que_image, size=(new_hei, new_wid), mode='bilinear', align_corners=True)

# Get the initial pose estimate
# -- Perform segmentation and pose inference
obj_data = perform_segmentation_and_encoding(model_net, que_image, reference_database, device=device)
obj_data.update({'camK': camK, 'img_scale': max(image.shape[:2])})
obj_data['bbox_scale'] /= query_rescaling_factor
obj_data['bbox_center'] /= query_rescaling_factor

# -- Initial pose estimation
track_pose = multiple_initial_pose_inference(obj_data, ref_database=reference_database, device=device)[0]

# Loop through each frame in the video
for view_idx in range(start_idx, num_frames, frame_interval):
    camK, image = prepare_image_and_camK(view_idx)
    image_hei, image_wid = image.shape[:2]

    # Refine the pose estimate (with GS tracker)
    track_timer = time.time()
    track_outp = GS_Tracker(model_net, frame=image, prev_pose=track_pose, camK=camK, ref_database=reference_database)
    track_accum_runtime += time.time() - track_timer

    # Update tracking data
    track_pose = track_outp['track_pose']
    bbox_scale, bbox_center = track_outp['bbox_scale'], track_outp['bbox_center']
    track_poses.append(track_pose)

    # Generate render and overlay images
    render_full_img = gs_utils.zoom_out_and_uncrop_image(track_outp['render_img'], bbox_scale=bbox_scale, bbox_center=bbox_center, orig_hei=image.shape[0], orig_wid=image.shape[1]).detach().cpu().squeeze()
    render_full_img_np = (torch.clamp(render_full_img, 0, 1.0) * 255).numpy().astype(np.uint8)
    query_img_np = (image * 255).numpy().astype(np.uint8)
    
    track_render_img = render_full_img_np.copy()
    track_overlap_frame = cv2.addWeighted(cv2.cvtColor(track_render_img.copy(), cv2.COLOR_BGR2HSV), 0.6, query_img_np, 0.4, 1)

    # Draw 3D bounding box on the image
    track_RT = torch.as_tensor(track_pose, dtype=torch.float32)
    track_bbox_KRT = torch.einsum('ij,kj->ki', track_RT[:3, :3], cannon_3D_bbox) + track_RT[:3, 3][None, :]
    track_bbox_KRT = torch.einsum('ij,kj->ki', camK, track_bbox_KRT)
    track_bbox_pts = (track_bbox_KRT[:, :2] / track_bbox_KRT[:, 2:3]).type(torch.int64).numpy()
    track_bbox3d_img = gs_utils.draw_3d_bounding_box(query_img_np.copy(), track_bbox_pts, color=color, linewidth=5)

    # Resize for visualization
    small_size = (image_wid // 3, image_hei // 3)
    resized_images = [cv2.resize(img, small_size) for img in [query_img_np, track_bbox3d_img, track_render_img, track_overlap_frame]]
    
    # Add text labels
    labels = ['Input Video', 'Tracking Result', 'Gaussian 3DGS', 'Input Overlay']
    for img, label in zip(resized_images, labels):
        cv2.putText(img, label, (10, 40), font, scale, color, thickness=thickness)
    
    # Concatenate images for output
    white_stripe = np.ones_like(resized_images[0])[:, :50, :] * 255
    concat_images = np.concatenate([resized_images[0], white_stripe, resized_images[1], white_stripe, resized_images[2], white_stripe, resized_images[3]], axis=1)
    track_video_frames.append(concat_images)

    # Log progress every 30 frames
    if (view_idx + 1) % 30 == 0:
        print(f'[{view_idx+1}/{num_frames}], \t{(view_idx - start_idx) / track_accum_runtime:.1f} FPS')

NameError: name 'CFG' is not defined

In [15]:
media.show_video(np.stack(track_video_frames, axis=0), fps=15, width=320*3)

0
This browser does not support the video tag.


# 4. Visualize the coordinate pointcloud of the 3D Gaussian object model

In [26]:
obj_gaussian_pointcloud = py3d_struct.Pointclouds(
    points=[obj_gaussians.get_xyz.squeeze().detach().cpu()],
    features=[obj_gaussians._features_dc.squeeze().detach().cpu().sigmoid()]
)

fig = py3d_vis.plot_scene(
    {" ": 
        {
            'Gaussian pointcloud': obj_gaussian_pointcloud,
        }
    },
    xaxis={"backgroundcolor":"rgb(200, 200, 230)"},
    yaxis={"backgroundcolor":"rgb(230, 200, 200)"},
    zaxis={"backgroundcolor":"rgb(200, 230, 200)"}, 
    
#     xaxis={"backgroundcolor":"rgb(255, 255, 255)"},
#     yaxis={"backgroundcolor":"rgb(255, 255, 255)"},
#     zaxis={"backgroundcolor":"rgb(255, 255, 255)"}, 
    
    pointcloud_marker_size=3,
    pointcloud_max_points=30_000,
    axis_args=py3d_vis.AxisArgs(showgrid=True)
)

fig.update_layout(width=800, height=600)
fig.show()

