Test out 2D Gaussian Splatting stuff

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go

# autoreload
%load_ext autoreload
%autoreload 2

# Toy Example

Fit a 2D Gaussian in 3D to 4 points

In [5]:
# Grid cell
xy = np.array([[0, 0], [1, 0], [1, 1], [0, 1]])
zs = np.array([0.3, 0.0, 0.2, 0.1])

# Plot the points
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=xy[:, 0], y=xy[:, 1], z=zs, mode='markers', marker=dict(size=5)))
fig.update_layout(scene=dict(aspectmode='data'))
fig.show()

In [6]:
# Centroid
centroid_z = np.mean(zs)
centroid_xy = np.mean(xy, axis=0)
centroid = np.array([centroid_xy[0], centroid_xy[1], centroid_z])

# Best fit plane
A = np.c_[xy, np.ones(xy.shape[0])]
coeffs, _, _, _ = np.linalg.lstsq(A, zs, rcond=None)
n = np.array([coeffs[0], coeffs[1], -1])
n = n / np.linalg.norm(n)

# Visualize the plane
x = np.linspace(0, 1, 10)
y = np.linspace(0, 1, 10)
x, y = np.meshgrid(x, y)
z = coeffs[0]*x + coeffs[1]*y + coeffs[2]

fig.add_trace(go.Surface(x=x, y=y, z=z))
fig.show()

In [None]:
# Simple: just pick circle in plane with xy aligned vectors

# Principal tangential vectors
if n[0] == n[1] == 0:
    tu = np.array([1, 0, 0])
    tv = np.array([0, 1, 0])
else:
    tu = np.array([n[1], -n[0], 0])
    tv = np.cross(n, tu)



# Lunar data test

Data available here: https://svs.gsfc.nasa.gov/4720/
* Only matching resolution between image and DEM is 9600 x 3240
    * Corresponds to 1.14 km per pixel
* Highest color resolution is 27360 x 13680
    * Corresponds to 0.4 km per pixel
* Highest DEM resolution is 23040 x 11520
    * Corresponds to 0.47 km per pixel

In [None]:
import tifffile

In [None]:
img_path = '../data/lroc_color_poles_hw5x3.tif'  # 9600 x 3240
img_data = tifffile.imread(img_path)
H, W = img_data.shape[:2]

dem_path = '../data/ldem_hw5x3.tif'  # 9600 x 3240
dem_data = tifffile.imread(dem_path)

In [None]:
# Plot the patches side by side
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(img_data)
ax[1].imshow(dem_data)
plt.show()

In [None]:
patch_size = 100
center = np.array([1000, 1000])
img_patch = img_data[center[0]-patch_size//2:center[0]+patch_size//2, center[1]-patch_size//2:center[1]+patch_size//2]
dem_patch = dem_data[center[0]-patch_size//2:center[0]+patch_size//2, center[1]-patch_size//2:center[1]+patch_size//2]

In [None]:
dem_patch

In [None]:
# Plot the patches side by side
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(img_patch)
ax[1].imshow(dem_patch)
plt.show()