In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
from lidar_clippin.model.sst import build_sst
from mmcv.runner import load_checkpoint
from lidar_clippin.model.sst import LidarEncoderSST
import pytorch_lightning as pl
import clip
import os


class LidarClippin(pl.LightningModule):
    def __init__(self, lidar_encoder: LidarEncoderSST, clip_model, batch_size: int):
        super().__init__()
        self.lidar_encoder = lidar_encoder
        self.clip = clip_model
        self.batch_size = batch_size
        for param in self.clip.parameters():
            param.requires_grad = False

    def training_step(self, batch, batch_idx):
        image, point_cloud = batch
        with torch.no_grad():
            image_features = self.clip.encode_image(image)
        lidar_features = self.lidar_encoder(point_cloud)
        loss = F.mse_loss((image_features), (lidar_features))
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.lidar_encoder.parameters(), lr=1e-5)
        steps_per_epoch = (3618846//self.batch_size)//self.trainer.accumulate_grad_batches
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=1e-3,
            #total_steps=self.trainer.estimated_stepping_batches,
            pct_start=0.1,
            steps_per_epoch=steps_per_epoch,
            epochs=self.trainer.max_epochs,
        )
        scheduler = {"scheduler": scheduler, "interval" : "step"}
        return [optimizer], [scheduler]


clip_model, clip_preprocess = clip.load("ViT-B/32")
lidar_encoder = LidarEncoderSST("model/sst_encoder_only_config.py")
model = LidarClippin(lidar_encoder, clip_model, 1)
load_checkpoint(model, "/proj/nlp4adas/checkpoints/35vsmuyp/epoch=97-step=32842.ckpt", map_location="cpu")
model.to(device)

In [None]:
from lidar_clippin.loader import build_loader
batch_size = 32
loader = build_loader("/proj/nlp4adas/datasets/once", clip_preprocess, batch_size=batch_size, num_workers=1, split="val")

In [None]:
img_feats = []
lidar_feats = []
with torch.no_grad():
    for i, batch in enumerate(loader):
        if i % 50 == 0 and i > 0:
            print(f"iter: {(i+1)*batch_size}")
            torch.save(torch.cat(img_feats, dim=0), f"img_feats_val_{(i+1)*batch_size}.pt")
            torch.save(torch.cat(lidar_feats, dim=0), f"lidar_feats_val_{(i+1)*batch_size}.pt")
        images, point_clouds = batch
        point_clouds = [pc.to(device) for pc in point_clouds]
        images = [img.to(device) for img in images]
        images = torch.cat([i.unsqueeze(0) for i in images])
        image_features = model.clip.encode_image(images)
        lidar_features, _ = model.lidar_encoder(point_clouds)
        img_feats.append(image_features.detach().cpu())
        lidar_feats.append(lidar_features.detach().cpu())

print("done")

In [None]:
img_feats = torch.cat(img_feats, dim=0)
lidar_feats = torch.cat(lidar_feats, dim=0)

torch.save(img_feats, "img_feats_val.pt")
torch.save(lidar_feats, "lidar_feats_val.pt")

In [None]:
img_feats = torch.load("img_feats_val_46535.pt").to(device)
lidar_feats = torch.load("lidar_feats_val_46535.pt").to(device)

In [None]:
def logit_img_txt(img_feat, txt_feat, model, weigths=None):
    img_feat = img_feat / img_feat.norm(dim=1, keepdim=True)
    txt_feat = txt_feat / txt_feat.norm(dim=1, keepdim=True)

    # cosine similarity as logits
    logit_scale = model.logit_scale.exp().float()
    logits_per_image = logit_scale * img_feat.float() @ txt_feat.t().float()
    if weights is not None:
        logits_per_image = logits_per_image @ weights.unsqueeze(1)
    logits_per_text = logits_per_image.t()
    return logits_per_text, logits_per_image

