In [1]:
import sys
import os
from dataclasses import dataclass
from dotenv import load_dotenv
from openai import OpenAI
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading

sys.path.insert(0, "../src")

from extract_valid_studies import main
from models import ValidStudy, PrimaryOutcome

load_dotenv()

True

In [2]:
MODEL = "gpt-5-mini"
RAW_STUDIES_DIR = "../data/raw_studies"
STUDIES_PROCESSED_OUTPUT_DIR= "../data/studies_processed"
CARDS_OUTPUT_DIR = "../data/cards"
MAX_WORKERS = 100
NUM_CARDS = 500

In [3]:
from pydantic import BaseModel, Field, model_validator

class LLMResponse(BaseModel):
    question: str = Field(..., description="The final question in the specified format with the appropriate placeholders filled in verbatim with the other fields.")
    intervention_fragment: str = Field(..., description="The main intervention being tested, in layperson's terms. This should be directly pluggable into the question template.")
    intervention_group_fragment: str = Field(..., description="The purpose of the clinical trial, in layperson's terms. This should be directly pluggable into the question template.")
    outcome_fragment: str = Field(..., description="The primary outcome being measured, in layperson's terms. This should be directly pluggable into the question template.")
    comparator_group_fragment: str = Field(..., description="The comparator or control condition, in layperson's terms. This should be directly pluggable into the question template.")
    timeframe_fragment: str = Field(..., description="The timeframe of the outcome measurement, in layperson's terms. This should be directly pluggable into the question template.")
    intervention_group_description: str = Field(..., description="A brief description of the intervention group.")
    comparator_group_description: str = Field(..., description="A brief description of the comparator/control group.")

    @model_validator(mode="after")
    def ensure_question_valid(self):
        expected_question = f"Did {self.intervention_fragment} improve {self.outcome_fragment} in {self.intervention_group_fragment} compared to {self.comparator_group_fragment} after {self.timeframe_fragment}?"
        if self.question != expected_question:
            raise ValueError(f"Question does not match the expected format. Got: {self.question}, Expected: {expected_question}")
        return self


@dataclass
class ProcessingInformation:
    study: ValidStudy
    outcome_id: str
    llm_response: LLMResponse

    def to_dict(self):
        def recursive_asdict(obj):
            if isinstance(obj, list):
                return [recursive_asdict(item) for item in obj]
            elif isinstance(obj, dict):
                return {key: recursive_asdict(value) for key, value in obj.items()}
            elif hasattr(obj, "__dict__"):
                return {key: recursive_asdict(value) for key, value in obj.__dict__.items()}
            else:
                return obj
        return {
            "study": recursive_asdict(self.study),
            "outcome_id": self.outcome_id,
            "llm_response": self.llm_response.model_dump(),
        }

In [4]:
@dataclass
class UsageTracker:
    total_api_calls: int = 0
    total_input_tokens: int = 0
    total_output_tokens: int = 0


    def cost(self) -> float:
        c_i = {
            "gpt-5": 1.25,
            "gpt-5-mini": 0.25,
        }
        c_o = {
            "gpt-5": 10,
            "gpt-5-mini": 2,
        }
        input_cost = (self.total_input_tokens / 1_000_000) * c_i[MODEL]
        output_cost = (self.total_output_tokens / 1_000_000) * c_o[MODEL]

        return input_cost + output_cost

    def summary(self):
        print(f"Total API calls: {self.total_api_calls}")
        print(f"Total input tokens: {self.total_input_tokens}")
        print(f"Total output tokens: {self.total_output_tokens}")
        print(f"Estimated cost: ${self.cost():.4f}")

tracker = UsageTracker()

In [5]:
def mk_prompt(v: ValidStudy, o: PrimaryOutcome) -> str:
    groups_info = []
    for g in o.groups:
        groups_info.append("\n".join([
            f"Group title: {g.title}",
            f"Description: {g.description}",
            f"Interventions: {', '.join([f"{i.name}: {i.description}" for i in g.interventions]) if g.interventions else '(uncertain)'}",
        ]))


    return f"""
We are creating flashcard summaries for a game where laypeople predict the outcomes of clinical trials (behavioral interventive).

The final question for the flashcard must be of format:
"Did [intervention_fragment] improve [outcome_fragment] in [intervention_group_fragment] compared to [comparator_group_fragment] after [timeframe_fragment]?"

Keep the questions as short as possible. Use acronyms if needed, as long as they are understandable. If there is a name given to the intervention (e.g. "The Jolly Flower Telephone Protocol for Healthy Ageing"), instead of using the name, simply describe the intervention in layperson's terms (e.g. "calling other elderly people").

Remember to keep the question short. Do not include examples. Do not include any additional text or explanation.

Recall that the question should be reconstructable by verbatim plugging in the other fields.


Please ensure that the answers are concise and easily understandable by someone without a medical background. Avoid technical jargon and use simple language. Where something is technical, give a lay description and then in parentheses the technical term. 

Please create a question based on the following clinical trial information:
• Trial Title: {v.title}
• Trial Description: {v.description}
• Measure: {o.title}
• Measure Description: {o.description}
• Timeframe: {o.timeframe}

The groups are as follows (the first is the intervention group):
{'\n\n'.join(groups_info)}


If there is missing intervention or comparator information, please either match to these interventions (if you can tell from the group title/description), or say "Control" if it is a no-treatment or standard care control group, or "Unknown" if you cannot tell.
    """


