<a href="https://colab.research.google.com/github/SushmitalKhan/Dissertation/blob/main/sankey_4layers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
import numpy as np
import re
import plotly.graph_objects as go
import json
from itertools import combinations

In [None]:
from collections import defaultdict
from ipywidgets import widgets, Output, VBox

In [None]:
import plotly.io as pio
pio.renderers.default = "notebook"

In [None]:
'''
split the combined columns:
for example combined data sources -->
for combo size 2 (a, b), split them into two objects,
where the first object in combo col will be a, and second object will be b.
all other values remain same

"combined_cols": [
      "Image_Search",
      "Location_history"
    ],
    "combo_size": 2,
    "gpt_output": {
      "inferences": [
        {
          "inference": "",
          "uncommonness": ,
          "sensitivity": ,
          "product_recommendation": "",
          "label": "",
          "category": "",
          "activity": "",
          "reason": ""
        },
        {
          "inference": "",
          "uncommonness": 5,
          "sensitivity": 4,
          "product_recommendation": "",
          "label": "",
          "category": "",
          "activity": "",
          "reason": ""
        },
        {
          "inference": "",
          "uncommonness": 6,
          "sensitivity": 2,
          "product_recommendation": "",
          "label": "",
          "category": "",
          "activity": "",
          "reason": ""
        }
      ],
      "overall_product_recommendation": ""
    },

"split_cols": [
      {
        "split_col": "Image_Search",
        "gpt_output": {
          "inferences": [
            {
              "inference": "":
              "uncommonness": ,
              "sensitivity": ,
              "product_recommendation": "",
              "label": "",
              "category": "",
              "activity": "",
              "reason": ""
            },
            {
              "inference": "",
              "uncommonness": 5,
              "sensitivity": 4,
              "product_recommendation": "",
              "label": "",
              "category": "",
              "activity": "",
              "reason": ""
            },
            {
              "inference": "",
              "uncommonness": ,
              "sensitivity": ,
              "product_recommendation": "",
              "label": "",
              "category": "",
              "activity": "",
              "reason": ""
            }
          ]
'''

'\nsplit the combined columns: \nfor example combined data sources --> \nfor combo size 2 (a, b), split them into two objects, \nwhere the first object in combo col will be a, and second object will be b. \nall other values remain same\n\n"combined_cols": [\n      "Image_Search",\n      "Location_history"\n    ],\n    "combo_size": 2,\n    "gpt_output": {\n      "inferences": [\n        {\n          "inference": "",\n          "uncommonness": ,\n          "sensitivity": ,\n          "product_recommendation": "",\n          "label": "",\n          "category": "",\n          "activity": "",\n          "reason": ""\n        },\n        {\n          "inference": "",\n          "uncommonness": 5,\n          "sensitivity": 4,\n          "product_recommendation": "",\n          "label": "",\n          "category": "",\n          "activity": "",\n          "reason": ""\n        },\n        {\n          "inference": "",\n          "uncommonness": 6,\n          "sensitivity": 2,\n          "produ

In [None]:
import json

# Load your split JSON data (from previous step)
with open('/Users/sushmitakhan/Desktop/infer_data_exploration.json', 'r') as f:
    data = json.load(f)

In [None]:
extended_data = []

for entry in data:
    combined_cols = entry.get("combined_cols", [])
    combo_size = entry.get("combo_size", 1)

    # Copy the original entry
    new_entry = entry.copy()

    # Only create split_cols if there are multiple combined_cols
    if combo_size > 1 and len(combined_cols) > 1:
        split_cols = []
        for col in combined_cols:
            split_cols.append({
                "split_col": col,
                "gpt_output": entry["gpt_output"]
            })
        new_entry["split_cols"] = split_cols

    extended_data.append(new_entry)

# Save the extended dataset
with open('/Users/sushmitakhan/Desktop/user_data_extended_v5.json', 'w') as f:
    json.dump(extended_data, f, indent=2)

print(f"Original data length: {len(data)}, Extended data length: {len(extended_data)}")


Original data length: 15, Extended data length: 15


In [None]:
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objects as go
from collections import defaultdict

# Load the updated dataset
with open('/Users/sushmitakhan/Desktop/user_data_extended_v4.json', 'r') as f:
    sankey_data = json.load(f)

In [None]:


app = dash.Dash(__name__)

