In [None]:
import torch
import sys
import itertools
from tqdm import tqdm
import clip
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
sys.path.append('..')
from lidar_clippin.anno_loader import build_loader, CLASSES, WEATHERS
from lidar_clippin.helpers import MultiLoader

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
# Use same prompts as prior art for objects
OBJECT_PROMPT_TEMPLATES = [
    'a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.',
]

# Generate weather and period prompts by permuting these attributes
QUALITY_MODIFIERS = ['bad', 'good', 'clean', 'dirty', 'cropped', 'close-up']
FORMAT_MODIFIERS = ['photo',]
SCENE_MODIFIERS = ['environment', 'scene', 'road', 'street', 'intersection']
CAPTURE_MODIFIERS = ['taken', 'captured',]

WEATHER_PROMPT_TEMPLATES = set()
for quality, format, scene, capture in itertools.product(QUALITY_MODIFIERS, FORMAT_MODIFIERS, SCENE_MODIFIERS, CAPTURE_MODIFIERS):
    # Example: a good photo of a rainy environment
    WEATHER_PROMPT_TEMPLATES.add(f'a {quality} {format} of a {{}} {scene}.')
    WEATHER_PROMPT_TEMPLATES.add(f'a {quality} {format} {capture} on a {{}} day.')
    WEATHER_PROMPT_TEMPLATES.add(f'a {quality} {format} {capture} in a {{}} {scene}.')
    WEATHER_PROMPT_TEMPLATES.add(f'a {quality} {format} of many things in a {{}} {scene}.')

PERIOD_PROMPT_TEMPLATES = set()
for quality, format, scene, capture in itertools.product(QUALITY_MODIFIERS, FORMAT_MODIFIERS, SCENE_MODIFIERS, CAPTURE_MODIFIERS):
    if quality == 'bright' or quality == 'dark':
        continue  # Do not bias day/night by brightness in prompt
    # Example: a good photo taken at night
    PERIOD_PROMPT_TEMPLATES.add(f'a {quality} {format} {capture} at {{}}.')
    # Example: a good photo of a scene taken at night
    PERIOD_PROMPT_TEMPLATES.add(f'a {quality} {format} of a {scene} {capture} at {{}}.')
    PERIOD_PROMPT_TEMPLATES.add(f'a {quality} {format} of many things in a {scene} {capture} at {{}}.')
    PERIOD_PROMPT_TEMPLATES.add(f'a {quality} {format} of the {scene} {capture} at {{}}.')
PERIOD_PROMPT_TEMPLATES = list(set(PERIOD_PROMPT_TEMPLATES))  # Remove duplicates

EMPTY_PROMPTS = set()
BUSY_PROMPTS = set()
for quality, format, scene in itertools.product(QUALITY_MODIFIERS, FORMAT_MODIFIERS, SCENE_MODIFIERS):
    # Example: a good photo of a busy environment
    for busy_modifier in ("busy", "crowded", "full"):
        BUSY_PROMPTS.add(f'a {quality} {format} of extremely {busy_modifier} traffic during rush hour with a large number of nearby vehicles.')

    for empty_modifier in ("empty", "deserted", "abandoned"):
        EMPTY_PROMPTS.add(f'a {quality} {format} of a completely {empty_modifier} {scene} with no vehicles in sight.')

print("Num prompts per subcategory:")
print(f"  Weather: {len(WEATHER_PROMPT_TEMPLATES)}")
print(f"  Period: {len(PERIOD_PROMPT_TEMPLATES)}")
print(f"  Busy: {len(BUSY_PROMPTS)}")
print(f"  Empty: {len(EMPTY_PROMPTS)}")
print(f"  Objects: {len(OBJECT_PROMPT_TEMPLATES)}")

In [None]:
CLIP_VERSION = "ViT-B/32"
# CLIP_VERSION = "ViT-L/14"
USE_COSINE = False
# USE_COSINE = True

