In [None]:
%load_ext autoreload
%autoreload 2
import os
import os.path as osp
import sys
import json
import numpy as np
import torch
import pyvista as pv
import accelerate
from tqdm import tqdm
from easydict import EasyDict as edict
from termcolor import colored

sys.path.append(os.path.join(os.getcwd(), '..'))

## Environment Setup

In [None]:
# local jupyter
jupyter_backend = "trame"

In [None]:
# remote jupyter
def is_vscode() -> bool:
    for var in os.environ:
        if var == "VSCODE_CWD":
            return True
    return False

if is_vscode():
    print(colored("Vscode jupyter DOESN'T support pyvista interative mode", "yellow", force_color=True))
    jupyter_backend = "static"
else:
    jupyter_backend = "trame"

# set this if on remote jupyter
# for headless linux users
os.environ["DISPLAY"] = ":99.0"
os.environ["PYVISTA_OFF_SCREEN"] = "true"
# NOTE: vscode remote jupyter does not work with pyvista
if not is_vscode():
    pv.global_theme.trame.server_proxy_enabled = True

In [None]:
accelerator = accelerate.Accelerator()
device = accelerator.device
device

Write and Load point cloud functions (option)

In [None]:
from plyfile import PlyData, PlyElement

def write_ply(points, save_path):
    """
    points: numpy array in shape (N, 6) or (N, 7)
    save_name: str end with ".ply"
    """
    assert points.shape[1] == 6 or points.shape[1] == 7, "points.shape[1] should be 6 or 7"
    save_path = str(save_path)
    assert save_path.endswith(".ply"), "save_name should end with '.ply'"
    points = [
        (points[i, 0], points[i, 1], points[i, 2], points[i, 3], points[i, 4], points[i, 5])
        for i in range(points.shape[0])
    ]
    vertex = np.array(
        points,
        dtype=[
            ("x", "f4"),
            ("y", "f4"),
            ("z", "f4"),
            ("red", "f4"),
            ("green", "f4"),
            ("blue", "f4"),
        ],
    )
    data = PlyElement.describe(vertex, "vertex", comments=["vertices"])
    PlyData([data]).write(save_path)

def read_ply(save_path):
    filename = save_path
    with open(filename, 'rb') as f:
        plydata = PlyData.read(f)
        num_verts = plydata['vertex'].count
        vertices = np.zeros(shape=[num_verts, 6], dtype=np.float32)
        vertices[:,0] = plydata['vertex'].data['x']
        vertices[:,1] = plydata['vertex'].data['y']
        vertices[:,2] = plydata['vertex'].data['z']
        vertices[:,3] = plydata['vertex'].data['red']
        vertices[:,4] = plydata['vertex'].data['green']
        vertices[:,5] = plydata['vertex'].data['blue']
    return vertices

## Prepare data & models

In [None]:
# load existing args
PROJECT_TOP_DIR = "your/project/dir"
PROJECT_DIR = osp.join(PROJECT_TOP_DIR, "dir/of/model")
CHECKPOINT_DIR = osp.join(
    PROJECT_DIR,
    "checkpoints/folder/name",
)
with open(osp.join(PROJECT_DIR, "config.json.txt"), "r") as f:
    args = edict(json.load(f))

### Load data

In [None]:
from data.referit3d.in_out.neural_net_oriented import (
    compute_auxiliary_data,
    load_referential_data,
    load_scan_related_data,
    trim_scans_per_referit3d_data_,
)
# load data
SCANNET_PKL_FILE = "../../datasets/scannet/instruct/global_small.pkl"
REFERIT_CSV_FILE = "../../datasets/nr3d/nr3d_generative_20230825_final.csv"
all_scans_in_dict, scans_split, class_to_idx = load_scan_related_data(SCANNET_PKL_FILE)
referit_data = load_referential_data(args, args.referit3D_file, scans_split)
# Prepare data & compute auxiliary meta-information.
all_scans_in_dict = trim_scans_per_referit3d_data_(referit_data, all_scans_in_dict)
mean_rgb = compute_auxiliary_data(referit_data, all_scans_in_dict)

In [None]:
from transformers import BertTokenizer
# prepare tokenizer
tokenizer = BertTokenizer.from_pretrained(args.bert_pretrain_path)
# Prepare the Listener
n_classes = len(class_to_idx) - 1  # -1 to ignore the <pad> class
pad_idx = class_to_idx["pad"]
# Object-type classification
class_name_list = list(class_to_idx.keys())

class_name_tokens = tokenizer(class_name_list, return_tensors="pt", padding=True)
class_name_tokens = class_name_tokens.to(device)

