In [None]:
import torch
import tinycudann as tcnn
import numpy as np
import plotly.graph_objects as go

from nerfstudio.field_components.spatial_distortions import SceneContraction

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

%load_ext autoreload
%autoreload 2

In [None]:
encoding = tcnn.Encoding(
    n_input_dims=2,
    encoding_config={
        "otype": "HashGrid",
        "n_levels": 16,
        "n_features_per_level": 8,
        "log2_hashmap_size": 19,
        "base_resolution": 16,
        "per_level_scale": 1.2599210739135742,
    },
)

In [None]:
tot_out_dims_2d = 128

heightcap_net = tcnn.Network(
    n_input_dims=tot_out_dims_2d,
    n_output_dims=1,
    network_config={
        "otype": "CutlassMLP",
        "activation": "Sine",
        "output_activation": "None",
        "n_neurons": 256,
        "n_hidden_layers": 1,
    },
)

In [None]:
heightcap_net.load_state_dict(torch.load('../models/red_rocks_height_net.pth'))
heightcap_net.to(device)
pass

In [None]:
encoding.load_state_dict(torch.load('../models/red_rocks_encs.pth'))
encoding.to(device)
pass

In [None]:
N = 512
bound = 0.75
XY_grid = torch.meshgrid(
    torch.linspace(-bound, bound, N, device=device),
    torch.linspace(-bound, bound, N, device=device),
    indexing='xy'
)
XY_grid = torch.stack(XY_grid, dim=-1)
positions = XY_grid.reshape(-1, 2)
xy = positions.detach().cpu().numpy()
x = xy[:,0] 
y = xy[:,1] 

In [None]:
spatial_distortion = SceneContraction()

In [None]:
positions = torch.cat([positions, torch.zeros_like(positions[..., :1])], dim=-1)
positions = spatial_distortion(positions)
positions = (positions + 2.0) / 4.0

In [None]:
pos_encd = encoding(positions[:, :2])
heights = heightcap_net(pos_encd)

In [None]:
heights

In [None]:
z = heights.detach().cpu().numpy()

fig = go.Figure(data=[go.Surface(x=x.reshape(N, N), y=y.reshape(N, N), z=z.reshape(N, N))])
fig.update_layout(title='Elevation Model', width=1500, height=800)
fig.update_layout(scene_aspectmode='data')
fig.show()

In [None]:
# save html
fig.write_html("red_rocks_MLP_sine.html")