In [1]:
import numpy as np
import cv2
import matplotlib.image as mpimg

from PIL import Image
Image.MAX_IMAGE_PIXELS = 1000000000 

# Lunar DEM

In [2]:
# read img using PIL Image
img = Image.open('SLDEM2015_256_60S_0S_120_240.JP2')

In [3]:
# 
img.size

(30720, 15360)

In [4]:
# convert to numpy array (takes around 1m 30s)
img_np = np.array(img)

In [7]:
# Downscale by 10
img_np_ds = img_np[::10, ::10]


In [27]:
# Downsampling the image by a factor of 100
img.thumbnail((img.size[0]//100, img.size[1]//100))

In [None]:
mpimg.imread('SLDEM2015_256_60S_0S_120_240.JP2')

# SIREN

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict

import plotly.graph_objects as go

# CUDA support 
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

from siren import Siren

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Toy data

In [12]:
# Make some elevation data
# 1. Create a 2D grid of x and y values
# 2. Use these to calculate the elevation at each point

grid_size = 256
XY = np.mgrid[-grid_size/2:grid_size/2, -grid_size/2:grid_size/2]
xvals = XY[0]
yvals = XY[1]
Z = 5 * np.sin(np.sqrt(XY[0]**2 + XY[1]**2) / 10)

# Plot the data
fig = go.Figure(data=[go.Surface(z=Z, x=xvals, y=yvals)])
fig.update_layout(width=1200, height=700, scene_aspectmode='data')
fig.show()

### Mt Bruno elevation data

In [20]:
import pandas as pd

z_data = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/api_docs/mt_bruno_elevation.csv')

# Guess the xy scale
xy = 200 * np.mgrid[-12:13, -12:13]
xvals = xy[0]
yvals = xy[1]

# Plot the data
fig = go.Figure(data=[go.Surface(z=z_data.values, x=xvals, y=yvals)])
fig.update_layout(width=1200, height=700, scene_aspectmode='data')
fig.show()

In [36]:
np.max(z_data.values)

372.8826

In [39]:
# Scale the XY data to -1 to 1
xy_scaled = xy / 2400
x_scaled = xy_scaled[0]
y_scaled = xy_scaled[1]
z_scaled = z_data.values / np.max(z_data.values)

In [40]:
# Plot the scaled data
fig = go.Figure(data=[go.Surface(z=z_scaled, x=x_scaled, y=y_scaled)])
fig.update_layout(width=1200, height=700, scene_aspectmode='data')
fig.show()

In [43]:
# Fit a Siren network to the data

siren = Siren(in_features=2, out_features=1, hidden_features=256,
                hidden_layers=3, outermost_linear=True).to(device)

In [44]:
# Train the network

# Loss function
criterion = nn.MSELoss()

# Optimizer
optimizer = torch.optim.Adam(siren.parameters(), lr=1e-5)

# Convert the data to torch tensors
xy_tensor = torch.tensor(xy_scaled, dtype=torch.float32).to(device)
xy_tensor = xy_tensor.reshape(2, -1).T
z_tensor = torch.tensor(z_scaled, dtype=torch.float32).to(device)
z_tensor = z_tensor.reshape(-1, 1)

# Train the network
for step in range(5000):
    # Forward pass
    pred, coords = siren(xy_tensor)

    # Compute loss
    loss = criterion(pred, z_tensor)

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print loss every 500 steps
    if step % 500 == 0:
        print(f"Step {step}, Loss {loss.item()}")

Step 0, Loss 0.04450414702296257
Step 500, Loss 2.5993063900386915e-05
Step 1000, Loss 6.537982699228451e-06
Step 1500, Loss 9.818478474699077e-07
Step 2000, Loss 6.709277755589937e-08
Step 2500, Loss 3.5772604878303582e-09
Step 3000, Loss 2.2083193051969374e-09
Step 3500, Loss 3.834375091743558e-12
Step 4000, Loss 1.4103086713322839e-14
Step 4500, Loss 2.907750605509081e-15


In [46]:
# Sample the Siren network to get the predicted elevation
with torch.no_grad():
    pred, coords = siren(xy_tensor)

# Plot the predictions
fig = go.Figure(data=[go.Surface(z=pred.cpu().numpy().reshape(25, 25), x=x_scaled, y=y_scaled)])
fig.update_layout(width=1200, height=700, scene_aspectmode='data')
fig.show()

In [47]:
# MSE loss
criterion(pred, z_tensor)

tensor(1.1817e-15, device='cuda:0')