In [None]:
import numpy as np
import plotly.graph_objects as go

import torch
import torch.nn as nn
import torch.nn.functional as F

# CUDA support 
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

from siren import Siren

# autoreload
%load_ext autoreload
%autoreload 2

In [None]:
# Load the siren elevation model
siren = Siren(in_features=2, out_features=1, hidden_features=256,
                hidden_layers=3, outermost_linear=True).to(device)

siren.load_state_dict(torch.load('models/mt_bruno_siren.pt'))
siren.eval()
pass

In [None]:
# Visualize it
xs = torch.linspace(-1, 1, steps=100, device=device)
ys = torch.linspace(-1, 1, steps=100, device=device)
x, y = torch.meshgrid(xs, ys, indexing='xy')
xy = torch.hstack((x.reshape(-1, 1), y.reshape(-1, 1)))

with torch.no_grad():
    pred, coords = siren(xy)

# Plot the predictions
fig = go.Figure(data=[go.Surface(z=pred.cpu().numpy().reshape(100, 100), x=x.cpu().numpy(), y=y.cpu().numpy())])
fig.update_layout(width=1200, height=700, scene_aspectmode='data')
fig.show()

In [None]:
from scipy.interpolate import CubicSpline

# Given three points
xy = np.array([[-1.0, -1.0],
               [-0.2, 0.3],
               [0.0, 0.0],
               [0.5, 0.6],
               [1.0, 1.0]])

# Fit a cubic spline
cs = CubicSpline(xy[:,0], xy[:,1])

# Plotting the cubic spline
x_plot = np.linspace(-1, 1, 100)
y_plot = cs(x_plot)

fig = go.Figure()
fig.add_trace(go.Scatter(x=x_plot, y=y_plot))
fig.add_trace(go.Scatter(x=xy[:,0], y=xy[:,1], mode='markers'))
# axis equal
fig.update_layout(width=500, height=500, scene_aspectmode='data')
fig.show()

In [None]:
cs.c

In [None]:
siren_traj = Siren(in_features=1, out_features=2, hidden_features=256,
                hidden_layers=3, outermost_linear=True).to(device)

In [None]:
# xy = np.array([[0.0, 0.0],
#                [0.2, 0.3],
#                [0.3, 0.5],
#                [0.5, 0.6],
#                [1.0, 1.0]])
xy = np.array([[-1.0, -1.0],
               [-0.2, 0.3],
               [0.0, 0.0],
               [0.5, 0.6],
               [1.0, 1.0]])

In [None]:
# Loss function
criterion = nn.MSELoss()

# Optimizer
optimizer = torch.optim.Adam(siren_traj.parameters(), lr=1e-5)

# Convert the data to torch tensors
t_tensor = torch.linspace(0, 1, len(xy)).to(device)[:,None]
xy_tensor = torch.tensor(xy, dtype=torch.float32).to(device)

# Train the network
for step in range(5000):
    # Forward pass
    pred, coords = siren_traj(t_tensor)

    # Compute loss
    loss = criterion(pred, xy_tensor)

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print loss every 500 steps
    if step % 500 == 0:
        print(f"Step {step}, Loss {loss.item()}")

In [None]:
t_tensor = torch.linspace(0, 1, 100).to(device)[:,None]
pred, coords = siren_traj(t_tensor)

fig = go.Figure()
fig.add_trace(go.Scatter(x=pred[:,0].cpu().detach().numpy(), y=pred[:,1].cpu().detach().numpy()))
fig.add_trace(go.Scatter(x=xy[:,0], y=xy[:,1], mode='markers'))
# axis equal
fig.update_layout(width=500, height=500, scene_aspectmode='data')
fig.show()