In [2]:
import os
import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from scipy.spatial.transform import Rotation as R
import time
import pickle
import argparse
from utils import *

from deepVCP import DeepVCP
from CustomDataset2 import CustomDataset2
from deepVCP_loss import deepVCP_loss
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from tqdm import tqdm

# conda install numpy tqdm matplotlib
# conda install -c anaconda scikit-learn
# conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
# conda install -c conda-forge trimesh


In [3]:
dataset = 'CustomDataset2'
retrain_path = 'store'
model_path = 'Optimise_final_model.pt'
full_dataset = True

In [4]:
num_epochs = 20
batch_size = 1
lr = 0.001
# loss balancing factor 
alpha = 0.5

print(f"Params: epochs: {num_epochs}, batch: {batch_size}, lr: {lr}, alpha: {alpha}\n")

# check if cuda is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}")


Params: epochs: 20, batch: 1, lr: 0.001, alpha: 0.5

device: cuda


In [5]:
root = 'meshes/'
train_data= CustomDataset2(root=root)

Processing bun000_v2.ply
(40245, 3)
Processing bun045_v2.ply
(40091, 3)
Processing bun090_v2.ply
(30373, 3)
Processing bun180_v2.ply
(40247, 3)
Processing bun270_v2.ply
(31697, 3)
Processing bun315_v2.ply
(35334, 3)
# Total clouds 6


In [7]:
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)

num_train = len(train_data)
print('Train dataset size: ', num_train)

use_normal = dataset == "modelnet"

# Initialize the model
model = DeepVCP(use_normal=use_normal)
model.load_state_dict(torch.load('./first_train_epoch_19_model.pt', map_location='cuda'))
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = nn.DataParallel(model)

model.to(device)

Train dataset size:  30


