<a href="https://colab.research.google.com/github/KravitzLab/FED3Analyses/blob/PCA_analysis/mgi_heiarchy_create.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install owlready2
!pip install --upgrade jinja2 pyvis
!pip install pygraphviz
!pip install pronto
!pip install faiss-cpu

In [None]:
# @title Install and Import libraries

#!pip install pronto
#!pip install pyvis
#!pip install sentence-transformers

import os, re, zipfile
import pandas as pd
from bs4 import BeautifulSoup
from google.colab import files
import csv
import pronto
import pyvis
import ipywidgets as widgets
import networkx as nx
from pyvis.network import Network
from google.colab import files
import networkx as nx
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import csv
from sentence_transformers import SentenceTransformer, util
from collections import Counter
import nltk
from nltk.corpus import stopwords
import html
import faiss
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
import numpy as np

In [None]:
# @title Upload Ontology File


#### Upload ontology file ####
uploaded = files.upload()
filename = list(uploaded.keys())[0]

#### Load ontology ####
mp = pronto.Ontology(filename)


# Identify roots (should only have 1 root for MGI)
roots = [t for t in mp.terms() if not list(t.superclasses(distance=1))]


# Depth function to calculate the minimum distance to the root
def get_depth(term):
    distances = []
    for root in roots:
        d = term.distance_from(root)
        if d is not None:
            distances.append(d)
    return min(distances) if distances else 0

# Calculate ancestors and descendents
def count_ancestors(term):
    return len(list(term.superclasses())) - 1  # subtract itself

def count_descendants(term):
    return len(list(term.subclasses())) - 1  # subtract itself

def is_leaf(term):
    return len(list(term.subclasses(distance=1))) == 0


#### Output CSV ####
# this is all nodes for the entire ontology
output_file = "ontology_edges_with_metadata.csv"

with open(output_file, "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow([
        "parent_id", "parent_label", "parent_definition", "parent_depth",
        "child_id", "child_label", "child_definition", "child_depth",
        "child_is_leaf", "num_ancestors_child", "num_descendants_child"
    ])

    for term in mp.terms():
        parents = list(term.superclasses(distance=1))

        for parent in parents:
            writer.writerow([
                parent.id,
                parent.name,
                parent.definition or "",
                get_depth(parent),

                term.id,
                term.name,
                term.definition or "",
                get_depth(term),

                "yes" if is_leaf(term) else "no",
                count_ancestors(term),
                count_descendants(term)
            ])


### Download the parent-child ontologies ###
try:
    from google.colab import files as gfiles
except Exception:
    gfiles = None

btn = widgets.Button(description=f"Download {os.path.basename(output_file)}", icon="download")
status = widgets.HTML()
def _dl(_):
    if gfiles is not None:
        status.value = f"Starting download: <code>{os.path.basename(output_file)}</code>…"
        gfiles.download(output_file)
    else:
        status.value = f"Saved locally at <code>{output_file}</code>."
display(btn, status)
btn.on_click(_dl)





In [None]:
# @title Trim the ontology
# define where the root starts and define how many descendents to include

# crop network for only behavior and 4 ancestors (abnormal behavior is 2 down)
# Load df
df = pd.read_csv("ontology_edges_with_metadata.csv")

# Define root
# mammalian phenotype: MP:0000001
#root_id = 'MP:0000001'
# abnormal behavior: MP:0004924
root_id = 'MP:0004924'
trim_root = mp[root_id]

# Define the level of ancestries
# 36 is the max level of ancestors a child can have (for all inclusion set to 36)
print(df['num_ancestors_child'].max())
anc_level = 36

# Collect descendants (all depths)
trimmed_descendants = {trim_root.id}
trimmed_descendants.update([t.id for t in trim_root.subclasses()])

# Keep only edges where BOTH parent and child are in the behavior subtree
df_trimmed = df[
    df["parent_id"].isin(trimmed_descendants) &
    df["child_id"].isin(trimmed_descendants)
]
print(df_trimmed.columns)

# define levels of ancetries
df_trimmed = df_trimmed[df_trimmed["num_ancestors_child"] <= anc_level]

# trim self referential edges
df_trimmed = df_trimmed[df_trimmed["parent_id"] != df_trimmed["child_id"]]
df_trimmed = df_trimmed[df_trimmed["parent_label"] != df_trimmed["child_label"]]
df_trimmed = df_trimmed.drop_duplicates(subset=["parent_id", "child_id"])


#### Download the trimmed behavior parent-child ontologies ####
output_file = "ontology_edges_trimmed.csv"
df_trimmed.to_csv(output_file, index=False)

try:
    from google.colab import files as gfiles
except Exception:
    gfiles = None

btn = widgets.Button(description=f"Download {os.path.basename(output_file)}", icon="download")
status = widgets.HTML()
def _dl(_):
    if gfiles is not None:
        status.value = f"Starting download: <code>{os.path.basename(output_file)}</code>…"
        gfiles.download(output_file)
    else:
        status.value = f"Saved locally at <code>{output_file}</code>."
