In [24]:
import plotly.graph_objects as go
import matplotlib.colors as mcolors
import numpy as np

# Define the dataset
data = {
    "LABEL": ["S", "F", "D", "N", "I"],
    "PS": [3, 0, 1, 1, 0],
    "OMP": [4, 1, 1, 1, 1],
    "CNP": [1, 2, 2, 1, 0],
    "NRP": [1, 1, 0, 1, 0],
    "NMCCC": [0, 1, 0, 0, 0],
    "PEC": [0, 0, 0, 1, 0],
    "NCDM": [1, 0, 1, 1, 0],
    "RGS": [1, 1, 1, 0, 0],
    "Reg": [2, 2, 1, 2, 0],
    "Aca": [7, 2, 3, 2, 1],
    "Oth": [1, 0, 1, 1, 0],
}

# Define node groups
sources = ['PS', 'OMP', 'CNP', 'NRP', 'NMCCC', 'PEC', 'NCDM', 'RGS']
intermediates = data["LABEL"]  # Intermediate nodes: ["S", "F", "D", "N", "I"]
targets = ['Reg', 'Aca', 'Oth']

# Create combined node labels list
node_labels = sources + intermediates + targets

# Create a mapping from node labels to indices
node_indices = {label: idx for idx, label in enumerate(node_labels)}

# Define lists for source indices, target indices, and values
source_indices = []
target_indices = []
values = []

num_rows = len(data["LABEL"])

# Create links between sources and intermediate nodes
for i in range(num_rows):
    intermediate_node_index = len(sources) + i
    for source in sources:
        flow_value = data[source][i]
        if flow_value > 0:
            src_idx = node_indices[source]
            tgt_idx = intermediate_node_index
            source_indices.append(src_idx)
            target_indices.append(tgt_idx)
            values.append(flow_value)

# Create links between intermediate nodes and targets
for i in range(num_rows):
    intermediate_node_index = len(sources) + i
    for target in targets:
        flow_value = data[target][i]
        if flow_value > 0:
            src_idx = intermediate_node_index
            tgt_idx = len(sources) + len(intermediates) + targets.index(target)
            source_indices.append(src_idx)
            target_indices.append(tgt_idx)
            values.append(flow_value)

# Generate unique colors using a color map
color_palette = list(mcolors.TABLEAU_COLORS.values())  # List of distinct colors
num_edges = len(source_indices)
colors = [color_palette[i % len(color_palette)] for i in range(num_edges)]

# Plot the Sankey diagram
fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color="black", width=0.5),
        label=node_labels
    ),
    link=dict(
        source=source_indices,
        target=target_indices,
        value=values,
        color=colors  # Unique edge colors
    )
)])

fig.update_layout(title_text="Sankey Diagram", font_size=10)
fig.show()

