In [1]:
import torch
import numpy as np
import pickle
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from model import *
import os
from numpy.linalg import inv
from PIL import Image
import math
import utils

def load_pickle(filename):
    with open(filename, 'rb') as f:
        return pickle.load(f)

In [2]:
# # setup the data
# train_dump = load_pickle('train_dump')
math.pi

3.141592653589793

In [3]:
# for item in train_dump.keys():
#     print(type(train_dump[item]['pcd'][0]))

In [4]:
def symmetric_orthogonalization(x):
    m = x.view(-1, 3, 3)
    u, s, v = torch.svd(m)
    vt = torch.transpose(v, 1, 2)
    det = torch.det(torch.matmul(u, vt))
    det = det.view(-1, 1, 1)
    vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1)
    r = torch.matmul(u, vt)
    return r

In [5]:
training_data_dir = "training_data_filtered/training_data/v2.2"
split_dir = "training_data_filtered/training_data/splits/v2"
objects_csv = 'training_data_filtered/training_data/objects_v1.csv'

def get_split_files(split_name):
    with open(os.path.join(split_dir, f"{split_name}.txt"), 'r') as f:
        prefix = [os.path.join(training_data_dir, line.strip()) for line in f if line.strip()]
        rgb = [p + "_color_kinect.png" for p in prefix]
        depth = [p + "_depth_kinect.png" for p in prefix]
        label = [p + "_label_kinect.png" for p in prefix]
        meta = [p + "_meta.pkl" for p in prefix]
    return rgb, depth, label, meta

rgb_files, depth_files, label_files, meta_files = get_split_files('train')
rgb_files_val, depth_files_val, label_files_val, meta_files_val = get_split_files('val')


def read_image(img_path):
    '''
    inputs:
    img_path : the location of the image to be read
    outputs:
    image converted to torch.tensor
    '''
    image = np.array(Image.open(img_path))
#     print(image)
    image = torch.from_numpy(image)
    return image

class mydataset(Dataset):
    # define the init method
    def __init__(self,meta_files,depth_files,label_files,transform=None, target_transform = None) -> None:
        super().__init__()
        self.meta_files = meta_files
        self.depth_files = depth_files
        self.label_files = label_files
        self.pcd = []
        self.poses = []
        self.transform = transform
        self.target_transform = target_transform

    # define the len method
    def __len__(self):
        return len(self.depth_files)

    # define the getitem() method
    def __getitem__(self,idx):
        self.pcd = []
        self.poses = []
        meta = load_pickle(self.meta_files[idx])
        intrinsic = meta['intrinsic']
        ext_inv = inv(meta['extrinsic'])
        depth = np.array(Image.open(self.depth_files[idx])) / 1000
        label = np.array(Image.open(self.label_files[idx]))
        z = depth
        v, u = np.indices(z.shape)
        uv1 = np.stack([u + 0.5, v + 0.5, np.ones_like(z)], axis=-1)
        points_viewer = uv1 @ np.linalg.inv(intrinsic).T * z[..., None]  # [H, W, 3]
        for j,l in enumerate(meta['object_ids']):
            crops_pcd = np.array([points_viewer[label==l]])
            if crops_pcd.shape[1]==0:
                continue
            crops_pcd = crops_pcd@ext_inv[:3,:3].T+ext_inv[:3,3]
            crops_pcd = crops_pcd.squeeze(0)
            if(len(crops_pcd)<800):
                continue
            crops_pcd = crops_pcd[:800,:]
            self.pcd.append(crops_pcd)
            self.poses.append(meta['poses_world'][l])
        return np.array(self.pcd), np.array(self.poses)


In [6]:
training_data = mydataset(meta_files,depth_files,label_files)
val_data = mydataset(meta_files_val,depth_files_val,label_files_val)

In [7]:
train_dataloader = DataLoader(training_data, batch_size=1, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=1,shuffle=True)

# train_pcd, train_poses = next(iter(train_dataloader))
# val_pcd, val_poses = next(iter(val_dataloader))

