In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import os
import open3d as o3d
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim
import pytorch_volumetric as pv
import matplotlib.pyplot as plt
from tqdm import tqdm


from usdf.utils.render_utils import depth_to_free_points
from mmint_tools import tr, matrix_to_pose, pose_to_matrix
from mmint_tools.camera_tools.img_utils import project_depth_image
from mmint_tools.notebook_tools.notebook_tools import view_points, view_mesh, view_pointcloud, view_points_groups
from mmint_tools.camera_tools.pointcloud_utils import mesh_to_pointcloud, generate_partial_pc, generate_partial_view, pack_o3d_pcd, unpack_o3d_pcd, tr_pointcloud

In [None]:
# load ycb mug
TEST_PATH = '/home/mik/Desktop/test'
YCB_MUG_PATH = '/home/mik/Downloads/025_mug/google_16k/nontextured.ply'

# Define basic functions:

In [None]:
def rotate_model_zaxis(model, theta):
    # model (k, 3)
    # thetas(...)
    # out: (..., k,3)
    ctheta = torch.cos(theta)
    stheta = torch.sin(theta)
    ones = torch.ones_like(theta, dtype=theta.dtype)
    zeros = torch.zeros_like(theta, dtype=theta.dtype)
    R1x = torch.stack([ctheta, -stheta, zeros], dim=-1)
    R2x = torch.stack([stheta, ctheta, zeros], dim=-1)
    R3x = torch.stack([zeros, zeros, ones], dim=-1)
    R = torch.stack([R1x, R2x, R3x], dim=-2)
    # transform the model points
    model_tr = torch.einsum('...ij,kj->...ki', R, model)
    return model_tr

def tr_points(points, T):
    # points (...,3) torch tensor
    # T: (4,4) torch tensor
    points_hom = torch.cat([points, torch.ones_like(points[...,0:1])], axis=-1)
    points_tr_hom = torch.einsum('ij,...j->...i', T, points_hom)
    points_tr = points_tr_hom[...,:3]
    return points_tr

# Load Mugs

In [None]:
# load and visualize 

mesh = o3d.io.read_triangle_mesh(YCB_MUG_PATH)
pc = mesh_to_pointcloud(mesh, num_points=5000)

In [None]:
view_pointcloud(pc)

In [None]:
# load ycb mug with no handle to correct for the offset
YCB_MUG_NOHANDLE_PATH = '/home/mik/Downloads/025_mug/google_16k/nontextured_no_handle.ply'
mesh_nohandle = o3d.io.read_triangle_mesh(YCB_MUG_NOHANDLE_PATH)
pc_nohandle = mesh_to_pointcloud(mesh_nohandle, num_points=5000)
pc_nohandle_mean = np.mean(pc_nohandle, axis=0)
print(pc_nohandle_mean)
view_pointcloud(pc_nohandle)

### Center the pointcloud 

In [None]:
pc_mean = np.mean(pc, axis=0) # note that the y axis is symmetic, while x is not.
center_tr = np.array([-pc_nohandle_mean[0], -pc_nohandle_mean[1], 0, 0, 0, 0, 1])

pc_centered = tr_pointcloud(pc, T=pose_to_matrix(center_tr))

view_pointcloud(pc_centered)


In [None]:
# verify it is centered by visualizing a set of transformations
points_centered = torch.tensor(pc_centered[..., :3], dtype=torch.float32)
angles = torch.linspace(-np.pi,np.pi, 6)
models_rotated = rotate_model_zaxis(points_centered, angles)

palette = sns.color_palette("Spectral", as_cmap=True)
color_list = palette(np.linspace(0,1,len(models_rotated)))
points = models_rotated.detach().cpu().numpy()
view_points_groups(points_xyz_list=points, 
                       colors_list=color_list, 
                       marker_size_list=[5 for i in color_list])


Conclusion: We have a good fit

### Get a partial view

In [None]:
partial_pcd = generate_partial_pc(mesh, view_axis=np.array([-1.0, .0, .0]), look_at=np.array([0, 0, 0.04]))

