In [None]:
import torch
import sys
import itertools
from functools import partial

from tqdm import tqdm
import clip
sys.path.append('..')
from lidarclip.anno_loader import build_anno_loader, CLASSES, WEATHERS
from lidarclip.helpers import MultiLoader, try_paths, get_topk, get_topk_separate_prompts
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
from lidarclip.prompts import WEATHER_PROMPT_TEMPLATES, PERIOD_PROMPT_TEMPLATES, OBJECT_PROMPT_TEMPLATES, BUSY_PROMPTS, EMPTY_PROMPTS
print("Num prompts per subcategory:")
print(f"  Weather: {len(WEATHER_PROMPT_TEMPLATES)}")
print(f"  Period: {len(PERIOD_PROMPT_TEMPLATES)}")
print(f"  Busy: {len(BUSY_PROMPTS)}")
print(f"  Empty: {len(EMPTY_PROMPTS)}")
print(f"  Objects: {len(OBJECT_PROMPT_TEMPLATES)}")

In [None]:
CLIP_VERSION = "ViT-L/14"
USE_COSINE = False
JOINT_FEATURES = False

# Load data and features
clip_model, clip_preprocess = clip.load(CLIP_VERSION)
get_topk = partial(get_topk, clip_model=clip_model, device=device)
get_topk_separate_prompts = partial(get_topk_separate_prompts, clip_model=clip_model, device=device)

dataset_root = try_paths("/proj/nlp4adas/datasets/once", "/Users/s0000960/data/once/")
anno_loader = build_anno_loader(dataset_root, clip_preprocess, batch_size=1, num_workers=8, split="val", skip_data=True, skip_anno=False, skip_anno_transform=True)
_val_loader = build_anno_loader(dataset_root, clip_preprocess, batch_size=1, num_workers=8, split="val", skip_data=True, skip_anno=True, skip_anno_transform=True)
_test_loader = build_anno_loader(dataset_root, clip_preprocess, batch_size=1, num_workers=8, split="test", skip_data=True, skip_anno=True, skip_anno_transform=True)
noanno_loader = MultiLoader([_val_loader, _test_loader])

feature_version = CLIP_VERSION.lower().replace("/", "-")
if USE_COSINE:
    feature_version += "_cosine"
if JOINT_FEATURES:
    feature_version += "__joint-trained"
feature_root = try_paths("/proj/nlp4adas/features", "../features")

anno_img_feats = torch.load(f"{feature_root}/once_{feature_version}_val-anno_img.pt").to(device)
anno_lidar_feats = torch.load(f"{feature_root}/once_{feature_version}_val-anno_lidar.pt").to(device)

noanno_img_feats = torch.cat((torch.load(f"{feature_root}/once_{feature_version}_val_img.pt"), torch.load(f"{feature_root}/once_{feature_version}_test_img.pt")),dim=0).to(device)
noanno_lidar_feats = torch.cat((torch.load(f"{feature_root}/once_{feature_version}_val_lidar.pt"), torch.load(f"{feature_root}/once_{feature_version}_test_lidar.pt")),dim=0).to(device)

# Compute joint features
# anno_joint_feats = anno_img_feats / anno_img_feats.norm(dim=1, keepdim=True) + anno_lidar_feats / anno_lidar_feats.norm(dim=1, keepdim=True)
# noanno_joint_feats = noanno_img_feats / noanno_img_feats.norm(dim=1, keepdim=True) + noanno_lidar_feats / noanno_lidar_feats.norm(dim=1, keepdim=True)
anno_joint_feats = anno_img_feats + anno_lidar_feats
noanno_joint_feats = noanno_img_feats + noanno_lidar_feats

assert noanno_img_feats.shape[0] == len(noanno_loader)

In [None]:
# Build masks
PERIODS = ("night", "day")
WEATHERS = ("sunny", "rainy")
BUSYNESS = ("busy", "empty")
NEARBY_CLASSES = [f"nearby {c}" for c in CLASSES]
NEARBY_CUTOFF = 10

masks = {name: torch.zeros(len(anno_loader), dtype=torch.bool) for name in itertools.chain(CLASSES, NEARBY_CLASSES, BUSYNESS)}
masks.update({name: torch.zeros(len(noanno_loader), dtype=torch.bool) for name in itertools.chain(PERIODS, WEATHERS)})
anno_meta_masks = {name: torch.zeros(len(anno_loader), dtype=torch.bool) for name in itertools.chain(PERIODS, WEATHERS)}

for i, (_, _, anno, meta) in tqdm(enumerate(noanno_loader)):
    weather = meta[0]["weather"]
    period = meta[0]["period"]
    if period != "night":
        period = "day"
    if weather in WEATHERS:
        masks[weather][i] = True
    masks[period][i] = True