In [None]:
from data.referit3d.datasets import make_data_loaders
data_loaders = make_data_loaders(
    args=args,
    accelerator=accelerator,
    referit_data=referit_data,
    class_to_idx=class_to_idx,
    scans=all_scans_in_dict,
    mean_rgb=mean_rgb,
    tokenizer=tokenizer,
)

In [None]:
from scripts.train_utils import move_batch_to_device_
# get random data
test_dataset = data_loaders["test"].dataset
rand_idx = np.random.randint(0, len(test_dataset))

In [None]:
rand_data = test_dataset[rand_idx]
print(f"Original text: {rand_data['text']}")
rand_data_scan, rand_data_target_objs = test_dataset.get_reference_data(rand_idx)[:2]
rand_data_3d_objs = rand_data_scan.three_d_objects.copy()
rand_data_3d_objs.remove(rand_data_target_objs)
# rand_data["text"] = "Create a light color chair in the center of the backpack and the door."
# rand_data["tokens"] = test_dataset.tokenizer(rand_data["text"], max_length=test_dataset.max_seq_len, truncation=True, padding=False)
collate_fn = data_loaders["test"].collate_fn
# get batch
batch = collate_fn([rand_data])
batch = move_batch_to_device_(batch, device)

### Load models with checkpoints

In [None]:
from models.referit3d_model.referit3d_net import ReferIt3DNet_transformer
from models.point_e_model.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from models.point_e_model.diffusion.sampler import PointCloudSampler
from models.point_e_model.models.configs import MODEL_CONFIGS, model_from_config

# referit3d model
mvt3dvg = ReferIt3DNet_transformer(args, n_classes, class_name_tokens, ignore_index=pad_idx)
# point-e model
point_e_config = MODEL_CONFIGS[args.point_e_model]
point_e_config["cache_dir"] = osp.join(PROJECT_TOP_DIR, "cache", "point_e_model")
point_e_config["n_ctx"] = args.points_per_object
point_e = model_from_config(point_e_config, device)
point_e_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[args.point_e_model])
# move models to gpu
mvt3dvg = mvt3dvg.to(device).eval()
point_e = point_e.to(device).eval()

In [None]:
# load model and checkpoints
if args.mode == "train":
    mvt3dvg = torch.compile(mvt3dvg)
mvt3dvg, point_e = accelerator.prepare(mvt3dvg, point_e)
accelerator.load_state(CHECKPOINT_DIR)

In [None]:
from models.point_e_model.diffusion.sampler import PointCloudSampler

aux_channels = ["R", "G", "B"]
sampler = PointCloudSampler(
    device=device,
    models=[point_e],
    diffusions=[point_e_diffusion],
    num_points=[args.points_per_object],
    aux_channels=aux_channels,
    guidance_scale=[3.0],
    use_karras=[True],
    karras_steps=[64],
    sigma_min=[1e-3],
    sigma_max=[120],
    s_churn=[3],
)

## Visualization

### Inference

In [None]:
with torch.no_grad():
    ctx_embeds, LOSS, CLASS_LOGITS, LANG_LOGITS, LOCATE_PREDS, pred_xyz = mvt3dvg(batch)

    prompts = batch["text"]
    # stack twice for guided scale
    ctx_embeds = torch.cat((ctx_embeds, ctx_embeds), dim=0)
    samples_it = sampler.sample_batch_progressive(
        batch_size=len(prompts),
        ctx_embeds=ctx_embeds,
        model_kwargs=dict(texts=prompts),
        accelerator=accelerator,
    )
    # get the last timestep prediction
    for last_pcs in samples_it:
        pass
    last_pcs = last_pcs.permute(0, 2, 1)

### Axis Norm model model's postprocessing for generated point cloud ###
**Only for axis norm model**, if your model did not apply `--axis-norm` option, please skip this step.