# further crop the pointcloud to get only the points that are close to the maximum in x
x_min = np.min(partial_pcd.points, axis=0)[0]
max_d = 0.02
partial_pcd = partial_pcd.crop(o3d.geometry.AxisAlignedBoundingBox([-np.inf, -np.inf, -np.inf], [x_min + max_d, np.inf, np.inf]))
    
partial_pc = unpack_o3d_pcd(partial_pcd)
print('Num of points, ', len(partial_pc))

num_points = 1000
partial_pc_downsampled = partial_pc[np.random.choice(np.arange(len(partial_pc)), num_points, replace=False)]
print(partial_pc_downsampled.shape)
view_pointcloud(partial_pc_downsampled)

### Obtain free space information

In [None]:
partial_depth_img, camera = generate_partial_view(mesh, view_axis=np.array([-1., .0, .0]), look_at=np.array([0, 0, 0.04]))
partial_depth_ar = np.asarray(partial_depth_img)
print(partial_depth_ar.shape)

In [None]:
# show partial
plt.imshow(partial_depth_ar)

In [None]:
def get_free_points(depth_img, K, min_depth=0., max_depth=1., num_steps=10, w_X_c=None):
    # create a grid of depths to be evaluates
    h,w = depth_img.shape[0], depth_img.shape[1]
    depth_values = np.linspace(min_depth, max_depth, num_steps+1)[1:]
    all_depths = np.repeat(np.expand_dims(depth_values, axis=1), w*h, axis=1).reshape(num_steps, h,w) # (num_steps, h, w)
    all_Ks = np.repeat(np.expand_dims(K, axis=0), num_steps, axis=0)
    # project the depths
    img_xyzs = project_depth_image(all_depths, all_Ks) # (num_steps, h, w, 3)
    depths = img_xyzs[..., -1] # (num_steps, h, w)
    # mask dpeths that z is closer than the depths img value
    depth_mask = depth_img.copy()
    depth_mask[np.where(depth_img == 0.0)] = np.inf
    mask = depths < depth_mask[None, :, :]
    pc_out = img_xyzs[mask] # (N, 3) where N<=num_steps*h*w
    if w_X_c is not None:
        pc_out = tr_pointcloud(pc_out, T=w_X_c)
    return pc_out
    

In [None]:
free_points_w = get_free_points(partial_depth_ar, K=camera.intrinsic.intrinsic_matrix, min_depth=0.04, max_depth=0.2, num_steps=15, w_X_c=np.linalg.inv(camera.extrinsic))
free_points_w.shape

In [None]:
# sample some and visualize them
num_free_points = 3000 
free_points_w_sampled = free_points_w[np.random.choice(free_points_w.shape[0], num_free_points, replace=False)]

In [None]:
view_points(free_points_w_sampled)

In [None]:
# get pointcloud from 

all_partial_points_from_depth_img = project_depth_image(partial_depth_ar, K=intrinsic_matrix)
partial_points_from_depth_img = all_partial_points_from_depth_img[np.where(all_partial_points_from_depth_img[...,-1] > 0)]
print(partial_points_from_depth_img.shape)
partial_points_from_depth_img = tr_pointcloud(partial_points_from_depth_img, T=np.linalg.inv(camera.extrinsic))
view_points(partial_points_from_depth_img.reshape(-1,3))

# Process the pointcloud

In [None]:
T_center = pose_to_matrix(center_tr)

In [None]:
partial_pc_centered = tr_pointcloud(partial_pc_downsampled, T=T_center)
free_points_w_sampled_centered = tr_pointcloud(free_points_w_sampled, T=T_center)

In [None]:
class ModelRotation(nn.Module):

    def __init__(self, model_pc):
        super().__init__()
        self.model = torch.tensor(model_pc[...,:3], dtype=torch.float32) # shape (N, 3)

    def forward(self, theta):
        # rotate the pointcloud by thetas
        # transform the model points
        model_tr = rotate_model_zaxis(self.model, theta)
        return model_tr
    

In [None]:
model_rotation = ModelRotation(pc_centered)

