## Table 3

In [13]:
import re
import pandas as pd
import numpy as np

###############################################################################
# 0)  CONFIG
###############################################################################
S_FOCAL = 8                     # “s = 8” in the paper
REDUCED_FILE = "proper_OR.csv"    # path to the proper‑OR file you generated

###############################################################################
# 1)  READ  &  PARSE  (r , z)
###############################################################################
df = pd.read_csv(REDUCED_FILE)

# the third column looks like "(7, 12)" → split it into integers
def parse_pair(txt):
    r, z = map(int, re.findall(r"\d+", txt))
    return pd.Series({"r": r, "z": z})

df[["r", "z"]] = df.iloc[:, 2].apply(parse_pair)

###############################################################################
# 2)  BUILD THE TABLE
###############################################################################
records = []
for t_v in sorted(df["z"].unique()):          # 11 unique 2→3 times
    # Number of 2→3 events at t_v
    dN = (df["z"] == t_v).sum()

    # … among those who were still in state 1 at s = 8  (⇔  r_i ≥ 8)
    dN_s8 = ((df["z"] == t_v) & (df["r"] >= S_FOCAL)).sum()

    # Risk set just before t_v:  r_i < t_v  &  z_i ≥ t_v
    at_risk = (df["r"] < t_v) & (df["z"] >= t_v)
    Y  = at_risk.sum()

    # … and who were in state 1 at s = 8
    Y_s8 = (at_risk & (df["r"] >= S_FOCAL)).sum()

    records.append((t_v, dN, dN_s8, Y, Y_s8))

table3 = pd.DataFrame(
    records,
    columns=[
        "t_v",
        r"∑ dN_i(t_v)",
        r"∑ ​I(8 ≤ r_i) dN_i(t_v)",
        r"∑ Y_i(t_v)",
        r"∑ ​I(8 ≤ r_i) Y_i(t_v)",
    ],
)

###############################################################################
# 3)  DISPLAY
###############################################################################
print(table3.to_string(index=False))

 t_v  ∑ dN_i(t_v)  ∑ ​I(8 ≤ r_i) dN_i(t_v)  ∑ Y_i(t_v)  ∑ ​I(8 ≤ r_i) Y_i(t_v)
  12            2                        0          16                      10
  13            2                        0          17                      13
  15            3                        3          22                      20
  16            4                        3          20                      18
  17            2                        1          16                      15
  18            6                        6          14                      14
  19            1                        1           8                       8
  20            2                        2           7                       7
  21            2                        2           5                       5
  22            1                        1           3                       3
  23            2                        2           2                       2


## Table 5

In [2]:
import pandas as pd
from collections import defaultdict

# Load CSV files
improper_or_df = pd.read_csv("improper_OR.csv")
improper_mi_df = pd.read_csv("improper_MI.csv")

# Extract r from Reduced (r, 23) column
reduced_rows = improper_mi_df["Reduced (r,23)"].str.extract(r"\((\d+),\s*23\)").astype(int)
improper_mi_df["r"] = reduced_rows[0]

# Count multiplicities from improper_OR: match (L, L+1] x (23, ∞)
counts = defaultdict(int)
for _, row in improper_or_df.iterrows():
    L, R, W = row['L'], row['R'], row['W']
    if W == 23 and R == L + 1:
        counts[R] += 1  # R is r in (r, 23)

# Filter MI rows with matching r and create table format
table5_df = improper_mi_df[improper_mi_df["r"].isin(counts.keys())].copy()
table5_df["multiplicity"] = table5_df["r"].map(counts)
table5_df["Improper OR=Improper MI"] = table5_df["r"].apply(lambda r: f"({r-1},{r}) × (23,∞)")
table5_df["reduced"] = table5_df["r"].apply(lambda r: f"({r}, 23)")

# Final table with desired columns
table5 = table5_df[["Improper OR=Improper MI", "multiplicity", "reduced", "p"]].copy()

# Optional: add total row
total_multiplicity = table5["multiplicity"].sum()
table5.loc["total"] = ["total", total_multiplicity, "", ""]

# Print the table
print(table5.to_string(index=False))

Improper OR=Improper MI  multiplicity  reduced       p
        (9,10) × (23,∞)             2 (10, 23)  0.0639
       (10,11) × (23,∞)             7 (11, 23)  0.1649
       (11,12) × (23,∞)             1 (12, 23)  0.0379
       (12,13) × (23,∞)             7 (13, 23)  0.1682
       (14,15) × (23,∞)             8 (15, 23)  0.1548
       (15,16) × (23,∞)             5 (16, 23)  0.0619
                  total            30                 


## Table 6

In [3]:
import pandas as pd

# --- Load ---
or_df = pd.read_csv("improper_OR.csv")      # columns: L, R, W, Z
mi_df = pd.read_csv("improper_MI.csv")      # columns: Improper MI, p, Reduced (r,23)

