# Imports and environment setup

In [None]:
# Change Directory to be outside the experiments folder
%cd ..

In [None]:
#import 
from datasets.shapenet_pointflow_sparse_cond import get_dataloaders
from pclab.utils import DataLoaders
from utils.callbacks import *
from pclab.learner import TrainLearner, ProgressCB, BatchSchedCB

In [None]:
from functools import partial
import torchsparse
from torchsparse.utils.collate import sparse_collate_fn
from pclab.utils import DataLoaders
import torch
import torch.nn as nn

## Datasets and Dataloaders

In [None]:
path = "/home/vvrbeast/Desktop/Giannis/Data/ShapeNetCore.v2.PC15k"
path = '/home/tourloid/Desktop/PhD/Data/ShapeNetCore.v2.PC15k/'

categories = ['chair']

tr_dl, te_dl = get_dataloaders(path, categories)
dls = DataLoaders(tr_dl, te_dl)

# Load the model

Uncomment and run one of the following cells to select the model version.

## SVD-S

In [None]:
from models.ddpm_unet_attn import SPVUnet
model = SPVUnet(voxel_size=0.1, nfs=(32, 64, 128, 256), num_layers=1, attn_chans=8, attn_start=3)
checkpoint_name = 'ddpm_unet_attn_64_128_256_256'

## SPVD

In [None]:
# from models import SPVD
# model=SPVD
# checkpoint_name = 'SPVD'

## SPVD-L

In [None]:
# from models import SPVD_L
# model=SPVD_L
# checkpoint_name = 'SPVD_L'

# Training

In [None]:
lr = 0.001
epochs = 2000 

# scheduler
total_steps = epochs * len(dls.train)
sched = partial(torch.optim.lr_scheduler.OneCycleLR, max_lr=lr, total_steps = total_steps)

# Callbacks
ddpm_cb = DDPMCB()
checkpoint_cb = CheckpointCB(2000, checkpoint_name, run_params={})
cbs = [ddpm_cb, DeviceCBSparse(), ProgressCB(plot=False), LossCB(), GradientClipCB(), checkpoint_cb, BatchSchedCB(sched)]

learn = TrainLearner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=torch.optim.Adam)
learn.fit(epochs)

# Inference

In [None]:
from utils.schedulers import DDPMSparseSchedulerGPU
from utils.visualization import quick_vis_batch
vis_batch = partial(quick_vis_batch, x_offset = 8, y_offset=8)

In [None]:
ddpm_sched = DDPMSparseSchedulerGPU(n_steps=1000, beta_min=0.0001, beta_max=0.02, pres=1e-5)
preds = ddpm_sched.sample(model, 32, 2048)
vis_batch(preds)

# Test

In [None]:
ddpm_sched = DDPMSparseSchedulerGPU(n_steps=1000, beta_min=0.0001, beta_max=0.02, sigma='coef_bt')
evaluate_gen(path, model, ddpm_sched, save_path='./results/', cates=categories)

In [None]:
ddpm_sched = DDPMSparseSchedulerGPU(n_steps=1000, beta_min=0.0001, beta_max=0.02, sigma='coef_bt')
evaluate_gen(path, model, ddpm_sched, save_path='./results/', cates=categories)

In [None]:
ddpm_sched = DDPMSparseSchedulerGPU(n_steps=1000, beta_min=0.0001, beta_max=0.02, sigma='coef_bt')
evaluate_gen(path, model, ddpm_sched, save_path='./results/', cates=categories)