thetas = torch.tensor([0., np.pi*0.25, np.pi*0.5, np.pi*0.75, np.pi], dtype=torch.float32)

point_rotated = model_rotation(thetas)

point_rotated.shape

In [None]:
def view_fit_with_partial_view(pc):
    fig = view_points_groups(points_xyz_list=[pc_centered[:,:3], partial_pc_centered[:,:3]], 
                       colors_list=[np.array([0, 0., 0.]), np.array([1., 0., 0.])], 
                       marker_size_list=[5, 5])
    return fig

In [None]:
view_fit_with_partial_view(point_rotated[0])

In [None]:
view_fit_with_partial_view(point_rotated[1])

In [None]:
view_fit_with_partial_view(point_rotated[2])

In [None]:
view_fit_with_partial_view(point_rotated[3])

In [None]:
view_fit_with_partial_view(point_rotated[4])

In [None]:
def view_fit_with_partial_and_free_sapce_view(pc):
    fig = view_points_groups(points_xyz_list=[pc[:,:3], partial_pc_centered[:,:3], free_points_w_sampled_centered], 
                       colors_list=[np.array([0, 0., 0.]), np.array([1., 0., 0.]) , np.array([.0, 0., 1.])], 
                       marker_size_list=[5, 5, 5])
    return fig

In [None]:
view_fit_with_partial_and_free_sapce_view(point_rotated[1])

In [None]:
view_fit_with_partial_and_free_sapce_view(point_rotated[4])

## Build the SDF to get free space and occupied space information

In [None]:
obj = pv.MeshObjectFactory(YCB_MUG_PATH)
sdf = pv.MeshSDF(obj)
query_range = np.array([
    [-0.15, 0.15],
    [-0.15, 0.15],
    [0.0, 0.0],
])
_ = pv.draw_sdf_slice(sdf, query_range)

In [None]:
# Test with a set of free space points
EPS = 1e-5
sdf_val, sdf_grad = sdf(free_points_w_sampled)
print('num_inside_points:', torch.sum(sdf_val < -EPS))
print('num_outside_points:', torch.sum(sdf_val > EPS))
print('num_surface_points:', torch.sum(torch.abs(sdf_val) < EPS))

In [None]:
torch.sum(sdf_val < 0)

In [None]:
EPS = 1e-4
sdf_val, sdf_grad = sdf(partial_pc_downsampled[...,:3])
print('num_inside_points:', torch.sum(sdf_val < -EPS))
print('num_outside_points:', torch.sum(sdf_val > EPS))
print('num_surface_points:', torch.sum(torch.abs(sdf_val) < EPS))

In [None]:
def sdf_with_gradients_wrapper(sdf_function):
    class SDFWithGradients(torch.autograd.Function):
        @staticmethod
        def forward(ctx, points):
            sdf_value, sdf_grad = sdf_function(points)
            ctx.save_for_backward(sdf_grad)
            return sdf_value
        
        @staticmethod
        def backward(ctx, grad_outputs):
            sdf_grad, = ctx.saved_tensors
#             import pdb; pdb.set_trace()
            grad_out = grad_outputs[...,None]*sdf_grad
            return grad_out
            
    return SDFWithGradients.apply
        
    