display(btn, status)
btn.on_click(_dl)

In [None]:
# @title Retrieve all the leaf nodes
# Get all leaf nodes
df_leafs = df_trimmed.copy()

# All parent IDs
all_parents = set(df_leafs["parent_id"])

# All child IDs
all_children = set(df_leafs["child_id"])

# Leafs = children that are never parents
leaf_nodes = all_children - all_parents

# Optionally get their labels
leaf_labels = df_leafs[df_leafs["child_id"].isin(leaf_nodes)][["child_id", "child_label", "child_definition"]].drop_duplicates()

### Download leaf labels ###
output_file = "leaf_nodes.csv"
leaf_labels.to_csv(output_file, index=False)

try:
    from google.colab import files as gfiles
except Exception:
    gfiles = None

btn = widgets.Button(description=f"Download {os.path.basename(output_file)}", icon="download")
status = widgets.HTML()
def _dl(_):
    if gfiles is not None:
        status.value = f"Starting download: <code>{os.path.basename(output_file)}</code>…"
        gfiles.download(output_file)
    else:
        status.value = f"Saved locally at <code>{output_file}</code>."
display(btn, status)
btn.on_click(_dl)

In [None]:
# @title OPTIONAL - Find the most common words in leafs of MGI branch

# Download stopwords if needed
nltk.download('stopwords')

# Combine label + definition text fields
text_data = (
    leaf_labels["child_label"].fillna('') + " " +
    leaf_labels["child_definition"].fillna('')
)

# Convert to one long string
all_text = " ".join(text_data.tolist()).lower()

# Remove punctuation and non-alpha
all_text = re.sub(r'[^a-z\s]', ' ', all_text)

# Tokenize
words = all_text.split()

# Remove stopwords
stop_words = set(stopwords.words("english"))
words = [w for w in words if w not in stop_words and len(w) > 2]

# Count frequency
word_counts = Counter(words)

# Get top 50 words
top_words = word_counts.most_common(50)

top_words

In [None]:
# @title Map Metrics onto the Ontology

#### Upload defined metrics ####
# and the info to append to the rag space
uploaded = files.upload()
filename = list(uploaded.keys())[0]
df_metrics = pd.read_csv(filename, encoding="latin1")

# remove rows where metric definitions are null
df_metrics = df_metrics.dropna(subset=['metric_definition']).reset_index(drop=True)

# Create the model by reference
# model = SentenceTransformer('all-MiniLM-L6-v2')
model = SentenceTransformer('all-mpnet-base-v2')


# Find columns with keyword string
pattern = "keyword"
matching_columns = df_metrics.columns[df_metrics.columns.str.contains(pattern)]


# create either null column or concatentaion of key words if exists.
if len(matching_columns) == 0:
    print("no matching columns")
    df_metrics["keywords_all"] = ""
    metric_texts = (
        #df_metrics["metric_name"] + ". " +
        df_metrics["metric_definition"]
    ).tolist()
else:
  # Merge all keyword column into 1 keyword column
    df_metrics["keywords_all"] = (
        df_metrics[matching_columns]
        .apply(lambda row: '. '.join(row.dropna().astype(str)), axis=1)
    )
    # Create text list
    metric_texts = (
        #df_metrics["metric_name"] + ". " +
        #df_metrics['keywords_all'] + ". " +
        df_metrics["metric_definition"]
    ).tolist()



onto_texts = (
    leaf_labels["child_label"] + ". " + leaf_labels["child_definition"]
).tolist()

#print(metric_texts)
#print(df_metrics)

# create the embeddings
metric_emb = model.encode(metric_texts, convert_to_tensor=True)
onto_emb = model.encode(onto_texts, convert_to_tensor=True)

# compute the similarity matrix
sim_matrix = util.cos_sim(metric_emb, onto_emb)


# Extract the top matches
# Define similarity threshold score
threshold = 0.50
top_k = 5

filtered_matches = []


for i, mrow in df_metrics.iterrows():
    scores = sim_matrix[i]

    # top-k candidate ontology indices
    top_idx = scores.topk(top_k).indices.tolist()

    for idx in top_idx:
        score_val = float(scores[idx])

        if score_val >= threshold:
            filtered_matches.append({
                "metric_id": mrow["metric_id"],
                "metric_name": mrow["metric_name"],
                "metric_keywords": mrow["keywords_all"],
                "metric_definition": mrow["metric_definition"],
                "ontology_id": leaf_labels["child_id"].iloc[idx],
                "ontology_term": leaf_labels["child_label"].iloc[idx],
                "ontology_definition": leaf_labels["child_definition"].iloc[idx],
                "similarity": score_val
            })

filtered_df = pd.DataFrame(filtered_matches)
filtered_df.head()


#### Download mapped metrics ####
output_file = "onto_metrics_mapped.csv"
filtered_df.to_csv(output_file, index=False)

try:
    from google.colab import files as gfiles
