In [1]:
%cd ..

/home/tourloid/Desktop/PhD/Code/SPVD


# Dataset

In [2]:
from datasets.partnet import get_sparse_completion_dataloaders
from pclab.utils import DataLoaders

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
path = './data/PartNetProcessed/'
tr_dl, te_dl = get_sparse_completion_dataloaders(path, 'Chair')
dls = DataLoaders(tr_dl, te_dl)

# Model

In [4]:
from models.spvd import SPVUnet
from functools import partial

In [5]:
get_model = partial(SPVUnet, point_channels=4, voxel_size=0.1, num_layers=1, pres=1e-5,
                    down_blocks = [[(32, 64, 128, 192, 192, 256), 
                                    (True, True, True, True, False), 
                                    (None, None, None, 8, 8)]], 
                                    # BLOCK 1
                    up_blocks   = [[(256, 192, 192), 
                                    (True, True), 
                                    (8, 8), 
                                    (3, 3)], 
                                    # BLOCK 2
                                   [(192, 128, 64, 32), 
                                    (True, True, False), 
                                    (None, None, None), 
                                    (3, 3, 3)]])

# Training Loop

## DDPM and MaskedMSELoss

In [6]:
from pclab.learner import *
from utils.callbacks import *
from pclab.learner import Callback
from functools import partial
import torch
import torch.nn as nn

In [7]:
class DDPMCB(Callback):

    def before_batch(self, learn):
        pts = learn.batch['input']
        t = torch.tensor(learn.batch['t'])
        inp = (pts, t)
        target = learn.batch['noise'], learn.batch['mask'].view(-1)
        learn.batch = (inp, target)

class CustomMSELoss(nn.Module):

    def __init__(self):
        super().__init__()

        self.loss_fn = nn.MSELoss()
    
    def forward(self, preds, target):

        noise, mask = target

        # calculate loss only for the noisy points
        preds = preds[~mask, :3]
        
        noise = noise.view(-1, 3)[~mask]

        return self.loss_fn(preds, noise)

## LR Finder

In [8]:
# ddpm_cb = DDPMCB()
# model = get_model()
# learn = TrainLearner(model, dls, CustomMSELoss(), cbs=[ddpm_cb, DeviceCBSparse(), GradientClipCB()], opt_func=torch.optim.Adam)
# learn.lr_find(max_mult=3)

## Training

In [9]:
# lr = 0.0002 
# epochs = 2000

# model = get_model()

# # scheduler
# total_steps = epochs * len(dls.train)
# sched = partial(torch.optim.lr_scheduler.OneCycleLR, max_lr=lr, total_steps = total_steps)

# # Callbacks
# ddpm_cb = DDPMCB()
# checkpoint_cb = CheckpointCB(1000, 'CompletionSPVD', run_params={'msg':model.msg})
# cbs = [ddpm_cb, DeviceCBSparse(), ProgressCB(plot=False), LossCB(), GradientClipCB(), checkpoint_cb, BatchSchedCB(sched)]

# learn = TrainLearner(model, dls, CustomMSELoss(), lr=lr, cbs=cbs, opt_func=torch.optim.Adam)
# learn.fit(epochs)

# Test Completion

In [56]:
from utils.completion_schedulers import DDPMSparseCompletionSchedulerGPU
from utils.visualization import quick_vis_batch, vis_pc_sphere

def pad(t, np):
    B, N, F = t.shape
    padded = torch.zeros(B, np, F).to(t)
    padded[:, :N, :] = t

    return padded

In [11]:
# load model from checkpoint
model = get_model()
model.load_state_dict(torch.load('./checkpoints/CompletionSPVD.pt')['state_dict'])
model = model.eval().cuda()

In [35]:
sched = DDPMSparseCompletionSchedulerGPU()

In [36]:
batch = next(iter(te_dl))

In [79]:
pc_batch = batch['input'].F.reshape(32, 2048, 4)[..., :3]
mask_batch = batch['mask']
for idx in range(7, 32):
    pc = pc_batch[idx]
    mask = mask_batch[idx]
    pc = pc[mask]
    #vis_pc_sphere(pc)
    preds = sched.complete(pc.unsqueeze(0), model, n_points=2048, save_process=False)
    quick_vis_batch(torch.cat([pad(pc.unsqueeze(0), 2048), preds], dim=0), grid=(2,1), x_offset=6)

KeyboardInterrupt: 