In [None]:
import random
import pickle
import numpy as np
from matplotlib import pyplot as plt

import torch

In [None]:
from ldm.data.nuscenes import NuScenesDataset
from ldm.data.utils import draw_projected_bbox, visualize_lidar, focus_on_bbox
from ldm.data.box_np_ops import points_in_bbox_corners
from ldm.data.lidar_converter import LidarConverter

In [None]:
# seed everything
seed = 3
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True


dataset = NuScenesDataset(
    state="val",
    object_database_path="/mnt/data/mobi/mobi/data/nuscenes/nuscenes_dbinfos_pbe_val.csv",
    scene_database_path="/mnt/data/mobi/mobi/data/nuscenes/nuscenes_scene_infos_pbe_val.pkl",
    reference_image_min_h=300,
    reference_image_min_w=300,
    object_classes=["car"],
    use_lidar=True,
    use_camera=True,
    image_height=256,
    image_width=256,
    # rot_every_angle=30
    # e0e3b665c44fa8d17d9f4770bdf9c',
)

In [None]:
# for i in range(100):
    # print(i)
sample = dataset[1]
bbox_3d = sample["bbox_3d"]
sample = sample["lidar"]
# image_tensor = sample["GT"]
image_tensor = sample["range_depth"]

def un_norm(x):
    return (x+1.0)/2.0

bbox_image_coords = sample['cond']['ref_bbox']

GT_img = un_norm(image_tensor).cpu().numpy().transpose(1, 2, 0)
GT_img = (GT_img * 255).astype(np.uint8)[..., ::-1]
GT_img = draw_projected_bbox(GT_img, bbox_image_coords[..., :2], thickness=1)
GT_img = GT_img[..., ::-1]

plt.figure(figsize=(20, 10))
plt.imshow(GT_img)
plt.axis('off')
plt.show()

In [None]:
range_depth = np.array(sample['range_depth'])[0]
range_depth_orig = sample['range_depth_orig']
range_shift_left = sample["range_shift_left"]

In [None]:
range_depth_unshift = np.roll(range_depth, range_shift_left, axis=-1)

In [None]:
range_depth.shape

In [None]:
plt.figure(figsize=(20, 10))
plt.imshow(range_depth_orig[..., None], cmap="gray")
plt.axis("off")
plt.show()

In [None]:
plt.figure(figsize=(20, 10))
plt.imshow(range_depth_unshift[..., None], cmap="gray")
plt.axis("off")
plt.show()

In [None]:
range_depth.shape

In [None]:
lidar_converter = LidarConverter()
range_depth_new, _ = lidar_converter.undo_default_transforms(
    range_shift_left,
    range_depth=range_depth_orig,
    range_depth_crop=range_depth,
    # mask=mask,
)

In [None]:
plt.figure(figsize=(20, 10))
plt.imshow(range_depth_new, cmap="gray")

In [None]:
points_new, _ = lidar_converter.range2pcd(range_depth_new)
points, _ = lidar_converter.range2pcd(range_depth_orig)
bbox_3d_new = bbox_3d

points, _ = focus_on_bbox(points, bbox_3d)
points_new, bbox_3d_new = focus_on_bbox(points_new, bbox_3d)

# mask = points_in_rbbox_corners(points, bbox_3d_new[None])
# points = points[mask[:, 0]]

In [None]:
lidar_vis_new = visualize_lidar(points_new, bboxes=bbox_3d_new)
lidar_vis = visualize_lidar(points, bboxes=bbox_3d_new)

In [None]:
# Plot them side by side
plt.figure(figsize=(20, 10))
plt.subplot(1, 2, 1)
plt.title("Before inpainting")
plt.imshow(lidar_vis)
plt.axis('off')
plt.subplot(1, 2, 2)
plt.title("After inpainting")
plt.imshow(lidar_vis_new)
plt.axis('off')


plt.show()

#### Find scene given picture

In [None]:
image = "n015-2018-10-02-10-50-40+0800__CAM_FRONT__1538448761512460.jpg"
with open("data/nuscenes/nuscenes_scene_infos_pbe_val.pkl", "rb") as f:
    scenes_info = pickle.load(f)

In [None]:
for scene_token, scene_info in scenes_info.items():
    print(scene_info['image_paths'][0])
    for image_path in scene_info['image_paths']:
        if image in image_path:
            print(scene_token)
            break

### Convert to video

In [None]:
import os
import cv2
image_paths = os.listdir("/mnt/data/mobi/mobi/results_test_rotate/exp/results")
# sort
image_paths = sorted(image_paths, key=lambda x: int(x.strip('.png').split('-')[-1]))

In [None]:
# load images
images = []
for image_path in image_paths:
    img = cv2.imread(os.path.join("/mnt/data/mobi/mobi/results_test_rotate/exp/results", image_path))
    images.append(img)

# create mp4 video
out = cv2.VideoWriter('output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 1, (800, 450))
for i in range(len(images)):
    out.write(images[i])
out.release()


### Model

In [None]:
from omegaconf import OmegaConf
from scripts.inference import load_model_from_config
from ldm.util import instantiate_from_config

In [None]:
def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model

In [None]:
config = OmegaConf.load("configs/nusc.yaml")
model = load_model_from_config(config, "checkpoints/model.ckpt")

In [None]:
model

### Checkpoint

In [None]:
import torch

In [None]:
model1 = torch.load("checkpoints/model.ckpt", map_location="cpu")['state_dict']
model2 = torch.load("/mnt/data/transient/mobi/mobi/models/Paint-by-Example/2024-03-25T16-58-51_nusc/checkpoints/last.ckpt", map_location="cpu")['state_dict']

In [None]:
for k in model2.keys():
    if k not in model1.keys():
        print(f"{k} not in model1")
    elif not torch.equal(model1[k], model2[k]):
        print(f"{k} is not equal")

### Scheduler

In [None]:
from ldm.lr_scheduler import LambdaLinearScheduler, LambdaWarmUpCosineScheduler2

In [None]:
scheduler = LambdaLinearScheduler(
    warm_up_steps=[0],
    f_start=[1e-3],
    cycle_lengths=[50000],
    f_max=[1],
    f_min=[1]
)

In [None]:
lr_list = [scheduler.schedule(i) for i in range(50000)]

In [None]:
import matplotlib.pyplot as plt
plt.plot(lr_list)

### Edit state dict

In [None]:
import torch

model = torch.load("checkpoints/model.ckpt", map_location="cpu")

In [None]:
model['state_dict'] = {k: v for k, v in model['state_dict'].items() if "first_stage_model" not in k}

In [None]:
torch.save(model, "checkpoints/model_no-vae.ckpt")