In [1]:
import os
import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from scipy.spatial.transform import Rotation as R
import time
import pickle
import argparse
from utils import *

from deepVCP import DeepVCP
from CustomDataset2 import CustomDataset2
from deepVCP_loss import deepVCP_loss
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from tqdm import tqdm

In [2]:
dataset = 'CustomDataset2'
retrain_path = 'store'
model_path = 'Optimise_final_model.pt'
full_dataset = True

In [3]:
num_epochs = 20
batch_size = 1
lr = 0.001
# loss balancing factor 
alpha = 0.5

print(f"Params: epochs: {num_epochs}, batch: {batch_size}, lr: {lr}, alpha: {alpha}\n")

# check if cuda is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}")


Params: epochs: 20, batch: 1, lr: 0.001, alpha: 0.5

device: cuda


In [4]:
root = 'meshes/'
train_data= CustomDataset2(root=root)

Processing bun000_v2.ply
(40245, 3)
Processing bun045_v2.ply
(40091, 3)
Processing bun090_v2.ply
(30373, 3)
Processing bun180_v2.ply
(40247, 3)
Processing bun270_v2.ply
(31697, 3)
Processing bun315_v2.ply
(35334, 3)
# Total clouds 6


In [5]:
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)

num_train = len(train_data)
print('Train dataset size: ', num_train)

use_normal = dataset == "modelnet"

# Initialize the model
model = DeepVCP(use_normal=use_normal)
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = nn.DataParallel(model)

model.to(device)

Train dataset size:  30