In [None]:
class SDFRotationModel(nn.Module):

    def __init__(self, mesh_path, init_tr=None):
        super().__init__()
        obj = pv.MeshObjectFactory(mesh_path)
        self.sdf = sdf_with_gradients_wrapper(pv.MeshSDF(obj))
        self.T = None
        if init_tr is not None:
            self.T = torch.linalg.inv(init_tr)
            
    def forward(self, points, theta):
        # points (..., N, 3)
        # theta (..., 1)
        query_points = self._get_query_points(points, theta) # (..., N, 3)
        sdf_val = self.sdf(query_points) # (..., N)
        return sdf_val
    
    def _get_query_points(self, points, theta):
        # points (..., N, 3)
        # theta (..., 1)
        # rotate the pointcloud by thetas
        R = self._build_rotation_matrix(-theta) # (..., 3, 3) # NOTE that we rotate -theta since we need to put it back to the reference frame
        # transform the model points
        query_points = torch.einsum('...ij,...kj->...ki', R, points)
        # transform if init_tr:
        if self.T is not None:
            query_points = self._tr_points(query_points, T=self.T)
        return query_points
    
    def _tr_points(self, points, T):
        # points: (..., 3)
        # T: (4, 4) matrix
        points_tr = tr_points(points, T)
        return points_tr
        
    def _build_rotation_matrix(self, theta):
        ctheta = torch.cos(theta)
        stheta = torch.sin(theta)
        ones = torch.ones_like(theta, dtype=theta.dtype)
        zeros = torch.zeros_like(theta, dtype=theta.dtype)
        R1x = torch.stack([ctheta, -stheta, zeros], dim=-1)
        R2x = torch.stack([stheta, ctheta, zeros], dim=-1)
        R3x = torch.stack([zeros, zeros, ones], dim=-1)
        R = torch.stack([R1x, R2x, R3x], dim=-2)
        return R
    

In [None]:
center_tr

In [None]:
init_tr = torch.tensor(pose_to_matrix(center_tr), dtype=torch.float32)
sdf_rotation_model = SDFRotationModel(YCB_MUG_PATH, init_tr=init_tr)

In [None]:
points = torch.rand((len(thetas),50,3), dtype=torch.float32)
# thetas = torch.tensor([0, 0.25*np.pi, 0.5*pi, 0.], dtype=torch.float32)
sdf_val = sdf_rotation_model(points, thetas)
sdf_val.shape

In [None]:
# class CustomLoss(torch.nn.Module):
#     def __init__(self, free_space_points, surface_points):
#         super().__init__()
#         self.eps = 3.5e-4
#         self.free_space_points = free_space_points # (M, 3)
#         self.surface_points = surface_points # (P, 3)
        
#     def forward(self, sdf_model, thetas):
#         # thetas: (N, 1)
#         # loss (N,)
#         surface_loss = self._compute_surface_loss(sdf_model, thetas) #(N,)
#         free_loss = self._compute_free_loss(sdf_model, thetas) #(N, )
#         loss = surface_loss + 100*free_loss
# #         print('Surface loss:', surface_loss)
# #         print('Free space loss:', free_loss)
#         return loss
    
#     def _compute_surface_loss(self, sdf_model, thetas):
#         N = thetas.shape[0]
#         surface_query_points = torch.repeat_interleave(self.surface_points.unsqueeze(0), N, dim=0) # (N, M, 3)
#         sdf_surface_val, _ = sdf_model(self.surface_points, thetas) # (N, P)
#         sdf_surface_val[torch.abs(sdf_surface_val) < self.eps] = 0
#         surface_loss = torch.mean(torch.abs(sdf_surface_val), dim=1) #(N,)
#         return surface_loss
    
#     def _compute_free_loss(self, sdf_model, thetas):
#         N = thetas.shape[0]
#         free_query_points = torch.repeat_interleave(self.free_space_points.unsqueeze(0), N, dim=0) # (N, P, 3)
#         sdf_free_val, _ = sdf_model(self.free_space_points, thetas) # (N, M)
#         free_loss = torch.mean(torch.abs(torch.minimum(torch.zeros_like(sdf_free_val), sdf_free_val)), dim=1) #(N, )
#         return free_loss


class CustomLoss(torch.nn.Module):
    def __init__(self, free_space_points, surface_points):
        super().__init__()
        self.eps = 3.5e-4
        self.free_space_points = free_space_points # (M, 3)
        self.surface_points = surface_points # (P, 3)
        
    def forward(self, sdf_model, thetas):
        # thetas: (N, 1)
        # loss (N,)
        surface_loss = self._compute_surface_loss(sdf_model, thetas) #(N,)
        free_loss = self._compute_free_loss(sdf_model, thetas) #(N, )
        loss = surface_loss + 100*free_loss
#         print('Surface loss:', surface_loss)
#         print('Free space loss:', free_loss)
        return loss
    
    def _compute_surface_loss(self, sdf_model, thetas):
        N = thetas.shape[0]
        surface_query_points = torch.repeat_interleave(self.surface_points.unsqueeze(0), N, dim=0) # (N, M, 3)
        sdf_surface_val = sdf_model(self.surface_points, thetas) # (N, P)
