In [1]:
import matplotlib.pyplot as plt
import pandas
import re
from pathlib import Path
import numpy as np
import random
import os

import Utilities

In [2]:
# use plotly instead of matplot lib for plotly's interactability

import plotly.graph_objects as go

def PlotlyPlot(path, wantedDistanceWithin):
    dfA, dfB = Utilities.GetTruncatedScenario(path, wantedDistanceWithin)

    dfA['time'] = dfA.index
    dfB['time'] = dfB.index

    
    # Create 3D Scatter plot for the first set of data
    trace1 = go.Scatter3d(
        x=dfA['HeadPosXA'],
        y=dfA['HeadPosZA'],
        z=dfA['time'],
        mode='markers',
        marker=dict(
            size=3,
            color=dfA['time'],                
            colorscale='Viridis',           
            opacity=0.8
        ),
        name='Car A'  
    )

    # Create 3D Scatter plot for the second set of data
    trace2 = go.Scatter3d(
        x=dfB['HeadPosXB'],
        y=dfB['HeadPosZB'],
        z=dfB['time'],
        mode='markers',
        marker=dict(
            size=3,
            color=dfB['time'],                
            colorscale='Viridis',           
            opacity=0.8
        ),
        name='Car B'  
    )

    data = [trace1, trace2]

    fig = go.Figure(data=data)

    fig.update_layout(title = path[-23:-4],
                    scene = dict(
                        xaxis_title='HeadPosX',
                        yaxis_title='HeadPosZ',
                        zaxis_title='time'),
                        width=700,
                        height=700,
                        margin=dict(r=20, b=10, l=10, t=30))   
    
    # force the x and y axis of the plot to contain the full area even if the path doesn't reach there
    x_axis_range = [-wantedDistanceWithin, wantedDistanceWithin]
    y_axis_range = [-wantedDistanceWithin, wantedDistanceWithin]
    fig.update_layout(scene=dict(camera=dict(projection=dict(type='orthographic')),
                                 xaxis=dict(range=x_axis_range),
                                 yaxis=dict(range=y_axis_range)))
    return fig

In [3]:
"""

temp_path = random.choice(CPlist_NYC)

Utilities.DrawIntersectionFromPath(temp_path, 40)
PlotlyPlot(temp_path, 20)
"""

'\n\ntemp_path = random.choice(CPlist_NYC)\n\nUtilities.DrawIntersectionFromPath(temp_path, 40)\nPlotlyPlot(temp_path, 20)\n'

In [4]:
import random
import re
import plotly.graph_objects as go

# logic of multiple is quite fundimentally different from single, so I'm creating a new function
def PlotlyPlotMultiple(paths, wantedDistanceWithin, cars):
    data = []
    for i, path in enumerate(paths):
        dfA, dfB = Utilities.GetTruncatedScenario(path, wantedDistanceWithin)
        dfA = dfA.reset_index(drop=True)
        dfB = dfB.reset_index(drop=True)

        dfA['time'] = dfA.index
        dfB['time'] = dfB.index

        if cars in ['carA', 'both']:
            trace1 = go.Scatter3d(
                x=dfA['HeadPosXA'],
                y=dfA['HeadPosZA'],
                z=dfA['time'],
                mode='markers',
                marker=dict(
                    size=3,
                    color=dfA['time'],                
                    colorscale='Viridis',           
                    opacity=0.8
                ),
                name=path[-18:-4] + "A"
            )
            data.append(trace1)

        if cars in ['carB', 'both']:
            trace2 = go.Scatter3d(
                x=dfB['HeadPosXB'],
                y=dfB['HeadPosZB'],
                z=dfB['time'],
                mode='markers',
                marker=dict(
                    size=3,
                    color=dfB['time'],                
                    colorscale='Viridis',           
                    opacity=0.8
                ),
                name=path[-18:-4] + "B"  
            )
            data.append(trace2)

    fig = go.Figure(data=data)

    match = re.match(r'.*CSVScenario-CP(\d+)_Session-(.{3})\d+_\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}\.csv', paths[0])
    if match:
        my_location = match.group(2)

    fig.update_layout(title = my_location,
                    scene = dict(
                        xaxis_title='HeadPosX',
                        yaxis_title='HeadPosZ',
                        zaxis_title='time'),
                        width=900,
                        height=900,
                        
                        margin=dict(r=20, b=10, l=10, t=30))   

    x_axis_range = [-wantedDistanceWithin, wantedDistanceWithin]
    y_axis_range = [-wantedDistanceWithin, wantedDistanceWithin]
    z_axis_range = [0, 1100] 
    fig.update_layout(scene=dict(camera=dict(projection=dict(type='orthographic')),
                                 xaxis=dict(range=x_axis_range, dtick=2),
                                 yaxis=dict(range=y_axis_range, dtick=2),
                                 zaxis=dict(range=z_axis_range)
                                ))
    
    # Creating list of visible status for traces
    visible_list = [True] * len(data)  

    def hide_all(visible_list):
        return ['legendonly'] * len(visible_list)
    def show_all(visible_list):
        return [True] * len(visible_list)

    fig.update_layout(
        updatemenus=[
            dict(
                type="buttons",
                direction="right",
                showactive=True,
                x=0,  
                y=0,  
                xanchor='left',
                yanchor='bottom',
                #pad=dict(r=3, t=3), 
                buttons=list([
                    dict(
                        label="Hide All Traces",
                        method="update",
                        args=[{"visible": hide_all(visible_list)}],
                    ),
                    dict(
                        label="Show All Traces",
                        method="update",
                        args=[{"visible": show_all(visible_list)}],
                    )
                ]),
            )
        ]
    )


    return fig


