# DDXPlus Causal QA Dataset Builder

This notebook builds a causal reasoning dataset from DDXPlus medical diagnosis data.

## Pipeline Overview:
1. **Load Data**: Evidence definitions, conditions, patient records
2. **KB-based Diagnosis Model**: Simple symptom-matching model for probability estimation
3. **Causal Intervention**: Apply do(X) operations on patient symptoms
4. **Effect Computation**: Calculate more/less/no_effect based on probability changes
5. **QA Generation**: Create natural language questions with reasoning chains

## 1. Imports and Configuration

In [None]:
import json
import math
import random
import ast
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from collections import defaultdict

import pandas as pd

print("✓ Imports loaded successfully")

In [None]:
# ============================================================
# Configuration
# ============================================================

BASE_DIR = Path(r"e:\PHD\01\Dataset\DDXPlus\22687585")
EVIDENCE_FILE = BASE_DIR / "release_evidences.json"
CONDITION_FILE = BASE_DIR / "release_conditions.json"
PATIENT_FILE = BASE_DIR / "release_validate_patients"  # No .csv extension

OUTPUT_FILE = Path(r"e:\PHD\01") / "DDXPlus_CausalQA.jsonl"

# Maximum number of QA items to generate
MAX_QA = 10000

# Scoring weights for KB-based diagnosis model
SYMPTOM_WEIGHT = 1.0
ANTECEDENT_WEIGHT = 0.5
SEVERITY_WEIGHT = 0.1

print(f"✓ Configuration loaded")
print(f"  - Data directory: {BASE_DIR}")
print(f"  - Output file: {OUTPUT_FILE}")
print(f"  - Max QA items: {MAX_QA}")

## 2. Load DDXPlus Data

In [None]:
def load_evidences() -> Dict[str, dict]:
    """Load evidence definitions (symptoms, vital signs, etc.)"""
    with open(EVIDENCE_FILE, "r", encoding="utf-8") as f:
        evidences = json.load(f)
    return {e["name"]: e for e in evidences}

def load_conditions() -> Dict[str, dict]:
    """Load disease/condition definitions"""
    with open(CONDITION_FILE, "r", encoding="utf-8") as f:
        conds = json.load(f)
    return {c["condition_name"]: c for c in conds}

def load_patients() -> pd.DataFrame:
    """Load patient records"""
    df = pd.read_csv(PATIENT_FILE)
    return df

# Load all data
EVIDENCE_MAP = load_evidences()
CONDITION_MAP = load_conditions()
PATIENTS = load_patients()

print(f"✓ Data loaded successfully")
print(f"  - Evidences: {len(EVIDENCE_MAP)}")
print(f"  - Conditions: {len(CONDITION_MAP)}")
print(f"  - Patients: {len(PATIENTS)}")
print(f"\nSample evidence types: {list(EVIDENCE_MAP.keys())[:5]}")
print(f"Sample conditions: {list(CONDITION_MAP.keys())[:5]}")

## 3. Evidence Parsing Utilities

In [None]:
def parse_evidences(evid_str) -> List[str]:
    """
    Parse patient EVIDENCES field into list of evidence tokens.
    Supports formats: list string, semicolon-separated, comma-separated, single value
    """
    if isinstance(evid_str, list):
        return [str(x) for x in evid_str]
    
    if pd.isna(evid_str):
        return []
    
    s = str(evid_str).strip()
    if not s:
        return []
    
    # Try parsing as list literal
    try:
        if s.startswith("[") and s.endswith("]"):
            lst = ast.literal_eval(s)
            return [str(x) for x in lst]
    except Exception:
        pass
    
    # Try semicolon-separated
    if ";" in s:
        return [t.strip() for t in s.split(";") if t.strip()]
    
    # Try comma-separated
    if "," in s:
        return [t.strip(" '\"") for t in s.split(",") if t.strip()]
    
    # Single evidence code
    return [s]

def get_ev_name(token: str) -> str:
    """Extract evidence name from token: 'E_130_@_V_86' -> 'E_130'"""
    token = str(token)
    if "_@_" in token:
        return token.split("_@_")[0]
    return token

def get_ev_value(token: str) -> Optional[str]:
    """Extract value from token: 'E_130_@_V_86' -> 'V_86'"""
    token = str(token)
    if "_@_" in token:
        parts = token.split("_@_")
        return parts[1] if len(parts) > 1 else None
    return None