# --- Parse MI endpoints from strings like "(5,7] × (23, ∞)" ---
uv = mi_df["Improper MI"].str.extract(r"\((\d+)\s*,\s*(\d+)\]")
mi_df["u"] = uv[0].astype(int)
mi_df["v"] = uv[1].astype(int)
mi_keep = mi_df[["u", "v", "Improper MI", "Reduced (r,23)"]].copy()

# --- Candidate improper ORs: W==23 (improper); we’ll later split adjacent vs non-adjacent ---
cand_or = or_df[or_df["W"] == 23][["L", "R"]].drop_duplicates()

# --- Containment: Type 1 = contains exactly one improper MI ---
cross = cand_or.merge(mi_keep, how="cross")
contained = cross[(cross["L"] <= cross["u"]) & (cross["v"] <= cross["R"])]

# Count # of improper MIs each OR contains
mi_counts = contained.groupby(["L", "R"]).size().reset_index(name="num_improper_MI")

# Keep only Type-1 ORs (exactly one improper MI)
type1_or = mi_counts[mi_counts["num_improper_MI"] == 1][["L", "R"]]

# --- Table 6 subset: NOT of the form (L, L+1) ---
table6_or = type1_or[type1_or["R"] != type1_or["L"] + 1]

# Attach the unique MI row for each selected OR
one_mi_detail = table6_or.merge(contained, on=["L", "R"], how="left")

# Multiplicity of each OR in raw data (how many times it appears with W=23)
mult = (or_df[or_df["W"] == 23]
        .groupby(["L", "R"]).size().reset_index(name="multiplicity"))

# Build Table 6
final = (one_mi_detail.merge(mult, on=["L", "R"], how="left")
         .drop_duplicates(subset=["L","R"])            # exactly one MI per OR
         .sort_values(by=["R", "L"])
         .assign(
             **{
                 "Improper OR":    lambda d: d.apply(lambda r: f"({r.L},{r.R}) × (23,∞)", axis=1),
                 "reduced (r,23)": lambda d: d["Reduced (r,23)"],
             }
         )[["Improper OR", "Improper MI", "multiplicity", "reduced (r,23)"]]
         .reset_index(drop=True))

# Add total row — multiplicities should sum to 9
table6 = pd.concat(
    [final, pd.DataFrame([["total", "", final["multiplicity"].sum(), ""]], columns=final.columns)],
    ignore_index=True
)

print(table6.to_string(index=False))

     Improper OR       Improper MI  multiplicity reduced (r,23)
  (1,7) × (23,∞)   (5,7] × (23, ∞)             1         (7,23)
  (5,7) × (23,∞)   (5,7] × (23, ∞)             2         (7,23)
  (7,9) × (23,∞)   (8,9] × (23, ∞)             1         (9,23)
(12,14) × (23,∞) (12,13] × (23, ∞)             2        (13,23)
(13,15) × (23,∞) (14,15] × (23, ∞)             3        (15,23)
           total                               9               


## Table 6a

In [4]:
import pandas as pd
import re
from pathlib import Path

# -------------------------
# 1) Load inputs
# -------------------------
table3 = pd.read_csv("table3.csv")
table5 = pd.read_csv("table5.csv")
table6 = pd.read_csv("table6.csv")

# -------------------------
# 2) Helpers
# -------------------------
def find_col(df, substrings):
    """Return the first column whose name contains ALL substrings (case-insensitive)."""
    subs = [s.lower() for s in substrings]
    for c in df.columns:
        cl = c.lower()
        if all(s in cl for s in subs):
            return c
    raise KeyError(f"Column with substrings {substrings} not found in {list(df.columns)}")

def extract_r(val):
    """Extract integer r from '(r,23)' allowing spaces."""
    if not isinstance(val, str):
        return None
    m = re.search(r"\((\d+)\s*,\s*23\)", val)
    return int(m.group(1)) if m else None

# -------------------------
# 3) Parse reduced points & multiplicities from Tables 5 & 6
# -------------------------
col_mult_5 = find_col(table5, ["multiplicity"])
col_red_5  = find_col(table5, ["reduced"])            # e.g., "reduced"
r_mult_5 = (
    table5.assign(r=table5[col_red_5].apply(extract_r))
          .dropna(subset=["r"])
          [["r", col_mult_5]]
          .rename(columns={col_mult_5: "multiplicity"})
)

col_mult_6 = find_col(table6, ["multiplicity"])
col_red_6  = find_col(table6, ["reduced"])            # e.g., "reduced (r,23)"
r_mult_6 = (
    table6.assign(r=table6[col_red_6].apply(extract_r))
          .dropna(subset=["r"])
          [["r", col_mult_6]]
          .rename(columns={col_mult_6: "multiplicity"})
)

# Combine 5 & 6, then aggregate multiplicities per r
r_mult = (
    pd.concat([r_mult_5, r_mult_6], ignore_index=True)
      .groupby("r", as_index=False)["multiplicity"].sum()
      .sort_values("r")
      .reset_index(drop=True)
)

