In [1]:
# pip install pgmpy

In [2]:
# pip install pyvis

In [3]:
# === Import Preamble ===
import os
import math
import random
import copy
import colorsys
import itertools
from itertools import combinations
from math import log2
from contextlib import contextmanager
from collections import defaultdict, deque

import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

from pyvis.network import Network
from graphviz import Digraph

from pgmpy.models import DiscreteBayesianNetwork
from pgmpy.factors.discrete import TabularCPD, DiscreteFactor
from pgmpy.inference import VariableElimination

# ============================================================================
g = Digraph(format="pdf", graph_attr={"rankdir":"LR", "size":"11.69,8.27!"})
g.edge("root","child", label="AS: Splatter→Pass\n0.40")
g.render("tiny_test", cleanup=True)  # writes tiny_test.pdf

# Optional: Graphviz layout availability flag
try:
    from networkx.drawing.nx_pydot import graphviz_layout as gv_layout
    _HAS_GV = True
except Exception:
    _HAS_GV = False


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Constants & Helpers 
random.seed(0)

HW_MIN, HW_MAX = 0.60, 0.95    # caps for handwriting reliability
EPS            = 1e-6          # small mass to avoid zeros

H_STATES = ["h1","h2","h3","h4","h5"]
AUTHOR_CODES = ["a1","a2","a3","a4"]

def normalise(dist, eps=EPS):
    """Replace zeros with eps, then renormalise to sum=1."""
    r = [x if x > 0 else eps for x in dist]
    S = sum(r)
    return [x / S for x in r]

def make_prior_dict(keys, values):
    """Build a prior dict and assert it sums to 1."""
    d = dict(zip(keys, values))
    assert abs(sum(d.values()) - 1.0) < 1e-8, "Priors must sum to 1"
    return d

def override_mode_timing(h_list, tup):
    """Bulk override mode_timing_pr for all modes in h_list with tuple tup."""
    for h in h_list:
        mode_timing_pr[h] = tup

def override_timing_author(h_list, t, tup):
    """Bulk override timing_author_pr[(h,t)] for h in h_list with tuple tup."""
    for h in h_list:
        timing_author_pr[(h, t)] = tup


In [5]:
# Mode-of-death priors (scenarios) 
mode_keys           = [f"h{i}" for i in range(1, 6)]
mode_pr_values      = [0.10, 0.08, 0.06, 0.46, 0.30]
mode_pr_biased_vals = [0.25, 0.04, 0.04, 0.52, 0.15]
mode_pr_fair_vals   = [0.10, 0.10, 0.10, 0.60, 0.10]
mode_pr_extreme_vals= [0.40, 0.05, 0.05, 0.40, 0.10]

mode_pr        = make_prior_dict(mode_keys, mode_pr_values)
mode_pr_biased = make_prior_dict(mode_keys, mode_pr_biased_vals)
mode_pr_fair   = make_prior_dict(mode_keys, mode_pr_fair_vals)
mode_pr_extreme= make_prior_dict(mode_keys, mode_pr_extreme_vals)

prs = {
    'True-world':        mode_pr,
    'Investigator-bias': mode_pr_biased,
    'Fair-reweighted':   mode_pr_fair,
    'Extreme-bias':      mode_pr_extreme,
}
bias_scenarios = list(prs.items())

# PARAMETER DEFINITIONS 

# 1a) Mode-of-death priors (duplicate kept; original order preserved)
mode_pr = {
    "h1": 0.10,
    "h2": 0.08,
    "h3": 0.06,
    "h4": 0.46,
    "h5": 0.30,
}

# 1b) Letter timing CPTs: (P(before), P(after))
mode_timing_pr = {
    "h1": (0.90, 0.10),
    "h2": (0.30, 0.70),
    "h3": (0.30, 0.70),  # identical to h2
    "h4": (0.80, 0.20),
    "h5": (0.80, 0.20),
}

# 1c) Author CPTs given (mode, timing): (P(victim), P(offender))
timing_author_pr = {
    # h1
    ("h1","t1"): (0.98, 0.02),
    ("h1","t2"): (0.10, 0.90),
    # h2
    ("h2","t1"): (0.98, 0.02),
    ("h2","t2"): (0.10, 0.90),
    # h3 (copy of h2)
    ("h3","t1"): (0.98, 0.02),
    ("h3","t2"): (0.10, 0.90),
    # h4
    ("h4","t1"): (0.90, 0.10),
    ("h4","t2"): (0.05, 0.95),
    # h5
    ("h5","t1"): (0.98, 0.02),
    ("h5","t2"): (0.02, 0.98),
}

# 1d) Raw 10-bin blood patterns | (h,t,a)
blood_pr_raw = {
    # h1
    ("h1","t1","a1"): [0.04,0.01,0.035,0.015,0.045,0.005,0.095,0.005,0.075,0.675],
    ("h1","t1","a2"): [0.40,0.05,0.20, 0.10, 0.10, 0.05, 0.05, 0.03, 0.02, 0.00],
    ("h1","t2","a3"): [0.30,0.10,0.20, 0.05, 0.15, 0.05, 0.10, 0.03, 0.02, 0.00],
    ("h1","t2","a4"): [0.20,0.15,0.15, 0.10, 0.10, 0.10, 0.10, 0.05, 0.05, 0.00],

    # h2
    ("h2","t1","a1"): [0.05,0.03,0.04,0.02,0.06,0.02,0.08,0.01,0.05,0.64],
    ("h2","t1","a2"): [0.35,0.05,0.25,0.10,0.10,0.05,0.05,0.02,0.03,0.00],
    ("h2","t2","a3"): [0.25,0.10,0.15,0.10,0.15,0.05,0.10,0.05,0.05,0.00],
    ("h2","t2","a4"): [0.15,0.15,0.10,0.10,0.10,0.10,0.15,0.10,0.05,0.00],

    # h3 (copy of h2)
    ("h3","t1","a1"): [0.05,0.03,0.04,0.02,0.06,0.02,0.08,0.01,0.05,0.64],
    ("h3","t1","a2"): [0.35,0.05,0.25,0.10,0.10,0.05,0.05,0.02,0.03,0.00],
    ("h3","t2","a3"): [0.25,0.10,0.15,0.10,0.15,0.05,0.10,0.05,0.05,0.00],
    ("h3","t2","a4"): [0.15,0.15,0.10,0.10,0.10,0.10,0.15,0.10,0.05,0.00],

    # h4
    ("h4","t1","a1"): [0.03,0.01,0.02,0.01,0.04,0.01,0.04,0.01,0.03,0.80],
    ("h4","t1","a2"): [0.25,0.05,0.20,0.10,0.10,0.05,0.05,0.02,0.03,0.15],
    ("h4","t2","a3"): [0.10,0.10,0.10,0.10,0.10,0.10,0.10,0.10,0.10,0.10],
    ("h4","t2","a4"): [0.10,0.15,0.10,0.10,0.10,0.10,0.10,0.10,0.05,0.10],

    # h5
    ("h5","t1","a1"): [0.02,0.005,0.015,0.005,0.03,0.005,0.08,0.001,0.04,0.799],
    ("h5","t1","a2"): [0.20,0.05,0.15,0.10,0.10,0.05,0.05,0.03,0.02,0.25],
    ("h5","t2","a3"): [0.25,0.10,0.20,0.10,0.15,0.05,0.10,0.02,0.03,0.00],
    ("h5","t2","a4"): [0.15,0.15,0.10,0.10,0.10,0.10,0.15,0.10,0.05,0.00],
}

# 1e) Handwriting reliability offsets
base_hw = 0.90
mode_offset   = {"h1":+0.02,"h2":+0.01,"h3":+0.01,"h4":-0.02,"h5":-0.03}
author_offset = {"a1":+0.01,"a2":-0.01,"a3":+0.01,"a4":-0.01}

