In [1]:
import os

import matplotlib
matplotlib.use('Agg')

import matplotlib.pyplot as plt
from neo4j import GraphDatabase
import networkx as nx
import pandas as pd
from tqdm import tqdm

from tools import Files

In [3]:
files = Files()
root = os.path.dirname(os.path.abspath(__name__))
raw_path = os.path.join(root, "data", "raw")
processed_path = os.path.join(root, "data", "processed")
graph_path = os.path.join(root, "data", "graph")

os.makedirs(processed_path, exist_ok=True)
os.makedirs(graph_path, exist_ok=True)

In [7]:
# Prepare Database and Drug Interactions from The Raw Drug Interactions Data
if os.path.exists(os.path.join(processed_path, "database.csv")):
    database = files.read_df(os.path.join(processed_path, "database.csv"))
    interactions = files.read_json(os.path.join(processed_path, "interactions.json"))
    drugs_list = files.read_json(os.path.join(processed_path, "drugs_list.json"))
else:
    # Read Raw Drug Interactions Data
    database = [files.read_df(os.path.join(raw_path, file)) for file in os.listdir(raw_path)]
    
    # Merge DataFrames
    database = pd.concat(database, ignore_index=True)
    database.drop(columns=["DDInterID_A", "DDInterID_B"], inplace=True)
    database.rename(columns={"Drug_A": "A", "Drug_B": "B", "Level": "Severity"}, inplace=True)
    database.drop_duplicates(inplace=True)
    
    # Extract Drug Interactions and Drugs List
    interactions = database.to_dict(orient="split")['data']
    drugs_list = list(set(database["A"].unique().tolist() + database["B"].unique().tolist()))

    # Save Database, Interactions and Drugs List
    database.to_csv(os.path.join(processed_path, "database.csv"), index=False)
    files.write_json(os.path.join(processed_path, "interactions.json"), interactions)
    files.write_json(os.path.join(processed_path, "drugs_list.json"), drugs_list)

print(f"Database Prepared - {len(interactions)} interactions - {len(drugs_list)} drugs")

Database Prepared - 160235 interactions - 1939 drugs


In [8]:
# Connect to Neo4j Local Database
uri = "bolt://localhost:7687"
neo4j_credentials = ("neo4j", "12345678")
driver = GraphDatabase.driver(uri, auth=neo4j_credentials)

# Test Connection
with driver.session() as session:
    session.run("MATCH () RETURN 1 LIMIT 1")
print("Neo4j Driver Connected")

Neo4j Driver Connected


In [11]:
# Wiping the Graph
def wipe_graph(tx):
    tx.run("MATCH (n) DETACH DELETE n")

if input("Do you want to wipe the Neo4j graph? (y/n): ").lower() == "y":
    with driver.session() as session:
        session.execute_write(wipe_graph)
    print("Neo4j graph has been successfully wiped.")

Neo4j graph has been successfully wiped.


In [13]:
# Constructing the Graph or Reconstructing the Graph
def create_graph(tx, batch):
    for drug_a, drug_b, severity_level in batch:
        tx.run("MERGE (d1:Drug {name: $drug_a}) "
               "MERGE (d2:Drug {name: $drug_b}) "
               "MERGE (d1)-[r:CONTRADICTS {severity: $severity_level}]->(d2)",
               drug_a=drug_a, drug_b=drug_b, severity_level=severity_level)

if input("Do you want to construct or reconstruct the Neo4j graph? (y/n): ").lower() == "y":
    batch_size = 100
    with driver.session() as session:
        for i in tqdm(range(0, len(interactions), batch_size)):
            session.execute_write(create_graph, interactions[i:i + batch_size])
    print("Neo4j graph has been successfully constructed.")

100%|██████████| 1603/1603 [03:05<00:00,  8.66it/s]

Neo4j graph has been successfully constructed.





In [14]:
# Exporting the Graph
def fetch_data(query):
    with driver.session() as session:
        result = session.run(query)
        return [record.data() for record in result]

nodes_data = fetch_data("MATCH (n) RETURN n")
edges_data = fetch_data("MATCH (n)-[r]->(m) RETURN n, r, m")

files.write_json(os.path.join(graph_path, "nodes.json"), nodes_data)
files.write_json(os.path.join(graph_path, "edges.json"), edges_data)
print("Neo4j graph has been successfully exported.")

Neo4j graph has been successfully exported.


In [None]:
if input("Do you want to visualize the Neo4j graph? (y/n): ").lower() == "y":
    G = nx.DiGraph()

    # Adding nodes to the graph
    for node in nodes_data:
        node_properties = node["n"]
        G.add_node(node_properties["name"], **node_properties)

    # Adding edges to the graph
    for edge in edges_data:
        start_node = edge["n"]["name"]
        end_node = edge["m"]["name"]
        edge_type = edge["r"][1]  # Or "CONTRADICTS"

        # Adding the edge with the edge type as a label
        G.add_edge(start_node, end_node, label=edge_type)
    
    figure_size = (100, 100)
    plt.figure(figsize=figure_size)
    
    pos = nx.spring_layout(G, k=0.1)
    nx.draw(G, pos, with_labels=True, node_size=50, font_size=4, edge_color='black', width=0.5)
    labels = nx.get_edge_attributes(G, 'label')
    nx.draw_networkx_edge_labels(G, pos, edge_labels=labels, font_size=3)
    plt.savefig(os.path.join(graph_path, "graph.png"), format="PNG", dpi=300)

    print(f"Graph plot has been successfully created and saved to {graph_path}.")