# Precompute cumulative sums over r for fast lookups of r < t_v
r_mult["cum_all"]  = r_mult["multiplicity"].cumsum()
r_mult["ge8"]      = (r_mult["r"] >= 8).astype(int) * r_mult["multiplicity"]
r_mult["cum_ge8"]  = r_mult["ge8"].cumsum()

def cum_less_than(tv, colname):
    """Cumulative sum of selected multiplicities for all r < tv."""
    eligible = r_mult[r_mult["r"] < tv]
    return int(eligible[colname].iloc[-1]) if not eligible.empty else 0

# -------------------------
# 4) Read needed columns from Table 3 and build Table 6a
# -------------------------
col_tv      = find_col(table3, ["t_v"])
col_dN      = find_col(table3, ["∑", "dn"])              # "∑ dN_i(t_v)"
col_I8dN    = find_col(table3, ["∑", "i(8", "dn"])       # "∑ I(8 ≤ r_i) dN_i(t_v)"
col_Y       = find_col(table3, ["∑", "y_i"])             # "∑ Y_i(t_v)"
col_I8Y     = find_col(table3, ["∑", "i(8", "y_i"])      # "∑ I(8 ≤ r_i) Y_i(t_v)"

rows = []
for _, r in table3.iterrows():
    tv    = int(r[col_tv])
    dN    = int(r[col_dN])
    I8dN  = int(r[col_I8dN])
    Y0    = int(r[col_Y])
    I8Y0  = int(r[col_I8Y])

    add_Y   = cum_less_than(tv, "cum_all")   # all r<tv
    add_I8Y = cum_less_than(tv, "cum_ge8")   # only r>=8, r<tv

    rows.append({
        "t_v": tv,
        "∑ dN_i(t_v)": dN,
        "∑ I(8 ≤ r_i) dN_i(t_v)": I8dN,
        "∑ Y_i(t_v)": f"{Y0} + {add_Y} = {Y0 + add_Y}",
        "∑ I(8 ≤ r_i) Y_i(t_v)": f"{I8Y0} + {add_I8Y} = {I8Y0 + add_I8Y}",
    })

table6a = pd.DataFrame(rows)

# -------------------------
# 5) Save result
# -------------------------
table6a.to_csv("table6a.csv", index=False)
print(table6a.to_string(index=False))

 t_v  ∑ dN_i(t_v)  ∑ I(8 ≤ r_i) dN_i(t_v)   ∑ Y_i(t_v) ∑ I(8 ≤ r_i) Y_i(t_v)
  12            2                       0 16 + 13 = 29          10 + 10 = 20
  13            2                       0 17 + 14 = 31          13 + 11 = 24
  15            3                       3 22 + 23 = 45          20 + 20 = 40
  16            4                       3 20 + 34 = 54          18 + 31 = 49
  17            2                       1 16 + 39 = 55          15 + 36 = 51
  18            6                       6 14 + 39 = 53          14 + 36 = 50
  19            1                       1  8 + 39 = 47           8 + 36 = 44
  20            2                       2  7 + 39 = 46           7 + 36 = 43
  21            2                       2  5 + 39 = 44           5 + 36 = 41
  22            1                       1  3 + 39 = 42           3 + 36 = 39
  23            2                       2  2 + 39 = 41           2 + 36 = 38


## Test Statistic

In [11]:
import pandas as pd
import re

# --- Load & normalize columns ---
df = pd.read_csv("table6a.csv").rename(columns={
    't_v': 't_v',
    '∑ dN_i(t_v)': 'sum_dN',
    '∑ I(8 ≤ r_i) dN_i(t_v)': 'sum_I_dN',
    '∑ Y_i(t_v)': 'sum_Y',
    '∑ I(8 ≤ r_i) Y_i(t_v)': 'sum_I_Y'
})

def to_number(x):
    """Extract the numeric total from entries like '16 + 13 = 29' or pass through numbers."""
    if pd.isna(x): 
        return pd.NA
    if isinstance(x, (int, float)): 
        return x
    s = str(x)
    nums = re.findall(r'(-?\d+(?:\.\d+)?)', s)
    return float(nums[-1]) if nums else pd.NA

for c in ['sum_dN', 'sum_I_dN', 'sum_Y', 'sum_I_Y']:
    df[c] = df[c].apply(to_number).astype(float)

# --- Compute U(t_v) ---
# U(t_v) = sum_I_dN - sum_I_Y * (sum_dN / sum_Y)
df['U_t'] = df['sum_I_dN'] - df['sum_I_Y'] * (df['sum_dN'] / df['sum_Y'])

print(df)

    t_v  sum_dN  sum_I_dN  sum_Y  sum_I_Y       U_t