In [6]:
valid_studies = main(RAW_STUDIES_DIR, NUM_CARDS)

 85%|████████▍ | 29359/34562 [00:07<00:01, 4084.48it/s]

Loaded 5000 raw studies with results
1736 out of 5000 (34.72%) studies have p-values reported in primary outcomes analyses.





In [7]:
def process_single_outcome(study: ValidStudy, o: PrimaryOutcome, tracker_lock: threading.Lock) -> ProcessingInformation:
    """Process a single study and return the result or raise an exception."""
    MAX_TRIES = 3
    success = False
    n_tries = 0

    client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    while not success and n_tries < MAX_TRIES:
        n_tries += 1

        try:
            response = client.responses.parse(
                model=MODEL,
                input=[
                    {
                        "role": "system",
                        "content": "You are an expert clinical trial analyst."
                    },
                    {"role": "user", "content": mk_prompt(study, o)},
                ],
                text_format=LLMResponse,
            )

            # Track token usage (thread-safe)
            with tracker_lock:
                tracker.total_api_calls += 1
                tracker.total_input_tokens += response.usage.input_tokens
                tracker.total_output_tokens += response.usage.output_tokens

            # Parse response
            llm_response = response.output_parsed

            return ProcessingInformation(
                study=study,
                outcome_id=o.id,
                llm_response=llm_response
            )
        except Exception as e:
            # Re-raise with study info for better error handling
            if n_tries >= MAX_TRIES - 1:
                raise Exception(f"Failed processing {study.nct_id}: {str(e)}")


def process_studies_multithreaded(studies: list[ValidStudy], max_workers: int = MAX_WORKERS) -> tuple[list[ProcessingInformation], list[tuple[str, str]]]:
    """Process studies using multiple threads."""
    results = []
    failures = []
    tracker_lock = threading.Lock()
    
    # Use ThreadPoolExecutor for concurrent API calls
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all tasks
        future_to_study = {
            executor.submit(process_single_outcome, study, study.primary_outcomes[0], tracker_lock): study
            for study in studies
        }
        
        # Process completed tasks with progress bar
        with tqdm(total=len(studies), desc="Processing studies") as pbar:
            for future in as_completed(future_to_study):
                study = future_to_study[future]
                try:
                    result = future.result()
                    results.append(result)
                except Exception as e:
                    failures.append((study.nct_id, str(e)))
                    pbar.set_postfix_str(f"Failures: {len(failures)}")
                
                pbar.update(1)
    
    return results, failures

In [8]:
# Process studies using multithreading (adjust max_workers based on API rate limits)
results, failures = process_studies_multithreaded(valid_studies[:NUM_CARDS], max_workers=MAX_WORKERS)

tracker.summary()

print(f"Successfully processed: {len(results)} studies")
print(f"Failed: {len(failures)} studies")
if failures:
    print("Failures:")
    for nct_id, error in failures[:5]:  # Show first 5 failures
        print(f"  {nct_id}: {error}")

Processing studies: 100%|██████████| 500/500 [03:15<00:00,  2.55it/s, Failures: 52]

