# Set up environment

In [161]:
import os
import json
import networkx as nx
from sqlalchemy import create_engine, MetaData
from openai import OpenAI
from dotenv import load_dotenv
import jsonschema

In [4]:
OPENAI_API_KEY = input("Enter your OpenAI API key: ").strip()

Enter your OpenAI API key:  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


In [162]:
client = OpenAI(api_key=OPENAI_API_KEY)

# Load files

In [163]:
load_dotenv()

engine = create_engine(
    f"postgresql+psycopg2://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}@"
    f"{os.getenv('DB_HOST', 'localhost')}:{os.getenv('DB_PORT', '5432')}/{os.getenv('DB_NAME', 'dag_review_db')}"
)
metadata = MetaData()
metadata.reflect(bind=engine)

## LLM prompts and schemas

In [164]:
prompt_dir = os.path.abspath(os.path.join(os.getcwd(), "..", "src", "llm", "prompts"))
# === PROMPTS TO SUMMARIZE STRUCTURED DATA ===
SYSTEM_PROMPT_SUMMARY_TEMPLATE = open(os.path.join(prompt_dir, "summary_extract_system.txt")).read()
USER_PROMPT_SUMMARY_TEMPLATE = open(os.path.join(prompt_dir, "summary_extract_user.txt")).read()

summary_schema_path = os.path.abspath(os.path.join(os.getcwd(), "..", "data", "schemas", "summary_extract_schema.json"))
with open(summary_schema_path, "r") as summary_schema_file:
    summary_schema = summary_schema_file.read()
    
SYSTEM_PROMPT_SUMMARY = SYSTEM_PROMPT_SUMMARY_TEMPLATE.replace("{{JSON_SCHEMA}}", summary_schema)

# === PROMPTS TO GENERATE DAG OUTPUT === 
SYSTEM_PROMPT_DAG_TEMPLATE = open(os.path.join(prompt_dir, "dag_node_extract_system.txt")).read()
NODE_PROMPT_TEMPLATE = open(os.path.join(prompt_dir, "dag_node_extract_user.txt")).read()

dag_node_schema_path = os.path.abspath(os.path.join(os.getcwd(), "..", "data", "schemas", "dag_node_extract_schema.json"))
with open(dag_node_schema_path, "r") as dag_schema_file:
    dag_node_schema = dag_schema_file.read()

dag_node_schema_path = os.path.abspath(os.path.join(os.getcwd(), "..", "data", "schemas", "dag_node_extract_schema.json"))
with open(dag_node_schema_path, "r") as dag_schema_file:
    dag_node_schema_json = json.load(dag_schema_file)
    
SYSTEM_PROMPT_DAG = SYSTEM_PROMPT_DAG_TEMPLATE.replace("{{JSON_SCHEMA}}", dag_node_schema)

# Core functions

## Formatting

In [165]:
# === Prepares structured object for inclusion in prompt === 
def format_scientific_variables(scientific_variables):
    data = json.loads(scientific_variables)
    
    result_sections = []
    
    # Process objectives with * bullets
    if data.get('objectives') and len(data['objectives']) > 0:
        objectives_section = ["OBJECTIVES:"]
        for obj in data['objectives']:
            objectives_section.append(f"* {obj['shortLabel']}")
        result_sections.append('\n'.join(objectives_section))
    
    # Process eligibility with - bullets  
    if data.get('eligibility') and len(data['eligibility']) > 0:
        eligibility_section = ["ELIGIBILITY:"]
        for elig in data['eligibility']:
            eligibility_section.append(f"* {elig['shortLabel']}")
        result_sections.append('\n'.join(eligibility_section))
    
    # Process outcomes with - bullets
    if data.get('outcomes') and len(data['outcomes']) > 0:
        outcomes_section = ["OUTCOMES:"]
        for outcome in data['outcomes']:
            outcomes_section.append(f"* {outcome['shortLabel']}")
        result_sections.append('\n'.join(outcomes_section))
    
    # Join all sections with double newlines
    return '\n\n'.join(result_sections)

### DAG formats