# Load data and features
batch_size = 1
clip_model, clip_preprocess = clip.load(CLIP_VERSION)
dataset_root = "/proj/nlp4adas/datasets/once"
dataset_root = "/Users/s0000960/data/once/"
anno_loader = build_loader(dataset_root, clip_preprocess, batch_size=batch_size, num_workers=8, split="val", skip_data=True, skip_anno=False)
_val_loader = build_loader(dataset_root, clip_preprocess, batch_size=batch_size, num_workers=8, split="val", skip_data=True, skip_anno=True)
_test_loader = build_loader(dataset_root, clip_preprocess, batch_size=batch_size, num_workers=8, split="test", skip_data=True, skip_anno=True)
noanno_loader = MultiLoader([_val_loader, _test_loader])
anno_dataset_for_vis, noanno_dataset_for_vis = None, None

tmp = CLIP_VERSION.lower().replace("/", "-")
if USE_COSINE:
    tmp += "_cosine"
anno_img_feats = torch.load(f"../features/once_{tmp}_val-anno_img.pt").to(device)
anno_lidar_feats = torch.load(f"../features/once_{tmp}_val-anno_lidar.pt").to(device)
noanno_img_feats = torch.cat((torch.load(f"../features/once_{tmp}_val_img.pt"), torch.load(f"../features/once_{tmp}_test_img.pt")),dim=0).to(device)
noanno_lidar_feats = torch.cat((torch.load(f"../features/once_{tmp}_val_lidar.pt"), torch.load(f"../features/once_{tmp}_test_lidar.pt")),dim=0).to(device)

# Compute joint features
# anno_joint_feats = anno_img_feats / anno_img_feats.norm(dim=1, keepdim=True) + anno_lidar_feats / anno_lidar_feats.norm(dim=1, keepdim=True)
# noanno_joint_feats = noanno_img_feats / noanno_img_feats.norm(dim=1, keepdim=True) + noanno_lidar_feats / noanno_lidar_feats.norm(dim=1, keepdim=True)
anno_joint_feats = anno_img_feats + anno_lidar_feats
noanno_joint_feats = noanno_img_feats + noanno_lidar_feats

assert noanno_img_feats.shape[0] == len(noanno_loader)

In [None]:
# Build masks
PERIODS = ("night", "day")
WEATHERS = ("sunny", "rainy")
BUSYNESS = ("busy", "empty")
NEARBY_CLASSES = [f"nearby {c}" for c in CLASSES]

masks = {name: torch.zeros(len(anno_loader), dtype=torch.bool) for name in itertools.chain(CLASSES, NEARBY_CLASSES, BUSYNESS)}
masks.update({name: torch.zeros(len(noanno_loader), dtype=torch.bool) for name in itertools.chain(PERIODS, WEATHERS)})
for i, (_, _, anno, meta) in tqdm(enumerate(noanno_loader)):
    weather = meta[0]["weather"]
    period = meta[0]["period"]
    if period != "night":
        period = "day"
    if weather in WEATHERS:
        masks[weather][i] = True
    masks[period][i] = True

for i, (_, _, anno, meta) in tqdm(enumerate(anno_loader)):
    num_busy = 0
    num_non_empty = 0
    for name, box2d, box3d in zip(anno[0]['names'], anno[0]['boxes_2d'], anno[0]['boxes_3d']):
        dist = torch.tensor(box3d[:3]).norm()
        assert box2d[2] > box2d[0]
        if dist < 15:
            masks[f"nearby {name}"][i] = True
        if name in ("Car", "Bus", "Truck"):
            if dist < 15:
                num_non_empty += 1
            if dist < 40:
                if box2d[2] - box2d[0] > 100:
                    num_non_empty += 1
                else:
                    num_non_empty += 0.5
            if dist < 60:
                num_busy += 1
        masks[name][i] = True
    if num_busy >= 5:
        masks["busy"][i] = True
    if num_non_empty < 1:
        masks["empty"][i] = True

In [None]:
###  Helper functions ###

from typing import Callable, Iterable, List


def logit_img_txt(img_feat, txt_feat, model):
    img_feat = img_feat / img_feat.norm(dim=1, keepdim=True)
    txt_feat = txt_feat / txt_feat.norm(dim=1, keepdim=True)

    # cosine similarity as logits
    logit_scale = model.logit_scale.exp().float()
    logits_per_image = logit_scale * img_feat.float() @ txt_feat.t().float()
    logits_per_text = logits_per_image.t()
    return logits_per_text, logits_per_image

