In [None]:
""" same p1-p4 for all the patches """

In [None]:
import os
import numpy as np
import open3d as o3d

import time
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random
from functools import partial
from jaxopt import GaussNewton

In [None]:
class fft_fit():

    def __init__(self):

        self.gn = GaussNewton(residual_fun=self.error_fn,maxiter=10,verbose=False,implicit_diff=True,jit=True )
        

    @partial(jit, static_argnums = (0,))
    def get_kernel(self, xy, p1, p2, p3, p4, lambdas):
        """
        Creates a fft kernel from the local point cloud patch
        """
        c_term = jnp.vstack((p1, p2))
        s_term = jnp.vstack((p3, p4))
        kernel = jnp.hstack((jnp.cos(xy@c_term), jnp.sin(xy@s_term)))
        
        pred_z = (kernel@lambdas)[:,0]

        return pred_z

    @partial(jit, static_argnums = (0,))
    def solve(self,lam,p1,p2,p3,p4,pcd):
        lamdas, state = self.gn.run(lam,p1,p2,p3,p4,pcd)
        return lamdas
    
    @partial(jit, static_argnums = (0,))
    def error_fn(self, lambdas, p1, p2, p3, p4, pcd):  
        lambdas = lambdas.reshape(-1,1)
        xy = pcd[:, :2]
        gt_z = pcd[:, -1]
        pred_z = self.get_kernel(xy, p1, p2, p3, p4, lambdas)
        error = jnp.hstack((pred_z - gt_z, 0.15*lambdas.squeeze(axis=-1)))

        return error

In [None]:
data_dir = "./src/terrain_mlp/fft_new/data/concatenated-cloud/run-0"

# total number of .npz files
end_frame = len([name for name in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, name)) and name.endswith(".npz")])
print("total samples: ", end_frame)

# .npz files named primitives-data-0 to primitives-data-{end_frame}
trainingfile_names = [f"data{i}.npz" for i in range(1, end_frame)]
trainingfile_paths = [os.path.join(data_dir, file) for file in trainingfile_names]

In [None]:
pcd_data = []
N = 35000
# i = 0
for file_path in trainingfile_paths:
    data = np.load(file_path)      
    cloud = data['cloud']
    pc = o3d.geometry.PointCloud()
    pc.points = o3d.utility.Vector3dVector(cloud)
    pc = pc.voxel_down_sample(voxel_size=0.06)
    pcd = np.asarray(pc.points, dtype=np.float32)
    if pcd.shape[0]>35000:
        indices = np.random.choice(len(pcd), N, replace=False)
        sampled_pcd = pcd[indices]
        # spc = o3d.geometry.PointCloud()
        # spc.points = o3d.utility.Vector3dVector(sampled_pcd)
        # o3d.visualization.draw_geometries([spc])
        pcd_data.append(sampled_pcd)
        
pcd_array = np.array(pcd_data)
print(pcd_array.shape)

xyz = np.nan_to_num(pcd_array,posinf=0.0,neginf=0.0)

In [None]:
pcd_jax = jnp.asarray(xyz)
num_lamdas = 200
num_params = 100

lam_data = jnp.zeros((pcd_jax.shape[0],num_lamdas))

key1 = jax.random.PRNGKey(0)
key2 = jax.random.PRNGKey(1)
key3 = jax.random.PRNGKey(2)
key4 = jax.random.PRNGKey(3)
key5 = jax.random.PRNGKey(4)

#aux_vars
p1 = random.normal(key1, shape=(num_params,), dtype=jnp.float32)
p2 = random.normal(key2, shape=(num_params,), dtype=jnp.float32)
p3 = random.normal(key3, shape=(num_params,), dtype=jnp.float32)
p4 = random.normal(key4, shape=(num_params,), dtype=jnp.float32)

fft_obj = fft_fit()

for s in range(pcd_jax.shape[0]):

    key5,subkey5 = jax.random.split(key5)

    # opt_vars
    lam = random.normal(key5, shape=(num_lamdas,), dtype=jnp.float32)

    t1 = time.time()
    lamdas = fft_obj.solve(lam,p1,p2,p3,p4,pcd_jax[s,:,:])
    t2 = time.time()
    print("case = ", s, "Time to construct FFT = ", t2-t1)

    z_new = fft_obj.get_kernel(pcd_jax[s,:,0:2], p1, p2, p3, p4, lamdas.reshape(-1,1))
    z_new = jnp.clip(z_new,-6.0,6.0)

    error = z_new - pcd_jax[s,:,2]
    print("error", jnp.linalg.norm(error))

    lam_data = lam_data.at[s,:].set(lamdas) 


In [None]:
cn = np.random.randint(0, 25)
min_xy = jnp.min(pcd_jax[cn,:,:],axis=0)
max_xy = jnp.max(pcd_jax[cn,:,:],axis=0)
min_x = min_xy[0]
max_x = max_xy[0]
min_y = min_xy[1]
max_y = max_xy[1]
# create test data
xs = jnp.arange(min_x,max_x,0.2, dtype=jnp.float32)
ys = jnp.arange(min_y,max_y,0.2, dtype=jnp.float32)
X,Y = jnp.meshgrid(xs,ys)
xp = X.flatten()
yp = Y.flatten()
zp = jnp.zeros(xp.shape, dtype=jnp.float32)
xyz_test = jnp.vstack((xp,yp,zp)).T

z_new = fft_obj.get_kernel(xyz_test[:, :2], p1, p2, p3, p4, lam_data[cn,:].reshape(-1,1))
z_new = jnp.clip(z_new,-6.0,6.0)

pc_orig = pcd_jax[cn,:,:]
pc_new = xyz_test

pc_new = pc_new.at[:,-1].set(z_new)
gt_pc = o3d.geometry.PointCloud()
gt_pc.points = o3d.utility.Vector3dVector(pc_orig)
t_pc = o3d.geometry.PointCloud()
t_pc.points = o3d.utility.Vector3dVector(pc_new)
gt_pc.paint_uniform_color([1, 0, 0])
t_pc.paint_uniform_color([0, 1, 0])
o3d.visualization.draw_geometries([gt_pc, t_pc])