In [None]:
# For axis_norm model
TOPK = 10
pred_xy, pred_z, pred_radius = pred_xyz
pred_xy_topk_bins = pred_xy.topk(TOPK, dim=-1)[1]  # (B, 5)
# pred_z_topk_bins = pred_z.topk(5, dim=-1)[1]  # (B, 5)
pred_z_topk_bins = pred_z.argmax(dim=-1, keepdim=True).repeat(1, TOPK)  # (B, 5)
pred_x_topk_bins = pred_xy_topk_bins % args.axis_norm_bins  # (B, 5)
pred_y_topk_bins = pred_xy_topk_bins // args.axis_norm_bins  # (B, 5)
pred_bins = torch.stack(
    (pred_x_topk_bins, pred_y_topk_bins, pred_z_topk_bins), dim=-1
)  # (B, 5, 3)
pred_bins = (pred_bins.float() + 0.5) / args.axis_norm_bins  # (B, 5, 3)
(
    min_box_center_axis_norm,  # (B, 3)
    max_box_center_axis_norm,  # (B, 3)
) = (
    batch["min_box_center_before_axis_norm"],
    batch["max_box_center_before_axis_norm"],
)  # all range from [-1, 1]
pred_topk_xyz = (
    min_box_center_axis_norm[:, None]
    + (max_box_center_axis_norm - min_box_center_axis_norm)[:, None] * pred_bins
)  # (B, 5, 3)
pred_radius = pred_radius.unsqueeze(-1).permute(0, 2, 1).repeat(1, 5, 1)  # (B, 5, 1)
# pred_topk_xyz = torch.cat([pred_topk_xyz, pred_radius], dim=-1)  # (B, 5, 4)

Axis-norm provides the topk axis for generated point cloud, you may choose by modifying the `object_idx` in the following code block.

In [None]:
# Choose this or the next block
# Choose which object position to visualize
# The object_idx should between 0 - 4
object_idx = 0

vis_pc = last_pcs.squeeze(0) # (P, 6)

pos = vis_pc[:, :3]
aux = vis_pc[:, 3:]

pred_box_center, pred_box_max_dist = pred_topk_xyz[:, object_idx, :], pred_radius[:, 0, :]

# Process the generated point cloud
coords = pos * pred_box_max_dist + pred_box_center
colors = aux.clamp(0, 255).round()  # (P, 3 or 4)
vis_pc = torch.cat((coords, colors), dim=-1)  # (P, 6)
vis_pc = vis_pc.unsqueeze(0) # (1, P, 6)
vis_pc = vis_pc.cpu().numpy()

### None-axis Norm model's postprocessing for generated point cloud ###

In [None]:
# replace last_pcs with the real point cloud
vis_pc = last_pcs.squeeze(0) # (P, 6)

pos = vis_pc[:, :3]
aux = vis_pc[:, 3:]

pred_box_center, pred_box_max_dist = LOCATE_PREDS[0, :3], LOCATE_PREDS[0, 3]

# Process the generated point cloud
coords = pos * pred_box_max_dist + pred_box_center
colors = aux.clamp(0, 255).round()  # (P, 3 or 4)
vis_pc = torch.cat((coords, colors), dim=-1)  # (P, 6)
vis_pc = vis_pc.unsqueeze(0) # (1, P, 6)
vis_pc = vis_pc.cpu().numpy()

### Visualization

The following is scene id and instructed text.

In [None]:
print(f"The scene id is: {batch['scan_id']}")
print(f"The instructed text is: {batch['text']}") 
print(f"The number of padding objects is: {batch['ctx_key_padding_mask'].sum().item()}")

In [None]:
# Create a pyvista point cloud object
plotter = pv.Plotter()
plotter.window_size = (800, 600)
if saved_cpos:
    plotter.camera_position = saved_cpos

# add generated objects
# obj = vis_pc[0]
# mesh = pv.PolyData(obj[:, :3]).delaunay_3d(alpha=0.005)
# color = obj[:, 3:6].astype(np.uint8)
# bound = mesh.bounds
# plotter.add_box_widget(callback=None, bounds=bound, factor=1.25, outline_translation=False, rotation_enabled=False, color="red")
# plotter.add_mesh(mesh, scalars=color, rgb=True, preference='point')

# add reference objects
mesh = pv.PolyData(rand_data_target_objs.pc).delaunay_3d(alpha=1e-3)
color = (rand_data_target_objs.color * 255).astype(np.uint8)
bound = mesh.bounds
plotter.add_box_widget(callback=None, bounds=bound, factor=1.25, outline_translation=False, rotation_enabled=False, color="blue")
plotter.add_mesh(mesh, scalars=color, rgb=True, preference='point')

# add context
for obj in rand_data_3d_objs:
    mesh = pv.PolyData(obj.pc).delaunay_3d(alpha=1e-3)
    color = (obj.color * 255).astype(np.uint8)
    plotter.add_mesh(mesh, scalars=color, rgb=True, preference='point')

plotter.show(jupyter_backend=jupyter_backend)

In [None]:
saved_cpos = plotter.camera_position

In [None]:
# plotter.save_graphic(f"{batch['stimulus_id'][0]}_wo.svg")
plotter.save_graphic(f"{batch['stimulus_id'][0]}_ref.svg")
# plotter.save_graphic(f"{batch['stimulus_id'][0]}_{object_idx}.svg")

In [None]:
saved_cpos = None

**Automatically select topk**

