In [None]:
# User-defined
from helpers import train_test_data

In [None]:
# Generate continuation data
train_dataset, test_dataset, info = train_test_data()

In [None]:
# Load necessary packages
import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt

from mpl_toolkits import mplot3d
from jax.random import uniform, PRNGKey

from ANN_1DOF_Data_Gen import generate_data
from ANN_1DOF import Damped_MLP, Damped_LNN

#### LNN Training

In [None]:
settings = {'name': 'ANN_LNN_Damped_Test',
            'lag_units': 128,
            'damp_units': 16,
            'layers': 3,
            'input_shape': 2,
            'train_batch_size': 32,
            'test_batch_size': -1,
            'shuffle': True,
            'seed': 0}

lr = 1e-3
optimizer = optax.adam(lr)
a = Damped_LNN(Damped_MLP, optimizer, settings, info, phy_sys)
a.gather()

In [None]:
# Load previous results, if any
prev_results = a.load_model("results/ANN_Damped_MKK1-C03_200epochs_loss/model.pkl") if False else None

In [None]:
# Start training LNN
results = a.train(train_dataset, test_dataset, results=prev_results, epochs=100, show_every=10)

In [None]:
# Save results
a.save_model(results, 'results/ANN_Damped_MKK1-C03_100epochs_loss')

### Examine results

In [None]:
# Plotting the corresponding lagrangian and damping function for each output in the test dataset
pred_acc_damped, pred_energy_damped = a._predict(results)
q, q_d = jnp.split(test_dataset[0], 2, axis=-1)
n = 500
Lnn, Dnn = pred_energy_damped(q, q_d)

In [None]:
# Comparing accelerations
F = test_dataset[1].reshape(q.shape[0], -1)
q_dd = pred_acc_damped(test_dataset[0], F)[:, -1]

In [None]:
fig = plt.figure(figsize=(14,14))

for i in range(2):
    ax = fig.add_subplot(221 + i, projection='3d')
    ax.plot3D(q.squeeze()[n*i:n*(i+1)], q_d.squeeze()[n*i:n*(i+1)], Lnn.squeeze()[n*i:n*(i+1)], color='black', label= 'Lagrangian function')
    ax.plot3D(q.squeeze()[n*i:n*(i+1)], q_d.squeeze()[n*i:n*(i+1)], Dnn.squeeze()[n*i:n*(i+1)], color='red', label= 'Damping function')
    ax.set_title('Output for wave ' + str(i))
    ax.set_xlabel('q')
    ax.set_ylabel(r'$\dot{q}$')
    plt.legend()

In [None]:
lim1, lim2 = info['qmax'], info['qdmax']

qa, qda = jnp.linspace(-lim1, lim1, 100), jnp.linspace(-lim2, lim2, 100)
qaa, qdaa = jnp.meshgrid(qa, qda)

# Get all energy functions here
L, D = jax.vmap(pred_energy_damped)(qaa.reshape(-1,1,1), qdaa.reshape(-1,1,1))
Lagrange_analy = 0.5*phy_sys['M']*qdaa**2 - 0.5*phy_sys['K']*qaa**2 - 0.25*phy_sys['NL']*qaa**4
Dissipation_analy = 0.5*phy_sys['C']*qdaa**2

In [None]:
fig = plt.figure(figsize=(15,12), tight_layout=True)

ax = fig.add_subplot(221, projection='3d')
m = ax.plot_surface(qaa, qdaa, L.reshape(qaa.shape), alpha=0.5, cmap='plasma')
ax.contour3D(qaa, qdaa, L.reshape(qaa.shape), cmap='binary')
ax.set_xlabel('q')
ax.set_ylabel(r'$\dot{q}$')
ax.set_zlabel(r'$\mathcal{L}_{NN}$', fontsize=16, labelpad=3)
ax.set_title('Lagrangian prediction - overall phase space')
fig.colorbar(m, ax=ax,shrink=0.5, pad=0.075)

ax = fig.add_subplot(222, projection = '3d')
m = ax.plot_surface(qaa, qdaa, Lagrange_analy.reshape(qaa.shape), alpha=0.5, cmap='plasma')
ax.contour3D(qaa, qdaa, Lagrange_analy.reshape(qaa.shape), cmap='binary')
ax.set_title('Analytical Lagrangian')
ax.set_xlabel(r'$q \ (m)$', fontsize=12)
ax.set_ylabel(r'$\dot{q} \ (m \ s^{-1})$ ', fontsize=12)
ax.set_zlabel(r'L (J)', fontsize=14, labelpad=2)
fig.colorbar(m, ax=ax,shrink=0.5, pad=0.075)

ax = fig.add_subplot(223, projection='3d')
m = ax.plot_surface(qaa, qdaa, D.reshape(qaa.shape), alpha=0.7, cmap='magma')
ax.contour3D(qaa, qdaa, D.reshape(qaa.shape), cmap='binary')
ax.set_xlabel('q')
ax.set_ylabel(r'$\dot{q}$')
ax.set_zlabel(r'$\mathcal{D}_{NN}$', fontsize=16, labelpad=3)
ax.set_title('Dissipation prediction - overall phase space')
fig.colorbar(m, ax=ax,shrink=0.5, pad=0.075)

ax = fig.add_subplot(224, projection = '3d')
m = ax.plot_surface(qaa, qdaa, Dissipation_analy.reshape(qaa.shape), alpha=0.7, cmap='magma')
ax.contour3D(qaa, qdaa, Dissipation_analy.reshape(qaa.shape), cmap='binary')
ax.set_title('Analytical Dissipation')
ax.set_xlabel(r'$q \ (m)$', fontsize=12)
ax.set_ylabel(r'$\dot{q} \ (m \ s^{-1})$ ', fontsize=12)
ax.set_zlabel(r'D (J)', fontsize=14, labelpad=2)
fig.colorbar(m, ax=ax,shrink=0.5, pad=0.075)