In [1]:
%cd ../src

/home/ubuntu/SPVD_Lightning/src


In [2]:
from models.ddpm_unet_cattn import SPVUnet
import torch
import lightning as L
from models.g_spvd import GSPVD



In [3]:
model_args = {
    'voxel_size' : 0.1,
    'nfs' : (32, 64, 128, 256), 
    'attn_chans' : 8, 
    'attn_start' : 3, 
    'cross_attn_chans' : 8, 
    'cross_attn_start' : 2, 
    'cross_attn_cond_dim' : 768,
}

model = SPVUnet(**model_args)
model = GSPVD(model=model, lr=0.1, training_steps=1)

In [5]:
def print_model_parameters(model, prefix=""):
    for name, child in model.named_children():
        param_count = sum(p.numel() for p in child.parameters())
        if param_count == 0:
            continue  # Skip modules without parameters (e.g., ReLU)
        print(f"{prefix}{name}: {child.__class__.__name__}")
        print(f"{prefix}Parameters: {param_count:,}")
        print_model_parameters(child, prefix + "  ")  # Recurse for nested modules

print_model_parameters(model)

model: SPVUnet
Parameters: 25,003,718
  conv_in: Conv3d
  Parameters: 2,592
  emb_mlp: Sequential
  Parameters: 20,800
    0: Sequential
    Parameters: 4,288
      0: BatchNorm1d
      Parameters: 64
      2: Linear
      Parameters: 4,224
    1: Sequential
    Parameters: 16,512
      1: Linear
      Parameters: 16,512
  downs: ModuleList
  Parameters: 4,768,512
    0: DownBlock
    Parameters: 71,936
      resnets: ModuleList
      Parameters: 63,744
        0: EmbResBlock
        Parameters: 63,744
          conv1: Sequential
          Parameters: 27,744
            0: BatchNorm
            Parameters: 64
            2: Conv3d
            Parameters: 27,680
          conv2: Sequential
          Parameters: 27,744
            0: BatchNorm
            Parameters: 64
            2: Conv3d
            Parameters: 27,680
          t_emb: TimeEmbeddingBlock
          Parameters: 8,256
            proj_mlp: Linear
            Parameters: 8,256
      down: Conv3d
      Parameters: 8,192
  

In [None]:
ckpt_path = '/home/ubuntu/SPVD_Lightning/checkpoints/GSPVD/all_categories_renders/checkpoints/epoch=399-step=111600.ckpt'
ckpt = torch.load(ckpt_path, weights_only=True)
model.load_state_dict(ckpt['state_dict'])

In [None]:
model = model.cuda().eval()

In [None]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
from utils.visualization import display_pointclouds_grid

ddpm_sched = DDPMSparseScheduler(beta_min=0.0001, beta_max=0.02, steps=1000)

In [None]:
from datasets.shapenet.shapenet_loader import ShapeNet

categories = ['skateboard']
path = "../data/ShapeNet"

te = ShapeNet(path, "test", 2048, categories, load_renders=True)

In [None]:
import numpy as np

samples = 16
references = [te[idx] for idx in np.random.choice(list(range(len(te))), size=(samples,))]

In [None]:
reference_images = torch.stack([r["render-features"] for r in references]).to("cuda")

In [None]:
preds = ddpm_sched.sample(model, 16, 2048, reference=reference_images)

In [None]:
display_pointclouds_grid(preds.cpu().numpy(), offset=8, point_size=0.3)

In [None]:
real = torch.stack([r["pc"] for r in references]).numpy()
display_pointclouds_grid(real, offset=8, point_size=0.3)