In [1]:
# 02_confounder_dag_pipeline.py

import json
import os
import networkx as nx
import matplotlib.pyplot as plt
from sqlalchemy import create_engine, MetaData
from dotenv import load_dotenv

load_dotenv()

# Load database credentials
engine = create_engine(
    f"postgresql+psycopg2://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}@"
    f"{os.getenv('DB_HOST', 'localhost')}:{os.getenv('DB_PORT', '5432')}/{os.getenv('DB_NAME', 'dag_review_db')}"
)
metadata = MetaData()
metadata.reflect(bind=engine)

Matplotlib is building the font cache; this may take a moment.


In [2]:
def infer_confounders(exposures, outcomes):
    # Simple domain-inspired heuristic: suggest age, sex, SES, comorbidities
    common_confounders = ["age", "sex", "socioeconomic_status", "baseline_health"]
    inferred = []
    for exposure in exposures:
        for outcome in outcomes:
            for c in common_confounders:
                inferred.append((c, exposure))  # confounder → exposure
                inferred.append((c, outcome))   # confounder → outcome
    return list(set(inferred)), common_confounders

In [95]:
def propose_dag(study_id):
    with engine.connect() as conn:
        study = conn.execute(metadata.tables['studies'].select().where(
            metadata.tables['studies'].c.id == study_id)).fetchone()
        if not study:
            print("Study not found.")
            return

        exposures = [row._mapping['exposure'] for row in conn.execute(
            metadata.tables['exposures'].select().where(
                metadata.tables['exposures'].c.study_id == study_id))]

        outcomes_data = conn.execute(
            metadata.tables['outcomes'].select().where(
                metadata.tables['outcomes'].c.study_id == study_id)).fetchall()
        outcomes = []
        for o in outcomes_data:
            data = o._mapping.get('data')
            if isinstance(data, list) and data and isinstance(data[0], dict) and 'value' in data[0]:
                outcomes.append(data[0]['value'])
            else:
                outcomes.append(str(data))

        if exposures:
            inferred_edges, confounders = infer_confounders(exposures, outcomes)
            graph_inputs = exposures + confounders
            print("Building DAG using exposures and inferred confounders.")
        else:
            print("No exposures found — switching to eligibility criteria and baseline characteristics.")
            inferred_edges = []
            confounders = []
            eligibility_data = conn.execute(
                metadata.tables['eligibility_criteria'].select().where(
                    metadata.tables['eligibility_criteria'].c.study_id == study_id)).fetchall()
            for row in eligibility_data:
                confounders.append(row._mapping['criteria'])
            graph_inputs = confounders

        dag = nx.DiGraph()
        dag.add_nodes_from(set(graph_inputs + outcomes))
        dag.add_edges_from([(c, o) for c in graph_inputs for o in outcomes if c and o])
        dag.add_edges_from([edge for edge in inferred_edges if edge[0] and edge[1]])

        def shorten(label, max_len=40):
            return label if len(label) <= max_len else label[:max_len] + "..."

        for node in dag.nodes:
            dag.nodes[node]['type'] = 'confounder' if node in confounders else ('exposure' if node in exposures else 'outcome')
            dag.nodes[node]['label'] = shorten(node)

        # Clear existing DAG for the study
        with engine.begin() as trans_conn:
            # First delete edges to avoid FK violation
            trans_conn.execute(metadata.tables['dag_edges'].delete().where(
                metadata.tables['dag_edges'].c.study_id == study_id))
            trans_conn.execute(metadata.tables['dag_nodes'].delete().where(
                metadata.tables['dag_nodes'].c.study_id == study_id))

            for node in dag.nodes:
                trans_conn.execute(metadata.tables['dag_nodes'].insert().values(
                    study_id=study_id,
                    node_label=node,
                    node_type=dag.nodes[node]['type'],
                    source='llm'
                ))

            for edge in dag.edges:
                from_label, to_label = edge
                from_node_id = trans_conn.execute(
                    metadata.tables['dag_nodes'].select().where(
                        metadata.tables['dag_nodes'].c.study_id == study_id,
                        metadata.tables['dag_nodes'].c.node_label == from_label
                    )).scalar()
                to_node_id = trans_conn.execute(
                    metadata.tables['dag_nodes'].select().where(
                        metadata.tables['dag_nodes'].c.study_id == study_id,
                        metadata.tables['dag_nodes'].c.node_label == to_label
                    )).scalar()

                if from_node_id and to_node_id:
                    trans_conn.execute(metadata.tables['dag_edges'].insert().values(
                        study_id=study_id,
                        from_node_id=from_node_id,
                        to_node_id=to_node_id,
                        relation_type='llm',
                        confidence=1.0
                    ))

        # Export DAGitty text format
        dagitty = "dag {"
        for node in dag.nodes:
            safe_node = node.replace('"', "'")
            dagitty += f"\n\"{safe_node}\""
        for src, tgt in dag.edges:
            dagitty += f"\n\"{src}\" -> \"{tgt}\""
        dagitty += "\n}"

        dagitty_txt_path = f"dagitty_study_{study_id}.txt"
        with open(dagitty_txt_path, "w") as f:
            f.write(dagitty)

        # Create R Markdown file to render DAG
        rmd_path = f"dagitty_study_{study_id}.Rmd"
        rmd_content = f"""
---
title: "DAGitty DAG for Study {study_id}"
output: html_document
---

```{{r setup, include=FALSE}}
library(dagitty)
library(ggdag)
```

```{{r load-and-plot-dag}}
dag_text <- readLines("dagitty_study_2.txt")
dag <- dagitty(paste(dag_text, collapse = "\n"))

ggdag(dag, text = FALSE, use_labels = "none") +
  geom_dag_edges() +
  geom_dag_text_repel(aes(label = name), size = 3, segment.size = 0.2, box.padding = 0.3, point.padding = 0.25) +
  theme_dag()
```
"""
        with open(dagitty_txt_path, "w") as f:
            f.write(dagitty + "\n")

        import subprocess
        subprocess.run(["Rscript", "-e", f"rmarkdown::render('{rmd_path}')"])
        html_path = f"dagitty_study_{study_id}.html"
        if os.path.exists(html_path):
            import webbrowser
            webbrowser.open(f"file://{os.path.abspath(html_path)}")
    return {
        "graph": dag,
        "dagitty": dagitty,
        "nodes": list(dag.nodes),
        "edges": list(dag.edges),
        "dagviz": None  # placeholder if needed later
    }

