In [None]:
import plotly.graph_objects as go
import pandas as pd
    
Categorise BMI and MRI derived adiposity values.
    
# Assuming data_for_plot is your DataFrame with the 'BMI Category' and 'MRI adipose tissue' columns
# Load your data
data_for_plot = pd.read_csv('C://your_data_path')
    
# Categorize BMI and MRI
data_for_plot['BMI Category'] = data_for_plot['BMI'].apply(categorize_bmi)
data_for_plot['MRI Category'] = data_for_plot['Whole_body_fat_mass'].apply(categorize_mri)
    
# Convert sex to a readable format (optional)
data_for_plot['Sex'] = data_for_plot['sex'].map({0: 'Female', 1: 'Male'})
    
# Group and count the data
grouped_data = data_for_plot.groupby(['BMI Category', 'Sex', 'MRI Category']).size().reset_index(name='Counts')
    
# Prepare the source, target, and value lists for the alluvial plot
source = []
target = []
value = []
    
bmi_categories = grouped_data['BMI Category'].unique()
sex_categories = grouped_data['Sex'].unique()
mri_types = grouped_data['MRI Category'].unique()
    
# Create a mapping of categories to indices
category_indices = {cat: idx for idx, cat in enumerate(bmi_categories)}
sex_indices = {sex: idx + len(bmi_categories) for idx, sex in enumerate(sex_categories)}
mri_indices = {mri: idx + len(bmi_categories) + len(sex_categories) for idx, mri in enumerate(mri_types)}
    
# Populate the source, target, and value lists
for _, row in grouped_data.iterrows():
    bmi_idx = category_indices[row['BMI Category']]
    sex_idx = sex_indices[row['Sex']]
    mri_idx = mri_indices[row['MRI Category']]
    
# First link (BMI to Sex)
source.append(bmi_idx)
target.append(sex_idx)
value.append(row['Counts'])  # Use 'Counts' for the value
    
# Second link (Sex to MRI)
source.append(sex_idx)
target.append(mri_idx)
value.append(row['Counts'])  # Use 'Counts' for the value again
    
    
# Create the alluvial plot
    fig = go.Figure(data=[go.Sankey(
        node=dict(
            pad=15,
            thickness=20,
            line=dict(color='black', width=0.5),
            label=list(bmi_categories) + list(sex_categories) + list(mri_types)
        ),
        link=dict(
            source=source,
            target=target,
            value=value
        ))])
    
# Increase the size of the plot
    fig.update_layout(
        title_text='Alluvial Plot: BMI Categories, Sex Distribution, and MRI Categories',
        font_size=12,
        width=1000,  # Width of the plot
        height=600   # Height of the plot
    )
    
    fig.show()