In [1]:
from dataset import get_dataset
from mmengine import Config
import plotly.express as px
from copy import deepcopy
import random
import torch
from src.utils.visualisation import colors, get_3d_corners
import matplotlib.pyplot as plt
import numpy as np
from src.model.utils.utils import modify_file_inplace
import plotly.graph_objects as go

modify_file_inplace("./config/_base_/data_select.py", "occ3d")
cfg = Config.fromfile(
    "./config/tpvformer/render.py"
)

Dataloader

In [2]:
from dataset import get_dataloader

train_dataset_loader, val_dataset_loader = get_dataloader(
    cfg.train_dataset_config,
    cfg.val_dataset_config,
    cfg.train_loader,
    cfg.val_loader,
    dist=None,
    iter_resume=None,
)

Model

In [3]:
import src.model
from mmseg.models import build_segmentor
from IPython.display import clear_output

my_model = build_segmentor(cfg.model)
my_model.init_weights()
my_model.cuda()

# Checkpoint
my_model.load_state_dict(
    torch.load("./ckpts/final/occ3d_tpv_std.pth")["state_dict"], strict=False
)
my_model.eval()
clear_output()

Data

In [4]:
# Select a random index
for i_iter, data in enumerate(val_dataset_loader):
    if i_iter == 0:
        break
    
# Select a specific index
# from dataset import OPENOCC_DATASET, get_dataloader
# from dataset.utils import custom_collate_fn_temporal
# val_wrapper = OPENOCC_DATASET.build(cfg.val_dataset_config)
# data = custom_collate_fn_temporal([val_wrapper[0]])

for k in list(data.keys()):
    if isinstance(data[k], torch.Tensor):
        data[k] = data[k].cuda()

# Image

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

resize_factor = 0.5  # Adjust this factor for faster display

# Permute to (B, H, W, C) and reverse the color channels from RGB to BGR
mean = torch.tensor(cfg.val_dataset_config["pipeline"][3]["mean"])
std = torch.tensor(cfg.val_dataset_config["pipeline"][3]["std"])

# Permute to (B, H, W, C) and reverse the color channels from RGB to BGR
images = data["img"][0].permute(0, 2, 3, 1).cpu()

# Unnormalize the images
images = ((images * std + mean) / 255.0).clip(0, 1)
images = images[..., [2, 1, 0]]

# Resize the images for faster display
new_size = [int(images.shape[1] * resize_factor), int(images.shape[2] * resize_factor)]
images = F.interpolate(images.permute(0, 3, 1, 2), size=new_size, mode="bilinear", align_corners=False)
images = images.permute(0, 2, 3, 1)  # Convert back to (B, H, W, C)

# Create a grid plot
num_images = images.shape[0]
fig, axes = plt.subplots(2, 3, figsize=(10, 4))
images = images[[1,0,2,4,3,5]]

for i, ax in enumerate(axes.flat):
    if i < num_images:
        ax.imshow(images[i].numpy())
        ax.axis("off")  # Hide axes for cleaner visualization
    else:
        ax.axis("off")  # Hide unused subplots

plt.tight_layout()
plt.show()

# Rendering

### Customize renderer

In [6]:
# Camera location
my_model.renderer.render_gt_mode = "sensor"
# Number of cameras
my_model.renderer.render_ncam = 6
# Specify a camera index
my_model.renderer.cam_idx = [0, 1, 2, 3, 4, 5]
# Dataset tag
my_model.renderer.dataset_tag = cfg.dataset_tag
# Gaussian scale
gaussian_scale = 0.27
my_model.renderer.gaussian_scale = gaussian_scale
my_model.aggregator.renderer_prep.gaussian_scale = gaussian_scale
my_model.aggregator.renderer_prep.overwrite_opacity = True

In [7]:
from src.model.utils.utils import (
    set_gt_render,
    set_matrix_to_render,
)

with torch.no_grad():
    set_matrix_to_render(False, my_model, data)

    result_dict = my_model(imgs=data["img"], metas=data)

    set_gt_render(False, my_model, data, result_dict["render"]["valid"])
clear_output()

### Plot GT & Pred

In [8]:
import plotly.express as px
import numpy as np
import torch
from skimage.transform import resize  # Import resize function for rescaling