#         sdf_surface_val[torch.abs(sdf_surface_val) < self.eps] = 0
        sdf_surface_val = F.threshold(sdf_surface_val, threshold=self.eps, value=0) - F.threshold(-sdf_surface_val, threshold=self.eps, value=0) # this is the same as above
        surface_loss = torch.mean(torch.abs(sdf_surface_val), dim=1) # (N,)
        return surface_loss
    
    def _compute_free_loss(self, sdf_model, thetas):
        N = thetas.shape[0]
        free_query_points = torch.repeat_interleave(self.free_space_points.unsqueeze(0), N, dim=0) # (N, P, 3)
        sdf_free_val = sdf_model(self.free_space_points, thetas) # (N, M)
        free_loss = torch.mean(torch.abs(torch.minimum(torch.zeros_like(sdf_free_val), sdf_free_val)), dim=1) #(N, )
        return free_loss

In [None]:
free_points = torch.tensor(free_points_w_sampled_centered, dtype=torch.float32)
surface_points = torch.tensor(partial_pc_centered[...,:3], dtype=torch.float32)
loss_fn = CustomLoss(free_points, surface_points)

In [None]:
thetas

In [None]:
loss_values = loss_fn(sdf_rotation_model, thetas)
print('Thetas:', thetas)
print(loss_values)

In [None]:
thetas = torch.tensor([0., 0.05*np.pi, -0.05*np.pi, np.pi*0.25, np.pi*0.5, np.pi*0.75, np.pi], dtype=torch.float32)

loss_values = loss_fn(sdf_rotation_model, thetas)
print('Thetas:', thetas)
print(loss_values)

In [None]:
point_rotated = model_rotation(thetas)

In [None]:
view_fit_with_partial_and_free_sapce_view(point_rotated[0])

In [None]:
view_fit_with_partial_and_free_sapce_view(point_rotated[2])

In [None]:
query_range = np.array([
    [-0.15, 0.15],
    [-0.15, 0.15],
    [0.0235, 0.0235],
])
_ = pv.draw_sdf_slice(sdf, query_range)

In [None]:
query_range = np.array([
    [-0.15, 0.15],
    [-0.15, 0.15],
    [0.0235, 0.0235],
])

class SDFRotated(nn.Module):
    def __init__(self, sdf_rot_model, theta):
        super().__init__()
        self.sdf_rot_model = sdf_rot_model
        self.theta = theta
        
    def forward(self, points):
        return self.sdf_rot_model(points, self.theta)
        
        
sdf_rot = SDFRotated(sdf_rotation_model, torch.tensor([np.pi*.75], dtype=torch.float32))
_ = pv.draw_sdf_slice(sdf_rot, query_range)

In [None]:
points_centered = torch.tensor(pc_centered[:,:3], dtype=torch.float32)

def debug_theta(theta):
    
    free_points_ref = sdf_rotation_model._get_query_points(free_points, theta)[0]
    surface_points_ref = sdf_rotation_model._get_query_points(surface_points, theta)[0]
    points_ref = sdf_rotation_model._get_query_points(points_centered, theta)[0]
    fig = view_points_groups(points_xyz_list=[pc[:,:3], points_ref,  surface_points_ref, free_points_ref], 
                       colors_list=[np.array([0, 0., 0.]),np.array([0, 1., 0.]), np.array([1., 0., 0.]) , np.array([.0, 0., 1.])], 
                       marker_size_list=[5, 5, 5, 5])
    return fig
    

In [None]:
debug_theta(torch.tensor([0.], dtype=torch.float32))

In [None]:
debug_theta(torch.tensor([0.5*np.pi], dtype=torch.float32))

In [None]:
# plot the loss values acrross a set of thetas

theta_range = torch.linspace(-np.pi, np.pi, 100)
free_loss = loss_fn._compute_free_loss(sdf_rotation_model, theta_range)
surface_loss = loss_fn._compute_surface_loss(sdf_rotation_model, theta_range)
loss_values = loss_fn(sdf_rotation_model, theta_range)

