In [33]:
import os, sys
from pathlib import Path
import numpy as np

import glob
import cv2
import json
import csv
import shutil
import PIL

import torch
from torchvision import transforms

## This notebook demonstrates the pipeline of the improved baseline, i.e., SAM+DINOv2.

### step 1: Computing per-frame 2D detections with cosine simlarity  

In [2]:
@torch.inference_mode()
def extract_features_dinov2(model, img_tensor, use_cuda=True):
    samples = img_tensor.cuda(non_blocking=True)
    feats = model(samples).float().clone()
    return feats

def extract_features(encoder_arch, encoder, visual_crop_np, preprocess, device):
    with torch.no_grad():
        if encoder_arch == "DINOv2":
            input_list = []
            for vc in visual_crop_np:
                image_input = preprocess(PIL.Image.fromarray(vc)).unsqueeze(0).to(device)
                input_list.append(image_input)
            all_input = torch.cat(input_list, dim=0)
            visual_crop_features = extract_features_dinov2(encoder, all_input) # N x 1024

    return visual_crop_features

def load_online_enrollment(annot_data_path, xywh2xyxy=False):
    bbox = {}
    frame = {}
    with open(os.path.join(annot_data_path, "svoe.txt"), "r") as f:
        for l in f:
            l_list = l.rstrip("\n").split()
            temp = np.array(l_list[2:]).astype(int)
            if xywh2xyxy:
                temp[2:] += temp[0:2]
            bbox[l_list[0]] = temp
            frame[l_list[0]] = l_list[1]
    return bbox, frame

In [5]:
def load_frame_proposals(detection_f, detector="sam", unnormalize=False, img_W=None, img_H=None, include_conf=True):
    res = [] # XYHW format
    if detector == "sam":
        with open(detection_f, "r") as f:
            for l in f:
                l = l.strip().split()
                if len(l) == 0:
                    continue
                x, y, w, h = float(l[0]), float(l[1]), float(l[2]), float(l[3])
                if len(l) > 4 and include_conf:
                    conf = float(l[4])
                xyxy = np.array([x, y, x+w, y+h], dtype=np.float32)
                if include_conf:
                    res.append(np.array([xyxy[0], xyxy[1], xyxy[2], xyxy[3], conf], dtype=np.float32))
                else:
                    res.append(np.array([xyxy[0], xyxy[1], xyxy[2], xyxy[3]], dtype=np.float32))

    return res

def extract_windows_from_detections(frame, all_detections):
    # add dets must in XYXY mode
    res = []
    for det in all_detections:
        x1, y1, x2, y2 = det[:4]
        window = frame[int(y1):int(y2), int(x1):int(x2)]
        res.append(window)
    return res

def extract_proposal_scores(proposal_features, visual_crop_features, det_score=None, dist_metric="cos", score_thresh=None, topk=1):
    # unsqueeze to M x N x D
    proposal_features = proposal_features.unsqueeze(0).expand(visual_crop_features.shape[0], -1, -1)
    visual_crop_features = visual_crop_features.unsqueeze(1).expand(-1, proposal_features.shape[1], -1)

    # compute proposal scores
    if dist_metric == "cos":
        proposal_scores = torch.nn.functional.cosine_similarity(proposal_features, visual_crop_features, dim=-1) # M x N

    if score_thresh is not None:
        proposal_scores[proposal_scores < score_thresh] = 0

    # extract top k results
    proposal_ret = torch.topk(proposal_scores, k=topk, dim=1)

    return proposal_ret

def load_instance_id_pairs(label_csv):
    with open(label_csv, "r") as f:
        data = list(csv.reader(f, delimiter=","))
    res = {}
    for i, v in enumerate(data):
        res[i] = v[0] # 0: "shoe_1"
    return res

def _convert_image_to_rgb(image):
    return image.convert("RGB")

def print_msg_box(msg, indent=1, width=None, title=None):
    """Print message-box with optional title."""
    lines = msg.split('\n')
    space = " " * indent
    if not width:
        width = max(map(len, lines))
    box = f'╔{"═" * (width + indent * 2)}╗\n'  # upper_border
    if title:
        box += f'║{space}{title:<{width}}{space}║\n'  # title
        box += f'║{space}{"-" * len(title):<{width}}{space}║\n'  # underscore
    box += ''.join([f'║{space}{line:<{width}}{space}║\n' for line in lines])
    box += f'╚{"═" * (width + indent * 2)}╝'  # lower_border
    print(box)

