In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import math
from IPython import display
from IPython.core.display import HTML
from IPython.core.display import display as html_width
html_width(HTML(""))
import random
from scipy.spatial import distance
import io
from matplotlib import patches
cm = plt.cm.get_cmap('RdYlBu')
import cv2
import pickle
import point_cloud_utils as pcu

Train = False

In [None]:
manifold_resolution = 20_000

# Number of points in the volume to sample around the shape
num_vol_pts = 5_000

# Number of points on the surface to sample
num_surf_pts = 20_000

def get_samples_for_mesh(mesh_path):
    v, f = pcu.load_mesh_vf(mesh_path)
    cv, nv, cf, nf = pcu.connected_components(v, f.astype(np.int32))

    # Extract mesh of connected component with most faces
    comp_max = np.argmax(nf)
    v, f, _, _ = pcu.remove_unreferenced_mesh_vertices(v, f[cf == comp_max])

    vm, fm = pcu.make_mesh_watertight(v, f.astype(np.int32), manifold_resolution)
    p_vol = (np.random.rand(num_vol_pts, 3))*np.array([0.6,0.6,0.6])-np.array([0.3,0.3,0.01]) 
    sdf, _, _  = pcu.signed_distance_to_mesh(p_vol, vm, fm)
    fid_surf, bc_surf = pcu.sample_mesh_random(vm, fm, num_surf_pts)
    f_i, bc = pcu.sample_mesh_random(v, f, num_samples=v.shape[0] * 40)

    # Use the face indices and barycentric coordinate to compute sample positions and normals
    v_sampled = pcu.interpolate_barycentric_coords(f, f_i, bc, v)
    v_sampled =v_sampled[np.abs(v_sampled[:,0])<0.3]
    v_sampled =v_sampled[np.abs(v_sampled[:,1])<0.3]
    return (vm,fm), (p_vol,sdf), v_sampled

In [None]:
import numpy as np
import torch
from os.path import join
from torch.utils.data import Dataset
import trimesh
import glob
import open3d as o3d
import os
data_path = "demo_data/real_world_exp2_exp_ready/"
class SDF_Dataset(Dataset):
    def __init__(self,traj_count,num_of_samples=10000):
        self.demo_meshes = list()
        self.demo_envs = list()
        self.traj_count = traj_count
        self.sdf_points = list()
        self.sdf_values = list()
        self.surface_points = list()    
        self.demo_t = np.zeros((traj_count,100,1))
        self.demo_x = np.zeros((traj_count,100,3))
        self.demo_q = np.zeros((traj_count,100,6))
        self.num_of_samples=num_of_samples

        for p in range(traj_count):
            
            if not os.path.exists(data_path+str(p)+'_processed/'):
                os.makedirs(data_path+str(p)+'_processed/')
                with open(data_path+str(p)+"/data.pickle", 'rb') as handle:
                    data = pickle.load(handle)
                mesh, (sdf_points,sdf_values), surface_points = get_samples_for_mesh(data_path+str(p)+"/mesh_cropped.ply")
                np.save(data_path+str(p)+'_processed/sdf_points',sdf_points)
                np.save(data_path+str(p)+'_processed/sdf_values',sdf_values)
                np.save(data_path+str(p)+'_processed/surface_points',surface_points)
                pcu.save_mesh_vf(data_path+str(p)+'_processed/data.ply', mesh[0], mesh[1])
                with open(data_path+str(p)+"_processed/data.pickle", 'wb') as handle:
                    pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
            else:
                with open(data_path+str(p)+"_processed/data.pickle", 'rb') as handle:
                    data = pickle.load(handle)   
                sdf_points = np.load(data_path+str(p)+'_processed/sdf_points.npy')
                sdf_values = np.load(data_path+str(p)+'_processed/sdf_values.npy')*5
                sdf_values[sdf_values>1]=1
                surface_points = np.load(data_path+str(p)+'_processed/surface_points.npy')
                v, f = pcu.load_mesh_vf(data_path+str(p)+'_processed/data.ply')  
                mesh=(v,f)
            self.demo_t[p,:,0] = np.array(data['t'])-0.5