DeepVCP(
  (FE1): feat_extraction_layer(
    (sa1): PointNetSetAbstraction(
      (mlp_convs): ModuleList(
        (0): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
        (2): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (mlp_bns): ModuleList(
        (0-1): 2 x BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (sa2): PointNetSetAbstraction(
      (mlp_convs): ModuleList(
        (0): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (mlp_bns): ModuleList(
        (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (sa3): PointNetSetAbstraction(

In [6]:
optim = Adam(model.parameters(), lr=lr)

# begin train 
model.train()
loss_epoch_avg = []
for epoch in tqdm(range(num_epochs)):
    print(f"epoch #{epoch}")
    loss_epoch = []
    running_loss = 0.0

    for n_batch, (src, target, R_gt, t_gt, R_prior) in enumerate(train_loader):
        start_time = time.time()
        # mini batch
        src, target, R_gt, t_gt, R_prior = src.to(device), target.to(device), R_gt.to(device), t_gt.to(device), R_prior.to(device)
        t_init = torch.randn((1, 3))
        src_keypts, target_vcp = model(src, target, R_prior, t_init)
        # print('src_keypts shape', src_keypts.shape)
        # print('target_vcp shape', target_vcp.shape)
        # zero gradient 
        optim.zero_grad()
        loss, R_pred, t_pred = deepVCP_loss(src_keypts, target_vcp, R_gt, t_gt, alpha=0.5)


        # error metric for rigid body transformation
        r_pred = R.from_matrix(R_pred.squeeze(0).cpu().detach().numpy())
        r_pred_arr = torch.tensor(r_pred.as_euler('xyz', degrees=True)).reshape(1, 3)
        r_gt = R.from_matrix(R_gt.squeeze(0).cpu().detach().numpy())
        r_gt_arr = torch.tensor(r_gt.as_euler('xyz', degrees=True)).reshape(1, 3)
        pdist = nn.PairwiseDistance(p = 2)
        t_pred = t_pred.squeeze(-1)
        t_gt = t_gt.squeeze(-1)

        print("rotation error: ", pdist(r_pred_arr, r_gt_arr).item())
        print("translation error: ", pdist(t_pred, t_gt).item())

        # backward pass
        loss.backward()
        # update parameters 
        optim.step()

        running_loss += loss.item()
        loss_epoch += [loss.item()]
        print("--- %s seconds ---" % (time.time() - start_time))
        if (n_batch + 1) % 5 == 0:
            print("Epoch: [{}/{}], Batch: {}, Loss: {}".format(
                epoch, num_epochs, n_batch, loss.item()))
            running_loss = 0.0

    torch.save(model.state_dict(), "epoch_" + str(epoch) + "_model.pt")
    loss_epoch_avg += [sum(loss_epoch) / len(loss_epoch)]
    print("loss epoch", loss_epoch)
    with open("training_loss_" + str(epoch) + ".txt", "wb") as fp:   #Pickling
        pickle.dump(loss_epoch, fp)


  0%|          | 0/20 [00:00<?, ?it/s]

epoch #0
feature extraction time:  11.372282028198242
src_keypts_idx_unsqueezed:  torch.Size([1, 3, 64])
src_keypts:  torch.Size([1, 64, 3])
Grouping keypoints time:  0.025913476943969727
B:  1
K_topk:  64
nsample:  32
num_feat:  32
get_cat_feat_src time:  0.0009963512420654297


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([3542,    9, 2699,  ..., 8625, 2389, 1841], device='cuda:0')
get_cat_feat_tgt time:  0.09368777275085449
Loss: 0.27215408631277443
rotation error:  39.490341293767734
translation error:  0.47490653265780325
--- 22.766135931015015 seconds ---
feature extraction time:  11.084944486618042
src_keypts_idx_unsqueezed:  torch.Size([1, 3, 64])
src_keypts:  torch.Size([1, 64, 3])
Grouping keypoints time:  0.03189349174499512
B:  1
K_topk:  64
nsample:  32
num_feat:  32
get_cat_feat_src time:  0.0
tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([5413, 23

  5%|▌         | 1/20 [11:02<3:29:54, 662.88s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([6352, 6760, 7495,  ..., 4518, 9850, 3374], device='cuda:0')
get_cat_feat_tgt time:  0.10564661026000977
Loss: 0.02839139061696496
rotation error:  60.91031040204419
translation error:  0.052281636713803245
--- 21.997391939163208 seconds ---
Epoch: [0/20], Batch: 29, Loss: 0.02839139061696496
loss epoch [0.27215408631277443, 0.2626272534105449, 0.2759550332284413, 0.2569048879682624, 0.25531831264536714, 0.2626025544007037, 0.21656447199963555, 0.17483902541840274, 0.1430070463635794, 0.07406528233621232, 0.04105310204861015, 0.05715398043256162, 0.07417054561848246, 0.09242938907047425, 0.07078995325142337, 0.09345670959039726, 0.039199910567100725, 0.023604600978700783, 0.015499640470489778, 0.02550825338011013, 0.04486942788566296, 0.0715

 10%|█         | 2/20 [22:06<3:18:57, 663.18s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([3020, 9258, 5486,  ..., 2582, 5427, 6564], device='cuda:0')
get_cat_feat_tgt time:  0.10842609405517578
Loss: 0.018983140298821727
rotation error:  44.24672326662384
translation error:  0.029458435148865814
--- 22.244384050369263 seconds ---
Epoch: [1/20], Batch: 29, Loss: 0.018983140298821727
loss epoch [0.04568846820467607, 0.012761129136200616, 0.010121664667826272, 0.011571962497832801, 0.026447941967756263, 0.026090647680229673, 0.02186965756519354, 0.02676050733153218, 0.03915574585317827, 0.02810654745781182, 0.015405710577113513, 0.010516170716419294, 0.0346670137914504, 0.026853243332122623, 0.02967634155817417, 0.02037993821544082, 0.030125084030481416, 0.033549867511366016, 0.03441587532639362, 0.027837107744451257, 0.02263311848

 15%|█▌        | 3/20 [33:11<3:08:03, 663.73s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([2763, 4595, 8770,  ..., 8875,  562,  271], device='cuda:0')
get_cat_feat_tgt time:  0.09343242645263672
Loss: 0.021511585208157057
rotation error:  60.10495370988108
translation error:  0.0199224900767446
--- 22.303500175476074 seconds ---
Epoch: [2/20], Batch: 29, Loss: 0.021511585208157057
loss epoch [0.02654265852315514, 0.04564492406069811, 0.036653326705654896, 0.02693318010646642, 0.03370382818798216, 0.029663301491645696, 0.038137891323431344, 0.02296433344886259, 0.045562064355537446, 0.04029032078102505, 0.029167343392451653, 0.017023859978664534, 0.024965274848978213, 0.014393721395655241, 0.015556347940332196, 0.03502335742532366, 0.023208722085231948, 0.0212899967720644, 0.02771889507081896, 0.016112979741246362, 0.0268079410759

In [None]:
# save
print("Finished Training")
torch.save(model.state_dict(), model_path)

Finished Training