def get_topk(prompts, k, img_feats, lidar_feats, joint_feats):
    text = clip.tokenize(prompts).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text)
        text_features = text_features.sum(axis=0, keepdim=True)
    logits_per_text_i, logits_per_img = logit_img_txt(img_feats, text_features, clip_model)
    logits_per_text_l, logits_per_pc = logit_img_txt(lidar_feats, text_features, clip_model)
    logits_per_text_j, logits_per_joint = logit_img_txt(joint_feats, text_features, clip_model)
    # logits_per_text_j = logits_per_text_i + logits_per_text_l

    _, img_idxs = torch.topk(logits_per_text_i[0,:], k)
    _, pc_idxs = torch.topk(logits_per_text_l[0,:], k)
    _, joint_idxs = torch.topk(logits_per_text_j[0,:], k)

    # Rank separately and then fuse rankings
    # pc_ranking = torch.argsort(torch.argsort(logits_per_text_l[0,:]))
    # img_ranking = torch.argsort(torch.argsort(logits_per_text_i[0,:]))
    # joint_ranking = pc_ranking * img_ranking
    # _, joint_idxs = torch.topk(joint_ranking, k)

    # Perform reranking. first round selects top 1% of the candidates, second round selects top K remaining candidates
    # second_round_scores = -torch.ones(lidar_feats.shape[0])
    # _, first_rank_winners = torch.topk(logits_per_text_i[0,:], 100*k)
    # second_round_scores[first_rank_winners] = logits_per_text_l[0,first_rank_winners]
    # _, first_rank_winners = torch.topk(logits_per_text_l[0,:], 100*k)
    # second_round_scores[first_rank_winners] = logits_per_text_i[0,first_rank_winners]
    # _, joint_idxs = torch.topk(second_round_scores, k)

    return img_idxs.numpy(), pc_idxs.numpy(), joint_idxs.numpy()


def get_topk_separate_prompts(image_prompts, lidar_prompts, k, img_feats, lidar_feats):
    with torch.no_grad():
        text_features_image = clip_model.encode_text(clip.tokenize(image_prompts).to(device)).sum(axis=0, keepdim=True)
        text_features_lidar = clip_model.encode_text(clip.tokenize(lidar_prompts).to(device)).sum(axis=0, keepdim=True)
    logits_per_text_i, logits_per_img = logit_img_txt(img_feats, text_features_image, clip_model)
    logits_per_text_l, logits_per_pc = logit_img_txt(lidar_feats, text_features_lidar, clip_model)
    logits_per_text_j = logits_per_text_i + logits_per_text_l

    _, pc_idxs = torch.topk(logits_per_text_l[0,:], k)
    _, img_idxs = torch.topk(logits_per_text_i[0,:], k)
    _, joint_idxs = torch.topk(logits_per_text_j[0,:], k)

    return img_idxs.numpy(), pc_idxs.numpy(), joint_idxs.numpy()



In [None]:
K = 10
VERBOSE = False
USE_ANNOS = True
# Computes overall score average for all object categories
image_overall, lidar_overall, joint_overall = 0.0, 0.0, 0.0

subcategories = [(PERIODS, PERIOD_PROMPT_TEMPLATES, False), (WEATHERS, WEATHER_PROMPT_TEMPLATES, False)]
if USE_ANNOS:
    subcategories += [
        (CLASSES, OBJECT_PROMPT_TEMPLATES, True),
        (NEARBY_CLASSES, OBJECT_PROMPT_TEMPLATES, True),
        (BUSYNESS, {"busy": BUSY_PROMPTS, "empty": EMPTY_PROMPTS}, True)
    ]
total_num = sum(len(subcat[0]) for subcat in subcategories)