DeepVCP(
  (FE1): feat_extraction_layer(
    (sa1): PointNetSetAbstraction(
      (mlp_convs): ModuleList(
        (0): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
        (2): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (mlp_bns): ModuleList(
        (0): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (sa2): PointNetSetAbstraction(
      (mlp_convs): ModuleList(
        (0): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (mlp_bns): ModuleList(
        (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, 

In [9]:
optim = Adam(model.parameters(), lr=lr)

# begin train 
model.train()
loss_epoch_avg = []
for epoch in tqdm(range(num_epochs)):
    print(f"epoch #{epoch}")
    loss_epoch = []
    running_loss = 0.0

    for n_batch, (src, target, R_gt, t_gt, R_prior) in enumerate(train_loader):
        start_time = time.time()
        # mini batch
        src, target, R_gt, t_gt, R_prior = src.to(device), target.to(device), R_gt.to(device), t_gt.to(device), R_prior.to(device)
        t_init = torch.randn((1, 3))
        if epoch < 3:
            src_keypts, target_vcp = model(src, target, R_gt, t_init)
        else:
            src_keypts, target_vcp = model(src, target, R_prior, t_init)
        # print('src_keypts shape', src_keypts.shape)
        # print('target_vcp shape', target_vcp.shape)
        # zero gradient 
        optim.zero_grad()
        loss, R_pred, t_pred = deepVCP_loss(src_keypts, target_vcp, R_gt, t_gt, alpha=0.5)


        # error metric for rigid body transformation
        r_pred = R.from_matrix(R_pred.squeeze(0).cpu().detach().numpy())
        r_pred_arr = torch.tensor(r_pred.as_euler('xyz', degrees=True)).reshape(1, 3)
        r_gt = R.from_matrix(R_gt.squeeze(0).cpu().detach().numpy())
        r_gt_arr = torch.tensor(r_gt.as_euler('xyz', degrees=True)).reshape(1, 3)
        pdist = nn.PairwiseDistance(p = 2)
        t_pred = t_pred.squeeze(-1)
        t_gt = t_gt.squeeze(-1)

        print("rotation error: ", pdist(r_pred_arr, r_gt_arr).item())
        print("translation error: ", pdist(t_pred, t_gt).item())

        # backward pass
        loss.backward()
        # update parameters 
        optim.step()

        running_loss += loss.item()
        loss_epoch += [loss.item()]
        print("--- %s seconds ---" % (time.time() - start_time))
        if (n_batch + 1) % 5 == 0:
            print("Epoch: [{}/{}], Batch: {}, Loss: {}".format(
                epoch, num_epochs, n_batch, loss.item()))
            running_loss = 0.0

    torch.save(model.state_dict(), "train2_epoch_" + str(epoch) + "_model.pt")
    loss_epoch_avg += [sum(loss_epoch) / len(loss_epoch)]
    print("loss epoch", loss_epoch)
    with open("training_loss_" + str(epoch) + ".txt", "wb") as fp:   #Pickling
        pickle.dump(loss_epoch, fp)


  0%|          | 0/20 [00:00<?, ?it/s]

epoch #0
feature extraction time:  6.514231443405151
src_keypts_idx_unsqueezed:  torch.Size([1, 3, 64])
src_keypts:  torch.Size([1, 64, 3])
Grouping keypoints time:  0.014949798583984375
B:  1
K_topk:  64
nsample:  32
num_feat:  32
get_cat_feat_src time:  0.0


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([6647,  331, 5217,  ..., 6068,  710, 5802], device='cuda:0')
get_cat_feat_tgt time:  0.10564661026000977
Loss: 0.009534380564656116
rotation error:  0.5830700465061389
translation error:  0.018363410093700242
--- 13.175766229629517 seconds ---
feature extraction time:  5.948279619216919
src_keypts_idx_unsqueezed:  torch.Size([1, 3, 64])
src_keypts:  torch.Size([1, 64, 3])
Grouping keypoints time:  0.015946626663208008
B:  1
K_topk:  64
nsample:  32
num_feat:  32
get_cat_feat_src time:  0.0009970664978027344
tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mas

  5%|▌         | 1/20 [06:32<2:04:26, 392.99s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([7111, 6671, 8274,  ..., 5871, 7099, 7447], device='cuda:0')
get_cat_feat_tgt time:  0.08571338653564453
Loss: 0.011340488127863847
rotation error:  1.684585901520797
translation error:  0.025077068109433
--- 12.881672382354736 seconds ---
Epoch: [0/20], Batch: 29, Loss: 0.011340488127863847
loss epoch [0.009534380564656116, 0.0289460798490175, 0.027248642048610047, 0.01287297203202614, 0.009899172638327867, 0.013431346016354597, 0.010420513220999437, 0.0024087996180834, 0.011687378788191217, 0.016439070369957374, 0.008711001751530376, 0.004177744110385644, 0.006483061964839447, 0.009487632056372952, 0.009456632267819357, 0.006081971650677381, 0.003864824855263435, 0.007297041967699407, 0.008007611262633786, 0.0068134359512429495, 0.00791405

 10%|█         | 2/20 [13:07<1:58:14, 394.12s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([4590, 5505, 2842,  ...,  793, 8591,  451], device='cuda:0')
get_cat_feat_tgt time:  0.09966659545898438
Loss: 0.009523812904433859
rotation error:  1.6035487633886156
translation error:  0.018287188696785053
--- 13.188870191574097 seconds ---
Epoch: [1/20], Batch: 29, Loss: 0.009523812904433859
loss epoch [0.008218081927973762, 0.004211456075801618, 0.0033743481419934067, 0.007655996853436484, 0.008112031264531635, 0.007811046140674535, 0.0041859663373815335, 0.006882031127524826, 0.008104870044218476, 0.00795531342780694, 0.009281226642050933, 0.005748528923793285, 0.004287760735529958, 0.007166221533380425, 0.009639544804104627, 0.006812624345002246, 0.004124522576254514, 0.0038397763107092704, 0.0063910632584177115, 0.005584540724610499,

 15%|█▌        | 3/20 [19:41<1:51:33, 393.73s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([4354, 3060, 5497,  ..., 3195, 3347, 5095], device='cuda:0')
get_cat_feat_tgt time:  0.0847165584564209
Loss: 0.0063220229538571504
rotation error:  1.1815612157768272
translation error:  0.019195024869227844
--- 12.7914719581604 seconds ---
Epoch: [2/20], Batch: 29, Loss: 0.0063220229538571504
loss epoch [0.008156840456064782, 0.006413460355720785, 0.008348724000038195, 0.007612358711601241, 0.0031042225706759572, 0.0017941486427526156, 0.006910690421908292, 0.009776710373762694, 0.007728072077723155, 0.004599568705177178, 0.004679530617495873, 0.00449028950480047, 0.009950615967689798, 0.008404629687763966, 0.00692832040396274, 0.006595499992475647, 0.005189633312705725, 0.0053852936498893144, 0.0039034317005929736, 0.006704589378644847, 0

 20%|██        | 4/20 [26:15<1:45:02, 393.91s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([3713, 9708, 4433,  ..., 1638, 4891, 3653], device='cuda:0')
get_cat_feat_tgt time:  0.09867095947265625
Loss: 0.01804175500589776
rotation error:  30.327653366778296
translation error:  0.044153141024343485
--- 13.094534635543823 seconds ---
Epoch: [3/20], Batch: 29, Loss: 0.01804175500589776
loss epoch [0.015972263039500792, 0.023672634769995253, 0.019792364744800353, 0.011074696602435452, 0.019742995307026435, 0.03437075397515377, 0.014688633524706257, 0.052075104758496, 0.040958237596922256, 0.028988865440382275, 0.036996581434800016, 0.03820132859320609, 0.049776599250084874, 0.05414900596289546, 0.032693126272688194, 0.022449587953795952, 0.018455131367443322, 0.008908412602773964, 0.03168684355690406, 0.03447919867893273, 0.0292162968

 25%|██▌       | 5/20 [32:52<1:38:44, 394.98s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([7482, 4583, 2646,  ...,  377, 8061, 2288], device='cuda:0')
get_cat_feat_tgt time:  0.08073043823242188
Loss: 0.02552076556703975
rotation error:  55.38817600943025
translation error:  0.014577770757211346
--- 13.32192063331604 seconds ---
Epoch: [4/20], Batch: 29, Loss: 0.02552076556703975
loss epoch [0.023365394237208424, 0.015511220219095602, 0.01765129971019422, 0.013453208080810882, 0.01840968413035846, 0.02427579864477028, 0.022286089659983837, 0.015956176617133788, 0.02510389759851156, 0.021224789696585776, 0.010548489404358779, 0.020178970271000757, 0.01411042374931099, 0.03490287227672255, 0.01678813146473766, 0.017016129889682898, 0.005334209424311178, 0.014477954663277394, 0.012441486854307869, 0.01779761350675886, 0.033126256091

 30%|███       | 6/20 [39:36<1:32:52, 398.07s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([4388, 2165, 3505,  ..., 8308, 5076, 7429], device='cuda:0')
get_cat_feat_tgt time:  0.09867000579833984
Loss: 0.004781143778623183
rotation error:  8.241443693162797
translation error:  0.003015827189880632
--- 13.841875791549683 seconds ---
Epoch: [5/20], Batch: 29, Loss: 0.004781143778623183
loss epoch [0.014404780956858433, 0.024847325788528635, 0.032427015216450236, 0.016629764003417563, 0.019269768178542968, 0.03515448782272406, 0.00652319568959755, 0.025891444648670305, 0.014484617520034099, 0.012476943291707851, 0.023683873145029618, 0.009913921912021473, 0.015409761631251448, 0.017998794795824287, 0.022998709618844617, 0.0227155845782762, 0.014986561691792157, 0.020548494497559265, 0.025503333468369027, 0.039147063747628705, 0.03665

 35%|███▌      | 7/20 [46:16<1:26:23, 398.76s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([6584, 6600, 6879,  ..., 4205,  769, 1923], device='cuda:0')
get_cat_feat_tgt time:  0.10265660285949707
Loss: 0.03653713538825597
rotation error:  343.662391504387
translation error:  0.02536946284634065
--- 13.136067152023315 seconds ---
Epoch: [6/20], Batch: 29, Loss: 0.03653713538825597
loss epoch [0.022847539214066574, 0.02676737041080087, 0.014674243342210304, 0.02169896509152942, 0.027067653821751997, 0.0311521724437708, 0.005829338157186629, 0.01274528611131303, 0.06734755147175037, 0.03291772250233015, 0.0202957924058074, 0.02507473320314364, 0.014347505151826704, 0.019054037348623917, 0.030417011903087862, 0.018472901558030157, 0.008081772184301948, 0.013029921551597733, 0.01097302985858031, 0.03781832237323787, 0.02097615762498562

 40%|████      | 8/20 [52:59<1:20:00, 400.07s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([5752,  534, 8917,  ..., 9664, 4177, 4225], device='cuda:0')
get_cat_feat_tgt time:  0.1066434383392334
Loss: 0.013299972427160585
rotation error:  32.0797098812562
translation error:  0.029529204432172128
--- 13.077393054962158 seconds ---
Epoch: [7/20], Batch: 29, Loss: 0.013299972427160585
loss epoch [0.009046642709162177, 0.02518252044079434, 0.019117031462908173, 0.02828089245442792, 0.04679292944935141, 0.035774575363963154, 0.037077189423695874, 0.030184974611180776, 0.031697054974133665, 0.014084687327012816, 0.03381483465675503, 0.016408179568269787, 0.04322488081998351, 0.027583707127355875, 0.0253670970233867, 0.029960223805959113, 0.018877387464396562, 0.06836606738530061, 0.03585792719479292, 0.01693729221277965, 0.0301669918652

 45%|████▌     | 9/20 [59:41<1:13:27, 400.67s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([1417,  767, 3367,  ..., 2109, 8511, 4538], device='cuda:0')
get_cat_feat_tgt time:  0.10265684127807617
Loss: 0.0574835013739924
rotation error:  405.57942999898415
translation error:  0.03686255515655941
--- 13.660361766815186 seconds ---
Epoch: [8/20], Batch: 29, Loss: 0.0574835013739924
loss epoch [0.028418443924203236, 0.017373306864638373, 0.009812418005550947, 0.018682167131832395, 0.012771643135278368, 0.010173817565733868, 0.009642459410108752, 0.01427880412024832, 0.020479691044446503, 0.0275374077280321, 0.04466951831118728, 0.029333711039448103, 0.014353267390553093, 0.023888489593048873, 0.020728756867420775, 0.012897700320745331, 0.008439174928817694, 0.011632862287649856, 0.012423261655567792, 0.013189042693834065, 0.022145956

 50%|█████     | 10/20 [1:06:22<1:06:48, 400.89s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([8927, 2810, 5931,  ..., 6978, 5817,  160], device='cuda:0')
get_cat_feat_tgt time:  0.09966683387756348
Loss: 0.022452335026382383
rotation error:  60.7746456553419
translation error:  0.01506343050783304
--- 13.175997734069824 seconds ---
Epoch: [9/20], Batch: 29, Loss: 0.022452335026382383
loss epoch [0.008368464731214366, 0.027614605422652726, 0.017518531705519848, 0.028978552060236223, 0.012272801267809027, 0.016231420987666983, 0.024849239444356512, 0.028245000120112634, 0.03652348566471342, 0.025000506673891358, 0.013648258120322155, 0.02152865091073064, 0.04198408341156861, 0.016532552750594256, 0.024375202562539593, 0.02438608410682589, 0.022618023345808605, 0.017539380364517312, 0.02323149970251026, 0.014300133296996958, 0.01346647

 55%|█████▌    | 11/20 [1:12:59<59:55, 399.56s/it]  

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([3262, 8372,  582,  ..., 1778, 7405, 5755], device='cuda:0')
get_cat_feat_tgt time:  0.08671092987060547
Loss: 0.040623515526476146
rotation error:  33.18703495662716
translation error:  0.03540406503474525
--- 13.111541986465454 seconds ---
Epoch: [10/20], Batch: 29, Loss: 0.040623515526476146
loss epoch [0.03184923537263086, 0.016999517018877894, 0.02478275541022197, 0.012781626658890129, 0.01792747403298464, 0.0357021906899534, 0.027096979721843363, 0.020806577907346527, 0.025048591836918658, 0.029006987846634566, 0.026718146201507597, 0.016344464474963714, 0.0291022805380406, 0.039367813870556184, 0.012467845646860867, 0.011339143348103282, 0.016508388016268964, 0.026540194645922957, 0.02474614307792126, 0.028587911157337303, 0.028285888

 60%|██████    | 12/20 [1:19:40<53:20, 400.01s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([2416, 1104, 3431,  ..., 5619, 6186, 3057], device='cuda:0')
get_cat_feat_tgt time:  0.09468293190002441
Loss: 0.016542560070822165
rotation error:  28.157979526634346
translation error:  0.03208302642574137
--- 13.913800954818726 seconds ---
Epoch: [11/20], Batch: 29, Loss: 0.016542560070822165
loss epoch [0.02007955630136877, 0.02643258368718618, 0.028654192239862444, 0.04084568805558997, 0.024854211873577718, 0.03962746409677126, 0.017284810212323764, 0.015285522420360895, 0.028714007709716324, 0.017792591127098082, 0.030757273789791078, 0.019631441692901506, 0.026504031643935166, 0.020411716281549773, 0.026238347837263785, 0.0419217713933252, 0.01330409527211701, 0.013747470138295045, 0.03357033590792619, 0.017725357270403044, 0.02245651

 65%|██████▌   | 13/20 [1:26:20<46:41, 400.16s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([5982, 9959,  721,  ..., 2672, 1314, 5124], device='cuda:0')
get_cat_feat_tgt time:  0.08571434020996094
Loss: 0.015533054213780689
rotation error:  264.09306255954255
translation error:  0.04380331215293736
--- 13.42535662651062 seconds ---
Epoch: [12/20], Batch: 29, Loss: 0.015533054213780689
loss epoch [0.01464849912879329, 0.01413929901015648, 0.01969215193023416, 0.016459341833589387, 0.02033620656739773, 0.026787642215245926, 0.03917527034069769, 0.021751921405658195, 0.02523360704658579, 0.032401698074138674, 0.01706559394556233, 0.02377978978965112, 0.010832447717542645, 0.00864882964628444, 0.016265357052814525, 0.01475084782476542, 0.03074257099939698, 0.016700961896462942, 0.013907870414409594, 0.013425785497920954, 0.041595210827

 70%|███████   | 14/20 [1:32:59<39:58, 399.71s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([3684, 2129, 3742,  ..., 3367, 9778, 3161], device='cuda:0')
get_cat_feat_tgt time:  0.0849766731262207
Loss: 0.04152813089361161
rotation error:  287.8931295726725
translation error:  0.03139353139625502
--- 12.951126098632812 seconds ---
Epoch: [13/20], Batch: 29, Loss: 0.04152813089361161
loss epoch [0.013752993047470456, 0.021756378956080923, 0.0159815054278195, 0.012302656020288294, 0.022184065854362844, 0.03839925016599266, 0.02156393995039497, 0.025361669498254615, 0.025969167915668508, 0.025346189948071116, 0.014359391075693548, 0.011882134901342307, 0.020930128143408895, 0.011855130033762335, 0.015941380990649526, 0.043707496492439074, 0.01426395663841501, 0.015921288508582876, 0.024609585960633604, 0.01867285015046808, 0.0121268635

 75%|███████▌  | 15/20 [1:39:36<33:14, 398.94s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([2980, 3566, 4834,  ..., 5227, 9056, 7349], device='cuda:0')
get_cat_feat_tgt time:  0.09780406951904297
Loss: 0.01837866612612684
rotation error:  37.70273438540242
translation error:  0.037671790733940884
--- 13.42911958694458 seconds ---
Epoch: [14/20], Batch: 29, Loss: 0.01837866612612684
loss epoch [0.017900353057017523, 0.023437845780272647, 0.029913176564888525, 0.026145477081910052, 0.023186280846046386, 0.032246108494527845, 0.04723787578407987, 0.03951691652516123, 0.026667202687278825, 0.014073985578106716, 0.026706194594720754, 0.020287152527260204, 0.01861096987726876, 0.020188712663402032, 0.014241761579382176, 0.03700647293489569, 0.02449886650221823, 0.04331728791484411, 0.022250424021964487, 0.02464140314460531, 0.0300857684

 80%|████████  | 16/20 [1:46:13<26:33, 398.27s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([ 811,  598, 3860,  ..., 9140, 3680, 1127], device='cuda:0')
get_cat_feat_tgt time:  0.0954594612121582
Loss: 0.022521945392092155
rotation error:  163.07283757060202
translation error:  0.023518942690128962
--- 13.069441080093384 seconds ---
Epoch: [15/20], Batch: 29, Loss: 0.022521945392092155
loss epoch [0.02604088750815914, 0.030707292406077513, 0.01696351607758962, 0.027207370750909073, 0.009541423250043118, 0.018171706780779083, 0.011517289424227379, 0.021830687917886466, 0.016424819522128972, 0.03331069959952723, 0.04388710702939548, 0.03973297307963378, 0.054151679321463, 0.048110293361269076, 0.04461092853437123, 0.036934031833274504, 0.018261732091890547, 0.016359530117292527, 0.06156189654123226, 0.03548389096919405, 0.02006999210

 85%|████████▌ | 17/20 [1:52:53<19:56, 398.72s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([ 560, 8633, 1355,  ..., 5870, 5538, 4492], device='cuda:0')
get_cat_feat_tgt time:  0.08308577537536621
Loss: 0.014325641424865544
rotation error:  34.309911101513315
translation error:  0.060076482529901276
--- 13.16728162765503 seconds ---
Epoch: [16/20], Batch: 29, Loss: 0.014325641424865544
loss epoch [0.025348771196000906, 0.005411869638293071, 0.02878697310323378, 0.036242078641102494, 0.015594022702629225, 0.016314281473579804, 0.02068078444749469, 0.03496032568481725, 0.006317828492929233, 0.024662870488012908, 0.015349691138671982, 0.018804062068884665, 0.04580968029880959, 0.02195077471649327, 0.021676462529851604, 0.019623562704894154, 0.01686891490089564, 0.009412577354974666, 0.009819282661123939, 0.019521239615892257, 0.045117

 90%|█████████ | 18/20 [1:59:30<13:16, 398.26s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([7523, 8474,  733,  ..., 7223, 3953, 9460], device='cuda:0')
get_cat_feat_tgt time:  0.08597040176391602
Loss: 0.007787480133105397
rotation error:  14.748453978097801
translation error:  0.015778939050414314
--- 13.296915054321289 seconds ---
Epoch: [17/20], Batch: 29, Loss: 0.007787480133105397
loss epoch [0.017988467901074924, 0.014929597536992812, 0.014051644811355814, 0.01346862373606823, 0.014149149689632297, 0.033352138252149, 0.04334540281837509, 0.02447982354359341, 0.012078266919527172, 0.021429760345037893, 0.013378007828669979, 0.01737637031775681, 0.011026565920312516, 0.022637305354800692, 0.028261344196979732, 0.01623319267742608, 0.0188045682163005, 0.015341042675578197, 0.022103298123521246, 0.023889053390875443, 0.022076558

 95%|█████████▌| 19/20 [2:06:07<06:37, 397.91s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([3590, 7228, 6686,  ..., 7190,  289,   86], device='cuda:0')
get_cat_feat_tgt time:  0.1110990047454834
Loss: 0.033061196890197
rotation error:  294.74161455485097
translation error:  0.0238951847704778
--- 13.19560718536377 seconds ---
Epoch: [18/20], Batch: 29, Loss: 0.033061196890197
loss epoch [0.01108912998780085, 0.01053433839196433, 0.02100991841767452, 0.014903463949402711, 0.039447512794041774, 0.02762999840900294, 0.027356327533015146, 0.02218421463246936, 0.01296949571962638, 0.02678747499545726, 0.01164322707938738, 0.015927034019823726, 0.03128199522034543, 0.01515188123501748, 0.02879241551955447, 0.00530612015458646, 0.030246563203156857, 0.013640715957048763, 0.00951324159605807, 0.011438359897282143, 0.032720652508936784, 0.

100%|██████████| 20/20 [2:12:48<00:00, 398.45s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([ 333, 8645, 2555,  ..., 2304, 2051,  859], device='cuda:0')
get_cat_feat_tgt time:  0.08765649795532227
Loss: 0.00925053492557177
rotation error:  25.303267896419555
translation error:  0.033886668592071414
--- 13.021884679794312 seconds ---
Epoch: [19/20], Batch: 29, Loss: 0.00925053492557177
loss epoch [0.019930899763159, 0.018847588280161913, 0.022835652000832582, 0.02876474134440855, 0.03217326517464673, 0.017569888328435722, 0.015523842212944151, 0.010062925407132235, 0.029266787283699768, 0.020479908945930863, 0.015226771188028633, 0.026073182897309914, 0.020838413084548285, 0.016122840347331045, 0.013217706228020838, 0.010255050661525617, 0.023686653465880555, 0.020038702524214073, 0.02482146924131238, 0.013615765945689588, 0.0124387




In [10]:
# save
print("Finished Training")

Finished Training