# 1f) Raw DNA accuracies | (h,t,a,e)
dna_pr_raw = {
    # h1
    ("h1","t1","a1","e1"): (0.995,0.005), ("h1","t1","a1","e3"): (0.990,0.010),
    ("h1","t1","a1","e5"): (0.993,0.007), ("h1","t1","a1","e7"): (0.992,0.008),
    ("h1","t1","a1","e9"): (0.994,0.006),
    ("h1","t1","a2","e11"): (0.980,0.020), ("h1","t1","a2","e13"): (0.985,0.015),
    ("h1","t1","a2","e15"): (0.975,0.025), ("h1","t1","a2","e17"): (0.982,0.018),
    ("h1","t1","a2","e19"): (0.988,0.012),
    ("h1","t2","a3","e21"): (0.960,0.040), ("h1","t2","a3","e23"): (0.965,0.035),
    ("h1","t2","a3","e25"): (0.958,0.042), ("h1","t2","a3","e27"): (0.962,0.038),
    ("h1","t2","a3","e29"): (0.967,0.033),
    ("h1","t2","a4","e31"): (0.940,0.060), ("h1","t2","a4","e33"): (0.948,0.052),
    ("h1","t2","a4","e35"): (0.943,0.057), ("h1","t2","a4","e37"): (0.950,0.050),
    ("h1","t2","a4","e39"): (0.945,0.055),

    # h2
    ("h2","t1","a1","e1"): (0.990,0.010), ("h2","t1","a1","e3"): (0.992,0.008),
    ("h2","t1","a1","e5"): (0.988,0.012), ("h2","t1","a1","e7"): (0.991,0.009),
    ("h2","t1","a1","e9"): (0.989,0.011),
    ("h2","t1","a2","e11"): (0.970,0.030), ("h2","t1","a2","e13"): (0.975,0.025),
    ("h2","t1","a2","e15"): (0.965,0.035), ("h2","t1","a2","e17"): (0.972,0.028),
    ("h2","t1","a2","e19"): (0.968,0.032),
    ("h2","t2","a3","e21"): (0.950,0.050), ("h2","t2","a3","e23"): (0.955,0.045),
    ("h2","t2","a3","e25"): (0.948,0.052), ("h2","t2","a3","e27"): (0.952,0.048),
    ("h2","t2","a3","e29"): (0.957,0.043),
    ("h2","t2","a4","e31"): (0.930,0.070), ("h2","t2","a4","e33"): (0.938,0.062),
    ("h2","t2","a4","e35"): (0.933,0.067), ("h2","t2","a4","e37"): (0.940,0.060),
    ("h2","t2","a4","e39"): (0.935,0.065),

    # h3 (copy of h2)
    ("h3","t1","a1","e1"): (0.990,0.010), ("h3","t1","a1","e3"): (0.992,0.008),
    ("h3","t1","a1","e5"): (0.988,0.012), ("h3","t1","a1","e7"): (0.991,0.009),
    ("h3","t1","a1","e9"): (0.989,0.011),
    ("h3","t1","a2","e11"): (0.970,0.030), ("h3","t1","a2","e13"): (0.975,0.025),
    ("h3","t1","a2","e15"): (0.965,0.035), ("h3","t1","a2","e17"): (0.972,0.028),
    ("h3","t1","a2","e19"): (0.968,0.032),
    ("h3","t2","a3","e21"): (0.950,0.050), ("h3","t2","a3","e23"): (0.955,0.045),
    ("h3","t2","a3","e25"): (0.948,0.052), ("h3","t2","a3","e27"): (0.952,0.048),
    ("h3","t2","a3","e29"): (0.957,0.043),
    ("h3","t2","a4","e31"): (0.930,0.070), ("h3","t2","a4","e33"): (0.938,0.072),
    ("h3","t2","a4","e35"): (0.923,0.077), ("h3","t2","a4","e37"): (0.930,0.070),
    ("h3","t2","a4","e39"): (0.925,0.075),

    # h4
    ("h4","t1","a1","e1") : (0.980, 0.020),
    ("h4","t1","a1","e3") : (0.982, 0.018),
    ("h4","t1","a1","e5") : (0.978, 0.022),
    ("h4","t1","a1","e7") : (0.981, 0.019),
    ("h4","t1","a1","e9") : (0.979, 0.021),
    ("h4","t1","a2","e11"): (0.950, 0.050),
    ("h4","t1","a2","e13"): (0.955, 0.045),
    ("h4","t1","a2","e15"): (0.945, 0.055),
    ("h4","t1","a2","e17"): (0.952, 0.048),
    ("h4","t1","a2","e19"): (0.948, 0.052),
    ("h4","t2","a3","e21"): (0.940, 0.060),
    ("h4","t2","a3","e23"): (0.945, 0.055),
    ("h4","t2","a3","e25"): (0.938, 0.062),
    ("h4","t2","a3","e27"): (0.942, 0.058),
    ("h4","t2","a3","e29"): (0.947, 0.053),
    ("h4","t2","a4","e31"): (0.920, 0.080),
    ("h4","t2","a4","e33"): (0.928, 0.072),
    ("h4","t2","a4","e35"): (0.923, 0.077),
    ("h4","t2","a4","e37"): (0.930, 0.070),
    ("h4","t2","a4","e39"): (0.925, 0.075),

    # h5
    ("h5","t1","a1","e1") : (0.975, 0.025),
    ("h5","t1","a1","e3") : (0.978, 0.022),
    ("h5","t1","a1","e5") : (0.973, 0.027),
    ("h5","t1","a1","e7") : (0.976, 0.024),
    ("h5","t1","a1","e9") : (0.974, 0.026),
    ("h5","t1","a2","e11"): (0.950, 0.050),
    ("h5","t1","a2","e13"): (0.955, 0.045),
    ("h5","t1","a2","e15"): (0.945, 0.055),
    ("h5","t1","a2","e17"): (0.952, 0.048),
    ("h5","t1","a2","e19"): (0.948, 0.052),
    ("h5","t2","a3","e21"): (0.940, 0.060),
    ("h5","t2","a3","e23"): (0.945, 0.055),
    ("h5","t2","a3","e25"): (0.938, 0.062),
    ("h5","t2","a3","e27"): (0.942, 0.058),
    ("h5","t2","a3","e29"): (0.947, 0.053),
    ("h5","t2","a4","e31"): (0.920, 0.080),
    ("h5","t2","a4","e33"): (0.928, 0.072),
    ("h5","t2","a4","e35"): (0.923, 0.077),
    ("h5","t2","a4","e37"): (0.930, 0.070),
    ("h5","t2","a4","e39"): (0.925, 0.075),
}

# After dna_pr_raw is constructed, swap for offender-authored a2/a4
for (h, t, a, e), (pV, pO) in list(dna_pr_raw.items()):
    if a in ('a2', 'a4'):     # offender wrote
        dna_pr_raw[(h, t, a, e)] = (pO, pV)   # make DNA=O large when offender wrote

# Timing splits when a letter is NOT found (L='no')
mode_timing_pr_Lno = {
    "h1": (0.85, 0.15),
    "h2": (0.55, 0.45),
    "h3": (0.55, 0.45),
    "h4": (0.65, 0.35),
    "h5": (0.80, 0.20),
}
# Timing splits when a letter IS found (L='yes')
mode_timing_pr_Lyes = mode_timing_pr

AUTHORS_BY_T = {'t1': ['a1','a2'], 't2': ['a3','a4']}


In [6]:
# Normalise & build binary blood_yes_no_pr 
blood_pr = { key: normalise(raw) for key, raw in blood_pr_raw.items() }
for dist10 in blood_pr.values():
    assert abs(sum(dist10) - 1.0) < 1e-8

blood_yes_no_pr = {}
for (h, t, a), dist10 in blood_pr.items():
    p_yes = sum(dist10[i] for i in range(0, 10, 2))
    blood_yes_no_pr[(h, t, a)] = (p_yes, 1.0 - p_yes)

# (g) Audit-node CPTs (keep AL simple; make AS depend on BP)
alpha_S, beta_S = 0.80, 0.90
alpha_F, beta_F = 0.85, 0.95
alpha_E, beta_E = 0.90, 0.98

mode_offset["h3"] = mode_offset["h2"]

# Build handwriting_pr
handwriting_pr = {}
for (h, t, a) in blood_pr_raw:
    p = base_hw + mode_offset[h] + author_offset[a]
    p = min(max(p, HW_MIN), HW_MAX)
    handwriting_pr[(h, t, a)] = (p, 1.0 - p)

# normalised DNA
dna_pr = {}
for key, (p0, p1) in dna_pr_raw.items():
    S = p0 + p1
    dna_pr[key] = (p0 / S, p1 / S)


In [7]:
# Group-wise overrides 
override_mode_timing(["h4","h5"], (0.80, 0.20))
override_timing_author(["h4","h5"], "t1", (0.90, 0.10))
override_timing_author(["h4","h5"], "t2", (0.05, 0.95))

# Ensure h3 shares h2 CPTs (run once here)
mode_timing_pr["h3"] = mode_timing_pr["h2"]
for t in ("t1","t2"):
    timing_author_pr[("h3", t)] = timing_author_pr[("h2", t)]

# Copy raw tables h2 -> h3 (if not already)
for (h, t, a), raw in list(blood_pr_raw.items()):
    if h == "h2":
        blood_pr_raw[("h3", t, a)] = raw.copy()

for (h, t, a, e), pair in list(dna_pr_raw.items()):
    if h == "h2":
        dna_pr_raw[("h3", t, a, e)] = pair

# Rebuild blood_yes_no_pr after copies
blood_pr = {k: normalise(v) for k, v in blood_pr_raw.items()}  # each sums to 1
blood_yes_no_pr = {}
for (h, t, a), dist10 in blood_pr.items():
    p_yes = sum(dist10[i] for i in range(0, 10, 2))   # even bins = "blood yes"
    blood_yes_no_pr[(h, t, a)] = (p_yes, 1.0 - p_yes)

# 3) Handwriting per (h,t,a)
HW_MIN, HW_MAX = 0.60, 0.95
base_hw = 0.90
mode_offset["h3"] = mode_offset["h2"]  # ensure same offset as h2

handwriting_pr = {}
for (h, t, a) in blood_pr_raw:
    p = base_hw + mode_offset[h] + author_offset[a]
    p = min(max(p, HW_MIN), HW_MAX)
    handwriting_pr[(h, t, a)] = (p, 1.0 - p)

# 4) Timing splits conditioned on L
mode_timing_pr_Lyes = mode_timing_pr  # alias to be explicit

# 5) Completeness checks
missing_author = [(h, t) for h in H_STATES for t in ['t1','t2']
                  if (h, t) not in timing_author_pr]
assert not missing_author, f"timing_author_pr missing: {missing_author}"

missing_blood = [
    (h, t, a)
    for h in H_STATES
    for t in ['t1','t2']
    for a in AUTHORS_BY_T[t]
    if (h, t, a) not in blood_yes_no_pr
]
assert not missing_blood, f"blood_yes_no_pr missing: {missing_blood}"

missing_hw = [
    (h, t, a)
    for h in H_STATES
    for t in ['t1','t2']
    for a in AUTHORS_BY_T[t]
    if (h, t, a) not in handwriting_pr
]
assert not missing_hw, f"handwriting_pr missing: {missing_hw}"

missing_Lsplits = [h for h in H_STATES
                   if (h not in mode_timing_pr_Lyes) or (h not in mode_timing_pr_Lno)]
