In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
from lidar_clippin.model.sst import build_sst
from mmcv.runner import load_checkpoint
from lidar_clippin.model.sst import LidarEncoderSST
import pytorch_lightning as pl
import clip
import os
from train import LidarClippin

clip_model, clip_preprocess = clip.load("ViT-B/32")
lidar_encoder = LidarEncoderSST("../lidar_clippin/model/sst_encoder_only_config.py")
model = LidarClippin(lidar_encoder, clip_model, 1, 1)
load_checkpoint(model, "/proj/nlp4adas/checkpoints/35vsmuyp/epoch=97-step=32842.ckpt", map_location="cpu")
model.to(device)
pass

In [None]:
from lidar_clippin.loader import build_loader
batch_size = 32
loader = build_loader("/proj/nlp4adas/datasets/once", clip_preprocess, batch_size=batch_size, num_workers=1, split="val")

In [None]:
img_feats = torch.load("img_feats_val.pt").to(device)
lidar_feats = torch.load("lidar_feats_val.pt").to(device)

In [None]:
def logit_img_txt(img_feat, txt_feat, model):
    img_feat = img_feat / img_feat.norm(dim=1, keepdim=True)
    txt_feat = txt_feat / txt_feat.norm(dim=1, keepdim=True)

    # cosine similarity as logits
    logit_scale = model.logit_scale.exp().float()
    logits_per_image = logit_scale * img_feat.float() @ txt_feat.t().float()
    logits_per_text = logits_per_image.t()
    return logits_per_text, logits_per_image

In [None]:
# Encode some text
text = clip.tokenize(["a photo of an animal on the road"]).to(device)
with torch.no_grad():
    text_features = clip_model.encode_text(text)
logits_per_text_l, logits_per_pc = logit_img_txt(lidar_feats, text_features, clip_model)
logits_per_text_i, logits_per_img = logit_img_txt(img_feats, text_features, clip_model)
pc_logits, pc_idxs = torch.topk(logits_per_text_l[0,:], 16)
img_logits, img_idxs = torch.topk(logits_per_text_i[0,:], 16)

In [None]:
from einops import rearrange
from matplotlib import pyplot as plt
means = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cpu")
stds = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cpu")

dataset = loader.dataset

fig, axs = plt.subplots(4,4, figsize=(15, 15))
for i in range(4):
    for j in range(4):
        idx = i*4 + j
        logit = pc_logits[idx]
        d_idx = pc_idxs[idx]
        
        img, _ = dataset[d_idx]
        
        axs[i,j].imshow(rearrange(img, "c h w -> h w c").cpu()*stds+means)
        axs[i,j].set_title(idx)


In [None]:
import numpy as np
fig, axs = plt.subplots(4,4, figsize=(30, 15))
for i in range(4):
    for j in range(4):
        idx = i*4 + j
        logit = pc_logits[idx]
        d_idx = pc_idxs[idx]
        _, pc = dataset[d_idx]
        
        pc = pc.cpu()
        axs[i,j].scatter(-pc[:,1], pc[:,0], s=0.1, c=np.clip(pc[:, 3], 0, 1), cmap="coolwarm")
        axs[i,j].axis("equal")
        axs[i,j].set_xlim(-40, 40)
        axs[i,j].set_ylim(0, 40)
        axs[i,j].set_title(idx)

plt.show()

In [None]:
# Automatically generate a number of prompts
import itertools
objects = ["car", "person walking on the road", "semi-truck", "bus", "parked bicycle", "person riding a bike", "forest"]
environments = ["on a rainy day", "to the left of the image", "on an empty road", "at night", "in winter"]
prompts = [f"a photo of a {obj[0]} {obj[1]}" for obj in itertools.product(objects, environments)]
prompts

In [None]:
text = clip.tokenize(prompts).to(device)
with torch.no_grad():
    text_features = clip_model.encode_text(text)
logits_per_text_l, logits_per_pc = logit_img_txt(lidar_feats, text_features, clip_model)
logits_per_text_i, logits_per_img = logit_img_txt(img_feats, text_features, clip_model)

In [None]:
from einops import rearrange
from matplotlib import pyplot as plt
means = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cpu")
stds = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cpu")

dataset = loader.dataset

num_prompts = len(prompts)
num_samples = 4

for i in range(num_prompts):
    pc_logits, pc_idxs = torch.topk(logits_per_text_l[i,:], num_samples)
    img_logits, img_idxs = torch.topk(logits_per_text_i[i,:], num_samples)
    print(prompts[i])
    fig, axs = plt.subplots(1, num_samples, figsize=(15, 15))
#     fig.suptitle(prompts[i])
    for j in range(num_samples):
        logit = pc_logits[j]
        d_idx = pc_idxs[j]
        
        img, _ = dataset[d_idx]
        
        axs[j].imshow(rearrange(img, "c h w -> h w c").cpu()*stds+means)
#         axs[j].set_title(idx)
    plt.show()