0    12     2.0       0.0   29.0     20.0 -1.379310
1    13     2.0       0.0   31.0     24.0 -1.548387
2    15     3.0       3.0   45.0     40.0  0.333333
3    16     4.0       3.0   54.0     49.0 -0.629630
4    17     2.0       1.0   55.0     51.0 -0.854545
5    18     6.0       6.0   53.0     50.0  0.339623
6    19     1.0       1.0   47.0     44.0  0.063830
7    20     2.0       2.0   46.0     43.0  0.130435
8    21     2.0       2.0   44.0     41.0  0.136364
9    22     1.0       1.0   42.0     39.0  0.071429
10   23     2.0       2.0   41.0     38.0  0.146341


In [12]:
# --- Compute total U ---
U = df['U_t'].sum()
print(f"Test Statistics: {round(U, 4)}")

Test Statistics: -3.1905


## Table 6b

In [15]:
import pandas as pd
import numpy as np

# === Load the input table (Table 6a stats) ===
df = pd.read_csv("table6a_stats.csv")

# Extract relevant columns
tv = df["t_v"].astype(int)         # time values
nV = df["sum_Y"].astype(int)       # n_V = sum of Y_i(t_v)
nv1 = df["sum_I_Y"].astype(int)    # n_{v1} = sum of I(8 <= r_i) Y_i(t_v)

# Compute the formula: n_{v1}(n_V - n_{v1}) / n_V^2
value = nv1 * (nV - nv1) / (nV ** 2)

# Build Table 6b DataFrame
table6b = pd.DataFrame({
    "t_v": tv,
    "n_V": nV,
    "n_{v1}": nv1,
    "n_{v1}(n_V - n_{v1})/n_V^2": value.round(6)
})

# Save to CSV
table6b.to_csv("table6b.csv", index=False)

print("Table 6b saved to table6b.csv")
print(table6b)

Table 6b saved to table6b.csv
    t_v  n_V  n_{v1}  n_{v1}(n_V - n_{v1})/n_V^2
0    12   29      20                    0.214031
1    13   31      24                    0.174818
2    15   45      40                    0.098765
3    16   54      49                    0.084019
4    17   55      51                    0.067438
5    18   53      50                    0.053400
6    19   47      44                    0.059756
7    20   46      43                    0.060964
8    21   44      41                    0.063533
9    22   42      39                    0.066327
10   23   41      38                    0.067817


In [16]:
# --- Compute standardized test statistic ---
df["var_term"] = df.apply(lambda row: (row["sum_I_Y"] * (row["sum_Y"] - row["sum_I_Y"])) / (row["sum_Y"]**2) if row["sum_Y"] > 0 else 0, axis=1)
V_U = df["var_term"].sum()

U_std = U / np.sqrt(V_U) if V_U > 0 else np.nan

print(f"Standardized Test Statistics: {round(U_std, 4)}")

Standardized Test Statistics: -3.1733


## Table 7

In [15]:
import re
import pandas as pd

# Ensure Pandas prints full strings without truncation
pd.set_option("display.max_colwidth", None)

# === 1) Load inputs ===
improper_or = pd.read_csv("improper_OR.csv")
table4 = pd.read_csv("table4.csv")

# Clean headers
improper_or.columns = [c.strip() for c in improper_or.columns]
table4.columns = [c.strip() for c in table4.columns]

# === 2) Parse r from Reduced column and convert p ===
def extract_r(cell: str) -> int:
    m = re.search(r"\(\s*(\d+)\s*,\s*23\s*\)", str(cell))
    if not m:
        raise ValueError(f"Cannot parse Reduced (r,23): {cell}")
    return int(m.group(1))

table4["_r"] = table4["Reduced (r,23)"].apply(extract_r)
table4["_p"] = pd.to_numeric(table4["p"], errors="coerce")
r_to_p = table4[["_r", "_p"]]

# === 3) Collapse improper_OR into unique (L,R] with multiplicity ===
improper_or["L"] = improper_or["L"].astype(int)
improper_or["R"] = improper_or["R"].astype(int)

grouped = (
    improper_or.groupby(["L", "R"], as_index=False)
    .size()
    .rename(columns={"size": "multiplicity"})
    .sort_values(["L", "R"])
    .reset_index(drop=True)
)

# === 4) Compute probability sums ===
def sum_for_interval(L: int, R: int):
    mask = (r_to_p["_r"] > L) & (r_to_p["_r"] <= R)
    probs = r_to_p.loc[mask, "_p"].dropna().tolist()
    total = float(sum(probs))
    terms = " + ".join(f"{p:.4f}".rstrip("0").rstrip(".") for p in probs)
    expr = f"{terms} = {total:.4f}".rstrip("0").rstrip(".") if terms else "0 = 0"
    return total, expr, len(probs)

totals, exprs, counts = [], [], []
for _, row in grouped.iterrows():
    total, expr, count = sum_for_interval(int(row["L"]), int(row["R"]))
    totals.append(total)
    exprs.append(expr)
    counts.append(count)

grouped["Total probability (numeric)"] = totals
grouped["Sum breakdown"] = exprs
grouped["num_MI_rects"] = counts