assert not missing_Lsplits, f"Timing splits missing for some h in Lyes/Lno: {missing_Lsplits}"


In [8]:
# 2.1 Mappings between codes and BN variables/states
H_MAP = {  # h-codes -> H states used in the notebook
    "h1":"O_delib", "h2":"O_self", "h3":"O_acc", "h4":"V_acc", "h5":"Suicide"
}
T_MAP = {"t1":"Before", "t2":"After"}        # timing
W_STATES = ["Victim","Offender"]

def a_code_for(W, t):
    """Map (writer, timing) -> 'a' code used in blood/DNA tables."""
    return {"Victim":{"t1":"a1","t2":"a3"},
            "Offender":{"t1":"a2","t2":"a4"}}[W][t]

# 2.2 Provide P(Lfound|H)
LFOUND_YES = {
    "h1":0.70, "h2":0.50, "h3":0.50, "h4":0.90, "h5":0.95
}

# 2.3 Helper to normalise a vector without epsilon smoothing
def nz_normalise(v):
    """normalise a vector to sum to one; requires strictly positive sum."""
    s = sum(v)
    assert s > 0, "Zero-sum vector"
    return [x/s for x in v]


In [9]:
# Build CPDs from edge tables

def build_cpds_from_edge_tables_v2():
    """
    Build path-aware CPDs that match a probability-tree/CEG view.
    Returns dict of BN loader keys.
    """
    H_ord = ["h1","h2","h3","h4","h5"]
    T_ord = ["t1","t2"]  # t1=Before, t2=After
    W_ord = ["Victim","Offender"]
    Lf_ord = ["No","Yes"]
    Blood_states = ["No","Yes"]
    BP_states = ["Fingerprint-like","Spatter"]
    DNA_states = ["V","O"]

    # 0) Prior on H
    prior_H = [mode_pr[h] for h in H_ord]

    # 1) P(Lfound | H)
    Lfound_table = [
        [1 - LFOUND_YES[h] for h in H_ord],   # No
        [    LFOUND_YES[h] for h in H_ord],   # Yes
    ]

    # 2) P(Ltime | H, Lfound)
    cols_before, cols_after = [], []
    for h in H_ord:
        for lf in Lf_ord:
            b, a = (mode_timing_pr_Lno[h] if lf == "No" else mode_timing_pr_Lyes[h])
            cols_before.append(b)
            cols_after.append(a)
    Ltime_table = [cols_before, cols_after]  # rows Before, After

    # 3) P(W | H, Ltime)
    WV, WO = [], []
    for h in H_ord:
        for t in T_ord:
            pV, pO = timing_author_pr[(h, t)]
            WV.append(pV); WO.append(pO)
    W_table = [WV, WO]

    # Local helper
    def a_code_for(W, t):
        return {"Victim":{"t1":"a1","t2":"a3"},
                "Offender":{"t1":"a2","t2":"a4"}}[W][t]

    # 4) P(Blood | H, Ltime, W)
    blood_rows_no, blood_rows_yes = [], []
    for h in H_ord:
        for t in T_ord:
            for W in W_ord:
                a = a_code_for(W, t)
                p_yes, p_no = blood_yes_no_pr[(h, t, a)]
                blood_rows_no.append(p_no)
                blood_rows_yes.append(p_yes)
    Blood_table = [blood_rows_no, blood_rows_yes]

    # 5) P(BP | Blood)
    PHI_BP = 0.60
    BP_table = [
        [1.0,        PHI_BP],      # Fingerprint-like
        [0.0, 1.0 - PHI_BP],       # Spatter
    ]

    # 6) P(DNA | H, Ltime, W) by averaging across e
    DNA_V, DNA_O = [], []
    e_groups = {}
    for (h, t, a, e), (p0, p1) in dna_pr_raw.items():
        e_groups.setdefault((h, t, a), []).append((p0, p1))
    for h in H_ord:
        for t in T_ord:
            for W in W_ord:
                a = a_code_for(W, t)
                pairs = e_groups.get((h, t, a), [])
                if not pairs:
                    pO = 0.5
                else:
                    pO = float(np.mean([p1/(p0+p1) for (p0, p1) in pairs]))
                pV = 1.0 - pO
                DNA_V.append(pV); DNA_O.append(pO)
    DNA_table = [DNA_V, DNA_O]

    # 7) P(AL | DNA)
    AL_PASS = 0.90
    AL_table = [
        [1 - AL_PASS, 1 - AL_PASS],  # Fail
        [    AL_PASS,     AL_PASS],  # Pass
    ]

    # 8) P(AS | BP)
    AS_PASS = 0.90
    AS_table = [
        [1 - AS_PASS, 1 - AS_PASS],  # Fail
        [    AS_PASS,     AS_PASS],  # Pass
    ]

    # 9) P(Hand | W) (marginalised reliability)
    def p_hand_V_given_W(W):
        num = den = 0.0
        for h in H_ord:
            pH = mode_pr[h]
            for lf in Lf_ord:
                pLf = (1 - LFOUND_YES[h]) if lf == "No" else LFOUND_YES[h]
                b, a = (mode_timing_pr_Lno[h] if lf == "No" else mode_timing_pr_Lyes[h])
                for t, pT in zip(T_ord, [b, a]):
                    pV, pO = timing_author_pr[(h, t)]
                    w_W = pV if W == "Victim" else pO
                    wgt = pH * pLf * pT * w_W
                    if wgt == 0.0:
                        continue
                    a_code = a_code_for(W, t)
                    pV0, pO0 = handwriting_pr[(h, t, a_code)]
                    num += wgt * pV0
                    den += wgt
        return num / den if den > 0 else 0.5

    Hand_V, Hand_O = [], []
    for W in W_ord:
        pV = p_hand_V_given_W(W)
        Hand_V.append(pV); Hand_O.append(1.0 - pV)
    Hand_table = [Hand_V, Hand_O]

    return {
        'H_prior'       : prior_H,
        'Lfound|H'      : Lfound_table,
        'Ltime|H,Lfound': Ltime_table,
        'W|H,Ltime'     : W_table,
        'Blood|H,Ltime,W': Blood_table,
        'BP|Blood'      : BP_table,
        'DNA|H,Ltime,W' : DNA_table,
        'AL|DNA'        : AL_table,
        'AS|BP'         : AS_table,
        'Hand|W'        : Hand_table,
    }

# Build once
EDGE_CPDS = build_cpds_from_edge_tables_v2()


In [10]:
# For consistent results
random.seed(7)
np.random.seed(7)

# Variables, states, display names, orders


VAR_STATES = {
    'H'     : ['O_delib','O_self','O_acc','V_acc','Suicide'],
    'Lfound': ['No','Yes'],
    'Ltime' : ['Before','After'],
    'O'     : ['Absent','Present'],
    'W'     : ['Victim','Offender'],
    'Blood' : ['No','Yes'],
    'BP'    : ['Fingerprint-like','Spatter'],
    'DNA'   : ['V','O'],
    'AL'    : ['Fail','Pass'],
    'AS'    : ['Fail','Pass'],
    'Hand'  : ['Victim','Offender'],
}

STATE_LABEL = {
    'H': {'O_delib':'Offender deliberate','O_self':'Offender self-defense',
          'O_acc':'Offender accidental','V_acc':'Victim accidental','Suicide':'Suicide'},
    'Lfound': {'No':'Letter not found','Yes':'Letter found'},
    'Ltime': {'Before':'Letter before death','After':'Letter after death'},
    'O': {'Absent':'Offender absent','Present':'Offender present'},
    'W': {'Victim':'Victim wrote','Offender':'Offender wrote'},
    'Blood': {'No':'NBlood','Yes':'Blood'},
    'BP': {'Fingerprint-like':'Blood pattern: Fingerprint-like','Spatter':'Blood pattern: Spatter'},
    'DNA': {'V':'DNA: Victim','O':'DNA: Offender'},
    'AL': {'Fail':'Audit lab: Fail','Pass':'Audit lab: Pass'},
    'AS': {'Fail':'Audit scene: Fail','Pass':'Audit scene: Pass'},
    'Hand': {'Victim':'Handwriting: Victim','Offender':'Handwriting: Offender'},
}

SHORT_LABEL = {
  'H': {
    'O_delib':'O delib','O_self':'O self-def','O accident':'O accident',
    'V_acc':'V accident','Suicide':'Suicide'
  },
  'Lfound': {'No':'L? No','Yes':'L? Yes'},
  'Ltime': {'Before':'t: Before','After':'t: After'},
  'O': {'Absent':'O: Absent','Present':'O: Present'},
  'W': {'Victim':'W: Victim','Offender':'W: Offender'},
  'Blood': {'No':'Blood: No','Yes':'Blood: Yes'},
  'BP': {'Fingerprint-like':'BP: Fingerprint','Spatter':'BP: Spat'},
  'DNA': {'V':'DNA: V','O':'DNA: O'},
  'AL': {'Fail':'AL: Fail','Pass':'AL: Pass'},
  'AS': {'Fail':'AS: Fail','Pass':'AS: Pass'},
  'Hand': {'Victim':'Hand: V','Offender':'Hand: O'},
}

# Event-tree orders (display only)
ORDER_HYPOTHESIS = ('H','Lfound','Ltime','O','W','Blood','BP','AS','DNA','AL','Hand')
ORDER_CHRONO  = ('Lfound','Ltime','O','W','H','Blood','BP','AS','DNA','AL','Hand')

AUDIT_VARS = {'AL','AS'}
SINK_VAR   = 'Hand'


In [11]:
# Build BN model (path-aware CPDs)

