In [1]:
%cd ../src

/home/ubuntu/SPVD_Lightning/src


In [2]:
import torch
torch.set_float32_matmul_precision('medium')

In [3]:
diffusion_steps = 1000
starting_checkpoint = f"../checkpoints/distillation/GSPVD/starting.ckpt"

model_args = {
    'voxel_size' : 0.1,
    'nfs' : (32, 64, 128, 256), 
    'attn_chans' : 8, 
    'attn_start' : 3, 
    'cross_attn_chans' : 8, 
    'cross_attn_start' : 2, 
    'cross_attn_cond_dim' : 768,
}

from distillation import DistillationProcess, Teacher, Student
distillation_agent = DistillationProcess(lr=1e-4)



In [4]:
from datasets.shapenet.shapenet_loader import get_dataloaders

categories = ['bowl']
path = "../data/ShapeNet"
tr, te, val = get_dataloaders(path, categories=categories, load_renders=True, n_steps=diffusion_steps, batch_size=32)

Loading (train) renders for bowl (02880940): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 358.30it/s]
Loading (test) renders for bowl (02880940): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:00<00:00, 427.15it/s]
Loading (val) renders for bowl (02880940): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 426.43it/s]


In [5]:
import torch
import os
import lightning as L

In [None]:
N = diffusion_steps
scheduler = "ddim"
distillation_agent.set_teacher(Teacher(model_args, starting_checkpoint, N, scheduler=scheduler))

while N > 0:
    previous_checkpoint = starting_checkpoint if N == diffusion_steps else new_checkpoint
    
    N = (N + 1) // 2
    distillation_agent.set_student(Student(model_args, previous_checkpoint, N, scheduler=scheduler))
    tr.dataset.set_scheduler(distillation_agent.student.diffusion_scheduler)
    te.dataset.set_scheduler(distillation_agent.student.diffusion_scheduler)
    val.dataset.set_scheduler(distillation_agent.student.diffusion_scheduler)


    # distillation_agent.check_initialization()

    max_epochs = 100
    trainer = L.Trainer(
        max_epochs=max_epochs, 
        callbacks=[],
        gradient_clip_val=10.0,
    )
    distillation_agent.learning_rate = 2 * 1e-4

    trainer.fit(distillation_agent, tr, val)
    print(f"Trained Student for {N} steps.")
    break

    folder_path = f"../checkpoints/distillation/GSPVD/{'-'.join(categories)}"
    os.makedirs(folder_path, exist_ok=True)
    new_checkpoint = f"../checkpoints/distillation/GSPVD/{'-'.join(categories)}/{N}-steps.ckpt"
    # torch.save(distillation_agent.student.state_dict(), new_checkpoint)
    distillation_agent.set_teacher(Teacher(model_args, new_checkpoint, N, scheduler=scheduler))

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params | Mode 
--------------------------------------------
0 | teacher | Teacher | 25.0 M | eval 
1 | student | Student | 25.0 M | train
--------------------------------------------
25.0 M    Trainable params
25.0 M    Non-trainable params
50.0 M    Total params
200.030   Total estimated model params size (MB)
295       Modules in train mode
295       Modules in eval mode


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

In [None]:
from utils.visualization import display_pointclouds_grid

# ddpm_sched = distillation_agent.teacher.diffusion_scheduler
ddpm_sched = distillation_agent.student.diffusion_scheduler

In [None]:
import numpy as np

samples = 16
references = [te.dataset[idx] for idx in np.random.choice(list(range(len(te.dataset))), size=(samples,))]

In [None]:
reference_images = torch.stack([r["render-features"] for r in references]).to("cuda")

distillation_agent.student = distillation_agent.student.cuda().eval()
distillation_agent.teacher = distillation_agent.teacher.cuda().eval()

In [None]:
preds = ddpm_sched.sample(distillation_agent.student.model, samples, 2048, reference=reference_images)
# preds = ddpm_sched.sample(distillation_agent.teacher.model, samples, 2048, reference=reference_images)

In [None]:
display_pointclouds_grid(preds.cpu().numpy(), offset=8, point_size=0.3)

In [12]:
real = torch.stack([r["pc"] for r in references]).numpy()
display_pointclouds_grid(real, offset=8, point_size=0.3)

Output()