In [None]:
# NN packages
import jax
import jax.numpy as jnp
import optax
import numpy as np
import json

# Visualization packages
import matplotlib.pyplot as plt

# ML Models
from LNN.models.MDOF_LNN import Physical_Damped_LNN, Modal_MLP

# Helper functions
from LNN.helpers import save_to_file, create_modal_training_data, plot_S_curves, plot_3DS_curves

from lnn_timesim import time_sim_branch, run

#### LNN

In [None]:
filename='amplitude_step_amplitude_'
path='LNN/Conx/modal_amp'
start=0.1
stop=1.9
step=0.1

ml_data = save_to_file(filename=filename, path=path, start=start, stop=stop, step=step, check=True)

train_data, test_data, info = create_modal_training_data(ml_data, path, split=0.2, seed=42)

In [None]:
mnn_settings = {
    'name': 'MNN',
    'units': 64,
    'layers': 4,
    'input_shape': 4,
    'train_batch_size': 128,
    'test_batch_size': 16,
    'shuffle': True,
    'seed': 69
    }

knn_settings = {
    'name': 'KNN',
    'units': 64,
    'layers': 4,
    'input_shape': 4,
    }

dnn_settings = {
    'name': 'DNN',
    'units': 32,
    'layers': 4,
    'input_shape': 2,
    }

lr = 1e-03
mnn_optimizer = optax.adam(lr)
knn_optimizer = optax.adam(lr)
dnn_optimizer = optax.adam(lr)
epochs = 20
show_every = 10

In [None]:
a = Physical_Damped_LNN(
    mnn_module=Modal_MLP, 
    knn_module=Modal_MLP,       
    dnn_module=Modal_MLP, 
    mnn_settings=mnn_settings,
    knn_settings=knn_settings,
    dnn_settings=dnn_settings, 
    mnn_optimizer=mnn_optimizer, 
    knn_optimizer=knn_optimizer, 
    dnn_optimizer=dnn_optimizer, 
    info=info, 
    activation=jax.nn.tanh)

# Start training LNN
results = None
_, _, _ = a.gather()

In [None]:
results_path = 'MDOF_LNN'
file_name='Modal'
iter_num = 200

epochs = 20
show_every = 10

In [None]:
results = Physical_Damped_LNN.load_model(f"./LNN/{results_path}/{file_name}/Iter_{iter_num}/model.pkl")

In [None]:
pred_acc_, pred_energy = a._predict(results)

#### NOTES:
- $16.0$ Hz doesn't start - to investigate
- $18.6$ Hz fails after a few steps - to investigate

In [None]:
cont_params_file = 'contparameters.json'
for i in np.arange(start, stop+0.1, step):
    # Open contparameters.json
    with open(cont_params_file, 'r') as file:
        data = json.load(file)
        # Modify forcing frequency
        data['forcing']['frequency'] = i
        # Save file
        data['Logger']['file_name'] = f'freq_step_{i:.02f}'

    # Modify contparameters.json
    with open(cont_params_file, 'w') as file:
        json.dump(data, file, indent=2)
    
    # Run simulation
    run(pred_acc=pred_acc_)
    
    # Perform time-sim post-processing
    time_sim_branch(file=f'freq_step_{i:.02f}', inplace="-i", run_bif="n", store_physical="n")