In [None]:
## ODE
import numpy as np
import pandas as pd # for data manipulation
import time
from scipy.integrate import odeint, solve_ivp
from sklearn.metrics import mean_squared_error
import scipy.optimize as optimize

from ipynb.fs.full.myfun_nn import *
from ipynb.fs.defs.myfun_plot import *

# Useful functions

In [None]:
def time_discretization(t0, tend, deltat=0.05):
    
    Nt = round((tend-t0)/deltat) + 1               # number of discretization points
                                                   # cast the value into int, us round to avoid cast problem
    tspan = np.linspace(t0, tend, int(Nt))         # timespan
    
    return tspan

In [None]:
def seq2scn(df):
    
    """
    get an array of scenes, pandas obj
    """
    
    seq = []
    
    # extract input and target for each scene
    for row in df.iterrows(): #run over rows
        scn = row[1]
        seq.append(scn)

    return seq

In [None]:
def match_timestamps_scene(t, x, deltat = 0.05):
    
    """
    Match the computed solution to the same timestamps of the scene
    
    t_matched, x_matched = match_timestamps_scene(t, x, deltat = 0.05)
    """
    # To recover the same timestep in the data
    factor = int(0.2/deltat)
    
    t_matched = np.array(t)[::factor]
    x_matched = np.array([traj[::factor] for traj in x])
    
    return t_matched, x_matched

In [None]:
def update_sol_lists(N, tspan_ann, sol_ann, x_list, t_list):
    
    """
    Once you solve the ode in a sub-interval of a scene, you update the lists containing t,x
    
    x_list, t_list = update_sol_lists(N, tspan_ann, sol_ann, x_list, t_list)
    """
    
    x_ann = sol_ann.tolist()
    t_ann = tspan_ann[1:] # avoid the first recording

    # add sol to the correct veh
    for j in range(0,N):
        tmp = x_ann[j][1:] # avoid the first recording
        x_list[j] = np.concatenate([x_list[j],tmp])
    t_list = np.concatenate([t_list,t_ann]).tolist()
    
    return x_list, t_list

## Recovering numpy arrays after storing csv

In [None]:
# .apply(lambda x: x.replace('array','np.array')).apply(eval).apply(np.array)