In [1]:
# Python libraries
import os
import matplotlib.pyplot as plt

# To enable LaTeX and select a font
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "sans-serif",
    "font.sans-serif": "Helvetica",
})

## Functions to plot

In [2]:
def plot_annotation(ax, scn, trajs, scale=1.0001):

    """
    Annotate the number of the machine close to the endpoint of each trajectory
    
    INPUT:
    * ax is the object where we want to plot over
    * scn is a scene of the dataset
    * trajs is an array of the trajectories to plot
    * scale allows to adjust how near to write the number to each machine
    """
    
    N = scn['N. vehicles']
    tstamps = scn['Tarr']
    
    # Annotate the name of each variable at the last value
    labels = list(range(1,N+1,1))
    for ycoord in trajs:
        coords = zip([tstamps[-1]*scale]*N,ycoord[:,-1]) # last value where I want to annotate the corresponding label
        for coord,lab in zip(coords,labels):
            ax.annotate(xy=coord,                    # The point (x, y) to annotate.
                        xytext=coord,                # The position (x, y) to place the text at.
        #                 textcoords='offset points',
                        text=lab,
                        verticalalignment='center')
            
    return

In [3]:
def plot_limits(ax, scn, trajs, xbal=0.01, ybal=0.05):

    """
    Annotate the number of the machine close to the endpoint of each trajectory
    
    INPUT:
    * ax is the object where we want to plot over
    * scn is a scene of the dataset
    * trajs is an array of the trajectories to plot
    * scale allows to adjust how near to write the number to each machine
    """
    
    tstamps = scn['Tarr']
    
    # X LIM
    ax.set_xlim(tstamps[0]*(1-xbal), tstamps[-1]*(1+xbal))
    # Y LIM
    mmin, mmax = min([x.min() for x in trajs]), max([x.max() for x in trajs])
    ax.set_ylim(mmin*(1-2.5*ybal), mmax*(1+ybal))

    return

## Plot functions for NN model

In [4]:
def plot_scn(scn, traj_sim, title=f"Trajs simulated by NN driven LWR model", xbal=0.01, ybal=0.05, scale=1.0001):
    """
    Plot and annotate trajectories from a scene.

    Args:
    - scn: A scene from the dataset.
    - traj_sim: An array of simulated trajectories to plot.
    - title: Title for the plot.
    - xbal: Balance factor for the x-axis plot limits.
    - ybal: Balance factor for the y-axis plot limits.
    - scale: Adjustment factor for annotation placement.
    """

    # Extract necessary data from the scene
    trajs = [scn['Xarr'], traj_sim]
    labels_plot = ["true", "traj sim"]
    tstamps = scn['Tarr']
    N = scn['N. vehicles']

    # Create a figure and axis for the plot
    width, height = 7, 5
    fig, ax = plt.subplots(figsize=(width, height))

    # Plot the true and simulated trajectories for each vehicle
    for veh in range(0, N):
        ax.plot(tstamps, scn['Xarr'][veh])
        ax.plot(tstamps, trajs[-1][veh], '--')

    # Add vertical dashed lines at time stamps
    for ts in tstamps:
        ax.axvline(x=ts, color='red', linestyle='--', linewidth=0.75)

    # Annotate the names of each vehicle at the last time stamp
    plot_annotation(ax, scn, trajs, scale)

    ax.set_xlabel("$t$")
    ax.set_ylabel("$X(t)$")
    ax.set_title(title, fontsize=15)

    # Uncomment this line if you want to set custom plot limits
    # plot_limits(ax, scn, trajs, xbal=xbal, ybal=ybal)

    # Add a legend and display the grid
    plt.legend(labels_plot)
    plt.grid()
    plt.show()

    return

## Plot functions for LIN/LOG models

In [5]:
def plot_TD_LWR_scn(scn, x_list_matched, title):
    """
    Plot time-discretized LWR (Lighthill-Whitham-Richards) model scene for the case Lin/Log.

    Args:
    - scn: A scene of the dataset.
    - x_list_matched: A list of matched trajectories.
    - title: Title for the plot.
    """

    tstamps = scn['Tarr']

    # Calculate a scaling factor for annotation placement based on the time range
    tscale = 1 + (tstamps[-1] - tstamps[0]) / 10000

    # Plot the scene and matched trajectories
    plot_scn(scn, x_list_matched, title, xbal=0.01, ybal=0.05, scale=tscale)

    return