# Phase 4: Causal Inference with GNN Embeddings

**Goal:** Estimate the causal effect of an intervention (a new security check) on fraud probability. We will use the powerful node embeddings from our trained GNN as features for a causal model.

**Methodology:**
1. Load the trained GNN model from MLflow.
2. Generate node embeddings for all transactions.
3. Simulate a treatment `T`: A binary variable indicating which transactions received the new security check.
4. Define an outcome `Y`: The fraud label.
5. Define confounders `W`: The GNN embeddings, which capture rich information about the transaction and its neighborhood.
6. Use **EconML's Double Machine Learning (DML)** estimator to calculate the Average Treatment Effect (ATE).

In [None]:
import mlflow
import torch
import numpy as np
import pandas as pd
from econml.dml import DML
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
import logging
import os
import sys

# --- FIX 1: Add project root to Python path ---
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    print(f"Adding project root to path: {project_root}")
    sys.path.insert(0, project_root)
# ------------------------------------------

# --- FIX 2: Explicitly import the model class ---
# This helps PyTorch's unpickler find the class definition and provides a clear error if the path is wrong.
try:
    from src.models import GraphSAGEModel
    from src.graph_construction import build_graph_data
    print("Successfully imported custom modules from 'src' directory.")
except ImportError as e:
    print(f"Failed to import custom modules. Error: {e}")
    print("Please ensure you are running this notebook from the 'notebooks' directory and have RESTARTED the kernel.")
# -----------------------------------------------

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

### 1. Load Model, Data, and Generate Embeddings

First, you need to get the `run_id` for your best GNN model from the MLflow UI.

**IMPORTANT:** If you get a `ModuleNotFoundError` in the cell below, it means the kernel needs to be restarted to recognize the path change we made above. 

**To fix it:**
1. Run the cell above this one (the one with the imports).
2. From the menu in VS Code or Jupyter, select **"Restart Kernel"**.
3. Run all the cells again from the top.

In [None]:
# Set the MLflow tracking URI to point to the project's mlruns directory
mlflow.set_tracking_uri('../mlruns')

# !!! IMPORTANT !!!
# Paste your GNN's Run ID from the MLflow UI below.
GNN_RUN_ID = 'ee25bc1644ff4cca94ec2eeea0deeba2' # <--- REPLACE THIS

W = None # Initialize W to ensure it exists
model = None # Initialize model to ensure it exists

try:
    # Load the trained model
    logged_model = f"runs:/{GNN_RUN_ID}/graphsage-model"
    model = mlflow.pytorch.load_model(logged_model)
    model.eval()
    print("Model loaded successfully.")

except mlflow.exceptions.MlflowException as e:
    print(f"MLflow Error: {e}")
    print(f"\nPlease make sure you have replaced the placeholder with a valid Run ID from your MLflow experiment.")
except ModuleNotFoundError as e:
    print(f"Module Not Found Error: {e}")
    print(f"\nThis is a common notebook issue. Please RESTART the kernel from the menu and run all cells again.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

### 2. Generate Embeddings and Simulate Treatment

In [None]:
# This cell will only run if the model was loaded successfully
if model is not None:
    # Load graph data
    PROCESSED_DATA_PATH = '../data/processed/train_merged.feather'
    GRAPH_DATA_PATH = '../data/processed/fraud_graph.pt'
    data = build_graph_data(PROCESSED_DATA_PATH, GRAPH_DATA_PATH)
    print("Graph data loaded successfully.")

    # Generate node embeddings (W)
    with torch.no_grad():
        W = model(data.x, data.edge_index).numpy()
    print(f"Generated embeddings (confounders W) with shape: {W.shape}")

    # Outcome Y (isFraud)
    Y = data.y.numpy()

    # Treatment T: Simulate that the security check was randomly assigned to 50% of transactions
    T = np.random.binomial(1, 0.5, size=data.num_nodes)

    print(f"Shape of Y (outcome): {Y.shape}")
    print(f"Shape of T (treatment): {T.shape}")
else:
    print("Model was not loaded due to an error in the previous cell. Please fix the error and re-run.")

### 3. Estimate Average Treatment Effect (ATE) with DML

DML is a powerful technique that uses two machine learning models to remove confounding bias:
- A model to predict the outcome `Y` from the confounders `W`.
- A model to predict the treatment `T` from the confounders `W`.

In [None]:
if W is not None:
    # Initialize the DML estimator
    est = DML(
        model_y=RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42),
        model_t=RandomForestClassifier(n_estimators=100, min_samples_leaf=10, random_state=42),
        random_state=42
    )

    # Fit the estimator. We only use W (confounders) here as we are interested in the ATE,
    # not heterogeneous effects that depend on other features X.
    print("Fitting the DML estimator...")
    est.fit(Y, T, W=W)
    print("Fit complete.")

    # Get the average treatment effect and its confidence interval
    # .ate_ is the property that stores the ATE after fitting
    ate_estimate = est.ate_
    ate_interval = est.ate_interval()

    print(f"\nEstimated Average Treatment Effect (ATE): {ate_estimate:.4f}")
    print(f"95% Confidence Interval: [{ate_interval[0]:.4f}, {ate_interval[1]:.4f}]")
else:
    print("Cannot estimate ATE because the previous steps failed.")

### 4. Interpretation

The ATE tells us the average change in the probability of fraud if we were to apply the security check to all transactions, compared to none of them.

- **A negative ATE** would suggest that the security check *reduces* the probability of fraud.
- **A positive ATE** would suggest it *increases* it (which would be unexpected).
- **An ATE close to zero** with a confidence interval that includes zero would suggest the check has no statistically significant effect.