In [None]:
import torch
import numpy as np
from pathlib import Path

from nemo.util.plotting import plot_surface
from nemo.dem import DEM
from nemo.nemov2 import NEMoV2

device = "cuda" if torch.cuda.is_available() else "cpu"

%load_ext autoreload
%autoreload 2

In [None]:
# tif_path = Path("../../data/Site01_final_adj_5mpp_surf.tif")
# dem = DEM.from_file(tif_path)
# dem = dem.downsample(10)

dem = DEM.from_file("../../data/Moon_Map_01_0_rep0.dat")

In [None]:
plot_surface(dem.data)

In [None]:
xyz = dem.get_xyz_combined()
xy = torch.from_numpy(xyz[:, :2]).float().to(device)
z = torch.from_numpy(xyz[:, 2]).float().to(device)

print(f"Real LDEM data shape: xy={xy.shape}, z={z.shape}")
print(
    f"Coordinate ranges: X({xy[:, 0].min():.3f}, {xy[:, 0].max():.3f}), Y({xy[:, 1].min():.3f}, {xy[:, 1].max():.3f})"
)
print(f"Elevation range: Z({z.min():.3f}, {z.max():.3f})")

nemo = NEMoV2(device=device)
nemo.compute_scaling_parameters(xy, z)
nemo.fit(
    xy,
    z,
    lr=1e-3,
    max_epochs=5000,
    batch_size=20000,
    verbose=False,
    early_stopping=False,
)

In [None]:
nemo(xy)

In [None]:
pred_z = nemo.predict_height(xy)
pred_grid = dem.data.copy()
pred_grid[:, :, 2] = pred_z.detach().cpu().numpy().reshape(dem.shape[:2])

plot_surface(pred_grid)

In [None]:
pred_z

In [None]:
plot_surface(dem.data)

In [None]:
dem.data.shape

In [None]:
plot_surface(xyz.reshape(dem.data.shape))

# LDEM

In [None]:
tif_path = Path("../../data/Site01_final_adj_5mpp_surf.tif")
dem = DEM.from_file(tif_path)

In [None]:
xyz = dem.get_xyz_combined()
xy = torch.from_numpy(xyz[:, :2]).float().to(device)
z = torch.from_numpy(xyz[:, 2]).float().to(device)

print(f"Real LDEM data shape: xy={xy.shape}, z={z.shape}")
print(
    f"Coordinate ranges: X({xy[:, 0].min():.3f}, {xy[:, 0].max():.3f}), Y({xy[:, 1].min():.3f}, {xy[:, 1].max():.3f})"
)
print(f"Elevation range: Z({z.min():.3f}, {z.max():.3f})")

# Create new NEMoV2 instance
nemov2 = NEMoV2(device="cuda" if torch.cuda.is_available() else "cpu")

# Test fitting with conservative settings
print("\nFitting NEMoV2 to real LDEM data...")
nemov2.fit(
    xy,
    z,
    lr=1e-3,
    max_epochs=5000,
    batch_size=20000,
    verbose=False,
    early_stopping=False,
)