In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
from lidarclip.model.sst import build_sst
from mmcv.runner import load_checkpoint
from lidarclip.model.sst import LidarEncoderSST
import clip
import os
from train import LidarClippin

clip_model, clip_preprocess = clip.load("ViT-B/32")
lidar_encoder = LidarEncoderSST("../lidarclip/model/sst_encoder_only_config.py")
model = LidarClippin(lidar_encoder, clip_model, 1, 1)
checkpoint_name = "35vsmuyp/epoch=97-step=32842.ckpt"
load_checkpoint(model, os.path.join("/proj/nlp4adas/checkpoints/", checkpoint_name), map_location="cpu")
model.to(device)
pass

In [None]:
from lidarclip.loader import build_loader, OnceImageLidarDataset

loader = build_loader("/proj/nlp4adas/datasets/once", clip_preprocess, batch_size=16, num_workers=2, split='val')

In [None]:
images, lidars = next(iter(loader))

lidars = [lid .to(device) for lid in lidars]
images = [img.to(device) for img in images]

In [None]:
from einops import rearrange
from matplotlib import pyplot as plt
means = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cpu")
stds = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cpu")

fig, axs = plt.subplots(4,4, figsize=(15, 15))
for i in range(4):
    for j in range(4):
        idx = i*4 + j
        
        axs[i,j].imshow(rearrange(images[idx], "c h w -> h w c").cpu()*stds+means)
        axs[i,j].set_title(idx)

In [None]:
import numpy as np
fig, axs = plt.subplots(4,4, figsize=(30, 15))
for i in range(4):
    for j in range(4):
        idx = i*4 + j
        pc = lidars[idx].cpu()
        axs[i,j].scatter(-pc[:,1], pc[:,0], s=0.1, c=np.clip(pc[:, 3], 0, 1), cmap="coolwarm")
        axs[i,j].axis("equal")
        axs[i,j].set_xlim(-40, 40)
        axs[i,j].set_ylim(0, 60)
        axs[i,j].set_title(idx)

plt.show()

In [None]:
feature, attention = model.lidar_encoder(lidars, return_attention=True)
attention_bev = attention.reshape(16, 80, 80).detach().cpu().numpy()
torch.save(feature, "lidar_out.pt")

In [None]:
def plot_weighted_cloud(pc, weight, img, plot_weight_grid=False, use_colorbar=False, cmap="coolwarm", half_cloud=False, title=None):
    if half_cloud:
        pc = pc[pc[:,0] < 20]
        pc = pc[pc[:,1] < 10]
        pc = pc[pc[:,1] > -10]
        x_idx = (pc[:,0] / 0.5).astype(np.int32)
        y_idx = ((pc[:,1] + 10) / 0.5).astype(np.int32)
    else:
        pc = pc[pc[:,0] < 40]
        pc = pc[pc[:,1] < 20]
        pc = pc[pc[:,1] > -20]
        x_idx = (pc[:,0] / 0.5).astype(np.int32)
        y_idx = ((pc[:,1] + 20) / 0.5).astype(np.int32)

    plt.figure(dpi=200)
    col = weight[y_idx, x_idx]
    plt.scatter(-pc[:,1], pc[:,0], s=0.1, c=col**0.3, cmap=cmap)
    plt.axis("scaled")
    if half_cloud:
        plt.ylim(0, 20)
        plt.xlim(-10, 10)
    else:
        plt.ylim(0, 40)
        plt.xlim(-13, 13)
    if use_colorbar:
        plt.colorbar()
    if title:
        plt.title(title)
    plt.show()
    if plot_weight_grid:
        plt.matshow(weight.T**0.5)
        plt.show()
    plt.figure(dpi=200)
    plt.axis('off')
    plt.imshow(rearrange(img, "c h w -> h w c").cpu()*stds+means)
    plt.show()

In [None]:
idx = 15
pc = lidars[idx].cpu().numpy()
weight = attention_bev[idx]
plot_weighted_cloud(pc, weight, images[idx], plot_weight_grid=True)

In [None]:
bev_features, mask = model.lidar_encoder(lidars, no_pooling=True)
bev_features = rearrange(bev_features, '(h w) n c -> n h w c', h=80, w=80)
# bev_features = bev_features[:, :40, 20:60]
bev_features = rearrange(bev_features, 'n h w c -> n (h w) c')
mask = rearrange(mask, 'n (h w) -> n h w', h=80, w=80)
# mask = mask[:, :40, 20:60]
mask = rearrange(mask, 'n h w -> n (h w)')

mask.shape

In [None]:
# Encode some text
words = ["car", "bus", "truck", "person", "bike or moped", "tree or bush", "road", "traffic light"]
words = ["a picture of a " + word for word in words]
text = clip.tokenize(words).to(device)
# text = clip.tokenize(["a picture of a car on the road", "a picture of an empty road bush", "a picture of a tree"]).to(device)

with torch.no_grad():
    text_features = clip_model.encode_text(text)
text_features.shape

In [None]:
txt_feat = text_features / text_features.norm(dim=-1, keepdim=True)
bev_feat = bev_features / bev_features.norm(dim=-1, keepdim=True)

In [None]:
# cosine similarity as logits
logit_scale = model.clip.logit_scale.exp().float()
logits_per_bev = logit_scale * bev_feat.float() @ txt_feat.t().float()
# logits_per_text = logits_per_image.t()
logits_per_bev.shape

In [None]:
logits_per_bev -= mask[..., None]*99999999

In [None]:
idx = -3
text_idx = 4

In [None]:
print(words[text_idx])
bev_softmaxed = torch.softmax(logits_per_bev, dim=-1)
bev_softmaxed = rearrange(bev_softmaxed, 'n (h w) c -> n h w c', h=80, w=80).detach().cpu().numpy()
weight = bev_softmaxed[idx,...,text_idx]
# print(weight.max())

pc = lidars[idx].cpu().numpy()
plot_weighted_cloud(pc, weight, images[idx], plot_weight_grid=False, use_colorbar=True, title=words[text_idx])


In [None]:
print(words)
bev_softmaxed = torch.softmax(logits_per_bev, dim=-1)
bev_softmaxed = rearrange(bev_softmaxed, 'n (h w) c -> n h w c', h=80, w=80).detach().cpu().numpy()
weight = np.argmax(bev_softmaxed[idx], axis=-1)
pc = lidars[idx].cpu().numpy()
plot_weighted_cloud(pc, weight, images[idx], plot_weight_grid=False, cmap="Set1")


In [None]:
bev_softmaxed = torch.softmax(logits_per_bev, dim=1)
bev_softmaxed = rearrange(bev_softmaxed, 'n (h w) c -> n h w c', h=80, w=80).detach().cpu().numpy()
weight = bev_softmaxed[idx,...,text_idx]

pc = lidars[idx].cpu().numpy()
plot_weighted_cloud(pc, weight**0.5, images[idx], plot_weight_grid=False)