In [34]:
data_root = os.path.join(os.getcwd(), "../benchmark_data")
detection_root = os.path.join(os.getcwd(), "../hololens_detection")
seq_list = os.listdir(os.path.join(os.getcwd(), "../benchmark_data/raw_video_seqs"))

img_ext = ".png"
detector = "sam"
enrollment_type = "SVOE"

# create encoder
encoder_arch = "DINOv2"
score_thresh_dict = {"DINOv2": 0.6}
model_name = "{}_{}_{}_{}".format(enrollment_type, detector, encoder_arch, score_thresh_dict[encoder_arch])

print_msg_box("model: {}".format(model_name))
save_dir = os.path.join(os.getcwd(), "results/association_results")
os.makedirs(save_dir, exist_ok=True)
cache_feature = True
recompute_feature = False
cache_dir = os.path.join(os.getcwd(), "cache_features")

device = "cuda" if torch.cuda.is_available() else "cpu"
if encoder_arch == "DINOv2":
    encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14', verbose=False)
    n_px = 256
    crop_px = 224
    preprocess = transforms.Compose([
        transforms.Resize(n_px, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(crop_px),
         _convert_image_to_rgb,
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    encoder.to(device)
    encoder.eval()

else:
    raise RuntimeError("Unknown encoder architecture")

for seq in seq_list:
    cam_pose = json.load(open(os.path.join(data_root, "raw_video_seqs", seq, "pv_pose.json")))
    all_frames = sorted(cam_pose.keys())
    save_dir_seq = os.path.join(save_dir, seq, model_name)
    if os.path.exists(save_dir_seq):
        continue
    os.makedirs(save_dir_seq, exist_ok=True)

    # check if feature is cached
    obj_feature_fname = "{}_{}".format(enrollment_type, encoder_arch)
    visual_crop_features_cache_p = os.path.join(cache_dir, seq, "objects", 
        "{}.pth".format(obj_feature_fname))

    if os.path.exists(visual_crop_features_cache_p) and not recompute_feature:
        print("load cached features!")
        visual_crop_features = torch.load(visual_crop_features_cache_p)
    else:
        if enrollment_type == "SVOE":
            bbox_dict, frame_dict = load_online_enrollment(os.path.join(data_root, "annotations", seq.replace("raw", "annotation")))
            visual_crop_np = []
            for obj_idx in frame_dict.keys():
                f_p = os.path.join(data_root, "raw_video_seqs", seq, "pv", frame_dict[obj_idx] + img_ext)
                f_img = cv2.cvtColor(cv2.imread(f_p), cv2.COLOR_BGR2RGB)
                f_bbox = bbox_dict[obj_idx]
                visual_crop_np += extract_windows_from_detections(f_img, [f_bbox])

        visual_crop_features = extract_features(encoder_arch, encoder, visual_crop_np, preprocess, device)
        if cache_feature:
            os.makedirs(os.path.join(cache_dir, seq, "objects"), exist_ok=True)
            torch.save(visual_crop_features.cpu(), os.path.join(cache_dir, seq, "objects", "{}.pth".format(obj_feature_fname)))

    num_obj = visual_crop_features.shape[0]
    count = 0
    for f in all_frames:
        frame_p = os.path.join(data_root, "raw_video_seqs", seq, "pv", f+img_ext)
        frame = cv2.cvtColor(cv2.imread(frame_p), cv2.COLOR_BGR2RGB)

        # load detection
        detection_f = os.path.join(detection_root, detector, seq, "labels", f+".txt")
        if not os.path.exists(detection_f):
            # still save an empty results
            output_f = os.path.join(save_dir_seq, f+".txt")
            with open(output_f, "w") as f:
                pass
            continue

        all_detections = load_frame_proposals(detection_f, detector=detector, unnormalize=True, img_W=frame.shape[1], img_H=frame.shape[0])
        proposal_features_cache_p = os.path.join(cache_dir, seq, "proposals", "{}_{}_{}.pth".format(f, detector, encoder_arch))

        if os.path.exists(proposal_features_cache_p):
            proposal_features = torch.load(proposal_features_cache_p)
        else:
            detection_windows = extract_windows_from_detections(frame, all_detections) # a list of windows: H x W x 3
            proposal_features = extract_features(encoder_arch, encoder, detection_windows, preprocess, device)

            if cache_feature:
                # check whether feature exists
                os.makedirs(os.path.join(cache_dir, seq, "proposals"), exist_ok=True)
                cache_f_seq = "{}_{}_{}.pth".format(f, detector, encoder_arch)
                torch.save(proposal_features.cpu(), os.path.join(cache_dir, seq, "proposals", cache_f_seq))

        # extract features and find the closest detection proposal
        det_scores = torch.tensor(np.array([det[4] for det in all_detections], dtype=np.float32))
        proposal_ret = extract_proposal_scores(proposal_features.cpu().float(), visual_crop_features.cpu().float(), det_scores,
                                               dist_metric="cos", score_thresh=score_thresh_dict[encoder_arch], topk=1)
        proposal_ret_val, proposal_ret_idx = proposal_ret.values.cpu().numpy(), proposal_ret.indices.cpu().numpy() # M x topk

        '''
            output format: visual_crop_id, associated detection #1, detection conf, similarity score
        '''
        output_f = os.path.join(save_dir_seq, f+".txt")
        with open(output_f, "w") as f:
            for i in range(num_obj):
                for j in range(proposal_ret_idx.shape[-1]):
                    if proposal_ret_val[i, j] == 0.: continue
                    f.write("{} {} {} {} {} {} {}\n".format(i, *all_detections[proposal_ret_idx[i, j]], proposal_ret_val[i, j]))
print("Step 1 done!")

╔════════════════════════════╗
║ model: SVOE_sam_DINOv2_0.6 ║
╚════════════════════════════╝
Step 1 done!


### step 2: Lifting the center of 2D bounding boxes to 3D  

In [35]:
import re
from collections import defaultdict

from proj_utils import compute_depth_scale_map_wrapper, pv_bbox2depth_unproj, to_homogeneous

sys.path.insert(0, os.getcwd())
import hl2ss_3dcv

def singular_matrix(matrix):
    return np.linalg.det(matrix) == 0

def split_letter_number(s):
    return re.split('(\d+)',s)

def load_2d_dets(annot_l, include_pred_score=False):
    res = defaultdict(list) # key is frame number, values are boxes
    for idx, annot_f in enumerate(annot_l):
        frame_temp = os.path.basename(annot_f).split(".")[0]
        with open(annot_f, "r") as f:
            for l in f:
                if len(l.split()) > 2:
                    splits = l.rstrip("\n").split()
                    obj_idx, x1, y1, x2, y2 = [int(float(xx)) for xx in splits[:5]]
                    if int(x1) <= 1 and int(y1) <= 1:
                        continue

                    if include_pred_score and len(splits) == 6:
                        if splits[-1].isnumeric():
                            score = float(splits[-1])
                        else:
                            score = 1.0
                        res[frame_temp].append(np.array([int(x1), int(y1), int(x2), int(y2), obj_idx, score], dtype=float))
                    else:
                        res[frame_temp].append(np.array([int(x1), int(y1), int(x2), int(y2), obj_idx], dtype=int))

    return res

In [36]:
associate_root = os.path.join(os.getcwd(), "results/association_results")
save_dir = os.path.join(os.getcwd(), "results/lifted_3d_results")

os.makedirs(save_dir, exist_ok=True)
cam2world = True

print_msg_box("model: {}".format(model_name))
ahat_calibration = hl2ss_3dcv._load_calibration_rm_depth_ahat(os.path.join(data_root, "calibrations", "rm_depth_ahat"))

for seq in seq_list:
    cam_pose = json.load(open(os.path.join(data_root, "raw_video_seqs", seq, "pv_pose.json")))
    depth_poses = json.load(open(os.path.join(data_root, "raw_video_seqs", seq, "depth_ahat_pose.json")))
    all_frames = sorted(cam_pose.keys(), key=lambda x:int(x))
    save_dir_seq = os.path.join(save_dir, seq, model_name)
    if os.path.exists(save_dir_seq):
        continue
    os.makedirs(save_dir_seq, exist_ok=True)

    tracker_annot_dir = os.path.join(associate_root, seq, model_name, "*.txt")
    tracker_annot_l = glob.glob(tracker_annot_dir)
    tracker_annot_dict = load_2d_dets(tracker_annot_l, include_pred_score=True)

    hw = {'image': (720, 1280), 'ahat_depth': (512, 512)}
    count = 0
    for f in all_frames:
        frame_abs_path = os.path.join(data_root, "raw_video_seqs", seq, "pv", f + img_ext)
        reproj_xyz = compute_depth_scale_map_wrapper(frame_abs_path, seq, hw=hw, depth_pose=depth_poses[f], rgb_pose=cam_pose[f])

        assoc_dets = defaultdict(list)
        scores = defaultdict(list)
        if f not in tracker_annot_dict.keys():
            det_flag = 0
        else:
            det_flag = 1
            bbox_l = tracker_annot_dict[f] # [x1, y1, x2, y2, idx, (pred_score)], XYXY format already
            for bbox in bbox_l:
                    object_id = bbox[4]
                    det = [float(x) for x in bbox[:4]]
                    det_tmp, det_flag = pv_bbox2depth_unproj(det, reproj_xyz) # convert 2D detection into 3D pts
                    if det_flag == 1:
                        if cam2world and not singular_matrix(cam_pose[f]):
                            matrix_t = hl2ss_3dcv.camera_to_rignode(ahat_calibration.extrinsics) @ hl2ss_3dcv.reference_to_world(depth_poses[f])
                            det_tmp = to_homogeneous(det_tmp).reshape(1, 4) @ matrix_t
                        assoc_dets[object_id].append(det_tmp.squeeze()[:3])
                        if len(bbox) > 5:
                            scores[object_id].append(bbox[-1])
                        else:
                            scores[object_id].append(1) # placeholder

        # save results
        '''
            results format: visual_crop_id, associated 3d Point in world coord #1, detection conf, similarity score
        '''
        if det_flag == 1:
            output_f = os.path.join(save_dir_seq, f+".txt")
            with open(output_f, "w") as f2:
                for i in assoc_dets.keys():
                    for j in range(len(assoc_dets[i])):
                        f2.write("{} {} {} {} {}\n".format(i, *assoc_dets[i][j], scores[i][j]))

print("Step 2 done!")

╔════════════════════════════╗
║ model: SVOE_sam_DINOv2_0.6 ║
╚════════════════════════════╝
Step 2 done!


### step 3: Adopting the simple memory mechanism for the frames without valid predictions

In [22]:
save_dir = os.path.join(os.getcwd(), "results/baseline_results")
os.makedirs(save_dir, exist_ok=True)

print_msg_box("model: {}".format(model_name))### step 3: Adopting the simple memory mechanism for the frames without valid predictions
for seq in seq_list:
    # load per-frame association results, camera poses, and object id
    cam_pose = json.load(open(os.path.join(data_root, "raw_video_seqs", seq, "pv_pose.json")))
    all_frames = sorted(cam_pose.keys(), key=lambda x:int(x))
    save_dir_seq = os.path.join(save_dir, seq, model_name)
    os.makedirs(save_dir_seq, exist_ok=True)

    memory = {}
    score_memory = {}
    for f in all_frames:
        # if 3d assocation results exist, load it == also update memory
        if os.path.exists(os.path.join(associate_root, seq, model_name, f+".txt")):
            with open(os.path.join(associate_root, seq, model_name, f+".txt"), "r") as f1:
                lines = f1.readlines()
                for line in lines:
                    line = line.strip().split(" ")
                    loc_np = np.array(line[1:4], dtype=np.float32)
                    scores = float(line[-1])

                    if line[0] not in memory:
                        memory[line[0]] = [loc_np]
                        score_memory[line[0]] = [scores]
                    elif (loc_np == memory[line[0]][-1]).all():
                        continue
                    else:
                        memory[line[0]].append(loc_np)
                        score_memory[line[0]].append(scores)

        # output 3d association results,
        with open(os.path.join(save_dir_seq, f+".txt"), "w") as f2:
            for k, v in memory.items():
                if len(v) >= 1:
                    f2.write("{} {} {} {} {}\n".format(int(float(k)), *v[-1], score_memory[k][-1]))

print("Step 3 done!")

╔════════════════════════════╗
║ model: SVOE_sam_DINOv2_0.6 ║
╚════════════════════════════╝
Step 3 done!


### step 4: Evaluation

In [23]:
def load_gt(gt_path: str):
    gt_3d = defaultdict(list)
    gt_3d_range = defaultdict(list)
    with open(gt_path, "r") as f:
        for line in f:
            line = line.rstrip("\n").split()
            start_f, end_f = int(line[0]), int(line[1])
            instance_id = int(line[2])
            x, y, z = float(line[3]), float(line[4]), float(line[5])
            gt_3d[instance_id].append(np.array([x, y, z]))
            gt_3d_range[instance_id].append([start_f, end_f])
    return gt_3d, gt_3d_range

In [29]:
baseline_root = os.path.join(os.getcwd(), "results/baseline_results")
# predefined thresholds
thresholds = [0.25, 0.5, 0.75, 1.0, 1.5]

# load predictions
num_gts_all = 0
num_preds_all = 0
tp_all, fp_all, fn_all = np.array([0] * len(thresholds)), np.array([0] * len(thresholds)), np.array([0] * len(thresholds))
for seq_id, seq in enumerate(seq_list):
    # per sequence precision and recall
    cam_pose = json.load(open(os.path.join(data_root, "raw_video_seqs", seq, "pv_pose.json")))
    all_frames = sorted(cam_pose.keys(), key=lambda x:int(x))

    # load gt 3D points for the sequence
    gt_path = os.path.join(data_root, "annotations", seq.replace("raw", "annotation"), "3d_center_annot.txt")
    gt_3d, gt_3d_range = load_gt(gt_path)

    tp, fp, fn = np.array([0] * len(thresholds)), np.array([0] * len(thresholds)), np.array([0] * len(thresholds))
    num_gts = 0
    num_preds = 0
    for f in all_frames:
        # load baseline 
        baseline_f = os.path.join(baseline_root, seq, model_name, f+".txt")
        baseline_res = {}
        bsaeline_res_detScore = {}
        with open(baseline_f, "r") as f1:
            for l in f1:
                l = l.rstrip("\n").split()
                instance_id = int(l[0])
                baseline_res[instance_id] = [np.array([float(l[1]), float(l[2]), float(l[3])])]
                bsaeline_res_detScore[instance_id] = [float(l[4])]

        num_preds += len(baseline_res.keys())
        matched_ids = []
        # compute distance to gt
        for instance_id in gt_3d:
            for i in range(len(gt_3d[instance_id])):
                if int(f) >= gt_3d_range[instance_id][i][0] and int(f) <= gt_3d_range[instance_id][i][1]:
                    num_gts += 1
                    if instance_id not in baseline_res:
                        fn += 1
                        continue

                    for t_idx, t in enumerate(thresholds):
                        tp_flag = False
                        assert len(baseline_res[instance_id]) == 1
                        for j in range(len(baseline_res[instance_id])):
                            dist = np.linalg.norm(gt_3d[instance_id][i] - baseline_res[instance_id][j])
                            if dist <= t:
                                tp[t_idx] += 1
                                tp_flag = True
                                # break
                        if not tp_flag:
                            fp[t_idx] += 1
                            fn[t_idx] += 1

                    matched_ids.append(instance_id)
        fp += len(baseline_res.keys()) - len(matched_ids)

    # compute precision and recall
    precision = tp / num_preds
    assert num_preds == (tp+fp)[0]
    recall = tp / num_gts
    tp_all += tp
    fp_all += fp
    fn_all += fn
    num_gts_all += num_gts
    num_preds_all += num_preds

# overall Stats
precision_all = tp_all / num_preds_all
recall_all = tp_all / num_gts_all
print("Finished evaluation!")
for t_idx, t in enumerate(thresholds):
    print("Threshold: {:.2f} -- Precision: {:.3f}, Recall: {:.3f}".format(t, precision_all[t_idx], recall_all[t_idx]))
print()

Finished evaluation!
Threshold: 0.25 -- Precision: 0.233, Recall: 0.249
Threshold: 0.50 -- Precision: 0.264, Recall: 0.281
Threshold: 0.75 -- Precision: 0.331, Recall: 0.353
Threshold: 1.00 -- Precision: 0.433, Recall: 0.463
Threshold: 1.50 -- Precision: 0.594, Recall: 0.634