except Exception:
    gfiles = None

btn = widgets.Button(description=f"Download {os.path.basename(output_file)}", icon="download")
status = widgets.HTML()
def _dl(_):
    if gfiles is not None:
        status.value = f"Starting download: <code>{os.path.basename(output_file)}</code>…"
        gfiles.download(output_file)
    else:
        status.value = f"Saved locally at <code>{output_file}</code>."
display(btn, status)
btn.on_click(_dl)


In [None]:
# @title Retrival matching 1 - Map metrics onto Ontology

############################ Upload defined metrics ############################
# and the info to append to the rag space
uploaded = files.upload()
filename = list(uploaded.keys())[0]
behavior_name = list(uploaded.keys())[1]

df_metrics = pd.read_csv(filename, encoding="latin1")
bev_info_df = pd.read_csv(behavior_name, encoding="latin1")

# copy the ontology
onto_df = df_trimmed[[
    "child_id",
    "child_label",
    "child_definition",
    "parent_id",
    "parent_label",
    "parent_definition"
]].copy()

# remove rows where metric definitions are null
df_metrics = df_metrics.dropna(subset=['metric_definition']).reset_index(drop=True)

######################## Adopt model and raise with RAG ########################
# model = SentenceTransformer('all-MiniLM-L6-v2')
rag_model = SentenceTransformer('all-mpnet-base-v2')

# Extract the text columns containing passed information
behavior_texts = bev_info_df["text"].astype(str).tolist()

# pass the ontology informaiton
onto_df["embed_text"] = (
    onto_df["child_label"].astype(str)
    + ". "
    + onto_df["child_definition"].astype(str)
)
onto_texts = onto_df["embed_text"].tolist()

# metadata list in same order as embeddings
onto_meta = []
for _, row in onto_df.iterrows():
    onto_meta.append({
        "child_id": row.child_id,
        "child_label": row.child_label,
        "child_definition": row.child_definition,
        "parent_id": row.parent_id,
        "parent_label": row.parent_label,
        "parent_definition": row.parent_definition,
    })

print(onto_meta)

############################### Build Embeddings ###############################
# Extract behavior knowledge (from your custom uploaded CSV)
behavior_texts = bev_info_df["text"].astype(str).tolist()

# Combine behavior texts (RAG) + ontology texts (child terms)
all_texts = behavior_texts + onto_texts

# Embed all RAG knowledge + ontology nodes
kb_emb = rag_model.encode(all_texts, convert_to_numpy=True, normalize_embeddings=True)
dimension = kb_emb.shape[1]

# Build FAISS index
index = faiss.IndexFlatL2(dimension)
index.add(kb_emb.astype("float32"))


# Find columns with keyword string
pattern = "keyword"
matching_columns = df_metrics.columns[df_metrics.columns.str.contains(pattern)]

# Merge keywords if they exist
if len(matching_columns) == 0:
    df_metrics["keywords_all"] = ""
else:
    df_metrics["keywords_all"] = (
        df_metrics[matching_columns]
        .apply(lambda row: ". ".join(row.dropna().astype(str)), axis=1)
    )

# Construct the metric text used for embedding
df_metrics["metric_text"] = (
    #df_metrics["metric_name"].astype(str)
    #+ ". " +
    #df_metrics["keywords_all"].astype(str) + ". " +
    df_metrics["metric_definition"].astype(str)
)

metric_texts = df_metrics["metric_text"].tolist()

# Embed the definitions
metric_emb = rag_model.encode(metric_texts, convert_to_numpy=True, normalize_embeddings=True)


#### compute the similarity matrix ####
#sim_matrix = util.cos_sim(metric_emb, onto_emb)

top_k = 5
threshold = 0.50

filtered_matches = []

# ontology starts after behavior entries
start_index_onto = len(behavior_texts)

for i, mrow in df_metrics.iterrows():

    q_emb = metric_emb[i:i+1]  # FAISS expects shape (1, dim)
    D, I = index.search(q_emb.astype("float32"), k=top_k)

    for score, idx in zip(D[0], I[0]):

        # Only keep ontology embedding to compare
        if idx < start_index_onto:
            continue  # skip behavior knowledge

        onto_idx = idx - start_index_onto  # index into onto_df/onto_meta
        if onto_idx < 0 or onto_idx >= len(onto_meta):
            continue

        # convert distance to similarity-like score
        similarity = 1 / (1 + score)

        if similarity >= threshold:
            meta = onto_meta[onto_idx]

            filtered_matches.append({
                "metric_id": mrow["metric_id"],
                "metric_name": mrow["metric_name"],
                "metric_keywords": mrow["keywords_all"],
                "metric_definition": mrow["metric_definition"],
                "ontology_child_id": meta["child_id"],
                "ontology_child_label": meta["child_label"],
                "ontology_child_definition": meta["child_definition"],
                "ontology_parent_id": meta["parent_id"],
                "ontology_parent_label": meta["parent_label"],
                "similarity": similarity
            })


