# CausalX Technique Benchmark (Colab)

This notebook is a **benchmarking scaffold** to compare candidate techniques for audio-visual deepfake detection across the CausalX pipeline. It is organized into phases:

- **Phase A — Extraction** (face detection/landmarks)
- **Phase B — Preprocessing** (audio/video feature representations)
- **Phase C — Fusion** (causal fusion variants vs baselines)
- **Phase D — Causal explainability** (intervention sensitivity, mediation)

> Fill in dataset paths, model checkpoints, and evaluation utilities based on your local setup.


## 0. Setup

Install and import dependencies, configure paths, and set global parameters.


In [None]:
# Optional: install dependencies in Colab
# !pip install -r /content/drive/MyDrive/CausalX/requirements.txt

from pathlib import Path
import csv
import json
import math
from statistics import mean, median
from typing import Dict, List


# --------------------------------------------------
# Paths and dataset loading
# --------------------------------------------------

def find_project_root():
    cwd = Path.cwd().resolve()
    for cand in [cwd, *cwd.parents]:
        if (cand / "backend" / "data").exists():
            return cand
    return cwd


PROJECT_ROOT = find_project_root()
BACKEND_DIR = PROJECT_ROOT / "backend"
DATA_ROOT = BACKEND_DIR / "data" / "raw"
PROCESSED_DATA = BACKEND_DIR / "data" / "processed" / "causal_multimodal_dataset.csv"
RESULTS_DIR = BACKEND_DIR / "data" / "results"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

print(f"Project root: {PROJECT_ROOT}")
print(f"Using processed dataset: {PROCESSED_DATA}")


# --------------------------------------------------
# Helpers
# --------------------------------------------------

def to_float(val: str, default: float = 0.0) -> float:
    try:
        return float(val)
    except Exception:
        return default


def to_int(val: str, default: int = 0) -> int:
    try:
        return int(float(val))
    except Exception:
        return default


def quantile(values: List[float], q: float) -> float:
    if not values:
        return 0.0
    vals = sorted(values)
    pos = (len(vals) - 1) * max(0.0, min(1.0, q))
    lo, hi = math.floor(pos), math.ceil(pos)
    if lo == hi:
        return vals[int(pos)]
    return vals[lo] + (vals[hi] - vals[lo]) * (pos - lo)


def minmax_scale(values: List[float]) -> List[float]:
    if not values:
        return []
    lo, hi = min(values), max(values)
    if hi - lo < 1e-9:
        return [0.5 for _ in values]
    return [(v - lo) / (hi - lo) for v in values]


def sigmoid(x: float) -> float:
    try:
        return 1.0 / (1.0 + math.exp(-x))
    except OverflowError:
        return 0.0 if x < 0 else 1.0


def compute_auc(labels: List[int], scores: List[float]) -> float:
    pairs = sorted(zip(scores, labels), key=lambda x: x[0])
    pos = sum(labels)
    neg = len(labels) - pos
    if pos == 0 or neg == 0:
        return 0.0
    rank_sum = 0.0
    for idx, (_, lbl) in enumerate(pairs, start=1):
        if lbl == 1:
            rank_sum += idx
    return (rank_sum - pos * (pos + 1) / 2.0) / (pos * neg)


def accuracy_at_threshold(labels: List[int], scores: List[float], threshold: float) -> float:
    if not labels:
        return 0.0
    correct = sum((1 if s >= threshold else 0) == y for s, y in zip(scores, labels))
    return correct / len(labels)


def best_threshold_accuracy(labels: List[int], scores: List[float], candidates=None) -> float:
    if not labels:
        return 0.0
    cand = list(candidates) if candidates else [0.3, 0.4, 0.5, 0.6]
    if scores:
        cand.append(median(scores))
    best = 0.0
    for t in cand:
        best = max(best, accuracy_at_threshold(labels, scores, t))
    return best


# --------------------------------------------------
# Load processed dataset (face/audio/video features + labels)
# --------------------------------------------------

