In [1]:
# Add the parent directory to the path so that we can import the necessary modules
import sys
sys.path.append('../')

# Hyperparameters

In [2]:
args = {
    'categories': ['chair'],
    'train_epochs': 50,
    'lr': 0.001,
    'load_checkpoint': False, 
    'save_checkpoint': 'all_depth4',
    'save_per_epochs': 50
}

# Dataset for Point-Voxels

In [3]:
from datasets.shapenet_pointflow_sparse import get_dataloaders
from pclab.utils import DataLoaders

path = "/home/tourloid/Desktop/PhD/Data/ShapeNetCore.v2.PC15k"
#path = "/home/vvrbeast/Desktop/Giannis/Data/ShapeNetCore.v2.PC15k"
train_dl, valid_dl = get_dataloaders(path, args['categories'])
dls = DataLoaders(train_dl, valid_dl)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
(1, 1, 1)
Total number of data:4612
Min number of points: (train)2048 (test)2048
(1, 1, 1)
Total number of data:662
Min number of points: (train)2048 (test)2048


# Network

In [4]:
#export
from models.modelv1 import SPVUnet

# Training

## Callbacks

In [5]:
#export
from pclab.learner import *
from pclab.utils import def_device
import fastcore.all as fc
from typing import Mapping
from copy import copy
from torcheval.metrics import Mean
from utils.callbacks import GradientClipCB
from functools import partial
import torch 
import torch.nn as nn
import torchsparse

In [6]:
class DDPMCB(Callback):
    
    def before_batch(self, learn): 
        pts = learn.batch['input']
        t = torch.tensor(learn.batch['t'])
        noise = learn.batch['noise']
        inp = (pts, t)
        learn.batch = (inp, noise.F)

In [7]:
def to_device(x, device=def_device):
    if isinstance(x, (torch.Tensor, torchsparse.SparseTensor)): return x.to(device)
    if isinstance(x, Mapping): return {k:v.to(device) for k,v in x.items()}
    return type(x)(to_device(o, device) for o in x)

class DeviceCBSparse(Callback):
    order = DDPMCB.order + 1
    def __init__(self, device=def_device): fc.store_attr()
    def before_fit(self, learn):
        if hasattr(learn.model, 'to'): learn.model.to(self.device)
    def before_batch(self, learn): learn.batch = to_device(learn.batch, device=self.device)

In [8]:
# Callback to monitor the loss

class LossCB(Callback):
    def __init__(self, *ms, **metrics):
        for o in ms: metrics[type(o).__name__] = o
        self.metrics = metrics
        self.all_metrics = copy(metrics)
        self.all_metrics['loss'] = self.loss = Mean()

    def _log(self, d): print(d)
    def before_fit(self, learn): learn.metrics = self
    def before_epoch(self, learn): [o.reset() for o in self.all_metrics.values()]

    def after_epoch(self, learn):
        log = {k:f'{v.compute():.3f}' for k,v in self.all_metrics.items()}
        log['epoch'] = learn.epoch
        log['train'] = 'train' if learn.model.training else 'eval'
        self._log(log)

    def after_batch(self, learn):
        x,y,*_ = learn.batch
        #for m in self.metrics.values(): m.update(to_cpu(learn.preds), y)
        self.loss.update(to_cpu(learn.loss), weight=2)

## Train

In [9]:
get_model = partial(SPVUnet, voxel_size=0.1, pres=1e-5)

In [10]:
from utils import model_num_params
model = get_model()
model_num_params(model)

15860739

### LR Finder

In [11]:
ddpm_cb = DDPMCB()
model = get_model()
learn = TrainLearner(model, dls, nn.MSELoss(), cbs=[ddpm_cb, DeviceCBSparse(), GradientClipCB()], opt_func=torch.optim.Adam)
#learn.lr_find(max_mult=3)

## Training loop

In [12]:
lr = 0.005
epochs = 15 #500 #args['train_epochs'] #400

model = get_model()

# scheduler
total_steps = epochs * len(dls.train)
sched = partial(torch.optim.lr_scheduler.OneCycleLR, max_lr=lr, total_steps = total_steps)

# Callbacks
ddpm_cb = DDPMCB()
cbs = [ddpm_cb, DeviceCBSparse(), ProgressCB(plot=False), LossCB(), BatchSchedCB(sched), GradientClipCB()]

learn = TrainLearner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=torch.optim.Adam)
learn.fit(epochs)

loss,epoch,train
0.558,0,train
0.339,0,eval
0.265,1,train
0.267,1,eval
0.262,2,train
0.277,2,eval
0.264,3,train
0.275,3,eval
0.267,4,train
0.27,4,eval


KeyboardInterrupt: 

# Inference

In [None]:
from tqdm import tqdm
import math
import numpy as np
from utils.visualization import quick_vis_batch
from torchsparse.utils.quantize import sparse_quantize
from torchsparse import SparseTensor
from torchsparse.utils.collate import sparse_collate_fn
vis_batch = partial(quick_vis_batch, x_offset = 8, y_offset=8)

## Noise Batch Generation

