In [1]:
# Move to the SPVD directory(SPVD/notebooks)
%cd ..

/home/ubuntu/SPVD_Lightning


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


# SPVD with PytorchLightning

This notebook includes all the essential code modifications needed to implement SPVD with PyTorch Lightning. These modifications have been exported into a Python script located at `models/lightningBase.py`.

Additionally, we have provided the complete code for training, using, and testing the model. 

If you only want the training, inference, and testing pipeline you can find it in the `TrainGeneration` notebook.

To export a python script from this notebook run:
`python utils/notebook2py.py notebooks/PytorchLightningIntegration.ipynb models/lightningBase.py`

# Pytorch Lightning Intergration

## Imports

In [2]:
#export
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as L
from abc import ABC, abstractmethod

## Task
Using taks allows for easy intergration of similar tasks, like Completion, Super-Resolution as proposed in the SPVD publication.

In [3]:
#export
class Task(ABC):
    @abstractmethod
    def prep_data(self, batch):
        pass
    @abstractmethod
    def loss_fn(self, pred, target):
        pass

In [4]:
#export
class SparseGeneration(Task):
    def prep_data(self, batch):
        noisy_data, t, noise = batch['input'], batch['t'], batch['noise']
        inp = (noisy_data, t)
        return inp, noise.F
    def loss_fn(self, preds, target):
        return F.mse_loss(preds, target)

## DiffusionBase

In [5]:
#export 
class DiffusionBase(L.LightningModule):

    def __init__(self, model, task=SparseGeneration(), lr=0.0002):
        super().__init__()
        self.model = model
        self.task = task
        self.learning_rate = lr
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        # get data from the batch
        inp, target = self.task.prep_data(batch)

        # activate the network for noise prediction
        preds = self(inp)

        # calculate the loss
        loss = self.task.loss_fn(preds, target)

        self.log('train_loss', loss, batch_size=self.tr_batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        inp, target = self.task.prep_data(batch)
        preds = self(inp)
        loss = self.task.loss_fn(preds, target)
        self.log('val_loss', loss, batch_size=self.vl_batch_size)

    def configure_optimizers(self):
        # Create the optimizer
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=0.05)

        # Create a dummy scheduler (we will update `total_steps` later)
        self.lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.learning_rate, total_steps=1)

        # Return optimizer and scheduler (scheduler will be updated in `on_fit_start`)
        return [optimizer], [{'scheduler': self.lr_scheduler, 'interval': 'step'}]

    # Setting the OneCycle scheduler correct number of steps at the start of the fit loop, where the dataloaders are available.
    def on_train_start(self):
        # Access the dataloader and calculate total steps
        train_loader = self.trainer.train_dataloader  # Access the dataloader from the trainer
        steps_per_epoch = len(train_loader)
        total_steps = steps_per_epoch * self.trainer.max_epochs
        
        # Update the scheduler's `total_steps` dynamically
        self.lr_scheduler.total_steps = total_steps

        # Read the batch size for logging
        self.tr_batch_size = self.trainer.train_dataloader.batch_size

    def on_validation_start(self):
        val_loader = self.trainer.val_dataloaders
        if val_loader:
            self.vl_batch_size = val_loader.batch_size

# Training

In [6]:
# Imports
from models import SPVD
from lightning.pytorch.callbacks import ModelCheckpoint
from datasets.shapenet_pointflow_sparse import get_dataloaders
path = '/home/ubuntu/ShapeNetPC'

#optimization for speed 
torch.set_float32_matmul_precision('medium')

/opt/conda/envs/lightning/lib/python3.10/site-packages/timm/models/layers/__init__.py:48: Importing from timm.models.layers is deprecated, please import via timm.layers


In [7]:
model = DiffusionBase(SPVD())

In [8]:
categories = ['car']
tr_dl, te_dl = get_dataloaders(path, categories)

(1, 1, 1)
Total number of data:2458
Min number of points: (train)2048 (test)2048
(1, 1, 1)
Total number of data:352
Min number of points: (train)2048 (test)2048


/opt/conda/envs/lightning/lib/python3.10/site-packages/torch/utils/data/dataloader.py:617: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.


In [None]:
checkpoint_callback = ModelCheckpoint(dirpath='checkpoints/')

trainer = L.Trainer(
    max_epochs=10, 
    gradient_clip_val=10.0, 
    callbacks=[checkpoint_callback]
)

In [None]:
trainer.fit(model=model, train_dataloaders=tr_dl, val_dataloaders=te_dl)

# Inference

In [16]:
# Imports
from utils.schedulers import DDPMSparseSchedulerGPU
from utils.visualization import quick_vis_batch
from functools import partial
vis_batch = partial(quick_vis_batch, x_offset = 8, y_offset=8)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [17]:
# Sending model to GPU and setting to eval mode:
model = model.cuda().eval()

In [18]:
ddpm_sched = DDPMSparseSchedulerGPU(n_steps=1000, beta_min=0.0001, beta_max=0.02, pres=1e-5)
preds = ddpm_sched.sample(model.cuda(), 32, 2048)
vis_batch(preds)

KeyboardInterrupt: 

# Test

In [None]:
from test_generation import evaluate_gen

In [None]:
ddpm_sched = DDPMSparseSchedulerGPU(n_steps=1000, beta_min=0.0001, beta_max=0.02, sigma='coef_bt')
evaluate_gen(path, model, ddpm_sched, save_path='./results/', cates=categories)

In [None]:
ddpm_sched = DDPMSparseSchedulerGPU(n_steps=1000, beta_min=0.0001, beta_max=0.02, sigma='coef_bt')
evaluate_gen(path, model, ddpm_sched, save_path='./results/', cates=categories)

In [None]:
ddpm_sched = DDPMSparseSchedulerGPU(n_steps=1000, beta_min=0.0001, beta_max=0.02, sigma='coef_bt')
evaluate_gen(path, model, ddpm_sched, save_path='./results/', cates=categories)