In [0]:
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, FloatType
import numpy as np

K = 2048
VOXEL = 0.2

# 1) Load tiled points
df = spark.table("processed_points_tiled_v2")

# 2) Voxel downsample
df = (
    df
    .withColumn("vx", F.floor(F.col("x") / VOXEL))
    .withColumn("vy", F.floor(F.col("y") / VOXEL))
    .withColumn("vz", F.floor(F.col("z") / VOXEL))
)

df = (
    df.groupBy("siteId", "tileId", "vx", "vy", "vz")
      .agg(
          F.avg("x").alias("x"),
          F.avg("y").alias("y"),
          F.avg("z").alias("z")
      )
)

# 3) Collect points per tile
df = (
    df.groupBy("siteId", "tileId")
      .agg(
          F.collect_list(F.array("x","y","z")).alias("points")
      )
)

# 4) Sample K points (simple FPS-style approximation)
@F.udf(returnType=ArrayType(ArrayType(FloatType())))
def sample_k_fps(points):
    pts = np.array(points)  # shape: (N, 3)
    if len(pts) <= K:
        return pts.tolist()
    
    # FPS: Farthest Point Sampling
    N = len(pts)
    selected = np.zeros(N, dtype=bool)
    dists = np.full(N, np.inf)
    
    idx = np.random.randint(N)
    selected[idx] = True
    result = [idx]
    
    selected_pts = pts[idx].reshape(1, -1)
    
    for _ in range(1, K):
        cur_dists = np.linalg.norm(pts - selected_pts[-1], axis=1)
        dists = np.minimum(dists, cur_dists)

        new_idx = np.argmax(dists)
        selected[new_idx] = True
        result.append(new_idx)
        selected_pts = np.vstack((selected_pts, pts[new_idx]))
    
    return pts[result].tolist()

df = df.withColumn("points_k", sample_k_fps("points"))

# 5) Normalize
@F.udf(returnType=ArrayType(ArrayType(FloatType())))
def normalize(points):
    pts = np.array(points)
    center = pts.mean(axis=0)
    pts = pts - center
    scale = np.max(np.linalg.norm(pts, axis=1))
    return (pts / scale).tolist()

df = df.withColumn("tensor", normalize("points_k"))

# Final ML-ready dataset
ml_dataset = df.select("siteId", "tileId", "tensor")

ml_dataset.show(5)
