<a href="https://colab.research.google.com/github/alisonmitchell/Biomedical-Knowledge-Graph/blob/main/06_Knowledge_Graphs/REBEL_KG.ipynb"
   target="_parent">
   <img src="https://colab.research.google.com/assets/colab-badge.svg"
      alt="Open in Colab">
</a>


# REBEL Knowledge Graph

## 1. Introduction

REBEL extracted 7149 triples after end-to-end NER and Relation Extraction, with 82 relation types. A subset of the 12 most frequent relation types were selected, filtering the dataset to 5774 triples which will be visualised as a knowledge graph.

## 2. Install/import libraries

In [None]:
!pip install pyvis

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import networkx as nx
import pickle

from IPython.display import display, HTML
from pyvis.network import Network

## 3. Load data

In [None]:
# load REBEL triples
with open('2024-11-02_rebel_triples_df_updated_7149.pickle', "rb") as f:
    rebel_triples_df_updated = pickle.load(f)

In [None]:
rebel_triples_df_updated.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7149 entries, 0 to 7148
Data columns (total 7 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   head       7149 non-null   object
 1   relation   7149 non-null   object
 2   tail       7149 non-null   object
 3   head_type  7149 non-null   object
 4   head_id    7149 non-null   object
 5   tail_type  7149 non-null   object
 6   tail_id    7149 non-null   object
dtypes: object(7)
memory usage: 391.1+ KB


## 4. Filter relations

We will select a subset of the 12 most frequent of the 82 relation types REBEL identified and visualise the triples as a knowledge graph.

In [None]:
# List of relations to filter
relations_to_keep = [
    'subclass of',
    'instance of',
    'part of',
    'has part',
    'facet of',
    'has effect',
    'subject has role',
    'has cause'
    'use',
    'drug used for treatment',
    'medical condition treated',
    'drug',
    'pathogen',
    'drug used for',
    'uses'
]

In [None]:
# Filter DataFrame for rows with specified relations
filtered_df = rebel_triples_df_updated[rebel_triples_df_updated['relation'].isin(relations_to_keep)]
filtered_df = filtered_df.reset_index(drop=True)

In [None]:
len(filtered_df)

5774

In [None]:
filtered_df

Unnamed: 0,head,relation,tail,head_type,head_id,tail_type,tail_id
0,consensus among all approaches,instance of,consensus,other,consensus among all approaches,other,consensus
1,hydrogen,subclass of,atoms,drug,chembl:CHEMBL4297766,other,atoms
2,distal airway,part of,lungs,other,distal airway,other,lungs
3,sepsis dataset,has part,b,disease,HP:0100806,disease,MONDO:0005737
4,mapk,has part,p38 mapk,gene,hgnc.genegroup:651,gene,ensembl:ENSG00000112062
...,...,...,...,...,...,...,...
5769,meta-analysis,uses,mortality rate,measurement,STATO:0000155,measurement,STATO:0000414
5770,ards,has effect,sepsis,disease,MONDO:0006502,disease,HP:0100806
5771,abiraterone,medical condition treated,covid-19,drug,chembl:CHEMBL254328,disease,covid
5772,genetic risk genes,subclass of,gene,other,genetic risk genes,other,gene


In [None]:
with open('2024-11-15_filtered_df_5774.pickle', "wb") as f:
    pickle.dump(filtered_df, f)

In [None]:
filtered_df.to_csv('2024-11-15_rebel_filtered_triples_5774.csv', index=False)

In [None]:
# Filter rows where either 'head_type' or 'tail_type' is 'other'
filtered_df_other = filtered_df.loc[(filtered_df['head_type'] == 'other') | (filtered_df['tail_type'] == 'other')]

# Filter rows where both 'head_type' and 'tail_type' are not 'other'
filtered_df_not_other = filtered_df.loc[(filtered_df['head_type'] != 'other') & (filtered_df['tail_type'] != 'other')]

In [None]:
len(filtered_df_other)

4314

In [None]:
len(filtered_df_not_other)

1460

Most of the entities REBEL identified are of entity type 'other'. This entity type label was assigned during a preprocessing step which performed fuzzy matching of REBEL entities against KAZU entities.

If a match was found a tuple was returned where the first value was the entity type (or 'other' if not found), and the second value was the entity ID (or the original entity if not found in entity_id).



In [None]:
# Filter rows where both 'head_type' and 'tail_type' are not 'other'
filtered_df_not_other = filtered_df.loc[(filtered_df['head_type'] != 'other') & (filtered_df['tail_type'] != 'other')]

## 5. Plot Knowledge Graph

In [None]:
entity_types = ["disease", "drug", "gene", "measurement", "cell_type", "species", "anatomy", "go_cc", "go_mf", "go_bp", "cell_line", "covid lineage", "other"]

In [None]:
# Define colour codes for entity types
color_codes = {
    "disease": "lightgreen",
    "drug": "orange",
    "gene": "pink",
    "measurement": "lightblue",
    "cell_type": "orchid",
    "species": "aquamarine",
    "anatomy": "burlywood",
    "go_cc": "silver",
    "go_mf": "salmon",
    "go_bp": "darkseagreen",
    "cell_line": "gold",
    "covid lineage": "yellow",
    "other": "teal"
}

We will use the [`pyvis`](https://github.com/WestHealth/pyvis) library designed for quick generation of visual network graphs with minimal python code.

In [None]:
def create_pyvis_graph(df, output_file):
    # Initialise a pyvis network object
    graph = Network(notebook=True,
                    cdn_resources="in_line",
                    directed=True,
                    select_menu=True,
                    filter_menu=True)

    # Create dictionaries to store node data with node IDs as keys
    node_titles = {}
    node_colors = {}  # Dictionary to store node colours based on entity type

    # Add nodes and edges to the network
    for index, row in df.iterrows():
        head = row['head']
        tail = row['tail']
        relation = row['relation']
        head_type = row['head_type']
        tail_type = row['tail_type']
        head_id = row['head_id']
        tail_id = row['tail_id']

        # Add nodes
        graph.add_node(head, title=head, label=head)
        graph.add_node(tail, title=tail, label=tail)

        # Add edges
        graph.add_edge(head, tail, title=relation, label=relation, color="gray")

        # Set tooltips and colours for head node
        if head_id == 'covid':  # Special case for 'covid'
            head_tooltip = f"Entity: {head}<br>Entity type: {head_type}<br>ID: covid"
        elif head != head_id:  # Head entity has an ID
            head_tooltip = f"Entity: {head}<br>Entity type: {head_type}<br>ID: <a href='https://bioregistry.io/{head_id}' target='_blank'>{head_id}</a>"
        else:  # Head entity does not have an ID
            head_tooltip = f"Entity: {head}<br>Entity type: {head_type}"

        node_titles[head] = head_tooltip
        node_colors[head] = color_codes[head_type]

        # Set tooltips and colours for tail node
        if tail_id == 'covid':  # Special case for 'covid'
            tail_tooltip = f"Entity: {tail}<br>Entity type: {tail_type}<br>ID: covid"
        elif tail != tail_id:  # Tail entity has an ID
            tail_tooltip = f"Entity: {tail}<br>Entity type: {tail_type}<br>ID: <a href='https://bioregistry.io/{tail_id}' target='_blank'>{tail_id}</a>"
        else:  # Tail entity does not have an ID
            tail_tooltip = f"Entity: {tail}<br>Entity type: {tail_type}"

        node_titles[tail] = tail_tooltip
        node_colors[tail] = color_codes[tail_type]

    # Set the calculated tooltips and colours to the graph nodes
    for node_data in graph.nodes:
        node_id = node_data["id"]
        node_data["title"] = node_titles.get(node_id, node_id)
        node_data["color"] = node_colors[node_id]  # Use the colour from the dictionary

    # Configure the graph's repulsion settings to control node spacing and spring dynamics
    graph.repulsion(
        node_distance=300,
        central_gravity=0.2,
        spring_length=300,
        spring_strength=0.05,
        damping=0.09
    )
    # Set the edges to have a dynamic smoothness
    graph.set_edge_smooth('dynamic')

    # Print the number of nodes and edges
    print(f"Graph with {len(graph.nodes)} nodes and {len(graph.edges)} edges.")

    # Generate the graph and save to HTML file
    graph.save_graph(output_file)

    # Display the graph
    display(HTML(output_file))

The graph has 2532 nodes and 5774 edges so can take a while to load.

In [None]:
create_pyvis_graph(filtered_df, 'rebel_12_relations.html')

Graph with 2532 nodes and 5774 edges.