In [None]:
# Encode some text
text = clip.tokenize(["a car", "a bus", "a truck", "a person", "a construction vehicle", "a tree", "a bush"]).to(device)
#text = clip.tokenize(["a rural road", "a urban road", "a highway", "an overhead bridge", "a cross-walk"]).to(device)
#text = clip.tokenize(["an image at night", "an image during the day"]).to(device)
#text = clip.tokenize(["a road with cars", "an empty road"]).to(device)
#text = clip.tokenize(["a tiny car", "a small car", "a normal-size car", "a big car", "a huge car"]).to(device)
#text = clip.tokenize(["a white car", "a black car", "a red car", "a green car", "a blue car", "a yellow car"]).to(device)
#text = clip.tokenize(["a cross-walk", "an empty road", "a crossing"]).to(device)
#text = clip.tokenize(["a wet road", "a dry road"]).to(device)
#text = clip.tokenize(["the front of a car", "the back of a car", "the side of a car"]).to(device)
class_names = ["a photo of a three-wheeler", "photo of a car", "photo of a bus", "photo of a semi-truck", "photo of a person walking", "photo of a bike or moped", "photo of trees or bushes", "photo of an empty road"]
text = clip.tokenize(class_names).to(device)
#text = clip.tokenize(["a vehicle with umbrellas"]).to(device)

with torch.no_grad():
    text_features = clip_model.encode_text(text)

weights = torch.tensor([1,0], device=device).float()
weights = None

In [None]:
logits_per_text_l, logits_per_pc = logit_img_txt(lidar_feats, text_features, clip_model, weights)
logits_per_text_i, logits_per_img = logit_img_txt(img_feats, text_features, clip_model, weights)


In [None]:
pc_logits, pc_idxs = torch.topk(logits_per_text_l[-1,:], 16)
img_logits, img_idxs = torch.topk(logits_per_text_i[-1,:], 16)

logits = pc_logits
idxs = pc_idxs

In [None]:
from einops import rearrange
from matplotlib import pyplot as plt
import numpy as np
means = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cpu")
stds = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cpu")

dataset = loader.dataset

fig, axs = plt.subplots(4,4, figsize=(30, 15))

for i in range(4):
    for j in range(4):
        idx = i*4 + j//2
        logit = logits[idx]
        d_idx = idxs[idx]
        
        img, pc = dataset[d_idx]
        pc = pc.cpu()
        
        if j%2 == 0:
            axs[i,j].imshow(rearrange(img, "c h w -> h w c").cpu()*stds+means)
            axs[i,j].set_title(logit.item())
            axs[i,j].axis('off')
        else:
            axs[i,j].scatter(-pc[:,1], pc[:,0], s=0.1, c=np.clip(pc[:, 3], 0, 1), cmap="coolwarm")
            axs[i,j].set_xlim(-40, 40)
            axs[i,j].set_ylim(0, 40)
            axs[i,j].set_title(d_idx.item())
        

In [None]:
d_idx = 5717

dataset = loader.dataset

img, pc = dataset[d_idx]
plt.imshow(rearrange(img, "c h w -> h w c").cpu()*stds+means)
plt.axis('off')
plt.show()

plt.scatter(-pc[:,1], pc[:,0], s=0.1, c=np.clip(pc[:, 3], 0, 1), cmap="coolwarm")
plt.gca().set_xlim(-20, 20)
plt.gca().set_ylim(0, 40)
plt.gca().set_aspect('equal', adjustable='box')
plt.show()

logits = logits_per_text_l[:,d_idx].softmax(dim=0)

#plt.xticks(rotation='vertical')
plt.yticks(fontsize=16)
plt.xticks(fontsize=16)
x = [name.replace("photo of", "").lstrip(" an") for name in class_names]
plt.barh(x,logits.cpu().numpy())
print(class_names)
print(logits)

In [None]:
import numpy as np
fig, axs = plt.subplots(4,4, figsize=(30, 15))
for i in range(4):
    for j in range(4):
        idx = i*4 + j
        logit = logits[idx]
        d_idx = idxs[idx]
        _, pc = dataset[d_idx]
        
        pc = pc.cpu()
        axs[i,j].scatter(-pc[:,1], pc[:,0], s=0.1, c=np.clip(pc[:, 3], 0, 1), cmap="coolwarm")
        axs[i,j].axis("equal")
        axs[i,j].set_xlim(-40, 40)
        axs[i,j].set_ylim(0, 40)
        axs[i,j].set_title(idx)

plt.show()

In [None]:
from einops import rearrange
from matplotlib import pyplot as plt
means = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cpu")
stds = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cpu")

dataset = loader.dataset

fig, axs = plt.subplots(4,4, figsize=(15, 15))
for i in range(4):
    for j in range(4):
        idx = i*4 + j
        logit = logits[idx]
        d_idx = idxs[idx]
        
        img, _ = dataset[d_idx]
        
        
        axs[i,j].imshow(rearrange(img, "c h w -> h w c").cpu()*stds+means)
        axs[i,j].set_title(logit.item())
        axs[i,j].axis('off')
        