filtered_df = pd.DataFrame(filtered_matches)
filtered_df.head()



#### Download mapped metrics ####
output_file = "onto_metrics_mapped_rag.csv"
filtered_df.to_csv(output_file, index=False)

try:
    from google.colab import files as gfiles
except Exception:
    gfiles = None

btn = widgets.Button(description=f"Download {os.path.basename(output_file)}", icon="download")
status = widgets.HTML()
def _dl(_):
    if gfiles is not None:
        status.value = f"Starting download: <code>{os.path.basename(output_file)}</code>…"
        gfiles.download(output_file)
    else:
        status.value = f"Saved locally at <code>{output_file}</code>."
display(btn, status)
btn.on_click(_dl)


In [None]:
# @title Retrival matching 2
######################### Upload metrics & RAG info ############################
uploaded = files.upload()
filename = list(uploaded.keys())[0]
behavior_name = list(uploaded.keys())[1]

df_metrics = pd.read_csv(filename, encoding="latin1")
bev_info_df = pd.read_csv(behavior_name, encoding="latin1")

# copy ontology
onto_df = df_trimmed[[
    "child_id", "child_label", "child_definition",
    "parent_id", "parent_label", "parent_definition"
]].copy()

df_metrics = df_metrics.dropna(subset=['metric_definition']).reset_index(drop=True)

############################### Build RAG model ##################################
rag_model = SentenceTransformer('all-mpnet-base-v2')

### Prepare ontology text ###
onto_df["embed_text"] = (
    onto_df["child_label"].astype(str)
    + ". " +
    onto_df["child_definition"].astype(str)
)

onto_texts = onto_df["embed_text"].tolist()

### Build ontology metadata ###
onto_meta = onto_df.to_dict("records")

### Encode ontology ###
onto_emb = rag_model.encode(
    onto_texts, convert_to_numpy=True, normalize_embeddings=True
)

### Build FAISS index for ontology ONLY ###
dimension = onto_emb.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(onto_emb.astype("float32"))


########################### Build metric embedding #############################

# Merge keywords if present
pattern = "keyword"
matching_columns = df_metrics.columns[df_metrics.columns.str.contains(pattern)]

if len(matching_columns) == 0:
    df_metrics["keywords_all"] = ""
else:
    df_metrics["keywords_all"] = (
        df_metrics[matching_columns]
            .apply(lambda row: ". ".join(row.dropna().astype(str)), axis=1)
    )

# df_metrics["metric_text"] = df_metrics["metric_definition"].astype(str)
# Construct the metric text used for embedding
df_metrics["metric_text"] = (
    # df_metrics["metric_name"].astype(str)
    # + ". " +
    # df_metrics["keywords_all"].astype(str) + ". " +
    df_metrics["metric_definition"].astype(str)
)

metric_texts = df_metrics["metric_text"].tolist()

metric_emb = rag_model.encode(
    metric_texts, convert_to_numpy=True, normalize_embeddings=True
)

############################## Matching loop ###################################
top_k = 5
threshold = 0.50
filtered_matches = []

for i, mrow in df_metrics.iterrows():
    q_emb = metric_emb[i:i+1]  # shape (1, dim)
    D, I = index.search(q_emb.astype("float32"), k=top_k)

    for score, idx in zip(D[0], I[0]):
        # convert L2 distance to cosine similarity
        similarity = 1 - (score / 2)

        if similarity < threshold:
            continue

        meta = onto_meta[idx]

        filtered_matches.append({
            "metric_id": mrow["metric_id"],
            "metric_name": mrow["metric_name"],
            "metric_keywords": mrow["keywords_all"],
            "metric_definition": mrow["metric_definition"],
            "ontology_child_id": meta["child_id"],
            "ontology_child_label": meta["child_label"],
            "ontology_child_definition": meta["child_definition"],
            "ontology_parent_id": meta["parent_id"],
            "ontology_parent_label": meta["parent_label"],
            "similarity": similarity
        })

filtered_df = pd.DataFrame(filtered_matches)
filtered_df.head()


#### Download mapped metrics ####
output_file = "onto_metrics_mapped_rag.csv"
filtered_df.to_csv(output_file, index=False)

try:
    from google.colab import files as gfiles
except Exception:
    gfiles = None

btn = widgets.Button(description=f"Download {os.path.basename(output_file)}", icon="download")
status = widgets.HTML()
def _dl(_):
    if gfiles is not None:
        status.value = f"Starting download: <code>{os.path.basename(output_file)}</code>…"
        gfiles.download(output_file)
    else:
        status.value = f"Saved locally at <code>{output_file}</code>."
display(btn, status)
btn.on_click(_dl)


In [None]:
# @title Retrival Matching 3
# inject embeddings that describe attributes
from google.colab import files
###################### Upload metrics & injectable info ########################
"""uploaded = files.upload()
filename = list(uploaded.keys())[0]
behavior_name = list(uploaded.keys())[1]

df_metrics = pd.read_csv(filename, encoding="latin1")
bev_info_df = pd.read_csv(behavior_name, encoding="latin1")"""