# === 5) Keep only type 2 ORs (more than one MI rectangle) ===
table7 = grouped[grouped["num_MI_rects"] > 1].copy()

# Nicely formatted interval column
table7.insert(0, "Imp OR (L,R]", table7.apply(lambda r: f"({int(r['L'])}, {int(r['R'])}]", axis=1))

# Drop helper columns and reset index
table7 = table7.drop(columns=["L", "R", "num_MI_rects"]).reset_index(drop=True)

# === 6) Save to CSV ===
table7.to_csv("table7.csv", index=False)

print("Table 7 (type 2 improper ORs) saved to table7.csv")
print(table7.to_string(index=False))

Table 7 (type 2 improper ORs) saved to table7.csv
Imp OR (L,R]  multiplicity  Total probability (numeric)                                                                  Sum breakdown
     (1, 11]             1                       0.2960                                      0.0437 + 0.0235 + 0.0639 + 0.1649 = 0.296
     (1, 12]             2                       0.3339                            0.0437 + 0.0235 + 0.0639 + 0.1649 + 0.0379 = 0.3339
     (1, 13]             2                       0.5021                   0.0437 + 0.0235 + 0.0639 + 0.1649 + 0.0379 + 0.1682 = 0.5021
     (1, 14]             3                       0.5021                   0.0437 + 0.0235 + 0.0639 + 0.1649 + 0.0379 + 0.1682 = 0.5021
     (1, 15]             2                       0.6569          0.0437 + 0.0235 + 0.0639 + 0.1649 + 0.0379 + 0.1682 + 0.1548 = 0.6569
     (1, 16]             1                       0.7188 0.0437 + 0.0235 + 0.0639 + 0.1649 + 0.0379 + 0.1682 + 0.1548 + 0.0619 = 0.7188
     

In [32]:
# end_to_end_or_pipeline.py
# -*- coding: utf-8 -*-
"""
End-to-end replication of the instruction images:

(1) Build OR mini-tables (per OR):
    - idx_in_or, MI interval, unconditional prob (from Table 4), conditional prob within OR
(2) Impute (draw) MI rectangles for each OR according to 'multiplicity' (Table 7)
(3) Generate I*-style indicator tables (I1..I5 + Y9/Y10 blocks)

Robust matching:
- Normalizes interval labels (whitespace, commas, bracket glyphs)
- Matches MI either by exact endpoints OR by containment when Table 7 lists a super-interval
- Parses probabilities like "3.2%" or "1,234"

Inputs (place next to this script or edit BASE below):
- table4.csv : MUST have columns ['Improper MI','p']
- table7.csv : MUST have columns ['Imp OR (L,R]','Sum breakdown'] and optional ['multiplicity']
Optional:
- yk_patterns.csv : rows 'Y1'.., cols '12'..'23' (0/1) to specify Y_k(t_v)
- mask_by_tv.csv  : one row with cols '12'..'23' (0/1) for I(8 ≤ r_i) mask

Outputs:
- ./out/or_probs/OR*.csv, OR_summary.csv
- ./out/or_draws/OR*_draws.csv, all_draws.csv
- ./out/I_tables/*.csv
"""

from pathlib import Path
import re
from typing import List, Dict, Tuple, Optional
import numpy as np
import pandas as pd

# ------------------------- CONFIG -------------------------
BASE = Path(".")
T4_PATH = BASE / "table4.csv"
T7_PATH = BASE / "table7.csv"
YK_PATH = BASE / "yk_patterns.csv"   # optional
MASK_PATH = BASE / "mask_by_tv.csv"  # optional

OUT = BASE / "out"
OR_PROB_DIR = OUT / "or_probs"
OR_DRAW_DIR = OUT / "or_draws"
I_DIR       = OUT / "I_tables"
for d in (OR_PROB_DIR, OR_DRAW_DIR, I_DIR):
    d.mkdir(parents=True, exist_ok=True)

TV_COLS = [str(v) for v in range(12, 24)]  # t_v = 12..23

# ------------------------- HELPERS -------------------------
def clean_prob(x) -> float:
    """Parse 'p' values that might include '%' or thousands separators."""
    if pd.isna(x):
        return np.nan
    s = str(x).strip()
    if s.endswith("%"):
        try:
            return float(s[:-1].replace(",", "")) / 100.0
        except:
            return np.nan
    try:
        return float(s.replace(",", ""))
    except:
        return pd.to_numeric(s, errors="coerce")

def normalize_interval_label(s: str) -> str:
    """Normalize labels like '(5, 7]' -> '(5,7]'; strip 'MI'/'Improper' prefixes; unify union glyph."""
    if pd.isna(s):
        return ""
    s0 = str(s).strip()
    s0 = re.sub(r"^\s*(MI|Improper|Rect|Rectangle)\s*", "", s0, flags=re.I)
    s0 = re.sub(r"\s*,\s*", ",", s0)
    s0 = (s0.replace("（", "(").replace("）", ")")
             .replace("，", ",").replace("［", "[").replace("］", "]")
             .replace("∪", "U").replace("u", "U"))
    s0 = re.sub(r"\s+", "", s0)
    return s0