def make_model_once(edge_cpds=EDGE_CPDS):
    """Build a BN that reflects the path-aware CPDs. Used for VE queries."""
    model = DiscreteBayesianNetwork([
        # letter + timing
        ('H','Lfound'),
        ('H','Ltime'), ('Lfound','Ltime'),

        # presence and authorship
        ('H','O'), ('Ltime','O'),
        ('O','W'), ('H','W'), ('Ltime','W'),

        # scene + lab
        ('H','Blood'), ('Ltime','Blood'), ('W','Blood'),
        ('Blood','BP'),
        ('BP','AS'),
        ('O','DNA'), ('W','DNA'),
        ('DNA','AL'),
        ('W','Hand'),
    ])

    def _to_table(var, table, evidence):
        vcard = len(VAR_STATES[var])
        ecards = [len(VAR_STATES[e]) for e in (evidence or [])]
        prod_e = int(np.prod(ecards)) if ecards else 1
        arr = np.asarray(table, dtype=float)
        if arr.ndim == 1:
            arr = arr.reshape((vcard, 1))
        else:
            if arr.shape == (1, vcard):
                arr = arr.T
            if arr.size == vcard * prod_e and arr.shape != (vcard, prod_e):
                arr = arr.reshape((vcard, prod_e), order='C')
        assert arr.shape == (vcard, prod_e), f"{var} has shape {arr.shape}, expected {(vcard, prod_e)}"
        return arr.tolist(), ecards

    def CPD(var, key, evidence=None):
        evidence = evidence or []
        table, ecards = _to_table(var, edge_cpds[key], evidence)
        return TabularCPD(
            variable=var,
            variable_card=len(VAR_STATES[var]),
            values=table,
            evidence=evidence,
            evidence_card=ecards,
            state_names={var: VAR_STATES[var], **{e: VAR_STATES[e] for e in evidence}},
        )

    # Pull original per-(H,Ltime) P(W=Offender) and calibrate O, W
    inv_H = {v: k for k, v in H_MAP.items()}
    inv_T = {'Before': 't1', 'After': 't2'}

    p_old_off_cols = [
        timing_author_pr[(inv_H[h_lab], inv_T[t_lab])][1]
        for h_lab in VAR_STATES['H']
        for t_lab in VAR_STATES['Ltime']
    ]

    eps_w = 1e-6
    theta = min(1-1e-6, max(p_old_off_cols))

    def clip(x, lo=1e-6, hi=1-1e-6):
        return max(lo, min(hi, x))

    omega_cols = [clip(p_off / theta) for p_off in p_old_off_cols]

    O_table = [
        [1.0 - om for om in omega_cols],
        [      om for om in omega_cols],
    ]

    # Build W table interleaving O=Absent, O=Present per (H,Ltime)
    W_cols_V_abs, W_cols_O_abs, W_cols_V_pre, W_cols_O_pre = [], [], [], []
    for _ in omega_cols:
        W_cols_V_abs.append(1.0 - eps_w); W_cols_O_abs.append(eps_w)
        W_cols_V_pre.append(1.0 - theta); W_cols_O_pre.append(theta)

    W_table = [
        [v for i in range(len(omega_cols)) for v in (W_cols_V_abs[i], W_cols_V_pre[i])],
        [o for i in range(len(omega_cols)) for o in (W_cols_O_abs[i], W_cols_O_pre[i])],
    ]

    cpd_O = TabularCPD(
        variable='O', variable_card=2, values=O_table,
        evidence=['H','Ltime'], evidence_card=[5,2],
        state_names={'O':VAR_STATES['O'], 'H':VAR_STATES['H'], 'Ltime':VAR_STATES['Ltime']}
    )
    cpd_W = TabularCPD(
        variable='W', variable_card=2, values=W_table,
        evidence=['H','Ltime','O'], evidence_card=[5,2,2],
        state_names={'W':VAR_STATES['W'], 'H':VAR_STATES['H'], 'Ltime':VAR_STATES['Ltime'], 'O':VAR_STATES['O']}
    )

    vals = [
        [0.98, 0.20, 0.85, 0.10],  # P(DNA='V' | O,W)
        [0.02, 0.80, 0.15, 0.90],  # P(DNA='O' | O,W)
    ]
    cpd_DNA = TabularCPD(
        variable='DNA', variable_card=2, values=vals,
        evidence=['O','W'], evidence_card=[2,2],
        state_names={'DNA':VAR_STATES['DNA'], 'O':VAR_STATES['O'], 'W':VAR_STATES['W']}
    )

    cpd_H       = CPD('H',       'H_prior')
    cpd_Lfound  = CPD('Lfound',  'Lfound|H',        evidence=['H'])
    cpd_Ltime   = CPD('Ltime',   'Ltime|H,Lfound',  evidence=['H','Lfound'])
    cpd_Blood   = CPD('Blood',   'Blood|H,Ltime,W', evidence=['H','Ltime','W'])
    cpd_BP      = CPD('BP',      'BP|Blood',        evidence=['Blood'])
    cpd_AL      = CPD('AL',      'AL|DNA',          evidence=['DNA'])
    cpd_AS      = CPD('AS',      'AS|BP',           evidence=['BP'])
    cpd_Hand    = CPD('Hand',    'Hand|W',          evidence=['W'])

    model.add_cpds(cpd_H, cpd_Lfound, cpd_Ltime, cpd_O, cpd_W,
                   cpd_Blood, cpd_BP, cpd_DNA, cpd_AL, cpd_AS, cpd_Hand)
    assert model.check_model()
    return model

# Chronology overlays (joint-preserving reparameterisation)

def _nz_norm(a, axis=-1):
    """normalise numpy array along axis with safe zero handling."""
    s = a.sum(axis=axis, keepdims=True)
    s[s == 0] = 1.0
    return a / s

def compute_chrono_overlays_from_model(model):
    """Compute P(Lfound), P(Ltime|Lfound), P(O|Ltime,Lfound), P(W|O,Ltime,Lfound),
    and P(H|Lfound,Ltime,O,W) directly from the BN CPDs.
    """
    H = len(VAR_STATES['H'])
    Lf = len(VAR_STATES['Lfound'])
    Lt = len(VAR_STATES['Ltime'])
    O  = len(VAR_STATES['O'])
    W  = len(VAR_STATES['W'])

    P_H = np.array(model.get_cpds('H').get_values(), dtype=float).reshape(H)
    P_Lf_given_H = np.array(model.get_cpds('Lfound').get_values(), dtype=float)
    P_Lt_given_HLf = np.array(model.get_cpds('Ltime').get_values(), dtype=float)
    P_O_given_HLt  = np.array(model.get_cpds('O').get_values(), dtype=float)
    P_W_given_HLtO = np.array(model.get_cpds('W').get_values(), dtype=float)

    P_Lf_given_H = P_Lf_given_H.reshape(Lf, H)
    P_Lt_given_HLf = P_Lt_given_HLf.reshape(Lt, H, Lf)
    P_O_given_HLt  = P_O_given_HLt.reshape(O,  H, Lt)
    P_W_given_HLtO = P_W_given_HLtO.reshape(W,  H, Lt, O)

    P_Lf = (P_Lf_given_H * P_H[None, :]).sum(axis=1)

    num_H_given_Lf = P_Lf_given_H * P_H[None, :]
    P_H_given_Lf   = _nz_norm(num_H_given_Lf, axis=1)

    P_Lt_given_Lf = np.einsum('lh, thl -> lt', P_H_given_Lf, P_Lt_given_HLf)

    num_H_given_LfLt = np.einsum('h, lh, thl -> lth', P_H, P_Lf_given_H, P_Lt_given_HLf)
    P_H_given_LfLt   = _nz_norm(num_H_given_LfLt, axis=2)

    P_O_given_LtLf = np.einsum('oht, lth -> lto', P_O_given_HLt, P_H_given_LfLt)
    P_O_given_LtLf = _nz_norm(P_O_given_LtLf, axis=2)

    num_H_given_LfLtO = np.einsum('lth,oht->ltoh', P_H_given_LfLt, P_O_given_HLt)
    P_H_given_LfLtO   = _nz_norm(num_H_given_LfLtO, axis=3)

    P_W_given_OLtLf = np.einsum('whlo,ltoh->ltow', P_W_given_HLtO, P_H_given_LfLtO)
    P_W_given_OLtLf = _nz_norm(P_W_given_OLtLf, axis=3)

    num_H_given_LfLtOW = np.einsum('ltoh,whlo->ltowh', P_H_given_LfLtO, P_W_given_HLtO)
    P_H_given_LfLtOW   = _nz_norm(num_H_given_LfLtOW, axis=4)

    return dict(
        P_Lfound                 = P_Lf,
        P_Ltime_given_Lfound     = P_Lt_given_Lf,
        P_O_given_Ltime_Lfound   = P_O_given_LtLf,
        P_W_given_O_Ltime_Lfound = P_W_given_OLtLf,
        P_H_given_LfLtOW         = P_H_given_LfLtOW,
    )