image_overall, lidar_overall, joint_overall = 0.0, 0.0, 0.0
for subcategory, prompt_template, use_annos in subcategories:
    image_subcat, lidar_subcat, joint_subcat = 0.0, 0.0, 0.0
    for name in subcategory:
        if isinstance(prompt_template, dict):
            prompts = prompt_template[name]
        else:
            prompts = [prompt.format(name) for prompt in prompt_template]
        if use_annos:
            img_idxs, pc_idxs, joint_idxs = get_topk(prompts, K, anno_img_feats, anno_lidar_feats, anno_joint_feats)
        else:
            img_idxs, pc_idxs, joint_idxs = get_topk(prompts, K, noanno_img_feats, noanno_lidar_feats, noanno_joint_feats)
        num_positives = masks[name].sum().item()
        best_case = min(num_positives, K)
        image_score = masks[name][img_idxs].sum().numpy()/best_case
        lidar_score = masks[name][pc_idxs].sum().numpy()/best_case
        joint_score = masks[name][joint_idxs].sum().numpy()/best_case
        # random_score = masks[name][random_idxs].sum().numpy()/best_case
        if VERBOSE:
            print(f"R@{K} for {name}: ({num_positives} matches)")
            print("    image: ", image_score)
            print("    lidar: ", lidar_score)
            print("    joint: ", joint_score)
            # print("    random:", random_score)
        image_subcat += image_score/len(subcategory)
        lidar_subcat += lidar_score/len(subcategory)
        joint_subcat += joint_score/len(subcategory)
        image_overall += image_score/total_num
        lidar_overall += lidar_score/total_num
        joint_overall += joint_score/total_num
    if VERBOSE:
        print(f"Average R@{K} for {subcategory}:")
        print("    image: ", image_subcat)
        print("    lidar: ", lidar_subcat)
        print("    joint: ", joint_subcat)
        print("=========================================")
print(f"R@{K} overall:")
print("    image: ", image_overall)
print("    lidar: ", lidar_overall)
print("    joint: ", joint_overall)

In [None]:
for subcategory, prompt_template, use_annos in subcategories:
    for name in subcategory:
        num_positives = masks[name].sum().item()
        print(f"random score for {name}: {num_positives/len(noanno_lidar_feats):.2f}")

In [None]:
if anno_dataset_for_vis is None or noanno_dataset_for_vis is None:
    split = "val"
    use_annos = True
    anno_dataset_for_vis = build_loader(dataset_root, clip_preprocess, batch_size=batch_size, num_workers=1, split=split, skip_data=False, skip_anno=not use_annos).dataset
    noanno_dataset_for_vis = MultiLoader([
        build_loader(dataset_root, clip_preprocess, batch_size=1, num_workers=1, split="val", skip_data=False, skip_anno=True),
        build_loader(dataset_root, clip_preprocess, batch_size=1, num_workers=1, split="test", skip_data=False, skip_anno=True),
    ])
means = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cpu")
stds = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cpu")


In [None]:
K2 = 6
# img_idxs, pc_idxs, joint_idxs = get_topk(EMPTY_PROMPTS, K2, anno_img_feats, anno_lidar_feats, anno_joint_feats)
# image_prompts = ["a photo with large water droplets on the lens", "an extremely blurry photo where no details can be discerned", "an extremely rainy scene with huge water sprays"]
image = ["a nearby semi truck", "a semi truck in the middle of the road", "a semi truck on the street"]
# image_prompts = ["a photo with extreme glare/", "an extremely corrupted photo where no details can be discerned", "an extremely blinded and glared photo", "a corrupted photo with extreme glare and no objects can be seen"]
# lidar_prompts = ["a nearby car", "a car in the middle of the road", "a car on the street"]
# image_prompts = ["a nearby pedestrian", "a semi large number of nearby people", "a person in the middle of the road", "a person on the street"]
# image_prompts = ["a photo of a dog", "animal", "pet", "person walking his dog", "dog crossing the street"]
# image_prompts = ["a photo of a nearby tuk-tuk, a type of three-wheeler", "a three-wheeler in the middle of the road", "a tuk-tuk, three-wheeler, on the street"]
# image_prompts = [p.format("night") for p in PERIOD_PROMPT_TEMPLATES]
# image_prompts = [p.format("foggy not wet") for p in WEATHER_PROMPT_TEMPLATES]
image_prompts = ["a small child on the sidewalk", "a kid walking with his parents", "a very tiny person"]
lidar_prompts = image_prompts
SUBSAMPLING = 1  # avoid picking extremely similar samples
OFFSET = 0
img_idxs, pc_idxs, joint_idxs = get_topk_separate_prompts(image_prompts, lidar_prompts, K2, noanno_img_feats[OFFSET::SUBSAMPLING], noanno_lidar_feats[OFFSET::SUBSAMPLING])
dataset_for_vis = noanno_dataset_for_vis
pc_idxs = pc_idxs * SUBSAMPLING + OFFSET
img_idxs = img_idxs * SUBSAMPLING + OFFSET
joint_idxs = joint_idxs * SUBSAMPLING + OFFSET