print("✓ Evidence parsing utilities defined")

## 4. Natural Language Context Builder

In [None]:
def decode_values_for_evidence(ev_name: str, tokens_for_ev: List[str]) -> List[str]:
    """
    Convert evidence codes to human-readable text.
    - Binary (B): Returns ["Yes"] if present
    - Categorical (C) / Multi-choice (M): Returns English descriptions
    """
    ev_def = EVIDENCE_MAP[ev_name]
    dtype = ev_def["data_type"]
    
    if dtype == "B":
        return ["Yes"]
    
    # For C/M types
    value_meaning = ev_def.get("value_meaning", {})
    texts = []
    for t in tokens_for_ev:
        v = get_ev_value(t)
        if not v:
            continue
        if v in value_meaning:
            en = value_meaning[v].get("en", v)
        else:
            en = v
        # Skip default values (equivalent to N/A)
        if v == ev_def.get("default_value"):
            continue
        texts.append(en)
    
    # Remove duplicates while preserving order
    texts = list(dict.fromkeys(texts))
    return texts

def build_context_from_evidences(evid_tokens: List[str]) -> str:
    """
    Build natural language context from patient evidence list.
    Format: Q: question?\n A: answer
    """
    grouped = defaultdict(list)
    for tok in evid_tokens:
        ev_name = get_ev_name(tok)
        grouped[ev_name].append(tok)
    
    lines = []
    for ev_name, toks in grouped.items():
        if ev_name not in EVIDENCE_MAP:
            continue
        ev_def = EVIDENCE_MAP[ev_name]
        q = ev_def.get("question_en", "").strip()
        if not q:
            continue
        answers = decode_values_for_evidence(ev_name, toks)
        if not answers:
            continue
        ans_text = ", ".join(answers)
        lines.append(f"{q}: {ans_text}")
    
    return "\n".join(lines)

print("✓ Context builder defined")

## 5. KB-Based Diagnosis Model

Simple knowledge-based diagnostic model that scores diseases based on symptom matching.

In [None]:
def kb_model_predict(evid_tokens: List[str]) -> Dict[str, float]:
    """
    Predict disease probabilities using knowledge-based symptom matching.
    
    Score calculation:
    - +SYMPTOM_WEIGHT for each matching symptom
    - +ANTECEDENT_WEIGHT for each matching antecedent (history)
    - -SEVERITY_WEIGHT * disease severity
    
    Returns: {disease_name: probability} after softmax normalization
    """
    ev_names_present = {get_ev_name(t) for t in evid_tokens}
    
    scores: Dict[str, float] = {}
    
    for cond_name, cond in CONDITION_MAP.items():
        sym = set(cond.get("symptoms", {}).keys())
        ant = set(cond.get("antecedents", {}).keys())
        severity = cond.get("severity", 3)
        
        sym_match = len(sym & ev_names_present)
        ant_match = len(ant & ev_names_present)
        
        score = (
            SYMPTOM_WEIGHT * sym_match
            + ANTECEDENT_WEIGHT * ant_match
            - SEVERITY_WEIGHT * severity
        )
        scores[cond_name] = score
    
    # Softmax normalization
    max_score = max(scores.values())
    exps = {k: math.exp(v - max_score) for k, v in scores.items()}
    z = sum(exps.values())
    
    if z <= 0:
        # Fallback to uniform distribution
        n = len(exps)
        return {k: 1.0 / n for k in exps.keys()}
    
    probs = {k: v / z for k, v in exps.items()}
    return probs

print("✓ KB-based diagnosis model defined")

## 6. Causal Intervention Engine

Implements do(X) interventions on patient symptoms.

