In [None]:
import torch
import sys
import itertools
from functools import partial

from tqdm import tqdm
import clip
sys.path.append('..')
from lidarclip.anno_loader import build_anno_loader
from lidarclip.anno_loader import NUSCENES_CLASSES as CLASSES
from lidarclip.helpers import MultiLoader, try_paths, get_topk, get_topk_separate_prompts
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
from lidarclip.prompts import OBJECT_PROMPT_TEMPLATES, BUSY_PROMPTS, EMPTY_PROMPTS
print("Num prompts per subcategory:")
print(f"  Busy: {len(BUSY_PROMPTS)}")
print(f"  Empty: {len(EMPTY_PROMPTS)}")
print(f"  Objects: {len(OBJECT_PROMPT_TEMPLATES)}")

In [None]:
CLIP_VERSION = "ViT-L/14"

# Load data and features
clip_model, clip_preprocess = clip.load(CLIP_VERSION)
get_topk = partial(get_topk, clip_model=clip_model, device=device)
get_topk_separate_prompts = partial(get_topk_separate_prompts, clip_model=clip_model, device=device)

dataset_root = try_paths("/proj/nlp4adas/datasets/nuscenes", "/Users/s0000960/data/nuscenes/")
anno_loader = build_anno_loader(dataset_root, clip_preprocess, batch_size=1, num_workers=0, split="val", skip_data=True, skip_anno=False, dataset_name="nuscenes")

feature_version = CLIP_VERSION.lower().replace("/", "-")
feature_root = try_paths("/proj/nlp4adas/features", "../features")

anno_img_feats = torch.load(f"{feature_root}/nuscenes_{feature_version}__joint-trained_val_img.pt").to(device)
anno_lidar_feats = torch.load(f"{feature_root}/nuscenes_{feature_version}__joint-trained_val_lidar.pt").to(device)
# Compute joint features
# anno_joint_feats = anno_img_feats / anno_img_feats.norm(dim=1, keepdim=True) + anno_lidar_feats / anno_lidar_feats.norm(dim=1, keepdim=True)
anno_joint_feats = anno_img_feats + anno_lidar_feats

print(len(anno_loader))

In [None]:
# Build masks
BUSYNESS = ("busy", "empty")
NEARBY_CLASSES = [f"nearby {c}" for c in CLASSES]
NEARBY_CUTOFF = 20 # Longer than ONCE due to the smaller nuscenes dataset not having enough positives for eval

masks = {name: torch.zeros(len(anno_loader), dtype=torch.bool) for name in itertools.chain(CLASSES, NEARBY_CLASSES, BUSYNESS)}

for i, (_, _, anno, meta) in tqdm(enumerate(anno_loader)):
    num_busy = 0
    num_non_empty = 0
#     print(i, "num annos", len(anno[0]['names']))
    for name, box3d in zip(anno[0]['names'], anno[0]['boxes_3d']):
        dist = box3d[:3].norm()
        if dist < NEARBY_CUTOFF:
            masks[f"nearby {name}"][i] = True
        if name in ("Car", "Bus", "Truck", "Trailer"):
            if dist < 30:
                num_non_empty += 1
            if dist < 60:
                num_busy += 1
#             print(i, name)
        masks[name][i] = True
#     print(i, num_busy)
    if num_busy >= 3:
#         print(i, "busy")
        masks["busy"][i] = True
    if num_non_empty < 1:
        masks["empty"][i] = True


In [None]:
subcategories = [
    (CLASSES, OBJECT_PROMPT_TEMPLATES, True),
    (NEARBY_CLASSES, OBJECT_PROMPT_TEMPLATES, True),
    (BUSYNESS, {"busy": BUSY_PROMPTS, "empty": EMPTY_PROMPTS}, True)
]
total_num = sum(len(subcat[0]) for subcat in subcategories)

In [None]:
PRINT_ELEMENTS = True
PRINT_SUBCATEGORIES = False
for K in [1,10,100]:
    image_overall, lidar_overall, joint_overall = 0.0, 0.0, 0.0
    for subcategory, prompt_template, needs_annos in subcategories:
        image_subcat, lidar_subcat, joint_subcat = 0.0, 0.0, 0.0
        for name in subcategory:
            if isinstance(prompt_template, dict):
                prompts = prompt_template[name]
            else:
                prompts = [prompt.format(name) for prompt in prompt_template]
            if needs_annos:
                img_idxs, pc_idxs, joint_idxs = get_topk(prompts, K, anno_img_feats, anno_lidar_feats, anno_joint_feats)
            else:
                # img_idxs, pc_idxs, joint_idxs = get_topk(prompts, K, noanno_img_feats, noanno_lidar_feats, noanno_joint_feats)
                raise NotImplementedError
            num_positives = masks[name].sum().item()
            best_case = min(num_positives, K)
            image_score = masks[name][img_idxs].sum().numpy()/best_case
            lidar_score = masks[name][pc_idxs].sum().numpy()/best_case
            joint_score = masks[name][joint_idxs].sum().numpy()/best_case
            # random_score = masks[name][random_idxs].sum().numpy()/best_case
            if PRINT_ELEMENTS:
                print(f"    P@{K} for {name}: ({num_positives} matches)")
                print("        image: ", image_score)
                print("        lidar: ", lidar_score)
                print("        joint: ", joint_score)
                # print("    random:", random_score)
            image_subcat += image_score/len(subcategory)
            lidar_subcat += lidar_score/len(subcategory)
            joint_subcat += joint_score/len(subcategory)
            image_overall += image_score/total_num
            lidar_overall += lidar_score/total_num
            joint_overall += joint_score/total_num
        if PRINT_SUBCATEGORIES:
            print(f"Average P@{K} for {subcategory}:")
            print("    image: ", image_subcat)
            print("    lidar: ", lidar_subcat)
            print("    joint: ", joint_subcat)
            print("=========================================")
    print(f"P@{K} overall:")
    print("    image: ", image_overall)
    print("    lidar: ", lidar_overall)
    print("    joint: ", joint_overall)
    print("=========================================")
    print("=========================================")

In [None]:
# Compute the expected precision for random sampling
overall_random = 0
num_cats = 0
for subcategory, prompt_template, needs_annos in subcategories:
    for name in subcategory:
        num_positives = masks[name].sum().item()
        if needs_annos:
            total_num = len(anno_lidar_feats)
        else:
            raise NotImplementedError  # len(noanno_lidar_feats)
        randscore = num_positives/total_num
        print(f"random score for {name}: {randscore:.2f}")
        overall_random += randscore
        num_cats += 1
print(f"overall random score: {overall_random / num_cats :.2f}")