In [47]:
import requests
import json
from dotenv import load_dotenv
import os
import pandas as pd       
import re
import networkx as nx     
import warnings
from openai import OpenAI
from tqdm import tqdm
import csv

# Configure settings for better display and fewer warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)
print("Libraries imported successfully.")

Libraries imported successfully.


In [35]:
# Load environment variables from .env file
load_dotenv()
# Read the API key
key = os.getenv("NVIDIA_API_KEY")
# Define the llm
model_name = "nvidia/llama-3.1-nemotron-ultra-253b-v1"

client = OpenAI(
    api_key=key,
    base_url="https://integrate.api.nvidia.com/v1", 
)

## Define the LLM Prompt for Extraction

In [23]:
# System prompt
system_prompt = """
You are a medical information-extraction assistant focused on respiratory diseases.
Given a passage of text about {disease}, extract Subject-Predicate-Object (SPO) triples.

Extraction rules:
────────────────
• Extract any medically meaningful relation related to {disease} and its context.
• Subjects and objects can be: symptoms, causes, treatments, diagnoses, risk factors, complications, or related medical concepts.
• Use only these predicates (verbatim, lowercase):
  has_symptom, is_diagnosed_by, is_treated_by, is_caused_by, 
  is_prevented_by, has_risk_factor, leads_to_complication, 
  treats, causes, occurs_with, associated_with
• Subjects/objects should be full terms — no pronouns or abbreviations.

Output format:
─────────────
• CSV only. First line must be:
  subject,predicate,object
• No explanation, no blank lines, no quotation marks, no punctuation at end of terms.
"""

# User prompt template (for formatting with each chunk)
user_prompt_template = """
EXTRACT SPO TRIPLES

Disease context: {disease}

Text:
\"\"\"{text_chunk}\"\"\"

Your response must be a CSV table starting with the header:
subject,predicate,object
"""

In [43]:
def textualize_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    def format_dict(d):
        lines = []
        for key, value in d.items():
            if key == 'url':
                continue
            if isinstance(value, dict):
                lines.append(f"{key}:\n{format_dict(value)}")
            else:
                lines.append(f"{key}:\n{value}")
        return '\n'.join(lines)

    if isinstance(data, dict):
        text = format_dict(data)
    elif isinstance(data, list):
        text = '\n\n'.join(format_dict(item) for item in data)
    else:
        text = str(data)

    return text

def format_llm_output(output_text, save_path):
    lines = output_text.strip().splitlines()
    cleaned_lines = []

    for line in lines:
        line = line.strip()
        if not line or line.lower().startswith("subject") or line.count(',') != 2:
            continue
        parts = [part.strip() for part in line.split(',')]
        if all(parts):
            cleaned_lines.append(parts)

    if not cleaned_lines:
        print(f"Warning: No valid SPO triples found in output: {save_path}")
        return
        
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(["subject", "predicate", "object"])
        writer.writerows(cleaned_lines)

# Function to save LLM output text into cleaned CSV
def save_llm_output(file_name, output_text, output_path):
    file_path = os.path.join(output_path, f"{file_name}.csv")
    format_llm_output(output_text, file_path)

## LLM Interaction - Extracting Triples (Chunk by Chunk)

In [51]:
def extract_spo(raw_data_path, output_path, client):
    failed_files = 0

    folders = [f for f in os.listdir(raw_data_path) if os.path.isdir(os.path.join(raw_data_path, f))]

    for folder_name in tqdm(folders, desc="Processing folders"):
        folder_path = os.path.join(raw_data_path, folder_name)
        files = os.listdir(folder_path)

        for file_name in tqdm(files, desc=f"→ {folder_name}", leave=False):
            file_path = os.path.join(folder_path, file_name)
            try:
                raw_text = textualize_json(file_path)
                completion = client.chat.completions.create(
                    model=model_name,
                    messages=[
                        {"role": "system", "content": system_prompt.format(disease=folder_name)},
                        {"role": "user", "content": user_prompt_template.format(text_chunk=raw_text, disease=folder_name)}
                    ],
                    temperature=0.2,
                    top_p=1.0,
                    max_tokens=2048,
                    frequency_penalty=0,
                    presence_penalty=0,
                    stream=False
                )

                output_text = completion.choices[0].message.content
                new_file_name = f"{folder_name}_{os.path.splitext(file_name)[0]}"
                save_llm_output(new_file_name, output_text, output_path)

            except Exception as e:
                print(f"Error processing file {file_name} in folder {folder_name}: {e}")
                failed_files += 1

