In [1]:
# NN packages
import jax
import jax.numpy as jnp
import optax

# Visualization packages
import matplotlib.pyplot as plt

# ML Models
from LNN.models.MDOF_LNN import Physical_Damped_LNN, Modal_MLP

# Helper functions
from LNN.helpers import save_to_file, create_modal_training_data

from runscript import run

#### Data Extraction
In this section, we extract the `pose`, `velocity`, `acceleration`, `time`, `force amplitude` and `period` from each continuation simulation file. This dataset was created for **frequencies** ranging from $10.0Hz$ to $24.0Hz$ in steps of $0.2Hz$, where for each, the continuation parameter was the **forcing amplitude**.

The output is in **modal coordinates**.

In [2]:
filename='frequency_step_frequency_'
path='LNN/results/modal'
start=10.0
stop=24.0
step=0.2

ml_data = save_to_file(filename=filename, path=path, start=start, stop=stop, step=step, check=True)

Data saved to LNN/results/modal/data.pkl
Number of files: 71

---EXAMPLE SHAPES---
pose: (2, 301, 39), vel: (2, 301, 39), acc: (2, 301, 39)
If MODAL: 2 Modes, 301 time steps per 39 points along curve
If PHYSICAL: 301 time steps per 39 points along curve
time: (301, 39), F: (39,), T: (39,), force: (2, 301, 39)


#### LNN Dataset Formation

In [3]:
train_data, test_data, info = create_modal_training_data(ml_data, path, split=0.2, seed=42)

Training data shape: (805175, 2, 4), Testing data shape: (207389, 2, 4)
Samples, # of Modes, [x, dx, ddx, force]
x_train.shape: (805175, 2), dx_train.shape: (805175, 2), ddx_train.shape: (805175, 2), force_train.shape: (805175, 2)
x_test.shape: (207389, 2), dx_test.shape: (207389, 2), ddx_test.shape: (207389, 2), force_test.shape: (207389, 2)


In [4]:
# Format dataset for LNN
# Position, velocity & total forcing conditions
train_x = train_data[:, :, :2]
train_dx = train_data[:, :, 1:3]
train_f = train_data[:, :, 3:]

test_x = test_data[:, :, :2]
test_dx = test_data[:, :, 1:3]
test_f = test_data[:, :, 3:]

train_data = train_x, train_f, train_dx
test_data = test_x, test_f, test_dx

In [5]:
train_data[0].shape, test_data[0].shape, train_data[1].shape, test_data[1].shape, train_data[2].shape, test_data[2].shape

((805175, 2, 2),
 (207389, 2, 2),
 (805175, 2, 1),
 (207389, 2, 1),
 (805175, 2, 2),
 (207389, 2, 2))

#### LNN


In [6]:
mnn_settings = {
    'name': 'MNN',
    'units': 64,
    'layers': 4,
    'input_shape': 4,
    'train_batch_size': 128,
    'test_batch_size': 16,
    'shuffle': True,
    'seed': 69
    }

knn_settings = {
    'name': 'KNN',
    'units': 64,
    'layers': 4,
    'input_shape': 4,
    }

dnn_settings = {
    'name': 'DNN',
    'units': 32,
    'layers': 4,
    'input_shape': 2,
    }

results_path = 'MDOF_LNN'
file_name='Phys'

lr = 1e-03
mnn_optimizer = optax.adam(lr)
knn_optimizer = optax.adam(lr)
dnn_optimizer = optax.adam(lr)
epochs = 20
show_every = 10

In [7]:
a = Physical_Damped_LNN(
    mnn_module=Modal_MLP, 
    knn_module=Modal_MLP,       
    dnn_module=Modal_MLP, 
    mnn_settings=mnn_settings,
    knn_settings=knn_settings,
    dnn_settings=dnn_settings, 
    mnn_optimizer=mnn_optimizer, 
    knn_optimizer=knn_optimizer, 
    dnn_optimizer=dnn_optimizer, 
    info=info, 
    activation=jax.nn.tanh)

# Start training LNN
results = None
_, _, _ = a.gather()

In [8]:
results_path = 'MDOF_LNN'
file_name='Modal'

epochs = 20
show_every = 10

