<a href="https://colab.research.google.com/github/KravitzLab/PsygeneAnalyses/blob/PCA_analysis/mgi_heiarchy_create.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#!pip install owlready2
#!pip install --upgrade jinja2 pyvis
#!pip install pygraphviz

In [27]:
# @title Install and Import libraries

#!pip install pronto
#!pip install pyvis
#!pip install sentence-transformers

import os, re, zipfile
import pandas as pd
from bs4 import BeautifulSoup
from google.colab import files
import csv
import pronto
import pyvis
import ipywidgets as widgets
import networkx as nx
from pyvis.network import Network
from google.colab import files
import networkx as nx
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import csv
from sentence_transformers import SentenceTransformer, util

In [None]:
# @title Upload Ontology File


#### Upload ontology file ####
uploaded = files.upload()
filename = list(uploaded.keys())[0]

#### Load ontology ####
mp = pronto.Ontology(filename)


# Identify roots (should only have 1 root for MGI)
roots = [t for t in mp.terms() if not list(t.superclasses(distance=1))]


# Depth function to calculate the minimum distance to the root
def get_depth(term):
    distances = []
    for root in roots:
        d = term.distance_from(root)
        if d is not None:
            distances.append(d)
    return min(distances) if distances else 0

# Calculate ancestors and descendents
def count_ancestors(term):
    return len(list(term.superclasses())) - 1  # subtract itself

def count_descendants(term):
    return len(list(term.subclasses())) - 1  # subtract itself

def is_leaf(term):
    return len(list(term.subclasses(distance=1))) == 0


#### Output CSV ####
# this is all nodes for the entire ontology
output_file = "ontology_edges_with_metadata.csv"

with open(output_file, "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow([
        "parent_id", "parent_label", "parent_definition", "parent_depth",
        "child_id", "child_label", "child_definition", "child_depth",
        "child_is_leaf", "num_ancestors_child", "num_descendants_child"
    ])

    for term in mp.terms():
        parents = list(term.superclasses(distance=1))

        for parent in parents:
            writer.writerow([
                parent.id,
                parent.name,
                parent.definition or "",
                get_depth(parent),

                term.id,
                term.name,
                term.definition or "",
                get_depth(term),

                "yes" if is_leaf(term) else "no",
                count_ancestors(term),
                count_descendants(term)
            ])


### Download the parent-child ontologies ###
try:
    from google.colab import files as gfiles
except Exception:
    gfiles = None

btn = widgets.Button(description=f"Download {os.path.basename(output_file)}", icon="download")
status = widgets.HTML()
def _dl(_):
    if gfiles is not None:
        status.value = f"Starting download: <code>{os.path.basename(output_file)}</code>…"
        gfiles.download(output_file)
    else:
        status.value = f"Saved locally at <code>{output_file}</code>."
display(btn, status)
btn.on_click(_dl)





In [None]:
# @title Trim the ontology
# define where the root starts and define how many descendents to include

# crop network for only behavior and 4 ancestors (abnormal behavior is 2 down)
# Load df
df = pd.read_csv("ontology_edges_with_metadata.csv")

# Define root
# mammalian phenotype: MP:0000001
#root_id = 'MP:0000001'
# abnormal behavior: MP:0004924
root_id = 'MP:0004924'
trim_root = mp[root_id]

# Define the level of ancestries
# 36 is the max level of ancestors a child can have (for all inclusion set to 36)
print(df['num_ancestors_child'].max())
anc_level = 36

# Collect descendants (all depths)
trimmed_descendants = {trim_root.id}
trimmed_descendants.update([t.id for t in trim_root.subclasses()])

# Keep only edges where BOTH parent and child are in the behavior subtree
df_trimmed = df[
    df["parent_id"].isin(trimmed_descendants) &
    df["child_id"].isin(trimmed_descendants)
]
print(df_trimmed.columns)

# define levels of ancetries
df_trimmed = df_trimmed[df_trimmed["num_ancestors_child"] <= anc_level]

# trim self referential edges
df_trimmed = df_trimmed[df_trimmed["parent_id"] != df_trimmed["child_id"]]
df_trimmed = df_trimmed[df_trimmed["parent_label"] != df_trimmed["child_label"]]
df_trimmed = df_trimmed.drop_duplicates(subset=["parent_id", "child_id"])


