In [None]:
import torch
import sys
import itertools
from functools import partial
from typing import Dict

from tqdm import tqdm
import clip
sys.path.append('..')
from lidarclip.anno_loader import build_anno_loader, CLASSES, WEATHERS
from lidarclip.helpers import MultiLoader, try_paths, logit_img_txt
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
from lidarclip.prompts import OBJECT_PROMPT_TEMPLATES
print("Num prompts per subcategory:")
print(f"  Objects: {len(OBJECT_PROMPT_TEMPLATES)}")

In [None]:
CLIP_VERSION = "ViT-L/14"

# Load data and features
batch_size = 1
clip_model, clip_preprocess = clip.load(CLIP_VERSION)
feature_version = CLIP_VERSION.lower().replace("/", "-")
feature_root = try_paths("/proj/nlp4adas/features", "../features")
bev_feats = torch.load(f"{feature_root}/once_{feature_version}_val_lidar_objs_bev_debug.pt", map_location=device)
dataset_root = try_paths("/proj/nlp4adas/datasets/once", "/Users/s0000960/data/once/")
loader = build_anno_loader(dataset_root, clip_preprocess, batch_size=1, num_workers=0, split="val", skip_data=False, skip_anno=False, )
print(bev_feats.shape)
img, pc, anno, _ = next(iter(loader))
img, pc, anno = img[0], pc[0].numpy(), anno[0] 

In [None]:
CATEGORIES = CLASSES
def gen_cls_embedding(cls_name: str) -> torch.Tensor:
    print(f"Generating embedding for {cls_name}")
    # prompts = [template.format(cls_name) for template in OBJECT_PROMPT_TEMPLATES]
    prompts = [cls_name]
    with torch.no_grad():
        tokenized_prompts = clip.tokenize(prompts).to(device)
        cls_features = clip_model.encode_text(tokenized_prompts)
        return cls_features.sum(axis=0, keepdim=True)
cls_embeddings = {name: gen_cls_embedding(name) for name in CATEGORIES}
print("Generated embeddings for: ", list(cls_embeddings.keys()))
cls_embeddings_pt = torch.vstack(list(cls_embeddings.values()))

In [None]:
from collections import defaultdict
from einops import rearrange
bev_cls_logits,_ = logit_img_txt(rearrange(bev_feats, "1 h w c -> (h w) c"), cls_embeddings_pt, clip_model)
bev_cls_scores = rearrange(bev_cls_logits.softmax(0), "c (h w) -> h w c", h=bev_feats.shape[-3], w=bev_feats.shape[-2])
print(bev_cls_scores.shape)


In [None]:
import numpy as np
import matplotlib.pyplot as plt

means = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cpu")
stds = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cpu")


def plot_weighted_cloud(pc, weight, img, plot_weight_grid=False, use_colorbar=False, cmap="coolwarm", half_cloud=False, title=None):
    if half_cloud:
        pc = pc[pc[:,0] < 20]
        pc = pc[pc[:,1] < 10]
        pc = pc[pc[:,1] > -10]
        x_idx = (pc[:,0] / 0.5).astype(np.int32)
        y_idx = ((pc[:,1] + 10) / 0.5).astype(np.int32)
    else:
        pc = pc[pc[:,0] < 40]
        pc = pc[pc[:,1] < 20]
        pc = pc[pc[:,1] > -20]
        x_idx = (pc[:,0] / 0.5).astype(np.int32)
        y_idx = ((pc[:,1] + 20) / 0.5).astype(np.int32)

    plt.figure(dpi=200)
    col = weight[y_idx, x_idx]
    plt.scatter(-pc[:,1], pc[:,0], s=0.1, c=col**0.3, cmap=cmap)
    plt.axis("scaled")
    if half_cloud:
        plt.ylim(0, 20)
        plt.xlim(-10, 10)
    else:
        plt.ylim(0, 40)
        plt.xlim(-13, 13)
    if use_colorbar:
        plt.colorbar()
    if title:
        plt.title(title)
    plt.show()
    if plot_weight_grid:
        plt.matshow(weight.T**0.5)
        plt.show()
    plt.figure(dpi=200)
    plt.axis('off')
    plt.imshow(rearrange(img, "c h w -> h w c").cpu()*stds+means)
    plt.show()

In [None]:
weight = np.argmax(bev_cls_scores.detach(), axis=-1)
weight = bev_cls_scores[..., 2].detach()
plot_weighted_cloud(pc, weight, img, plot_weight_grid=False, cmap="coolwarm")


In [None]:
from skimage.draw import polygon2mask

VOXEL_SIZE=0.5
_, y_dim, x_dim, _ = bev_feats.shape

for box3d in anno["boxes_3d"]:
    # Convert box to feature grid space
    box_center = box3d[:2] / VOXEL_SIZE
    # Compensate y coordinate for the fact that the lidar features are
    # centered around the ego vehicle in the y direction (x starts from 0)
    box_center[1] += y_dim / 2
    box_size = box3d[3:5] / VOXEL_SIZE
    box_rotation = -torch.Tensor([box3d[6]])

    # Create the corner points of the bounding box
    box_points = torch.tensor(
        [
            [-box_size[0] / 2, -box_size[1] / 2],
            [-box_size[0] / 2, box_size[1] / 2],
            [box_size[0] / 2, box_size[1] / 2],
            [box_size[0] / 2, -box_size[1] / 2],
        ]
    )
    # Create a rotation matrix from the box rotation
    rotation_matrix = torch.tensor(
        [
            [torch.cos(box_rotation), -torch.sin(box_rotation)],
            [torch.sin(box_rotation), torch.cos(box_rotation)],
        ]
    )
    # Rotate the corner points of the bounding box
    box_points = torch.matmul(box_points, rotation_matrix) + box_center
    # Create a mask of the pixels that are within the bounding box
    # Flip x and y (since y is height and x is width)
    mask = polygon2mask((x_dim, y_dim), box_points.cpu().numpy()[:, ::-1])
    # Pool the features within the bounding box
    # pooled_features = bev_feature[mask].mean(dim=(0))
    plt.matshow(mask)
    plt.show()
plt.matshow(bev_feats[0].sum(-1))
plt.show()