def parse_endpoints(label: str) -> Optional[Tuple[str, int, int, str]]:
    """Return (left_bracket, a, b, right_bracket) for '(a,b]' style; None if not parseable."""
    lab = normalize_interval_label(label)
    m = re.match(r"^([\(\[])\s*(\d+)\s*,\s*(\d+)\s*([\)\]])$", lab)
    if not m:
        return None
    return (m.group(1), int(m.group(2)), int(m.group(3)), m.group(4))

def bracket_insensitive_key(label: str) -> Optional[Tuple[int, int]]:
    """Reduce interval to (a,b) ignoring bracket types so '(5,7]' matches '[5,7)' etc."""
    ep = parse_endpoints(label)
    if not ep:
        return None
    _, a, b, _ = ep
    return (a, b)

def contains(super_interval: str, sub_interval: str) -> bool:
    """Check if sub_interval is fully contained in super_interval (numeric only, ignore brackets)."""
    S = parse_endpoints(super_interval)
    s = parse_endpoints(sub_interval)
    if not S or not s:
        return False
    _, A, B, _ = S
    _, a, b, _ = s
    return (A <= a) and (b <= B)

def parse_union_list(s: str) -> List[str]:
    ss = str(s).replace("∪", "U").replace("u", "U")
    return [normalize_interval_label(p) for p in re.split(r"\s*U\s*", ss) if p.strip()]

# ------------------------- LOAD INPUTS -------------------------
# Table 4
t4 = pd.read_csv(T4_PATH)
if "Improper MI" not in t4.columns or "p" not in t4.columns:
    raise RuntimeError("table4.csv must include columns 'Improper MI' and 'p'.")
t4["_interval_raw"]  = t4["Improper MI"]
t4["_interval_norm"] = t4["_interval_raw"].map(normalize_interval_label)
t4["_endpoints_key"] = t4["_interval_norm"].map(bracket_insensitive_key)
t4["_prob_raw"]      = t4["p"]
t4["_prob"]          = t4["_prob_raw"].map(clean_prob)

# Table 7
t7 = pd.read_csv(T7_PATH)
if "Imp OR (L,R]" not in t7.columns or "Sum breakdown" not in t7.columns:
    raise RuntimeError("table7.csv must include columns 'Imp OR (L,R]' and 'Sum breakdown'.")
t7["_or_id"]         = np.arange(1, len(t7)+1)
t7["_or_label"]      = t7["Imp OR (L,R]"].map(normalize_interval_label)
t7["_or_endpoints"]  = t7["_or_label"].map(bracket_insensitive_key)
t7["_multiplicity"]  = pd.to_numeric(t7.get("multiplicity", 1), errors="coerce").fillna(1).astype(int)
t7["_components_norm"] = t7["Sum breakdown"].map(parse_union_list)

t4_key_set = set(k for k in t4["_endpoints_key"] if k is not None)

# ---------------------- OR BLOCK (ROBUST) ----------------------
def robust_or_block(components_norm: List[str], t4_df: pd.DataFrame) -> pd.DataFrame:
    """
    Build an OR block by:
      1) exact endpoint matches when the component is an MI bin,
      2) otherwise, include all MI rectangles fully contained in the component (super-interval).
    """
    chosen_idx = []
    for comp in components_norm:
        key = bracket_insensitive_key(comp)
        if key and key in t4_key_set:
            chosen_idx.extend(t4_df.index[t4_df["_endpoints_key"] == key].tolist())
        else:
            for j, lab in enumerate(t4_df["_interval_norm"]):
                if contains(comp, lab):
                    chosen_idx.append(j)
    chosen_idx = sorted(set(chosen_idx))
    if not chosen_idx:
        return pd.DataFrame(columns=["idx_in_or", "interval", "uncond_prob", "cond_prob_in_or"])

    sub = t4_df.loc[chosen_idx].copy().reset_index(drop=True)
    sub["idx_in_or"] = np.arange(1, len(sub) + 1)
    sub["interval"] = sub["_interval_norm"]
    sub["uncond_prob"] = sub["_prob"]
    total = sub["uncond_prob"].sum(skipna=True)
    sub["cond_prob_in_or"] = sub["uncond_prob"] / total if (pd.notna(total) and total > 0) else np.nan
    return sub[["idx_in_or", "interval", "uncond_prob", "cond_prob_in_or"]]

# -------------------- OR MINI-TABLES & SUMMARY --------------------
summary_rows: List[Dict] = []
for _, r in t7.iterrows():
    oid   = int(r["_or_id"])
    comps = r["_components_norm"]
    blk   = robust_or_block(comps, t4)
    blk.to_csv(OR_PROB_DIR / f"OR{oid}.csv", index=False)
    summary_rows.append({
        "OR_id": oid,
        "OR_interval": r["_or_label"],
        "multiplicity": int(r["_multiplicity"]),
        "components_count": len(comps),
        "components_union": " ∪ ".join(comps),
        "num_mi_in_or": int(len(blk)) if not blk.empty else 0,
        "sum_uncond_prob": float(blk["uncond_prob"].sum()) if not blk.empty else np.nan
    })

