# Train Causal DAG — ARGUS Phase 2

This notebook discovers and trains the causal DAG for the Beas–Brahmaputra
river basins using the PC algorithm (constraint-based causal discovery)
and validates the graph with interventional queries.

**Owner:** Rogesh · **Service:** `causal_engine`

In [None]:
import numpy as np
import pandas as pd
import json
from pathlib import Path
from datetime import datetime, timedelta

# Synthetic hydro-met data for DAG discovery
np.random.seed(42)
N = 2000  # samples

# Ground truth causal model:
# rainfall_upper → upstream_level → midstream_level → downstream_level → flood_risk
# rainfall_upper → soil_moisture → flood_risk
# dam_release → midstream_level

rainfall_upper = np.random.exponential(10, N)
rainfall_lower = 0.6 * rainfall_upper + np.random.normal(0, 3, N)
snowmelt = np.random.uniform(0, 5, N)
soil_moisture = 0.3 + 0.4 * (rainfall_upper / rainfall_upper.max()) + np.random.normal(0, 0.1, N)
soil_moisture = np.clip(soil_moisture, 0, 1)
dam_release = np.random.uniform(10, 100, N)
tributary_flow = np.random.uniform(5, 50, N)

upstream_level = 1.5 + 0.08 * rainfall_upper + 0.03 * snowmelt + np.random.normal(0, 0.3, N)
midstream_level = 0.5 * upstream_level + 0.02 * dam_release + 0.01 * tributary_flow + 0.3 * soil_moisture + np.random.normal(0, 0.2, N)
downstream_level = 0.7 * midstream_level + 0.02 * rainfall_lower + np.random.normal(0, 0.2, N)

# Flood risk: sigmoid of downstream level
flood_risk = 1 / (1 + np.exp(-(downstream_level - 3.5) * 2))
flood_risk += 0.1 * soil_moisture
flood_risk = np.clip(flood_risk, 0, 1)

df = pd.DataFrame({
    'rainfall_upper': rainfall_upper,
    'rainfall_lower': rainfall_lower,
    'snowmelt': snowmelt,
    'soil_moisture': soil_moisture,
    'dam_release': dam_release,
    'tributary_flow': tributary_flow,
    'upstream_level': upstream_level,
    'midstream_level': midstream_level,
    'downstream_level': downstream_level,
    'flood_risk': flood_risk,
})

print(f'Dataset shape: {df.shape}')
df.describe().round(3)

## Correlation-based DAG Discovery (PC Algorithm Approximation)

In [None]:
# Compute partial correlations as proxy for conditional independence
corr = df.corr()

# Threshold for edge existence
EDGE_THRESHOLD = 0.15

variables = list(df.columns)
edges = []

# Known causal ordering (domain knowledge)
causal_order = [
    'rainfall_upper', 'rainfall_lower', 'snowmelt',
    'soil_moisture', 'dam_release', 'tributary_flow',
    'upstream_level', 'midstream_level', 'downstream_level',
    'flood_risk',
]

# Discover edges: only from earlier to later in causal order
for i, src in enumerate(causal_order):
    for j, tgt in enumerate(causal_order):
        if j <= i:
            continue
        r = abs(corr.loc[src, tgt])
        if r > EDGE_THRESHOLD:
            edges.append({
                'source': src,
                'target': tgt,
                'weight': round(float(r), 3),
                'lag_hours': round((j - i) * 2.0, 1),  # heuristic lag
                'mechanism': 'hydrological' if 'level' in tgt or 'level' in src else 'meteorological',
            })

print(f'Discovered {len(edges)} causal edges')
for e in edges:
    print(f"  {e['source']:20s} → {e['target']:20s}  w={e['weight']:.3f}  lag={e['lag_hours']}h")

## Build & Save Causal DAG

In [None]:
# Build node list
nodes = []
child_map = {}
parent_map = {}
for e in edges:
    child_map.setdefault(e['source'], []).append(e['target'])
    parent_map.setdefault(e['target'], []).append(e['source'])

for var in variables:
    nodes.append({
        'node_id': var,
        'variable': var,
        'parents': parent_map.get(var, []),
        'children': child_map.get(var, []),
    })

dag = {
    'dag_id': 'beas_brahmaputra_v1',
    'nodes': nodes,
    'edges': edges,
    'version': '1.0.0',
    'created_at': datetime.now().isoformat(),
}

# Save
out_path = Path('../shared/causal_dag/beas_brahmaputra_v1.json')
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(dag, indent=2))
print(f'DAG saved to {out_path}')
print(f'  Nodes: {len(nodes)}, Edges: {len(edges)}')

## Validate: Interventional Queries

In [None]:
import sys
sys.path.insert(0, '..')

from services.causal_engine.dag import load_dag
from services.causal_engine.gnn import CausalGNNEngine
from services.causal_engine.interventions import InterventionAPI
from shared.models.phase2 import InterventionRequest

dag_model = load_dag(str(out_path))
engine = CausalGNNEngine(dag_model)
api = InterventionAPI(engine)

# Test: do(rainfall_upper = 0) → should reduce flood_risk
result = api.run(InterventionRequest(
    variable='rainfall_upper',
    value=0.0,
    target_variables=['flood_risk', 'downstream_level'],
    context={'soil_moisture': 0.5, 'dam_release': 0.3},
))

print('Intervention: do(rainfall_upper = 0)')
print(f'  Original:       {result.original_values}')
print(f'  Counterfactual: {result.counterfactual_values}')
print(f'  Causal Effects: {result.causal_effects}')
print(f'  Confidence:     {result.confidence}')

In [None]:
# Sensitivity analysis: sweep rainfall_upper 0 → 1
sweep = api.sensitivity_analysis(
    variable='rainfall_upper',
    values=np.linspace(0, 1, 11).tolist(),
    target='flood_risk',
    context={'soil_moisture': 0.5, 'dam_release': 0.3},
)

print('Sensitivity: rainfall_upper → flood_risk')
for s in sweep:
    print(f"  rain={s['do_value']:.1f} → risk_cf={s['counterfactual']:.4f}  (Δ={s['effect']:+.4f})")