In [6]:
import torch
from learning.densefusion import DenseFuseNet
from learning.loss import quat_to_rot
from learning.utils import OBJ_NAMES, OBJ_NAMES_TO_IDX, IDX_TO_OBJ_NAMES

from pathlib import Path
from learning.load import PoseDataset
import pickle
from tqdm import tqdm
import numpy as np

def load_model(model, optimizer, load_path, device=torch.device('cpu')):
    checkpoint = torch.load(load_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer

def free(x: torch.Tensor):
    return x.detach().cpu().numpy()

def to_T(R, t):
    T = np.eye(4)
    T[:3, :3] = R
    T[:3, 3] = t
    return T

def pred_over_raw_data(output_json_name='dfnet_pred.json', processed_data_dir='processed_data'):
    def get_stripped_lines(fp, levels=[1, 2]):
        return [x.strip() for x in open(fp, 'r').readlines() if int(x[0]) in levels]
    
    device = torch.device('cuda')

    num_objects = len(OBJ_NAMES)
    dfnet = torch.nn.DataParallel(DenseFuseNet(num_objects)).to(device)
    optimizer = torch.optim.Adam(dfnet.parameters(), lr=1e-4)

    dfnet, optimizer = load_model(dfnet, optimizer, 'checkpoints/sub1.pt', device=torch.device('cuda'))
    dfnet = dfnet.eval()

    raw_data_dir = Path('raw_data')
    raw_test_dir = raw_data_dir / 'testing_data'
    raw_test_obj_dir = raw_test_dir / 'v2.2'
    processed_data_dir = Path(processed_data_dir)
    processed_test_dir = processed_data_dir / 'test'

    test_scene_names = get_stripped_lines(raw_test_dir / 'test.txt')
    test_ds = PoseDataset(data_dir=processed_test_dir, train=False, cloud=True, rgb=True, choose=True)

    data_point_num = 0
    all_data = dict()
    pbar = tqdm(test_scene_names)
    for scene_name in pbar:
        meta_path = raw_test_obj_dir / f'{scene_name}_meta.pkl'
        meta = pickle.load(open(meta_path, 'rb'))

        scene_data = dict(poses_world=[None] * 79)

        for obj_id, obj_name in zip(meta['object_ids'], meta['object_names']):

            pbar.set_description(f'dp_num={data_point_num}')

            cloud, rgb, choose, obj_idxs = test_ds[data_point_num]
            if len(cloud) < 1:
                T = np.eye(4)
            else:
                cloud, rgb, choose, obj_idxs = cloud.to(device), rgb.to(device), choose.to(device), obj_idxs.to(device)
                cloud, rgb, choose, obj_idxs = cloud.unsqueeze(0), rgb.unsqueeze(0), choose.unsqueeze(0), obj_idxs.unsqueeze(0)
                cloud, rgb, choose, obj_idxs = cloud.float(), rgb.float(), choose.float(), obj_idxs

                cloud_new = cloud.transpose(2, 1)
                rgb_new = torch.moveaxis(rgb, -1, 1)
                choose_new = choose.view(choose.size(0), -1)
                R_quat_pred, t_pred, c_pred = dfnet(cloud_new, rgb_new, choose_new, obj_idxs)

                R_pred = quat_to_rot(R_quat_pred)

                R_pred_opt = R_pred[0][torch.argmax(c_pred[0])]
                t_pred_opt = t_pred[0][torch.argmax(c_pred[0])]

                R_pred, t_pred = free(R_pred_opt), free(t_pred_opt)
                T = to_T(R_pred, t_pred)

            scene_data['poses_world'][obj_id] = T.tolist()

            data_point_num += 1

        all_data[scene_name] = scene_data

    import json
    with open(output_json_name, 'w') as fp:
        json.dump(all_data, fp)

In [7]:
pred_over_raw_data(output_json_name='dfnet_pred.json', processed_data_dir='processed_data')

  return self._call_impl(*args, **kwargs)
dp_num=1499: 100%|██████████| 200/200 [00:27<00:00,  7.15it/s]