In [219]:
def generate_dagitty_file(dag_json, study_id, output_dir):
    # Create mapping from node ID to label
    id_to_label = {node['id']: node['label'] for node in dag_json['nodes']}
    
    # Build graph using only labels
    dag = nx.DiGraph()
    node_labels = [node['label'] for node in dag_json['nodes']]
    dag.add_nodes_from(node_labels)
    
    # Add edges using mapped labels
    edge_tuples = []
    for edge in dag_json['edges']:
        from_label = id_to_label[edge['from']]
        to_label = id_to_label[edge['to']]
        edge_tuples.append((from_label, to_label))
    
    dag.add_edges_from(edge_tuples)
    
    # Save to database
    with engine.begin() as trans_conn:
        # Clear existing DAG for this study
        trans_conn.execute(metadata.tables['dag_edges'].delete().where(
            metadata.tables['dag_edges'].c.study_id == study_id))
        trans_conn.execute(metadata.tables['dag_nodes'].delete().where(
            metadata.tables['dag_nodes'].c.study_id == study_id))
        
        # Insert all nodes first
        for node_label in dag.nodes:
            trans_conn.execute(metadata.tables['dag_nodes'].insert().values(
                study_id=study_id,
                node_label=node_label,
                node_type='llm',
                source='llm'
            ))
        
        # Query back the inserted nodes to get their IDs
        node_id_map = {}
        node_query = metadata.tables['dag_nodes'].select().where(
            metadata.tables['dag_nodes'].c.study_id == study_id
        )
        result = trans_conn.execute(node_query)
        for row in result:
            node_id_map[row.node_label] = row.id
        
        # Insert edges using the database node IDs
        for from_label, to_label in dag.edges:
            from_node_id = node_id_map.get(from_label)
            to_node_id = node_id_map.get(to_label)
            
            if from_node_id is not None and to_node_id is not None:
                trans_conn.execute(metadata.tables['dag_edges'].insert().values(
                    study_id=study_id,
                    from_node_id=from_node_id,
                    to_node_id=to_node_id,
                    relation_type='llm',
                    confidence=1.0
                ))
            else:
                print(f"Warning: Could not find node IDs for edge {from_label} -> {to_label}")
    
    # Export DAGitty format
    dagitty = "dag {\n"
    
    # Add nodes (only labels, no duplicates)
    for node_label in dag.nodes:
        safe_node = node_label.replace('"', "'")
        dagitty += f'"{safe_node}"\n'
    
    # Add edges
    for from_label, to_label in dag.edges:
        safe_from = from_label.replace('"', "'")
        safe_to = to_label.replace('"', "'")
        dagitty += f'"{safe_from}" -> "{safe_to}"\n'
    
    dagitty += "}"
    
    # Write to file
    os.makedirs(output_dir, exist_ok=True)
    dagitty_txt_path = os.path.join(output_dir, f"dagitty_study_{study_id}.txt")
    with open(dagitty_txt_path, "w") as f:
        f.write(dagitty + "\n")

    json_path = os.path.join(output_dir, f"dagitty_study_{study_id}.json")
    print(f"json_path {json_path}")
    with open(json_path, "w") as f:
        json.dump(dag_json, f, indent=2)
    
    return dagitty

In [232]:
# Create R Markdown file to render DAG
def generate_rmd_dag(output_dir, study_id):
    rmd_path = os.path.join(output_dir, f"dagitty_study_{study_id}.Rmd")
    json_path = os.path.join(output_dir, f"dagitty_study_{study_id}.json")
    txt_path = os.path.join(output_dir, f"dagitty_study_{study_id}.txt")
    rmd_content = f"""---
title: "DAGitty DAG for Study {study_id}"
output: html_document
---
```{{r setup, include=FALSE}}
library(dagitty)
library(ggdag)
library(ggplot2)
library(jsonlite)
```

```{{r load-and-plot-dag}}
# Read the structure JSON to get node labels
structure_data <- fromJSON('{json_path}') 

# Create a mapping from ID to label
node_mapping <- setNames(structure_data$nodes$label, structure_data$nodes$id)

dag_text <- readLines("dagitty_study_2.txt")
dag <- dagitty(paste(dag_text, collapse = "\n"))

tidy_dag <- tidy_dagitty(dag)

# Plot using direct labels
ggplot(tidy_dag, aes(x = x, y = y, xend = xend, yend = yend)) +
  geom_dag_edges() +
  geom_dag_node(color = "lightblue", size = 8) +

  # Use geom_label instead of geom_dag_text for styled labels
  geom_label(
    aes(label = name, y = y - 0.2),  
    label.size = 0.2,                
    label.r = unit(0.1, "lines"),    
    color = "black",                 
    fill = "white",                  
    label.padding = unit(0.15, "lines"),
    label.color = "pink",          
    size = 3                         
  ) +

  theme_dag() +
  theme(legend.position = "none")
```"""
    with open(rmd_path, "w") as rmd_file:
        rmd_file.write(rmd_content)

    import subprocess
    subprocess.run(["Rscript", "-e", f"rmarkdown::render('{rmd_path}')"])
    html_path = os.path.join(output_dir, f"dagitty_study_{study_id}.html")
    if os.path.exists(html_path):
        import webbrowser
        webbrowser.open(f"file://{os.path.abspath(html_path)}")

    return rmd_path