In [9]:
# # Standard loss
# for _ in range(10):
#     results = a.train(train_data, test_data, results, epochs=epochs, show_every=show_every)
#     a.save_model(results, model_name=f"Iter_{results['last_epoch']}", folder_name=f"{results_path}/{file_name}")
# print(f"Final loss: {results['best_loss']}")

In [10]:
results = Physical_Damped_LNN.load_model("./LNN/MDOF_LNN/Modal/Iter_200/model.pkl")

In [11]:
pred_acc_, pred_energy = a._predict(results)

In [None]:
import jax
import jax.numpy as jnp
import jax.experimental.ode as jode

from diffrax import diffeqsolve, ODETerm, Tsit5, SaveAt, PIDController

def model_ode(t, X, args):
    T, F = args
    # State equation of the mass spring system. Xdot(t) = g(X(t))
    x = X[:2]
    xdot = X[2:]

    force = jnp.array([F * jnp.sin(2 * jnp.pi / T * t), 0])
    _X = jnp.concatenate(
        (x[None, :, None], xdot[None, :, None]), axis=-1)
    _force = force[None, :, None]
    xddot = pred_acc_(_X, _force)

    Xdot = jnp.concatenate((xdot, xddot[0]))
    return Xdot

term=ODETerm(model_ode)
solver=Tsit5()
saveat = SaveAt(ts=jnp.linspace(0, 2, 10 + 1))

def periodicity(X0, T, F):
    t = jnp.linspace(0, T, 10 + 1)
    # Xsol = jode.odeint(jax.jit(model_ode), X0, t, T, F, rtol=1e-8)
    Xsol = diffeqsolve(term, solver, t0=t[0], t1=t[-1], dt0=0.2, y0=X0, args=(T, F), saveat=saveat, stepsize_controller=PIDController(rtol=1.4e-8, atol=1.4e-8),max_steps=None)
    H = Xsol.ys[-1, :] - Xsol.ys[0, :]
    return H.reshape(-1, 1)

X0 = jnp.array([0.0, 0.0, 1.0, 1.0])
T = 2.0
F = 10.0

periodicity(X0, T, F)

In [None]:
import scipy.linalg as spl
import jax.numpy as jnp
from scipy.integrate import odeint

# Beam paremeters for 2 mode system
k_nl = 4250000
w_1 = 91.734505484821950
w_2 = 3.066194429903638e02
zeta_1 = 0.0
zeta_2 = 0.0
phi_L = jnp.array([[-7.382136522799137, 7.360826867549465]])

# Modal Matrices
M = jnp.eye(2)
C = jnp.array([[2 * zeta_1 * w_1, 0], [0, 2 * zeta_2 * w_2]])
K = jnp.array([[w_1**2, 0], [0, w_2**2]])
Minv = spl.inv(M)

def model_ode_anal(t, X, T, F):
    # State equation of the mass spring system. Xdot(t) = g(X(t))
    x = X[:2]
    xdot = X[2:]
    KX = K @ x
    CXdot = C @ xdot
    force = jnp.array([F * jnp.sin(2 * jnp.pi / T * t), 0])
    phi_x = phi_L @ x  # Physical displacement at nonlinear location
    fnl = k_nl * phi_L.T @ (phi_x**3)
    Xdot = jnp.concatenate((xdot, Minv @ (-KX - CXdot - fnl + force)))
    return Xdot

def periodicity_anal(X0, T, F):
    t = jnp.linspace(0, T, 10 + 1)
    Xsol = odeint(model_ode_anal, X0, t, args=(T, F), rtol=1e-8, tfirst=True)
    H = Xsol[-1, :] - Xsol[0, :]
    return H.reshape(-1, 1)

X0 = jnp.array([0.0, 0.0, 1.0, 1.0])
T = 2.0
F = 10.0

periodicity_anal(X0, T, F)

  Xsol = odeint(model_ode_anal, X0, t, args=(T, F), rtol=1e-8, tfirst=True)


array([[ 0.36862745],
       [ 0.30980392],
       [-0.36470588],
       [ 0.        ]])

In [30]:
import scipy.linalg as spl
import jax.numpy as jnp
from scipy.integrate import solve_ivp

# Beam paremeters for 2 mode system
k_nl = 4250000
w_1 = 91.734505484821950
w_2 = 3.066194429903638e02
zeta_1 = 0.0
zeta_2 = 0.0
phi_L = jnp.array([[-7.382136522799137, 7.360826867549465]])