#### Download the trimmed behavior parent-child ontologies ####
output_file = "ontology_edges_trimmed.csv"
df_trimmed.to_csv(output_file, index=False)

try:
    from google.colab import files as gfiles
except Exception:
    gfiles = None

btn = widgets.Button(description=f"Download {os.path.basename(output_file)}", icon="download")
status = widgets.HTML()
def _dl(_):
    if gfiles is not None:
        status.value = f"Starting download: <code>{os.path.basename(output_file)}</code>…"
        gfiles.download(output_file)
    else:
        status.value = f"Saved locally at <code>{output_file}</code>."
display(btn, status)
btn.on_click(_dl)

In [None]:
# @title Retrieve all the leaf nodes
# Get all leaf nodes
df_leafs = df_trimmed.copy()

# All parent IDs
all_parents = set(df_leafs["parent_id"])

# All child IDs
all_children = set(df_leafs["child_id"])

# Leafs = children that are never parents
leaf_nodes = all_children - all_parents

# Optionally get their labels
leaf_labels = df_leafs[df_leafs["child_id"].isin(leaf_nodes)][["child_id", "child_label", "child_definition"]].drop_duplicates()

### Download leaf labels ###
output_file = "leaf_nodes.csv"
leaf_labels.to_csv(output_file, index=False)

try:
    from google.colab import files as gfiles
except Exception:
    gfiles = None

btn = widgets.Button(description=f"Download {os.path.basename(output_file)}", icon="download")
status = widgets.HTML()
def _dl(_):
    if gfiles is not None:
        status.value = f"Starting download: <code>{os.path.basename(output_file)}</code>…"
        gfiles.download(output_file)
    else:
        status.value = f"Saved locally at <code>{output_file}</code>."
display(btn, status)
btn.on_click(_dl)

In [None]:
# @title Create a network graph for the trimmed tree above
df_network = df_trimmed.copy()

#### Build directed graph
G = nx.DiGraph()

for _, row in df_network.iterrows():
    parent = row["parent_id"]
    child = row["child_id"]
    if parent != child:
        G.add_edge(parent, child)
    # Node attributes
    G.nodes[parent]['label'] = row['parent_label']
    G.nodes[parent]['title'] = f"{row['parent_label']} ({parent})\n{row['parent_definition'] or ''}"
    G.nodes[child]['label'] = row['child_label']
    G.nodes[child]['title'] = f"{row['child_label']} ({child})\n{row['child_definition'] or ''}"

#### Remove self-loops
G.remove_edges_from(list(nx.selfloop_edges(G)))

#### Compute descendants for node size
descendant_counts = {n: len(nx.descendants(G, n)) for n in G.nodes()}

#### Assign unique colors to branches
direct_children = list(G.successors(root_id))
num_branches = len(direct_children)
cmap = cm.get_cmap('tab20', num_branches)
branch_colors = {child: mcolors.to_hex(cmap(i)) for i, child in enumerate(direct_children)}

#### Propagate branch color to all nodes
node_colors = {}
for branch_root, color in branch_colors.items():
    nodes_in_branch = nx.descendants(G, branch_root)
    nodes_in_branch.add(branch_root)
    for n in nodes_in_branch:
        node_colors[n] = color

#### Root color
node_colors[root_id] = "#ff9999"  # root highlighted

#### Compute node sizes (dynamic root sizing)
node_sizes = {n: 10 + descendant_counts.get(n, 0) * 2 for n in G.nodes()}
max_descendant_size = max(size for n, size in node_sizes.items() if n != root_id)
node_sizes[root_id] = max_descendant_size + 3  # root slightly bigger

#### Prepare edge colors (propagate branch color up to root)
edge_colors = {}

for branch_root, color in branch_colors.items():
    # Include the edge from root -> branch_root
    if G.has_edge(root_id, branch_root):
        edge_colors[(root_id, branch_root)] = color
    # All other edges in the branch
    for u, v in nx.edge_dfs(G, branch_root):
        edge_colors[(u, v)] = color


