# Benchmarks

In [9]:
# the notebook imports
import matplotlib.pyplot as plt
import numpy as np
# this is the convenience function
from autokoopman import auto_koopman
# for a complete example, let's create an example dataset using an included benchmark system
from autokoopman.benchmark import bio2, fhn, lalo20, prde20, robe21, spring

from sklearn.metrics import mean_squared_error
import statistics
import os
import csv

import time

In [10]:
benches = [bio2.Bio2(), fhn.FitzHughNagumo(), lalo20.LaubLoomis(), prde20.ProdDestr(), robe21.RobBench()]

In [11]:
def get_trajectories(bench, iv):
    # get the model from the experiment results
    model = experiment_results['tuned_model']

    # simulate using the learned model
    trajectory = model.solve_ivp(
        initial_state=iv,
        tspan=(0.0, 10.0),
        sampling_period=0.1
    )
    # simulate the ground truth for comparison
    true_trajectory = bench.solve_ivp(
        initial_state=iv,
        tspan=(0.0, 10.0),
        sampling_period=0.1
    )
    
    return trajectory, true_trajectory

In [12]:
def test_trajectories(bench, num_tests):
    mses = []
    mses_dim = [[] for x in range(len(bench.names))]
    for j in range(num_tests):
        iv = np.random.uniform(low=bench.init_set_low, high=bench.init_set_high, size=(len(bench.names)))
        trajectory, true_trajectory = get_trajectories(bench, iv)
        mse = mean_squared_error(trajectory.states.T, true_trajectory.states.T)
        mses.append(mse)
        
        for traj_dim, (trajectory, true_trajectory) in enumerate(zip(trajectory.states.T, true_trajectory.states.T)):
            mse = mean_squared_error(trajectory, true_trajectory)
            mses_dim[traj_dim].append(mse)
            
    return mses, mses_dim

In [13]:
def store_data(bench_name, dim, param_name, param_values, train_times, all_mses, all_mses_dim):
    if not os.path.exists('data'):
        os.makedirs('data')
        
    with open(f'data/{bench_name}', 'a') as f:
        writer = csv.writer(f)
        row = [param_name, "train_time", "Avg mse"]
        for i in range(dim):
              row.append(f'Avg mse dim {i+1}')
        writer.writerow(row)
              
        for param_value, train_time, mse, mses_dim in zip(param_values, train_times, all_mses, all_mses_dim):
              row = [param_value, train_time, mse]
              for mses in mses_dim:
                  row.append(mses)
              writer.writerow(row)
        writer.writerow([])

In [14]:
def plot(trajectory, true_trajectory):
    plt.figure(figsize=(10, 6))
    # plot the results
    plt.plot(trajectory.states.T[0], trajectory.states.T[1],label='Trajectory Prediction')
    plt.plot(true_trajectory.states.T[0], true_trajectory.states.T[1],label='Ground Truth')

    plt.xlabel("$x_1$")
    plt.ylabel("$x_2$")
    plt.grid()
    plt.legend()
    plt.title("Bio2 Test Trajectory Plot")
    plt.show()

In [15]:
all_param_values = {
    "train_size":[5, 10, 25, 50, 75, 100, 125, 150, 175, 200],
    "samp_period":[0.01, 0.025, 0.05, 0.08, 0.1, 0.2, 0.25, 0.5, 1, 2],
    "obs_type":["rff","quadratic", "id"],
    "opt":["grid", "monte-carlo"],
    "n_obs":[5, 10, 25, 50, 100, 200, 300, 400, 500, 1000],
    "grid_param_slices":[5, 10, 25, 50, 100],
    "n_splits":[5, 10, 25, 50, 100],
    "rank":[(1, 200, 40)]
}

In [None]:
for bench in benches:
    for param, param_values in all_param_values.items():
        all_mses = []
        all_mses_dim = []
        times = []
        param_dict = {"train_size":10,"samp_period":0.1,"obs_type":"rff","opt":"grid","n_obs":200,
                      "grid_param_slices":5,"n_splits":5,"rank":(1, 200, 40)}
        for i, param_value in enumerate(param_values):
            np.random.seed(0)
            param_dict[param] = param_value
            start = time.time()
            # generate training data
            training_data = bench.solve_ivps(initial_states=np.random.uniform(low=bench.init_set_low, 
                            high=bench.init_set_high, size=(param_dict["train_size"], len(bench.names))),
                            tspan=[0.0, 10.0], sampling_period=param_dict["samp_period"])
            # learn model from data
            experiment_results = auto_koopman(
                training_data,          # list of trajectories
                sampling_period=0.1,    # sampling period of trajectory snapshots
                obs_type=param_dict["obs_type"],         # use Random Fourier Features Observables
                opt=param_dict["opt"],             # grid search to find best hyperparameters
                n_obs=param_dict["n_obs"],              # maximum number of observables to try
                max_opt_iter=200,       # maximum number of optimization iterations
                grid_param_slices=param_dict["grid_param_slices"],# for grid search, number of slices for each parameter
                n_splits=param_dict["n_splits"],             # k-folds validation for tuning, helps stabilize the scoring
                rank=param_dict["rank"]       # rank range (start, stop, step) DMD hyperparameter
            )
            end = time.time()
            times.append(round(end - start, 3))

            mses, mses_dim = test_trajectories(bench, 10)
            print("The average mean square error is ", statistics.mean(mses))
            all_mses.append(statistics.mean(mses))
            all_mses_dim.append([])
            for traj_dim, mses in enumerate(mses_dim):
                all_mses_dim[i].append(statistics.mean(mses))
                print(f"The average mean square error for dim {traj_dim+1} is", statistics.mean(mses))

        store_data(bench.name, len(bench.names), param, param_values, times, all_mses, all_mses_dim)

 96%|█████████████████████████████████████████▎ | 24/25 [00:08<00:00,  2.99it/s]


The average mean square error is  4.4927353661417946e-05
The average mean square error for dim 1 is 1.9406395731246658e-07
The average mean square error for dim 2 is 2.6544153344365286e-07
The average mean square error for dim 3 is 1.904669432476599e-07
The average mean square error for dim 4 is 2.5507699592417455e-07
The average mean square error for dim 5 is 9.426335888788825e-07
The average mean square error for dim 6 is 1.9006537163115083e-06
The average mean square error for dim 7 is 3.4157038147799054e-06
The average mean square error for dim 8 is 0.0003745709205591328
The average mean square error for dim 9 is 2.2611221843730588e-05