uploaded = files.upload()

file_roles = {
    "metrics": ["metric", "defined_metrics"],
    "info": ["information", "info"],
    "target": ["target", "targets"]
}
identified = {k: [] for k in file_roles}

# classify files
for fname in uploaded.keys():
    lname = fname.lower()
    for role, patterns in file_roles.items():
        if any(p in lname for p in patterns):
            identified[role].append(fname)

# validate
for role, files in identified.items():
    if len(files) != 1:
        raise ValueError(
            f"Expected exactly 1 file for '{role}', found {len(files)}: {files}"
        )

# load
df_metrics = pd.read_csv(identified["metrics"][0], encoding="latin1")
bev_info_df = pd.read_csv(identified["info"][0], encoding="latin1")
target_df = pd.read_csv(identified["target"][0], encoding="latin1")


############## Read in the LLM weights for the sentence transformer ############
embedding_model = SentenceTransformer('all-mpnet-base-v2')

##################### Define ontology texts for embedding ######################
# Parent nodes
parents = df_trimmed[[
    "parent_id",
    "parent_label",
    "parent_definition"
]].rename(columns={
    "parent_id": "node_id",
    "parent_label": "label",
    "parent_definition": "definition"
})

# Child nodes
children = df_trimmed[[
    "child_id",
    "child_label",
    "child_definition"
]].rename(columns={
    "child_id": "node_id",
    "child_label": "label",
    "child_definition": "definition"
})

# Combine and remove duplicates to get a list of all unique nodes
onto_nodes = (
    pd.concat([parents, children], ignore_index=True)
      .drop_duplicates(subset=["node_id"])
      .reset_index(drop=True)
)

# create list of parent child node ID's
onto_edges = df_trimmed[[
    "parent_id",
    "child_id"
]].drop_duplicates()

# Create column to embed using the label and the definition (for now)
onto_nodes["embed_text"] = (
    onto_nodes["label"].astype(str) + ". " +
    onto_nodes["definition"].astype(str)
)

# Retrieve the texts as list and create embeddings
onto_texts = onto_nodes["embed_text"].tolist()
onto_emb = embedding_model.encode(
    onto_texts,
    convert_to_numpy=True,
    normalize_embeddings=True
)

# build a dictionary for the nodes
# this is aligned with embeddings
onto_meta = []
for _, row in onto_nodes.iterrows():
    onto_meta.append({
        "node_id": row.node_id,
        "label": row.label,
        "definition": row.definition
    })


######## inject behavior information and ontology nodes into embeddings ########
# create behavior texts
behavior_texts = bev_info_df["text"].tolist()

all_texts = behavior_texts + onto_nodes["embed_text"].tolist()

# create the knowledge base embeddings
kb_emb = embedding_model.encode(
    all_texts,
    convert_to_numpy=True,
    normalize_embeddings=True
    )

index = faiss.IndexFlatL2(kb_emb.shape[1])
index.add(kb_emb.astype("float32"))
# find where the ontology embeddings start
ontology_start = len(behavior_texts)


##################### Build the embeddings for the metrics #####################
# Clean up the dataframe
df_metrics = df_metrics.dropna(subset=['metric_definition', "metric_interpretation"]).reset_index(drop=True)


# Merge keywords if present
pattern = "keyword"
matching_columns = df_metrics.columns[df_metrics.columns.str.contains(pattern)]

if len(matching_columns) == 0:
    df_metrics["keywords_all"] = ""
else:
    df_metrics["keywords_all"] = (
        df_metrics[matching_columns]
            .apply(lambda row: ". ".join(row.dropna().astype(str)), axis=1)
    )

# df_metrics["metric_text"] = df_metrics["metric_definition"].astype(str)
# Construct the metric text used for embedding
df_metrics["metric_text"] = (
    # df_metrics["metric_name"].astype(str)
    # + ". " +
    # df_metrics["keywords_all"].astype(str) + ". " +
    df_metrics["metric_definition"].astype(str)
    + ": " + df_metrics["metric_interpretation"].astype(str)
)


metric_texts = df_metrics["metric_text"].tolist()
print(metric_texts)

metric_emb = embedding_model.encode(
    metric_texts,
    convert_to_numpy=True,
    normalize_embeddings=True
)



################## Iterate to compare metrics and ontologies ###################
# define the top results
top_k = 8
threshold = 0.50
results = []

# iterate through the metrics and map to ontology
for i, mrow in df_metrics.iterrows():

    q_emb = metric_emb[i:i+1]
    D, I = index.search(q_emb.astype("float32"), top_k)

    for dist, idx in zip(D[0], I[0]):

        if idx < ontology_start:
            continue

        node_idx = idx - ontology_start
        similarity = 1 - (dist / 2)

        if similarity >= threshold:
            node = onto_meta[node_idx]

            results.append({
                "metric_id": mrow.metric_id,
                "metric_name": mrow.metric_name,
                "metric_definition": mrow.metric_definition,
                "ontology_node_id": node["node_id"],
                "ontology_label": node["label"],
                "ontology_definition": node["definition"],
                #"node_depth": node["depth"],
                "similarity": similarity
            })

