In [None]:
import torch
import sys
from tqdm import tqdm
sys.path.append('..')
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
import clip
from lidar_clippin.anno_loader import build_loader, CLASSES, PERIODS, WEATHERS
batch_size = 1
clip_model, clip_preprocess = clip.load("ViT-B/32")
dataset_root = "/proj/nlp4adas/datasets/once"
dataset_root = "/Users/s0000960/data/once/"
loader = build_loader(dataset_root, clip_preprocess, batch_size=batch_size, num_workers=8, split="val", skip_data=True)

In [None]:
img_feats = torch.load("img_feats_val.pt").to(device)
lidar_feats = torch.load("lidar_feats_val.pt").to(device)

In [None]:
def logit_img_txt(img_feat, txt_feat, model):
    img_feat = img_feat / img_feat.norm(dim=1, keepdim=True)
    txt_feat = txt_feat / txt_feat.norm(dim=1, keepdim=True)

    # cosine similarity as logits
    logit_scale = model.logit_scale.exp().float()
    logits_per_image = logit_scale * img_feat.float() @ txt_feat.t().float()
    logits_per_text = logits_per_image.t()
    return logits_per_text, logits_per_image

In [None]:
# Encode some text
text = clip.tokenize(["a photo of an animal on the road"]).to(device)
with torch.no_grad():
    text_features = clip_model.encode_text(text)
logits_per_text_l, logits_per_pc = logit_img_txt(lidar_feats, text_features, clip_model)
logits_per_text_i, logits_per_img = logit_img_txt(img_feats, text_features, clip_model)
pc_logits, pc_idxs = torch.topk(logits_per_text_l[0,:], 16)
img_logits, img_idxs = torch.topk(logits_per_text_i[0,:], 16)

In [None]:
# Automatically generate a number of prompts
import itertools
objects = ["car", "person walking on the road", "semi-truck", "bus", "parked bicycle", "person riding a bike", "forest"]
environments = ["on a rainy day", "to the left of the image", "on an empty road", "at night", "in winter"]
prompts = [f"a photo of a {obj[0]} {obj[1]}" for obj in itertools.product(objects, environments)]


In [None]:
text = clip.tokenize(prompts).to(device)
with torch.no_grad():
    text_features = clip_model.encode_text(text)
logits_per_text_l, logits_per_pc = logit_img_txt(lidar_feats, text_features, clip_model)
logits_per_text_i, logits_per_img = logit_img_txt(img_feats, text_features, clip_model)

In [None]:
cls_mask = {cls_name: torch.zeros(len(loader), dtype=torch.int32) for cls_name in CLASSES}
weather_mask = {weather_name: torch.zeros(len(loader), dtype=torch.bool) for weather_name in WEATHERS}
period_mask = {period_name: torch.zeros(len(loader), dtype=torch.bool) for period_name in PERIODS}


In [None]:
for i, (_, _, anno, meta) in tqdm(enumerate(loader)):
    weather_mask[meta[0]["weather"]][i] = 1
    period_mask[meta[0]["period"]][i] = 1
    for name, box2d, box3d in zip(anno[0]['names'], anno[0]['boxes_2d'], anno[0]['boxes_3d']):
        dist = torch.tensor(box3d[:3]).norm()
        if dist < 45:
            cls_mask[name][i] = 1
        elif cls_mask[name][i] == 0:
            cls_mask[name][i] = -1