def make_chrono_overrides(overlays):
    """Build callables that return probability vectors for overrides."""
    def _Lfound_vec(hist):
        return overlays['P_Lfound']
    def _Ltime_vec(hist):
        lf = _hist_get_state(hist, 'Lfound'); i = _idx('Lfound', lf)
        return overlays['P_Ltime_given_Lfound'][i, :]
    def _O_vec(hist):
        lf = _hist_get_state(hist, 'Lfound'); lt = _hist_get_state(hist, 'Ltime')
        i = _idx('Lfound', lf); j = _idx('Ltime', lt)
        return overlays['P_O_given_Ltime_Lfound'][i, j, :]
    def _W_vec(hist):
        lf = _hist_get_state(hist, 'Lfound'); lt = _hist_get_state(hist, 'Ltime'); o = _hist_get_state(hist, 'O')
        i = _idx('Lfound', lf); j = _idx('Ltime', lt); k = _idx('O', o)
        return overlays['P_W_given_O_Ltime_Lfound'][i, j, k, :]
    def _H_vec(hist):
        lf = _hist_get_state(hist, 'Lfound'); lt = _hist_get_state(hist, 'Ltime')
        o  = _hist_get_state(hist, 'O');      w  = _hist_get_state(hist, 'W')
        i = _idx('Lfound', lf); j = _idx('Ltime', lt); k = _idx('O', o); m = _idx('W', w)
        return overlays['P_H_given_LfLtOW'][i, j, k, m, :]

    return {'Lfound': _Lfound_vec, 'Ltime': _Ltime_vec, 'O': _O_vec, 'W': _W_vec, 'H': _H_vec}

# Build model, VE, and overlays
MODEL = make_model_once()
VE    = VariableElimination(MODEL)

_CHRONO_OVERLAYS = compute_chrono_overlays_from_model(MODEL)
CHRONO_OVERRIDES = make_chrono_overrides(_CHRONO_OVERLAYS)


In [12]:
# Canonical labels & BN query 

def canonical_label(var, state):
    """Format Var:State labels used on edges and histories."""
    return f"{var}:{state}"

def parse_label(label):
    """Split a canonical label into (var, state)."""
    var, state = label.split(':', 1)
    return var, state

def history_to_evidence(history):
    """Convert a tuple/list of 'Var:State' labels to a pgmpy evidence dict."""
    hyp = {}
    for lab in history:
        v, s = parse_label(lab)
        hyp[v] = s
    return hyp

def get_prob_bn(model, history, query_var, query_state):
    """P(query_var=query_state | history) via VE. If already fixed, return 0/1."""
    evidence = history_to_evidence(history)
    if query_var in evidence:
        return 1.0 if evidence[query_var] == query_state else 0.0
    q = VE.query([query_var], evidence=evidence, show_progress=False)
    states = model.get_cpds(query_var).state_names[query_var]
    idx = states.index(query_state)
    return float(q.values[idx])

def get_prob_mixed(model, history, query_var, query_state, overrides=None):
    """Use overlay vector for query_var when provided, otherwise fall back to BN."""
    if overrides and query_var in overrides:
        vec = np.asarray(overrides[query_var](history), dtype=float)
        s = float(vec.sum())
        if s <= 0:
            return get_prob_bn(model, history, query_var, query_state)
        vec = vec / s
        idx = VAR_STATES[query_var].index(query_state)
        return float(vec[idx])
    return get_prob_bn(model, history, query_var, query_state)

# Helpers for overrides
def _hist_get_state(history, var):
    for lab in history:
        v, st = parse_label(lab)
        if v == var:
            return st
    return None

def _idx(var, st):
    return VAR_STATES[var].index(st)


In [13]:
# Event-tree builder + deterministic contraction

def build_event_tree(model, order, var_states, root_history=(), overrides=None):
    """Construct the event tree following a variable order, optionally using overlays."""
    G = nx.DiGraph()
    root = root_history
    G.add_node(root, history=root)
    frontier = [root]

    for var in order:
        new_frontier = []
        for u in frontier:
            hist = list(G.nodes[u]['history'])
            hist_vars = {parse_label(l)[0] for l in hist if ':' in l}
            if var in hist_vars:
                new_frontier.append(u)
                continue
            states = var_states[var]
            probs  = [get_prob_mixed(model, hist, var, s, overrides=overrides) for s in states]
            s = sum(probs); assert s > 0
            probs = [p/s for p in probs]
            labels = [canonical_label(var, s) for s in states]
            assert len(labels) == len(set(labels))
            for state, p, lab in zip(states, probs, labels):
                if p <= 0.0:
                    continue
                v = tuple(hist + [lab])
                G.add_node(v, history=v)
                G.add_edge(u, v, prob=float(p), label=lab, var=var, state=state)
                new_frontier.append(v)
        frontier = new_frontier

    contract_prob_one_edges(G)
    return G

def contract_prob_one_edges(G, tol=1e-12):
    """Contract edges with probability approximately 1 to avoid trivial nodes."""
    changed = True
    while changed:
        changed = False
        for u in list(G.nodes()):
            succ = list(G.successors(u))
            if not succ:
                continue
            if len(succ) == 1:
                v = succ[0]
                p = G[u][v]['prob']
                if abs(p - 1.0) < tol:
                    preds = list(G.predecessors(u))
                    if preds:
                        for a in preds:
                            pa = G[a][u]['prob'] * p
                            lab = G[a][u]['label']
                            if G.has_edge(a, v):
                                G[a][v]['prob'] += pa
                            else:
                                G.add_edge(a, v, prob=pa, label=lab)
                            G.remove_edge(a, u)
                    if u in G:
                        G.remove_node(u)
                    changed = True
                    break


In [14]:
# Stage & position (CEG) via partition refinement

def compute_positions(G, prob_round=12, depth_lock=True):
    """Return maps: node->stage_id and node->position_id using refinement."""
    roots = [n for n in G.nodes() if G.in_degree(n) == 0]
    assert len(roots) == 1, "Graph must have a single root"
    root = roots[0]
    depth = {root: 0}
    for u in nx.topological_sort(G):
        for v in G.successors(u):
            depth[v] = depth[u] + 1

    def stage_sig(u):
        items = []
        for v in G.successors(u):
            lab = G[u][v]['label']
            p   = round(float(G[u][v]['prob']), prob_round)
            items.append((lab, p))
        sig = tuple(sorted(items))
        return (depth[u], sig) if depth_lock else sig

    stage = {u: stage_sig(u) for u in G.nodes()}
    pos   = {u: stage[u] for u in G.nodes()}

    changed = True
    while changed:
        changed = False
        new_pos = {}
        for u in G.nodes():
            items = []
            for v in G.successors(u):
                lab = G[u][v]['label']
                p   = round(float(G[u][v]['prob']), prob_round)
                items.append((lab, p, pos.get(v, ('LEAF',))))
            sig = (depth[u], tuple(sorted(items))) if depth_lock else tuple(sorted(items))
            new_pos[u] = sig
        if any(new_pos[u] != pos[u] for u in G.nodes()):
            pos = new_pos
            changed = True

    def compress_to_ids(mapping):
        uniq, ids = {}, {}
        k = 0
        for u, sig in mapping.items():
            if sig not in uniq:
                uniq[sig] = k; k += 1
            ids[u] = uniq[sig]
        return ids

    return compress_to_ids(stage), compress_to_ids(pos)


In [15]:
# Audit elimination (handles chains) + renormalise

def is_audit_node(node):
    """True if the last label in the node history belongs to an audit variable."""
    if not node:
        return False
    var, _ = parse_label(node[-1])
    return var in AUDIT_VARS

def strip_audits_from_history(history):
    """Return a tuple of labels with all AUDIT_VARS removed."""
    return tuple(lab for lab in history if parse_label(lab)[0] not in AUDIT_VARS)

def eliminate_audits(G):
    """Remove all audit vertices by bridging over audit-only chains."""
    H = nx.DiGraph()

    def ensure_node(hist_clean):
        if hist_clean not in H:
            H.add_node(hist_clean, history=hist_clean)

    def first_non_audit_children(x, p_in):
        stack = [(x, p_in)]
        while stack:
            u, pu = stack.pop()
            for v in G.successors(u):
                pv = pu * float(G[u][v]['prob'])
                if is_audit_node(v):
                    stack.append((v, pv))
                else:
                    yield v, pv

    root = [n for n in G.nodes() if G.in_degree(n) == 0][0]
    q = deque([root])
    seen = set([root])

    while q:
        u_orig = q.popleft()
        u_clean = strip_audits_from_history(u_orig)
        ensure_node(u_clean)

        for v in G.successors(u_orig):
            p0 = float(G[u_orig][v]['prob'])
            if is_audit_node(v):
                for w, pw in first_non_audit_children(v, p0):
                    lab = G[v][w]['label']
                    child_clean = u_clean + (lab,)
                    ensure_node(child_clean)
                    var, state = parse_label(lab)
                    if H.has_edge(u_clean, child_clean):
                        H[u_clean][child_clean]['prob'] += pw
                    else:
                        H.add_edge(u_clean, child_clean, prob=pw,
                                   label=lab, var=var, state=state)
                    if w not in seen:
                        seen.add(w); q.append(w)
            else:
                lab = G[u_orig][v]['label']
                child_clean = u_clean + (lab,)
                ensure_node(child_clean)
                var, state = parse_label(lab)
                if H.has_edge(u_clean, child_clean):
                    H[u_clean][child_clean]['prob'] += p0
                else:
                    H.add_edge(u_clean, child_clean, prob=p0,
                               label=lab, var=var, state=state)
                if v not in seen:
                    seen.add(v); q.append(v)

    root_clean = strip_audits_from_history(root)
    reach = {root_clean}
    dq = deque([root_clean])
    while dq:
        x = dq.popleft()
        for y in H.successors(x):
            if y not in reach:
                reach.add(y); dq.append(y)
    H = H.subgraph(reach).copy()

    for u in H.nodes():
        succ = list(H.successors(u))
        if succ:
            s = sum(float(H[u][w]['prob']) for w in succ)
            if s > 0:
                for w in succ:
                    H[u][w]['prob'] = float(H[u][w]['prob']) / s

    return H


In [16]:
# Invariants / sanity checks

