*Snippet kindly shared by Ofer Mendelevitch (ofer@syntegra.io)*

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from typing import List, Dict, Any, Tuple

In [1]:
def gen_sankey(df: pd.DataFrame, all_labels, label_to_color, regimens):

    tx_dfs = []
    for i in range(len(phases)-1):
        source = phases[i]
        dest = phases[i+1]
        tx = pd.DataFrame({
            'source_label': df[source].map(lambda x: source[:5]+x if pd.isnull(x)==False else np.NaN),
            'target_label': df[dest].map(lambda x: dest[:5]+x if pd.isnull(x)==False else np.NaN),
        })
        tx = tx[(tx.source_label.isnull()==False) & (tx.target_label.isnull()==False)]
        tx_dfs.append(tx)
    tx_df = pd.concat(tx_dfs, axis=0, ignore_index=True)

    for p in phases:
        tx_df = tx_df[tx_df.source_label.map(lambda x: (x[:5]==p[:5] and x[5:] in regimens[p]) or x[:5]!=p[:5])]
        tx_df = tx_df[tx_df.target_label.map(lambda x: (x[:5]==p[:5] and x[5:] in regimens[p]) or x[:5]!=p[:5])]
    tx_df['value'] = 1
    tx_df = tx_df.groupby(['source_label', 'target_label'])['value'].count().reset_index(drop=False)

    if all_labels is None:
        all_labels = list(set(tx_df.source_label.unique().tolist() + tx_df.target_label.unique().tolist()))
    label_map = {label: i for i,label in enumerate(all_labels)}
    rev_label_map = {i:label for label,i in label_map.items()}

    tx_df['source'] = tx_df['source_label'].map(lambda x: label_map[x])
    tx_df['target'] = tx_df['target_label'].map(lambda x: label_map[x])
    tx_df['color'] = tx_df.source_label.map(lambda x: label_to_color[x[5:]])

    fig = go.Figure(data=[go.Sankey(
        node = dict(
            pad = 10,
            thickness = 10,
            line = dict(color = "black", width = 0.5),
            label = [c[5:] for c in all_labels],
            color = [label_to_color[lbl[5:]] for lbl in all_labels]
        ),
        link = dict(
            source = tx_df.source.tolist(),
            target = tx_df.target.tolist(),
            value = tx_df.value.tolist(),
            color = tx_df.color.tolist()
        )
    )])
    return fig


NameError: name 'pd' is not defined

In [None]:
phases = ['pttm_regimen', 'm1tm_regimen', 'm2tm_regimen', 'm3tm_regimen', 'm4tm_regimen', 'm5tm_regimen'][:4]
n_regimens = 5

# Generate all labels up-front, based on top regimen in each phase (line of treatment)
all_labels = []
for p in phases:
    all_labels.extend(real[p].map(lambda x: f'{p[:5]}{x}').value_counts(dropna=True).index)
all_labels = list(set(all_labels))

# Pick top <n_regimens> most common regimens for each phase
regimens = {}
for p in phases:
    regimens[p] = real[p].value_counts(dropna=True).index[:n_regimens].tolist()

# calculate label_to_color - doing this upfront so that the colors are consistent across real and synthetic
colors = px.colors.qualitative.Pastel  # you can choose a different color map
nColors = len(colors)
label_to_color = {}
inx = 0
for lbl in all_labels:
    if lbl[5:] not in label_to_color:
        label_to_color[lbl[5:]] = colors[inx%nColors]
        inx += 1


# read input flies
real = pd.read_csv('real.csv')  # replace with correct file name for Real
syn = pd.read_csv('syn.csv')    # replace with correct file name for Synthetic
        
# sample dataframes to make sure they are the same size
n = min(len(real), len(syn))
real_fig = gen_sankey(real.sample(n, random_state=42), all_labels, label_to_color, regimens)
syn_fig = gen_sankey(syn.sample(n, random_state=42), all_labels, label_to_color, regimens)

template = 'plotly_white'
fig = make_subplots(rows=2, cols=1,
                    specs=[ [{"type": "sankey"}], [{"type": "sankey"}] ],
                    row_titles = ['Real', 'Synthetic'], vertical_spacing=0.02)
fig.add_trace(real_fig.data[0], row=1, col=1)
fig.add_trace(syn_fig.data[0], row=2, col=1)
fig.update_layout(autosize=False, width=1000, height=600,
                  margin = dict(t=10, l=1, r=5, b=10), template=template, 
                  font=dict(color = 'black', size=10))
fig
