# Use DCP as initializer
---
## Define dataset

In [None]:
import numpy as np
from dcp import train, test
from data import SceneNet
from torch.utils.data import DataLoader
import sys
sys.path.append("dcp-master")
from model import DCP
import torch
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import k3d
from util import transform_point_cloud

def visualize_pointcloud(point_cloud1, point_size, point_cloud2=None, flip_axes=False, name='point_cloud', R=None, t=None):
    plot = k3d.plot(name=name, grid_visible=False, grid=(-0.55, -0.55, -0.55, 0.55, 0.55, 0.55))
    plt_points1 = k3d.points(positions=point_cloud1, point_size=point_size, color=0xd0d0d0)
    plot += plt_points1
    plt_points1.shader = '3d'
    if point_cloud2 is not None:
        plt_points2 = k3d.points(positions=point_cloud2, point_size=point_size, color=0x0dd00d)
        plot += plt_points2
        plt_points2.shader = '3d'
    plot.display()

def transform(point_cloud, R=None, t=None):
    t_broadcast = np.broadcast_to(t[:, np.newaxis], (3, point_cloud.shape[0]))
    return (R @ point_cloud.T + t_broadcast).T

trainDataset = SceneNet(1024, "train", filter=False)
valDataset = SceneNet(1024, "val", filter=False)
train_loader = DataLoader(trainDataset, batch_size=32, shuffle=False, drop_last=False)
test_loader = DataLoader(valDataset, batch_size=1, shuffle=False, drop_last=False)
print(len(trainDataset))
print(len(valDataset))

## Define parameters

In [None]:
args = {
        "model_path": 'checkpoints/dcp_v1/models/model.best.t7',
        # "model_path": 'dcp-master/pretrained/dcp_v2.t7',
        "exp_name":"dcp_v1",
         "model":"dcp", 
         "emb_nn":"dgcnn", 
         "pointer":"identity", 
         "head":"svd", 
         "eval": True,
         'emb_dims': 512,
         'cycle': False,
         'use_sgd': False,
         'lr': 0.001,
         'epochs': 350,
         'n_blocks': 1,
         'dropout': 0.0,
         'ff_dims': 1024,
         'n_heads': 4,
         'use_sgd': False,
         'momentum': 0.9,
        }
net = DCP(args)

## DCP Training

In [None]:

train(args, net, train_loader, test_loader)

## DCP Testing

In [None]:
net.load_state_dict(torch.load(args['model_path']), strict=False)

r, t = test(args, net, test_loader)
# print(r[0])
# print(t[0])

## Visualization

In [None]:

valDataset = SceneNet(20000, "val", filter=True)
src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba = valDataset[42]
# print(rotation_ab)
# print(translation_ab)
# print(points1.shape)
transformed_src = transform_point_cloud(torch.tensor(src), torch.tensor(r[42]).unsqueeze(0), torch.tensor(t[42]).unsqueeze(0)).T
transformed_src2 = transform_point_cloud(torch.tensor(src), torch.tensor(rotation_ab).unsqueeze(0), torch.tensor(translation_ab).unsqueeze(0)).T
# transformed_src1 = transform_point_cloud(torch.tensor(src), r1, t1).T
# visualize_pointcloud(target.T, .03, transformed_src1)
visualize_pointcloud(target.T, .03, transformed_src)
visualize_pointcloud(target.T, .03, transformed_src2)
visualize_pointcloud(target.T, .03, src.T)
# export_pointcloud_to_obj('71_origin_filter.obj', target.T, np.array(src.T))
# export_pointcloud_to_obj('71_dcp_filter.obj', target.T, np.array(transformed_src))

# ICP Calibration

In [None]:
from icp import test

valDataset = SceneNet(1024, "val", icp=True, r=r, t=t, filter=True)
print(len(valDataset))
r_icp, t_icp = test(valDataset)

## Final Visualisation

In [None]:
valDataset = SceneNet(20000, "val", icp=True, r=r, t=t, filter=True)
src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba = valDataset[42]
# print(rotation_ab)
# print(translation_ab)
# print(points1.shape)
transformed_src = transform_point_cloud(torch.tensor(src).double(), torch.tensor(r_icp[0]).unsqueeze(0).double(), torch.tensor(t_icp[0]).unsqueeze(0).double()).T
transformed_src2 = transform_point_cloud(torch.tensor(src), torch.tensor(rotation_ab).unsqueeze(0), torch.tensor(translation_ab).unsqueeze(0)).T
# transformed_src1 = transform_point_cloud(torch.tensor(src), r1, t1).T
# visualize_pointcloud(target.T, .03, transformed_src1)
visualize_pointcloud(target.T, .03, transformed_src)
visualize_pointcloud(target.T, .03, transformed_src2)
visualize_pointcloud(target.T, .03, src.T)

# export_pointcloud_to_obj('71_icp_filter.obj', target.T, np.array(transformed_src))
# export_pointcloud_to_obj('24_icp.obj', target.T, np.array(transformed_src))

## Export

In [None]:
"""Export to disk"""


def export_mesh_to_obj(path, vertices, faces, vertices2):
    """
    exports mesh as OBJ
    :param path: output path for the OBJ file
    :param vertices: Nx3 vertices
    :param faces: Mx3 faces
    :return: None
    """

    # write vertices starting with "v "
    # write faces starting with "f "

    # ###############
    # DONE: Implement
    v = ""
    f = ""
    v2 = ""
    file = open(path, 'w+')

    if vertices is not None:
        for vertice in vertices:
            v = v + "v "
            for i in vertice:
                v = v + str(i) + " "
            v = v + "0.098039 0.8117647 0.\n"
            # v = v + "\n"
        
    count = 1
    if faces is not None:
        for face in faces:
            f = f + "f "
            for i in face:
                f = f + str(i+1) + " "
            f = f + "\n"

    if vertices2 is not None:
        for vertice2 in vertices2:
            v2 = v2 + "v "
            for i in vertice2:
                v2 = v2 + str(i) + " "
            v2 = v2 + "0.717647 0.717647 0.717647\n"  
            # v2 = v2 + "\n"    
    file.write(v)
    file.write(v2)
    file.write(f)
    file.close()
        
    # ###############


def export_pointcloud_to_obj(path, pointcloud, pointcloud2=None):
    """
    export pointcloud as OBJ
    :param path: output path for the OBJ file
    :param pointcloud: Nx3 points
    :return: None
    """

    # ###############
    # DONE: Implement
    export_mesh_to_obj(path, pointcloud, None, pointcloud2)
    # ###############

# print(rigid_body_transformation_params)
# valDataset = SceneNet(10000, "val", icp=True, r=r, t=t)
# for i in range(len(valDataset)):
#     src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba = valDataset[i]
#     transformed_src = transform_point_cloud(torch.tensor(src).double(), torch.tensor(r_icp[i]).unsqueeze(0).double(), torch.tensor(t_icp[i]).unsqueeze(0).double()).T
#     export_pointcloud_to_obj('56_'+str(i)+'.obj', target.T, np.array(transformed_src))