def assert_outflows_sum_to_1(G, tol=1e-12):
    bad = []
    for u in G.nodes():
        succ = list(G.successors(u))
        if not succ:
            continue
        s = sum(float(G[u][v]['prob']) for v in succ)
        if not math.isclose(s, 1.0, rel_tol=0, abs_tol=tol):
            bad.append((u, s))
    assert not bad, f"Outgoing probs not 1.0 at {len(bad)} nodes; first: {bad[:3]}"

def assert_unique_labels_per_floret(G):
    bad = []
    for u in G.nodes():
        labs = [G[u][v]['label'] for v in G.successors(u)]
        if len(labs) != len(set(labs)):
            bad.append((u, labs))
    assert not bad, f"Duplicate labels in florets: {bad[:3]}"

def assert_same_marginals(model, G1, G2, var='H', tol=1e-9):
    """Placeholder: BN marginals should not depend on display order."""
    pass


In [17]:
# Build both trees, compute positions, remove audits, check invariants

def build_all():
    G_hyp  = build_event_tree(MODEL, ORDER_HYPOTHESIS, VAR_STATES)
    G_ch   = build_event_tree(MODEL, ORDER_CHRONO,  VAR_STATES,
                              overrides=CHRONO_OVERRIDES)

    for G in (G_hyp, G_ch):
        assert_unique_labels_per_floret(G)
        assert_outflows_sum_to_1(G)

    st_hyp, pos_hyp = compute_positions(G_hyp)
    st_ch,  pos_ch  = compute_positions(G_ch)

    G_hyp_noaudit = eliminate_audits(G_hyp)
    G_ch_noaudit  = eliminate_audits(G_ch)
    for G in (G_hyp_noaudit, G_ch_noaudit):
        assert_unique_labels_per_floret(G)
        assert_outflows_sum_to_1(G)

    return {
        'hyp' : {'G':G_hyp, 'stage':st_hyp, 'pos':pos_hyp, 'noaudit':G_hyp_noaudit},
        'chr' : {'G':G_ch,  'stage':st_ch,  'pos':pos_ch,  'noaudit':G_ch_noaudit},
    }

BUILDS = build_all()


In [18]:
# Plotting helpers 

def _tab20():
    return [
        "#1f77b4","#ff7f0e","#2ca02c","#d62728","#9467bd",
        "#8c564b","#e377c2","#7f7f7f","#bcbd22","#17becf",
        "#aec7e8","#ffbb78","#98df8a","#ff9896","#c5b0d5",
        "#c49c94","#f7b6d2","#c7c7c7","#dbdb8d","#9edae5"
    ]

def stage_node_colors(G, stage_map):
    pal = _tab20()
    colors = {}
    for n in G.nodes():
        sid = stage_map[n]
        colors[n] = pal[sid % len(pal)]
    return colors

def pos_node_colors(Q):
    pal = _tab20()
    colors = {}
    for p in Q.nodes():
        colors[p] = pal[int(p) % len(pal)]
    return colors

def _lerp(a, b, t):  # linear interpolate 0..1
    return a + (b-a)*t

def prob_to_hex(p):
    c0 = (210, 224, 255)
    c1 = ( 31,  78, 121)
    r = int(_lerp(c0[0], c1[0], p))
    g = int(_lerp(c0[1], c1[1], p))
    b = int(_lerp(c0[2], c1[2], p))
    return f"#{r:02x}{g:02x}{b:02x}"

def _edge_text(p):
    return f"{p:.2f}"

def _labels_for_nodes(H):
    lab = {}
    for n in H.nodes():
        if len(n) == 0:
            lab[n] = "Start"
        else:
            var, state = parse_label(n[-1])
            lab[n] = STATE_LABEL[var][state]
    return lab

# Graphviz (pydot) safe layout wrapper (kept; reimport guarded below)
try:
    from networkx.drawing.nx_pydot import graphviz_layout as _nx_graphviz_layout
    _HAS_GV = True
except Exception:
    _HAS_GV = False

def gv_layout(G, prog='dot', rankdir='LR'):
    """Compute Graphviz positions while avoiding ':' issues via relabeling."""
    import numpy as _np
    import networkx as _nx

    mapping = {n: f"n{i}" for i, n in enumerate(G.nodes())}

    Hsafe = _nx.DiGraph()
    Hsafe.add_nodes_from(mapping.values())
    Hsafe.add_edges_from((mapping[u], mapping[v]) for u, v in G.edges())

    Hsafe.graph.setdefault('graph', {})
    Hsafe.graph['graph']['rankdir'] = rankdir

    pos_safe = _nx_graphviz_layout(Hsafe, prog=prog)

    inv = {v: k for k, v in mapping.items()}
    return {inv[k]: _np.array(v, dtype=float) for k, v in pos_safe.items()}

def _pretty_node_label(node):
    if not node:
        return "Start"
    var, state = parse_label(node[-1])
    return STATE_LABEL[var][state]

def _pretty_edge_label(edge_label, p):
    var, state = parse_label(edge_label)
    return f"{STATE_LABEL[var][state]}\n{p:.2f}"

def _stage_palette(n):
    cols = []
    for i in range(n):
        h = (i * 0.145) % 1.0
        s = 0.45 + 0.25*((i % 2))
        v = 0.92
        r, g, b = colorsys.hsv_to_rgb(h, s, v)
        cols.append(f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}")
    return cols

def draw_floret(model, history, next_var, figsize=(8,4)):
    """Plot a single floret P(next_var | history)."""
    states = VAR_STATES[next_var]
    probs  = [get_prob_bn(model, history, next_var, s) for s in states]
    s = sum(probs); probs = [p/s for p in probs]
    labels = [STATE_LABEL[next_var][s] for s in states]

    plt.figure(figsize=figsize)
    plt.title(
        f"Floret: {next_var} | " +
        " → ".join([STATE_LABEL[parse_label(h)[0]][parse_label(h)[1]] if ':' in h else h
                          for h in history])
    )

    x = np.arange(len(states))
    plt.bar(x, probs)
    plt.xticks(x, labels, rotation=30, ha='right')
    plt.ylabel("Probability")
    plt.ylim(0, 1)
    plt.show()

def draw_scenario_path(labels_with_p, title="Scenario path"):
    """Draw a linear path with edge labels and probabilities."""
    fig, ax = plt.subplots(figsize=(12, 1.8))
    ax.axis('off'); x=0.0
    for lab, p in labels_with_p:
        ax.add_patch(plt.Rectangle((x,0.2), 1.8, 0.6, fill=False))
        ax.text(x+0.9, 0.5, lab, ha='center', va='center')
        if p is not None:
            ax.text(x+0.0, 0.9, f"{p:.2f}")
        x += 2.0
        ax.arrow(x-0.2, 0.5, 0.3, 0.0, head_width=0.05, head_length=0.1, length_includes_head=True)
    plt.title(title)
    plt.show()


In [19]:
def render_staged_subtree_pdf(
        G, stage_map, history_seq, pdf_path, title="",
        depth=None, rankdir='LR', a4='landscape',
        *, min_edge_prob=0.0, mass_cut=0.0,
        summarize_tail=False, abbreviate=False):
    """Render a staged subtree to PDF with optional pruning and tail summaries."""

    root = [n for n in G.nodes() if G.in_degree(n)==0][0]
    u = root
    for (var, st) in history_seq:
        lab = canonical_label(var, st)
        nxt = next((v for v in G.successors(u) if G[u][v]['label']==lab), None)
        if nxt is None:
            raise ValueError(f"History not found: {history_seq}")
        u = nxt

    keep = {u}; frontier=[u]; d=0
    while frontier and (depth is None or d < depth):
        nxt=[]
        for x in frontier:
            for y in G.successors(x):
                keep.add(y); nxt.append(y)
        frontier=nxt; d+=1
    H = G.subgraph(keep).copy()

    missing = [n for n in H.nodes() if n not in stage_map]
    if missing:
        raise ValueError("stage_map must be computed on the SAME graph you render.")

    def local_masses(Gsub, start):
        mass = {start: 1.0}
        for x in nx.topological_sort(Gsub):
            mu = float(mass.get(x, 0.0))
            for y in Gsub.successors(x):
                mass[y] = mass.get(y, 0.0) + mu * float(Gsub[x][y]['prob'])
        return mass
    mass = local_masses(H, u)

    kept_edges = []
    tails = {}
    for a, b, d in H.edges(data=True):
        p = float(d['prob'])
        if p < min_edge_prob:
            if summarize_tail:
                t = tails.setdefault(a, {'mass_sum':0.0, 'count':0})
                t['mass_sum'] += mass.get(a,0.0) * p
                t['count'] += 1
            continue
        flow = mass.get(a, 0.0) * p
        if flow < mass_cut:
            if summarize_tail:
                t = tails.setdefault(a, {'mass_sum':0.0, 'count':0})
                t['mass_sum'] += flow
                t['count'] += 1
            continue
        kept_edges.append((a, b, d))

    used_nodes = {u}
    for a, b, _ in kept_edges:
        used_nodes.add(a); used_nodes.add(b)

    W, Hsize = ('11.69','8.27') if a4=='landscape' else ('8.27','11.69')
    dot = Digraph(comment=title, format='pdf')
    dot.attr(rankdir=rankdir, size=f'{W},{Hsize}!', margin='0.25',
             nodesep='0.30', ranksep='0.60', splines='spline')

    stages = sorted({stage_map[n] for n in used_nodes})
    pal = {}
    for i, s in enumerate(stages):
        h = (i * 0.123) % 1.0
        r, g, b = colorsys.hsv_to_rgb(h, 0.35, 1.0)
        pal[s] = f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"

    leaf_override = {}
    for n in used_nodes:
        if H.out_degree(n) == 0 and n:
            var, st = parse_label(n[-1])
            if var == 'Hand':
                leaf_override[n] = {'Victim':'#e9f8ef', 'Offender':'#fde9ea'}[st]

    id_of = {n: f"n{i}" for i, n in enumerate(used_nodes)}

    dot.attr('node', shape='circle', style='filled', penwidth='1.2',
             color='#475e88', fontsize='12')
    for n in used_nodes:
        stid = stage_map[n]
        fill = leaf_override.get(n, pal[stid])
        label = "Start" if n == u else f"s{stid}"
        dot.node(id_of[n], label, fillcolor=fill)

    dot.attr('edge', color='#2f3b4a', penwidth='1.2', fontsize='11')
    def _edge_label(lab, p):
        var, st = parse_label(lab)
        return f"{STATE_LABEL[var][st]}\n{p:.2f}"
    for a, b, d in kept_edges:
        dot.edge(id_of[a], id_of[b], label=_edge_label(d['label'], float(d['prob'])))

    if summarize_tail:
        for a, info in tails.items():
            if a not in used_nodes or info['count']==0 or info['mass_sum']<=0.0:
                continue
            tid = f"tail_{id_of[a]}"
            dot.attr('node', shape='box', style='filled', fillcolor='#f1f1f1',
                     color='#8c8c8c', fontsize='10')
            dot.node(tid, f"{info['count']} tiny branch(es)\nSigma flow approx {info['mass_sum']:.2g}")
            dot.attr('edge', color='#a0a0a0', style='dashed')
            dot.edge(id_of[a], tid, label="pruned")

    if title:
        dot.attr(label=title, labelloc='t', fontsize='18')
    dot.render(filename=pdf_path, cleanup=True)


