In [5]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.colors as pc
dictA=np.load('Pre_processed_Data/mapping_dict.npy', allow_pickle=True).item()
df = pd.DataFrame(list(dictA.items()), columns=["node", "label"])
for j in range(3):
    df[f'scale{j}']=df['node'].map({i: f"C{j}_{np.load('Results/robust_partitions_f.npy')[j][i]}" for i in range(238)})

    
scales = ["scale0", "scale1", "scale2"]

all_coms = pd.unique(df[scales].values.ravel())
com_to_id = {com: i for i, com in enumerate(all_coms)}

unique_labels = df["label"].unique()

palette = pc.qualitative.Bold   


label_to_color = {
    lab: palette[i % len(palette)] for i, lab in enumerate(unique_labels)
}

def transitions_by_class(df, from_col, to_col):
    g = (
        df.groupby(["label", from_col, to_col])
          .size()
          .reset_index(name="count")
    )
    g["source"] = g[from_col].map(com_to_id)
    g["target"] = g[to_col].map(com_to_id)
    g["color"] = g["label"].map(label_to_color)
    return g

t0_1 = transitions_by_class(df, "scale0", "scale1")
t1_2 = transitions_by_class(df, "scale1", "scale2")
links = pd.concat([t0_1, t1_2], ignore_index=True)

fig = go.Figure()

# Sankey trace
fig.add_trace(go.Sankey(
    node=dict(
        pad=20,
        thickness=20,
        color="rgba(0,0,0,0.1)"
    ),
    link=dict(
        source=links["source"],
        target=links["target"],
        value=links["count"],
        color=links["color"]
    )
))


for lab in unique_labels:
    fig.add_trace(go.Scatter(
        x=[None], y=[None],
        mode="markers",
        marker=dict(size=12, color=label_to_color[lab]),
        name=lab
    ))

fig.update_layout(
    font_size=12,
    legend_title="Class",
        paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
    width=400
)

fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)

fig.show()