In [None]:
for object_idx in tqdm(range(TOPK)):
    # generate
    with torch.no_grad():
        ctx_embeds, LOSS, CLASS_LOGITS, LANG_LOGITS, LOCATE_PREDS, pred_xyz = mvt3dvg(batch)
    
        prompts = batch["text"]
        # stack twice for guided scale
        ctx_embeds = torch.cat((ctx_embeds, ctx_embeds), dim=0)
        samples_it = sampler.sample_batch_progressive(
            batch_size=len(prompts),
            ctx_embeds=ctx_embeds,
            model_kwargs=dict(texts=prompts),
            accelerator=accelerator,
        )
        # get the last timestep prediction
        for last_pcs in samples_it:
            pass
        last_pcs = last_pcs.permute(0, 2, 1)
    # locate
    vis_pc = last_pcs.squeeze(0) # (P, 6)
    
    pos = vis_pc[:, :3]
    aux = vis_pc[:, 3:]
    
    pred_box_center, pred_box_max_dist = pred_topk_xyz[:, object_idx, :], pred_radius[:, 0, :]

    coords = pos * pred_box_max_dist + pred_box_center
    colors = aux.clamp(0, 255).round()  # (P, 3 or 4)
    vis_pc = torch.cat((coords, colors), dim=-1)  # (P, 6)
    vis_pc = vis_pc.unsqueeze(0) # (1, P, 6)
    vis_pc = vis_pc.cpu().numpy()
    # Create a pyvista point cloud object
    plotter = pv.Plotter()
    plotter.window_size = (800, 600)
    if saved_cpos:
        plotter.camera_position = saved_cpos
    
    # add generated objects
    obj = vis_pc[0]
    mesh = pv.PolyData(obj[:, :3]).delaunay_3d(alpha=0.005)
    color = obj[:, 3:6].astype(np.uint8)
    bound = mesh.bounds
    plotter.add_box_widget(callback=None, bounds=bound, factor=1.25, outline_translation=False, rotation_enabled=False, color="red")
    plotter.add_mesh(mesh, scalars=color, rgb=True, preference='point')
    
    # add context
    for obj in rand_data_3d_objs:
        mesh = pv.PolyData(obj.pc).delaunay_3d(alpha=1e-3)
        color = (obj.color * 255).astype(np.uint8)
        plotter.add_mesh(mesh, scalars=color, rgb=True, preference='point')

    # save
    plotter.save_graphic(f"{batch['stimulus_id'][0]}_{object_idx}.svg")

## Shape Diversity

In [None]:
print(f"The scene id is: {batch['scan_id']}")
print(f"The instructed text is: {batch['text']}") 
print(f"The number of padding objects is: {batch['ctx_key_padding_mask'].sum().item()}")

In [None]:
# generate
with torch.no_grad():
    ctx_embeds, LOSS, CLASS_LOGITS, LANG_LOGITS, LOCATE_PREDS, pred_xyz = mvt3dvg(batch)

    prompts = batch["text"]
    # stack twice for guided scale
    ctx_embeds = torch.cat((ctx_embeds, ctx_embeds), dim=0)
    samples_it = sampler.sample_batch_progressive(
        batch_size=len(prompts),
        ctx_embeds=ctx_embeds,
        model_kwargs=dict(texts=prompts),
        accelerator=accelerator,
    )
    # get the last timestep prediction
    for last_pcs in samples_it:
        pass
    last_pcs = last_pcs.permute(0, 2, 1)
# Process the generated point cloud
vis_pc = last_pcs.squeeze(0) # (P, 6)
pos = vis_pc[:, :3]
aux = vis_pc[:, 3:]
coords = pos
colors = aux.clamp(0, 255).round()  # (P, 3 or 4)
vis_pc = torch.cat((coords, colors), dim=-1)  # (P, 6)
vis_pc = vis_pc.unsqueeze(0) # (1, P, 6)
vis_pc = vis_pc.cpu().numpy()

# Create a pyvista point cloud object
plotter = pv.Plotter()
plotter.window_size = (800, 600)
if saved_cpos:
    plotter.camera_position = saved_cpos

# add generated objects
obj = vis_pc[0]
mesh = pv.PolyData(obj[:, :3]).delaunay_3d(alpha=0.005)
color = obj[:, 3:6].astype(np.uint8)
bound = mesh.bounds
plotter.add_mesh(mesh, scalars=color, rgb=True, preference='point')
plotter.show(jupyter_backend=jupyter_backend)

In [None]:
saved_cpos = plotter.camera_position

In [None]:
plotter.save_graphic(f"{batch['stimulus_id'][0]}_2.svg")

In [None]:
saved_cpos = None