In [20]:
# Basic information theory utilities 

def _entropy(pvec):
    p = np.asarray(pvec, dtype=float)
    p = p[p > 0]
    return -float(np.sum(p * np.log2(p)))

def _kl(p, q, eps=1e-12):
    p = np.asarray(p, dtype=float)
    q = np.asarray(q, dtype=float)
    p = np.clip(p, eps, 1.0)
    q = np.clip(q, eps, 1.0)
    p = p / p.sum(); q = q / q.sum()
    return float(np.sum(p * (np.log2(p) - np.log2(q))))

# Posterior helpers 

def _posterior_of(var, evidence):
    """Return a dict state->prob for a variable given evidence."""
    q = VE.query([var], evidence=evidence, show_progress=False)
    states = MODEL.get_cpds(var).state_names[var]
    return {s: float(q.values[i]) for i, s in enumerate(states)}

def posterior_H(evidence=None):
    evidence = evidence or {}
    return _posterior_of('H', evidence)

def _normalise_dict(d):
    s = sum(d.values())
    return {k: v/s for k, v in d.items()} if s > 0 else d


In [21]:
# Expected information gain (EIG) for a candidate test

def expected_information_gain(candidate_var, given_evidence=None):
    """I(H; X | evidence) by enumeration over states of X."""
    given_evidence = given_evidence or {}

    # Prior/posterior on H under current evidence
    pH = posterior_H(given_evidence)
    H_states = list(VAR_STATES['H'])
    pH_vec = np.array([pH[h] for h in H_states], dtype=float)

    # Distribution of candidate under evidence
    qX = _posterior_of(candidate_var, given_evidence)
    X_states = list(VAR_STATES[candidate_var])
    qX_vec = np.array([qX[x] for x in X_states], dtype=float)

    # Expected KL between H|X and H
    eig = 0.0
    for x in X_states:
        ev = dict(given_evidence)
        ev[candidate_var] = x
        pHx = posterior_H(ev)
        pHx_vec = np.array([pHx[h] for h in H_states], dtype=float)
        eig += qX[x] * _kl(pHx_vec, pH_vec)
    return float(eig)


In [22]:
# -CONFIG: variable names & labels
VAR_H   = "H"
VAR_O   = "O"
VAR_DNA = "DNA"
VAR_HAND= "Hand"
VAR_BLD = "Blood"

LABEL_O = ("Absent", "Present")            # order: 0=Absent, 1=Present
LABEL_D = ("Victim", "Offender")           # order: 0=Victim, 1=Offender
LABEL_HA = ("Victim", "Offender")          # for Hand (sinks)

# tiny utilities 
def _factor_state_names(f):
    # pgmpy Factor returned by infer.query
    try:
        return f.state_names
    except AttributeError:
        # pgmpy<0.1 fallback
        return {v: list(range(f.get_cardinality([v])[v])) for v in f.scope()}

def entropy_base2(p_dict):
    p = np.array([p for p in p_dict.values()], dtype=float)
    p = p[p > 0]
    return -np.sum(p * np.log2(p))

def marginal(infer, var, evidence=None):
    q = infer.query(variables=[var], evidence=evidence, show_progress=False)
    names = _factor_state_names(q)[var]
    vals  = q.values
    return {names[i]: float(vals[i]) for i in range(len(names))}

# ---- drop-in replacements ----
def joint2(infer, varx, vary, evidence=None):
    q = infer.query(variables=[varx, vary], evidence=evidence, show_progress=False)
    names = _factor_state_names(q)  # already defined in your notebook
    xs = names[varx]; ys = names[vary]
    order = list(q.variables)       # pgmpy’s axis order (may differ)
    ax_x = order.index(varx); ax_y = order.index(vary)
    tbl = {}
    for ix, x in enumerate(xs):
        for iy, y in enumerate(ys):
            if ax_x == 0 and ax_y == 1:
                val = q.values[ix, iy]
            else:
                val = q.values[iy, ix]
            tbl[(x, y)] = float(val)
    return xs, ys, tbl

def conditional_table(infer, child, parent, evidence=None, eps=0.0):
    xs, ys, joint = joint2(infer, child, parent, evidence=evidence)
    table = {}
    for y in ys:
        denom = sum(joint[(x, y)] for x in xs)
        if denom <= eps:
            table[y] = {x: float("nan") for x in xs}
        else:
            table[y] = {x: joint[(x, y)]/denom for x in xs}
    return xs, ys, table

def mi_bits(infer, varx, vary, evidence=None):
    xs, ys, joint = joint2(infer, varx, vary, evidence=evidence)
    px = {x: sum(joint[(x,y)] for y in ys) for x in xs}
    py = {y: sum(joint[(x,y)] for x in xs) for y in ys}
    I = 0.0
    for x in xs:
        for y in ys:
            pxy = joint[(x,y)]
            if pxy > 0 and px[x] > 0 and py[y] > 0:
                I += pxy * math.log2(pxy/(px[x]*py[y]))
    return I

def likelihood_ratios_binary(infer, test_var, H1, H0, evidence=None):
    p1 = marginal(infer, test_var, {**(evidence or {}), VAR_H: H1})
    p0 = marginal(infer, test_var, {**(evidence or {}), VAR_H: H0})
    lr = {}
    for t in p1:
        lr[t] = p1[t] / p0[t] if p0[t] > 0 else float('inf')
    return lr

def print_audit_neutrality(infer, audit_var, e=None):
    print(f"Neutrality check for {audit_var} across H at e={e}")
    for h, _ in marginal(infer, VAR_H).items():
        pa = marginal(infer, audit_var, {**(e or {}), VAR_H: h})
        print(f"P({audit_var} | H={h}) = { {k: round(v,4) for k,v in pa.items()} }")


def check_neutrality(infer, audit_var, target_var, e=None):
    for t in infer.query([target_var], show_progress=False).state_names[target_var]:
        q = infer.query([audit_var], evidence={**(e or {}), target_var: t}, show_progress=False)
        print(f"P({audit_var} | {target_var}={t}) = "
              f"{ {s: round(float(p),3) for s,p in zip(q.state_names[audit_var], q.values)} }")


In [23]:
# Pretty printers 

def _fmt_pct(x):
    return f"{100.0*float(x):.1f}%"

def print_sorted_posterior(title, dist):
    print("\n" + title)
    print("-" * len(title))
    for k, v in sorted(dist.items(), key=lambda kv: kv[1], reverse=True):
        print(f"{STATE_LABEL['H'][k]:28s}  {v: .4f} ({_fmt_pct(v)})")

def print_table(rows, headers=None, colwidth=14):
    if headers:
        print(" ".join(h.ljust(colwidth) for h in headers))
        print("-" * (colwidth * len(headers)))
    for r in rows:
        print(" ".join(str(c).ljust(colwidth) for c in r))

def entropy_categorical(pdict):
    return -sum(p*log2(p) for p in pdict.values() if p > 0)

def print_entropy_H(infer, e=None):
    post = infer.query([VAR_H], evidence=e, show_progress=False)
    p_post = {state: float(prob) for state, prob in zip(post.state_names[VAR_H], post.values)}
    H_post = entropy_categorical(p_post)
    prior = infer.query([VAR_H], show_progress=False)
    p_prior = {state: float(prob) for state, prob in zip(prior.state_names[VAR_H], prior.values)}
    H_prior = entropy_categorical(p_prior)
    print(f"H(H) [bits]   : {H_prior:.3f}")
    print(f"P(H | e={e}) : {p_post}")
    print(f"H(H|e) [bits] : {H_post:.3f}")
    
def print_dna_channel_given_O(infer, e=None):
    # Get conditional with the engine’s own state names
    xs, ys, table = conditional_table(
        infer, child=VAR_DNA, parent=VAR_O, evidence=e, eps=1e-15
    )
    def fmt(val):
        if val is None:                 # missing key
            return "     —     "
        if isinstance(val, float) and math.isnan(val):  # undefined (zero denom)
            return "     —     "
        return f"{val:12.4f}"

    # Header uses returned child labels (xs)
    print(f"P({VAR_DNA} | {VAR_O}, e={e})")
    print(" " * 13 + "".join(f"{x:>12}" for x in xs))
    # Rows iterate over returned parent labels (ys)
    for o in ys:
        row = [table[o].get(x) for x in xs]
        print(f"{o:>12} " + " ".join(fmt(v) for v in row))

