In [1]:
import os
import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from scipy.spatial.transform import Rotation as R
import time
import pickle
import argparse
from utils import *

from deepVCP import DeepVCP
from CustomDataset2 import CustomDataset2
from deepVCP_loss import deepVCP_loss
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from tqdm import tqdm

In [2]:
dataset = 'CustomDataset2'
retrain_path = 'store'
model_path = 'Optimise_final_model.pt'
full_dataset = True

In [3]:
num_epochs = 20
batch_size = 1
lr = 0.001
# loss balancing factor 
alpha = 0.5

print(f"Params: epochs: {num_epochs}, batch: {batch_size}, lr: {lr}, alpha: {alpha}\n")

# check if cuda is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}")


Params: epochs: 20, batch: 1, lr: 0.001, alpha: 0.5

device: cuda


In [4]:
root = 'meshes/'
train_data= CustomDataset2(root=root)

Processing bun000_v2.ply
(40245, 3)
Processing bun045_v2.ply
(40091, 3)
Processing bun090_v2.ply
(30373, 3)
Processing bun180_v2.ply
(40247, 3)
Processing bun270_v2.ply
(31697, 3)
Processing bun315_v2.ply
(35334, 3)
# Total clouds 6


In [5]:
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)

num_train = len(train_data)
print('Train dataset size: ', num_train)

use_normal = dataset == "modelnet"

# Initialize the model
model = DeepVCP(use_normal=use_normal)
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = nn.DataParallel(model)

model.to(device)

Train dataset size:  30


