In [1]:
""" same p1-p4 for all the patches """

' same p1-p4 for all the patches '

In [2]:
import os
import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt
import matplotlib.cm as mpl
from scipy.interpolate import griddata

import time
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from functools import partial
from jax import random
import jaxopt
from jaxopt import GaussNewton

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
class fft_fit():

    def __init__(self):

        self.gn = GaussNewton(residual_fun=self.error_fn,maxiter=10,verbose=False,implicit_diff=True,jit=True )
        

    @partial(jit, static_argnums = (0,))
    def get_kernel(self, xy, p1, p2, p3, p4, lambdas):
        """
        Creates a fft kernel from the local point cloud patch
        """
        c_term = jnp.vstack((p1, p2))
        s_term = jnp.vstack((p3, p4))
        kernel = jnp.hstack((jnp.cos(xy@c_term), jnp.sin(xy@s_term)))
        
        pred_z = (kernel@lambdas)[:,0]

        return pred_z

    @partial(jit, static_argnums = (0,))
    def solve(self,lam,p1,p2,p3,p4,pcd):
        lamdas, state = self.gn.run(lam,p1,p2,p3,p4,pcd)
        return lamdas
    
    @partial(jit, static_argnums = (0,))
    def error_fn(self, lambdas, p1, p2, p3, p4, pcd):  
        lambdas = lambdas.reshape(-1,1)
        xy = pcd[:, :2]
        gt_z = pcd[:, -1]
        pred_z = self.get_kernel(xy, p1, p2, p3, p4, lambdas)
        error = jnp.hstack((pred_z - gt_z, 0.15*lambdas.squeeze(axis=-1)))

        return error

In [4]:
data_dir = "./src/terrain_mlp/fft_new/data/concatenated-cloud/run-0"

# total number of .npz files
end_frame = len([name for name in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, name)) and name.endswith(".npz")])
print("total samples: ", end_frame)

# .npz files named primitives-data-0 to primitives-data-{end_frame}
trainingfile_names = [f"data{i}.npz" for i in range(1, end_frame)]
trainingfile_paths = [os.path.join(data_dir, file) for file in trainingfile_names]

total samples:  466


In [5]:
pcd_data = []
N = 35000
# i = 0
for file_path in trainingfile_paths:
    data = np.load(file_path)      
    cloud = data['cloud']
    pc = o3d.geometry.PointCloud()
    pc.points = o3d.utility.Vector3dVector(cloud)
    pc = pc.voxel_down_sample(voxel_size=0.06)
    pcd = np.asarray(pc.points, dtype=np.float32)
    if pcd.shape[0]>35000:
        indices = np.random.choice(len(pcd), N, replace=False)
        sampled_pcd = pcd[indices]
        # spc = o3d.geometry.PointCloud()
        # spc.points = o3d.utility.Vector3dVector(sampled_pcd)
        # o3d.visualization.draw_geometries([spc])
        pcd_data.append(sampled_pcd)

    # if i>25:
    #     break
    # i += 1
        
pcd_array = np.array(pcd_data)
print(pcd_array.shape)

xyz = np.nan_to_num(pcd_array,posinf=0.0,neginf=0.0)

# batch_size = xyz.shape[0]
# patch_size = xyz.shape[1]
# pcd_shift = np.zeros((batch_size,patch_size,3))
# temp = np.min(xyz[:, :, :2], axis=1)
# pcd_shift[:,:,0] = xyz[:,:,0]-np.expand_dims(temp[:,0], axis=-1)
# pcd_shift[:,:,1] = xyz[:,:,1]-np.expand_dims(temp[:,1], axis=-1)
# pcd_shift[:,:,2] = xyz[:,:,2]

pcd_shift = xyz

(185, 35000, 3)


In [6]:
# fig, ax = plt.subplots()
# plt.rcParams["font.weight"] = "bold"
# ax.set_xlabel('x (m)')
# ax.set_ylabel('y (m)')
# ax.scatter(pcd_shift[10,:,0],pcd_shift[10,:,1])
# plt.show()

# print(pcd_shift[0,:,0].min())
# print(pcd_shift[0,:,1].min())

In [7]:
pcd_jax = jnp.asarray(pcd_shift)
num_lamdas = 200
num_params = 100

lam_data = jnp.zeros((pcd_jax.shape[0],num_lamdas))

key1 = jax.random.PRNGKey(0)
key2 = jax.random.PRNGKey(1)
key3 = jax.random.PRNGKey(2)
key4 = jax.random.PRNGKey(3)
key5 = jax.random.PRNGKey(4)

