In [1]:
%cd ../src

/home/ubuntu/SPVD_Lightning/src


In [2]:
import torch
torch.set_float32_matmul_precision('medium')

In [3]:
from datasets.shapenet.shapenet_loader import get_dataloaders

categories = ['bowl']
path = "../data/ShapeNet"
tr, te, val = get_dataloaders(path, categories=categories, load_renders=True)

Loading (train) renders for bowl (02880940): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 366.36it/s]
Loading (test) renders for bowl (02880940): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:00<00:00, 438.44it/s]
Loading (val) renders for bowl (02880940): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 427.86it/s]


In [4]:
from my_models.spvd import SPVUnet

model_params = {
    "features": (32, 32, 64, 128, 256),
    "attn_heads_list": (None, None, None, 8),
    "cross_attn_heads_list": (None, None, 8, 8),
    "cross_attn_cond_dim": 768,
}

model = SPVUnet(**model_params)

In [5]:
from my_models.lightning_base import DiffusionBase

model = DiffusionBase(model)

In [6]:
# def print_model_parameters(model, prefix=""):
#     for name, child in model.named_children():
#         param_count = sum(p.numel() for p in child.parameters())
#         if param_count == 0:
#             continue  # Skip modules without parameters (e.g., ReLU)
#         print(f"{prefix}{name}: {child.__class__.__name__}")
#         print(f"{prefix}Parameters: {param_count:,}")
#         print_model_parameters(child, prefix + "  ")  # Recurse for nested modules

# print_model_parameters(model)

In [7]:
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint

# Save every 50 epochs, with epoch number in filename
checkpoint = ModelCheckpoint(
    dirpath="../checkpoints/ShapeNet",
    filename="chair-airplane-{epoch:03d}",  # e.g., model-epoch=050.ckpt
    every_n_epochs=50,            # Save interval
    save_top_k=-1,                # Keep all checkpoints (-1 = no limit)
)

epochs = 100
trainer = L.Trainer(
    log_every_n_steps=5,
    max_epochs=epochs,
    # callbacks=[checkpoint],
    gradient_clip_val=10.0,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [8]:
trainer.fit(model=model, train_dataloaders=tr, val_dataloaders=val)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type    | Params | Mode 
------------------------------------------
0 | model | SPVUnet | 25.2 M | train
------------------------------------------
25.2 M    Trainable params
0         Non-trainable params
25.2 M    Total params
100.863   Total estimated model params size (MB)
299       Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

RuntimeError: shape '[17, -1, 128]' is invalid for input of size 319488

In [None]:
import os
os.makedirs("../checkpoints/ShapeNet/", exist_ok=True)
checkpoint = {
    "state_dict": model.state_dict(),
    "down_blocks": down_blocks,
    "up_blocks": up_blocks,
}
torch.save(checkpoint, "../checkpoints/ShapeNet/chair-airplane-2000.ckpt")

In [None]:
import numpy as np
samples = 16
references = [tr.dataset[idx] for idx in np.random.choice(list(range(len(tr.dataset))), size=(16,))]
reference_images = torch.stack([r["render-features"] for r in references]).to("cuda")

In [None]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
ddpm_sched = DDPMSparseScheduler(beta_min=0.0001, beta_max=0.02, steps=1024)

In [None]:
samples = 16
model = model.cuda().eval()
preds = ddpm_sched.sample(model, samples, 4096, reference=reference_images, stochastic=True)

In [None]:
from utils.visualization import display_pointclouds_grid

In [None]:
display_pointclouds_grid(preds.cpu().numpy(), offset=8, point_size=0.3)

In [None]:
real = np.array([r["pc"] for r in references])
display_pointclouds_grid(real, offset=8, point_size=0.3)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

fig, axs = plt.subplots(4, 4, figsize=(10, 10))

for i in range(4):
    for j in range(4):
        ax = axs[i][j]
        ax.axis('off')
        idx = i * 4 + j
        view = references[idx]["selected-view"]
        file = references[idx]["filename"]

        img = mpimg.imread(f"../data/ShapeNet/renders/{file}/00{view}.png")
        ax.imshow(img)