In [8]:
import plotly.graph_objects as go
import pandas as pd

# Flows provided by user
data = {
    'Pre_Covid_Mode': [
        '5 or more days a week','5 or more days a week','5 or more days a week','5 or more days a week',
        '2-4 days a week','2-4 days a week','2-4 days a week','2-4 days a week',
        '1 day a week','1 day a week','1 day a week','1 day a week',
        'Less than weekly','Less than weekly','Less than weekly','Less than weekly'
    ],
    'Work_Model': [
        '5 or more days a week','2-4 days a week','1 day a week','Less than weekly',
        '5 or more days a week','2-4 days a week','1 day a week','Less than weekly',
        '5 or more days a week','2-4 days a week','1 day a week','Less than weekly',
        '5 or more days a week','2-4 days a week','1 day a week','Less than weekly'
    ],
    'Weight': [
        1217847, 751888, 93787, 181012,
         64833, 391080, 55484, 57317,
         3836, 17247, 11388, 12065,
         137161, 146542, 32193, 94904
    ]
}

df = pd.DataFrame(data)

# Define nodes in the order you want them shown on left and right
left_nodes = [
    '5 or more days a week',
    '2-4 days a week',
    '1 day a week',
    'Less than weekly'
]
right_nodes = left_nodes[:]  # same ordering on the right

# Build full node list with prefixes so left and right are distinct
all_nodes = [f"Pre-COVID: {n}" for n in left_nodes] + [f"2023: {n}" for n in right_nodes]

# Map dataframe rows to sankey indices
source = []
target = []
value = []
link_colors = []

# Colors for each category (used for links based on the source)
link_color_map = {
    '5 or more days a week': 'rgba(31,119,180,0.4)',   # blue-ish semi-transparent
    '2-4 days a week': 'rgba(44,160,68,0.4)',         # green-ish
    '1 day a week': 'rgba(255,127,14,0.4)',           # orange-ish
    'Less than weekly': 'rgba(148,103,189,0.4)'       # purple-ish
}

# Node (solid) colors
node_color_map = {
    '5 or more days a week': '#1f77b4',
    '2-4 days a week': '#2ca02c',
    '1 day a week': '#ff7f0e',
    'Less than weekly': '#9467bd'
}

for _, row in df.iterrows():
    src_idx = left_nodes.index(row['Pre_Covid_Mode'])
    tgt_idx = len(left_nodes) + right_nodes.index(row['Work_Model'])
    source.append(src_idx)
    target.append(tgt_idx)
    value.append(row['Weight'])
    link_colors.append(link_color_map[row['Pre_Covid_Mode']])

# Create the node color list in the same order as all_nodes
node_color_list = [node_color_map[node.split(': ')[1]] for node in all_nodes]

fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=20,
        thickness=30,
        line=dict(color="black", width=0.5),
        label=all_nodes,
        color=node_color_list
    ),
    link=dict(
        source=source,
        target=target,
        value=value,
        color=link_colors
    )
)])

fig.update_layout(
    title="Commute Frequency: Pre-COVID vs 2023",
    font_size=12,
    height=700,
    width=1200
)

fig.show()