In [53]:
extract_spo('./raw_files', './spo/raw', client)

Processing folders:   0%|          | 0/5 [00:00<?, ?it/s]
→ asthma:   0%|          | 0/11 [00:00<?, ?it/s][A
→ asthma:   9%|▉         | 1/11 [00:08<01:28,  8.85s/it][A
→ asthma:  18%|█▊        | 2/11 [00:34<02:48, 18.77s/it][A
→ asthma:  27%|██▋       | 3/11 [00:41<01:48, 13.53s/it][A
→ asthma:  36%|███▋      | 4/11 [01:17<02:34, 22.13s/it][A
→ asthma:  45%|████▌     | 5/11 [01:24<01:39, 16.62s/it][A
→ asthma:  55%|█████▍    | 6/11 [01:35<01:14, 14.82s/it][A
→ asthma:  64%|██████▎   | 7/11 [01:43<00:50, 12.50s/it][A
→ asthma:  73%|███████▎  | 8/11 [01:49<00:31, 10.67s/it][A
→ asthma:  82%|████████▏ | 9/11 [01:59<00:20, 10.49s/it][A
→ asthma:  91%|█████████ | 10/11 [02:07<00:09,  9.66s/it][A
→ asthma: 100%|██████████| 11/11 [02:25<00:00, 12.09s/it][A
Processing folders:  20%|██        | 1/5 [02:25<09:41, 145.31s/it]
→ chronic-obstructive-pulmonary-disease-copd:   0%|          | 0/9 [00:00<?, ?it/s][A
→ chronic-obstructive-pulmonary-disease-copd:  11%|█         | 1/9 [00:21<

### Extraction Summary 


In [82]:
def count_spos_per_disease(spo_folder_path):
    spo_dict = defaultdict(list)

    for file_name in os.listdir(spo_folder_path):
        if file_name.endswith('.csv'):
            disease = file_name.split('_')[0]
            file_path = os.path.join(spo_folder_path, file_name)

            try:
                df = pd.read_csv(file_path)
                if {'subject', 'predicate', 'object'}.issubset(df.columns):
                    for _, row in df.iterrows():
                        spo_dict[disease].append((row["subject"], row["predicate"], row["object"]))
                else:
                    print(f"Skipped file (missing columns): {file_name}")
            except Exception as e:
                print(f"Error reading file {file_name}: {e}")

    diseases = []
    counts = []
    # Optional: Print summary
    for disease, spo_list in sorted(spo_dict.items()):
        diseases.append(disease)
        counts.append(len(spo_list))
        print(f"{disease}: {len(spo_list)} SPO triples")

    total_spo = sum(len(v) for v in spo_dict.values())
    print(f"\nTotal SPO triples: {total_spo}")
    print(f"Total diseases: {len(spo_dict)}")

    return dict(spo_dict), diseases, counts

all_extracted_triples, diseases, original_counts = count_spos_per_disease('./spo/raw')

asthma: 677 SPO triples
chronic-obstructive-pulmonary-disease-copd: 459 SPO triples
coronavirus-disease-(covid-19): 283 SPO triples
pneumonia: 373 SPO triples
tuberculosis: 298 SPO triples

Total SPO triples: 2090
Total diseases: 5


## Post-processing & canonicalisation

### Normalize and De-duplicate Triples

Processes each disease separately.  
Normalizes all triples (trimming and lowercasing).  
Deduplicates based on the cleaned (subject, predicate, object) tuples.  
Tracks and prints detailed stats per disease.

In [99]:
normalized_spo_by_disease = {}
dedup_stats = {}

for disease in diseases:
    triples = all_extracted_triples[disease]

    normalized_triples = []
    seen_triples = set()
    empty_removed_count = 0
    duplicates_removed_count = 0

    for i, (s, p, o) in enumerate(triples):
        s = s.strip().lower() if isinstance(s, str) else ''
        p = re.sub(r'\s+', ' ', p.strip().lower()) if isinstance(p, str) else ''
        o = o.strip().lower() if isinstance(o, str) else ''

        if all([s, p, o]):
            key = (s, p, o)
            if key not in seen_triples:
                normalized_triples.append({'subject': s, 'predicate': p, 'object': o})
                seen_triples.add(key)
            else:
                duplicates_removed_count += 1
        else:
            empty_removed_count += 1

    normalized_spo_by_disease[disease] = normalized_triples
    dedup_stats[disease] = {
        "original": len(triples),
        "kept": len(normalized_triples),
        "duplicates_removed": duplicates_removed_count,
        "empty_removed": empty_removed_count
    }

# Summary
total = 0
print("Deduplication Summary Per Disease:")
for disease in sorted(dedup_stats):
    total += stats['kept']
    stats = dedup_stats[disease]
    print(f"{disease}: Kept {stats['kept']} / {stats['original']} | Duplicates: {stats['duplicates_removed']}, Empty: {stats['empty_removed']}")
print(f"Total remaining SPO: {total}")

Deduplication Summary Per Disease:
asthma: Kept 533 / 677 | Duplicates: 144, Empty: 0
chronic-obstructive-pulmonary-disease-copd: Kept 418 / 459 | Duplicates: 41, Empty: 0
coronavirus-disease-(covid-19): Kept 271 / 283 | Duplicates: 12, Empty: 0
pneumonia: Kept 328 / 373 | Duplicates: 45, Empty: 0
tuberculosis: Kept 279 / 298 | Duplicates: 19, Empty: 0
Total remaining SPO: 1829


In [104]:
def export_spo_per_disease(normalized_spo_by_disease, path):
    os.makedirs(path, exist_ok=True)
    for disease, triples in normalized_spo_by_disease.items():
        filename = f"{disease.lower().replace(' ', '_')}.csv"
        file_path = os.path.join(path, filename)

        with open(file_path, mode='w', newline='', encoding='utf-8') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=["subject", "predicate", "object"])
            writer.writeheader()
            writer.writerows(triples)
    print(f"Exported SPO CSV files to: {path}")

export_spo_per_disease(normalized_spo_by_disease, './spo/normalized&deduplicated')

Exported SPO CSV files to: ./spo_normalized&deduplicated


### Refine the spo with another llm call

In [112]:
refining_system_prompt = """
You are a medical SPO-triple refinement assistant for respiratory-disease knowledge graphs.

Task
────
Given SPO triples about the disease {disease}, return a single, clean, fully-connected graph:

• Eliminate semantic duplicates.  
• Clarify vague terms (use formal medical wording, no abbreviations or pronouns).  
• Ensure every node belongs to ONE connected component anchored on {disease}.  
  – If a node is isolated but clearly relates to {disease}, ADD one factual triple to link it.  
  – If connection is uncertain, leave the triple unchanged (do NOT fabricate facts).  
• All text must be lowercase.

Output format
─────────────
csv only — first line literally:
subject,predicate,object
No explanations, blank lines, quotes, or trailing punctuation.
"""

refining_user_prompt_template = """
REFINE SPO TRIPLES  –  DISEASE: {disease}

Below is a csv table of extracted triples.

Your tasks
──────────
1. Deduplicate rows that convey the same fact.  
2. Rephrase unclear wording for precision.  
3. Ensure every node appears in at least one edge connected (directly or indirectly) to "{disease}".  
4. If a node is isolated yet obviously related, add ONE triple using an allowed predicate to connect it.  
5. Do NOT invent new medical facts.  
6. Return lowercase csv, starting with the header.

Input triples:
{csv_triples}

Your output must start with:
subject,predicate,object
"""

In [132]:
def read_csv_rows(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        return list(reader)

def rows_to_csv_string(rows):
    output = ["subject,predicate,object"]
    for row in rows:
        output.append(f"{row['subject']},{row['predicate']},{row['object']}")
    return "\n".join(output)

def adaptive_chunks(rows, max_chars):
    """Yield batches of rows such that the total character length of the CSV string stays below max_chars."""
    batch = []
    current_length = len("subject,predicate,object\n")  # header

    for row in rows:
        row_str = f"{row['subject']},{row['predicate']},{row['object']}\n"
        if current_length + len(row_str) > max_chars:
            if batch:
                yield batch
                batch = []
                current_length = len("subject,predicate,object\n")
        batch.append(row)
        current_length += len(row_str)

    if batch:
        yield batch

def refine_spo_from_csv(input_path, output_path, client, model_name, system_prompt, user_prompt_template, max_chars_per_batch=5000):
    os.makedirs(output_path, exist_ok=True)
    failed_files = 0

    for file_name in tqdm(os.listdir(input_path), desc="Refining SPO per disease"):
        if not file_name.endswith('.csv'):
            continue

        disease = os.path.splitext(file_name)[0].replace('_', ' ')
        file_path = os.path.join(input_path, file_name)

        try:
            rows = read_csv_rows(file_path)
            all_refined_rows = []

            for i, batch in enumerate(adaptive_chunks(rows, max_chars=max_chars_per_batch)):
                csv_chunk = rows_to_csv_string(batch)

                completion = client.chat.completions.create(
                    model=model_name,
                    messages=[
                        {"role": "system", "content": system_prompt.format(disease=disease)},
                        {"role": "user", "content": user_prompt_template.format(disease=disease, csv_triples=csv_chunk)}
                    ],
                    temperature=0.2,
                    top_p=1.0,
                    max_tokens=2048,
                    frequency_penalty=0,
                    presence_penalty=0,
                    stream=False
                )

                output_text = completion.choices[0].message.content.strip()

                # Parse the refined triples
                lines = output_text.splitlines()
                for line in lines[1:]:  # Skip header
                    parts = line.split(",")
                    if len(parts) == 3:
                        subject, predicate, object_ = [p.strip() for p in parts]
                        all_refined_rows.append({
                            "subject": subject,
                            "predicate": predicate,
                            "object": object_
                        })

            # Deduplicate after merging batches
            unique_rows = [dict(t) for t in {tuple(d.items()) for d in all_refined_rows}]

            # Save final merged and deduplicated file
            out_filename = f"{file_name.replace('.csv', '')}_refined.csv"
            out_path = os.path.join(output_path, out_filename)

            with open(out_path, 'w', newline='', encoding='utf-8') as f:
                writer = csv.DictWriter(f, fieldnames=["subject", "predicate", "object"])
                writer.writeheader()
                writer.writerows(unique_rows)

        except Exception as e:
            print(f"❌ Error processing {file_name}: {e}")
            failed_files += 1

    print(f"\n✅ Completed refinement with {failed_files} failure(s).")

In [134]:
refine_spo_from_csv(
    input_path='./spo/normalized&deduplicated',
    output_path='./spo/refined',
    client=client,  
    model_name=model_name,
    system_prompt=refining_system_prompt,
    user_prompt_template=refining_user_prompt_template,
)

Refining SPO per disease: 100%|██████████| 5/5 [10:06<00:00, 121.21s/it]


✅ Completed refinement with 0 failure(s).





In [165]:
from collections import defaultdict
import os
import pandas as pd

def summarize_refinement_process(raw_path, refined_path):
    summary = []

    # Helper function to count SPO triples from a folder
    def load_spo_counts(folder_path):
        spo_counts = {}
        for file_name in os.listdir(folder_path):
            if file_name.endswith('.csv'):
                disease = file_name.replace('_refined', '').replace('.csv', '').replace('_', ' ')
                file_path = os.path.join(folder_path, file_name)

                try:
                    df = pd.read_csv(file_path)
                    if {'subject', 'predicate', 'object'}.issubset(df.columns):
                        spo_counts[disease] = len(df)
                    else:
                        print(f"Skipped file (missing columns): {file_name}")
                except Exception as e:
                    print(f"Error reading file {file_name}: {e}")
        return spo_counts

    raw_counts = load_spo_counts(raw_path)
    refined_counts = load_spo_counts(refined_path)

    diseases = sorted(set(raw_counts.keys()).union(refined_counts.keys()))

    print(f"{'Disease':<50} {'Raw':>6} {'Refined':>8} {'Δ':>6}")
    print("-" * 52)

    total_raw = total_refined = 0

    for disease in diseases:
        raw = raw_counts.get(disease, 0)
        refined = refined_counts.get(disease, 0)
        delta = refined - raw
        total_raw += raw
        total_refined += refined

        print(f"{disease:<50} {raw:>6} {refined:>8} {delta:>6}")
        summary.append({
            "disease": disease,
            "raw": raw,
            "refined": refined,
            "delta": delta
        })

    print("-" * 52)
    print(f"{'Total':<50} {total_raw:>6} {total_refined:>8} {total_refined - total_raw:>6}")

    return summary

summary = summarize_refinement_process(
    raw_path='./spo/normalized&deduplicated',
    refined_path='./spo/refined'
)

Disease                                               Raw  Refined      Δ
----------------------------------------------------
asthma                                                533      570     37
chronic-obstructive-pulmonary-disease-copd            418      403    -15
coronavirus-disease-(covid-19)                        271      228    -43
pneumonia                                             328      562    234
tuberculosis                                          279      367     88
----------------------------------------------------
Total                                                1829     2130    301


### graph count and formatting

In [169]:
def read_csv_rows(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        return list(csv.DictReader(f))

def csv_file_to_disease(file_name: str) -> str:
    """Assumes pattern '<disease>_refined.csv'. Adjust if needed."""
    return file_name.replace("_refined.csv", "").replace("_", " ")

def extract_graphs_from_directory(
    input_folder: str,
    output_json_path: str,
    verbose: bool = False,
):
    all_graphs = []                   # list of all sub-graphs (for JSON)
    disease_stats = defaultdict(lambda: {"graphs": 0, "nodes": 0, "edges": 0})

    for file_name in tqdm(os.listdir(input_folder), desc="Building graphs"):
        if not file_name.endswith("_refined.csv"):
            continue

        disease = csv_file_to_disease(file_name)
        file_path = os.path.join(input_folder, file_name)

        rows = read_csv_rows(file_path)

        # Build MultiDiGraph
        G = nx.MultiDiGraph()
        for row in rows:
            G.add_edge(row["subject"], row["object"], predicate=row["predicate"])

        # Weakly connected components = individual sub-graphs
        for comp_id, comp_nodes in enumerate(nx.weakly_connected_components(G), 1):
            sub = G.subgraph(comp_nodes)
            triples = [
                [u, d["predicate"], v] for u, v, d in sub.edges(data=True)
            ]

            # Stats
            n_nodes = sub.number_of_nodes()
            n_edges = sub.number_of_edges()

            disease_stats[disease]["graphs"] += 1
            disease_stats[disease]["nodes"] += n_nodes
            disease_stats[disease]["edges"] += n_edges

            if verbose:
                print(
                    f"{disease:<45} | graph {comp_id:>2} | "
                    f"{n_nodes:>3} nodes | {n_edges:>3} edges"
                )

            all_graphs.append(triples)

    # Write the list-of-lists JSON
    with open(output_json_path, "w", encoding="utf-8") as fp:
        json.dump(all_graphs, fp, indent=2, ensure_ascii=False)

    print("\nSummary by disease")
    print(f"{'disease':<45} {'#graphs':>8} {'nodes':>10} {'edges':>10}")
    print("-" * 55)
    total_g = total_n = total_e = 0
    for dis, s in sorted(disease_stats.items()):
        print(
            f"{dis:<45} {s['graphs']:>8} {s['nodes']:>10} {s['edges']:>10}"
        )
        total_g += s["graphs"]
        total_n += s["nodes"]
        total_e += s["edges"]
    print("-" * 55)
    print(f"{'TOTAL':<45} {total_g:>8} {total_n:>10} {total_e:>10}")

    return all_graphs, disease_stats


graphs = extract_graphs_from_directory(
    input_folder='./spo/refined',
    output_json_path='./spo/graphs.json',
    verbose=True
)

Building graphs: 100%|██████████| 5/5 [00:00<00:00, 96.15it/s]

asthma                                        | graph  1 | 435 nodes | 570 edges
chronic-obstructive-pulmonary-disease-copd    | graph  1 | 316 nodes | 403 edges
coronavirus-disease-(covid-19)                | graph  1 | 197 nodes | 228 edges
pneumonia                                     | graph  1 | 420 nodes | 562 edges
tuberculosis                                  | graph  1 | 285 nodes | 367 edges

Summary by disease
disease                                        #graphs      nodes      edges
-------------------------------------------------------
asthma                                               1        435        570
chronic-obstructive-pulmonary-disease-copd           1        316        403
coronavirus-disease-(covid-19)                       1        197        228
pneumonia                                            1        420        562
tuberculosis                                         1        285        367
-------------------------------------------------------
T




## Store refined Triplets in neo4j

In [249]:
# from neo4j import GraphDatabase
# 
# uri = "bolt://localhost:7687"  # Neo4j Desktop default
# user = "neo4j"
# password = "password"  # Set this when you created your local DB
# 
# driver = GraphDatabase.driver(uri, auth=(user, password))
# 
# def add_triple(tx, subj, pred, obj):
#     tx.run("""
#         MERGE (s:Entity {name: $subj})
#         MERGE (o:Entity {name: $obj})
#         MERGE (s)-[:RELATION {type: $pred}]->(o)
#     """, subj=subj, obj=obj, pred=pred)
# 
# with driver.session() as session:
#     for subj, pred, obj in refined_triples:
#         session.write_transaction(add_triple, subj, pred, obj)