In [None]:
import plotly.graph_objects as go
import pandas as pd


# Survey question for pre_covid_mode:  Before March 15, 2020, how did <you/name> usually commute to <your/their> primary workplace?
# Survey question for work_mode:       Currently, how <do you/does name> typically travel to <your/their> primary workplace?

# data were calculated in Tableau https://github.com/BayAreaMetro/Travel-Diary-Surveys/blob/master/BATS-2023/BATS-2023-SurveyDataViz_PreCovidModeSankey.twb
# Tab namee "pre post mode table"
data = {
    'Pre_Covid_Mode': ['Active Transportation', 'Active Transportation', 'Active Transportation',
                       'Drive', 'Drive', 'Drive',
                       'Transit', 'Transit', 'Transit'],
    'Work_Model': ['Active Transportation', 'Drive', 'Transit',
                   'Active Transportation', 'Drive', 'Transit',
                   'Active Transportation', 'Drive', 'Transit'],
    'Weight': [108607, 67204, 27841,
               40375, 2148263, 75416,
               31496, 181133, 201361]
}

df = pd.DataFrame(data)

# Create separate node lists for left and right
left_nodes = ['Active Transportation', 'Drive', 'Transit']
right_nodes = ['Active Transportation', 'Drive', 'Transit']

# Combine all nodes with prefixes
all_nodes = [f"Pre-COVID: {node}" for node in left_nodes] + [f"2023: {node}" for node in right_nodes]

# Map source to left side and target to right side
source = []
target = []
value = []
link_colors = []

# Define colors for each mode
colors = {
    'Active Transportation': 'rgba(44, 160, 68, 0.4)',  # Green
    'Drive': 'rgba(31, 119, 180, 0.4)',  # Blue
    'Transit': 'rgba(255, 127, 14, 0.4)',  # Orange
}

node_colors = {
    'Active Transportation': '#2ca02c',  # Solid green
    'Drive': '#1f77b4',  # Solid blue
    'Transit': '#ff7f0e',  # Solid orange
}

for _, row in df.iterrows():
    src_idx = left_nodes.index(row['Pre_Covid_Mode'])
    tgt_idx = len(left_nodes) + right_nodes.index(row['Work_Model'])
    source.append(src_idx)
    target.append(tgt_idx)
    value.append(row['Weight'])
    # Color the link based on the source (Pre-COVID mode)
    link_colors.append(colors[row['Pre_Covid_Mode']])

node_color_list = [node_colors.get(node.split(': ')[1], '#808080') for node in all_nodes]

fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=20,
        thickness=30,
        line=dict(color="black", width=0.5),
        label=all_nodes,
        color=node_color_list
    ),
    link=dict(
        source=source,
        target=target,
        value=value,
        color=link_colors
    )
)])

fig.update_layout(
    title="Self-Reported Journey-to-Work Mode Pre-Pandemic vs. 2023",
    font_size=12,
    height=700,
    width=1200
)

fig.show()