#### Create PyVis network
net = Network(
    notebook=True,
    directed=True,
    height="800px",
    width="100%",
    cdn_resources='in_line'
)

# Add nodes
for node, data in G.nodes(data=True):
    node_label = data['label']
    title = data['title']
    color = node_colors.get(node, "#66ccff")
    size = node_sizes.get(node, 10)

    net.add_node(
        node_label,
        label=node_label,
        title=title,
        color=color,
        size=size
    )

# Add edges with propagated branch colors
for u, v in G.edges():
    color = edge_colors.get((u, v), "#66ccff")
    net.add_edge(G.nodes[u]['label'], G.nodes[v]['label'], color=color)

#### Download
filename = "mgi_network.html"
net.show(filename)
files.download("mgi_network.html")


In [None]:
# @title Map Metrics onto the Ontology

### read in the metric descriptions ###
# need to read in the metrics list
#### Upload defined metrics ####
uploaded = files.upload()
filename = list(uploaded.keys())[0]
df_metrics = pd.read_csv(filename, encoding="latin1")


# Create the model by reference
# model = SentenceTransformer('all-MiniLM-L6-v2')
model = SentenceTransformer('all-mpnet-base-v2')


# Dummy Sentence using common key words
t0 = "Vecna is a Lich in DND, he has many abilities and can cast modify memory, wants to learn secrets of eveyone, and his motivation is beyond understanding."
# MGI description of impaired short-term object recognition memory
t1 = "impaired ability of short-term memory to recognize objects during the first few minutes after training"
# Win-Stay Definition
t2 = "The ability to measure an organisms learning to stay at the device after a winning trial measuring cognition and motviation."

# Create the embeddings
emb0 = model.encode(t0, convert_to_tensor=True)
emb1 = model.encode(t1, convert_to_tensor=True)
emb2 = model.encode(t2, convert_to_tensor=True)

# Find similar scores
score0 = util.cos_sim(emb0, emb2).item()
print("Dummy vs winstay Similarity:", score0)
score1 = util.cos_sim(emb1, emb2).item()
print("Short term memory vs winstay Similarity:", score1)


#### Do for a matrix ####
metric_texts = (
    df_metrics["metric_name"] + ". " + df_metrics["metric_definition"]
).tolist()

onto_texts = (
    leaf_labels["child_label"] + ". " + leaf_labels["child_definition"]
).tolist()

metric_emb = model.encode(metric_texts, convert_to_tensor=True)
onto_emb = model.encode(onto_texts, convert_to_tensor=True)


# compute the similarity matrix
sim_matrix = util.cos_sim(metric_emb, onto_emb)


# Extract the top matches
# Define similarity threshold score
threshold = 0.4
top_k = 5

filtered_matches = []

for i, mrow in df_metrics.iterrows():
    scores = sim_matrix[i]

    # top-k candidate ontology indices
    top_idx = scores.topk(top_k).indices.tolist()

    for idx in top_idx:
        score_val = float(scores[idx])

        if score_val >= threshold:
            filtered_matches.append({
                "metric_id": mrow["metric_id"],
                "metric_name": mrow["metric_name"],
                "metric_definition": mrow["metric_definition"],
                "ontology_id": leaf_labels["child_id"].iloc[idx],
                "ontology_term": leaf_labels["child_label"].iloc[idx],
                "ontology_definition": leaf_labels["child_definition"].iloc[idx],
                "similarity": score_val
            })

filtered_df = pd.DataFrame(filtered_matches)
filtered_df.head()



### Download mapped metrics ###
output_file = "onto_metrics_mapped.csv"
filtered_df.to_csv(output_file, index=False)

try:
    from google.colab import files as gfiles
except Exception:
    gfiles = None

btn = widgets.Button(description=f"Download {os.path.basename(output_file)}", icon="download")
status = widgets.HTML()
def _dl(_):
    if gfiles is not None:
        status.value = f"Starting download: <code>{os.path.basename(output_file)}</code>…"
        gfiles.download(output_file)
    else:
        status.value = f"Saved locally at <code>{output_file}</code>."
display(btn, status)
btn.on_click(_dl)