DeepVCP(
  (FE1): feat_extraction_layer(
    (sa1): PointNetSetAbstraction(
      (mlp_convs): ModuleList(
        (0): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
        (2): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (mlp_bns): ModuleList(
        (0-1): 2 x BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (sa2): PointNetSetAbstraction(
      (mlp_convs): ModuleList(
        (0): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (mlp_bns): ModuleList(
        (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (sa3): PointNetSetAbstraction(

In [6]:
optim = Adam(model.parameters(), lr=lr)

# begin train 
model.train()
loss_epoch_avg = []
for epoch in tqdm(range(num_epochs)):
    print(f"epoch #{epoch}")
    loss_epoch = []
    running_loss = 0.0

    for n_batch, (src, target, R_gt, t_gt, R_prior) in enumerate(train_loader):
        start_time = time.time()
        # mini batch
        src, target, R_gt, t_gt, R_prior = src.to(device), target.to(device), R_gt.to(device), t_gt.to(device), R_prior.to(device)
        t_init = torch.randn((1, 3))
        src_keypts, target_vcp = model(src, target, R_prior, t_init)
        # print('src_keypts shape', src_keypts.shape)
        # print('target_vcp shape', target_vcp.shape)
        # zero gradient 
        optim.zero_grad()
        loss, R_pred, t_pred = deepVCP_loss(src_keypts, target_vcp, R_gt, t_gt, alpha=0.5)


        # error metric for rigid body transformation
        r_pred = R.from_matrix(R_pred.squeeze(0).cpu().detach().numpy())
        r_pred_arr = torch.tensor(r_pred.as_euler('xyz', degrees=True)).reshape(1, 3)
        r_gt = R.from_matrix(R_gt.squeeze(0).cpu().detach().numpy())
        r_gt_arr = torch.tensor(r_gt.as_euler('xyz', degrees=True)).reshape(1, 3)
        pdist = nn.PairwiseDistance(p = 2)
        t_pred = t_pred.squeeze(-1)
        t_gt = t_gt.squeeze(-1)

        print("rotation error: ", pdist(r_pred_arr, r_gt_arr).item())
        print("translation error: ", pdist(t_pred, t_gt).item())

        # backward pass
        loss.backward()
        # update parameters 
        optim.step()

        running_loss += loss.item()
        loss_epoch += [loss.item()]
        print("--- %s seconds ---" % (time.time() - start_time))
        if (n_batch + 1) % 5 == 0:
            print("Epoch: [{}/{}], Batch: {}, Loss: {}".format(
                epoch, num_epochs, n_batch, loss.item()))
            running_loss = 0.0

    torch.save(model.state_dict(), "epoch_" + str(epoch) + "_model.pt")
    loss_epoch_avg += [sum(loss_epoch) / len(loss_epoch)]
    print("loss epoch", loss_epoch)
    with open("training_loss_" + str(epoch) + ".txt", "wb") as fp:   #Pickling
        pickle.dump(loss_epoch, fp)


  0%|          | 0/20 [00:00<?, ?it/s]

epoch #0
feature extraction time:  11.372282028198242
src_keypts_idx_unsqueezed:  torch.Size([1, 3, 64])
src_keypts:  torch.Size([1, 64, 3])
Grouping keypoints time:  0.025913476943969727
B:  1
K_topk:  64
nsample:  32
num_feat:  32
get_cat_feat_src time:  0.0009963512420654297


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([3542,    9, 2699,  ..., 8625, 2389, 1841], device='cuda:0')
get_cat_feat_tgt time:  0.09368777275085449
Loss: 0.27215408631277443
rotation error:  39.490341293767734
translation error:  0.47490653265780325
--- 22.766135931015015 seconds ---
feature extraction time:  11.084944486618042
src_keypts_idx_unsqueezed:  torch.Size([1, 3, 64])
src_keypts:  torch.Size([1, 64, 3])
Grouping keypoints time:  0.03189349174499512
B:  1
K_topk:  64
nsample:  32
num_feat:  32
get_cat_feat_src time:  0.0
tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([5413, 23

  5%|▌         | 1/20 [11:02<3:29:54, 662.88s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([6352, 6760, 7495,  ..., 4518, 9850, 3374], device='cuda:0')
get_cat_feat_tgt time:  0.10564661026000977
Loss: 0.02839139061696496
rotation error:  60.91031040204419
translation error:  0.052281636713803245
--- 21.997391939163208 seconds ---
Epoch: [0/20], Batch: 29, Loss: 0.02839139061696496
loss epoch [0.27215408631277443, 0.2626272534105449, 0.2759550332284413, 0.2569048879682624, 0.25531831264536714, 0.2626025544007037, 0.21656447199963555, 0.17483902541840274, 0.1430070463635794, 0.07406528233621232, 0.04105310204861015, 0.05715398043256162, 0.07417054561848246, 0.09242938907047425, 0.07078995325142337, 0.09345670959039726, 0.039199910567100725, 0.023604600978700783, 0.015499640470489778, 0.02550825338011013, 0.04486942788566296, 0.0715

 10%|█         | 2/20 [22:06<3:18:57, 663.18s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([3020, 9258, 5486,  ..., 2582, 5427, 6564], device='cuda:0')
get_cat_feat_tgt time:  0.10842609405517578
Loss: 0.018983140298821727
rotation error:  44.24672326662384
translation error:  0.029458435148865814
--- 22.244384050369263 seconds ---
Epoch: [1/20], Batch: 29, Loss: 0.018983140298821727
loss epoch [0.04568846820467607, 0.012761129136200616, 0.010121664667826272, 0.011571962497832801, 0.026447941967756263, 0.026090647680229673, 0.02186965756519354, 0.02676050733153218, 0.03915574585317827, 0.02810654745781182, 0.015405710577113513, 0.010516170716419294, 0.0346670137914504, 0.026853243332122623, 0.02967634155817417, 0.02037993821544082, 0.030125084030481416, 0.033549867511366016, 0.03441587532639362, 0.027837107744451257, 0.02263311848

 15%|█▌        | 3/20 [33:11<3:08:03, 663.73s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([2763, 4595, 8770,  ..., 8875,  562,  271], device='cuda:0')
get_cat_feat_tgt time:  0.09343242645263672
Loss: 0.021511585208157057
rotation error:  60.10495370988108
translation error:  0.0199224900767446
--- 22.303500175476074 seconds ---
Epoch: [2/20], Batch: 29, Loss: 0.021511585208157057
loss epoch [0.02654265852315514, 0.04564492406069811, 0.036653326705654896, 0.02693318010646642, 0.03370382818798216, 0.029663301491645696, 0.038137891323431344, 0.02296433344886259, 0.045562064355537446, 0.04029032078102505, 0.029167343392451653, 0.017023859978664534, 0.024965274848978213, 0.014393721395655241, 0.015556347940332196, 0.03502335742532366, 0.023208722085231948, 0.0212899967720644, 0.02771889507081896, 0.016112979741246362, 0.0268079410759

 20%|██        | 4/20 [44:17<2:57:08, 664.28s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([2580, 1495, 5776,  ..., 2884, 8300, 9513], device='cuda:0')
get_cat_feat_tgt time:  0.09667682647705078
Loss: 0.046805853644904034
rotation error:  83.58505841110261
translation error:  0.06853679921168554
--- 22.231043338775635 seconds ---
Epoch: [3/20], Batch: 29, Loss: 0.046805853644904034
loss epoch [0.018228114065411007, 0.030171642659002254, 0.031736207925020556, 0.030159834338511424, 0.010727888925925891, 0.048606673431997593, 0.030683719543462887, 0.0373445420149885, 0.02326113207447663, 0.03416133748860343, 0.030813135441424472, 0.024867373605169354, 0.026650875400917394, 0.03226621337205414, 0.04696063649407213, 0.03705485639180757, 0.026291717183869395, 0.037553580196473126, 0.016004682531298955, 0.06110417711379097, 0.0433645604

 25%|██▌       | 5/20 [55:21<2:46:05, 664.35s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([3107, 4247, 6572,  ..., 1594, 8091, 8085], device='cuda:0')
get_cat_feat_tgt time:  0.10066366195678711
Loss: 0.027161428095823288
rotation error:  83.16497048694895
translation error:  0.03628706317093349
--- 22.19798731803894 seconds ---
Epoch: [4/20], Batch: 29, Loss: 0.027161428095823288
loss epoch [0.022912565374625257, 0.04216810847030138, 0.048538164967243676, 0.04317492793461487, 0.03720310424953843, 0.047509523572013886, 0.022402830213255658, 0.0263974238865943, 0.02342327588466903, 0.017220872984195428, 0.031493337728302276, 0.031823562835224035, 0.028812729826852485, 0.005992591523040964, 0.01778544849657839, 0.03221168961382725, 0.03422216102711061, 0.04273596058000617, 0.03225986781187413, 0.0241495007357344, 0.0155266915503928

 30%|███       | 6/20 [1:06:23<2:34:51, 663.67s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([9382, 2026, 9026,  ..., 3534, 2795, 2611], device='cuda:0')
get_cat_feat_tgt time:  0.10590934753417969
Loss: 0.02551967347995016
rotation error:  40.991686978156
translation error:  0.05114561085361554
--- 22.115290880203247 seconds ---
Epoch: [5/20], Batch: 29, Loss: 0.02551967347995016
loss epoch [0.016551064594102664, 0.029262527099916977, 0.013474282210104115, 0.014755370934066465, 0.0401690488132905, 0.025627749463091577, 0.00660638142830616, 0.01534575612037923, 0.017562893106404336, 0.01221965074878594, 0.03777050808373795, 0.0334409587733846, 0.02402741208645752, 0.018807574537828047, 0.01969650334845582, 0.021721895412690415, 0.010238196017937214, 0.011423572176876399, 0.01354623054057005, 0.03239548649568526, 0.019811697794560056

 35%|███▌      | 7/20 [1:17:22<2:23:26, 662.06s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([5466, 8145, 5411,  ..., 3575, 4264, 7993], device='cuda:0')
get_cat_feat_tgt time:  0.11262345314025879
Loss: 0.02539453261826253
rotation error:  53.85850481491015
translation error:  0.04499097045790326
--- 22.315564393997192 seconds ---
Epoch: [6/20], Batch: 29, Loss: 0.02539453261826253
loss epoch [0.018481762991572337, 0.029403251734749038, 0.02668080494926179, 0.021437457918876385, 0.022173516052468565, 0.023982988633392946, 0.038497207780507414, 0.03497944453703604, 0.02694949706400479, 0.023181425034008045, 0.013872861711411532, 0.032270079783455294, 0.01809445177200423, 0.02072605295623024, 0.0345285516119358, 0.02992951547103484, 0.01668899349072593, 0.03408743586571068, 0.06286512291848423, 0.023349355791534982, 0.017581008663846

 40%|████      | 8/20 [1:28:14<2:11:49, 659.12s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([6031, 9115, 5990,  ..., 2907, 9835, 3931], device='cuda:0')
get_cat_feat_tgt time:  0.11860322952270508
Loss: 0.0176269373883818
rotation error:  241.70292011106477
translation error:  0.015061790230405154
--- 21.120352029800415 seconds ---
Epoch: [7/20], Batch: 29, Loss: 0.0176269373883818
loss epoch [0.015865225391482714, 0.027143922440359184, 0.02125928582017081, 0.030272850258086566, 0.046003485215154655, 0.03956243317897536, 0.0247374770127945, 0.03751825091949685, 0.01832304753244276, 0.040982058446555406, 0.016751763415856193, 0.04384129477908059, 0.03911145906534963, 0.04109561201183298, 0.044533609433311375, 0.039824131223966816, 0.02546987038415545, 0.025171039471312728, 0.016431851071232892, 0.013484867572342393, 0.03111325029086

 45%|████▌     | 9/20 [1:37:45<1:56:00, 632.76s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([7212, 7800, 6122,  ..., 1043, 1389, 5872], device='cuda:0')
get_cat_feat_tgt time:  0.10764002799987793
Loss: 0.034261383181295735
rotation error:  36.36000503627532
translation error:  0.04322953584919714
--- 19.242645025253296 seconds ---
Epoch: [8/20], Batch: 29, Loss: 0.034261383181295735
loss epoch [0.0205562121947264, 0.028040633179231868, 0.02561402312851505, 0.022871808367655042, 0.019631349201063845, 0.02560535564360478, 0.019929372151520443, 0.038552044363005966, 0.029678670870427505, 0.016173287725651173, 0.029779410849855206, 0.01179740128769928, 0.024014049829398327, 0.03333261832386272, 0.027408420674473744, 0.03640002278018191, 0.021309780075541847, 0.03136039339619619, 0.026449682675174616, 0.03689878102758707, 0.05364946922

 50%|█████     | 10/20 [1:45:37<1:37:24, 584.44s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([5443, 6636, 4024,  ..., 9403, 6707, 7153], device='cuda:0')
get_cat_feat_tgt time:  0.12458348274230957
Loss: 0.03177284007248178
rotation error:  49.47665345218597
translation error:  0.047214246303082734
--- 14.65701937675476 seconds ---
Epoch: [9/20], Batch: 29, Loss: 0.03177284007248178
loss epoch [0.014018014418987123, 0.02196403990348769, 0.019052323109062998, 0.03061659483183709, 0.012036960740695438, 0.020024734595312357, 0.028803527255924027, 0.016730146132701344, 0.014012058861977638, 0.014653864575808133, 0.015488801169309369, 0.01208728318673805, 0.040852553549040715, 0.02183442713433456, 0.014658644979268519, 0.007733370639648952, 0.02757575841320197, 0.009108029713777369, 0.01894079484367087, 0.026075055244731825, 0.0069075171

 55%|█████▌    | 11/20 [1:53:04<1:21:30, 543.35s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([7844, 8421, 9362,  ..., 4328, 3579, 4809], device='cuda:0')
get_cat_feat_tgt time:  0.10764002799987793
Loss: 0.026369346509773524
rotation error:  79.99334595764921
translation error:  0.04146919192792065
--- 14.900293111801147 seconds ---
Epoch: [10/20], Batch: 29, Loss: 0.026369346509773524
loss epoch [0.022160183647083526, 0.01652114317336617, 0.016556854456408004, 0.040521994797535696, 0.015435692217847848, 0.010649601574342648, 0.019910781566660152, 0.02693934343758272, 0.030351660848629565, 0.015449916893128229, 0.008294810628882085, 0.029310437536128546, 0.022167319605973068, 0.014596682960399763, 0.017948332980636226, 0.01406818519473311, 0.020398169544509213, 0.015306684818770889, 0.014845751719245235, 0.054131024422913385, 0.0439

 60%|██████    | 12/20 [2:00:31<1:08:35, 514.46s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([4176, 3996,  114,  ..., 4003, 5916, 6691], device='cuda:0')
get_cat_feat_tgt time:  0.1096339225769043
Loss: 0.04414233890730713
rotation error:  30.824001713059573
translation error:  0.05483453399921706
--- 14.610134363174438 seconds ---
Epoch: [11/20], Batch: 29, Loss: 0.04414233890730713
loss epoch [0.02080074713336815, 0.013763613443605717, 0.030343966976600228, 0.03085150051969748, 0.02282365922688398, 0.03360385277047835, 0.02882195718710988, 0.023584918092542505, 0.018542045630586864, 0.01892421956854084, 0.03299584906157594, 0.04293117874126043, 0.026982241478641693, 0.014497911176376208, 0.013611570443427963, 0.030985765371888224, 0.020162985577071225, 0.025173661873884635, 0.020021960942292, 0.039762855200269016, 0.04036313152424

 65%|██████▌   | 13/20 [2:07:54<57:30, 492.87s/it]  

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([2388, 9153, 3159,  ..., 9290, 2947, 1200], device='cuda:0')
get_cat_feat_tgt time:  0.11461687088012695
Loss: 0.03412709245514918
rotation error:  37.56630629126113
translation error:  0.04369041828654622
--- 14.63314151763916 seconds ---
Epoch: [12/20], Batch: 29, Loss: 0.03412709245514918
loss epoch [0.028401839225575036, 0.024554890243393284, 0.024621389058889543, 0.01207216206636875, 0.025961092890708977, 0.024250400603564273, 0.023077016390086025, 0.015211347462216503, 0.013031071165505352, 0.01139449003265524, 0.027713971053734694, 0.014037424639192744, 0.0126200620049802, 0.014153863781710989, 0.01524221595192483, 0.020344077504974113, 0.032757532909816985, 0.0190751126995888, 0.016500250027975904, 0.025330828801349185, 0.02720866973

 70%|███████   | 14/20 [2:15:16<47:46, 477.73s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([6942, 7788, 7254,  ..., 5470, 1228, 8409], device='cuda:0')
get_cat_feat_tgt time:  0.09468364715576172
Loss: 0.03462285920887315
rotation error:  43.95370793249765
translation error:  0.03897932135941173
--- 14.841360569000244 seconds ---
Epoch: [13/20], Batch: 29, Loss: 0.03462285920887315
loss epoch [0.02271743196014208, 0.023488999266998448, 0.010656750117570626, 0.025239853362932798, 0.03681677514074365, 0.020739939519020163, 0.02961041693774765, 0.06145755702557294, 0.06081789255235149, 0.047954442043194614, 0.04069997767552504, 0.01821445883761092, 0.0303912547980033, 0.035832866078983734, 0.02502791705813579, 0.030730660714136375, 0.02330651204306637, 0.02015344349917739, 0.0407960120824286, 0.023640216052568188, 0.03424351893604929

 75%|███████▌  | 15/20 [2:22:42<39:00, 468.09s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([ 743, 3354, 6135,  ..., 1622, 6233, 1904], device='cuda:0')
get_cat_feat_tgt time:  0.08870339393615723
Loss: 0.021974937097614486
rotation error:  291.6878856514946
translation error:  0.0072287817318787044
--- 14.853323459625244 seconds ---
Epoch: [14/20], Batch: 29, Loss: 0.021974937097614486
loss epoch [0.012756174356111658, 0.027697201103762135, 0.015504638628111439, 0.0204955338764454, 0.026323705257968063, 0.017112507555845882, 0.031576442248446156, 0.021675018206234357, 0.02529568472231797, 0.01527745713663654, 0.024111069430150477, 0.014257879276484132, 0.03757855238453418, 0.04368268141996272, 0.039905948490355156, 0.008456220870481477, 0.028648027806301435, 0.012618947048521765, 0.0160450901630923, 0.009199475606413612, 0.0302198

 80%|████████  | 16/20 [2:32:41<33:49, 507.48s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([6218, 7348, 5831,  ..., 9486, 8957, 1220], device='cuda:0')
get_cat_feat_tgt time:  0.11063003540039062
Loss: 0.030147517093487584
rotation error:  287.1325862792337
translation error:  0.03421493695985234
--- 27.017803192138672 seconds ---
Epoch: [15/20], Batch: 29, Loss: 0.030147517093487584
loss epoch [0.01968460212261643, 0.018848490598134048, 0.02505268220589247, 0.016027626288621675, 0.013577288941065486, 0.036176969943510676, 0.017730037853464123, 0.011493282471583734, 0.020102091772846966, 0.01861356239731187, 0.01645703246841647, 0.02449765412303181, 0.01487995402483863, 0.02118335267397426, 0.02827093581418123, 0.025523082577692976, 0.025877603265159803, 0.03235976095955106, 0.011949690433039959, 0.023109249745308418, 0.0203324418

 85%|████████▌ | 17/20 [2:43:26<27:26, 548.71s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([4607, 9576,  433,  ...,  376, 8101, 3251], device='cuda:0')
get_cat_feat_tgt time:  0.09966731071472168
Loss: 0.036542163026606275
rotation error:  90.66038282823325
translation error:  0.053051226406128756
--- 21.2441246509552 seconds ---
Epoch: [16/20], Batch: 29, Loss: 0.036542163026606275
loss epoch [0.022569595144395414, 0.0243295593559487, 0.026756191579346528, 0.03396395087392569, 0.01748957264508736, 0.05349323233653934, 0.017979843796310426, 0.02287498679508692, 0.021018559813076612, 0.034169392855193555, 0.018475901620025163, 0.020998950317749215, 0.021311016741358782, 0.017378168653598196, 0.011308437346863102, 0.039335003986977354, 0.013009015400714632, 0.012045952797953407, 0.010156827122710183, 0.015901802424561542, 0.03208495

 90%|█████████ | 18/20 [2:54:43<19:34, 587.26s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([6824, 8711, 3875,  ...,  765, 4381, 2301], device='cuda:0')
get_cat_feat_tgt time:  0.10764074325561523
Loss: 0.020278547130021495
rotation error:  41.65617942341326
translation error:  0.04128148651304482
--- 28.534846782684326 seconds ---
Epoch: [17/20], Batch: 29, Loss: 0.020278547130021495
loss epoch [0.025159528723424418, 0.012805155882965125, 0.010719919321463356, 0.014084939263788461, 0.016025197809309947, 0.03412153082055755, 0.024482466050894762, 0.033628049485224895, 0.0718707427746601, 0.04448973662585958, 0.026364520090611313, 0.024352798732799726, 0.01588479766818433, 0.021061408505848544, 0.03060926797971915, 0.011852017927347436, 0.029573196839467505, 0.01966941706323416, 0.03003668016941815, 0.01747331858395395, 0.0159729646

 95%|█████████▌| 19/20 [3:03:52<09:35, 575.66s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([7637, 5458, 9248,  ..., 2820, 2078, 7483], device='cuda:0')
get_cat_feat_tgt time:  0.09368777275085449
Loss: 0.014449905997528673
rotation error:  45.06288917797024
translation error:  0.0143462147469076
--- 31.634239196777344 seconds ---
Epoch: [18/20], Batch: 29, Loss: 0.014449905997528673
loss epoch [0.031309880425211524, 0.022867760643018667, 0.03140754044421282, 0.017407241957122004, 0.013695513976425948, 0.021235536337804953, 0.014848090405871267, 0.04972308832928345, 0.01796202799356089, 0.044890965757374314, 0.025098909993914766, 0.02015905136951815, 0.035977607550371354, 0.03286519683518462, 0.0447886776458387, 0.022182621210592803, 0.011008117138330213, 0.0245995621293864, 0.015735943139812802, 0.015643851925691447, 0.01769540373

100%|██████████| 20/20 [3:12:44<00:00, 578.24s/it]

tgt_pts_xyz:  torch.Size([1, 10000, 3])
ref_pts:  torch.Size([1, 10000, 3])
dist_normalize:  torch.Size([1, 32, 13824])
feat_weight_map:  torch.Size([1, 32, 32, 13824])
idx_1_mask:  tensor([[0]])
idx_1_mask_flatten:  tensor([0])
idx_2_mask:  tensor([  11, 6637, 2822,  ..., 2950, 2435, 3567], device='cuda:0')
get_cat_feat_tgt time:  0.12059712409973145
Loss: 0.05071938793496703
rotation error:  279.4913091817251
translation error:  0.04925120031905461
--- 27.15120816230774 seconds ---
Epoch: [19/20], Batch: 29, Loss: 0.05071938793496703
loss epoch [0.02800566254205541, 0.011289687004258371, 0.01516644535842585, 0.029256355733691952, 0.018905167809903936, 0.020396110681558968, 0.022379865750191116, 0.02591199734594861, 0.057757939408246625, 0.03879990450661007, 0.02083738998812816, 0.023240996068001788, 0.026063456278538562, 0.0234581076781964, 0.03744378783538765, 0.026871668837158927, 0.006430796651855191, 0.02963230389367756, 0.023416743831711932, 0.014714021138854198, 0.0312083415616




In [7]:
# save
print("Finished Training")
torch.save(model.state_dict(), model_path)

Finished Training