#             self.demo_x[p] = np.array(data['pos_traj'])
            self.demo_q[p] = np.array(data['joint_traj'])
            self.demo_envs.append(data['env_parameters'])                    
            self.demo_meshes.append(mesh)
            self.sdf_points.append(sdf_points)
            self.sdf_values.append(sdf_values)
            self.surface_points.append(surface_points)
            
                
    def __len__(self):
        return self.traj_count
    
    def __getitem__(self,index):
        random_points_on_traj = np.random.choice(np.arange(100),100)
        
        t_mp = torch.from_numpy(self.demo_t[index,random_points_on_traj]).float()
        q_mp = torch.from_numpy(self.demo_q[index,random_points_on_traj]).float()
        
        
        random_points_on_surface = np.random.choice(np.arange(len(self.surface_points[index])),self.num_of_samples//2)
        random_points_off_surface = np.random.choice(np.arange(len(self.sdf_values[index])),self.num_of_samples//2)
        
        x = torch.from_numpy(np.vstack([self.surface_points[index][random_points_on_surface],
                       self.sdf_points[index][random_points_off_surface]])).float()

        y = {'sdf':torch.from_numpy(np.vstack([np.zeros((self.num_of_samples//2,1)),
                             self.sdf_values[index][random_points_off_surface].reshape(-1,1)])).float(),
             }
        observations =  {
            't': t_mp,
            'mp': q_mp,
            'coords_sdf': x,
            'sdf': y['sdf'],
            'instance_idx':torch.Tensor([index]).squeeze().long()}
    
        ground_truth = {'sdf':observations['sdf'],
                        'mp':  q_mp}

        return observations, ground_truth

In [None]:
from torch.utils.data import DataLoader
import time
from matplotlib import patches
%matplotlib inline
dataset = SDF_Dataset(9)
dataloader = DataLoader(dataset, shuffle=True,batch_size=1, num_workers=0, drop_last = True)

In [None]:
%matplotlib notebook
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(projection='3d')
observations,_ = dataset[2]
surface_points = observations['coords_sdf'][:1000]
sdf_points = observations['coords_sdf'][1000:]
sdf_values = observations['sdf'][1000:].reshape(-1)
ax.scatter3D(sdf_points[sdf_values<0.1,0], sdf_points[sdf_values<0.1,1], sdf_points[sdf_values<0.1,2]);
# ax.scatter3D(sdf_points[sdf_values>0,0], sdf_points[sdf_values>0,1], sdf_points[sdf_values>0,2]);
ax.scatter3D(surface_points[:,0], surface_points[:,1], surface_points[:,2]);

In [None]:
import torch
from torch import nn
from modules import *
from meta_modules import HyperNetwork
from skimage import measure
import plotly.graph_objects as go
l1_loss = torch.nn.L1Loss()

def task_loss_with_deform(model_output, gt):
    embeddings = model_output['latent_vec']    
    embeddings_constraint = torch.mean(embeddings ** 2)
    deform = model_output['deform']
    
    sdf_constraint = l1_loss(model_output['sdf'],gt['sdf'])

    return {'sdf' :torch.abs(sdf_constraint)* 3e4, 
            'mp' : ((model_output['mp']-gt['mp'])**2).mean() * 1e5,           
            'embeddings_constraint': embeddings_constraint.mean() * 1e6,
           }

class NFSMP_Shape(nn.Module):
    def __init__(self, num_instances,latent_dim=128, model_type='relu', hyper_hidden_layers=2,
                 hyper_hidden_features=64,mp_out_size=6,hidden_num=128, affine=True, **kwargs):
        super().__init__()

        self.latent_dim = latent_dim
        self.latent_codes = nn.Embedding(num_instances, self.latent_dim)
        nn.init.normal_(self.latent_codes.weight, mean=0, std=0)
        
        self.mp_net = SingleBVPNet(in_features=1,
                                   out_features=mp_out_size,L =8,pos_encoding=True)
        self.deform_net = SingleBVPNet(in_features=3,
                                    out_features=3,L =2,pos_encoding=True)                 
        self.sdf_net=SingleBVPNet(in_features=3,out_features=1,L =8,pos_encoding=True)

        self.deform_epoch_multiplier = 8.0/2.0
        # Hyper-Net  
        
        self.hyper_net_mp = HyperNetwork(hyper_in_features=self.latent_dim,
                                         hyper_hidden_layers=hyper_hidden_layers,
                                         hyper_hidden_features=hyper_hidden_features,
                                         hypo_module=self.mp_net)           
        self.hyper_net_deform= HyperNetwork(hyper_in_features=self.latent_dim,
                                         hyper_hidden_layers=hyper_hidden_layers, 
                                         hyper_hidden_features=hyper_hidden_features,
                                         hypo_module=self.deform_net)   
        
        last_layer = [layer for layer in self.deform_net.modules() if isinstance(layer, BatchLinear)][-1]
        torch.nn.init.zeros_(last_layer.weight)
        torch.nn.init.zeros_(last_layer.bias)           

    def get_hypo_net_weights(self, model_input):
        instance_idx = model_input['instance_idx']
        embedding = self.latent_codes(instance_idx)
        hypo_params_mp = self.hyper_net_mp(embedding)
        hypo_params_deform = self.hyper_net_deform(embedding)
        return hypo_params_mp, hypo_params_deform, embedding
    def get_latent_code(self,instance_idx):

        embedding = self.latent_codes(instance_idx)

        return embedding
    def inference(self,coords_sdf,embedding,epoch = 1000):

        with torch.no_grad():
            model_out = dict()
            
            hypo_params_deform = self.hyper_net_deform(embedding)
            model_in = {'coords': coords_sdf}
            deform =self.deform_net(model_in,self.deform_epoch_multiplier*epoch, params=hypo_params_deform)['model_out']
            
            model_in = {'coords': coords_sdf+deform}
            model_out['sdf'] =self.sdf_net(model_in,epoch)['model_out']            
            return model_out
    def inference_mp(self,coords_mp,embedding,epoch = 1000):

        with torch.no_grad():
            model_out = dict()
            hypo_params_mp = self.hyper_net_mp(embedding)
            model_in = {'coords': coords_mp}
            model_out['mp'] =self.mp_net(model_in,epoch, params=hypo_params_mp)['model_out']
            return model_out     

    
    def forward(self, model_input,gt,epoch,**kwargs):

        instance_idx = model_input['instance_idx']
        coords_sdf  = model_input['coords_sdf'] 
        t = model_input['t']
        model_in1 = {'coords': t}
        model_in2 = {'coords': coords_sdf}
        
        hypo_params_mp, hypo_params_deform, embedding = self.get_hypo_net_weights(model_input)
        mp = self.mp_net(model_in1,epoch, params=hypo_params_mp)['model_out']      
        deform =self.deform_net(model_in2,self.deform_epoch_multiplier*epoch, params=hypo_params_deform)['model_out']
        model_in3 = {'coords': coords_sdf+deform}  
        
        sdf = self.sdf_net(model_in3,epoch)['model_out']

        model_out = {'t': t,
                     'coords_sdf':coords_sdf,
                     'mp':mp,
                     'sdf':sdf,
                     'deform':deform,                     
                     'latent_vec':embedding,}
        
        losses = task_loss_with_deform(model_out, gt)
        return losses

def estimate_mesh_from_model(model, epoch, subject_idx, resolution = 64, scale = 0.3 ):
    with torch.no_grad():
        N=resolution
        max_batch=64 ** 3
        model.eval()

        # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle
        voxel_origin = [-1, -1, 0]
        voxel_size = 2.0 / (N - 1)

        overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor())
        samples = torch.zeros(N ** 3, 4)

        # transform first 3 columns
        # to be the x, y, z index
        samples[:, 2] = overall_index % N
        samples[:, 1] = (overall_index.long() // N) % N
        samples[:, 0] = ((overall_index.long() // N) // N) % N

        # transform first 3 columns
        # to be the x, y, z coordinate
        samples[:, 0] = (samples[:, 0] * scale * voxel_size) + scale *voxel_origin[0]
        samples[:, 1] = (samples[:, 1] * scale * voxel_size) + scale *voxel_origin[1]
        samples[:, 2] = (samples[:, 2] * scale * voxel_size) + scale *voxel_origin[2]
        num_samples = N ** 3

        samples.requires_grad = False

        head = 0
        subject_idx = torch.Tensor([subject_idx]).squeeze().long().cuda()[None,...]
        embedding = model.get_latent_code(subject_idx)
        while head < num_samples:
            sample_subset = samples[head : min(head + max_batch, num_samples), 0:3].cuda()[None,...]
            samples[head : min(head + max_batch, num_samples), 3] = (
                model.inference(sample_subset,embedding,epoch)['sdf']
                .squeeze()#.squeeze(1)
                .detach()
                .cpu()
            )
            head += max_batch

        sdf_values = samples[:, 3]
        sdf_values = sdf_values.reshape(N, N, N)
        v,t,n,_ = measure.marching_cubes(sdf_values.numpy(),0.001,step_size = 1,spacing=[2/N,2/N,2/N])
        return scale*(v+np.array(voxel_origin)),t,n
def show_mesh(vertices, faces,colors=[]):
    x, y, z = zip(*vertices)
    xt, yt, zt = zip(*faces)

    fig = go.Figure(
        data=[
            go.Mesh3d(
                x=x, 
                y=y, 
                z=z, 
                i = list(xt),
                j = list(yt),
                k = list(zt),
                colorscale='jet',
                intensity=colors,
            ),

    ])
    fig.show()
    return fig    


In [None]:
traj_count=9
model = NFSMP_Shape(traj_count,latent_dim = 128,hidden_num=128)
model.to(device=torch.device('cuda:0'))
optim = torch.optim.Adam([
                {'params': model.sdf_net.parameters()},
                {'params': model.mp_net.parameters()},
                {'params': model.deform_net.parameters()},
                {'params': model.hyper_net_deform.parameters()},
                {'params': model.hyper_net_mp.parameters()},
                {'params': model.latent_codes.parameters(), 'lr': 1e-4},
            ],
    lr=1e-3)

In [None]:
from tqdm import tqdm
import time
import copy
%matplotlib inline

if Train: 
    total_steps=0
    epochs=1500
    lowest_mp_loss = 1000

    with tqdm(total=len(dataloader) * epochs) as pbar:
        train_losses = []
        for epoch in range(epochs):
            if epoch%200==199:
                v,t,n = estimate_mesh_from_model(model,epoch,0,128)
                show_mesh(v,t)
            model.train()
            epoch_mp_losses = []        
            for step, (model_input, gt) in enumerate(dataloader):

                start_time = time.time()
                model_input = {key: value.cuda() for key, value in model_input.items()}
                gt = {key: value.cuda() for key, value in gt.items()}

                losses = model(model_input,gt,epoch)

                train_loss = 0.
                for loss_name, loss in losses.items():
                    single_loss = loss.mean()
                    if epoch %100== 0 and step==0:
                        print(loss_name,single_loss)
                    train_loss += single_loss
                epoch_mp_losses.append(losses['mp'].item())
                train_losses.append(train_loss.item())
                optim.zero_grad()
                train_loss.backward()
                optim.step()
                pbar.update(1)
                pbar.set_postfix(loss=train_loss, time=time.time() - start_time, epoch=epoch)
                total_steps += 1
            if np.mean(epoch_mp_losses)<lowest_mp_loss:
                lowest_mp_loss=np.mean(epoch_mp_losses)
                checkpoint = copy.deepcopy(model.state_dict())
    model.load_state_dict(checkpoint)
    torch.save(model.state_dict(), "real_exp_2")    

In [None]:
if not Train:
    model.load_state_dict(torch.load("real_exp_2"))
    model.eval()

In [None]:
def traj_field(model,grid_size = 100,epoch=1000):
    tlist = np.linspace(-0.5, 0.5, grid_size)
    Z_mp = list() #y.reshape(64,64)
    for index in range(9):
        with torch.no_grad():
            samples = torch.from_numpy(tlist.reshape(-1,1)).float().cuda()            
            samples.requires_grad=False

            subject_idx = torch.Tensor([index]).squeeze().long().cuda()[None,...]
            embedding = model.get_latent_code(subject_idx)
            out = model.inference_mp(samples,embedding,epoch=epoch)['mp'].squeeze().detach().cpu().numpy()
            Z_mp.append(out.reshape(grid_size,6))
            
    return Z_mp 
trajectories = traj_field(model,100,1000)
errors = list()
for i in range(9):
    errors.append(np.linalg.norm(trajectories[i]-dataset.demo_q[i]).mean())
print('MSE', np.mean(errors))

In [None]:
%matplotlib inline
v,t,n = estimate_mesh_from_model(model,1000,1,128)
show_mesh(v,t)

In [None]:
import plotly.offline  
def estimate_mesh_traj_from_model(model, epoch, subject_idx,subject_idx2, resolution = 64,traj_len = 50,scale = 0.3 ):
    with torch.no_grad():
        N=resolution
        max_batch=64 ** 3
        model.eval()
        subject_idx = torch.Tensor([subject_idx]).squeeze().long().cuda()[None,...]
        subject_idx2 = torch.Tensor([subject_idx2]).squeeze().long().cuda()[None,...]
        embedding1 = model.get_latent_code(subject_idx)
        embedding2 = model.get_latent_code(subject_idx2)
        
        
        # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle
        voxel_origin = [-1, -1, 0]
        voxel_size = 2.0 / (N - 1)

        overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor())
        mesh_list=list()
        for interp in range(traj_len):
            embedding = embedding1+(embedding2-embedding1)/(traj_len-1.0)*(interp)
            samples = torch.zeros(N ** 3, 4)

            # transform first 3 columns
            # to be the x, y, z index
            samples[:, 2] = overall_index % N
            samples[:, 1] = (overall_index.long() // N) % N
            samples[:, 0] = ((overall_index.long() // N) // N) % N

            # transform first 3 columns
            # to be the x, y, z coordinate
            samples[:, 0] = (samples[:, 0] * scale * voxel_size) + scale *voxel_origin[0]
            samples[:, 1] = (samples[:, 1] * scale * voxel_size) + scale *voxel_origin[1]
            samples[:, 2] = (samples[:, 2] * scale * voxel_size) + scale *voxel_origin[2]
            num_samples = N ** 3

            samples.requires_grad = False

            head = 0


            while head < num_samples:
                sample_subset = samples[head : min(head + max_batch, num_samples), 0:3].cuda()[None,...]
                samples[head : min(head + max_batch, num_samples), 3] = (
                    model.inference(sample_subset,embedding,epoch)['sdf']
                    .squeeze()#.squeeze(1)
                    .detach()
                    .cpu()
                )
                head += max_batch

            sdf_values = samples[:, 3]
            sdf_values = sdf_values.reshape(N, N, N)
            
            v,t,n,_ = measure.marching_cubes(sdf_values.numpy(),0.0035,step_size = 1,spacing=[2/N,2/N,2/N])
            mesh_list.append((scale*(v+np.array(voxel_origin)),t,n))
        return mesh_list

def draw_mesh_list(mesh_list,fname):
       
    fig = go.Figure(data = [
            go.Mesh3d(
                            x=[],y=[],z=[], 
                            i = [],j = [],k =[],
                            colorscale='jet',
                        ),

        ],
        layout=go.Layout(
                title="Interpolation Animation",
                paper_bgcolor='rgba(0,0,0,0)',
                plot_bgcolor='rgba(0,0,0,0)'
        )
    )
    fig.update_layout(scene = dict(
                        xaxis=dict(range=[-1, 1], autorange=False,
                                showgrid= True, zeroline= False,visible= False,  ),
                        yaxis=dict(range=[-1, 1], autorange=False,
                                showgrid= True, zeroline= False,visible= False,  ),
                        zaxis=dict(range=[-1, 1], autorange=False,
                                showgrid= True, zeroline= False,visible= False,  ),
                        )
                     )

    # Frames
    frames=list()
    opacity = 1
    for i in range(0,len(mesh_list)):
        v,t,_ = mesh_list[i]
        x, y, z = zip(*v)
        xt, yt, zt = zip(*t)
        go_mesh_list = list()
        go_mesh_list.append(go.Mesh3d(x=x,y=y,z=z, 
                        i = list(xt),j = list(yt),k =list(zt),
                        color='gray',
                        opacity = opacity,
                        name = 'Predicted Surface Mesh'
                        ),)
        frames.append(
            go.Frame(data= go_mesh_list,
                    name=f'frame{i}',
                )
            )
    fig.update(frames=frames)




    def frame_args(duration):
        return {
                "frame": {"duration": duration},
                "mode": "immediate",
                "fromcurrent": True,
                "transition": {"duration": duration, "easing": "linear"},
                }


    sliders = [
        {"pad": {"b": 10, "t": 60},
         "len": 0.9,
         "x": 0.1,
         "y": 0,

         "steps": [
                     {"args": [[f.name], frame_args(0)],
                      "label": str(k),
                      "method": "animate",
                      } for k, f in enumerate(fig.frames)
                  ]
         }
            ]

    fig.update_layout(

        updatemenus = [{"buttons":[
                        {
                            "args": [None, frame_args(50)],
                            "label": "Play", 
                            "method": "animate",
                        },
                        {
                            "args": [[None], frame_args(0)],
                            "label": "Pause", 
                            "method": "animate",
                      }],

                    "direction": "left",
                    "pad": {"r": 10, "t": 20},
                    "type": "buttons",
                    "x": 0.1,
                    "y": 0,
                }
             ],
             sliders=sliders
        )

    fig.update_layout(sliders=sliders)

    camera = dict(
        up=dict(x=0, y=-0.707, z=0.707),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=0, y=1, z=-1.25)
    )
    fig.update_layout(scene_aspectmode='cube')
    fig.update_layout(scene_camera=camera)
    plotly.offline.plot(fig, filename=fname+'.html', auto_open=False)
    del fig    

In [None]:
traj = [
    [0,1],
    [1,2],
    [2,3],
    [3,4],
    [4,5],
    [5,6],
    [6,7],
    [7,8],
]
mesh_traj = []
for pair in traj:
    mesh_traj = mesh_traj + estimate_mesh_traj_from_model(model,2000,pair[0],pair[1],64,10)
draw_mesh_list(mesh_traj,'Interpolation_Test')  
# Open Interpolation_Test.html file to view it. Pause it once, and start again for smoother viewing.
# When file is initially opened, html file will look at the object from bottom, you can rotate it with left mouse button