## LLM

In [200]:
def standardize_variables_llm(study_id):
    with engine.connect() as conn:
        objectives = [row._mapping['content'] for row in conn.execute(
            metadata.tables['objectives'].select().where(
                metadata.tables['objectives'].c.study_id == study_id))]

        eligibility = [row._mapping['criteria'] for row in conn.execute(
            metadata.tables['eligibility_criteria'].select().where(
                metadata.tables['eligibility_criteria'].c.study_id == study_id))]

        outcomes = []
        for row in conn.execute(metadata.tables['outcomes'].select().where(
            metadata.tables['outcomes'].c.study_id == study_id)):
            data = row._mapping.get('data')
            if isinstance(data, list) and data and isinstance(data[0], dict) and 'value' in data[0]:
                outcomes.append(data[0]['value'])
            else:
                outcomes.append(str(data))

        # Step 1: Summarize terms
        object_summarization = f"""
OBJECTIVES:
{json.dumps(objectives, indent=2)}

ELIGIBILITY CRITERIA:
{json.dumps(eligibility, indent=2)}

OUTCOMES:
{json.dumps(outcomes, indent=2)}
"""
        
        USER_PROMPT_SUMMARY = USER_PROMPT_SUMMARY_TEMPLATE.replace("{{OBJECT_SUMMARY}}", object_summarization)
        summary_response = client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT_SUMMARY},
                {"role": "user", "content": USER_PROMPT_SUMMARY}
            ],
            temperature=0.3
        )
        
        scientific_variables = summary_response.choices[0].message.content

        summarized_variables = format_scientific_variables(scientific_variables)

        return summarized_variables

In [201]:
def propose_dag_llm(variable_summary, study_id, output_dir):
        DAG_PROMPT = NODE_PROMPT_TEMPLATE.replace("{{SCIENTIFIC_VARIABLES}}", variable_summary)

        assistant = client.beta.assistants.create(
            name="RWE DAG Generator",
            instructions=SYSTEM_PROMPT_DAG,
            model="gpt-4",
            tools=[{"type": "function", "function": {"name": "generate_dag", "parameters": dag_node_schema_json}}]
        )
        
        thread = client.beta.threads.create()
        
        client.beta.threads.messages.create(
            thread_id=thread.id,
            role="user",
            content=DAG_PROMPT
        )
        
        run = client.beta.threads.runs.create_and_poll(thread_id=thread.id, assistant_id=assistant.id)
        
        messages = client.beta.threads.messages.list(thread_id=thread.id)
        response_content = messages.data[0].content[0].text.value
        
        try:
            structured_dag_proposal = json.loads(response_content)
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON: {e}\n\nContent:\n{response_content}")
        
        try:
            jsonschema.validate(instance=structured_dag_proposal, schema=dag_node_schema_json)
        except jsonschema.ValidationError as e:
            raise ValueError(f"Response did not match schema: {e.message}")
        
        os.makedirs(output_dir, exist_ok=True)
        structured_dag_path = os.path.join(output_dir, f"structured_dag_study_{study_id}.json")
        with open(structured_dag_path, "w") as f:
            json.dump(structured_dag_proposal, f, indent=2)
        
        return structured_dag_proposal

# Routines

In [197]:
study_id = 2
output_dir = os.path.abspath(os.path.join(os.getcwd(), "..", "output"))
print(f"\n ... \n ... running standardize_variables_llm")
variable_summary = standardize_variables_llm(study_id)
print(f" ... \n\n  1 | VARIABLE SUMMARY\n\n {variable_summary}")
print(f"\n ... \n ... running propose_dag_llm")
structured_dag_proposal = propose_dag_llm(variable_summary, study_id, output_dir) 
print(f" ... \n\n  2 | STRUCTURED DAG PROPOSAL \n\n {structured_dag_proposal}")
print(f"\n ... \n ... running generate_dagitty_file")
dagitty = generate_dagitty_file(structured_dag_proposal, study_id, output_dir)
print(f" ... \n ... running generate_rmd_dag")
print(f"\n ... \n\n  3 | DAGITTY \n\n {dagitty}")
generate_rmd_dag(dagitty_file_directory, study_id)


 ... 
 ... running standardize_variables_llm
 ... 

  1 | VARIABLE SUMMARY

 OBJECTIVES:
* SBHF Prevalence in T2DM

ELIGIBILITY:
* T2DM Patients without CVD

