In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
from salad.models.phase1 import Phase1Model
from salad.models.phase2 import Phase2Model
from salad.models.language_phase1 import LangPhase1Model
from salad.models.language_phase2 import LangPhase2Model
from omegaconf import OmegaConf
import trimesh
import hydra
from salad.model_components.network import UnCondDiffNetwork, CondDiffNetwork
from salad.utils import visutil, imageutil
from salad.utils.spaghetti_util import *
from typing import Literal
from pytorch_lightning import seed_everything
device = "cuda:2"

def load_model(category: Literal["airplane", "chair", "table", "car"], model_class: Literal["phase1", "phase2"], device):
    c = OmegaConf.load(f"../checkpoints/{category}/{model_class}/hparams.yaml")
    model = hydra.utils.instantiate(c)
    # ckpt = torch.load(f"../checkpoints/{category}/{model_class}/car.ckpt")
    ckpt = torch.load(f"/auto/k2/ademirtas/codes/diffusion/salad/results/car/phase1/1110_015913/checkpoints/epoch=3382-train_loss=0.0000.ckpt")
    model.load_state_dict(ckpt["state_dict"])
    model.eval()
    for p in model.parameters(): p.requires_grad_(False)
    model = model.to(device)
    return model

In [None]:
phase1_model = load_model("car", "phase1", device)

In [None]:
seed_everything(63)

# phase1 sampling
extrinsics = phase1_model.sampling_poses(4)

In [None]:
extrinsics.shape

In [None]:
poses = extrinsics

In [None]:
from primitives import mesh_cuboid, mesh_cylinder

In [None]:
# Get the pose of the first car's part
idx = 0
pose = poses[idx].detach().cpu().numpy()

# Extract the rotation, translation, and scale parameters
quaternions = pose[:, :4]
translations = pose[:, 4:7]
scales = pose[:, 7:10]

# Mesh cuboids or cylinders
# NOTE: the factors of 2 are needed for some parameters when converting to trimesh primitives
#   because my scale factor are half-lengths. 
cuboids = [mesh_cuboid(scales[i] * 2, translations[i], quaternions[i]) for i in range(0, 1)]
cylinders = [mesh_cylinder(scales[i, 0], scales[i, 2] * 2, translations[i], quaternions[i]) for i in range(1,5)]
# cylinders.append(mesh_cylinder(scales[i, 0], scales[i, 2] * 2, translations[i], quaternions[i]) for i in range(2, 5))

# Visualize them
trimesh.Scene([trimesh.creation.axis()] + cylinders + cuboids).show()