# DenseFusion Submission

In [1]:
import torch
from learning.densefusion import DenseFuseNet
from learning.loss_batch import quat_to_rot
from learning.utils import OBJ_NAMES, OBJ_NAMES_TO_IDX, IDX_TO_OBJ_NAMES
from icp import run_icp, rigid_transform, R_and_T, to_T

from pathlib import Path
from learning.load import PoseDataset
import pickle
from tqdm import tqdm
import numpy as np

def load_model(model: torch.nn.Module, optimizer: torch.optim.Optimizer, load_path, device=torch.device('cpu')):
    checkpoint = torch.load(load_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer

def free(x: torch.Tensor) -> np.ndarray:
    return x.detach().cpu().numpy()

def pred_over_raw_data(
        output_json_name='dfnet_pred.json',
        processed_data_dir='processed_data_unet',
        raw_data_dir='raw_data_all',
        checkpoint='select-checkpoints/dfnet.pt',
        refine=False,
        icp_thresh=1e-30,
    ):
    def get_stripped_lines(fp, levels=[1, 2, 3]):
        return [x.strip() for x in open(fp, 'r').readlines() if int(x[0]) in levels]
    
    device = torch.device('cuda')

    num_objects = len(OBJ_NAMES)
    dfnet = torch.nn.DataParallel(DenseFuseNet(num_objects)).to(device)
    optimizer = torch.optim.Adam(dfnet.parameters(), lr=1e-4)

    dfnet, optimizer = load_model(dfnet, optimizer, checkpoint, device=torch.device('cuda'))
    dfnet = dfnet.eval()

    raw_data_dir = Path(raw_data_dir)
    raw_test_dir = raw_data_dir / 'testing_data'
    raw_test_obj_dir = raw_test_dir / 'v2.2'
    processed_data_dir = Path(processed_data_dir)
    processed_test_dir = processed_data_dir / 'test'

    test_scene_names = get_stripped_lines(raw_test_dir / 'test.txt')
    test_ds = PoseDataset(data_dir=processed_test_dir, train=False, cloud=True, rgb=True, model=True, choose=True)

    data_point_num = 0
    all_data = dict()
    pbar = tqdm(test_scene_names)
    for scene_name in pbar:
        meta_path = raw_test_obj_dir / f'{scene_name}_meta.pkl'
        meta = pickle.load(open(meta_path, 'rb'))

        scene_data = dict(poses_world=[None] * 79)

        for obj_id, obj_name in zip(meta['object_ids'], meta['object_names']):

            pbar.set_description(f'dp_num={data_point_num}')

            cloud, rgb, model, choose, obj_idxs = test_ds[data_point_num]
            if len(cloud) < 1:
                T = np.eye(4)
            else:
                cloud, rgb, model, choose, obj_idxs = cloud.to(device), rgb.to(device), model.to(device), choose.to(device), obj_idxs.to(device)
                cloud, rgb, model, choose, obj_idxs = cloud.unsqueeze(0), rgb.unsqueeze(0), model.unsqueeze(0), choose.unsqueeze(0), obj_idxs.unsqueeze(0)
                cloud, rgb, model, choose, obj_idxs = cloud.float(), rgb.float(), model.float(), choose.float(), obj_idxs

                cloud_new = cloud.transpose(2, 1)
                rgb_new = torch.moveaxis(rgb, -1, 1)
                choose_new = choose.view(choose.size(0), -1)
                R_quat_pred, t_pred, c_pred = dfnet(cloud_new, rgb_new, choose_new, obj_idxs)

                R_pred = quat_to_rot(R_quat_pred)

                R_pred_opt = R_pred[0][torch.argmax(c_pred[0])]
                t_pred_opt = t_pred[0][torch.argmax(c_pred[0])]

                R_pred, t_pred = free(R_pred_opt), free(t_pred_opt)

                if refine:
                    # icp refinement
                    cloud = free(cloud.squeeze(0))
                    model = free(model.squeeze(0))

                    pred_model = model @ R_pred.T + t_pred
                    R_pred_cloud, t_pred_cloud = R_and_T(run_icp(cloud, pred_model, max_attempts=1, max_iters=1000, finish_loop_thresh=icp_thresh, acceptable_thresh=icp_thresh))
                    T = rigid_transform(model, (pred_model - t_pred_cloud) @ R_pred_cloud)
                else:
                    T = to_T(R_pred, t_pred)

            scene_data['poses_world'][obj_id] = T.tolist()

            data_point_num += 1

        all_data[scene_name] = scene_data

    import json
    with open(output_json_name, 'w') as fp:
        json.dump(all_data, fp)

## HW2: Levels 1-2

DenseFusion

In [5]:
pred_over_raw_data(output_json_name='dfnet_pred.json', processed_data_dir='processed_data', raw_data_dir='raw_data', checkpoint='select-checkpoints/dfnet.pt', refine=False)

dp_num=1499: 100%|██████████| 200/200 [00:43<00:00,  4.64it/s]


DenseFusion + ICP Refinement

In [3]:
pred_over_raw_data(output_json_name='dfnet_pred_refine.json', processed_data_dir='processed_data', raw_data_dir='raw_data', checkpoint='select-checkpoints/dfnet.pt', refine=True)

dp_num=1499: 100%|██████████| 200/200 [00:54<00:00,  3.64it/s]


## HW3: Levels 1-3

DenseFusion

In [6]:
pred_over_raw_data(output_json_name='dfnet_pred_unet.json', processed_data_dir='processed_data_unet', raw_data_dir='raw_data_all', checkpoint='select-checkpoints/dfnet.pt', refine=False)

dp_num=4999: 100%|██████████| 600/600 [02:17<00:00,  4.35it/s]


DenseFusion + ICP Refinement

In [4]:
pred_over_raw_data(output_json_name='dfnet_pred_unet_refine.json', processed_data_dir='processed_data_unet', raw_data_dir='raw_data_all', checkpoint='select-checkpoints/dfnet.pt', refine=True)

dp_num=4999: 100%|██████████| 600/600 [03:41<00:00,  2.71it/s]