In [5]:
import ipywidgets as widgets
from IPython.display import display

def PlotAll(CP_num, bounding_distance, plotted_car_type):
    def SetCPList(cp_num):
        CPlist_NYC = Utilities.GetScenarioCSVList(cp_num, True, "NYC")
        CPlist_ITH = Utilities.GetScenarioCSVList(cp_num, True, "ITH")
        CPlist_ISR = Utilities.GetScenarioCSVList(cp_num, True, "ISR")
        return CPlist_NYC, CPlist_ITH, CPlist_ISR

    CPlist_NYC, CPlist_ITH, CPlist_ISR = SetCPList(CP_num)

    distance_within = bounding_distance
    # car types: carA, carB, both
    car_type = plotted_car_type

    NYC_plot = PlotlyPlotMultiple(CPlist_NYC, distance_within, car_type)
    ITH_plot = PlotlyPlotMultiple(CPlist_ITH, distance_within, car_type)
    SIR_plot = PlotlyPlotMultiple(CPlist_ISR, distance_within, car_type)

    # I found plotly's subgraph a bit tricky to work with, so for now the workaround is to use ipywidgets
    # imagine attaching multiple js canvas to the website instead attemping combine everything into a single js
    out1 = widgets.Output()
    out2 = widgets.Output()
    out3 = widgets.Output()

    # Display your plots in the output widgets
    with out1:
        display(NYC_plot)

    #with out2:
        #display(ITH_plot)

    with out3:
        display(SIR_plot)

    #plot_container = widgets.HBox([out1, out2, out3])

    plot_container = widgets.HBox([out1, out3])


    display(plot_container)



In [6]:
PlotAll(1, 20, "both")

got 28 csv from CP1 in NYC with validity: True 
got 10 csv from CP1 in ITH with validity: True 
got 30 csv from CP1 in ISR with validity: True 


HBox(children=(Output(), Output()))

In [7]:
PlotAll(2, 20, "both")

got 29 csv from CP2 in NYC with validity: True 
got 11 csv from CP2 in ITH with validity: True 
got 31 csv from CP2 in ISR with validity: True 


HBox(children=(Output(), Output()))

In [8]:
PlotAll(3, 20, "both")

got 26 csv from CP3 in NYC with validity: True 
got 10 csv from CP3 in ITH with validity: True 
got 25 csv from CP3 in ISR with validity: True 


HBox(children=(Output(), Output()))

In [9]:
PlotAll(5, 20, "both")

got 26 csv from CP5 in NYC with validity: True 
got 10 csv from CP5 in ITH with validity: True 
got 22 csv from CP5 in ISR with validity: True 


HBox(children=(Output(), Output()))

In [10]:
PlotAll(6, 20, "both")

got 26 csv from CP6 in NYC with validity: True 
got 9 csv from CP6 in ITH with validity: True 
got 27 csv from CP6 in ISR with validity: True 


HBox(children=(Output(), Output()))

In [11]:
PlotAll(7, 20, "both")

got 35 csv from CP7 in NYC with validity: True 
got 10 csv from CP7 in ITH with validity: True 
got 31 csv from CP7 in ISR with validity: True 


HBox(children=(Output(), Output()))

In [12]:
PlotAll(8, 20, "both")

got 23 csv from CP8 in NYC with validity: True 
got 10 csv from CP8 in ITH with validity: True 
got 24 csv from CP8 in ISR with validity: True 


HBox(children=(Output(), Output()))