In [3]:
import os
import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from scipy.spatial.transform import Rotation as R
import time
import pickle
import argparse
from utils import *

from deepVCP import DeepVCP
from CustomDataset2 import CustomDataset2
from deepVCP_loss import deepVCP_loss
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from tqdm import tqdm

In [2]:
dataset = 'CustomDataset2'
retrain_path = 'store'
model_path = 'Optimise_final_model.pt'
full_dataset = True

In [3]:
num_epochs = 20
batch_size = 1
lr = 0.001
# loss balancing factor 
alpha = 0.5

print(f"Params: epochs: {num_epochs}, batch: {batch_size}, lr: {lr}, alpha: {alpha}\n")

# check if cuda is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}")


Params: epochs: 20, batch: 1, lr: 0.001, alpha: 0.5

device: cuda


In [4]:
root = 'meshes/'
train_data= CustomDataset2(root=root)

Processing bun000_v2.ply
(40245, 3)
Processing bun045_v2.ply
(40091, 3)
Processing bun090_v2.ply
(30373, 3)
Processing bun180_v2.ply
(40247, 3)
Processing bun270_v2.ply
(31697, 3)
Processing bun315_v2.ply
(35334, 3)
# Total clouds 6


In [5]:
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)

for n_batch, (src, target, R_gt, t_gt, ) in enumerate(train_loader):
    print(src.shape, target.shape)
    print(R_gt, t_gt)
    break


num_train = len(train_data)
print('Train dataset size: ', num_train)

use_normal = dataset == "modelnet"

# Initialize the model
model = DeepVCP(use_normal=use_normal)
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = nn.DataParallel(model)

model.to(device)

torch.Size([1, 3, 10000]) torch.Size([1, 3, 10000])
tensor([[[ 0.7743, -0.3514,  0.5262],
         [ 0.4306,  0.9020, -0.0313],
         [-0.4636,  0.2509,  0.8498]]], dtype=torch.float64) tensor([[[0.0742],
         [0.0742],
         [0.0742]]], dtype=torch.float64)
Train dataset size:  30


