In [1]:
%cd ../src

/home/ubuntu/SPVD_Lightning/src


In [2]:
import torch
torch.set_float32_matmul_precision('medium')

In [3]:
diffusion_steps = 1000
starting_checkpoint = f"../checkpoints/distillation/GSPVD/starting.ckpt"

model_args = {
    'voxel_size' : 0.1,
    'nfs' : (32, 64, 128, 256), 
    'attn_chans' : 8, 
    'attn_start' : 3, 
    'cross_attn_chans' : 8, 
    'cross_attn_start' : 2, 
    'cross_attn_cond_dim' : 768,
}

from distillation import DistillationProcess, Teacher, Student
distillation_agent = DistillationProcess(lr=1e-4)



In [4]:
from datasets.shapenet.shapenet_loader import get_dataloaders

categories = ['bowl']
path = "../data/ShapeNet"
tr, te, val = get_dataloaders(path, categories=categories, load_renders=True, n_steps=diffusion_steps, batch_size=32)

Loading (train) renders for bowl (02880940): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 363.33it/s]
Loading (test) renders for bowl (02880940): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:00<00:00, 432.80it/s]
Loading (val) renders for bowl (02880940): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 429.30it/s]


In [5]:
import torch
import lightning as L
from my_schedulers.ddim_scheduler import DDIMSparseScheduler

In [6]:
N = diffusion_steps
distillation_agent.set_teacher(Teacher(model_args, starting_checkpoint, N))

while N > 0:
    previous_checkpoint = starting_checkpoint if N == diffusion_steps else f"../checkpoints/distillation/GSPVD/{N}-steps.ckpt"
    
    N = (N + 1) // 2
    student_sched = DDIMSparseScheduler(steps=N, prev_alpha=distillation_agent.teacher.diffusion_scheduler.alpha)
    tr.dataset.set_scheduler(student_sched)
    te.dataset.set_scheduler(student_sched)
    val.dataset.set_scheduler(student_sched)

    distillation_agent.set_student(Student(model_args, previous_checkpoint, N, scheduler=student_sched))

    distillation_agent.check_initialization()

    max_epochs = 100
    trainer = L.Trainer(
        max_epochs=max_epochs, 
        callbacks=[],
        gradient_clip_val=10.0,
    )

    trainer.fit(distillation_agent, tr, val)
    print(f"Trained Student for {N} steps.")

    new_checkpoint = f"../checkpoints/distillation/GSPVD/{N}-steps.ckpt"
    torch.save(distillation_agent.student.state_dict(), new_checkpoint)
    distillation_agent.set_teacher(Teacher(model_args, new_checkpoint, N, student_sched))
    

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params | Mode 
--------------------------------------------
0 | teacher | Teacher | 25.0 M | eval 
1 | student | Student | 25.0 M | train
--------------------------------------------
25.0 M    Trainable params
25.0 M    Non-trainable params
50.0 M    Total params
200.030   Total estimated model params size (MB)
295       Modules in train mode
295       Modules in eval mode


Initialization check passed for 500 steps.


Sanity Checking: |                                                                                            …

/opt/conda/envs/spvd/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (4) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [7]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
from utils.visualization import display_pointclouds_grid

# ddpm_sched = distillation_agent.teacher.diffusion_scheduler
ddpm_sched = distillation_agent.student.diffusion_scheduler

In [8]:
import numpy as np

samples = 16
references = [tr.dataset[idx] for idx in np.random.choice(list(range(len(tr.dataset))), size=(samples,))]
# references = [te.dataset[idx] for idx in np.random.choice(list(range(len(te.dataset))), size=(samples,))]

In [9]:
reference_images = torch.stack([r["render-features"] for r in references]).to("cuda")

distillation_agent.student = distillation_agent.student.cuda().eval()
distillation_agent.teacher = distillation_agent.teacher.cuda().eval()

In [None]:
preds = ddpm_sched.sample(distillation_agent.student.model, samples, 2048, reference=reference_images)
# preds = ddpm_sched.sample(distillation_agent.teacher.model, samples, 2048, reference=reference_images)

In [None]:
display_pointclouds_grid(preds.cpu().numpy(), offset=8, point_size=0.3)