for i, (_, _, anno, meta) in tqdm(enumerate(anno_loader)):
    # Update anno meta masks
    weather = meta[0]["weather"]
    period = meta[0]["period"]
    if period != "night":
        period = "day"
    if weather in WEATHERS:
        anno_meta_masks[weather][i] = True
    anno_meta_masks[period][i] = True

    # Update anno masks
    num_busy = 0
    num_non_empty = 0
    for name, box2d, box3d in zip(anno[0]['names'], anno[0]['boxes_2d'], anno[0]['boxes_3d']):
        dist = box3d[:3].norm()
        if dist < NEARBY_CUTOFF:
            masks[f"nearby {name}"][i] = True
        if name in ("Car", "Bus", "Truck"):
            if dist < 15:
                num_non_empty += 1
            if dist < 40:
                if box2d[2] - box2d[0] > 100:
                    num_non_empty += 1
                else:
                    num_non_empty += 0.5
            if dist < 60:
                num_busy += 1
        masks[name][i] = True
    if num_busy >= 5:
        masks["busy"][i] = True
    if num_non_empty < 1:
        masks["empty"][i] = True

In [None]:
USE_ANNOS = True
NIGHT_ONLY = False

subcategories = []
subcategories += [(PERIODS, PERIOD_PROMPT_TEMPLATES, False), (WEATHERS, WEATHER_PROMPT_TEMPLATES, False)]
if USE_ANNOS:
    subcategories += [
        (CLASSES, OBJECT_PROMPT_TEMPLATES, True),
        (NEARBY_CLASSES, OBJECT_PROMPT_TEMPLATES, True),
        (BUSYNESS, {"busy": BUSY_PROMPTS, "empty": EMPTY_PROMPTS}, True)
    ]
total_num = sum(len(subcat[0]) for subcat in subcategories)

if NIGHT_ONLY:
    # Only evaluate on a subset of the annotated frames
    filter_mask = anno_meta_masks['night']
    anno_img_feats = anno_img_feats[filter_mask]
    anno_lidar_feats = anno_lidar_feats[filter_mask]
    anno_joint_feats = anno_joint_feats[filter_mask]
    for mask_name, mask in list(masks.items()):
        if len(mask) == len(filter_mask):
            masks[mask_name] = mask[filter_mask]

In [None]:
PRINT_ELEMENTS = True
PRINT_SUBCATEGORIES = False
# for K in [1,10,100]:
for K in [10,100]:
    image_overall, lidar_overall, joint_overall = 0.0, 0.0, 0.0
    for subcategory, prompt_template, needs_annos in subcategories:
        image_subcat, lidar_subcat, joint_subcat = 0.0, 0.0, 0.0
        for name in subcategory:
            if isinstance(prompt_template, dict):
                prompts = prompt_template[name]
            else:
                prompts = [prompt.format(name) for prompt in prompt_template]
            if needs_annos:
                img_idxs, pc_idxs, joint_idxs = get_topk(prompts, K, anno_img_feats, anno_lidar_feats, anno_joint_feats)
            else:
                img_idxs, pc_idxs, joint_idxs = get_topk(prompts, K, noanno_img_feats, noanno_lidar_feats, noanno_joint_feats)
            num_positives = masks[name].sum().item()
            best_case = min(num_positives, K)
            image_score = masks[name][img_idxs].sum().numpy()/best_case
            lidar_score = masks[name][pc_idxs].sum().numpy()/best_case
            joint_score = masks[name][joint_idxs].sum().numpy()/best_case
            if PRINT_ELEMENTS:
                print(f"    P@{K} for {name}: ({num_positives} matches)")
                print("        image: ", image_score)
                print("        lidar: ", lidar_score)
                print("        joint: ", joint_score)
            image_subcat += image_score/len(subcategory)
            lidar_subcat += lidar_score/len(subcategory)
            joint_subcat += joint_score/len(subcategory)
            image_overall += image_score/total_num
            lidar_overall += lidar_score/total_num
            joint_overall += joint_score/total_num
        if PRINT_SUBCATEGORIES:
            print(f"Average P@{K} for {subcategory}:")
            print("    image: ", image_subcat)
            print("    lidar: ", lidar_subcat)
            print("    joint: ", joint_subcat)
            print("=========================================")
    print(f"P@{K} overall:")
    print("    image: ", image_overall)
    print("    lidar: ", lidar_overall)
    print("    joint: ", joint_overall)
    print("=========================================")
    print("=========================================")

In [None]:
# Compute the expected precision for random sampling
for subcategory, prompt_template, needs_annos in subcategories:
    for name in subcategory:
        num_positives = masks[name].sum().item()
        total_num = len(anno_lidar_feats) if needs_annos else len(noanno_lidar_feats)
        print(f"random score for {name}: {num_positives/total_num:.2f}")