filtered_df = pd.DataFrame(results)
filtered_df.head()


#### Download mapped metrics ####
output_file = "onto_metrics_mapped_rag.csv"
filtered_df.to_csv(output_file, index=False)

try:
    from google.colab import files as gfiles
except Exception:
    gfiles = None

btn = widgets.Button(description=f"Download {os.path.basename(output_file)}", icon="download")
status = widgets.HTML()
def _dl(_):
    if gfiles is not None:
        status.value = f"Starting download: <code>{os.path.basename(output_file)}</code>…"
        gfiles.download(output_file)
    else:
        status.value = f"Saved locally at <code>{output_file}</code>."
display(btn, status)
btn.on_click(_dl)



In [None]:
# @title OPTIONAL - compare model setups


# Load data
uploaded = files.upload()
filename = list(uploaded.keys())[0]

comp_df = pd.read_csv(filename, encoding="latin1")


# Get all unique model types
models = comp_df['model'].unique()
print(f"\nModels found: {len(models)}")

for i, model in enumerate(models, 1):
    df_m = comp_df[comp_df['model'] == model]

    count = len(df_m)
    unique_metrics = df_m['metric_id'].nunique()

    # Filter for similarity > 0.7
    df_thr = df_m[df_m['similarity'] > 0.7]
    count_thr = df_thr['metric_id'].count()

    print(f"  {i}. {model}: {count} matches, {unique_metrics} unique metrics")
    print(f"     >0.7 similarity: {count_thr} metrics")

# Create dictionary of dataframes for each model
model_dfs = {model: comp_df[comp_df['model'] == model].copy() for model in models}

#### overlap analysis ####
# Create overlap matrix
overlap_matrix = pd.DataFrame(index=models, columns=models, dtype=float)
shared_pairs_dict = {}



#### plots ####
n_models = len(models)
fig = plt.figure(figsize=(18, 12))

# Plot 1: Similarity distributions for all models
"""ax1 = plt.subplot(2, 2, 1)
for model in models:
    ax1.hist(model_dfs[model]['similarity'], bins=30, alpha=0.5, label=model)
ax1.set_xlabel('Similarity Score')
ax1.set_ylabel('Frequency')
ax1.set_title('Similarity Score Distribution')
ax1.legend()
ax1.grid(alpha=0.3)

ax2 = plt.subplot(2, 2, 2)
data_to_plot = [model_dfs[model]['similarity'] for model in models]
ax2.boxplot(data_to_plot, labels=models)
ax2.set_ylabel('Similarity Score')
ax2.set_title('Similarity Score Comparison')
ax2.set_xticklabels(models, rotation=45, ha='right')
ax2.grid(alpha=0.3)"""

ax3 = plt.subplot(2, 2, 3)
avg_sims = [model_dfs[model]['similarity'].mean() for model in models]
colors = plt.cm.viridis(np.linspace(0, 1, n_models))
ax3.bar(range(n_models), avg_sims, color=colors)
ax3.set_xticks(range(n_models))
ax3.set_xticklabels(models, rotation=45, ha='right')
ax3.set_ylabel('Average Similarity')
ax3.set_title('Average Similarity by Model')
ax3.grid(alpha=0.3, axis='y')

ax4 = plt.subplot(2, 2, 4)
match_counts = [len(model_dfs[model]) for model in models]
ax4.bar(range(n_models), match_counts, color=colors)
ax4.set_xticks(range(n_models))
ax4.set_xticklabels(models, rotation=45, ha='right')
ax4.set_ylabel('Number of Matches')
ax4.set_title('Total Matches by Model')
ax4.grid(alpha=0.3, axis='y')



In [None]:
# @title Create WORK IN PROGRESS a network graph for the trimmed tree above

df_network = df_trimmed.copy()
df_mapped = filtered_df.copy()

df_mapped = df_mapped.rename(columns={'ontology_node_id': 'child_id',
                                      'ontology_label': 'ontology_term' })
#df_mapped = df_mapped.rename(columns={'ontology_child_id': 'child_id',
#                                      'ontology_child_label': 'ontology_term',
#                                      'ontology_child_definition': 'ontology_definition'})
df_network = pd.merge(df_network, df_mapped, on='child_id', how='left')


#### Build directed graph ####
G = nx.DiGraph()

for _, row in df_network.iterrows():
    parent = row["parent_id"]
    child = row["child_id"]

    if parent != child:
        G.add_edge(parent, child)

    # Node attributes
    G.nodes[parent]['label'] = row['parent_label']
    G.nodes[parent]['title'] = f"{row['parent_label']} ({parent})\n{row['parent_definition'] or ''}"

    G.nodes[child]['label'] = row['child_label']
    G.nodes[child]['title'] = f"{row['child_label']} ({child})\n{row['child_definition'] or ''}"

    # Add metrics
    if 'metrics' not in G.nodes[child]:
        G.nodes[child]['metrics'] = []
    if pd.notna(row.get("metric_name")):
        G.nodes[child]['metrics'].append(
            f"{row['metric_name']}: {row['metric_definition'] or ''}"
        )

