In [2]:
import os
import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from scipy.spatial.transform import Rotation as R
import time
import pickle
import argparse
from utils import *

from deepVCP import DeepVCP
from CustomDataset2 import CustomDataset2
from deepVCP_loss import deepVCP_loss
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from tqdm import tqdm

# conda install numpy tqdm matplotlib
# conda install -c anaconda scikit-learn
# conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
# conda install -c conda-forge trimesh


In [3]:
dataset = 'CustomDataset2'
retrain_path = 'store'
model_path = 'Optimise_final_model.pt'
full_dataset = True

In [4]:
num_epochs = 20
batch_size = 1
lr = 0.001
# loss balancing factor 
alpha = 0.5

print(f"Params: epochs: {num_epochs}, batch: {batch_size}, lr: {lr}, alpha: {alpha}\n")

# check if cuda is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}")


Params: epochs: 20, batch: 1, lr: 0.001, alpha: 0.5

device: cuda


In [5]:
root = 'meshes/'
train_data= CustomDataset2(root=root)

Processing bun000_v2.ply
(40245, 3)
Processing bun045_v2.ply
(40091, 3)
Processing bun090_v2.ply
(30373, 3)
Processing bun180_v2.ply
(40247, 3)
Processing bun270_v2.ply
(31697, 3)
Processing bun315_v2.ply
(35334, 3)
# Total clouds 6


In [7]:
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)

num_train = len(train_data)
print('Train dataset size: ', num_train)

use_normal = dataset == "modelnet"

# Initialize the model
model = DeepVCP(use_normal=use_normal)
model.load_state_dict(torch.load('./first_train_epoch_19_model.pt', map_location='cuda'))
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = nn.DataParallel(model)

model.to(device)

Train dataset size:  30