In [8]:
# defined model
device = "cuda"

net = STN3d()
# net = torch.load('poitnet_3.pth')
# net = PointNetCls()

net.float().to(device)

STN3d(
  (conv1): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
  (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
  (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
  (fc1_r): Linear(in_features=1024, out_features=512, bias=True)
  (fc1_t): Linear(in_features=1024, out_features=512, bias=True)
  (fc2_r): Linear(in_features=512, out_features=256, bias=True)
  (fc2_t): Linear(in_features=512, out_features=256, bias=True)
  (fc3_r): Linear(in_features=256, out_features=128, bias=True)
  (fc3_t): Linear(in_features=256, out_features=128, bias=True)
  (fc4_r): Linear(in_features=128, out_features=9, bias=True)
  (fc4_t): Linear(in_features=128, out_features=3, bias=True)
  (relu): ReLU()
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn4_r): BatchNorm1d

In [9]:
def loss_torch(R_est,t_est,R_gt,t_gt,src_pcd):
    with torch.no_grad():
        rre = torch.acos(torch.clamp(0.5 * ((torch.bmm(torch.permute(R_est,(0,2,1)), R_gt)).diagonal(offset=0, dim1=-1, dim2=-2).sum(-1) - 1), -1.0, 1.0)).sum()
        rte = torch.linalg.norm(t_est - t_gt)
    src_pcd = torch.permute(src_pcd,(0,2,1))
    tgt_pcd = torch.bmm(src_pcd,torch.permute(R_gt,(0,2,1)))
    tgt_pcd = torch.add(tgt_pcd, t_gt.unsqueeze(1))
    pred_pcd = torch.bmm(src_pcd,torch.permute(R_est,(0,2,1)))
    pred_pcd = torch.add(pred_pcd, t_est.unsqueeze(1))
    point_loss = torch.linalg.norm(pred_pcd-tgt_pcd)
    final_loss = point_loss#+rre+rte
    return final_loss,rre,rte,point_loss

In [10]:
def compute_rre_symmetry(R_est: np.ndarray, R_gt: np.ndarray,
                         sym_rots: np.ndarray, rot_axis=None):
    assert R_est.shape == (3, 3), 'R_est: expected shape (3, 3), received shape {}.'.format(R_est.shape)
    assert R_gt.shape == (3, 3), 'R_gt: expected shape (3, 3), received shape {}.'.format(R_gt.shape)

    if rot_axis is None:
        R_gt_sym = R_gt @ sym_rots
        rre_sym_all = np.arccos(np.clip(0.5 * (np.trace(R_est.T @ R_gt_sym, axis1=-2, axis2=-1) - 1), -1.0, 1.0))
        rre_best = np.min(rre_sym_all)
    else:
        R_gt_sym = R_gt @ sym_rots
        rot_axis_gt = R_gt_sym @ rot_axis  # [?, 3]
        rot_axis_est = R_est @ rot_axis  # [3]
        rre_sym = np.arccos(np.clip(np.dot(rot_axis_gt, rot_axis_est), -1.0, 1.0))  # [?]
        rre_best = np.min(rre_sym)
    return rre_best

In [11]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/exp_normals')

data_val = next(iter(val_dataloader))

In [12]:
from tqdm import tqdm,trange
import matplotlib.pyplot as plt
from gradient_flow import *


batch_size = 1
epochs = 25
print_freq = 20
print_freq_val = 100
epoch_save = 100
batch_loss = []
batch_loss_val = []
optim = torch.optim.Adam(net.parameters(), lr=0.0001)
# define the train routine
for epoch in trange(epochs):
    print_count = 0
    print_loss = 0
    print_rre = 0
    print_rte = 0
    print_ptloss = 0
    epoch_step = 0
    for data in train_dataloader:
        epoch_step += 1
        print_count += 1
        optim.zero_grad()
        pcd = data[0].float().squeeze(0).permute(0,2,1).to(device)
        r_pred,t_pred = net(pcd)
#         r_pred = r_pred.squeeze(2)
#         t_pred = t_pred.squeeze(2)
        r_pred = symmetric_orthogonalization(r_pred)
#         r_pred = torch.permute(r_pred,(0,2,1))
        train_poses = data[1].squeeze(0).to(device)
        R_gt = train_poses[:,:3,:3]
        t_gt = train_poses[:,:3,3]
        
        loss,rre,rte,tr_ptloss = loss_torch(r_pred,t_pred,R_gt,t_gt,pcd)
        print_loss += loss.item()
        print_rre +=rre.item()
        print_rte += rte.item()
        print_ptloss += tr_ptloss.item()
        loss.backward()
        optim.step()
        # print(print_loss)
        if(print_count % print_freq == 0): 
            print(f"[{epoch+1}/{epochs}][{epoch_step}/{len(train_dataloader)}]") 
            print(f"loss_train: {print_loss / print_freq}")
            print(f"rre_train: {print_rre / print_freq}")
            print(f"rte_train: {print_rte / print_freq}")
            print(f"ptloss_train: {print_ptloss / print_freq}")
            writer.add_scalar('loss',print_loss / print_freq,print_count)
            batch_loss.append(print_loss / print_freq)
            print_loss = 0
            print_rre = 0
            print_rte = 0
            print_ptloss = 0
        if(print_count % print_freq_val == 0):
            with torch.no_grad():
                r_pred,t_pred = net(data_val[0].float().squeeze(0).permute(0,2,1).to(device))
                r_pred = symmetric_orthogonalization(r_pred)
                r_pred = torch.permute(r_pred,(0,2,1))
                train_poses = data_val[1].squeeze(0).to(device)
                R_gt = train_poses[:,:3,:3]
                t_gt = train_poses[:,:3,3]
                val_loss,rre_val,rte_val,val_ptloss = loss_torch(r_pred,t_pred,R_gt,t_gt,data_val[0].float().squeeze(0).permute(0,2,1).to(device))
                print(f"[{epoch+1}/{epochs}][{epoch_step}/{len(train_dataloader)}]") 
                print(f"loss_val: {val_loss.item()}")
                print(f"rre_val: {rre_val.item()}")
                print(f"rte_val: {rte_val.item()}")
                writer.add_scalar('val_loss',val_loss,print_count)
                batch_loss_val.append(val_loss.item())

  0%|          | 0/25 [00:00<?, ?it/s]

[1/25][20/3964]
loss_train: 63.125390625
rre_train: 11.501401233673096
rte_train: 1.549858459830284
ptloss_train: 63.125390625
[1/25][40/3964]
loss_train: 60.488168716430664
rre_train: 13.876017332077026
rte_train: 1.560651034116745
ptloss_train: 60.488168716430664
[1/25][60/3964]
loss_train: 58.08301486968994
rre_train: 12.777997469902038
rte_train: 1.4887207806110383
ptloss_train: 58.08301486968994
[1/25][80/3964]
loss_train: 58.523474311828615
rre_train: 13.267211103439331
rte_train: 1.4663214683532715
ptloss_train: 58.523474311828615
[1/25][100/3964]
loss_train: 53.80232830047608
rre_train: 12.677518224716186
rte_train: 1.3328192681074142
ptloss_train: 53.80232830047608
[1/25][100/3964]
loss_val: 44.27958679199219
rre_val: 8.00511360168457
rte_val: 1.1612155437469482
[1/25][120/3964]
loss_train: 52.98549537658691
rre_train: 11.308639311790467
rte_train: 1.362018659710884
ptloss_train: 52.98549537658691
[1/25][140/3964]
loss_train: 52.512145614624025
rre_train: 11.431504034996033
rt

  crops_pcd /= np.max(np.linalg.norm(crops_pcd, axis=1))


[1/25][360/3964]
loss_train: 48.95852079391479
rre_train: 11.969983863830567
rte_train: 1.156521674990654
ptloss_train: 48.95852079391479
[1/25][380/3964]
loss_train: 49.316233444213864
rre_train: 12.143868112564087
rte_train: 1.0614352643489837
ptloss_train: 49.316233444213864
[1/25][400/3964]
loss_train: 51.87393684387207
rre_train: 12.664512181282044
rte_train: 1.1514399349689484
ptloss_train: 51.87393684387207
[1/25][400/3964]
loss_val: 50.6492805480957
rre_val: 7.84267520904541
rte_val: 0.8639469742774963
[1/25][420/3964]
loss_train: 48.41273746490479
rre_train: 11.615524506568908
rte_train: 1.0333091974258424
ptloss_train: 48.41273746490479
[1/25][440/3964]
loss_train: 51.71680870056152
rre_train: 12.757634043693542
rte_train: 1.241715431213379
ptloss_train: 51.71680870056152
[1/25][460/3964]
loss_train: 50.296982002258304
rre_train: 10.774740076065063
rte_train: 1.0687736988067627
ptloss_train: 50.296982002258304
[1/25][480/3964]
loss_train: 48.69978065490723
rre_train: 10.60382

[1/25][1400/3964]
loss_train: 49.74625358581543
rre_train: 11.285971784591675
rte_train: 1.018052165210247
ptloss_train: 49.74625358581543
[1/25][1400/3964]
loss_val: 27.228330612182617
rre_val: 4.6400675773620605
rte_val: 0.6860799789428711
[1/25][1420/3964]
loss_train: 45.74273939132691
rre_train: 8.837004113197327
rte_train: 0.9477012664079666
ptloss_train: 45.74273939132691
[1/25][1440/3964]
loss_train: 50.69987564086914
rre_train: 11.827348279953004
rte_train: 1.0599725276231766
ptloss_train: 50.69987564086914
[1/25][1460/3964]
loss_train: 46.4590576171875
rre_train: 10.271747958660125
rte_train: 0.9948330879211426
ptloss_train: 46.4590576171875
[1/25][1480/3964]
loss_train: 48.435959815979004
rre_train: 10.835756587982178
rte_train: 1.0485084682703019
ptloss_train: 48.435959815979004
[1/25][1500/3964]
loss_train: 45.581419944763184
rre_train: 9.783911156654359
rte_train: 1.000473329424858
ptloss_train: 45.581419944763184
[1/25][1500/3964]
loss_val: 29.0168399810791
rre_val: 4.150

[1/25][2420/3964]
loss_train: 46.5183780670166
rre_train: 9.592026329040527
rte_train: 1.0605160176753998
ptloss_train: 46.5183780670166
[1/25][2440/3964]
loss_train: 46.265254974365234
rre_train: 10.966628909111023
rte_train: 1.067648708820343
ptloss_train: 46.265254974365234
[1/25][2460/3964]
loss_train: 48.851992988586424
rre_train: 10.897075700759888
rte_train: 1.0371643483638764
ptloss_train: 48.851992988586424
[1/25][2480/3964]
loss_train: 48.16656265258789
rre_train: 10.957598853111268
rte_train: 1.042963621020317
ptloss_train: 48.16656265258789
[1/25][2500/3964]
loss_train: 45.030336952209474
rre_train: 9.895600175857544
rte_train: 1.0786398857831956
ptloss_train: 45.030336952209474
[1/25][2500/3964]
loss_val: 28.30492401123047
rre_val: 4.410979747772217
rte_val: 0.756943941116333
[1/25][2520/3964]
loss_train: 42.389288902282715
rre_train: 8.37906836271286
rte_train: 0.8973910927772522
ptloss_train: 42.389288902282715
[1/25][2540/3964]
loss_train: 45.702693462371826
rre_train: 

[1/25][3460/3964]
loss_train: 46.67507133483887
rre_train: 10.162968826293945
rte_train: 1.0869253367185592
ptloss_train: 46.67507133483887
[1/25][3480/3964]
loss_train: 48.92015342712402
rre_train: 12.31663362979889
rte_train: 1.1076056778430938
ptloss_train: 48.92015342712402
[1/25][3500/3964]
loss_train: 47.88215742111206
rre_train: 11.50086851119995
rte_train: 1.0617171794176101
ptloss_train: 47.88215742111206
[1/25][3500/3964]
loss_val: 39.31097412109375
rre_val: 7.76375150680542
rte_val: 0.7169038653373718
[1/25][3520/3964]
loss_train: 47.74767694473267
rre_train: 11.775865387916564
rte_train: 1.061858233809471
ptloss_train: 47.74767694473267
[1/25][3540/3964]
loss_train: 51.8808479309082
rre_train: 11.604334425926208
rte_train: 1.0609694510698318
ptloss_train: 51.8808479309082
[1/25][3560/3964]
loss_train: 45.150223064422605
rre_train: 10.032131552696228
rte_train: 1.1009619265794754
ptloss_train: 45.150223064422605
[1/25][3580/3964]
loss_train: 45.95290641784668
rre_train: 10.1

  4%|▍         | 1/25 [07:42<3:04:55, 462.30s/it]

[2/25][20/3964]
loss_train: 46.14980020523071
rre_train: 10.93286907672882
rte_train: 1.0898528426885605
ptloss_train: 46.14980020523071
[2/25][40/3964]
loss_train: 42.03432769775391
rre_train: 7.822733974456787
rte_train: 0.9543469339609146
ptloss_train: 42.03432769775391
[2/25][60/3964]
loss_train: 44.62418842315674
rre_train: 10.894415903091431
rte_train: 1.01043860912323
ptloss_train: 44.62418842315674
[2/25][80/3964]
loss_train: 45.668055629730226
rre_train: 10.150637555122376
rte_train: 1.0042741477489472
ptloss_train: 45.668055629730226
[2/25][100/3964]
loss_train: 44.95657339096069
rre_train: 9.416127526760102
rte_train: 0.9901788771152497
ptloss_train: 44.95657339096069
[2/25][100/3964]
loss_val: 35.578041076660156
rre_val: 5.01793098449707
rte_val: 0.7782472372055054
[2/25][120/3964]
loss_train: 45.812108516693115
rre_train: 8.960058557987214
rte_train: 1.0153710067272186
ptloss_train: 45.812108516693115
[2/25][140/3964]
loss_train: 44.15742092132568
rre_train: 9.611085546016

[2/25][1060/3964]
loss_train: 47.38900241851807
rre_train: 10.720283961296081
rte_train: 1.0741570889949799
ptloss_train: 47.38900241851807
[2/25][1080/3964]
loss_train: 43.809246349334714
rre_train: 8.80141624212265
rte_train: 0.9794567793607711
ptloss_train: 43.809246349334714
[2/25][1100/3964]
loss_train: 46.04090328216553
rre_train: 10.418111062049865
rte_train: 1.0177903413772582
ptloss_train: 46.04090328216553
[2/25][1100/3964]
loss_val: 33.328182220458984
rre_val: 4.512075424194336
rte_val: 0.728563129901886
[2/25][1120/3964]
loss_train: 45.8508903503418
rre_train: 10.37390034198761
rte_train: 0.9521826148033142
ptloss_train: 45.8508903503418
[2/25][1140/3964]
loss_train: 43.88059635162354
rre_train: 9.400303554534911
rte_train: 0.9372862502932549
ptloss_train: 43.88059635162354
[2/25][1160/3964]
loss_train: 44.0775749206543
rre_train: 8.736354422569274
rte_train: 0.9041238486766815
ptloss_train: 44.0775749206543
[2/25][1180/3964]
loss_train: 47.27066707611084
rre_train: 10.3262

[2/25][2100/3964]
loss_train: 46.35370712280273
rre_train: 9.989339518547059
rte_train: 1.0267360866069795
ptloss_train: 46.35370712280273
[2/25][2100/3964]
loss_val: 38.600215911865234
rre_val: 6.088656902313232
rte_val: 0.7920262217521667
[2/25][2120/3964]
loss_train: 47.78571701049805
rre_train: 10.591129398345947
rte_train: 1.1048101782798767
ptloss_train: 47.78571701049805
[2/25][2140/3964]
loss_train: 43.65990085601807
rre_train: 9.21570817232132
rte_train: 1.0349441677331925
ptloss_train: 43.65990085601807
[2/25][2160/3964]
loss_train: 46.963347816467284
rre_train: 10.05650336742401
rte_train: 0.9924909144639968
ptloss_train: 46.963347816467284
[2/25][2180/3964]
loss_train: 48.11359577178955
rre_train: 10.926735353469848
rte_train: 1.0109656035900116
ptloss_train: 48.11359577178955
[2/25][2200/3964]
loss_train: 48.97670936584473
rre_train: 11.01619267463684
rte_train: 1.1163792610168457
ptloss_train: 48.97670936584473
[2/25][2200/3964]
loss_val: 33.442501068115234
rre_val: 4.804

[2/25][3120/3964]
loss_train: 48.73891410827637
rre_train: 11.609408736228943
rte_train: 1.07051320374012
ptloss_train: 48.73891410827637
[2/25][3140/3964]
loss_train: 45.317347049713135
rre_train: 9.671516239643097
rte_train: 0.9419434100389481
ptloss_train: 45.317347049713135
[2/25][3160/3964]
loss_train: 43.63851757049561
rre_train: 9.495643591880798
rte_train: 1.0344009667634964
ptloss_train: 43.63851757049561
[2/25][3180/3964]
loss_train: 46.14883451461792
rre_train: 9.212859582901
rte_train: 0.9648331731557847
ptloss_train: 46.14883451461792
[2/25][3200/3964]
loss_train: 45.45982151031494
rre_train: 10.478418469429016
rte_train: 0.9482278570532798
ptloss_train: 45.45982151031494
[2/25][3200/3964]
loss_val: 32.681819915771484
rre_val: 4.227687835693359
rte_val: 0.7597943544387817
[2/25][3220/3964]
loss_train: 47.64389019012451
rre_train: 11.03567556142807
rte_train: 1.0485693246126175
ptloss_train: 47.64389019012451
[2/25][3240/3964]
loss_train: 44.64687337875366
rre_train: 9.1232

  8%|▊         | 2/25 [15:23<2:56:53, 461.44s/it]

[3/25][20/3964]
loss_train: 44.83081159591675
rre_train: 9.74857827425003
rte_train: 1.005927786231041
ptloss_train: 44.83081159591675
[3/25][40/3964]
loss_train: 43.00432348251343
rre_train: 8.863599300384521
rte_train: 0.9917214512825012
ptloss_train: 43.00432348251343
[3/25][60/3964]
loss_train: 43.29317646026611
rre_train: 10.217138874530793
rte_train: 0.9455001056194305
ptloss_train: 43.29317646026611
[3/25][80/3964]
loss_train: 45.515505409240724
rre_train: 9.847513127326966
rte_train: 0.9656359612941742
ptloss_train: 45.515505409240724
[3/25][100/3964]
loss_train: 51.911603927612305
rre_train: 12.652207636833191
rte_train: 1.082817757129669
ptloss_train: 51.911603927612305
[3/25][100/3964]
loss_val: 32.720218658447266
rre_val: 5.70694637298584
rte_val: 0.7124598026275635
[3/25][120/3964]
loss_train: 50.02585945129395
rre_train: 11.488562417030334
rte_train: 1.0793961048126222
ptloss_train: 50.02585945129395
[3/25][140/3964]
loss_train: 45.34312314987183
rre_train: 10.24441583156

[3/25][1060/3964]
loss_train: 47.53367757797241
rre_train: 10.462063550949097
rte_train: 1.0109390288591384
ptloss_train: 47.53367757797241
[3/25][1080/3964]
loss_train: 45.69572086334229
rre_train: 11.764898943901063
rte_train: 1.0571907430887222
ptloss_train: 45.69572086334229
[3/25][1100/3964]
loss_train: 48.12954339981079
rre_train: 10.916172540187835
rte_train: 0.9578298270702362
ptloss_train: 48.12954339981079
[3/25][1100/3964]
loss_val: 51.46546173095703
rre_val: 9.386091232299805
rte_val: 0.7576428651809692
[3/25][1120/3964]
loss_train: 45.11178417205811
rre_train: 9.721683371067048
rte_train: 0.9551817625761032
ptloss_train: 45.11178417205811
[3/25][1140/3964]
loss_train: 40.44072027206421
rre_train: 8.408514046669007
rte_train: 0.8969777941703796
ptloss_train: 40.44072027206421
[3/25][1160/3964]
loss_train: 44.642963123321536
rre_train: 10.069967198371888
rte_train: 0.9952666133642196
ptloss_train: 44.642963123321536
[3/25][1180/3964]
loss_train: 46.6043004989624
rre_train: 1

[3/25][2100/3964]
loss_train: 45.09231681823731
rre_train: 9.718958950042724
rte_train: 0.8288314372301102
ptloss_train: 45.09231681823731
[3/25][2100/3964]
loss_val: 39.6145133972168
rre_val: 5.725735187530518
rte_val: 0.6847114562988281
[3/25][2120/3964]
loss_train: 49.40126190185547
rre_train: 11.902115368843079
rte_train: 0.9332563042640686
ptloss_train: 49.40126190185547
[3/25][2140/3964]
loss_train: 42.200437259674075
rre_train: 7.875551652908325
rte_train: 0.9364455133676529
ptloss_train: 42.200437259674075
[3/25][2160/3964]
loss_train: 48.452499866485596
rre_train: 11.383396124839782
rte_train: 1.0307782590389252
ptloss_train: 48.452499866485596
[3/25][2180/3964]
loss_train: 45.53764381408691
rre_train: 10.690991115570068
rte_train: 0.9425623506307602
ptloss_train: 45.53764381408691
[3/25][2200/3964]
loss_train: 49.30118465423584
rre_train: 10.707093811035156
rte_train: 0.984084066748619
ptloss_train: 49.30118465423584
[3/25][2200/3964]
loss_val: 33.16266632080078
rre_val: 4.69

[3/25][3120/3964]
loss_train: 42.648281288146975
rre_train: 9.019437885284423
rte_train: 0.9032479494810104
ptloss_train: 42.648281288146975
[3/25][3140/3964]
loss_train: 49.02192163467407
rre_train: 11.175740838050842
rte_train: 0.9198022902011871
ptloss_train: 49.02192163467407
[3/25][3160/3964]
loss_train: 44.38603038787842
rre_train: 9.408256363868713
rte_train: 0.9626613855361938
ptloss_train: 44.38603038787842
[3/25][3180/3964]
loss_train: 44.94854335784912
rre_train: 9.755979776382446
rte_train: 0.9361491844058036
ptloss_train: 44.94854335784912
[3/25][3200/3964]
loss_train: 43.89461278915405
rre_train: 9.985981225967407
rte_train: 0.9364152699708939
ptloss_train: 43.89461278915405
[3/25][3200/3964]
loss_val: 41.63776397705078
rre_val: 6.9435553550720215
rte_val: 0.7019118070602417


  8%|▊         | 2/25 [21:40<4:09:11, 650.08s/it]


KeyboardInterrupt: 

In [None]:
torch.save(net,'poitnet_4.pth')
# t_pred

In [None]:
def get_sample(idx):
    pcd = []
    poses = []
    meta = load_pickle(meta_files[idx])
    intrinsic = meta['intrinsic']
    ext_inv = inv(meta['extrinsic'])
    depth = np.array(Image.open(depth_files[idx])) / 1000
    label = np.array(Image.open(label_files[idx]))
    z = depth
    v, u = np.indices(z.shape)
    uv1 = np.stack([u + 0.5, v + 0.5, np.ones_like(z)], axis=-1)
    points_viewer = uv1 @ np.linalg.inv(intrinsic).T * z[..., None]  # [H, W, 3]
    for j,l in enumerate(meta['object_ids']):
        crops_pcd = np.array([points_viewer[label==l]])
        crops_pcd = crops_pcd@ext_inv[:3,:3].T+ext_inv[:3,3]
        crops_pcd = crops_pcd.squeeze(0)
        if(len(crops_pcd)<1000):
            continue
        crops_pcd = crops_pcd[:1000,:]
        pcd.append(crops_pcd)
        poses.append(meta['poses_world'][l])
    return np.array(pcd), np.array(poses)

In [None]:
training_data_dir = "training_data_filtered/training_data/v2.2"
split_dir = "training_data_filtered/training_data/splits/v2"
objects_csv = 'training_data_filtered/training_data/objects_v1.csv'

def get_split_files(split_name):
    with open(os.path.join(split_dir, f"{split_name}.txt"), 'r') as f:
        prefix = [os.path.join(training_data_dir, line.strip()) for line in f if line.strip()]
        rgb = [p + "_color_kinect.png" for p in prefix]
        depth = [p + "_depth_kinect.png" for p in prefix]
        label = [p + "_label_kinect.png" for p in prefix]
        meta = [p + "_meta.pkl" for p in prefix]
    return rgb, depth, label, meta

rgb_files, depth_files, label_files, meta_files = get_split_files('train')
rgb_files_val, depth_files_val, label_files_val, meta_files_val = get_split_files('val')

def get_sample(idx):
    pcd = []
    poses = []
    meta = load_pickle(meta_files[idx])
    intrinsic = meta['intrinsic']
    ext_inv = inv(meta['extrinsic'])
    depth = np.array(Image.open(depth_files[idx])) / 1000
    label = np.array(Image.open(label_files[idx]))
    z = depth
    v, u = np.indices(z.shape)
    uv1 = np.stack([u + 0.5, v + 0.5, np.ones_like(z)], axis=-1)
    points_viewer = uv1 @ np.linalg.inv(intrinsic).T * z[..., None]  # [H, W, 3]
    for j,l in enumerate(meta['object_ids']):
        crops_pcd = np.array([points_viewer[label==l]])
        crops_pcd = crops_pcd@ext_inv[:3,:3].T+ext_inv[:3,3]
        crops_pcd = crops_pcd.squeeze(0)
        if(len(crops_pcd)<1000):
            continue
        crops_pcd = crops_pcd[:1000,:]
        pcd.append(crops_pcd)
        poses.append(meta['poses_world'][l])
    return np.array(pcd), np.array(poses)


s =1
# scene = scenes[s]
boxed_image = np.array(Image.open(rgb_files[s])) / 255 
meta = load_pickle(meta_files[s])
pcd,gt_pose = get_sample(s)
pcd = torch.from_numpy(pcd)
print(gt_pose)
# poses_world = np.array([test_dump[scene]['poses_world'][idx] for idx in meta['object_ids']])
with torch.no_grad():
    poses_world = net(pcd.float().permute(0,2,1).to(device))
    print(poses_world[0].size())
    poses_world_r = symmetric_orthogonalization(poses_world[0]).detach().cpu().numpy()
    poses_world_t = poses_world[1].detach().cpu().numpy()
    print(poses_world_r)
box_sizes = np.array([meta['extents'][idx] * meta['scales'][idx] for idx in meta['object_ids']])
for i in range(len(poses_world_r)):
    print(i)
    utils.draw_projected_box3d(
        boxed_image, poses_world_t[i], box_sizes[i], poses_world_r[i], meta['extrinsic'], meta['intrinsic'],
        thickness=2)

Image.fromarray((boxed_image * 255).astype(np.uint8))