OUTCOMES:
* Echocardiographic Parameters

 ... 
 ... running propose_dag_llm
 ... 

  2 | STRUCTURED DAG PROPOSAL 

 {'nodes': [{'id': '1', 'label': 'T2DM'}, {'id': '2', 'label': 'SBHF Prevalence in T2DM'}, {'id': '3', 'label': 'Echocardiographic Parameters'}, {'id': '4', 'label': 'Age'}, {'id': '5', 'label': 'Sex'}, {'id': '6', 'label': 'BMI'}, {'id': '7', 'label': 'Hypertension'}, {'id': '8', 'label': 'Smoking Status'}, {'id': '9', 'label': 'Cholesterol Levels'}], 'edges': [{'from': '1', 'to': '2'}, {'from': '4', 'to': '1'}, {'from': '5', 'to': '1'}, {'from': '6', 'to': '1'}, {'from': '7', 'to': '1'}, {'from': '8', 'to': '1'}, {'from': '9', 'to': '1'}, {'from': '4', 'to': '3'}, {'from': '5', 'to': '3'}, {'from': '6', 'to': '3'}, {'from': '7', 'to': '3'}, {'from': '2', 'to': '3'}]}

 ... 
 ... running generate_dagitty_file
 ... 



processing file: dagitty_study_2.Rmd


1/4                    
2/4 [setup]            
3/4                    
4/4 [load-and-plot-dag]


Error in `open.connection()`:
! cannot open the connection
Backtrace:
    ▆
 1. └─jsonlite::fromJSON("/Users/aimeeharrison/causal-commons/dag_review/output")
 2.   └─jsonlite:::parse_and_simplify(...)
 3.     └─jsonlite:::parseJSON(txt, bigint_as_char)
 4.       └─jsonlite:::parse_con(txt, bigint_as_char)
 5.         ├─base::open(con, "rb")
 6.         └─base::open.connection(con, "rb")

Quitting from dagitty_study_2.Rmd:12-37 [load-and-plot-dag]
Execution halted


NameError: name 'output_dir' is not defined

In [226]:
output_dir = os.path.abspath(os.path.join(os.getcwd(), "..", "output"))
print(f"\n ... \n ... running generate_dagitty_file")
dagitty = generate_dagitty_file(structured_dag_proposal, study_id, output_dir)
print(f"\n ... \n\n  3 | DAGITTY \n\n {dagitty}")


 ... 
 ... running generate_dagitty_file
json_path /Users/aimeeharrison/causal-commons/dag_review/output/dagitty_study_2.json

 ... 

  3 | DAGITTY 

 dag {
"T2DM"
"SBHF Prevalence in T2DM"
"Echocardiographic Parameters"
"Age"
"Sex"
"BMI"
"Hypertension"
"Smoking Status"
"Cholesterol Levels"
"T2DM" -> "SBHF Prevalence in T2DM"
"SBHF Prevalence in T2DM" -> "Echocardiographic Parameters"
"Age" -> "T2DM"
"Age" -> "Echocardiographic Parameters"
"Sex" -> "T2DM"
"Sex" -> "Echocardiographic Parameters"
"BMI" -> "T2DM"
"BMI" -> "Echocardiographic Parameters"
"Hypertension" -> "T2DM"
"Hypertension" -> "Echocardiographic Parameters"
"Smoking Status" -> "T2DM"
"Cholesterol Levels" -> "T2DM"
}


In [233]:

print(f" ... \n ... running generate_rmd_dag")
generate_rmd_dag(output_dir, study_id)

 ... 
 ... running generate_rmd_dag




processing file: dagitty_study_2.Rmd


1/4                    
2/4 [setup]            
3/4                    
4/4 [load-and-plot-dag]


output file: dagitty_study_2.knit.md



/opt/homebrew/bin/pandoc +RTS -K512m -RTS dagitty_study_2.knit.md --to html4 --from markdown+autolink_bare_uris+tex_math_single_backslash --output dagitty_study_2.html --lua-filter /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/library/rmarkdown/rmarkdown/lua/pagebreak.lua --lua-filter /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/library/rmarkdown/rmarkdown/lua/latex-div.lua --lua-filter /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/library/rmarkdown/rmarkdown/lua/table-classes.lua --embed-resources --standalone --variable bs3=TRUE --section-divs --template /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/library/rmarkdown/rmd/h/default.html --no-highlight --variable highlightjs=1 --variable theme=bootstrap --mathjax --variable 'mathjax-url=https://mathjax.rstudio.com/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML' --include-in-header /var/folders/g5/fk0fxc8n0b56nqrtwj8ywdlm0000gn/T//RtmpmZPhgc/rmarkdown-str172676ecf2dff.html 



Output created: dagitty_study_2.html


'/Users/aimeeharrison/causal-commons/dag_review/output/dagitty_study_2.Rmd'

In [None]:
jsonlite::fromJSON("/Users/aimeeharrison/causal-commons/dag_review/output/dagitty_study_2.json")