# Example camera labels
camera_labels = [
    "CAM_FRONT",
    "CAM_FRONT_RIGHT",
    "CAM_FRONT_LEFT",
    "CAM_BACK",
    "CAM_BACK_RIGHT",
    "CAM_BACK_LEFT",
]

# Set the factor for rescaling. For example, 2 will reduce the size by half.
factor = 4
rescale = True  # Set this to True to enable rescaling

#### Cams

In [None]:
img_arrays_gt = []
img_arrays_pred = []

# Assuming `data["render_cam"]` is ground truth and `result_dict["render"]["cam"]` is predictions
for i in range(6):
    # Ground truth rendering
    render_cam = data["render_cam"][0, i]
    map_colors = torch.index_select(
        torch.from_numpy(colors).cuda(),
        0,
        render_cam.argmax(-1).flatten().to(torch.int32),
    )[:, :3]

    gt_mapped_img = (
        (map_colors.reshape(*render_cam.shape[:2], 3) / 255).cpu().numpy()
    )
    
    # Rescale if enabled
    if rescale:
        gt_mapped_img_rescaled = resize(gt_mapped_img, 
                                        (gt_mapped_img.shape[0] // factor, 
                                         gt_mapped_img.shape[1] // factor), 
                                        anti_aliasing=True)
        img_arrays_gt.append(gt_mapped_img_rescaled)  # Append ground truth image
    else:
        img_arrays_gt.append(gt_mapped_img)  # Append ground truth image without scaling

    # Prediction rendering
    pred_cam = result_dict["render"]["cam"][0, i]  # Predictions
    pred_map_colors = torch.index_select(
        torch.from_numpy(colors).cuda(),
        0,
        pred_cam.argmax(-1).flatten().to(torch.int32),
    )[:, :3]

    pred_mapped_img = (
        (pred_map_colors.reshape(*pred_cam.shape[:2], 3) / 255)
        .cpu()
        .numpy()
    )

    # Rescale if enabled
    if rescale:
        pred_mapped_img_rescaled = resize(pred_mapped_img, 
                                          (pred_mapped_img.shape[0] // factor, 
                                           pred_mapped_img.shape[1] // factor), 
                                          anti_aliasing=True)
        img_arrays_pred.append(pred_mapped_img_rescaled)  # Append predicted image
    else:
        img_arrays_pred.append(pred_mapped_img)  # Append predicted image without scaling

# Plot both ground truth and predictions in separate subplots
fig = px.imshow(
    np.stack(img_arrays_gt),
    facet_col=0,
    facet_col_wrap=3,
    title="Ground Truth: Plot all sensor images",
)
fig.update_layout(title="Ground Truth: Plot all sensor images")

# Replace facet titles with camera labels
fig.for_each_annotation(
    lambda a: a.update(text=camera_labels[int(a.text.split("=")[-1])])
)

# Create a new figure for the predictions
fig_pred = px.imshow(
    np.stack(img_arrays_pred),
    facet_col=0,
    facet_col_wrap=3,
    title="Predictions: Plot all sensor images",
)
fig_pred.update_layout(title="Predictions: Plot all sensor images")

# Replace facet titles with camera labels
fig_pred.for_each_annotation(
    lambda a: a.update(text=camera_labels[int(a.text.split("=")[-1])])
)

# Show the ground truth and predictions
fig.show()
fig_pred.show()

In [None]:
from src.model.utils.utils import set_up_metrics

_, miou_metric_img, *_ = set_up_metrics(cfg.dataset_tag)

np.mean(
    miou_metric_img._compute_iou(
        data["render_cam"].argmax(-1).flatten(),
        result_dict["render"]["cam"].argmax(-1).flatten(),
        None,
    )
)

#### Depth

In [None]:
# Assuming `data["render_cam_depth"]` is ground truth depth and `result_dict["render"]["depth"]` is predicted depth
depth_arrays_gt = []
depth_arrays_pred = []

# Iterate through each camera for ground truth and prediction depth
for i in range(6):
    # Ground truth depth rendering
    render_cam_depth = data["render_cam_depth"][0, i].cpu().numpy()

    # Rescale if enabled
    if rescale:
        gt_depth_rescaled = resize(render_cam_depth,
                                   (render_cam_depth.shape[0] // factor,
                                    render_cam_depth.shape[1] // factor),
                                   anti_aliasing=True)
        depth_arrays_gt.append(gt_depth_rescaled)  # Append rescaled ground truth depth
    else:
        depth_arrays_gt.append(render_cam_depth)  # Append original ground truth depth

    # Prediction depth rendering
    pred_cam_depth = result_dict["render"]["depth"][0, i].cpu().numpy()  # Predictions

    # Rescale if enabled
    if rescale:
        pred_depth_rescaled = resize(pred_cam_depth,
                                     (pred_cam_depth.shape[0] // factor,
                                      pred_cam_depth.shape[1] // factor),
                                     anti_aliasing=True)
        depth_arrays_pred.append(pred_depth_rescaled)  # Append rescaled predicted depth
    else:
        depth_arrays_pred.append(pred_cam_depth)  # Append original predicted depth

# Plot both ground truth and predictions in separate subplots for depth images
fig_gt = px.imshow(
    np.array(depth_arrays_gt).squeeze(-1),
    facet_col=0,
    facet_col_wrap=3,
    title="Ground Truth:",
)
fig_gt.update_layout(title="Ground Truth: Plot all depths")

# Replace facet titles with camera labels
fig_gt.for_each_annotation(
    lambda a: a.update(text=camera_labels[int(a.text.split("=")[-1])])
)

# Create a new figure for the predictions
fig_pred = px.imshow(
    np.array(depth_arrays_pred).squeeze(-1),
    facet_col=0,
    facet_col_wrap=3,
    title="Predictions:",
)
fig_pred.update_layout(title="Predictions: Plot all depths")

# Replace facet titles with camera labels
fig_pred.for_each_annotation(
    lambda a: a.update(text=camera_labels[int(a.text.split("=")[-1])])
)

# Show the ground truth and predictions
fig_gt.show()
fig_pred.show()

#### BeV

In [None]:
# Example BEV labels (for BEV, you might not need camera labels)
bev_labels = [
    "Ground Truth BEV",
    "Predicted BEV",
]

# Set the factor for rescaling. For example, 4 will reduce the size by a quarter.
factor = 1
rescale = True  # Set this to True to enable rescaling

# Assuming `data["render_bev"]` is ground truth BEV and `result_dict["render"]["bev"]` is predicted BEV
bev_arrays_gt = []
bev_arrays_pred = []

# Ground truth BEV rendering
render_bev_gt = data["render_bev"][0].cpu()
sem_cls = render_bev_gt.argmax(-1).flatten().to(torch.int32)
if cfg.dataset_tag == "occ3d":
    sem_cls[sem_cls == 17] = 20
map_colors_gt = torch.index_select(
    torch.from_numpy(colors),
    0,
    sem_cls,
)[:, :3]

gt_mapped_bev = (map_colors_gt.reshape(*render_bev_gt.shape[:2], 3) / 255).cpu().numpy()

# Rescale if enabled
if rescale:
    gt_mapped_bev_rescaled = resize(
        gt_mapped_bev,
        (gt_mapped_bev.shape[0] // factor, gt_mapped_bev.shape[1] // factor),
        anti_aliasing=True,
    )
    bev_arrays_gt.append(gt_mapped_bev_rescaled)  # Append rescaled ground truth BEV
else:
    bev_arrays_gt.append(gt_mapped_bev)  # Append original ground truth BEV

# Prediction BEV rendering
render_bev_pred = result_dict["render"]["bev"][0].cpu()  # Predictions
sem_cls = render_bev_pred.argmax(-1).flatten().to(torch.int32)
if cfg.dataset_tag == "occ3d":
    sem_cls[sem_cls == 17] = 20
pred_map_colors = torch.index_select(
    torch.from_numpy(colors),
    0,
    sem_cls,
)[:, :3]

pred_mapped_bev = (
    (pred_map_colors.reshape(*render_bev_pred.shape[:2], 3) / 255).cpu().numpy()
)

# Rescale if enabled
if rescale:
    pred_mapped_bev_rescaled = resize(
        pred_mapped_bev,
        (pred_mapped_bev.shape[0] // factor, pred_mapped_bev.shape[1] // factor),
        anti_aliasing=True,
    )
    bev_arrays_pred.append(pred_mapped_bev_rescaled)  # Append rescaled predicted BEV
else:
    bev_arrays_pred.append(pred_mapped_bev)  # Append original predicted BEV

# Plot both ground truth and predictions in separate subplots for BEV images
fig_gt = px.imshow(
    np.concatenate([np.array(bev_arrays_gt), np.array(bev_arrays_pred)]),
    facet_col=0,
    title="Ground Truth BEV",
)
fig_gt.update_layout(title="Ground Truth / Predicted BEV")

# Replace facet titles with BEV labels
fig_gt.for_each_annotation(
    lambda a: a.update(text=bev_labels[int(a.text.split("=")[-1])])
)

#### BeV depth

In [None]:
import plotly.express as px
import numpy as np
import torch
from skimage.transform import resize  # Import resize function for rescaling

# Example BEV labels (for BEV, you might not need camera labels)
bev_labels = [
    "Ground Truth BEV",
    "Predicted BEV",
]

# Set the factor for rescaling. For example, 4 will reduce the size by a quarter.
factor = 1
rescale = True  # Set this to True to enable rescaling

# Assuming `data["render_bev"]` is ground truth BEV and `result_dict["render"]["bev"]` is predicted BEV
bev_arrays_gt = []
bev_arrays_pred = []

# Ground truth BEV rendering
render_bev_gt = data["render_bev_depth"][0].cpu().numpy()

# Normalize the depth values to the range [0, 1]
gt_normalized_bev = render_bev_gt

# Rescale if enabled
if rescale:
    gt_normalized_bev_rescaled = resize(gt_normalized_bev,
                                        (gt_normalized_bev.shape[0] // factor,
                                         gt_normalized_bev.shape[1] // factor),
                                        anti_aliasing=True)
    bev_arrays_gt.append(gt_normalized_bev_rescaled)  # Append rescaled ground truth BEV
else:
    bev_arrays_gt.append(gt_normalized_bev)  # Append original ground truth BEV

# Prediction BEV rendering
render_bev_pred = result_dict["render"]["bev_depth"][0].cpu().numpy()  # Predictions

# Normalize the depth values to the range [0, 1]
pred_normalized_bev = (render_bev_pred)

# Rescale if enabled
if rescale:
    pred_normalized_bev_rescaled = resize(pred_normalized_bev,
                                          (pred_normalized_bev.shape[0] // factor,
                                           pred_normalized_bev.shape[1] // factor),
                                          anti_aliasing=True)
    bev_arrays_pred.append(pred_normalized_bev_rescaled)  # Append rescaled predicted BEV
else:
    bev_arrays_pred.append(pred_normalized_bev)  # Append original predicted BEV

# Plot both ground truth and predictions in separate subplots for BEV images
fig = px.imshow(
    np.concatenate([np.array(bev_arrays_gt), np.array(bev_arrays_pred)], axis=0).squeeze(-1),
    facet_col=0,
    facet_col_wrap=2,
    title="Ground Truth / Predicted BEV",
)
fig.update_layout(title="Ground Truth / Predicted BEV")

# Replace facet titles with BEV labels
fig.for_each_annotation(
    lambda a: a.update(text=bev_labels[int(a.text.split("=")[-1])])
)

# Show the ground truth and predictions
fig.show()

# 3D

## Render voxels

### GT

Get data

In [14]:
gt_xyz, gt_colors = data["occ_xyz"][0].flatten(0, 2), data["occ_label"][0].flatten(0, 2)
gt_colors[gt_colors == 255] = 0
if cfg.dataset_tag =="occ3d":
    empty_cls = 17
elif cfg.dataset_tag =="surroundocc":
    empty_cls = 0
idx = (gt_colors != empty_cls)

gt_xyz = gt_xyz[idx]
gt_colors = gt_colors[idx]
map_colors = torch.index_select(
    torch.from_numpy(colors).to("cuda"), 0, gt_colors.to(torch.int32)
)

Modify cams

In [15]:
ref2cams = torch.cat(
    [data["render"]["ref2cams"], data["render"]["ref2cams_bev"]], dim=1
).clone()
intrins = torch.cat([data["render"]["intrins"], data["render"]["intrins_bev"]], dim=1)

c2w = torch.linalg.inv(ref2cams)
pts = c2w[0, :, :3, 3]

Render

In [16]:
depth_coeff = 10.0 # Rays depth for camera visualisation

In [None]:
# Plot the points
fig = go.Figure()
Ncam = intrins.shape[1]
colors_cam = ["blue", "steelblue", "cyan", "royalblue", "dodgerblue", "lightgreen", "red"]

# Plot coarse points
for _ in range(Ncam-1):
    fig.add_trace(
        go.Scatter3d(
            x=pts[_ : _ + 1, 0].cpu(),
            y=pts[_ : _ + 1, 1].cpu(),
            z=pts[_ : _ + 1, 2].cpu(),
            mode="markers",
            marker=dict(size=3, color=colors_cam[_]),
            showlegend=False  # Hide legend for each trace
        )
    )

    # Get 3D corners in the world space
    H, W = 900, 1600
    corners_3d = get_3d_corners(
        H, W, intrins[0, _, :3, :3].cpu(), c2w[0, _].cpu(), depth_coeff
    )
    fig.add_trace(
        go.Scatter3d(
            x=corners_3d[:, 0],
            y=corners_3d[:, 1],
            z=corners_3d[:, 2],
            mode="markers",
            marker=dict(size=3, color=colors_cam[_]),
            showlegend=False  # Hide legend for each trace
        )
    )

    # Create lines from camera center to each of the 4 borders
    for i in range(4):
        fig.add_trace(
            go.Scatter3d(
                x=np.concatenate([pts[_ : _ + 1, 0].cpu(), corners_3d[i : i + 1, 0]]),
                y=np.concatenate([pts[_ : _ + 1, 1].cpu(), corners_3d[i : i + 1, 1]]),
                z=np.concatenate([pts[_ : _ + 1, 2].cpu(), corners_3d[i : i + 1, 2]]),
                mode="lines",
                line=dict(width=3, color=colors_cam[_]),
                showlegend=False,  # Hide legend for each trace
            )
        )

# Plot 3D occ
fig.add_trace(
    go.Scatter3d(
        x=gt_xyz[..., 0].cpu(),
        y=gt_xyz[..., 1].cpu(),
        z=gt_xyz[..., 2].cpu(),
        mode="markers",
        marker=dict(size=1, color=map_colors.cpu()),
        showlegend=False,  # Hide legend for this trace too
    )
)

# Set plot layout
fig.update_layout(
    scene=dict(
        xaxis=dict(visible=False),  # Hide x-axis
        yaxis=dict(visible=False),  # Hide y-axis
        zaxis=dict(visible=False),  # Hide z-axis
        bgcolor="white",  # Set the scene background to white
    ),
    paper_bgcolor="white",  # Set the entire figure's background to white
    plot_bgcolor="white",  # Set the plotting area background to white
)
fig["layout"]["scene"]["aspectmode"] = "data"

# Show the plot
fig.show()


### Pred

In [None]:
pred_segm = result_dict["pred_occ"][-1][0].argmax(0).cpu()
pred_xyz = data["occ_xyz"][0].flatten(0, 2).cpu()
pred_map_colors = torch.index_select(torch.from_numpy(colors), 0, pred_segm.to(torch.int32))

index = (pred_segm != empty_cls) 
pred_segm = pred_segm[index]
pred_xyz = pred_xyz[index]
pred_map_colors = pred_map_colors[index]

# Plot 3D occ
fig = go.Figure()
fig.add_trace(
    go.Scatter3d(
        x=pred_xyz[..., 0].cpu(),
        y=pred_xyz[..., 1].cpu(),
        z=pred_xyz[..., 2].cpu(),
        mode="markers",
        marker=dict(size=1, color=pred_map_colors.cpu()),
        showlegend=False,  # Hide legend for this trace too
    )
)

# Set plot layout
fig.update_layout(
    scene=dict(
        xaxis=dict(visible=False),  # Hide x-axis
        yaxis=dict(visible=False),  # Hide y-axis
        zaxis=dict(visible=False),  # Hide z-axis
        bgcolor="white",  # Set the scene background to white
    ),
    paper_bgcolor="white",  # Set the entire figure's background to white
    plot_bgcolor="white",  # Set the plotting area background to white
)
fig["layout"]["scene"]["aspectmode"] = "data"

# Show the plot
fig.show()