In [1]:
import pandas as pd
import numpy as np

entities=pd.read_csv("../data/extract_data/metakg_entities.csv")
triples=pd.read_csv("../data/extract_data/metakg_triples.csv")


In [36]:
import pandas as pd
import numpy as np
from pyecharts import options as opts
from pyecharts.charts import Sankey
from pyecharts.globals import ThemeType
from collections import defaultdict
import random


hmdb_list=['hmdb_id:HMDB0001308','hmdb_id:HMDB0062800','hmdb_id:HMDB0001341','hmdb_id:HMDB0000208','hmdb_id:HMDB0062424','hmdb_id:HMDB0014559','hmdb_id:HMDB0000168','hmdb_id:HMDB0000191','hmdb_id:HMDB0000538','hmdb_id:HMDB0000562','hmdb_id:HMDB0000082','hmdb_id:HMDB0001532','hmdb_id:HMDB0000998','hmdb_id:HMDB0000012','hmdb_id:HMDB0001112','hmdb_id:HMDB0001473','hmdb_id:HMDB0001068','hmdb_id:HMDB0001342','hmdb_id:HMDB0001409','hmdb_id:HMDB0001058','hmdb_id:HMDB0001049','hmdb_id:HMDB0001201','hmdb_id:HMDB0008327','hmdb_id:HMDB0001397','hmdb_id:HMDB0003379','hmdb_id:HMDB0015536','hmdb_id:HMDB0000172','hmdb_id:HMDB0000156','hmdb_id:HMDB0002108','hmdb_id:HMDB0001138','hmdb_id:HMDB0006029','hmdb_id:HMDB0011745','hmdb_id:HMDB0001487','hmdb_id:HMDB0000217','hmdb_id:HMDB0000828','hmdb_id:HMDB0250791','hmdb_id:HMDB0001489','hmdb_id:HMDB0000618','hmdb_id:HMDB0029418','hmdb_id:HMDB0060274','hmdb_id:HMDB0000251','hmdb_id:HMDB0000295','hmdb_id:HMDB0000286','hmdb_id:HMDB0000290','hmdb_id:HMDB0000285']
select_relations = ['has_pathway', 'has_disease', 'has_reference', 'has_tissue_location']
num_relations_to_select = 10

def create_sankey_plot(triples, select_relations, hmdb_list, num_relations_to_select):

    triples = triples[triples["Relationship"].isin(select_relations)]

    def prepare_sankey_data(triples, category):
        nodes = set()
        links = []
        
        for _, row in triples.iterrows():
            if row['Head'].startswith('hmdb_id:'):
                source = row['Head']
                target = f"{category}:{row['Tail']}"
            else:
                source = f"{category}:{row['Head']}"
                target = row['Tail']
            
            nodes.add(source)
            nodes.add(target)
            links.append({"source": source, "target": target, "value": 1})
        
        return list(nodes), links

    num_relations_to_select = 10
    all_nodes = set()
    all_links = []

    for relation in select_relations:
        relation_triples = triples[(triples["Relationship"] == relation) & 
                                ((triples["Head"].isin(hmdb_list)) | (triples["Tail"].isin(hmdb_list)))]
        top_n_relation = relation_triples[relation_triples["Tail"].isin(relation_triples["Tail"].value_counts().index[:num_relations_to_select])]
        nodes, links = prepare_sankey_data(top_n_relation, relation)
        all_nodes.update(nodes)
        all_links.extend(links)

    metabolite_values = defaultdict(int)
    for link in all_links:
        if link["source"].startswith("hmdb_id:"):
            metabolite_values[link["source"]] += link["value"]

    project_node = "project_id:001"
    all_nodes.add(project_node)
    metabolite_nodes = [node for node in all_nodes if node.startswith("hmdb_id:")]
    for metabolite in metabolite_nodes:
        all_links.append({"source": project_node, "target": metabolite, "value": metabolite_values[metabolite]})

    def get_color_palette():
        return [
            "#4ECDC4",  # 水绿色
            "#FF6F61",  # 珊瑚色
            "#6A5ACD",  # 杜鹃紫
            "#FFB400",  # 明亮的黄色
            "#FF6F91",  # 浅红色
            "#1A1A1D",  # 深灰色
            "#F7B7A3",  # 浅粉色
            "#C7D2FE",  # 浅蓝色
            "#FF9A00",  # 橙色
            "#A0D9CE"   # 浅绿色
        ]

    color_palette = get_color_palette()

    categories = ["project_id", "hmdb_id"] + select_relations
    color_map = {category: color_palette[i % len(color_palette)] for i, category in enumerate(categories)}


    nodes = []
    for node in all_nodes:
        category = node.split(":")[0] if ":" in node else "hmdb_id"
        nodes.append({"name": ":".join(node.split(":")[-2:]), "itemStyle": {"color": color_map[category]}})

    for link in all_links:
        target_category = link["target"].split(":")[0] if ":" in link["target"] else "hmdb_id"
        link["target"]=":".join(link["target"].split(":")[-2:]) if ":" in link["target"] else link["target"]
        link["lineStyle"] = {"color": color_map[target_category]}

    sankey = Sankey(init_opts=opts.InitOpts(width="1200px", height="800px", theme=ThemeType.LIGHT))

    for category, color in color_map.items():
        sankey.add(
            series_name=category,
            nodes=[],
            links=[],
            label_opts=opts.LabelOpts(color=color),
            linestyle_opt=opts.LineStyleOpts(opacity=0.3, curve=0.5, color=color),
            itemstyle_opts=opts.ItemStyleOpts(color=color),
        )

    sankey.add(
        series_name="Metabolite Relationships",
        nodes=nodes,
        links=all_links,
        linestyle_opt=opts.LineStyleOpts(opacity=0.2, curve=0.5),
        label_opts=opts.LabelOpts(position="right"),
        node_gap=4,
    )

    sankey.set_global_opts(
        title_opts=opts.TitleOpts(title="Metabolite Relationships Sankey Diagram"),
        tooltip_opts=opts.TooltipOpts(trigger="item", trigger_on="mousemove"),
        legend_opts=opts.LegendOpts(
            orient="horizontal",
            pos_left="30%",
            pos_top="top",
            item_width=20,
            item_height=15,
            textstyle_opts=opts.TextStyleOpts(font_size=12),
        )
    )

    sankey.render("metabolite_relationships_sankey.html")

'/Users/colton/metakg-ori/analysis/nature metab/metabolite_relationships_sankey.html'