# Modal Matrices
M = jnp.eye(2)
C = jnp.array([[2 * zeta_1 * w_1, 0], [0, 2 * zeta_2 * w_2]])
K = jnp.array([[w_1**2, 0], [0, w_2**2]])
Minv = spl.inv(M)

def model_ode_anal(t, X, T, F):
    # State equation of the mass spring system. Xdot(t) = g(X(t))
    x = X[:2]
    xdot = X[2:]
    KX = K @ x
    CXdot = C @ xdot
    force = jnp.array([F * jnp.sin(2 * jnp.pi / T * t), 0])
    phi_x = phi_L @ x  # Physical displacement at nonlinear location
    fnl = k_nl * phi_L.T @ (phi_x**3)
    Xdot = jnp.concatenate((xdot, Minv @ (-KX - CXdot - fnl + force)))
    return Xdot

def periodicity_anal(X0, T, F):
    t = jnp.linspace(0, T, 10 + 1)
    Xsol = solve_ivp(model_ode_anal, (t[0], t[-1]), X0, args=(T, F), t_eval=jnp.linspace(0, 2, 10 + 1), rtol=1e-8)
    H = Xsol.y[:, -1] - Xsol.y[:, 0]
    return H.reshape(-1, 1)

X0 = jnp.array([0.0, 0.0, 1.0, 1.0])
T = 2.0
F = 10.0

periodicity_anal(X0, T, F)

array([[-0.00490432],
       [-0.00292757],
       [-1.94054514],
       [-1.08723555]])

In [None]:
import scipy.linalg as spl
import jax.numpy as jnp
import jax.experimental.ode as jode

# Beam paremeters for 2 mode system
k_nl = 4250000
w_1 = 91.734505484821950
w_2 = 3.066194429903638e02
zeta_1 = 0.0
zeta_2 = 0.0
phi_L = jnp.array([[-7.382136522799137, 7.360826867549465]])

# Modal Matrices
M = jnp.eye(2)
C = jnp.array([[2 * zeta_1 * w_1, 0], [0, 2 * zeta_2 * w_2]])
K = jnp.array([[w_1**2, 0], [0, w_2**2]])
Minv = spl.inv(M)

def model_ode_anal(X, t, T, F):
    # State equation of the mass spring system. Xdot(t) = g(X(t))
    x = X[:2]
    xdot = X[2:]
    KX = K @ x
    CXdot = C @ xdot
    force = jnp.array([F * jnp.sin(2 * jnp.pi / T * t), 0])
    phi_x = phi_L @ x  # Physical displacement at nonlinear location
    fnl = k_nl * phi_L.T @ (phi_x**3)
    Xdot = jnp.concatenate((xdot, Minv @ (-KX - CXdot - fnl + force)))
    return Xdot

def periodicity_anal(X0, T, F):
    t = jnp.linspace(0, T, 10 + 1)
    Xsol = jode.odeint(model_ode_anal, X0, t, T, F, rtol=1e-8)

    H = Xsol[-1, :] - Xsol[0, :]
    return H.reshape(-1, 1)

X0 = jnp.array([0.0, 0.0, 1.0, 1.0])
T = 2.0
F = 10.0

periodicity_anal(X0, T, F)

Array([[-0.00528779],
       [-0.00311958],
       [-1.8358393 ],
       [-0.9015873 ]], dtype=float32)

In [None]:
import scipy.linalg as spl
import jax
import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, Tsit5, SaveAt, PIDController

# Beam paremeters for 2 mode system
k_nl = 4250000
w_1 = 91.734505484821950
w_2 = 3.066194429903638e02
zeta_1 = 0.0
zeta_2 = 0.0
phi_L = jnp.array([[-7.382136522799137, 7.360826867549465]])

# Modal Matrices
M = jnp.eye(2)
C = jnp.array([[2 * zeta_1 * w_1, 0], [0, 2 * zeta_2 * w_2]])
K = jnp.array([[w_1**2, 0], [0, w_2**2]])
Minv = spl.inv(M)