#### Remove self-loops ####
G.remove_edges_from(list(nx.selfloop_edges(G)))


#### Compute descendants for sizing ####
descendant_counts = {n: len(nx.descendants(G, n)) for n in G.nodes()}


#### Assign unique colors to branches ####
direct_children = list(G.successors(root_id))
num_branches = len(direct_children)
cmap = cm.get_cmap('tab20', num_branches)
branch_colors = {child: mcolors.to_hex(cmap(i)) for i, child in enumerate(direct_children)}

#### Propagate branch color to all nodes ####
node_colors = {}
for branch_root, color in branch_colors.items():
    nodes_in_branch = nx.descendants(G, branch_root)
    nodes_in_branch.add(branch_root)
    for n in nodes_in_branch:
        node_colors[n] = color


#### Gray out unmapped nodes and propagate upward from mapped nodes ####
color_by_map = True

if color_by_map:
    mapped_nodes = set(df_network.loc[df_network["ontology_term"].notna(), "child_id"])

    # Gray out all unmapped nodes
    for n in G.nodes():
        if n not in mapped_nodes:
            node_colors[n] = "#D3D3D3"  # light gray

    # Color upward from mapped nodes
    for mapped in mapped_nodes:
        mapped_color = node_colors[mapped]
        ancestors = nx.ancestors(G, mapped)
        for a in ancestors:
            if node_colors[a] == "#D3D3D3":
                node_colors[a] = mapped_color

# Highlight root
node_colors[root_id] = "#ff9999"


#### Compute node sizes ####
# Should add logic to compute sizes based on number of mapped matrics
node_sizes = {n: 10 + descendant_counts.get(n, 0) * 2 for n in G.nodes()}
max_descendant_size = max(size for n, size in node_sizes.items() if n != root_id)
node_sizes[root_id] = max_descendant_size + 3


#### Prepare edge colors ####
edge_colors = {}
for branch_root, color in branch_colors.items():
    if G.has_edge(root_id, branch_root):
        edge_colors[(root_id, branch_root)] = color
    for u, v in nx.edge_dfs(G, branch_root):
        edge_colors[(u, v)] = color

# Gray edges connecting to gray nodes
for u, v in G.edges():
    if node_colors[u] == "#D3D3D3" or node_colors[v] == "#D3D3D3":
        edge_colors[(u, v)] = "#C0C0C0"


#### Create PyVis network ####
net = Network(
    notebook=True,
    directed=True,
    height="800px",
    width="100%",
    cdn_resources='in_line'
)

# HTML clean function
def clean(x):
    if x is None:
        return ""
    return html.escape(str(x))

# Add nodes with HTML tooltip
for node, data in G.nodes(data=True):
    # Base ontology section
    base_title = data.get('title', '')

    # Metrics section
    metric_text = ""
    if "metrics" in data and data["metrics"]:
        unique_metrics = sorted(set(data["metrics"]))
        metric_text = "\n\nMapped Metrics:\n" + "\n".join(unique_metrics)

    title = base_title + metric_text


    net.add_node(
            node,
            label=data.get("label", node),
            title=title,
            color=node_colors.get(node, "#66ccff"),
            size=node_sizes.get(node, 10)
        )

# Add edges
for u, v in G.edges():
    color = edge_colors.get((u, v), "#66ccff")
    net.add_edge(u, v, color=color)





#### Download ####
#filename = "mgi_network.html"
#net.show(filename)
#files.download("mgi_network.html")

filename = "mgi_network.html"

# Step 1: Generate the raw PyVis HTML
net.save_graph(filename)

# Step 2: Read it back in and inject search bar + lookup JS
with open(filename, "r", encoding="utf-8") as f:
    html_data = f.read()