DeepVCP(
  (FE1): feat_extraction_layer(
    (sa1): PointNetSetAbstraction(
      (mlp_convs): ModuleList(
        (0): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
        (2): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (mlp_bns): ModuleList(
        (0-1): 2 x BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (sa2): PointNetSetAbstraction(
      (mlp_convs): ModuleList(
        (0): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (mlp_bns): ModuleList(
        (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (sa3): PointNetSetAbstraction(

In [6]:
optim = Adam(model.parameters(), lr=lr)

# begin train 
model.train()
loss_epoch_avg = []
for epoch in tqdm(range(num_epochs)):
    print(f"epoch #{epoch}")
    loss_epoch = []
    running_loss = 0.0

    for n_batch, (src, target, R_gt, t_gt, ) in enumerate(train_loader):
        start_time = time.time()
        # mini batch
        src, target, R_gt, t_gt = src.to(device), target.to(device), R_gt.to(device), t_gt.to(device)
        t_init = torch.randn((1, 3))
        R_prior = R_gt.clone()
        corruption = torch.FloatTensor(batch_size, 3, 3).uniform_(0.9, 1.1).cuda()
        R_prior = R_prior*corruption
        src_keypts, target_vcp = model(src, target, R_prior, t_init)
        # print('src_keypts shape', src_keypts.shape)
        # print('target_vcp shape', target_vcp.shape)
        # zero gradient 
        optim.zero_grad()
        loss, R_pred, t_pred = deepVCP_loss(src_keypts, target_vcp, R_gt, t_gt, alpha=0.5)


        # error metric for rigid body transformation
        r_pred = R.from_matrix(R_pred.squeeze(0).cpu().detach().numpy())
        r_pred_arr = torch.tensor(r_pred.as_euler('xyz', degrees=True)).reshape(1, 3)
        r_gt = R.from_matrix(R_gt.squeeze(0).cpu().detach().numpy())
        r_gt_arr = torch.tensor(r_gt.as_euler('xyz', degrees=True)).reshape(1, 3)
        pdist = nn.PairwiseDistance(p = 2)
        t_pred = t_pred.squeeze(-1)
        t_gt = t_gt.squeeze(-1)

        print("rotation error: ", pdist(r_pred_arr, r_gt_arr).mean())
        print("translation error: ", pdist(t_pred, t_gt).mean())

        # backward pass
        loss.backward()
        # update parameters 
        optim.step()

        running_loss += loss.item()
        loss_epoch += [loss.item()]
        print("--- %s seconds ---" % (time.time() - start_time))
        if (n_batch + 1) % 5 == 0:
            print("Epoch: [{}/{}], Batch: {}, Loss: {}".format(
                epoch, num_epochs, n_batch, loss.item()))
            running_loss = 0.0

    torch.save(model.state_dict(), "epoch_" + str(epoch) + "_model.pt")
    loss_epoch_avg += [sum(loss_epoch) / len(loss_epoch)]
    print("loss epoch", loss_epoch)
    with open("training_loss_" + str(epoch) + ".txt", "wb") as fp:   #Pickling
        pickle.dump(loss_epoch, fp)


  0%|          | 0/20 [00:00<?, ?it/s]

epoch #0
feature extraction time:  7.882020473480225
src_keypts_idx_unsqueezed:  torch.Size([1, 3, 64])
src_keypts:  torch.Size([1, 64, 3])
Grouping keypoints time:  0.013999700546264648
B:  1
K_topk:  64
nsample:  32
num_feat:  32
get_cat_feat_src time:  0.0


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([9579, 8410,  355,  ..., 5979, 3797, 3177], device='cuda:0')
get_cat_feat_tgt time:  0.100982666015625
Loss: 0.27637573912479096
rotation error:  tensor(3.1098, dtype=torch.float64)
translation error:  tensor(0.4787, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 15.347990989685059 seconds ---
feature extraction time:  6.659988164901733
src_keypts_idx_unsqueezed:  torch.Size([1, 3, 64])
src_keypts:  torch.Size([1, 64, 3])
Grouping keypoints time:  0.021999835968017578
B:  1
K_topk:  64
nsample:  32
num_feat:  32
get_cat_feat_src time:  0.0
tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  ten

  5%|▌         | 1/20 [06:38<2:06:17, 398.82s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([7602, 5331, 4691,  ..., 9096, 4951, 6647], device='cuda:0')
get_cat_feat_tgt time:  0.10900020599365234
Loss: 0.01725959058851028
rotation error:  tensor(1.6776, dtype=torch.float64)
translation error:  tensor(0.0325, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 13.094977378845215 seconds ---
Epoch: [0/20], Batch: 29, Loss: 0.01725959058851028
loss epoch [0.27637573912479096, 0.27459086797206944, 0.2652749522319714, 0.2606825147270588, 0.24915880266668983, 0.23499054500243122, 0.21185771209261753, 0.1778115809094979, 0.1421393383286807, 0.0889993066817635, 0.023345632686382727, 0.046724583807539855, 0.07894302370666809, 0.08599324252742525, 0.07001036745033014, 0.04650262087204836, 0.03056697957241051, 0.01287928073637

 10%|█         | 2/20 [13:17<1:59:41, 398.95s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([6477, 7129, 3122,  ..., 9980, 8826, 9047], device='cuda:0')
get_cat_feat_tgt time:  0.10899972915649414
Loss: 0.015012037196973806
rotation error:  tensor(1.2423, dtype=torch.float64)
translation error:  tensor(0.0302, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 13.33097767829895 seconds ---
Epoch: [1/20], Batch: 29, Loss: 0.015012037196973806
loss epoch [0.015304348836430294, 0.0110405396049468, 0.008237461626192467, 0.010255235907166406, 0.01191419168291328, 0.0013734941506750975, 0.0014994653506217944, 0.0051812526089602615, 0.008459169051559181, 0.005267598864951439, 0.005038349044157623, 0.003981297729960101, 0.003440634483102438, 0.0028016310710255177, 0.004582461646532388, 0.0030136540907688483, 0.0039467611331

 15%|█▌        | 3/20 [20:14<1:55:22, 407.19s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([9310, 3017, 7048,  ..., 7598,  136, 5473], device='cuda:0')
get_cat_feat_tgt time:  0.09500002861022949
Loss: 0.021952892934289828
rotation error:  tensor(1.2952, dtype=torch.float64)
translation error:  tensor(0.0380, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 13.071055173873901 seconds ---
Epoch: [2/20], Batch: 29, Loss: 0.021952892934289828
loss epoch [0.013189813578130923, 0.004735353081428427, 0.011944040788760562, 0.01584228274318112, 0.017272541857352636, 0.0142005744556047, 0.006109897775853867, 0.007682343850072219, 0.019174093174585453, 0.023103579305540724, 0.015067824712314517, 0.006429770241871087, 0.011772557115718785, 0.021414599788516512, 0.020045780102800982, 0.018764405272305062, 0.00898838137228832

 20%|██        | 4/20 [26:48<1:47:07, 401.74s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([ 924, 2947, 6173,  ..., 3514, 3612,   68], device='cuda:0')
get_cat_feat_tgt time:  0.10699987411499023
Loss: 0.006482177476594958
rotation error:  tensor(4.1399, dtype=torch.float64)
translation error:  tensor(0.0134, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 12.90297794342041 seconds ---
Epoch: [3/20], Batch: 29, Loss: 0.006482177476594958
loss epoch [0.0158796336149726, 0.010003935028530656, 0.003213993027722013, 0.011676942245211621, 0.011576225598629006, 0.010964341502855697, 0.003525249181574477, 0.002570302574666042, 0.007865145221873968, 0.006526691479395701, 0.006151153392007927, 0.005570342322794722, 0.0047010121853871805, 0.004124436260106454, 0.006585682689857121, 0.01002236984502039, 0.00442257625565304

 25%|██▌       | 5/20 [33:24<1:39:54, 399.64s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([8875, 8884, 9009,  ..., 1663, 1730, 1526], device='cuda:0')
get_cat_feat_tgt time:  0.1100008487701416
Loss: 0.012766972609331082
rotation error:  tensor(0.8555, dtype=torch.float64)
translation error:  tensor(0.0230, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 13.200085878372192 seconds ---
Epoch: [4/20], Batch: 29, Loss: 0.012766972609331082
loss epoch [0.0029675925463045496, 0.0033533933554441633, 0.002049991084822852, 0.0029180902241660485, 0.0036255444926530083, 0.004430667185283699, 0.0032945563607307275, 0.0017140268426869179, 0.0033398977785971933, 0.004856192603693638, 0.0034489067241628537, 0.00556987833658938, 0.0034193700661544835, 0.004402988089033466, 0.007110837587510697, 0.0047901148026827985, 0.002458

 30%|███       | 6/20 [40:44<1:36:27, 413.36s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([ 796, 5327, 4056,  ..., 9794, 2688, 5948], device='cuda:0')
get_cat_feat_tgt time:  0.11299991607666016
Loss: 0.012846002908631955
rotation error:  tensor(2.6168, dtype=torch.float64)
translation error:  tensor(0.0222, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 13.12019944190979 seconds ---
Epoch: [5/20], Batch: 29, Loss: 0.012846002908631955
loss epoch [0.019185223184028672, 0.016249654886637925, 0.008818113446767879, 0.0024970416908455455, 0.009944636685264274, 0.015001024185377937, 0.009392415753777666, 0.0033975100046553363, 0.013250423438391366, 0.01721033612183267, 0.020917624882227003, 0.015191977478438467, 0.006954103485634357, 0.00953969723041127, 0.017112282747488102, 0.016955020332256283, 0.013646721739456

 35%|███▌      | 7/20 [47:25<1:28:41, 409.36s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([8267, 1401, 7659,  ..., 4080, 5132, 1741], device='cuda:0')
get_cat_feat_tgt time:  0.10599994659423828
Loss: 0.018034960045012798
rotation error:  tensor(1.8835, dtype=torch.float64)
translation error:  tensor(0.0351, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 13.54501748085022 seconds ---
Epoch: [6/20], Batch: 29, Loss: 0.018034960045012798
loss epoch [0.014665431842361859, 0.009777865825837296, 0.0039779150842124975, 0.006892169735654823, 0.012926556206751925, 0.011470944648548293, 0.011174788333815043, 0.004327094721947877, 0.01256163036640645, 0.01678431789823493, 0.02113594954741608, 0.01698822507691826, 0.007164676838786787, 0.006243103766620408, 0.0142850204781852, 0.019304895667826148, 0.019170108339852988, 

 40%|████      | 8/20 [54:02<1:21:04, 405.41s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([3713,  245, 9118,  ..., 5979, 3789, 9675], device='cuda:0')
get_cat_feat_tgt time:  0.10500025749206543
Loss: 0.011806767811530724
rotation error:  tensor(3.3000, dtype=torch.float64)
translation error:  tensor(0.0218, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 12.732006072998047 seconds ---
Epoch: [7/20], Batch: 29, Loss: 0.011806767811530724
loss epoch [0.017310831849713854, 0.012735595952570393, 0.006372611346323859, 0.007341488091073524, 0.01193774593949216, 0.014072308104580966, 0.008797098088717389, 0.008182448113708421, 0.006265470294557756, 0.00912082266661576, 0.014480716332187312, 0.009513055774499922, 0.0052110429576554575, 0.007523833856928168, 0.01303077603645956, 0.009181576725703428, 0.0089804897644522

 45%|████▌     | 9/20 [1:01:10<1:15:38, 412.58s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([4683,  362, 1211,  ..., 9583, 2500, 2766], device='cuda:0')
get_cat_feat_tgt time:  0.11100029945373535
Loss: 0.0064648287538249015
rotation error:  tensor(2.7186, dtype=torch.float64)
translation error:  tensor(0.0128, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 13.457998991012573 seconds ---
Epoch: [8/20], Batch: 29, Loss: 0.0064648287538249015
loss epoch [0.017155433568677547, 0.01933062676162459, 0.01761454901843912, 0.006872480054112908, 0.005854415259117767, 0.013271951927057574, 0.012822319667181628, 0.008818736943898562, 0.004337543517568891, 0.006187963329363501, 0.014808708505635378, 0.015509065723872752, 0.010425292488694018, 0.005475714918602061, 0.007979814301550729, 0.013833917247743539, 0.01269137490381

 50%|█████     | 10/20 [1:07:52<1:08:12, 409.29s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([6091, 7794, 7976,  ..., 8674, 3974, 5328], device='cuda:0')
get_cat_feat_tgt time:  0.10400032997131348
Loss: 0.006860575897432112
rotation error:  tensor(3.9676, dtype=torch.float64)
translation error:  tensor(0.0175, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 13.265988826751709 seconds ---
Epoch: [9/20], Batch: 29, Loss: 0.006860575897432112
loss epoch [0.00610172938959257, 0.004735450244568139, 0.004350349946263871, 0.0036831781828209533, 0.002661374279812095, 0.0035453290530784243, 0.005324546407360797, 0.005751136174706387, 0.006160556671602464, 0.005414168613978572, 0.004196462911908228, 0.0029545129421463243, 0.001797460903567104, 0.006836998244054334, 0.006367434876914928, 0.0051402296046927885, 0.00649220407

 55%|█████▌    | 11/20 [1:14:27<1:00:44, 404.91s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([9555, 8025, 6800,  ..., 6766, 3807, 3469], device='cuda:0')
get_cat_feat_tgt time:  0.10200023651123047
Loss: 0.006965990801277942
rotation error:  tensor(1.9664, dtype=torch.float64)
translation error:  tensor(0.0190, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 13.436017990112305 seconds ---
Epoch: [10/20], Batch: 29, Loss: 0.006965990801277942
loss epoch [0.006703719684595865, 0.005439987332597616, 0.0036207911899333992, 0.0020483199300100905, 0.005078478639092475, 0.0030735073566192624, 0.003378241583074519, 0.007745081196146165, 0.0016193515476732342, 0.0019867353468854625, 0.0015297810116109544, 0.002853150014468885, 0.007649392937534421, 0.003777359759470069, 0.0025961820311342964, 0.005644612912526302, 0.006076

 60%|██████    | 12/20 [1:21:11<53:56, 404.62s/it]  

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([5119, 3498, 9737,  ..., 3058, 5769, 1023], device='cuda:0')
get_cat_feat_tgt time:  0.08999991416931152
Loss: 0.015162313869089014
rotation error:  tensor(2.1353, dtype=torch.float64)
translation error:  tensor(0.0272, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 13.416018009185791 seconds ---
Epoch: [11/20], Batch: 29, Loss: 0.015162313869089014
loss epoch [0.004689541691193531, 0.0048319977210094876, 0.004256640624889801, 0.006205275195063165, 0.0018586856833989528, 0.005967111851843582, 0.010408254303091006, 0.005928998555954523, 0.005003808553182159, 0.009594636519148441, 0.018297679197407088, 0.01535883788715715, 0.010767284101048708, 0.001441085707316378, 0.011016382077332637, 0.0142679232832997, 0.01672504026895

 65%|██████▌   | 13/20 [1:28:01<47:23, 406.21s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([9810, 9415, 6403,  ..., 9039, 1480, 9972], device='cuda:0')
get_cat_feat_tgt time:  0.10699939727783203
Loss: 0.0054471112405891494
rotation error:  tensor(2.6383, dtype=torch.float64)
translation error:  tensor(0.0102, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 13.356982469558716 seconds ---
Epoch: [12/20], Batch: 29, Loss: 0.0054471112405891494
loss epoch [0.012968825509001752, 0.008398630362854294, 0.00518110581014661, 0.00957516015327645, 0.014212056809595341, 0.01914414680684922, 0.01515509096437437, 0.013890488357341966, 0.005992130850844171, 0.007854100432102329, 0.013300717278485995, 0.010699442130847527, 0.005678285894765999, 0.005648527970683411, 0.008184397930042902, 0.008247324766633441, 0.006073784063487

 70%|███████   | 14/20 [1:34:53<40:48, 408.02s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([8081,  823, 2361,  ..., 8753, 7719, 6208], device='cuda:0')
get_cat_feat_tgt time:  0.09599947929382324
Loss: 0.004812662681292307
rotation error:  tensor(1.0477, dtype=torch.float64)
translation error:  tensor(0.0179, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 13.244981527328491 seconds ---
Epoch: [13/20], Batch: 29, Loss: 0.004812662681292307
loss epoch [0.004415447711734919, 0.001402557952964134, 0.003729773839041486, 0.002733008887925368, 0.00475278850914546, 0.0036440260733006293, 0.0030430802179922676, 0.006708188305916166, 0.0030624782690873136, 0.0028333610844978744, 0.003514958408921179, 0.0020772057640296012, 0.004698642194474918, 0.0040956934023159385, 0.0030365205322921327, 0.005452557103012763, 0.0058009

 75%|███████▌  | 15/20 [1:41:44<34:04, 408.87s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([5569, 3640,   24,  ..., 1463, 3798, 3643], device='cuda:0')
get_cat_feat_tgt time:  0.1100006103515625
Loss: 0.0048540046998841666
rotation error:  tensor(1.7683, dtype=torch.float64)
translation error:  tensor(0.0126, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 14.034082889556885 seconds ---
Epoch: [14/20], Batch: 29, Loss: 0.0048540046998841666
loss epoch [0.007876800034868823, 0.009035819702609485, 0.005643273074642763, 0.004853158452935812, 0.008895330927165034, 0.006352823661125887, 0.004169696156819842, 0.003408607710060874, 0.00822648356109527, 0.008029455177065273, 0.007804110300959835, 0.004147501292112182, 0.0044852327338105085, 0.004476364973446212, 0.0047524585175974824, 0.006022810301438618, 0.00820781220

 80%|████████  | 16/20 [1:48:31<27:13, 408.26s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([4492, 1778, 3312,  ..., 1181, 2045, 3025], device='cuda:0')
get_cat_feat_tgt time:  0.09131550788879395
Loss: 0.0030806187220169252
rotation error:  tensor(3.8653, dtype=torch.float64)
translation error:  tensor(0.0095, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 13.403400659561157 seconds ---
Epoch: [15/20], Batch: 29, Loss: 0.0030806187220169252
loss epoch [0.0032380096422005597, 0.004904478879993826, 0.006873570106394871, 0.0018001746014341402, 0.006835925387695884, 0.008003820221760745, 0.0032204777297918695, 0.0028168623594712686, 0.0055971028017717025, 0.005425255768000554, 0.007609782913340728, 0.008814050582780359, 0.004070503109275456, 0.0025113407486346386, 0.004548018242832805, 0.004316843108865058, 0.00398

 85%|████████▌ | 17/20 [1:55:23<20:28, 409.62s/it]

loss epoch [0.005473708862111491, 0.0028370990035789106, 0.003924253351731269, 0.004248186939485811, 0.00329185616295658, 0.0037968272689274055, 0.002667597770586223, 0.005256153343627201, 0.003984135670244375, 0.004280072132733629, 0.00480606266577634, 0.0036718133238766433, 0.005713854262483097, 0.006839053271409752, 0.004239610979390999, 0.007740472881233688, 0.00683441290113994, 0.003973129560849942, 0.005307149689428434, 0.005667454042964632, 0.004749643589139332, 0.004853079009014543, 0.004132477122540635, 0.00212523167359343, 0.005992572204794442, 0.016216186360442968, 0.0088491942229393, 0.007704541025791169, 0.005258917179875533, 0.005689854186658676]
epoch #17
feature extraction time:  6.896001577377319
src_keypts_idx_unsqueezed:  torch.Size([1, 3, 64])
src_keypts:  torch.Size([1, 64, 3])
Grouping keypoints time:  0.021000146865844727
B:  1
K_topk:  64
nsample:  32
num_feat:  32
get_cat_feat_src time:  0.0
tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000

 90%|█████████ | 18/20 [2:02:17<13:41, 410.92s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([6966, 4313, 1675,  ..., 8593, 9829, 6680], device='cuda:0')
get_cat_feat_tgt time:  0.11700034141540527
Loss: 0.0072778938010690485
rotation error:  tensor(2.3618, dtype=torch.float64)
translation error:  tensor(0.0171, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 13.82204532623291 seconds ---
Epoch: [17/20], Batch: 29, Loss: 0.0072778938010690485
loss epoch [0.007101844453778752, 0.0053059041301837424, 0.003577665635534089, 0.008198778137431029, 0.01025476023662067, 0.00899944269687792, 0.004555304080557743, 0.0076846761745672666, 0.010440039772362405, 0.009368126965572158, 0.013893098527825474, 0.009010420561295718, 0.0019767683554266753, 0.014755546026550792, 0.020468534645466267, 0.021674740556916147, 0.01879588804

 95%|█████████▌| 19/20 [2:09:09<06:51, 411.25s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([3514, 2237, 1997,  ..., 1010,   11, 3676], device='cuda:0')
get_cat_feat_tgt time:  0.10899996757507324
Loss: 0.00913322552420936
rotation error:  tensor(2.6720, dtype=torch.float64)
translation error:  tensor(0.0153, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 14.24802541732788 seconds ---
Epoch: [18/20], Batch: 29, Loss: 0.00913322552420936
loss epoch [0.004158405221119931, 0.008284845365978238, 0.006362304733642519, 0.008818809417136093, 0.005600680331517733, 0.008279658196878712, 0.005996144963810844, 0.007752921791168706, 0.006128740131422974, 0.005819107442149906, 0.010086421061762555, 0.007461287455531037, 0.00487284186529633, 0.0050385199816559145, 0.006164372838116749, 0.005391919337560261, 0.0050891567244544

100%|██████████| 20/20 [2:16:03<00:00, 408.16s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([ 733, 7868, 3823,  ..., 9569, 8071, 4259], device='cuda:0')
get_cat_feat_tgt time:  0.1100003719329834
Loss: 0.010870042277999254
rotation error:  tensor(2.3333, dtype=torch.float64)
translation error:  tensor(0.0183, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
--- 13.630985260009766 seconds ---
Epoch: [19/20], Batch: 29, Loss: 0.010870042277999254
loss epoch [0.0024786231472144864, 0.003632966313616842, 0.002870603620226793, 0.0026043047959128326, 0.004516337857406418, 0.00420837519842149, 0.0027043296125505377, 0.004761109324859243, 0.006086536667981743, 0.007174659715140915, 0.00720774022177792, 0.002703543520521084, 0.0011873903168894318, 0.0034071105618079787, 0.004332734126124681, 0.0048197837961049695, 0.0041780059




In [None]:
# save
print("Finished Training")
torch.save(model.state_dict(), model_path)

Finished Training