In [None]:
def intervene_evidences(evid_tokens: List[str], ev_name: str) -> Optional[List[str]]:
    """
    Apply causal intervention do(X) on evidence ev_name.
    
    Intervention strategies by data type:
    - Binary (B): Toggle presence/absence
    - Categorical (C): Change to a different non-default value
    - Multi-choice (M): Add or remove one value
    
    Returns: New evidence list after intervention, or None if intervention fails
    """
    ev_def = EVIDENCE_MAP.get(ev_name)
    if ev_def is None:
        return None
    
    dtype = ev_def["data_type"]
    default_val = ev_def.get("default_value")
    
    # Current tokens for this evidence
    tokens_for_ev = [t for t in evid_tokens if get_ev_name(t) == ev_name]
    new_tokens = list(evid_tokens)
    
    if dtype == "B":
        # Binary: toggle presence
        if ev_name in new_tokens:
            new_tokens = [t for t in new_tokens if t != ev_name]
        else:
            new_tokens.append(ev_name)
        return new_tokens
    
    possible_vals = ev_def.get("possible-values", [])
    
    if dtype == "C":
        # Categorical: change to different value
        cur_vals = [get_ev_value(t) for t in tokens_for_ev if get_ev_value(t)]
        cur_val = cur_vals[0] if cur_vals else default_val
        
        candidate_vals = [v for v in possible_vals if v != cur_val and v != default_val]
        if not candidate_vals:
            return None
        
        new_val = random.choice(candidate_vals)
        # Remove old token and add new one
        new_tokens = [t for t in new_tokens if get_ev_name(t) != ev_name]
        new_tokens.append(f"{ev_name}_@_{new_val}")
        return new_tokens
    
    if dtype == "M":
        # Multi-choice: add or remove a value
        cur_vals = {
            get_ev_value(t)
            for t in tokens_for_ev
            if get_ev_value(t) and get_ev_value(t) != default_val
        }
        all_vals = [v for v in possible_vals if v != default_val]
        
        if random.random() < 0.5:
            # Try to remove a value
            if not cur_vals:
                return None
            remove_v = random.choice(list(cur_vals))
            new_tokens = [
                t for t in new_tokens
                if not (get_ev_name(t) == ev_name and get_ev_value(t) == remove_v)
            ]
        else:
            # Try to add a value
            candidates = [v for v in all_vals if v not in cur_vals]
            if not candidates:
                # Can't add, try removing instead
                if not cur_vals:
                    return None
                remove_v = random.choice(list(cur_vals))
                new_tokens = [
                    t for t in new_tokens
                    if not (get_ev_name(t) == ev_name and get_ev_value(t) == remove_v)
                ]
            else:
                add_v = random.choice(candidates)
                new_tokens.append(f"{ev_name}_@_{add_v}")
        
        return new_tokens
    
    # Unknown data type
    return None

print("✓ Intervention engine defined")

## 7. Reasoning Chain Builder & Effect Computation

In [None]:
def build_reasoning_chain(disease_name: str, evid_tokens: List[str]) -> str:
    """
    Generate causal reasoning chain: Disease -> Observed Symptoms
    Shows which disease symptoms are present in the patient.
    """
    cond = CONDITION_MAP.get(disease_name)
    if cond is None:
        return ""
    
    symptom_evs = set(cond.get("symptoms", {}).keys())
    grouped = defaultdict(list)
    for tok in evid_tokens:
        ev_name = get_ev_name(tok)
        if ev_name in symptom_evs:
            grouped[ev_name].append(tok)
    
    if not grouped:
        return f"{disease_name} is not strongly supported by currently observed symptoms."
    
    lines = [f"{disease_name} can cause the following symptoms observed in the patient:"]
    for ev_name, toks in grouped.items():
        ev_def = EVIDENCE_MAP.get(ev_name)
        if not ev_def:
            continue
        q = ev_def.get("question_en", "").strip()
        ans = decode_values_for_evidence(ev_name, toks)
        if not ans:
            continue
        lines.append(f"- {q} → {', '.join(ans)}")
    
    return "\n".join(lines)

def compute_effect(
    before_probs: Dict[str, float],
    after_probs: Dict[str, float],
    disease: str,
    eps: float = 1e-6,
) -> str:
    """
    Compute causal effect: more/less/no_effect
    Compare disease probability before and after intervention.
    """
    b = before_probs.get(disease, 0.0)
    a = after_probs.get(disease, 0.0)
    
    if a > b + eps:
        return "more"
    elif a < b - eps:
        return "less"
    else:
        return "no_effect"

print("✓ Reasoning chain builder and effect computation defined")

## 8. Main QA Dataset Builder