def model_ode_anal(t, X, args):
    T, F = args
    # State equation of the mass spring system. Xdot(t) = g(X(t))
    x = X[:2]
    xdot = X[2:]
    KX = K @ x
    CXdot = C @ xdot
    force = jnp.array([F * jnp.sin(2 * jnp.pi / T * t), 0])
    phi_x = phi_L @ x  # Physical displacement at nonlinear location
    fnl = k_nl * phi_L.T @ (phi_x**3)
    Xdot = jnp.concatenate((xdot, Minv @ (-KX - CXdot - fnl + force)))
    return Xdot

term=ODETerm(model_ode_anal)
solver=Tsit5()
saveat = SaveAt(ts=jnp.linspace(0, 2, 10 + 1))

def periodicity_anal(X0, T, F):
    t = jnp.linspace(0, T, 10 + 1)
    Xsol = diffeqsolve(term, solver, t0=t[0], t1=t[-1], dt0=0.2, y0=X0, args=(T, F), saveat=saveat, stepsize_controller=PIDController(rtol=1e-8, atol=1e-8),max_steps=None)

    H = Xsol.ys[-1, :] - Xsol.ys[0, :]
    return H.reshape(-1, 1)

X0 = jnp.array([0.0, 0.0, 1.0, 1.0])
T = 2.0
F = 10.0

periodicity_anal(X0, T, F)

Array([[-0.00473348],
       [-0.0035218 ],
       [-1.7558949 ],
       [-1.2415923 ]], dtype=float32)

In [None]:
run(pred_acc=pred_acc_)

In [None]:
a.plot_results(results)

In [None]:
pred_acc_, pred_energy = a._predict(results)

In [None]:
limq11, limq12, limqd11, limqd12 = info["q1min"], info["q1max"], info["qd1max"], info["qd1min"]
limq21, limq22, limqd21, limqd22 = info["q2min"], info["q2max"], info["qd2max"], info["qd2min"]

q1a, q1da = jnp.linspace(limq11, limq12, 100), jnp.linspace(
    limqd11, limqd12, 100)
q1aa, q1daa = jnp.meshgrid(q1a, q1da)

q2a, q2da = jnp.linspace(limq21, limq22, 100), jnp.linspace(
    limqd21, limqd22, 100)
q2aa, q2daa = jnp.meshgrid(q2a, q2da)

q1a.shape, q1da.shape, q1aa.shape, q1daa.shape, q2a.shape, q2da.shape, q2aa.shape, q2daa.shape

In [None]:
M, K, C = jax.vmap(pred_energy)(jnp.concatenate([q1aa.reshape(-1, 1), q2aa.reshape(-1, 1)], axis=1), jnp.concatenate([q1daa.reshape(-1, 1), q2daa.reshape(-1, 1)], axis=1))

L = M - K
M.shape, K.shape, C.shape, L.shape

In [None]:
fig = plt.figure(figsize=(12, 12), tight_layout=True)
fig.suptitle(f"Final Test Loss: {results['best_loss']:.3e}")

# --------------------------------- FUNCTIONS

# -------------------------------- Lagrangian
ax = fig.add_subplot(421, projection="3d")
m = ax.plot_surface(q1aa, q1daa, L[:, 0].reshape(q1aa.shape), cmap="RdGy", lw=0)
ax.set_xlabel("q")
ax.set_ylabel(r"$\dot{q}$")
ax.set_zlabel(r"$\mathcal{L}_{NN}$", fontsize=16, labelpad=3)
ax.set_title(f"Mode 1 Lagrangian")
fig.colorbar(m, ax=ax, shrink=0.3, pad=0.1)

ax = fig.add_subplot(422, projection="3d")
m = ax.plot_surface(q2aa, q2daa, L[:, 1].reshape(q2aa.shape), cmap="RdGy", lw=0)
ax.set_xlabel("q")
ax.set_ylabel(r"$\dot{q}$")
ax.set_zlabel(r"$\mathcal{L}_{NN}$", fontsize=16, labelpad=3)
ax.set_title(f"Mode 2 Lagrangian")
fig.colorbar(m, ax=ax, shrink=0.3, pad=0.1)