pd.DataFrame(summary_rows).sort_values("OR_id").to_csv(OR_PROB_DIR / "OR_summary.csv", index=False)

# ---------------------------- IMPUTATION ----------------------------
rng = np.random.default_rng(20250910)  # fixed seed for reproducibility
draw_records: List[Dict] = []

for _, r in t7.iterrows():
    oid   = int(r["_or_id"])
    mult  = int(r["_multiplicity"])
    blk   = robust_or_block(r["_components_norm"], t4)

    if blk.empty or mult <= 0:
        pd.DataFrame(columns=["draw_num", "mi_idx_in_or", "mi_interval", "prob_used"]).to_csv(
            OR_DRAW_DIR / f"OR{oid}_draws.csv", index=False
        )
        continue

    probs  = blk["cond_prob_in_or"].to_numpy(dtype=float)
    labels = blk["interval"].tolist()
    idxs   = blk["idx_in_or"].to_numpy()

    # normalize or fallback to uniform
    ps = np.nansum(probs)
    if not np.isfinite(ps) or ps <= 0:
        probs = np.ones_like(probs, dtype=float) / len(probs)
    else:
        probs = probs / ps

    draws = rng.choice(len(labels), size=mult, p=probs, replace=True)

    per_or = pd.DataFrame({
        "draw_num":     np.arange(1, mult+1),
        "mi_idx_in_or": idxs[draws],
        "mi_interval":  [labels[d] for d in draws],
        "prob_used":    probs[draws],
    })
    per_or.to_csv(OR_DRAW_DIR / f"OR{oid}_draws.csv", index=False)

    for row2 in per_or.itertuples(index=False):
        draw_records.append({
            "OR_id":        oid,
            "draw_num":     int(row2.draw_num),
            "mi_idx_in_or": int(row2.mi_idx_in_or),
            "mi_interval":  row2.mi_interval,
            "prob_used":    float(row2.prob_used),
        })

all_draws = pd.DataFrame(draw_records, columns=["OR_id","draw_num","mi_idx_in_or","mi_interval","prob_used"])
if not all_draws.empty:
    all_draws = all_draws.sort_values(["OR_id","draw_num"])
all_draws.to_csv(OR_DRAW_DIR / "all_draws.csv", index=False)

# ---------------------------- I*-TABLES ----------------------------
# Load Yk patterns if available; else default to all ones.
YMAP: Dict[str, np.ndarray] = {}
if YK_PATH.exists():
    yk = pd.read_csv(YK_PATH)
    label_col = yk.columns[0]
    ok_cols = [c for c in yk.columns if str(c) in TV_COLS]
    for _, rr in yk.iterrows():
        name = str(rr[label_col]).strip()
        m = re.search(r"(\d+)", name)
        if not m:
            continue
        key = f"Y{int(m.group(1))}"
        vec = rr[ok_cols].to_numpy(dtype=int)
        if vec.size == len(TV_COLS):
            YMAP[key] = vec
else:
    for k in range(1, 11):  # default Y1..Y10 = all ones
        YMAP[f"Y{k}"] = np.ones(len(TV_COLS), dtype=int)

# Mask for I(8 ≤ r_i): use mask_by_tv.csv if present; else default step at t_v >= 16
if MASK_PATH.exists():
    mdf = pd.read_csv(MASK_PATH)
    mask = None
    for _, rr in mdf.iterrows():
        vals = rr[[c for c in mdf.columns if str(c) in TV_COLS]].to_numpy(dtype=int)
        if vals.size == len(TV_COLS):
            mask = vals
            break
    mask_I_8_le_r = mask if mask is not None else np.array([1 if int(c) >= 16 else 0 for c in TV_COLS])
else:
    mask_I_8_le_r = np.array([1 if int(c) >= 16 else 0 for c in TV_COLS], dtype=int)

def yrow(ykey: str, masked: bool) -> np.ndarray:
    base = YMAP.get(ykey, np.zeros(len(TV_COLS), dtype=int))
    return (base & mask_I_8_le_r) if masked else base

def write_I(specs: List[Tuple[str, str, bool]], name: str):
    rows, idx = [], []
    for label, ykey, masked in specs:
        rows.append(yrow(ykey, masked))
        idx.append(label)
    df = pd.DataFrame(rows, columns=TV_COLS, index=idx)
    df.to_csv(I_DIR / f"{name}.csv")
    return df

