In [None]:
import numpy as np
import plotly.graph_objects as go

import torch
import torch.nn as nn
import torch.nn.functional as F

from siren import Siren
from global_planner import GlobalPlanner

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# autoreload
%load_ext autoreload
%autoreload 2

# Mt Bruno SIREN

In [None]:
# Load the siren elevation model
siren = Siren(in_features=2, out_features=1, hidden_features=256,
                hidden_layers=3, outermost_linear=True).to(device)

siren.load_state_dict(torch.load('models/mt_bruno_siren.pt'))
siren.eval()
pass

In [None]:
# Visualize it
xs = torch.linspace(-1, 1, steps=100, device=device)
ys = torch.linspace(-1, 1, steps=100, device=device)
x, y = torch.meshgrid(xs, ys, indexing='xy')
xy = torch.hstack((x.reshape(-1, 1), y.reshape(-1, 1)))

pred, coords = siren(xy)

# Plot the predictions
fig = go.Figure(data=[go.Surface(z=pred.detach().cpu().numpy().reshape(100, 100), x=x.cpu().numpy(), y=y.cpu().numpy())])
fig.update_layout(width=1200, height=700, scene_aspectmode='data')
fig.show()

In [None]:
# Get gradients
def gradient(y, x, grad_outputs=None):
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
    return grad

z_xy_grad = gradient(pred, coords)

x_grad = z_xy_grad[:, 0].detach().cpu().numpy().reshape(100, 100)
y_grad = z_xy_grad[:, 1].detach().cpu().numpy().reshape(100, 100)

In [None]:
GRID_LEN = 100

start_idx = (0, 0)                   # (-1, -1)
end_idx = (GRID_LEN-1, GRID_LEN-1)   # (1, 1)
# start_idx = (0, GRID_LEN-1)           # (-1, -1)
# end_idx = (GRID_LEN-1, 0)             # (1, 1)

grad_costmat = (np.abs(x_grad) + np.abs(y_grad))

In [None]:
# Run A* on the cost matrix
gp = GlobalPlanner(grad_costmat)
path = gp.plan(start_idx, end_idx)

path_xs = xs[path[:,0]]
path_ys = ys[path[:,1]]
path_xy = torch.hstack((path_xs[:,None], path_ys[:,None]))
path_zs, _ = siren(path_xy)

In [None]:
# Plot path on surface plot
fig = go.Figure()
fig.add_trace(go.Surface(z=pred.detach().cpu().numpy().reshape(100, 100), x=x.cpu().numpy(), y=y.cpu().numpy()))
fig.add_trace(go.Scatter3d(x=path_xs.detach().cpu().numpy(), 
                           y=path_ys.detach().cpu().numpy(), 
                           z=path_zs.detach().cpu().numpy().flatten(), 
                           mode='markers', marker=dict(size=3, color='red')))
fig.update_layout(width=1200, height=700, scene_aspectmode='data')
fig.show()

# Lunar DEM

In [None]:
dem_full = np.load('data/lunar_dem.npy')
Z = dem_full[:1000, :1000]

# Surface plot approximately to scale
xs = torch.linspace(0, 118500, steps=Z.shape[0], device=device)
ys = torch.linspace(0, 118400, steps=Z.shape[1], device=device)
x, y = torch.meshgrid(xs, ys, indexing='xy')

fig = go.Figure(data=[go.Surface(z=Z, x=x.cpu().numpy(), y=y.cpu().numpy())])
fig.update_layout(width=1200, height=700, scene_aspectmode='data')
fig.show()

In [None]:
# Compute gradients
x_grad = np.gradient(Z, axis=0)
y_grad = np.gradient(Z, axis=1)

grad_costmat = (np.abs(x_grad) + np.abs(y_grad))

In [None]:
start_idx = (0, 0)                  
end_idx = (999, 999)  

gp = GlobalPlanner(grad_costmat)
path = gp.plan(start_idx, end_idx)

In [None]:
path_xs = xs[path[:,0]]
path_ys = ys[path[:,1]]
path_zs = Z[path[:,1], path[:,0]]

In [None]:
# Plot path on surface plot
fig = go.Figure()
fig.add_trace(go.Surface(z=Z, x=x.cpu().numpy(), y=y.cpu().numpy()))
fig.add_trace(go.Scatter3d(x=path_xs.detach().cpu().numpy(), 
                           y=path_ys.detach().cpu().numpy(), 
                           z=path_zs, 
                           mode='markers', marker=dict(size=3, color='red')))
fig.update_layout(width=1200, height=700, scene_aspectmode='data')
fig.show()