#aux_vars
p1 = random.normal(key1, shape=(num_params,), dtype=jnp.float32)
p2 = random.normal(key2, shape=(num_params,), dtype=jnp.float32)
p3 = random.normal(key3, shape=(num_params,), dtype=jnp.float32)
p4 = random.normal(key4, shape=(num_params,), dtype=jnp.float32)

fft_obj = fft_fit()

for s in range(pcd_jax.shape[0]):

    key5,subkey5 = jax.random.split(key5)

    # opt_vars
    lam = random.normal(key5, shape=(num_lamdas,), dtype=jnp.float32)

    t1 = time.time()
    lamdas = fft_obj.solve(lam,p1,p2,p3,p4,pcd_jax[s,:,:])
    t2 = time.time()
    print("case = ", s, "Time to construct FFT = ", t2-t1)

    z_new = fft_obj.get_kernel(pcd_jax[s,:,0:2], p1, p2, p3, p4, lamdas.reshape(-1,1))
    z_new = jnp.clip(z_new,-6.0,6.0)

    error = z_new - pcd_jax[s,:,2]
    print("error", jnp.linalg.norm(error))

    lam_data = lam_data.at[s,:].set(lamdas) 


case =  0 Time to construct FFT =  1.7108113765716553
error 7.561141
case =  1 Time to construct FFT =  1.124314308166504
error 7.541844
case =  2 Time to construct FFT =  1.1340575218200684
error 7.5473533
case =  3 Time to construct FFT =  1.1400468349456787
error 7.5563173
case =  4 Time to construct FFT =  1.1503703594207764
error 7.803787
case =  5 Time to construct FFT =  1.1491847038269043
error 19.975752
case =  6 Time to construct FFT =  1.1524317264556885
error 7.1939597
case =  7 Time to construct FFT =  1.140822410583496
error 18.183222
case =  8 Time to construct FFT =  1.139664888381958
error 8.260234
case =  9 Time to construct FFT =  1.1514763832092285
error 8.62023
case =  10 Time to construct FFT =  1.1397755146026611
error 9.192831
case =  11 Time to construct FFT =  1.1356730461120605
error 10.8666
case =  12 Time to construct FFT =  1.1367783546447754
error 10.297227
case =  13 Time to construct FFT =  1.146758794784546
error 8.392444
case =  14 Time to construct F

In [8]:
# cn = np.random.randint(0, 25)
# min_xy = jnp.min(pcd_jax[cn,:,:],axis=0)
# max_xy = jnp.max(pcd_jax[cn,:,:],axis=0)
# min_x = min_xy[0]
# max_x = max_xy[0]
# min_y = min_xy[1]
# max_y = max_xy[1]
# # create test data
# xs = jnp.arange(min_x,max_x,0.2, dtype=jnp.float32)
# ys = jnp.arange(min_y,max_y,0.2, dtype=jnp.float32)
# X,Y = jnp.meshgrid(xs,ys)
# xp = X.flatten()
# yp = Y.flatten()
# zp = jnp.zeros(xp.shape, dtype=jnp.float32)
# xyz_test = jnp.vstack((xp,yp,zp)).T

# z_new = fft_obj.get_kernel(xyz_test[:, :2], p1_data[cn,:], p2_data[cn,:], p3_data[cn,:], p4_data[cn,:], lam_data[cn,:].reshape(-1,1))
# z_new = jnp.clip(z_new,-6.0,6.0)

# pc_orig = pcd_jax[cn,:,:]
# pc_new = xyz_test

# pc_new = pc_new.at[:,-1].set(z_new)
# gt_pc = o3d.geometry.PointCloud()
# gt_pc.points = o3d.utility.Vector3dVector(pc_orig)
# t_pc = o3d.geometry.PointCloud()
# t_pc.points = o3d.utility.Vector3dVector(pc_new)
# gt_pc.paint_uniform_color([1, 0, 0])
# t_pc.paint_uniform_color([0, 1, 0])
# o3d.visualization.draw_geometries([gt_pc, t_pc])

In [None]:
# lam_data = np.asarray(lam_data)
# p1_data = np.asarray(p1)
# p2_data = np.asarray(p2)
# p3_data = np.asarray(p3)
# p4_data = np.asarray(p4)

# np.savez("./src/terrain_mlp/fft_new/data/pcd_fit_data_new",pcd=pcd_shift,lam_data=lam_data,p1_data=p1_data,p2_data=p2_data,p3_data=p3_data,p4_data=p4_data)