PLOT_IMAGE = True
PLOT_LIDAR = True
PLOT_JOINT = False

rows = int(PLOT_IMAGE) + int(PLOT_LIDAR) + int(PLOT_JOINT)

fig = plt.figure(figsize=(10*K2,10*rows), dpi=200)
width, height = 1/K2, 1/(rows+0.1)
# ax = fig.add_axes([0, 0, 1, 1])

curr_row = 1
if PLOT_IMAGE:
    for i, idx in enumerate(img_idxs):
        ax = fig.add_axes([i/K2, 1/rows, width, height])
        ax.imshow((dataset_for_vis[idx][0].permute(1,2,0)*stds + means).numpy())
        ax.axis('off')
    curr_row+=1

if PLOT_LIDAR:
    for i, idx in enumerate(pc_idxs):
        ax = fig.add_axes([i/K2, curr_row/rows, width, height])
        ax.imshow((dataset_for_vis[idx][0].permute(1,2,0)*stds + means).numpy())
        ax.axis('off')

    curr_row+=1

if PLOT_JOINT:
    for i, idx in enumerate(joint_idxs):
        img, pc = dataset_for_vis[idx][:2]
        # WHAT THE FUCK IS GOING ON HERE. I HAVE TO DO THIS TO GET THE RIGHT ORDERING
        ax = fig.add_axes([i/K2, 0, width, height])
        ax.imshow((img.permute(1,2,0)*stds + means).numpy())
        ax.axis('off')
        
        ax = fig.add_axes([i/K2 + 1.96/K2/3, 1.93/3/rows, width/3, height/3])
        # draw a white box around the image
        rect = patches.Rectangle((-10,4), 20, 16, alpha=1.0, facecolor='white')
        ax.add_patch(rect)
        # draw the point cloud
        pc = pc[pc[:,0] < 20]
        pc = pc[pc[:,1] < 10]
        pc = pc[pc[:,1] > -10]
        col = pc[:,3]
        ax.scatter(-pc[:,1], pc[:,0], s=0.1, c=col**0.3, cmap="coolwarm")
        ax.axis("scaled")
        ax.axis("off")
        ax.set_ylim(0, 20)
        ax.set_xlim(-10, 10)
    curr_row+=1

print(pc_idxs)
print(img_idxs)
print(joint_idxs)

In [None]:
for i, idx in enumerate(pc_idxs):
    plt.imshow((dataset_for_vis[idx][0].permute(1,2,0)*stds + means).numpy())
    plt.axis('off')
    plt.savefig("headlights_{}.png".format(i), bbox_inches='tight', pad_inches=0)

# Zero-Shot Classification

In [None]:
# zero_shot_classes =  ["car", "bus", "truck", "person", "bike or moped",  "animal", "tree or bush", "three-wheeler"]
zero_shot_classes =  ["car", "bus", "truck", "person", "two-wheeler",  "animal", "three-wheeler"]
# reverse the order to make plots look good hehe
zero_shot_classes.reverse()
class_embeddings = []
for cls_name in tqdm(zero_shot_classes, "computing class embeddings..."):
    # print("embedding ", cls_name)
    prompts = [template.format(cls_name) for template in OBJECT_PROMPT_TEMPLATES]
    with torch.no_grad():
        class_embeddings.append(clip_model.encode_text(clip.tokenize(prompts).to(device)).sum(axis=0, keepdim=True))
class_embeddings = torch.cat(class_embeddings, dim=0)

In [None]:
# SAMPLE_IDXS = [
#     61687,  # greenery
#     0,      # bus
#     1947,   # car
#     8601,   # bus-car mix
#     88888,  # nearby car
#     58777,  # person biking
#     113859, # moped
#     13406,  # person walking unclear object
#     96150,  # person with umbrella
#     34170,  # person with backpack
#     58854,  # doggo
# ]
SAMPLE_IDXS = [109737]
dataset_for_vis = noanno_dataset_for_vis
lidar_feat = noanno_lidar_feats[SAMPLE_IDXS]
image_feat = noanno_img_feats[SAMPLE_IDXS]
joint_feat = noanno_joint_feats[SAMPLE_IDXS]
with torch.no_grad():
    lidar_scores_all = logit_img_txt(lidar_feat, class_embeddings, clip_model)[0].softmax(0)
    image_scores_all = logit_img_txt(image_feat, class_embeddings, clip_model)[0].softmax(0)
    joint_scores_all = logit_img_txt(joint_feat, class_embeddings, clip_model)[0].softmax(0)