plt.plot(theta_range, loss_values, label='loss')
plt.plot(theta_range, free_loss, label='free_loss')
plt.plot(theta_range, surface_loss, label='surface_loss')
plt.legend()

# Testing the SVGD:

# Convert the cost function into a probability

In [None]:
class RBF(torch.nn.Module):
    def __init__(self, sigma=None):
        super(RBF, self).__init__()
        self.sigma = sigma

    def forward(self, X, Y):
#         import pdb; pdb.set_trace()
        if len(X.shape) == 1:
            X = X.unsqueeze(-1)
        if len(Y.shape) == 1:
            Y = Y.unsqueeze(-1)
        XX = X.matmul(X.t())
        XY = X.matmul(Y.t())
        YY = Y.matmul(Y.t())

        dnorm2 = -2 * XY + XX.diag().unsqueeze(1) + YY.diag().unsqueeze(0)

        # Apply the median heuristic (PyTorch does not give true median)
        if self.sigma is None:
            np_dnorm2 = dnorm2.detach().cpu().numpy()
            h = np.median(np_dnorm2) / (2 * np.log(X.size(0) + 1))
            sigma = np.sqrt(h).item()
        else:
            sigma = self.sigma

        gamma = 1.0 / (1e-8 + 2 * sigma ** 2)
        K_XY = (-gamma * dnorm2).exp()

        return K_XY
  
# Let us initialize a reusable instance right away.
K = RBF()

In [None]:
class SVGD:
    def __init__(self, P, K, optimizer):
        self.P = P
        self.K = K
        self.optim = optimizer

    def phi(self, X):
        X = X.detach().requires_grad_(True)

        log_prob = self.P.log_prob(X)
        score_func = autograd.grad(log_prob.sum(), X)[0]

        K_XX = self.K(X, X.detach())
        grad_K = -autograd.grad(K_XX.sum(), X)[0]
        phi = (K_XX.detach().matmul(score_func) + grad_K) / X.size(0)

        return phi

    def step(self, X):
        self.optim.zero_grad()
        X.grad = -self.phi(X)
        self.optim.step()

In [None]:
class CostProb(object):
    def __init__(self, cost_fnc):
        self.cost = cost_fnc
        
    def log_prob(self, x):
        logp = -self.cost(x)
#         import pdb; pdb.set_trace()
        return logp
        

In [None]:
def cost_function_wrapper(thetas):
    cost = loss_fn(sdf_rotation_model,thetas)
    return cost

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

n = 20

X_init = torch.tensor(np.random.uniform(-np.pi, np.pi, (n,)), dtype=torch.float32).to(device)
X_init.device

X = X_init.clone()
cost_prob = CostProb(cost_function_wrapper)
svgd = SVGD(cost_prob, K, optim.Adam([X], lr=3e-3))
for _ in tqdm(range(1000)):
    svgd.step(X)

In [None]:
def plot_particles(X_star, X_init):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    costs_init = cost_function_wrapper(X_init)
    costs_star = cost_function_wrapper(X_star)
    min_theta = min(-np.pi, min(X_star.min(), X_init.min() ))
    max_theta = max(np.pi, max(X_star.max(), X_init.max() ))
    theta_range = torch.linspace(min_theta, max_theta, 100)
    loss_values = loss_fn(sdf_rotation_model, theta_range)
    axes[0].plot(theta_range, loss_values, label='loss')
    axes[0].scatter(X_init, costs_init, color='red')
    axes[1].plot(theta_range, loss_values, label='loss')
    axes[1].scatter(X_star, costs_star, color='red')
    

In [None]:
plot_particles(X, X_init)

In [None]:
# plot the loss values acrross a set of thetas

theta_range = torch.linspace(-np.pi, np.pi, 100)
loss_values = loss_fn(sdf_rotation_model, theta_range)
cost_fun_values = torch.exp(-loss_values)
plt.plot(theta_range, cost_fun_values, label='loss')
plt.legend()