# Fixed lookup search bar + JS function
lookup_js = """
<!-- Search Bar -->
<div style="position: fixed; top: 10px; left: 10px;
     z-index: 9999; background: rgba(255,255,255,0.95);
     padding: 10px; border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.2);">
    <input id="nodeSearch" type="text"
           placeholder="Search Network..."
           style="padding:6px; width:280px; border-radius:6px; border: 1px solid #ccc; font-size: 14px;">
    <div id="matchCount" style="margin-top: 5px; font-size: 12px; color: #666;"></div>
</div>

<script type="text/javascript">
var originalNodeColors = {};
var searchNetwork = null;
var searchTimeout = null;
var isSearchReady = false;

// Initialize after page loads
window.addEventListener('load', function() {
    setTimeout(initSearch, 1000);
});

function initSearch() {
    var net = null;

    // Find network object
    if (typeof network !== 'undefined') {
        net = network;
    } else {
        for (var key in window) {
            if (key.indexOf('network') === 0 && window[key] && window[key].body) {
                net = window[key];
                break;
            }
        }
    }

    if (!net || !net.body || !net.body.data || !net.body.data.nodes) {
        setTimeout(initSearch, 500);
        return;
    }

    searchNetwork = net;

    // Wait for stabilization before backing up colors
    net.once('stabilizationIterationsDone', function() {
        try {
            var nodesDataset = net.body.data.nodes;
            var allIds = nodesDataset.getIds();

            allIds.forEach(function(id) {
                var node = nodesDataset.get(id);
                if (node.color) {
                    if (typeof node.color === 'string') {
                        originalNodeColors[id] = node.color;
                    } else if (node.color.background) {
                        originalNodeColors[id] = node.color.background;
                    } else {
                        originalNodeColors[id] = '#97C2FC';
                    }
                } else {
                    originalNodeColors[id] = '#97C2FC';
                }
            });

            isSearchReady = true;
            console.log('Search ready with ' + allIds.length + ' nodes');

            // Attach search handler
            document.getElementById('nodeSearch').addEventListener('input', function() {
                clearTimeout(searchTimeout);
                searchTimeout = setTimeout(lookupNode, 300);
            });

        } catch(e) {
            console.error('Error backing up colors:', e);
        }
    });

    // Fallback if stabilization never fires
    setTimeout(function() {
        if (!isSearchReady) {
            try {
                var nodesDataset = net.body.data.nodes;
                var allIds = nodesDataset.getIds();

                allIds.forEach(function(id) {
                    var node = nodesDataset.get(id);
                    originalNodeColors[id] = (node.color && node.color.background) ? node.color.background :
                                            (typeof node.color === 'string' ? node.color : '#97C2FC');
                });

                isSearchReady = true;
                console.log('Search ready (fallback) with ' + allIds.length + ' nodes');

                document.getElementById('nodeSearch').addEventListener('input', function() {
                    clearTimeout(searchTimeout);
                    searchTimeout = setTimeout(lookupNode, 300);
                });
            } catch(e) {
                console.error('Fallback error:', e);
            }
        }
    }, 3000);
}

function lookupNode() {
    if (!isSearchReady || !searchNetwork) {
        return;
    }

    try {
        var query = document.getElementById('nodeSearch').value.toLowerCase().trim();
        var nodesDataset = searchNetwork.body.data.nodes;
        var allIds = nodesDataset.getIds();
        var matchedIds = [];

        // Clear search - restore all colors
        if (query === '') {
            var updates = [];
            allIds.forEach(function(id) {
                updates.push({
                    id: id,
                    color: originalNodeColors[id]
                });
            });
            nodesDataset.update(updates);
            document.getElementById('matchCount').textContent = '';
            searchNetwork.unselectAll();
            return;
        }

        // Search all nodes for partial matches
        var nodesWithMetrics = [];
        var nodesWithoutMetrics = [];

        allIds.forEach(function(id) {
            try {
                var node = nodesDataset.get(id);
                var label = (node.label || '').toLowerCase();
                var title = (node.title || '').toLowerCase();

                // Check if query matches anywhere in label or title
                if (label.indexOf(query) !== -1 || title.indexOf(query) !== -1) {
                    matchedIds.push(id);

                    // Check if this node has mapped metrics
                    if (title.indexOf('mapped metrics:') !== -1) {
                        nodesWithMetrics.push(id);
                    } else {
                        nodesWithoutMetrics.push(id);
                    }
                }
            } catch(e) {
                // Skip nodes that cause errors
            }
        });

        // Batch update colors: red for nodes with metrics, yellow for nodes without metrics
        var updates = [];
        allIds.forEach(function(id) {
            var color;
            if (nodesWithMetrics.indexOf(id) !== -1) {
                color = '#FF0000';  // Red for nodes with mapped metrics
            } else if (nodesWithoutMetrics.indexOf(id) !== -1) {
                color = '#FFFF00';  // Yellow for nodes without mapped metrics
            } else {
                color = '#DDDDDD';  // Gray for non-matches
            }
            updates.push({
                id: id,
                color: color
            });
        });
        nodesDataset.update(updates);

        // Update match count
        var countText = matchedIds.length === 0 ? 'No matches found' :
                        matchedIds.length === 1 ? '1 node found' :
                        matchedIds.length + ' nodes found';
        document.getElementById('matchCount').textContent = countText;

        // Select matched nodes
        if (matchedIds.length > 0) {
            searchNetwork.selectNodes(matchedIds);
        } else {
            searchNetwork.unselectAll();
        }

    } catch(e) {
        console.error('Search error:', e);
    }
}
</script>
"""

# Step 3: Insert the search bar before </body>
html_data = html_data.replace("</body>", lookup_js + "\n</body>")

# Step 4: Write out modified HTML
with open(filename, "w", encoding="utf-8") as f:
    f.write(html_data)

# Step 5: Download final file
files.download(filename)