# ---------------------------------- Mass
ax = fig.add_subplot(423, projection="3d")
m = ax.plot_surface(q1aa, q1daa, M[:, 0].reshape(q1aa.shape), cmap="PiYG", lw=0)
ax.set_xlabel("q")
ax.set_ylabel(r"$\dot{q}$")
ax.set_zlabel(r"$\mathcal{M}_{NN}$", fontsize=16, labelpad=3)
ax.set_title("Mode 1 Mass")
fig.colorbar(m, ax=ax, shrink=0.3, pad=0.1)

ax = fig.add_subplot(424, projection="3d")
m = ax.plot_surface(q2aa, q2daa, M[:, 1].reshape(q2aa.shape), cmap="PiYG", lw=0)
ax.set_xlabel("q")
ax.set_ylabel(r"$\dot{q}$")
ax.set_zlabel(r"$\mathcal{M}_{NN}$", fontsize=16, labelpad=3)
ax.set_title("Mode 2 Mass")
fig.colorbar(m, ax=ax, shrink=0.3, pad=0.1)

# ---------------------------------- Stiffness
ax = fig.add_subplot(425, projection="3d")
m = ax.plot_surface(q1aa, q1daa, K[:, 0].reshape(q1aa.shape), cmap="PiYG", lw=0)
ax.set_xlabel("q")
ax.set_ylabel(r"$\dot{q}$")
ax.set_zlabel(r"$\mathcal{K}_{NN}$", fontsize=16, labelpad=3)
ax.set_title("Mode 1 Stiffness")
fig.colorbar(m, ax=ax, shrink=0.3, pad=0.1)

ax = fig.add_subplot(426, projection="3d")
m = ax.plot_surface(q2aa, q2daa, K[:, 1].reshape(q2aa.shape), cmap="PiYG", lw=0)
ax.set_xlabel("q")
ax.set_ylabel(r"$\dot{q}$")
ax.set_zlabel(r"$\mathcal{K}_{NN}$", fontsize=16, labelpad=3)
ax.set_title("Mode 2 Stiffness")
fig.colorbar(m, ax=ax, shrink=0.3, pad=0.1)

# ---------------------------------- Damping
ax = fig.add_subplot(427, projection="3d")
m = ax.plot_surface(q1aa, q1daa, C[:, 0].reshape(q1aa.shape), cmap="PiYG", lw=0)
ax.set_xlabel("q")
ax.set_ylabel(r"$\dot{q}$")
ax.set_zlabel(r"$\mathcal{D}_{NN}$", fontsize=16, labelpad=3)
ax.set_title("Mode 1 Damping")
fig.colorbar(m, ax=ax, shrink=0.3, pad=0.1)

ax = fig.add_subplot(428, projection="3d")
m = ax.plot_surface(q2aa, q2daa, C[:, 1].reshape(q2aa.shape), cmap="PiYG", lw=0)
ax.set_xlabel("q")
ax.set_ylabel(r"$\dot{q}$")
ax.set_zlabel(r"$\mathcal{D}_{NN}$", fontsize=16, labelpad=3)
ax.set_title("Mode 2 Damping")
fig.colorbar(m, ax=ax, shrink=0.3, pad=0.1)
# fig.savefig(f"./Modal_LNN/{file_name}-LD.png")


In [None]:
ddx = pred_acc_(test_x, test_f)
ddx.shape, test_dx.shape

In [None]:
fig = plt.figure(figsize=(12, 12), tight_layout=True)
fig.suptitle(f"Final Test Loss: {results['best_loss']:.3e}")

ax = fig.add_subplot(211)
ax.plot(ddx[:301*4, 0], label="LNN Mode 1")
ax.plot(test_dx[:301*4, 0, -1], label="Truth Mode 1", linestyle='dashed')

ax = fig.add_subplot(212)
ax.plot(ddx[:301*4, 1], label="LNN Mode 2")
ax.plot(test_dx[:301*4, 1, -1], label="Truth Mode 2", linestyle='dashed')

plt.legend()

In [None]:
fig = plt.figure(figsize=(12, 6), tight_layout=True)

ax = fig.add_subplot(121)
ax.plot(test_x[:301*20, 0, 0], test_x[:301*20, 0, 1], label="Mode 1")
ax.set_title("Mode 1 Orbits")

ax = fig.add_subplot(122)
ax.plot(test_x[:301*20, 1, 0], test_x[:301*20, 1, 1], label="Mode 2")
ax.set_title("Mode 2 Orbits")

plt.legend()