app.layout = html.Div([
    html.H2("Interactive Sankey Diagram: Source → Category → Recommendation → Activity"),

    html.Div([
        html.Label("Filter by Source:"),
        dcc.Dropdown(
            id="source-dropdown",
            options=[{"label": col, "value": col}
                     for entry in sankey_data for col in entry["combined_cols"]],
            multi=True,
            placeholder="Select Source(s)"
        ),
    ], style={"width": "30%", "display": "inline-block", "margin-right": "20px"}),

    html.Div([
        html.Label("Filter by Category:"),
        dcc.Dropdown(
            id="category-dropdown",
            options=[],
            multi=False,
            placeholder="Select Category"
        ),
    ], style={"width": "30%", "display": "inline-block"}),

    dcc.Graph(id="sankey-graph", style={"height": "800px"})
])

def generate_sankey(selected_sources=None, selected_category=None):
    def format_label(label, max_len=25):
        """Break long labels into multiple lines for readability."""
        if not label:
            return ""
        if len(label) <= max_len:
            return label
        return "<br>".join([label[i:i+max_len] for i in range(0, len(label), max_len)])

    node_index = {}
    node_counter = 0
    labels = []
    sources = []
    targets = []
    values = []

    # Determine max label length per layer to adjust vertical spacing
    max_label_length = 0

    for entry in sankey_data:
        combo_cols = entry.get("combined_cols", [])
        gpt_output = entry.get("gpt_output", {})

        sources_list = combo_cols if len(combo_cols) > 1 else [combo_cols[0]]

        for col in sources_list:
            if selected_sources and col not in selected_sources:
                continue
            for inf in gpt_output.get("inferences", []):
                label = inf["label"]
                activity = inf.get("activity")
                recommendation = inf.get("recommended_product") or inf.get("product_recommendation") or gpt_output.get("overall_recommended_product") or gpt_output.get("recommended_product_combined") or gpt_output.get("overall_recommendation", {}).get("recommended_product") or gpt_output.get("recommended_product_combined")

                # Skip if filtering by category and doesn't match
                if selected_category and label != selected_category:
                    continue

                # Nodes for source, category, recommendation
                node_names = [col, label, recommendation]

                # If category selected, insert activity layer
                if selected_category:
                    node_names.insert(2, activity)

                for node_name in node_names:
                    friendly_name = format_label(node_name)
                    max_label_length = max(max_label_length, len(friendly_name))
                    if friendly_name not in node_index:
                        node_index[friendly_name] = node_counter
                        labels.append(friendly_name)
                        node_counter += 1

                # Build links
                for i in range(len(node_names) - 1):
                    src_idx = node_index[format_label(node_names[i])]
                    tgt_idx = node_index[format_label(node_names[i + 1])]
                    sources.append(src_idx)
                    targets.append(tgt_idx)
                    values.append(1)

    if not labels:
        return go.Figure()

    # Dynamically adjust node padding and thickness based on max label length
    node_pad = max(15, min(60, max_label_length * 1.5))
    node_thickness = max(20, min(40, max_label_length * 0.8))

    fig = go.Figure(data=[go.Sankey(
        node=dict(
            pad=node_pad,
            thickness=node_thickness,
            line=dict(color="black", width=0.5),
            label=labels,
            color="skyblue"
        ),
        link=dict(
            source=sources,
            target=targets,
            value=values
        )
    )])

    fig.update_traces()
    fig.update_layout(title_text="Interactive Sankey Diagram", font_size=10)
    return fig

# Update category dropdown based on selected sources
@app.callback(
    Output("category-dropdown", "options"),
    Input("source-dropdown", "value")
)
def update_category_options(selected_sources):
    categories = set()
    for entry in sankey_data:
        combo_cols = entry.get("combined_cols", [])
        sources_list = combo_cols if len(combo_cols) > 1 else [combo_cols[0]]
        if selected_sources and not any(col in selected_sources for col in sources_list):
            continue
        for inf in entry.get("gpt_output", {}).get("inferences", []):
            categories.add(inf["label"])
    return [{"label": c, "value": c} for c in sorted(categories)]

# Update Sankey diagram
@app.callback(
    Output("sankey-graph", "figure"),
    Input("source-dropdown", "value"),
    Input("category-dropdown", "value")
)
def update_sankey(selected_sources, selected_category):
    return generate_sankey(selected_sources, selected_category)

if __name__ == "__main__":
    app.run(debug=True)