# Row specs from the images
I1_specs = [
    ("if 1 ≤ i ≤ 4, Y1(t_v)", "Y1", False),
    ("if i = 1, I(8 ≤ r_i)Y1(t_v)", "Y1", True),
    ("if 2 ≤ i ≤ 4, I(8 ≤ r_i)Y1(t_v)", "Y1", True),
]
I2_specs = [
    ("if 1 ≤ i ≤ 4, Y2(t_v)", "Y2", False),
    ("if i = 5, Y2(t_v)", "Y2", False),
    ("if i = 1, I(8 ≤ r_i)Y2(t_v)", "Y2", True),
    ("if 2 ≤ i ≤ 4, I(8 ≤ r_i)Y2(t_v)", "Y2", True),
    ("if i = 5, I(8 ≤ r_i)Y2(t_v)", "Y2", True),
]
I3_specs = [
    ("if 1 ≤ i ≤ 4, Y3(t_v)", "Y3", False),
    ("if i = 5, Y3(t_v)", "Y3", False),
    ("if i = 6, Y3(t_v)", "Y3", False),
    ("if i = 1, I(8 ≤ r_i)Y3(t_v)", "Y3", True),
    ("if 2 ≤ i ≤ 4, I(8 ≤ r_i)Y3(t_v)", "Y3", True),
    ("if i = 5, I(8 ≤ r_i)Y3(t_v)", "Y3", True),
    ("if i = 6, I(8 ≤ r_i)Y3(t_v)", "Y3", True),
]
I4_specs = [
    ("if 1 ≤ i ≤ 4, Y7(t_v)", "Y7", False),
    ("if i = 5, Y7(t_v)", "Y7", False),
    ("if i = 6, Y7(t_v)", "Y7", False),
    ("if i = 1, I(8 ≤ r_i)Y7(t_v)", "Y7", True),
    ("if 2 ≤ i ≤ 4, I(8 ≤ r_i)Y7(t_v)", "Y7", True),
    ("if i = 5, I(8 ≤ r_i)Y7(t_v)", "Y7", True),
    ("if i = 6, I(8 ≤ r_i)Y7(t_v)", "Y7", True),
]
I5_specs = [
    ("if 1 ≤ i ≤ 4, Y5(t_v)", "Y5", False),
    ("if i = 5, Y5(t_v)", "Y5", False),
    ("if i = 6, Y5(t_v)", "Y5", False),
    ("if i = 1, I(8 ≤ r_i)Y5(t_v)", "Y5", True),
    ("if 2 ≤ i ≤ 4, I(8 ≤ r_i)Y5(t_v)", "Y5", True),
    ("if i = 5, I(8 ≤ r_i)Y5(t_v)", "Y5", True),
    ("if i = 6, I(8 ≤ r_i)Y5(t_v)", "Y5", True),
    ("if i = 7, Y5(t_v)", "Y5", False),  # explicitly shown as missing row in image
]
Y9_specs = [
    ("i = 1 or 2, Y9(t_v)", "Y9", False),
    ("i = 1 or 2, I(8 ≤ r_i)Y9(t_v)", "Y9", True),
]
Y10_specs = [
    ("1 ≤ i ≤ 3, Y10(t_v)", "Y10", False),
    ("i = 4, Y10(t_v)", "Y10", False),
    ("i = 5, Y10(t_v)", "Y10", False),
    ("i = 6, Y10(t_v)", "Y10", False),
]

write_I(I1_specs, "I1")
write_I(I2_specs, "I2")
write_I(I3_specs, "I3")
write_I(I4_specs, "I4")
write_I(I5_specs, "I5")
write_I(Y9_specs, "Y9_block")
write_I(Y10_specs, "Y10_block")

# --------------------------- DONE ---------------------------
print("✓ OR mini-tables     ->", OR_PROB_DIR / "OR_summary.csv")
print("✓ Imputed draws      ->", OR_DRAW_DIR / "all_draws.csv", f"(rows={len(all_draws)})")
print("✓ I*-tables (CSVs)   ->", I_DIR)
print("Notes:")
print(" - If OR blocks are empty, check that 'Sum breakdown' entries numerically contain MI bins in table4.")
print(" - Add 'yk_patterns.csv' and/or 'mask_by_tv.csv' to control Y_k(t_v) and I(8 ≤ r_i).")

✓ OR mini-tables     -> out/or_probs/OR_summary.csv
✓ Imputed draws      -> out/or_draws/all_draws.csv (rows=0)
✓ I*-tables (CSVs)   -> out/I_tables
Notes:
 - If OR blocks are empty, check that 'Sum breakdown' entries numerically contain MI bins in table4.
 - Add 'yk_patterns.csv' and/or 'mask_by_tv.csv' to control Y_k(t_v) and I(8 ≤ r_i).


In [28]:
I1_specs

[('if 1 ≤ i ≤ 4, Y1(t_v)', 'Y1', False),
 ('if i = 1, I(8 ≤ r_i)Y1(t_v)', 'Y1', True),
 ('if 2 ≤ i ≤ 4, I(8 ≤ r_i)Y1(t_v)', 'Y1', True)]