for i, sample_idx in enumerate(SAMPLE_IDXS):
    img, pc = dataset_for_vis[sample_idx][:2]
    lidar_scores = lidar_scores_all[:, i]
    image_scores = image_scores_all[:, i]
    joint_scores = joint_scores_all[:, i]

    fig = plt.figure(figsize=(7,5), dpi=150)

    # Draw the image
    ax = fig.add_axes([0, 0, 5/7, 1])
    ax.imshow((img.permute(1,2,0)*stds + means).numpy())
    ax.axis('off')

    # Draw the point cloud
    ax = fig.add_axes([0.455, 0.698, 0.3, 0.3])
    rect = patches.Rectangle((-10,0), 20, 20, alpha=1.0, facecolor='white')
    ax.add_patch(rect)
    pc = pc[pc[:,0] < 20]
    pc = pc[pc[:,1] < 10]
    pc = pc[pc[:,1] > -10]
    col = pc[:,3]
    ax.scatter(-pc[:,1], pc[:,0], s=0.1, c=col**0.3, cmap="coolwarm")
    ax.axis("scaled")
    ax.axis("off")
    ax.set_ylim(0, 20)
    ax.set_xlim(-10, 10)

    # Draw the class bar chart
    ax = fig.add_axes([5/7+0.001, 0.05, 2/7-0.001, 0.9])
    # plot class scores as a horizontal bar chart with image lidar and joint scores side by side
    # top down not down up
    ax.barh(np.arange(len(zero_shot_classes)) + 0.34, image_scores, height=0.15, color='#8DD376', label='image')
    ax.barh(np.arange(len(zero_shot_classes)) + 0.17, lidar_scores, height=0.15, color='#DE8B8A', label='lidar')
    ax.barh(np.arange(len(zero_shot_classes)) , joint_scores, height=0.15, color='#B172E0', label='joint')
    ax.set_yticks(np.arange(len(zero_shot_classes))+0.57)
    ax.set_yticklabels(zero_shot_classes)
    ax.set_xticks([0.1, 0.4, 0.7, 1.0])
    ax.yaxis.tick_right()
    ax.tick_params(axis="y", direction="in", pad=-142, labelsize=10, length=0)
    ax.tick_params(axis="x", direction='inout', length=5)
    ax.set_xlim(0, 1)
    ax.spines.right.set_visible(False)
    ax.spines.left.set_visible(False) 
    ax.spines.top.set_visible(False)
    ax.set_title("Classification scores")
    ax.legend()
    fig.savefig("{}.png".format(sample_idx), bbox_inches='tight', pad_inches=0)


In [None]:
lidar_scores_all.shape

In [None]:
ID_TO_SAVE = 109737
tmp = "ViT-B/32".lower().replace("/", "-")
save_feat = torch.cat((torch.load(f"../features/once_{tmp}_val_lidar.pt"), torch.load(f"../features/once_{tmp}_test_lidar.pt")),dim=0)[ID_TO_SAVE]
torch.save(save_feat.clone(), f"lidar_feat_{ID_TO_SAVE}.pt")

In [None]:

plt.imshow(dataset_for_vis[ID_TO_SAVE][0].permute(1,2,0)*stds + means)
plt.axis("off")
plt.savefig("front_page_image.png", bbox_inches='tight', pad_inches=0)

In [None]:
plt.figure(figsize=(5,5), dpi=200)

# Draw the point cloud
pc  = dataset_for_vis[ID_TO_SAVE][1].numpy()
pc = pc[pc[:,0] < 20]
pc = pc[pc[:,1] < 10]
pc = pc[pc[:,1] > -10]
col = pc[:,3]
plt.scatter(-pc[:,1], pc[:,0], s=0.1, c=col**0.3, cmap="coolwarm")
plt.axis("scaled")
plt.axis("off")
plt.ylim(0, 20)
plt.xlim(-10, 10)
plt.savefig("front_page_lidar.png", bbox_inches='tight', pad_inches=0)

In [None]:
print(lidar_scores)
print(zero_shot_classes)