In [96]:
propose_dag(2)

No exposures found — switching to eligibility criteria and baseline characteristics.




processing file: dagitty_study_2.Rmd


1/4                    
2/4 [setup]            
3/4                    
4/4 [load-and-plot-dag]
/opt/homebrew/bin/pandoc +RTS -K512m -RTS dagitty_study_2.knit.md --to html4 --from markdown+autolink_bare_uris+tex_math_single_backslash --output dagitty_study_2.html --lua-filter /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/library/rmarkdown/rmarkdown/lua/pagebreak.lua --lua-filter /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/library/rmarkdown/rmarkdown/lua/latex-div.lua --lua-filter /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/library/rmarkdown/rmarkdown/lua/table-classes.lua --embed-resources --standalone --variable bs3=TRUE --section-divs --template /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/library/rmarkdown/rmd/h/default.html --no-highlight --variable highlightjs=1 --variable theme=bootstrap --mathjax --variable 'mathjax-url=https://mathjax.rstudio.com/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML' --include-in-header 

output file: dagitty_study_2.knit.md


Output created: dagitty_study_2.html


{'graph': <networkx.classes.digraph.DiGraph at 0x11f6d77c0>,
 'dagitty': 'dag {\n"Echocardiographic parameters such as left ventricular ejection fraction, left ventricular mass index (LVMI), left ventricular hypertrophy, left atrial enlargement and diastolic function"\n"Patients with type 2 diabetes mellitus (T2DM) with no history of cardiovascular disease (CVD)"\n"Patients with type 2 diabetes mellitus (T2DM) with no history of cardiovascular disease (CVD)" -> "Echocardiographic parameters such as left ventricular ejection fraction, left ventricular mass index (LVMI), left ventricular hypertrophy, left atrial enlargement and diastolic function"\n}',
 'nodes': ['Echocardiographic parameters such as left ventricular ejection fraction, left ventricular mass index (LVMI), left ventricular hypertrophy, left atrial enlargement and diastolic function',
  'Patients with type 2 diabetes mellitus (T2DM) with no history of cardiovascular disease (CVD)'],
 'edges': [('Patients with type 2 diabete