Total API calls: 448
Total input tokens: 641911
Total output tokens: 588290
Estimated cost: $1.3371
Successfully processed: 448 studies
Failed: 52 studies
Failures:
  NCT01744977: Failed processing NCT01744977: 1 validation error for LLMResponse
  Value error, Question does not match the expected format. Got: Did blister-packaging plus adherence education improve statin adherence in people with high LDL or poor adherence compared to the education-only group after 12 months?, Expected: Did blister-packaging plus adherence education improve statin adherence (measured by pill refill) in people with high LDL or poor statin adherence compared to education-only group after 12 months? [type=value_error, input_value={'question': 'Did blister... adherence packaging).'}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.11/v/value_error
  NCT03270943: Failed processing NCT03270943: 1 validation error for LLMResponse
  Value error, Question does not match the expect




In [9]:
import ujson as json
for result in results:
    with open(os.path.join(STUDIES_PROCESSED_OUTPUT_DIR, f"{result.study.nct_id}_{result.outcome_id}_llm_response.json"), "w") as f:
        json.dump(result.to_dict(), f, indent=4)

In [17]:
def get_num_participants(groups) -> int:
    n = 0
    for g in groups:
        num = g.num_participants
        if isinstance(num, int):
            n += num
        elif isinstance(num, str):
            new_num = ''
            for c in num:
                if c.isdigit():
                    new_num += c
            n += int(new_num)

    return n

def mk_card(r: ProcessingInformation) -> dict:
    """Create a card dictionary from the processing information."""
    o = next(o for o in r.study.primary_outcomes if o.id == r.outcome_id)
    p_value = f"{o.p_value.comparator}{o.p_value.value}"
    if p_value.startswith("="):
        p_value = p_value[1:]

    success = (o.p_value.value < 0.05 and o.p_value.comparator != ">") or p_value == "<0.05"

    return {
        "study": {
            "nct_id": r.study.nct_id,
            "title": r.study.title,
            "brief_description": r.study.brief_description,
        },
        "card_id": r.outcome_id,
        "front_details": {
            "question": r.llm_response.question,
            "intervention_fragment": r.llm_response.intervention_fragment,
            "intervention_group_fragment": r.llm_response.intervention_group_fragment,
            "outcome_fragment": r.llm_response.outcome_fragment,
            "comparator_group_fragment": r.llm_response.comparator_group_fragment,
            "timeframe_fragment": r.llm_response.timeframe_fragment,
        },
        "p_value": p_value,
        "num_participants": get_num_participants(o.groups),
        "success": success,
        "conditions": r.study.conditions,
        "keywords": r.study.keywords,
    }

cards = [mk_card(r) for r in results]
with open(os.path.join(CARDS_OUTPUT_DIR, "cards.json"), "w") as f:
    json.dump(cards, f, indent=4)

In [18]:
for c in cards:
    print(c["front_details"]["question"])
    print(c["p_value"])
    print(c["num_participants"])

Did 6-week diabetes self-management classes improve blood sugar control (HbA1c) in adults with T2DM compared to usual care (control) after 12 months?
0.771
376
Did using a web-based lung cancer screening decision aid improve decisional conflict (uncertainty about the screening decision) in Veterans who used the decision tool compared to Veterans given general prevention info (not about lung cancer) after 1 month?
0.18
124
Did practicing Tai Chi improve knee pain (WOMAC pain score) in people with knee osteoarthritis compared to wellness education and stretching after 12 weeks?
0.0005
40
Did patient-parent-coach support with problem-solving, if‑then action plans, and dose reminders improve percent of days with full medication taking (100% taking adherence) in teen kidney transplant recipients compared to attention-control coach visits (no adherence support) after 12 months?
0.006
138
Did ibuprofen (200 mg twice daily) plus a 6-week home walking and resistance-band exercise program improv

In [None]:
from pprint import pprint
for r in results:
    o = next(o for o in r.study.primary_outcomes if o.id == r.outcome_id)
    succ = o.p_value.value < 0.05 and o.p_value.comparator is not ">"
    print(f"Study {r.study.nct_id} {r.study.title}:")
    pprint(f"  Question: {r.llm_response.question}")
    print(f"  Answer: {'YES' if succ else 'NO'}")
    print('-'*60)
    
o

Study NCT02665468 eHealth Partnered Evaluation Initiative:
('  Question: Did support to encourage secure messaging (reminders, '
 'motivational messages, education) improve secure messaging use (sent at '
 'least one secure message) in Veterans given the supported-adoption '
 'intervention compared to Veterans receiving usual care (no extra support) '
 'after six months?')
  Answer: YES
------------------------------------------------------------
Study NCT01872338 Mindfulness-Based Cognitive Therapy for Suicide Prevention:
('  Question: Did mindfulness meditation plus suicide safety planning (MBCT-S) '
 'improve suicide events (attempts, preparatory acts, or suicidal thoughts '
 'needing emergency care) in high-risk Veterans (VA High Risk for Suicide '
 'List) compared to VA usual care (treatment as usual) after 12 months?')
  Answer: YES
------------------------------------------------------------
Study NCT01221090 Diabetes Self-Management Models to Reduce Health Disparities:
('  Ques

  succ = o.p_value.value < 0.05 and o.p_value.comparator is not ">"


PrimaryOutcome(nct_id='NCT03344796', id='NCT03344796_po_0', title='Acceptability of Wearing UV Sensor and Receiving Text Messages', description='Online system usability 6 item scale (Likert 7 items range from strongly disagree to strongly agree Minimum value 6, maximum value 42, higher score better outcome', population_description='Focus groups and structured interviews did not have data collected for acceptability of wearing the UV sensor and receiving text messages.', timeframe='cohort study 1 (arms 1 and 2) at end of 21 days, cohort study 2 (only one arm) at end of 28 days', groups=[Group(id='OG001', title='Cohort Study 1 Arm1', description='31 subjects\n\nFirst cohort study: It is expected that each of 60 melanoma survivors in arms 1+2 will wear the sensor and transmit data for 21 days in the warm weather months of June-Aug 2019. It will take about 2 months to enroll the subjects. Subjects will receive daily text messages in a sequence starting with behavioral facilitation, outcome