In [None]:
def build_causal_qa_dataset(max_qa: int = 10000) -> List[dict]:
    """
    Build causal QA dataset with real do(X) interventions.
    
    For each patient:
    1. Parse evidences
    2. Identify ground-truth disease (PATHOLOGY)
    3. Select an evidence X related to the disease
    4. Compute P(Y|evidences) - baseline probability
    5. Apply do(X) intervention
    6. Compute P(Y|do(X)) - post-intervention probability
    7. Determine effect: more/less/no_effect
    8. Generate QA item with reasoning chain
    """
    qa_items: List[dict] = []
    skipped = 0
    
    for idx, row in PATIENTS.iterrows():
        if len(qa_items) >= max_qa:
            break
        
        # Parse patient evidences
        evid_tokens = parse_evidences(row.get("EVIDENCES", ""))
        if not evid_tokens:
            skipped += 1
            continue
        
        # Get ground-truth pathology (disease Y)
        pathology = row.get("PATHOLOGY")
        if pathology not in CONDITION_MAP:
            skipped += 1
            continue
        
        cond = CONDITION_MAP[pathology]
        
        # Current evidence names present
        ev_names_present = {get_ev_name(t) for t in evid_tokens}
        
        # Select intervention target X from disease-related evidences
        related_evs = set(cond.get("symptoms", {}).keys()) | set(
            cond.get("antecedents", {}).keys()
        )
        candidate_X = sorted(ev_names_present & related_evs)
        
        if not candidate_X:
            skipped += 1
            continue
        
        X_name = random.choice(candidate_X)
        ev_def = EVIDENCE_MAP[X_name]
        
        # STEP 1: Baseline probability P(Y|evidences)
        before_probs = kb_model_predict(evid_tokens)
        
        # STEP 2: Apply intervention do(X)
        after_evid_tokens = intervene_evidences(evid_tokens, X_name)
        if after_evid_tokens is None:
            skipped += 1
            continue
        
        # STEP 3: Post-intervention probability P(Y|do(X))
        after_probs = kb_model_predict(after_evid_tokens)
        
        # STEP 4: Compute effect
        label = compute_effect(before_probs, after_probs, pathology)
        
        # STEP 5: Build natural language components
        context = build_context_from_evidences(evid_tokens)
        
        q_text = ev_def.get("question_en", "").strip()
        if not q_text:
            skipped += 1
            continue
        
        question = (
            f"Suppose the answer to the following question about the patient "
            f"changes due to an intervention:\n"
            f"\"{q_text}\"\n\n"
            f"How will this affect the likelihood of {pathology}?"
        )
        
        reasoning_chain = build_reasoning_chain(pathology, evid_tokens)
        
        # STEP 6: Create QA item
        qa = {
            "id": f"ddxplus_causalqa_{idx}_{X_name}",
            "context": context,
            "question": question,
            "choices": ["more", "less", "no_effect"],
            "label": label,
            "X_evidence": X_name,
            "Y_disease": pathology,
            "before_prob": round(before_probs.get(pathology, 0.0), 4),
            "after_prob": round(after_probs.get(pathology, 0.0), 4),
            "reasoning_chain": reasoning_chain,
        }
        
        qa_items.append(qa)
    
    print(f"✓ Generated {len(qa_items)} QA items (skipped {skipped} patients)")
    return qa_items

print("✓ QA dataset builder defined")

## 9. Generate Dataset and Save

In [None]:
# Generate the causal QA dataset
qa_items = build_causal_qa_dataset(max_qa=MAX_QA)

# Save to JSONL file
OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)
with OUTPUT_FILE.open("w", encoding="utf-8") as f:
    for item in qa_items:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")

print("\n" + "="*80)
print(f"✓ DDXPlus Causal QA dataset saved to: {OUTPUT_FILE}")
print(f"  Total items: {len(qa_items)}")
print("="*80)

## 10. Inspect Sample QA Items

In [None]:
if qa_items:
    print("Sample QA Item #1:")
    print("="*80)
    sample = qa_items[0]
    print(f"ID: {sample['id']}")
    print(f"\nContext (first 200 chars):")
    print(sample['context'][:200] + "...")
    print(f"\nQuestion:")
    print(sample['question'])
    print(f"\nChoices: {sample['choices']}")
    print(f"Label: {sample['label']}")
    print(f"\nIntervention: {sample['X_evidence']}")
    print(f"Target Disease: {sample['Y_disease']}")
    print(f"Prob Before: {sample['before_prob']:.4f}")
    print(f"Prob After: {sample['after_prob']:.4f}")
    print(f"\nReasoning Chain (first 300 chars):")
    print(sample['reasoning_chain'][:300] + "...")
    print("="*80)
    
    # Show label distribution
    from collections import Counter
    label_counts = Counter(item['label'] for item in qa_items)
    print(f"\nLabel Distribution:")
    for label, count in label_counts.most_common():
        print(f"  {label}: {count} ({count/len(qa_items)*100:.1f}%)")