def print_sink_masses(infer, e=None, hand_labels=LABEL_HA):
    ph = marginal(infer, VAR_HAND, evidence=e)
    off = ph.get(hand_labels[1], float("nan"))  # Offender
    vic = ph.get(hand_labels[0], float("nan"))  # Victim
    print(f"Sinks (Hand report) at e={e}")
    print(f"P(Offender sink | e) = {off:.6f}")
    print(f"P(Victim   sink | e) = {vic:.6f}")

def print_q_and_MI(infer, e=None, present_label=LABEL_O[1]):
    pO = marginal(infer, VAR_O, evidence=e)
    q  = pO.get(present_label, float("nan"))
    I_dna = mi_bits(infer, VAR_DNA, VAR_O, evidence=e)
    I_hand= mi_bits(infer, VAR_HAND, VAR_O, evidence=e)
    print(f"q(e)=P(O=Present | e={e}) = {q:.3f}")
    print(f"I(DNA; O | e)   [bits]   = {I_dna:.3f}")
    print(f"I(Hand; O | e)  [bits]   = {I_hand:.3f}")
    return q, I_dna, I_hand

def compare_builds(infer_hyp, infer_chrono, e=None, tol=1e-6):
    def snapshot(infer):
        pO = marginal(infer, VAR_O, evidence=e)
        q  = pO[LABEL_O[1]]
        I  = mi_bits(infer, VAR_DNA, VAR_O, evidence=e)
        ph = marginal(infer, VAR_HAND, evidence=e)
        sink_off = ph[LABEL_HA[1]]
        sink_vic = ph[LABEL_HA[0]]
        return q, I, sink_off, sink_vic
    a = snapshot(infer_hyp)
    b = snapshot(infer_chrono)
    names = ["q(e)", "I(DNA;O|e)", "Offender sink", "Victim sink"]
    ok = True
    for i, nm in enumerate(names):
        if abs(a[i]-b[i]) > tol:
            ok = False
            print(f"[DIFF] {nm}: hyp={a[i]:.6f}, chrono={b[i]:.6f}")
    if ok:
        print(f"[OK] Hypothesis-first vs Chronological are equal within tol={tol:g} on q, MI, and sink masses.")


In [24]:
# Orchestration: a quick, reproducible end-to-end run 

EVIDENCE_BASE = {VAR_BLD: "Yes", "Lfound": "Yes"}

def run_quick_demo():
    os.makedirs("outputs", exist_ok=True)

    # 1) Build trees and compute overlays (already built globally)
    builds = BUILDS
    G_hyp, st_hyp = builds['hyp']['G'], builds['hyp']['stage']
    G_chr, st_chr = builds['chr']['G'], builds['chr']['stage']

    # 2) Render small staged subtrees (if graphviz available)
    try:
        render_staged_subtree_pdf(
            G_hyp, st_hyp,
            history_seq=[('Lfound','Yes')],
            pdf_path=os.path.join('outputs','subtree_hyp_LfoundYes'),
            title='Hypothesis order — subtree from Lfound=Yes',
            depth=3, summarize_tail=True, min_edge_prob=0.02, mass_cut=1e-4
        )
        render_staged_subtree_pdf(
            G_chr, st_chr,
            history_seq=[('Lfound','Yes')],
            pdf_path=os.path.join('outputs','subtree_chr_LfoundYes'),
            title='Chronology order — subtree from Lfound=Yes',
            depth=3, summarize_tail=True, min_edge_prob=0.02, mass_cut=1e-4
        )
        print("Rendered PDF subtrees in ./outputs/ (Graphviz).")
    except Exception as ex:
        print("Graphviz render skipped:", ex)

    # 3) Key posterior with evidence Blood=Yes (mirrors your print index)
    ev = {'Blood':'Yes'}
    post_H_B = posterior_H(ev)
    print_sorted_posterior("Table 3 — Posterior over H | Blood=Yes", post_H_B)

    # 4) Next best test by Expected Information Gain (EIG)
    exclude = set(ev.keys())
    candidates = [v for v in VAR_STATES.keys() if v not in exclude and v not in ('H',)]
    eig_rows = []
    for v in candidates:
        try:
            val = expected_information_gain(v, ev)
            eig_rows.append((v, val))
        except Exception:
            pass
    eig_rows.sort(key=lambda t: t[1], reverse=True)

    print("\nNext best test by EIG | Blood=Yes")
    print("-------------------------------")
    if eig_rows:
        best_var, best_eig = eig_rows[0]
        print(f"Best: {best_var}  (EIG={best_eig:.4f} bits)")
    print_table([(v, f"{e:.4f}") for v, e in eig_rows], headers=["Test","EIG(bits)"])

    # 5) Likelihood ratios for AS=Pass under Blood=Yes (Appendix C analogue)
    try:
        lrs = likelihood_ratios_binary('AS', 'Pass', given_evidence=ev)
        rows = [(STATE_LABEL['H'][h], f"{lrs[h]:.3f}") for h in VAR_STATES['H']]
        print("\nAppendix C — LRs for AS=Pass | Blood=Yes (H vs ~H)")
        print("----------------------------------------------")
        print_table(rows, headers=["H","LR"])
    except Exception as ex:
        print("LR computation skipped:", ex)

    # 6) Sensitivity: EIG(DNA; H | Blood=Yes)
    try:
        eig_dna = expected_information_gain('DNA', ev)
        print(f"\nSensitivity — EIG(DNA;H | Blood=Yes): {eig_dna:.4f} bits")
    except Exception:
        pass

    # 7) Conditional sink masses: P(Hand | Blood=Yes)
    try:
        p_hand = _posterior_of('Hand', ev)
        print("\nConditional sink masses — P(Hand | Blood=Yes)")
        print("--------------------------------------------")
        print_table([(k, f"{v:.4f}") for k, v in p_hand.items()], headers=["Hand","Prob"]) 
    except Exception:
        pass

    # 8) Invariance & sanity checks already enforced; summarise headline
    p0 = posterior_H()
    p1 = post_H_B
    def _top(d):
        k = max(d, key=d.get)
        return STATE_LABEL['H'][k], d[k]
    h0, m0 = _top(p0)
    h1, m1 = _top(p1)
    print("\nSummary — headline posteriors")
    print("-----------------------------")
    print(f"Prior argmax H: {h0} ({m0:.3f})")
    print(f"Posterior argmax H | Blood=Yes: {h1} ({m1:.3f})")


    infer_hyp    = VariableElimination(MODEL)
    infer_chrono = VariableElimination(MODEL)

    e = {VAR_BLD: "Yes"}  # base scene evidence used in the thesis

    print_entropy_H(infer_hyp, e=e)                 # prints H(H) and H(H|e)
    print_dna_channel_given_O(infer_hyp, e=e)       # prints the 2x2 scene-conditional DNA|O table
    print_sink_masses(infer_hyp, e=e)               # prints Offender/Victim sink masses (Hand report)
    print_q_and_MI(infer_hyp, e=e)                  # prints q(e), I(DNA;O|e), I(Hand;O|e)
    compare_builds(infer_hyp, infer_chrono, e=e)    # asserts equality across builds (q, MI, sinks)

    e = dict(EVIDENCE_BASE)  # {'Blood':'Yes','Lfound':'Yes'}

    print_entropy_H(infer_hyp, e=e)
    print_dna_channel_given_O(infer_hyp, e=e)          # now prints proper numbers (≈ 0.98 / 0.02 and ≈ 0.14 / 0.86)
    print_sink_masses(infer_hyp, e=e)
    q, I_dna, I_hand = print_q_and_MI(infer_hyp, e=e)

    # Optional: audit neutrality columns (baseline should be equal across H)
    print_audit_neutrality(infer_hyp, audit_var="AS", e=e)
    print_audit_neutrality(infer_hyp, audit_var="AL", e=e)

    # Parity across builds
    compare_builds(infer_hyp, infer_chrono, e=e)


    # Run neutrality checks for AS and AL across H, O, and W 
    for a in sorted(AUDIT_VARS):          # deterministic order
        for t in ['H','O','W']:           # fixed order for readability
            print(f"\nNeutrality check for {a} across {t} at e={EVIDENCE_BASE}")
            check_neutrality(infer_hyp, audit_var=a, target_var=t, e=EVIDENCE_BASE)

# Script entrypoint 
if __name__ == '__main__':
    run_quick_demo()


Graphviz render skipped: History not found: [('Lfound', 'Yes')]

Table 3 — Posterior over H | Blood=Yes
--------------------------------------
Victim accidental              0.3959 (39.6%)
Suicide                        0.2849 (28.5%)
Offender self-defense          0.1188 (11.9%)
Offender deliberate            0.1114 (11.1%)
Offender accidental            0.0891 (8.9%)

Next best test by EIG | Blood=Yes
-------------------------------
Best: Lfound  (EIG=0.1109 bits)
Test           EIG(bits)     
----------------------------
Lfound         0.1109        
Ltime          0.0755        
O              0.0456        
W              0.0404        
DNA            0.0321        
Hand           0.0000        
BP             0.0000        
AS             0.0000        
AL             -0.0000       
LR computation skipped: likelihood_ratios_binary() got an unexpected keyword argument 'given_evidence'

Sensitivity — EIG(DNA;H | Blood=Yes): 0.0321 bits

Conditional sink masses — P(Hand | Blood=Yes)