DeepVCP(
  (FE1): feat_extraction_layer(
    (sa1): PointNetSetAbstraction(
      (mlp_convs): ModuleList(
        (0): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
        (2): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (mlp_bns): ModuleList(
        (0): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (sa2): PointNetSetAbstraction(
      (mlp_convs): ModuleList(
        (0): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (mlp_bns): ModuleList(
        (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, 

In [9]:
optim = Adam(model.parameters(), lr=lr)

# begin train 
model.train()
loss_epoch_avg = []
for epoch in tqdm(range(num_epochs)):
    print(f"epoch #{epoch}")
    loss_epoch = []
    running_loss = 0.0

    for n_batch, (src, target, R_gt, t_gt, R_prior) in enumerate(train_loader):
        start_time = time.time()
        # mini batch
        src, target, R_gt, t_gt, R_prior = src.to(device), target.to(device), R_gt.to(device), t_gt.to(device), R_prior.to(device)
        t_init = torch.randn((1, 3))
        if epoch < 3:
            src_keypts, target_vcp = model(src, target, R_gt, t_init)
        else:
            src_keypts, target_vcp = model(src, target, R_prior, t_init)
        # print('src_keypts shape', src_keypts.shape)
        # print('target_vcp shape', target_vcp.shape)
        # zero gradient 
        optim.zero_grad()
        loss, R_pred, t_pred = deepVCP_loss(src_keypts, target_vcp, R_gt, t_gt, alpha=0.5)


        # error metric for rigid body transformation
        r_pred = R.from_matrix(R_pred.squeeze(0).cpu().detach().numpy())
        r_pred_arr = torch.tensor(r_pred.as_euler('xyz', degrees=True)).reshape(1, 3)
        r_gt = R.from_matrix(R_gt.squeeze(0).cpu().detach().numpy())
        r_gt_arr = torch.tensor(r_gt.as_euler('xyz', degrees=True)).reshape(1, 3)
        pdist = nn.PairwiseDistance(p = 2)
        t_pred = t_pred.squeeze(-1)
        t_gt = t_gt.squeeze(-1)

        print("rotation error: ", pdist(r_pred_arr, r_gt_arr).item())
        print("translation error: ", pdist(t_pred, t_gt).item())

        # backward pass
        loss.backward()
        # update parameters 
        optim.step()

        running_loss += loss.item()
        loss_epoch += [loss.item()]
        print("--- %s seconds ---" % (time.time() - start_time))
        if (n_batch + 1) % 5 == 0:
            print("Epoch: [{}/{}], Batch: {}, Loss: {}".format(
                epoch, num_epochs, n_batch, loss.item()))
            running_loss = 0.0

    torch.save(model.state_dict(), "train2_epoch_" + str(epoch) + "_model.pt")
    loss_epoch_avg += [sum(loss_epoch) / len(loss_epoch)]
    print("loss epoch", loss_epoch)
    with open("training_loss_" + str(epoch) + ".txt", "wb") as fp:   #Pickling
        pickle.dump(loss_epoch, fp)


  0%|          | 0/20 [00:00<?, ?it/s]

epoch #0
feature extraction time:  6.514231443405151
src_keypts_idx_unsqueezed:  torch.Size([1, 3, 64])
src_keypts:  torch.Size([1, 64, 3])
Grouping keypoints time:  0.014949798583984375
B:  1
K_topk:  64
nsample:  32
num_feat:  32
get_cat_feat_src time:  0.0


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([6647,  331, 5217,  ..., 6068,  710, 5802], device='cuda:0')
get_cat_feat_tgt time:  0.10564661026000977
Loss: 0.009534380564656116
rotation error:  0.5830700465061389
translation error:  0.018363410093700242
--- 13.175766229629517 seconds ---
feature extraction time:  5.948279619216919
src_keypts_idx_unsqueezed:  torch.Size([1, 3, 64])
src_keypts:  torch.Size([1, 64, 3])
Grouping keypoints time:  0.015946626663208008
B:  1
K_topk:  64
nsample:  32
num_feat:  32
get_cat_feat_src time:  0.0009970664978027344
tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mas

  5%|▌         | 1/20 [06:32<2:04:26, 392.99s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([7111, 6671, 8274,  ..., 5871, 7099, 7447], device='cuda:0')
get_cat_feat_tgt time:  0.08571338653564453
Loss: 0.011340488127863847
rotation error:  1.684585901520797
translation error:  0.025077068109433
--- 12.881672382354736 seconds ---
Epoch: [0/20], Batch: 29, Loss: 0.011340488127863847
loss epoch [0.009534380564656116, 0.0289460798490175, 0.027248642048610047, 0.01287297203202614, 0.009899172638327867, 0.013431346016354597, 0.010420513220999437, 0.0024087996180834, 0.011687378788191217, 0.016439070369957374, 0.008711001751530376, 0.004177744110385644, 0.006483061964839447, 0.009487632056372952, 0.009456632267819357, 0.006081971650677381, 0.003864824855263435, 0.007297041967699407, 0.008007611262633786, 0.0068134359512429495, 0.00791405

 10%|█         | 2/20 [13:07<1:58:14, 394.12s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([4590, 5505, 2842,  ...,  793, 8591,  451], device='cuda:0')
get_cat_feat_tgt time:  0.09966659545898438
Loss: 0.009523812904433859
rotation error:  1.6035487633886156
translation error:  0.018287188696785053
--- 13.188870191574097 seconds ---
Epoch: [1/20], Batch: 29, Loss: 0.009523812904433859
loss epoch [0.008218081927973762, 0.004211456075801618, 0.0033743481419934067, 0.007655996853436484, 0.008112031264531635, 0.007811046140674535, 0.0041859663373815335, 0.006882031127524826, 0.008104870044218476, 0.00795531342780694, 0.009281226642050933, 0.005748528923793285, 0.004287760735529958, 0.007166221533380425, 0.009639544804104627, 0.006812624345002246, 0.004124522576254514, 0.0038397763107092704, 0.0063910632584177115, 0.005584540724610499,

 15%|█▌        | 3/20 [19:41<1:51:33, 393.73s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([4354, 3060, 5497,  ..., 3195, 3347, 5095], device='cuda:0')
get_cat_feat_tgt time:  0.0847165584564209
Loss: 0.0063220229538571504
rotation error:  1.1815612157768272
translation error:  0.019195024869227844
--- 12.7914719581604 seconds ---
Epoch: [2/20], Batch: 29, Loss: 0.0063220229538571504
loss epoch [0.008156840456064782, 0.006413460355720785, 0.008348724000038195, 0.007612358711601241, 0.0031042225706759572, 0.0017941486427526156, 0.006910690421908292, 0.009776710373762694, 0.007728072077723155, 0.004599568705177178, 0.004679530617495873, 0.00449028950480047, 0.009950615967689798, 0.008404629687763966, 0.00692832040396274, 0.006595499992475647, 0.005189633312705725, 0.0053852936498893144, 0.0039034317005929736, 0.006704589378644847, 0

 20%|██        | 4/20 [26:15<1:45:02, 393.91s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([3713, 9708, 4433,  ..., 1638, 4891, 3653], device='cuda:0')
get_cat_feat_tgt time:  0.09867095947265625
Loss: 0.01804175500589776
rotation error:  30.327653366778296
translation error:  0.044153141024343485
--- 13.094534635543823 seconds ---
Epoch: [3/20], Batch: 29, Loss: 0.01804175500589776
loss epoch [0.015972263039500792, 0.023672634769995253, 0.019792364744800353, 0.011074696602435452, 0.019742995307026435, 0.03437075397515377, 0.014688633524706257, 0.052075104758496, 0.040958237596922256, 0.028988865440382275, 0.036996581434800016, 0.03820132859320609, 0.049776599250084874, 0.05414900596289546, 0.032693126272688194, 0.022449587953795952, 0.018455131367443322, 0.008908412602773964, 0.03168684355690406, 0.03447919867893273, 0.0292162968

 25%|██▌       | 5/20 [32:52<1:38:44, 394.98s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([7482, 4583, 2646,  ...,  377, 8061, 2288], device='cuda:0')
get_cat_feat_tgt time:  0.08073043823242188
Loss: 0.02552076556703975
rotation error:  55.38817600943025
translation error:  0.014577770757211346
--- 13.32192063331604 seconds ---
Epoch: [4/20], Batch: 29, Loss: 0.02552076556703975
loss epoch [0.023365394237208424, 0.015511220219095602, 0.01765129971019422, 0.013453208080810882, 0.01840968413035846, 0.02427579864477028, 0.022286089659983837, 0.015956176617133788, 0.02510389759851156, 0.021224789696585776, 0.010548489404358779, 0.020178970271000757, 0.01411042374931099, 0.03490287227672255, 0.01678813146473766, 0.017016129889682898, 0.005334209424311178, 0.014477954663277394, 0.012441486854307869, 0.01779761350675886, 0.033126256091

In [None]:
# save
print("Finished Training")

Finished Training