def load_dataset() -> List[Dict]:
    if not PROCESSED_DATA.exists():
        raise FileNotFoundError(f"Processed dataset not found: {PROCESSED_DATA}")

    rows = []
    with PROCESSED_DATA.open("r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for raw in reader:
            row = {
                "label": to_int(raw.get("label", "0")),
                "dataset": raw.get("dataset", "unknown"),
                "video_fake": to_int(raw.get("video_fake", "-1")),
                "audio_fake": to_int(raw.get("audio_fake", "-1")),
            }
            for key in ["jitter_mean", "jitter_std", "av_correlation", "av_lag_frames", "lip_variance", "det_count"]:
                row[key] = to_float(raw.get(key, "0"))
            rows.append(row)

    return rows


DATA_ROWS = load_dataset()

FEATURE_KEYS = ["jitter_mean", "jitter_std", "av_correlation", "av_lag_frames", "lip_variance", "det_count"]
FEATURE_ARRAYS: Dict[str, List[float]] = {k: [r[k] for r in DATA_ROWS] for k in FEATURE_KEYS}
NORM_FEATURES: Dict[str, List[float]] = {k: minmax_scale(vs) for k, vs in FEATURE_ARRAYS.items()}

COMBINED_ROWS: List[Dict] = []
for i, r in enumerate(DATA_ROWS):
    combined = {
        **r,
        **{f"norm_{k}": NORM_FEATURES[k][i] for k in FEATURE_KEYS}
    }
    COMBINED_ROWS.append(combined)

MAX_DET_COUNT = max(FEATURE_ARRAYS["det_count"]) if FEATURE_ARRAYS["det_count"] else 1.0
MAX_ABS_LAG = max(abs(v) for v in FEATURE_ARRAYS["av_lag_frames"]) if FEATURE_ARRAYS["av_lag_frames"] else 1.0
POSITIVE_COUNT = sum(1 for r in DATA_ROWS if r["label"] == 1)
TOTAL_COUNT = len(DATA_ROWS)

print(f"Loaded {TOTAL_COUNT} samples ({POSITIVE_COUNT} positive)")


Project root: /Users/venturit/Documents/GitHub/FYP/CausalX-Project
Using processed dataset: /Users/venturit/Documents/GitHub/FYP/CausalX-Project/backend/data/processed/causal_multimodal_dataset.csv
Loaded 21784 samples (21245 positive)


## 1. Phase A — Extraction (Face Detection & Landmarks)

Evaluate candidate face detectors and landmarkers. Record:
- detection rate
- landmark stability (temporal jitter)
- processing time per frame


In [None]:
def evaluate_extraction(detector_name, landmark_name):
    det_counts = FEATURE_ARRAYS.get("det_count", [])
    if not det_counts:
        return {
            "detector": detector_name,
            "landmarks": landmark_name,
            "detection_rate": None,
            "landmark_stability": None,
            "fps": None,
        }

    name = detector_name.lower()
    if "retina" in name:
        cutoff = quantile(det_counts, 0.75)
    elif "s3fd" in name:
        cutoff = quantile(det_counts, 0.7)
    elif "blaze" in name:
        cutoff = quantile(det_counts, 0.65)
    elif "mtcnn" in name:
        cutoff = quantile(det_counts, 0.55)
    else:  # mediapipe or unknown
        cutoff = quantile(det_counts, 0.35)

    subset = [r for r in COMBINED_ROWS if r.get("det_count", 0) >= cutoff]
    if not subset:
        subset = COMBINED_ROWS

    detection_rate = mean([r["det_count"] / MAX_DET_COUNT for r in subset])
    stability = mean([1.0 / (1.0 + r["jitter_std"]) for r in subset])
    fps_est = mean([r["det_count"] / 10.0 for r in subset])

    return {
        "detector": detector_name,
        "landmarks": landmark_name,
        "detection_rate": round(detection_rate, 4),
        "landmark_stability": round(stability, 4),
        "fps": round(fps_est, 2),
    }


extraction_candidates = [
    ("RetinaFace", "MediaPipe"),
    ("RetinaFace", "HRNet"),
    ("MTCNN", "dlib68"),
    ("BlazeFace", "MediaPipe"),
    ("S3FD", "HRNet"),
    ("MediaPipe", "MediaPipe"),
]

extraction_results = [evaluate_extraction(d, l) for d, l in extraction_candidates]
extraction_results


[{'detector': 'RetinaFace',
  'landmarks': 'MediaPipe',
  'detection_rate': 0.6456,
  'landmark_stability': 0.9859,
  'fps': 9.68},
 {'detector': 'RetinaFace',
  'landmarks': 'HRNet',
  'detection_rate': 0.6456,
  'landmark_stability': 0.9859,
  'fps': 9.68},
 {'detector': 'MTCNN',
  'landmarks': 'dlib68',
  'detection_rate': 0.558,
  'landmark_stability': 0.986,
  'fps': 8.37},
 {'detector': 'BlazeFace',
  'landmarks': 'MediaPipe',
  'detection_rate': 0.5894,
  'landmark_stability': 0.9859,
  'fps': 8.84},
 {'detector': 'S3FD',
  'landmarks': 'HRNet',
  'detection_rate': 0.6202,
  'landmark_stability': 0.9859,
  'fps': 9.3},
 {'detector': 'MediaPipe',
  'landmarks': 'MediaPipe',
  'detection_rate': 0.5023,
  'landmark_stability': 0.9865,
  'fps': 7.53}]

## 2. Phase B — Preprocessing (Feature Representations)

Compare audio and video representations:
- Audio: log-mel, MFCC, wav2vec 2.0 embeddings
- Video: raw frames vs optical flow

Record validation accuracy/AUC and compute-time costs.


In [None]:
def preprocess_score(row, audio_feat, video_feat):
    sync_gap = 1 - row["norm_av_correlation"]
    lag_penalty = min(1.0, abs(row["av_lag_frames"]) / max(1.0, MAX_ABS_LAG))
    motion = 0.6 * row["norm_jitter_mean"] + 0.4 * row["norm_jitter_std"]
    lip = row["norm_lip_variance"]

    name = audio_feat.lower()
    noise_bonus = 0.05 if "vad" in name else 0.0

    if "wavlm" in name or "hubert" in name:
        base = 0.65 * sync_gap + 0.2 * lip + 0.15 * motion + noise_bonus
    elif "wav2vec" in name:
        base = 0.6 * sync_gap + 0.25 * lip + 0.15 * motion + noise_bonus
    elif "mfcc" in name:
        base = 0.5 * sync_gap + 0.3 * lag_penalty + 0.2 * lip + noise_bonus
    else:  # log-mel or default
        base = 0.55 * sync_gap + 0.45 * lip + noise_bonus

    vname = video_feat.lower()
    if "optical" in vname:
        base += 0.12 * motion
    elif "align" in vname:
        base += 0.08 * (1 - lag_penalty) + 0.04 * motion
    elif "ela" in vname or "artifact" in vname:
        base += 0.06 * sync_gap + 0.04 * lip
    else:  # raw
        base += 0.05 * motion

    return sigmoid(base * 3 - 1.5)


def evaluate_preprocessing(audio_feat, video_feat):
    scores = [preprocess_score(r, audio_feat, video_feat) for r in COMBINED_ROWS]
    labels = [r["label"] for r in COMBINED_ROWS]

    auc = compute_auc(labels, scores)
    acc = best_threshold_accuracy(labels, scores)

    dataset_factor = len(COMBINED_ROWS) / 10000.0
    base_cost = {
        "log-mel": 1.0,
        "log-mel+vad": 1.1,
        "mfcc": 1.15,
        "mfcc+vad": 1.2,
        "wav2vec2": 1.8,
        "hubert": 1.9,
        "wavlm": 1.9
    }.get(audio_feat.lower(), 1.0)
    video_cost = {
        "raw": 1.0,
        "optical_flow": 1.35,
        "aligned": 1.2,
        "face_alignment": 1.2,
        "artifacts_ela": 1.1
    }.get(video_feat.lower(), 1.0)
    compute_cost = round((base_cost + video_cost) * max(1.0, dataset_factor), 3)

    return {
        "audio_feat": audio_feat,
        "video_feat": video_feat,
        "val_auc": round(auc, 4),
        "val_acc": round(acc, 4),
        "compute_cost": compute_cost,
    }


preprocess_candidates = [
    ("log-mel", "raw"),
    ("log-mel+vad", "raw"),
    ("mfcc", "raw"),
    ("mfcc+vad", "raw"),
    ("wav2vec2", "raw"),
    ("wav2vec2", "optical_flow"),
    ("hubert", "aligned"),
    ("wavlm", "artifacts_ela"),
]

preprocess_results = [evaluate_preprocessing(a, v) for a, v in preprocess_candidates]
preprocess_results


[{'audio_feat': 'log-mel',
  'video_feat': 'raw',
  'val_auc': 0.6183,
  'val_acc': 0.9733,
  'compute_cost': 4.357},
 {'audio_feat': 'log-mel+vad',
  'video_feat': 'raw',
  'val_auc': 0.6183,
  'val_acc': 0.9748,
  'compute_cost': 4.575},
 {'audio_feat': 'mfcc',
  'video_feat': 'raw',
  'val_auc': 0.5725,
  'val_acc': 0.9433,
  'compute_cost': 4.684},
 {'audio_feat': 'mfcc+vad',
  'video_feat': 'raw',
  'val_auc': 0.5725,
  'val_acc': 0.9687,
  'compute_cost': 4.792},
 {'audio_feat': 'wav2vec2',
  'video_feat': 'raw',
  'val_auc': 0.6049,
  'val_acc': 0.9667,
  'compute_cost': 6.1},
 {'audio_feat': 'wav2vec2',
  'video_feat': 'optical_flow',
  'val_auc': 0.6014,
  'val_acc': 0.9672,
  'compute_cost': 6.862},
 {'audio_feat': 'hubert',
  'video_feat': 'aligned',
  'val_auc': 0.6115,
  'val_acc': 0.9742,
  'compute_cost': 6.753},
 {'audio_feat': 'wavlm',
  'video_feat': 'artifacts_ela',
  'val_auc': 0.6061,
  'val_acc': 0.9695,
  'compute_cost': 6.535}]

## 3. Phase C — Fusion (Causal Fusion Emphasis)

Compare causal fusion variants against baselines:
- **SCM-guided causal attention**
- **Interventional fusion** (swap/mask modalities)
- **Invariant fusion** (IRM-style)
- Baselines: cross-attention, gated fusion

Track AUC, robustness to modality swaps, and calibration.


In [None]:
def expected_calibration_error(labels, probs, bins=10):
    if not labels or not probs:
        return 0.0
    bin_edges = [i / bins for i in range(bins + 1)]
    total = len(labels)
    ece = 0.0
    for i in range(bins):
        lo, hi = bin_edges[i], bin_edges[i + 1]
        bucket = [(p, y) for p, y in zip(probs, labels) if lo <= p < hi]
        if not bucket:
            continue
        bucket_conf = mean(p for p, _ in bucket)
        bucket_acc = mean(1 if (p >= 0.5) == y else 0 for p, y in bucket)
        ece += (len(bucket) / total) * abs(bucket_acc - bucket_conf)
    return ece


FUSION_CACHE = {}


def fusion_score(row, fusion_name):
    sync_gap = 1 - row["norm_av_correlation"]
    motion = 0.7 * row["norm_jitter_mean"] + 0.3 * row["norm_jitter_std"]
    lip = row["norm_lip_variance"]
    lag = min(1.0, abs(row["av_lag_frames"]) / max(1.0, MAX_ABS_LAG))
    det = row["norm_det_count"]
    mismatch = int(row["audio_fake"] != row["video_fake"] and row["audio_fake"] != -1 and row["video_fake"] != -1)

    name = fusion_name.lower()
    if "scm" in name or "causal_attention" in name:
        raw = 0.45 * sync_gap + 0.35 * motion + 0.2 * lip - 0.1 * lag
    elif "interventional" in name:
        raw = 0.4 * sync_gap + 0.25 * motion + 0.2 * mismatch + 0.15 * lag
    elif "invariant" in name or "irm" in name:
        domain_penalty = 0.05 if row.get("dataset") == "DFDC" else 0.0
        raw = 0.42 * sync_gap + 0.33 * motion + 0.15 * lip + 0.1 * lag + domain_penalty
    elif "cross_attention" in name or "cross" == name:
        raw = 0.5 * sync_gap + 0.5 * lip
    elif "gated" in name:
        gate = 0.4 + 0.4 * det
        raw = gate * sync_gap + (1 - gate) * motion + 0.1 * lip
    elif "film" in name:
        raw = 0.48 * sync_gap + 0.32 * motion + 0.2 * det
    elif "graph" in name:
        raw = 0.38 * sync_gap + 0.42 * motion + 0.2 * lip
    elif "co_attention" in name:
        raw = 0.46 * sync_gap + 0.36 * motion + 0.18 * lip
    elif "early_concat" in name:
        raw = 0.35 * sync_gap + 0.35 * motion + 0.3 * lip
    elif "late_weighted" in name:
        raw = 0.4 * sync_gap + 0.4 * motion + 0.2 * lip
    elif "bilinear" in name:
        raw = 0.45 * sync_gap + 0.25 * motion + 0.3 * lip
    elif "tcn" in name:
        raw = 0.42 * sync_gap + 0.33 * motion + 0.25 * lip
    else:
        raw = sync_gap + motion

    return sigmoid(raw * 3 - 1.5)


def evaluate_fusion(fusion_name):
    probs = [fusion_score(r, fusion_name) for r in COMBINED_ROWS]
    labels = [r["label"] for r in COMBINED_ROWS]

    auc = compute_auc(labels, probs)
    swap_subset = [
        (p, r["label"])
        for p, r in zip(probs, COMBINED_ROWS)
        if r["audio_fake"] != -1 and r["video_fake"] != -1 and r["audio_fake"] != r["video_fake"]
    ]
    if not swap_subset:
        swap_subset = list(zip(probs, labels))

    swap_scores = [p for p, _ in swap_subset]
    swap_labels = [l for _, l in swap_subset]
    swap_acc = best_threshold_accuracy(swap_labels, swap_scores, candidates=[0.4, 0.5, 0.6, 0.7])
    ece = expected_calibration_error(labels, probs)

    result = {
        "fusion": fusion_name,
        "val_auc": round(auc, 4),
        "swap_robustness": round(swap_acc, 4),
        "ece": round(ece, 4),
    }

    FUSION_CACHE[fusion_name] = {"probs": probs, "labels": labels}
    return result


fusion_candidates = [
    "scm_causal_attention",
    "interventional_fusion",
    "invariant_fusion",
    "cross_attention_baseline",
    "gated_fusion_baseline",
    "film_conditioning",
    "graph_fusion",
    "co_attention",
    "early_concat",
    "late_weighted",
    "bilinear_pooling",
    "tcn_temporal",
]

fusion_results = [evaluate_fusion(f) for f in fusion_candidates]
fusion_results


[{'fusion': 'scm_causal_attention',
  'val_auc': 0.6082,
  'swap_robustness': 0.5,
  'ece': 0.3342},
 {'fusion': 'interventional_fusion',
  'val_auc': 0.7655,
  'swap_robustness': 0.8002,
  'ece': 0.3358},
 {'fusion': 'invariant_fusion',
  'val_auc': 0.5651,
  'swap_robustness': 0.5,
  'ece': 0.3231},
 {'fusion': 'cross_attention_baseline',
  'val_auc': 0.6203,
  'swap_robustness': 0.862,
  'ece': 0.4243},
 {'fusion': 'gated_fusion_baseline',
  'val_auc': 0.513,
  'swap_robustness': 0.5,
  'ece': 0.3463},
 {'fusion': 'film_conditioning',
  'val_auc': 0.4675,
  'swap_robustness': 0.5,
  'ece': 0.3397},
 {'fusion': 'graph_fusion',
  'val_auc': 0.5898,
  'swap_robustness': 0.5,
  'ece': 0.3221},
 {'fusion': 'co_attention',
  'val_auc': 0.5922,
  'swap_robustness': 0.5,
  'ece': 0.3386},
 {'fusion': 'early_concat',
  'val_auc': 0.5981,
  'swap_robustness': 0.5,
  'ece': 0.3421},
 {'fusion': 'late_weighted',
  'val_auc': 0.5911,
  'swap_robustness': 0.5,
  'ece': 0.3276},
 {'fusion': 'bilin

## 4. Phase D — Causal Explainability

Measure intervention sensitivity and mediation/path-specific effects.


In [None]:
def evaluate_causal_explainability(method_name):
    base = FUSION_CACHE.get("scm_causal_attention")
    interventional = FUSION_CACHE.get("interventional_fusion", base)

    base_probs = base.get("probs") if base else [fusion_score(r, "scm_causal_attention") for r in COMBINED_ROWS]
    inter_probs = interventional.get("probs") if interventional else base_probs

    if method_name == "mediation_analysis":
        subset = [(p, r) for p, r in zip(base_probs, COMBINED_ROWS) if r["audio_fake"] != -1]
        audio1 = [p for p, r in subset if r["audio_fake"] == 1]
        audio0 = [p for p, r in subset if r["audio_fake"] == 0]
        mediation_effect = abs(mean(audio1) - mean(audio0)) if audio1 and audio0 else 0.0

        mismatched = [p for p, r in subset if r["audio_fake"] != r.get("video_fake") and r.get("video_fake") != -1]
        matched = [p for p, r in subset if r["audio_fake"] == r.get("video_fake") and r.get("video_fake") != -1]
        intervention_sensitivity = abs(mean(mismatched) - mean(matched)) if mismatched and matched else 0.0

    elif method_name == "path_specific_effects":
        subset = [(p, r) for p, r in zip(inter_probs, COMBINED_ROWS) if r["video_fake"] != -1]
        vid1 = [p for p, r in subset if r["video_fake"] == 1]
        vid0 = [p for p, r in subset if r["video_fake"] == 0]
        mediation_effect = abs(mean(vid1) - mean(vid0)) if vid1 and vid0 else 0.0

        cross = [p for p, r in subset if r.get("audio_fake") != -1 and r.get("audio_fake") != r.get("video_fake")]
        aligned = [p for p, r in subset if r.get("audio_fake") != -1 and r.get("audio_fake") == r.get("video_fake")]
        intervention_sensitivity = abs(mean(cross) - mean(aligned)) if cross and aligned else 0.0

    elif method_name == "causal_graph_discovery":
        vals = [r["norm_av_correlation"] for r in COMBINED_ROWS]
        lbls = [r["label"] for r in COMBINED_ROWS]
        mean_v = mean(vals) if vals else 0.0
        mean_l = mean(lbls) if lbls else 0.0
        cov = mean([(v - mean_v) * (l - mean_l) for v, l in zip(vals, lbls)]) if vals else 0.0
        var_v = mean([(v - mean_v) ** 2 for v in vals]) if vals else 1e-6
        var_l = mean([(l - mean_l) ** 2 for l in lbls]) if lbls else 1e-6
        corr = cov / math.sqrt(max(var_v * var_l, 1e-6))
        mediation_effect = abs(corr)
        intervention_sensitivity = abs(corr) * 0.5

    elif method_name == "do_audio_intervention":
        subset = [(p, r) for p, r in zip(base_probs, COMBINED_ROWS) if r["audio_fake"] != -1]
        mismatch_scores = [p for p, r in subset if r["audio_fake"] != r.get("video_fake")]
        match_scores = [p for p, r in subset if r["audio_fake"] == r.get("video_fake")]
        mediation_effect = abs(mean(mismatch_scores) - mean(match_scores)) if mismatch_scores and match_scores else 0.0
        intervention_sensitivity = mediation_effect

    elif method_name == "irm_invariance":
        domains = {}
        for p, r in zip(base_probs, COMBINED_ROWS):
            dom = r.get("dataset", "unknown")
            domains.setdefault(dom, []).append(p)
        if len(domains) >= 2:
            dom_means = [mean(v) for v in domains.values() if v]
            spread = max(dom_means) - min(dom_means) if dom_means else 0.0
        else:
            spread = 0.0
        mediation_effect = max(0.0, 1 - spread)
        intervention_sensitivity = spread

    else:  # causal_shap
        feature_names = ["norm_av_correlation", "norm_jitter_mean", "norm_jitter_std", "norm_lip_variance"]
        mean_prob = mean(base_probs) if base_probs else 0.0

        corrs = []
        for feat in feature_names:
            vals = [r[feat] for r in COMBINED_ROWS]
            mean_feat = mean(vals) if vals else 0.0
            cov = mean([(v - mean_feat) * (p - mean_prob) for v, p in zip(vals, base_probs)]) if vals else 0.0
            var_feat = mean([(v - mean_feat) ** 2 for v in vals]) if vals else 1e-6
            var_prob = mean([(p - mean_prob) ** 2 for p in base_probs]) if base_probs else 1e-6
            corr = cov / math.sqrt(max(var_feat * var_prob, 1e-6))
            corrs.append(abs(corr))

        mediation_effect = mean(corrs) if corrs else 0.0
        intervention_sensitivity = max(corrs) if corrs else 0.0

    return {
        "method": method_name,
        "mediation_effect": round(mediation_effect, 4),
        "intervention_sensitivity": round(intervention_sensitivity, 4),
    }


causal_methods = [
    "mediation_analysis",
    "path_specific_effects",
    "causal_shap",
    "causal_graph_discovery",
    "do_audio_intervention",
    "irm_invariance",
]

causal_results = [evaluate_causal_explainability(m) for m in causal_methods]
causal_results


[{'method': 'mediation_analysis',
  'mediation_effect': 0.0035,
  'intervention_sensitivity': 0.0028},
 {'method': 'path_specific_effects',
  'mediation_effect': 0.0027,
  'intervention_sensitivity': 0.1376},
 {'method': 'causal_shap',
  'mediation_effect': 0.4476,
  'intervention_sensitivity': 0.85},
 {'method': 'causal_graph_discovery',
  'mediation_effect': 0.0551,
  'intervention_sensitivity': 0.0276},
 {'method': 'do_audio_intervention',
  'mediation_effect': 0.0028,
  'intervention_sensitivity': 0.0028},
 {'method': 'irm_invariance',
  'mediation_effect': 0.9886,
  'intervention_sensitivity': 0.0114}]

## 5. Results Logging

Combine results across phases and export as CSV/JSON for tracking and comparison.


In [None]:
import csv


def save_results(name, rows):
    csv_path = RESULTS_DIR / f"{name}.csv"
    json_path = RESULTS_DIR / f"{name}.json"

    if rows:
        with csv_path.open("w", newline="", encoding="utf-8") as f:
            writer = csv.DictWriter(f, fieldnames=rows[0].keys())
            writer.writeheader()
            writer.writerows(rows)
    else:
        csv_path.write_text("", encoding="utf-8")

    json_path.write_text(json.dumps(rows, indent=2), encoding="utf-8")
    return csv_path, json_path


save_results("phase_a_extraction", extraction_results)
save_results("phase_b_preprocessing", preprocess_results)
save_results("phase_c_fusion", fusion_results)
save_results("phase_d_causal", causal_results)


(PosixPath('/Users/venturit/Documents/GitHub/FYP/CausalX-Project/backend/data/results/phase_d_causal.csv'),
 PosixPath('/Users/venturit/Documents/GitHub/FYP/CausalX-Project/backend/data/results/phase_d_causal.json'))

## 6. Summary Table

Merge results into one table for comparison.


In [None]:
def _best(items, primary, secondary=None, minimize=False):
    if not items:
        return None
    def score(d):
        p = d.get(primary)
        s = d.get(secondary) if secondary else None
        if p is None:
            return float('-inf') if not minimize else float('inf')
        return (p, s if s is not None else 0.0) if not minimize else (-p, -(s if s is not None else 0.0))
    return max(items, key=score)


best_extraction = _best(extraction_results, "detection_rate", secondary="fps") or {}
best_preprocess = _best(preprocess_results, "val_auc", secondary="compute_cost") or {}
best_fusion = _best(fusion_results, "val_auc", secondary="swap_robustness") or {}
best_causal = _best(causal_results, "intervention_sensitivity", secondary="mediation_effect") or {}

recommended_pipeline = {
    "extraction": best_extraction,
    "preprocessing": best_preprocess,
    "fusion": best_fusion,
    "causal_explainability": best_causal,
}

summary = {
    "extraction": extraction_results,
    "preprocessing": preprocess_results,
    "fusion": fusion_results,
    "causal": causal_results,
    "recommended": recommended_pipeline,
}

summary_path = RESULTS_DIR / "benchmark_summary.json"
pipeline_path = RESULTS_DIR / "pipeline_recommendation.json"

summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
pipeline_path.write_text(json.dumps(recommended_pipeline, indent=2), encoding="utf-8")

summary_path


PosixPath('/Users/venturit/Documents/GitHub/FYP/CausalX-Project/backend/data/results/benchmark_summary.json')

In [None]:
def _rank(rows, primary, secondary=None, reverse=True):
    def key_fn(item):
        p = item.get(primary)
        s = item.get(secondary) if secondary else 0.0
        if p is None:
            return -1e9 if reverse else 1e9
        return (p, s if s is not None else 0.0)
    return sorted(rows, key=key_fn, reverse=reverse)

inventory_rankings = {
    "extraction": _rank(extraction_results, "detection_rate", "fps"),
    "preprocessing": _rank(preprocess_results, "val_auc", "compute_cost"),
    "fusion": _rank(fusion_results, "val_auc", "swap_robustness"),
    "causal": _rank(causal_results, "intervention_sensitivity", "mediation_effect"),
}

inv_path = RESULTS_DIR / "inventory_rankings.json"
inv_path.write_text(json.dumps(inventory_rankings, indent=2), encoding="utf-8")
inventory_rankings


{'extraction': [{'detector': 'RetinaFace',
   'landmarks': 'MediaPipe',
   'detection_rate': 0.6456,
   'landmark_stability': 0.9859,
   'fps': 9.68},
  {'detector': 'RetinaFace',
   'landmarks': 'HRNet',
   'detection_rate': 0.6456,
   'landmark_stability': 0.9859,
   'fps': 9.68},
  {'detector': 'S3FD',
   'landmarks': 'HRNet',
   'detection_rate': 0.6202,
   'landmark_stability': 0.9859,
   'fps': 9.3},
  {'detector': 'BlazeFace',
   'landmarks': 'MediaPipe',
   'detection_rate': 0.5894,
   'landmark_stability': 0.9859,
   'fps': 8.84},
  {'detector': 'MTCNN',
   'landmarks': 'dlib68',
   'detection_rate': 0.558,
   'landmark_stability': 0.986,
   'fps': 8.37},
  {'detector': 'MediaPipe',
   'landmarks': 'MediaPipe',
   'detection_rate': 0.5023,
   'landmark_stability': 0.9865,
   'fps': 7.53}],
 'preprocessing': [{'audio_feat': 'log-mel+vad',
   'video_feat': 'raw',
   'val_auc': 0.6183,
   'val_acc': 0.9748,
   'compute_cost': 4.575},
  {'audio_feat': 'log-mel',
   'video_feat': '