In [None]:
## Libraries
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from datetime import datetime

# To enable LaTeX and select a font
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "sans-serif",
    "font.sans-serif": "Helvetica",
})

# Using ipynb import functions defined in other nb
sys.path.append("myfun/")
from ipynb.fs.defs.myfun_load_dataset import *
from ipynb.fs.defs.myfun_models import *

## Data Acquisition

In [None]:
# Load the dataset
par_dir = os.path.dirname(os.getcwd()) # parent dir
dir_name = par_dir + "/NN-interaction"

## Solve Lin/Log models in a df

In [None]:
flag_save = False

In [None]:
l = []
for flag_processed in ['pre', 'post']:

    # Load the desired dflist
    _, dflist = load_dataset(dir_name, flag_processed)
    
    # Compute
    tmp = solve_TD_LWR_dataset(dflist, TD_LWR_model, v0=30, L=5, deltat=0.05, tol=1e-8, pplot=False)
    tmp['processed'] = [flag_processed] * tmp.shape[0]
    
    l.append(tmp)
    
info_alldataset = pd.concat(l)

In [None]:
info_alldataset

In [None]:
info_alldataset.groupby(['LWR_flag','N. file']).mean(numeric_only=True)

## Prepare the out dir

In [None]:
# Create the directory..
if flag_save:
    
    ext = ".svg"
    
    df_seen = [df['N. file'][0] for df in dflist]
    df_seen_str = '-'.join(str(x) for x in df_seen)
    
    # Create directory where to save the image
    now = datetime.now() 
    d = now.strftime(f"%Y-%m-%d_%H-%M-%S_df{df_seen_str}-LINLOG-POST")
    
    path = 'out/' + d
    os.mkdir(path)
    
    # Save the solution in a file
    namefile = '/info_alldataset.txt'

    with open(path + namefile, 'w') as output:
        info_alldataset.to_csv(path + namefile, sep=',', index=False)

## Plot

In [None]:
ms2kmh = 3.6

# Initialize the figure
width, height = 7, 5
fig, ax = plt.subplots(figsize=(width,height))

lin_color, log_color = 'gold', 'olive'

# For all df
for key, grp in info_alldataset.groupby(['LWR_flag','N. file']):
    tmp = grp.mean(numeric_only=True)
    nf = key[1]
    match key[0]:
        case 'Lin':
            ax.plot(nf, tmp['v0_scn']*ms2kmh, color = lin_color, marker="x")
        case 'Log':
            ax.plot(nf, tmp['v0_scn']*ms2kmh, color = log_color, marker="x")

# mean line
for key, grp in info_alldataset.groupby('LWR_flag'):
    tmp = grp.mean(numeric_only=True)
    
    xrange = [i for i in range(1,11)]
    yval = [tmp['v0_scn']*ms2kmh] * len(xrange) 
    match key:
        case 'Lin':
            ax.plot(xrange, yval, color = lin_color, label = key)
        case 'Log':
            ax.plot(xrange, yval, color = log_color,  label = key)

xlim = [0.75,10.25]
ax.set_xlim(xlim)
ax.set_xticks(range(1,11))
ylim = ax.get_ylim()

ax.set_xlabel(r"$data set$",fontsize=14)
ax.set_ylabel(r"$v_{0}\ [km/h]$",fontsize=14)
ax.set_title(fr"$Velocities\ v_{0}\ of\ the\ leading\ car$",fontsize=14)
ax.legend()

plt.show()

In [None]:
# Save figure
if flag_save:

    title = f"/v0_df{df_seen_str}_1"     
    fig.savefig(path+title+ext, bbox_inches='tight')