In [None]:
class DDPMSchedulerBase:

    def __init__(self, beta_min=0.0001, beta_max=0.02, n_steps=1000, mode='linear'):

        self.beta_min, self.beta_max, self.n_steps = beta_min, beta_max, n_steps

        if mode == 'linear':
            self.beta, self.alpha, self.alpha_hat, self.alpha_hat_prev1 = self._linear_scheduling()
        else: 
            raise NotImplementedError
        

    def _linear_scheduling(self):

        beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps)
        alpha = 1. - beta
        alpha_hat = torch.cumprod(alpha, dim=0)
        alpha_hat_prev1 = torch.ones_like(alpha_hat) # necessary for some scheduling stategies
        alpha_hat_prev1[1:] = alpha_hat[:-1] # we want a_hat[0] == 1
        
        return beta, alpha, alpha_hat, alpha_hat_prev1

    def append_prediction(self, preds, x_t):
        preds.append(x_t.detach().cpu().numpy())

    def create_noise(self, shape, device):
        return torch.randn(shape).to(device)
    
    def sample(self, model, bs, n_points=2048, nf=3, emb=None, save_process=False):
        """
            Args:
                - model        : neural net for noise prediction
                - shape        : Desired shape of the point cloud BxNxF
                - emb          : conditional embedding, if None it will be ignored
                - save_process : save the intermediate point clouds of the generation process
        """
        device = next(iter(model.parameters())).device
        shape = (bs, n_points, nf)
        x_t = self.create_noise(shape, device)

        if save_process: 
            preds = []
            self.append_prediction(preds, x_t)

        for t in reversed(range(self.n_steps)):
            x_t = self.sample_step(model, x_t, t, emb, shape, device)
            if save_process: self.append_prediction(preds, x_t)

        return preds if save_process else self.append_prediction([], x_t)[0] # In case x_t is not a torch.Tensor


    def sample_step(self, model, x_t, t, emb, shape, device):
        """
            Args:
                - model  : neural net for noise prediction
                - x_t    : previous point cloud
                - t      : current time step
                - emb    : conditional embedding, if None it will be ignored
                - shape  : shape of the point cloud
                - device : device to run the computations
        """
        bs = shape[0]

        # creating the time embedding variable
        t_batch = torch.full((bs,), t, device=device, dtype=torch.long)

        # activate the model to predict the noise
        noise_pred = model((x_t, t_batch)) if emb is None else model((x_t, t_batch, emb))
        
        # calculate the new point coordinates
        x_t = self.update_rule(x_t, noise_pred, t, shape, device)
        
        return x_t

    def predict_x0_from_noise(self, x_t, noise_pred, t, shape, device):
        # x_t.shape : B x N x F
        # noise_pred.shape : B x N x F
        raise NotImplementedError

    def update_rule(self, x_t, noise_pred, t, shape, device):
        # x_t.shape : B x N x F
        # noise_pred.shape : B x N x F
        raise NotImplementedError

    def noisify_sample(self, x0, step):
        raise NotImplementedError

In [None]:
class DDPMSparseScheduler(DDPMSchedulerBase):
    def __init__(self, beta_min=0.0001, beta_max=0.02, n_steps=1000, pres=1e-8, mode='linear'):
        super().__init__(beta_min, beta_max, n_steps, mode)
        self.pres = pres
    
    def sparse_from_pts(self, pts:torch.Tensor, shape):
        # Receive a tensor of points and return a SparseTensor 

        pts = pts.reshape(shape)
        # make coordinates positive
        coords = pts[:, :, :3]
        coords = coords - coords.min(dim=1, keepdim=True)[0]
        coords = coords.numpy()

        # Unfortunately we need to loop over the batch to apply sparse_quantize 
        # Also DATA have to be in CPU and coords represented as np.arrays
        batch = []
        for b in range(shape[0]):

            c, indices = sparse_quantize(coords[b], self.pres, return_index=True)
            f = pts[b][indices]

            batch.append(
                {'pc':SparseTensor(coords = torch.tensor(c), feats=f)}
            )
        
        batch = sparse_collate_fn(batch)['pc']

        return batch

    def append_prediction(self, preds, x_t):
        preds.append(x_t.F.detach().cpu().numpy())
    
    def create_noise(self, shape, device):
        noise = torch.randn(shape)
        noise = self.sparse_from_pts(noise, shape).to(device)
        return noise

    def update_rule(self, x_t, noise_pred, t, shape, device):
        
        x_t = x_t.F

        # predict x0 from noise
        x0 = self.predict_x0_from_noise(x_t, noise_pred, t, shape, device)

        coef_x0 = self.beta[t] * self.alpha_hat[t-1].sqrt() / (1 - self.alpha_hat[t])
        coef_xt = (1 - self.alpha_hat_prev1[t]) * self.alpha[t].sqrt() / (1 - self.alpha_hat[t])

        mean = coef_x0 * x0 + coef_xt * noise_pred

        variance = torch.exp(0.5 * torch.log(torch.max(self.beta[t] * (1 - self.alpha_hat_prev1[t]) / (1 - self.alpha_hat[t]), torch.full_like(self.beta[t], 1e-8, device=device))))

        x_t = mean + variance * torch.randn_like(mean)

        return self.sparse_from_pts(x_t.detach().cpu(), shape).to(device)

    def predict_x0_from_noise(self, x_t, noise_pred, t, shape, device):
        x0 = (1 / self.alpha_hat[t]).sqrt() * x_t - (1 / self.alpha_hat[t] - 1).sqrt() * noise_pred
        return x0

In [None]:
sched = DDPMSparseScheduler(beta_min=0.0001, beta_max=0.01, n_steps=1000, pres=1e-5) 

In [None]:
batch = sched.sample(model, 32)

In [None]:
vis_batch(batch)