# ACE Evaluation (Dataset → Embeddings → Metrics)

This notebook loads the ACE events dataset, builds embeddings, rolls episodes, and reproduces temporal metrics.
It assumes FAISS is installed (recommended for speed), but will fall back to NumPy if not.


Notes:
- This notebook produces dataset-only baselines (no live rag events).
- If you use pool_candidates.csv, ensure event_id values match ACE dataset IDs (ace-<qid>) before running.


In [None]:
from pathlib import Path
import sys
import json
import random
import calendar
from datetime import datetime, timezone, timedelta

import numpy as np
import pandas as pd
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

REPO_ROOT = None
for candidate in [Path.cwd().resolve(), *Path.cwd().resolve().parents]:
    if (candidate / 'src').exists():
        REPO_ROOT = candidate
        break
if REPO_ROOT is None:
    raise RuntimeError('Could not find repo root containing src/')
sys.path.append(str(REPO_ROOT))

import src.storage_helpers as storage

DATA_DIR = REPO_ROOT / "data"
EVENTS_PATH = DATA_DIR / "events" / "ace_events_h1_2025.jsonl"
EMBED_DIR = DATA_DIR / "embeddings"
EMBED_DIR.mkdir(parents=True, exist_ok=True)
EMBED_MODEL_NAME = "intfloat/e5-large-v2" # Used for current test/results
# EMBED_MODEL_NAME = "intfloat/e5-small-v2" # smaller/faster for limited resources - results may vary
EMBED_QUERY_PREFIX = "query: "
EMBED_PASSAGE_PREFIX = "passage: "
EMBED_MODEL_SLUG = EMBED_MODEL_NAME.replace("/", "_").replace("-", "_")
EMBED_PATH = EMBED_DIR / f"event_embeddings_{EMBED_MODEL_SLUG}.npz"

EVAL_DIR = DATA_DIR / "eval"
INPUT_DIR = EVAL_DIR / "inputs"
OUTPUT_DIR = EVAL_DIR / "outputs"
METRICS_DIR = OUTPUT_DIR / "metrics"
AUDIT_DIR = OUTPUT_DIR / "audit"
MANIFEST_DIR = OUTPUT_DIR / "manifests"

for path in [INPUT_DIR, METRICS_DIR, AUDIT_DIR, MANIFEST_DIR]:
    path.mkdir(parents=True, exist_ok=True)


TEMPORAL_QUERY_PATH = INPUT_DIR / "temporal_queries.json"
RELATIVE_QUERY_PATH = INPUT_DIR / "relative_temporal_queries.json"

BATCH_SIZE = 64
REBUILD_EMBEDDINGS = False # Set to True force re-building embeddings


In [None]:
# Reproducibility settings
EVAL_SEED = 13
print("Eval seed:", EVAL_SEED)


In [None]:
# Run manifest (reproducibility metadata)
import hashlib
import platform

RUN_MANIFEST_PATH = MANIFEST_DIR / "run_manifest.json"
manifest = {
    "timestamp_utc": datetime.utcnow().isoformat() + "Z",
    "dataset_id": "anon-user-423/ACE",
    "dataset_config": "events",
    "embed_model": EMBED_MODEL_NAME,
    "seed": EVAL_SEED,
    "python": platform.python_version(),
}
system_prompt = globals().get("SYSTEM_PROMPT", "")
if system_prompt:
    manifest["system_prompt_sha256"] = hashlib.sha256(system_prompt.encode()).hexdigest()
RUN_MANIFEST_PATH = MANIFEST_DIR / "run_manifest.json"


In [None]:
# Load dataset
ds = load_dataset("anon-user-423/ACE", "events", split="train")
print(ds)
print("Columns:", ds.column_names)
print("Sample:", ds[0])


In [None]:
# Write events to JSONL once (used by storage helpers)
EVENTS_PATH.parent.mkdir(parents=True, exist_ok=True)
if not EVENTS_PATH.exists():
    with EVENTS_PATH.open("w", encoding="utf-8") as f:
        for row in ds:
            f.write(json.dumps(dict(row), ensure_ascii=True) + "\n")
    print("Wrote", EVENTS_PATH)
else:
    print("Events file already exists:", EVENTS_PATH)


In [None]:
# Build episodes from dataset

# ------ Will take around an hour or two to run ------

episode_dir = DATA_DIR / "episodes"
if not episode_dir.exists() or not any(episode_dir.rglob("*.json")):
    summary = storage.roll_up_episodes(events_path=EVENTS_PATH, destination_dir=episode_dir, overwrite=True, show_progress=True)
    print("Episode roll-up:", summary)
else:
    print("Episodes already present:", episode_dir)


In [None]:
# Build or load embeddings
if REBUILD_EMBEDDINGS or not EMBED_PATH.exists():
    model = SentenceTransformer(EMBED_MODEL_NAME)
    event_ids = []
    vectors = []
    for start in tqdm(range(0, len(ds), BATCH_SIZE), desc="Embedding events"):
        batch = ds[start:start + BATCH_SIZE]
        texts = [f"{EMBED_PASSAGE_PREFIX}{q} {r}".strip() for q, r in zip(batch["question"], batch["response"]) ]
        emb = model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
        vectors.append(emb)
        event_ids.extend(batch["event_id"])
    vectors = np.vstack(vectors)
    np.savez(EMBED_PATH, event_ids=np.array(event_ids), vectors=vectors)
    print("Wrote", EMBED_PATH)
else:
    model = SentenceTransformer(EMBED_MODEL_NAME)
    print("Using cached embeddings:", EMBED_PATH)


storage.load_event_embeddings_npz(EMBED_PATH)
storage.register_event_embedder(lambda text: model.encode([f"{EMBED_QUERY_PREFIX}{text}"], convert_to_numpy=True)[0])
print('Event embeddings loaded:', len(storage.EVENT_EMBEDDINGS))


In [None]:
# Dense-only sanity check (embeddings + embedder)
print("Event embeddings loaded:", len(storage.EVENT_EMBEDDINGS))
print("Event embedder set:", storage.EVENT_EMBEDDER is not None)
if not storage.EVENT_EMBEDDINGS:
    print("No event embeddings loaded: dense-only will be zero.")
elif storage.EVENT_EMBEDDER is None:
    print("No event embedder registered: dense-only will be zero.")
else:
    try:
        vec = storage.EVENT_EMBEDDER("sanity check")
        print("Sample embedder output shape:", getattr(vec, "shape", None))
    except Exception as exc:
        print("Embedder call failed:", exc)


In [None]:
# Build retrieval index
index = storage.RetrievalIndex.build(
    primary_path=DATA_DIR / "events" / "custom_events.jsonl",
    extra_paths=[EVENTS_PATH],
    show_progress=True,
)
print("Index events:", len(index.events))
print("FAISS enabled:", index.faiss_index is not None)


In [None]:
# Build retriever variants
BUILD_VARIANT_INDEXES = True
LEXICAL_WEIGHT = 0.1
TIME_WEIGHT = 0.2
HALF_LIFE_DAYS = 30

def _build_index(lexical_weight, time_weight, use_dense=True):
    saved = None
    if not use_dense:
        saved = dict(storage.EVENT_EMBEDDINGS)
        storage.EVENT_EMBEDDINGS.clear()
    idx = storage.RetrievalIndex.build(
        primary_path=DATA_DIR / "events" / "custom_events.jsonl",
        extra_paths=[EVENTS_PATH],
        half_life_days=HALF_LIFE_DAYS,
        lexical_weight=lexical_weight,
        time_weight=time_weight,
    )
    if saved is not None:
        storage.EVENT_EMBEDDINGS.update(saved)
    return idx

if BUILD_VARIANT_INDEXES:
    idx_hybrid = _build_index(LEXICAL_WEIGHT, TIME_WEIGHT, use_dense=True)
    idx_hybrid_no_time = _build_index(LEXICAL_WEIGHT, 0.0, use_dense=True)
    idx_dense_only = _build_index(0.0, 0.0, use_dense=True)
    idx_lexical_only = _build_index(LEXICAL_WEIGHT, 0.0, use_dense=False)
else:
    idx_hybrid = index
    idx_hybrid_no_time = index
    idx_dense_only = index
    idx_lexical_only = index

RETRIEVERS = {
    "hybrid": lambda q, k: idx_hybrid.search(q, limit=k),
    "hybrid_no_time": lambda q, k: idx_hybrid_no_time.search(q, limit=k),
    "dense_only": lambda q, k: idx_dense_only.search(q, limit=k),
    "bm25": lambda q, k: idx_lexical_only.search(q, limit=k),
}


In [None]:
# Dense coverage check (index + embeddings alignment)
print("Index events:", len(idx_dense_only.events))
print("Embeddings loaded:", len(storage.EVENT_EMBEDDINGS))
print("Dense ids:", len(idx_dense_only.dense_ids))
print("Dense vectors shape:", getattr(idx_dense_only.dense_vectors, "shape", None))
print("FAISS present:", idx_dense_only.faiss_index is not None)

missing = [e.get("event_id") for e in idx_dense_only.events if e.get("event_id") not in storage.EVENT_EMBEDDINGS]
print("Missing embeddings:", len(missing))
if missing:
    print("Missing sample:", missing[:5])


In [None]:
# Load eval queries
absolute_queries = json.loads(TEMPORAL_QUERY_PATH.read_text())
relative_queries = json.loads(RELATIVE_QUERY_PATH.read_text())
print("Absolute queries:", len(absolute_queries))
print("Relative queries:", len(relative_queries))

if relative_queries and 'relative' in relative_queries[0]:
    ref_iso = relative_queries[0]['relative'].get('reference_iso')
    if ref_iso:
        print('Relative reference:', ref_iso)


In [None]:
# Retrieval metrics (uses labeled pool if provided)
LABEL_PATH = INPUT_DIR / "pool_candidates.csv"
TOP_K_LIST = [1, 3, 5, 10]

def hit_at_k(expected, ranked, k):
    return 1.0 if expected and any(eid in expected for eid in ranked[:k]) else 0.0

def recall_at_k(expected, ranked, k):
    if not expected:
        return 0.0
    return len(set(ranked[:k]) & expected) / len(expected)

def ndcg_at_k(relevance, ranked, k):
    def _dcg(vals):
        return sum(val / np.log2(i + 2) for i, val in enumerate(vals))
    gains = [relevance.get(eid, 0.0) for eid in ranked[:k]]
    ideal = sorted(relevance.values(), reverse=True)[:k]
    return (_dcg(gains) / _dcg(ideal)) if ideal and _dcg(ideal) > 0 else 0.0

def mrr(expected, ranked):
    if not expected:
        return 0.0
    for i, eid in enumerate(ranked, start=1):
        if eid in expected:
            return 1.0 / i
    return 0.0

if LABEL_PATH.exists():
    labels_df = pd.read_csv(LABEL_PATH)
    labels_df.columns = [c.strip() for c in labels_df.columns]
    if "label" not in labels_df.columns:
        raise KeyError("Missing label column in pool_candidates.csv")
    QUERY_LABELS = {}
    for (query, event_id), group in labels_df.groupby(["query", "event_id"]):
        max_label = float(group["label"].max())
        QUERY_LABELS.setdefault(query, {})[event_id] = max_label
    EVAL_QUERIES = sorted(QUERY_LABELS.keys())
    # Filter out queries with no positive labels
    EVAL_QUERIES = [q for q in EVAL_QUERIES if any(rel > 0 for rel in QUERY_LABELS.get(q, {}).values())]
    print("Queries with at least one positive label:", len(EVAL_QUERIES))
    print("Using labeled pool:", len(EVAL_QUERIES), "queries")
else:
    print("Warning: pool_candidates.csv not found; using temporal queries as fallback labels.")
    EVAL_QUERIES = [item["query"] for item in absolute_queries]
    QUERY_LABELS = {
        item["query"]: {item["expected_events"][0]: 2.0}
        for item in absolute_queries
        if item.get("expected_events")
    }
    print("Using temporal queries as fallback labels:", len(EVAL_QUERIES))

rng = random.Random(EVAL_SEED)
rng.shuffle(EVAL_QUERIES)

rows = []
for name, retriever in RETRIEVERS.items():
    metrics = {"retriever": name}
    counts = 0
    for query in EVAL_QUERIES:
        relevance = QUERY_LABELS.get(query, {})
        expected = {eid for eid, rel in relevance.items() if rel > 0}
        hits = retriever(query, max(TOP_K_LIST))
        ranked = [hit.get("event_id") for hit in hits if hit.get("event_id")]
        for k in TOP_K_LIST:
            metrics[f"hit@{k}"] = metrics.get(f"hit@{k}", 0.0) + hit_at_k(expected, ranked, k)
            metrics[f"recall@{k}"] = metrics.get(f"recall@{k}", 0.0) + recall_at_k(expected, ranked, k)
            metrics[f"ndcg@{k}"] = metrics.get(f"ndcg@{k}", 0.0) + ndcg_at_k(relevance, ranked, k)
        metrics["mrr"] = metrics.get("mrr", 0.0) + mrr(expected, ranked)
        counts += 1
    if counts:
        for key in list(metrics.keys()):
            if key == "retriever":
                continue
            metrics[key] = metrics[key] / counts
    rows.append(metrics)

retrieval_df = pd.DataFrame(rows)
retrieval_path = METRICS_DIR / "metrics_results.csv"
retrieval_df.to_csv(retrieval_path, index=False)
print("Wrote", retrieval_path)
display(retrieval_df)


In [None]:
# Strict metrics: label==1 only (exclude exact-match label==2)
strict_rows = []
for name, retriever in RETRIEVERS.items():
    metrics = {'retriever': name}
    counts = 0
    for query in EVAL_QUERIES:
        relevance = QUERY_LABELS.get(query, {})
        expected = {eid for eid, rel in relevance.items() if rel == 1}
        if not expected:
            continue
        hits = retriever(query, max(TOP_K_LIST))
        ranked = [hit.get('event_id') for hit in hits if hit.get('event_id')]
        for k in TOP_K_LIST:
            metrics[f'hit@{k}'] = metrics.get(f'hit@{k}', 0.0) + hit_at_k(expected, ranked, k)
            metrics[f'recall@{k}'] = metrics.get(f'recall@{k}', 0.0) + recall_at_k(expected, ranked, k)
            metrics[f'ndcg@{k}'] = metrics.get(f'ndcg@{k}', 0.0) + ndcg_at_k(
                {eid: relevance.get(eid, 0) for eid in expected}, ranked, k
            )
        metrics['mrr'] = metrics.get('mrr', 0.0) + mrr(expected, ranked)
        counts += 1
    if counts:
        for key in list(metrics.keys()):
            if key == 'retriever':
                continue
            metrics[key] = metrics[key] / counts
    metrics['n_queries'] = counts
    strict_rows.append(metrics)

strict_df = pd.DataFrame(strict_rows)
print('Strict metrics (label==1 only):')
display(strict_df)
strict_path = METRICS_DIR / 'metrics_results_strict.csv'
strict_df.to_csv(strict_path, index=False)
print('Wrote', strict_path)


In [None]:
# Efficiency metrics
import time
try:
    import psutil
except Exception:
    psutil = None

def _rss_mb():
    if psutil is None:
        return None
    return psutil.Process().memory_info().rss / (1024 * 1024)

eff_rows = []
for name, retriever in RETRIEVERS.items():
    latencies = []
    rss_samples = []
    for query in EVAL_QUERIES[:50]:
        start = time.perf_counter()
        _ = retriever(query, 10)
        latencies.append((time.perf_counter() - start) * 1000)
        rss = _rss_mb()
        if rss is not None:
            rss_samples.append(rss)
    if not latencies:
        continue
    eff_rows.append({
        "retriever": name,
        "n_calls": len(latencies),
        "latency_p50_ms": float(np.percentile(latencies, 50)),
        "latency_p95_ms": float(np.percentile(latencies, 95)),
        "rss_mb_avg": float(np.mean(rss_samples)) if rss_samples else None,
    })

eff_df = pd.DataFrame(eff_rows)
eff_path = METRICS_DIR / "efficiency_metrics.csv"
eff_df.to_csv(eff_path, index=False)
print("Wrote", eff_path)
display(eff_df)


In [None]:
# Ablation table (R@10 / nDCG@10 / p50 ms)
ablation_variants = {
    "Dense-only": "dense_only",
    "Lexical-only (BM25)": "bm25",
    "Hybrid (no time)": "hybrid_no_time",
    "Hybrid (time decay)": "hybrid",
    "Hybrid (no episodes)": "hybrid",
}
rows = []
for label, key in ablation_variants.items():
    if key not in retrieval_df.set_index("retriever").index:
        continue
    r10 = float(retrieval_df.set_index("retriever").loc[key, "recall@10"])
    ndcg10 = float(retrieval_df.set_index("retriever").loc[key, "ndcg@10"])
    p50 = float(eff_df.set_index("retriever").loc[key, "latency_p50_ms"])
    rows.append({"variant": label, "r@10": r10, "ndcg@10": ndcg10, "p50_ms": p50})

ablation_df = pd.DataFrame(rows)
ablation_path = METRICS_DIR / "ablation_metrics.csv"
ablation_df.to_csv(ablation_path, index=False)
print("Wrote", ablation_path)
display(ablation_df)


In [None]:
# Build event lookup + episode mapping
EVENTS = list(storage.iter_normalized_events(EVENTS_PATH))
EVENT_TS_BY_ID = {evt["event_id"]: float(evt["ts_unix"]) for evt in EVENTS}

EVENT_EPISODE = {}
for path in (DATA_DIR / "episodes").rglob("*.json"):
    try:
        data = json.loads(path.read_text())
    except Exception:
        continue
    episode_id = data.get("episode_id") or path.stem
    for evt in data.get("events", []):
        eid = evt.get("event_id")
        if eid and eid not in EVENT_EPISODE:
            EVENT_EPISODE[eid] = episode_id

def retrieve_hybrid(query: str, top_k: int = 50):
    return index.search(query, limit=top_k)


In [None]:
# Temporal evaluation helpers
def _window_from_spec(spec):
    start = datetime.fromisoformat(spec["start_iso"])
    end = datetime.fromisoformat(spec["end_iso"])
    label = spec.get("label", "window")
    return start, end, label

def _shift_months(dt, months):
    # months > 0 means shift backwards
    total = (dt.year * 12 + (dt.month - 1)) - months
    year = total // 12
    month = (total % 12) + 1
    last_day = calendar.monthrange(year, month)[1]
    day = min(dt.day, last_day)
    tz = dt.tzinfo
    return datetime(year, month, day, tzinfo=tz)

def _window_from_relative(spec):
    ref_iso = spec.get("reference_iso")
    if ref_iso:
        now = datetime.fromisoformat(ref_iso)
    else:
        now = datetime.now(tz=timezone.utc)
    tz = now.tzinfo or timezone.utc
    rel_type = spec.get("type")
    value = int(spec.get("value", 0))
    if rel_type == "days_ago":
        target = now - timedelta(days=value)
        start = datetime(target.year, target.month, target.day, tzinfo=tz)
        end = start + timedelta(days=1) - timedelta(seconds=1)
        label = f"days_ago:{value}"
        return start, end, label
    if rel_type == "weeks_ago":
        target = now - timedelta(weeks=value)
        week_start = datetime(target.year, target.month, target.day, tzinfo=tz) - timedelta(days=target.weekday())
        start = datetime(week_start.year, week_start.month, week_start.day, tzinfo=tz)
        end = start + timedelta(days=7) - timedelta(seconds=1)
        label = f"weeks_ago:{value}"
        return start, end, label
    if rel_type == "months_ago":
        target = _shift_months(now, value)
        start = datetime(target.year, target.month, 1, tzinfo=tz)
        last_day = calendar.monthrange(target.year, target.month)[1]
        end = datetime(target.year, target.month, last_day, tzinfo=tz) + timedelta(days=1) - timedelta(seconds=1)
        label = f"months_ago:{value}"
        return start, end, label
    if rel_type == "years_ago":
        target = _shift_months(now, value * 12)
        start = datetime(target.year, 1, 1, tzinfo=tz)
        end = datetime(target.year, 12, 31, tzinfo=tz) + timedelta(days=1) - timedelta(seconds=1)
        label = f"years_ago:{value}"
        return start, end, label
    raise ValueError(f"Unsupported relative spec: {spec}")

def _within_window(event_id: str, start_ts: float, end_ts: float) -> bool:
    ts = EVENT_TS_BY_ID.get(event_id, None)
    if ts is None:
        return False
    return start_ts <= ts <= end_ts

def _episode_hit(expected_ids, ranked):
    if not expected_ids:
        return 0
    expected_episode = EVENT_EPISODE.get(expected_ids[0])
    if not expected_episode:
        return 0
    for eid in ranked:
        if EVENT_EPISODE.get(eid) == expected_episode:
            return 1
    return 0

def evaluate_temporal(queries, top_k_prec=(1, 3), top_k_recall=3):
    window_hits = 0
    cite_p = {k: 0.0 for k in top_k_prec}
    cite_r = 0.0
    episode_hits = 0
    abs_errors = []
    total = 0
    for item in queries:
        query = item["query"]
        expected = item.get("expected_events", [])
        if not expected:
            continue
        if "window" in item:
            start, end, _ = _window_from_spec(item["window"])
        else:
            start, end, _ = _window_from_relative(item["relative"])
        start_ts = start.timestamp()
        end_ts = end.timestamp()
        hits = retrieve_hybrid(query, 50)
        ranked = [hit.get("event_id") for hit in hits if hit.get("event_id")]
        ranked = [eid for eid in ranked if _within_window(eid, start_ts, end_ts)]
        ranked_recall = ranked[:top_k_recall]
        ranked_prec = {k: ranked[:k] for k in top_k_prec}
        if ranked_recall:
            window_hits += 1
        expected_set = set(expected)
        intersect_recall = expected_set.intersection(ranked_recall)
        for k, rk in ranked_prec.items():
            intersect_prec = expected_set.intersection(rk)
            cite_p[k] += (len(intersect_prec) / k) if rk else 0.0
        cite_r += len(intersect_recall) / len(expected_set)
        episode_hits += _episode_hit(expected, ranked_recall)
        if ranked_recall:
            top_id = ranked_recall[0]
            ts_top = EVENT_TS_BY_ID.get(top_id)
            ts_gold = EVENT_TS_BY_ID.get(expected[0])
            if ts_top is not None and ts_gold is not None:
                abs_errors.append(abs(ts_top - ts_gold) / 3600.0)
        total += 1
    if total == 0:
        return {}
    metrics = {
        "win_acc": window_hits / total,
        f"cite_r@{top_k_recall}": cite_r / total,
        "episode_hit": episode_hits / total,
        "med_abs_err_hours": float(np.median(abs_errors)) if abs_errors else 0.0,
        "n_queries": total,
    }
    for k in top_k_prec:
        metrics[f"cite_p@{k}"] = cite_p[k] / total
    return metrics


In [None]:
# Run temporal metrics
abs_metrics = evaluate_temporal(absolute_queries, top_k_prec=(1, 3), top_k_recall=3)
rel_metrics = evaluate_temporal(relative_queries, top_k_prec=(1, 3), top_k_recall=3)

abs_df = pd.DataFrame([dict(abs_metrics, retriever="hybrid", suite="absolute")])
rel_df = pd.DataFrame([dict(rel_metrics, retriever="hybrid", suite="relative")])

abs_path = METRICS_DIR / "temporal_metrics_absolute.csv"
rel_path = METRICS_DIR / "temporal_metrics_relative.csv"
abs_df.to_csv(abs_path, index=False)
rel_df.to_csv(rel_path, index=False)

print("Wrote", abs_path)
display(abs_df)
print("Wrote", rel_path)
display(rel_df)


In [None]:
# Temporal baselines: no-window + counterfactual time
COUNTERFACTUAL_SHIFT_DAYS = 30

def _shift_window_by_days(start, end, days):
    return start + timedelta(days=days), end + timedelta(days=days)

def evaluate_temporal_variant(queries, top_k_prec=(1, 3), top_k_recall=3, apply_window=True, counterfactual_days=None):
    window_hits = 0
    cite_p = {k: 0.0 for k in top_k_prec}
    cite_r = 0.0
    episode_hits = 0
    abs_errors = []
    total = 0
    for item in queries:
        query = item["query"]
        expected = item.get("expected_events", [])
        if not expected:
            continue
        if "window" in item:
            start, end, _ = _window_from_spec(item["window"])
        else:
            start, end, _ = _window_from_relative(item["relative"])
        if counterfactual_days:
            start, end = _shift_window_by_days(start, end, counterfactual_days)
        start_ts = start.timestamp()
        end_ts = end.timestamp()
        hits = retrieve_hybrid(query, 50)
        ranked_all = [hit.get("event_id") for hit in hits if hit.get("event_id")]
        ranked_in_window = [eid for eid in ranked_all if _within_window(eid, start_ts, end_ts)]
        ranked_use = ranked_in_window if apply_window else ranked_all
        ranked_recall = ranked_use[:top_k_recall]
        ranked_prec = {k: ranked_use[:k] for k in top_k_prec}
        if any(_within_window(eid, start_ts, end_ts) for eid in ranked_all[:top_k_recall]):
            window_hits += 1
        expected_set = set(expected)
        intersect_recall = expected_set.intersection(ranked_recall)
        for k, rk in ranked_prec.items():
            intersect_prec = expected_set.intersection(rk)
            cite_p[k] += (len(intersect_prec) / k) if rk else 0.0
        cite_r += len(intersect_recall) / len(expected_set)
        episode_hits += _episode_hit(expected, ranked_recall)
        if ranked_recall:
            top_id = ranked_recall[0]
            ts_top = EVENT_TS_BY_ID.get(top_id)
            ts_gold = EVENT_TS_BY_ID.get(expected[0])
            if ts_top is not None and ts_gold is not None:
                abs_errors.append(abs(ts_top - ts_gold) / 3600.0)
        total += 1
    if total == 0:
        return {}
    metrics = {
        "win_acc": window_hits / total,
        f"cite_r@{top_k_recall}": cite_r / total,
        "episode_hit": episode_hits / total,
        "med_abs_err_hours": float(np.median(abs_errors)) if abs_errors else 0.0,
        "n_queries": total,
    }
    for k in top_k_prec:
        metrics[f"cite_p@{k}"] = cite_p[k] / total
    return metrics

# No-window baseline (raw top-k)
abs_nowindow = evaluate_temporal_variant(absolute_queries, top_k_prec=(1, 3), top_k_recall=3, apply_window=False)
rel_nowindow = evaluate_temporal_variant(relative_queries, top_k_prec=(1, 3), top_k_recall=3, apply_window=False)

# Counterfactual time (shifted window)
abs_counter = evaluate_temporal_variant(absolute_queries, top_k_prec=(1, 3), top_k_recall=3, apply_window=True, counterfactual_days=COUNTERFACTUAL_SHIFT_DAYS)
rel_counter = evaluate_temporal_variant(relative_queries, top_k_prec=(1, 3), top_k_recall=3, apply_window=True, counterfactual_days=COUNTERFACTUAL_SHIFT_DAYS)

abs_nowindow_df = pd.DataFrame([dict(abs_nowindow, retriever="hybrid", suite="absolute_nowindow")])
rel_nowindow_df = pd.DataFrame([dict(rel_nowindow, retriever="hybrid", suite="relative_nowindow")])
abs_counter_df = pd.DataFrame([dict(abs_counter, retriever="hybrid", suite="absolute_counterfactual")])
rel_counter_df = pd.DataFrame([dict(rel_counter, retriever="hybrid", suite="relative_counterfactual")])

abs_nowindow_path = METRICS_DIR / "temporal_metrics_absolute_nowindow.csv"
rel_nowindow_path = METRICS_DIR / "temporal_metrics_relative_nowindow.csv"
abs_counter_path = METRICS_DIR / "temporal_metrics_absolute_counterfactual.csv"
rel_counter_path = METRICS_DIR / "temporal_metrics_relative_counterfactual.csv"

abs_nowindow_df.to_csv(abs_nowindow_path, index=False)
rel_nowindow_df.to_csv(rel_nowindow_path, index=False)
abs_counter_df.to_csv(abs_counter_path, index=False)
rel_counter_df.to_csv(rel_counter_path, index=False)

print("Wrote", abs_nowindow_path)
print("Wrote", rel_nowindow_path)
print("Wrote", abs_counter_path)
print("Wrote", rel_counter_path)
display(abs_nowindow_df)
display(rel_nowindow_df)
display(abs_counter_df)
display(rel_counter_df)


In [None]:
# Temporal split (recent vs old) using hybrid retriever
RECENCY_DAYS = 60
RECENT_REFERENCE_ISO = "2025-07-01T00:00:00+00:00"
RECENT_REFERENCE = datetime.fromisoformat(RECENT_REFERENCE_ISO)

now = RECENT_REFERENCE
recent_queries = []
old_queries = []
for item in absolute_queries:
    ts = item.get("event_timestamp")
    if not ts:
        continue
    ts_dt = datetime.fromisoformat(ts.replace("Z", "+00:00"))
    age_days = (now.date() - ts_dt.date()).days
    if age_days <= RECENCY_DAYS:
        recent_queries.append(item)
    else:
        old_queries.append(item)

def evaluate_temporal_with_retriever(queries, retriever, top_k_prec=(1, 3), top_k_recall=3):
    window_hits = 0
    cite_p = {k: 0.0 for k in top_k_prec}
    cite_r = 0.0
    episode_hits = 0
    abs_errors = []
    total = 0
    for item in queries:
        query = item["query"]
        expected = item.get("expected_events", [])
        if not expected:
            continue
        if "window" in item:
            start, end, _ = _window_from_spec(item["window"])
        else:
            start, end, _ = _window_from_relative(item["relative"])
        start_ts = start.timestamp()
        end_ts = end.timestamp()
        hits = retriever(query, 50)
        ranked = [hit.get("event_id") for hit in hits if hit.get("event_id")]
        ranked = [eid for eid in ranked if _within_window(eid, start_ts, end_ts)]
        ranked_recall = ranked[:top_k_recall]
        ranked_prec = {k: ranked[:k] for k in top_k_prec}
        if ranked_recall:
            window_hits += 1
        expected_set = set(expected)
        intersect_recall = expected_set.intersection(ranked_recall)
        for k, rk in ranked_prec.items():
            intersect_prec = expected_set.intersection(rk)
            cite_p[k] += (len(intersect_prec) / k) if rk else 0.0
        cite_r += len(intersect_recall) / len(expected_set)
        episode_hits += _episode_hit(expected, ranked_recall)
        if ranked_recall:
            top_id = ranked_recall[0]
            ts_top = EVENT_TS_BY_ID.get(top_id)
            ts_gold = EVENT_TS_BY_ID.get(expected[0])
            if ts_top is not None and ts_gold is not None:
                abs_errors.append(abs(ts_top - ts_gold) / 3600.0)
        total += 1
    if total == 0:
        return {}
    metrics = {
        "win_acc": window_hits / total,
        f"cite_r@{top_k_recall}": cite_r / total,
        "episode_hit": episode_hits / total,
        "med_abs_err_hours": float(np.median(abs_errors)) if abs_errors else 0.0,
        "n_queries": total,
    }
    for k in top_k_prec:
        metrics[f"cite_p@{k}"] = cite_p[k] / total
    return metrics

recent_metrics = evaluate_temporal_with_retriever(recent_queries, RETRIEVERS["hybrid"], top_k_prec=(1, 3), top_k_recall=3)
old_metrics = evaluate_temporal_with_retriever(old_queries, RETRIEVERS["hybrid"], top_k_prec=(1, 3), top_k_recall=3)

split_df = pd.DataFrame([
    dict(recent_metrics, retriever="hybrid", suite="recent"),
    dict(old_metrics, retriever="hybrid", suite="old"),
])
split_path = METRICS_DIR / "temporal_metrics_split.csv"
split_df.to_csv(split_path, index=False)
print("Wrote", split_path)
display(split_df)


## Human audit template (manual labeling)
Generate an audit sheet for the existing queries. Fill correctness/citation scores (0/1/2).


In [None]:
# Build a human-audit template (no model inference here)


AUDIT_N = None  # set an int to sample, or None to use all
AUDIT_SEED = 13
AUDIT_OUT = AUDIT_DIR / 'audit_sheet_1.csv'
AUDIT_OUT_RATER2 = AUDIT_DIR / 'audit_sheet_2.csv'
AUDIT_QUESTIONS = INPUT_DIR / 'audit_questions.json'

import random
random.seed(AUDIT_SEED)

def _is_low_quality_response(text: str) -> bool:
    if not text:
        return True
    t = text.strip().lower()
    if not t:
        return True
    if t.startswith('asker comments'):
        return True
    if t.startswith('a: asker comments'):
        return True
    return False

queries = []
# Prefer curated audit questions if available
if AUDIT_QUESTIONS.exists():
    data = json.loads(AUDIT_QUESTIONS.read_text())
    # Build event_id -> response lookup for filtering
    event_lookup = {}
    events_dir = DATA_DIR / 'events'
    event_files = list(events_dir.glob('*.jsonl')) if events_dir.exists() else []
    for path in event_files:
        for evt in storage.iter_normalized_events(path):
            eid = evt.get('event_id')
            if eid and eid not in event_lookup:
                event_lookup[eid] = (evt.get('response') or '')

    for item in data:
        if isinstance(item, dict):
            q = item.get('query')
            eid = item.get('event_id')
            if eid:
                resp = event_lookup.get(eid, '')
                if _is_low_quality_response(resp):
                    continue
        else:
            q = str(item)
        if q:
            queries.append(q)
    queries = list(dict.fromkeys(queries))
# Otherwise, fall back to labeled pool queries if available
elif (INPUT_DIR / 'pool_candidates.csv').exists():
    df_pool = pd.read_csv(INPUT_DIR / 'pool_candidates.csv')
    queries = list(dict.fromkeys(df_pool['query'].dropna().tolist()))
# Fallback to temporal queries
else:
    if 'absolute_queries' in globals():
        queries.extend([q['query'] for q in absolute_queries])
    if 'relative_queries' in globals():
        queries.extend([q['query'] for q in relative_queries])
    queries = list(dict.fromkeys(queries))

if AUDIT_N is not None and len(queries) > AUDIT_N:
    queries = random.sample(queries, AUDIT_N)

audit_df = pd.DataFrame({
    'query': queries,
    'response': [''] * len(queries),
    'cited_ids': [''] * len(queries),
    'correctness': [''] * len(queries),  # 0/1/2
    'citation_usefulness': [''] * len(queries),  # 0/1/2
    'notes': [''] * len(queries),
    'cited_events_text': [''] * len(queries),
    'cited_evidence_taken': [''] * len(queries),
    'cited_evidence_suggested': [''] * len(queries),
    'cited_action_taken': [''] * len(queries),
    'cited_action_suggested': [''] * len(queries),
})
audit_df.to_csv(AUDIT_OUT, index=False)
# Create a second sheet with the same queries (for second rater)
audit_df[['query', 'response', 'cited_ids', 'correctness', 'citation_usefulness', 'notes', 'cited_events_text', 'cited_evidence_taken', 'cited_evidence_suggested', 'cited_action_taken', 'cited_action_suggested']].to_csv(AUDIT_OUT_RATER2, index=False)
print('Wrote', AUDIT_OUT)
print('Wrote', AUDIT_OUT_RATER2)


In [None]:
# Optional: auto-fill responses + cited_ids for human audit
# Set RUN_AUDIT_MODEL=True to generate model responses for the audit sheet.
import os
RUN_AUDIT_MODEL = True
AUDIT_MODEL_NAME = os.getenv('ACE_MODEL_NAME', 'Qwen/Qwen2.5-3B-Instruct')
AUDIT_DEVICE = os.getenv('ACE_DEVICE', 'auto')
AUDIT_MAX_NEW_TOKENS = int(os.getenv('ACE_MAX_NEW_TOKENS', '400'))
AUDIT_TEMPERATURE = float(os.getenv('ACE_TEMPERATURE', '0.2'))
AUDIT_TOP_P = float(os.getenv('ACE_TOP_P', '0.9'))
AUDIT_TOP_K = int(os.getenv('ACE_TOP_K', '5'))
# AUDIT_EVENT_CONTEXT_BUDGET_TOKENS = int(os.getenv('ACE_EVENT_CONTEXT_BUDGET_TOKENS', '0'))
AUDIT_EVENT_CONTEXT_BUDGET_TOKENS=None  # Uses default 1200 in summarize_events()

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from src.ace_agent import detect_time_request, summarize_events

MEMORY_SYSTEM_PROMPT = (
    'You are a personal assistant with access to prior interactions. '
    'Answer only using the provided context; if the answer is not in memory, say so.'
)
GENERAL_SYSTEM_PROMPT = (
    'You are a helpful assistant. Use general knowledge when needed. '
    'If memory context is provided, prefer it and cite it.'
)

def pick_device(choice: str) -> str:
    if choice != 'auto':
        return choice
    if torch.backends.mps.is_available():
        return 'mps'
    if torch.cuda.is_available():
        return 'cuda'
    return 'cpu'

def build_chat_prompt(messages):
    if hasattr(tokenizer, 'apply_chat_template'):
        return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    return '\n'.join(f"{m['role']}: {m['content']}" for m in messages) + '\nassistant:'

def generate_text(messages):
    prompt = build_chat_prompt(messages)
    encoded = tokenizer(prompt, return_tensors='pt')
    encoded = {key: value.to(model.device) for key, value in encoded.items()}
    with torch.no_grad():
        generated = model.generate(
            **encoded,
            max_new_tokens=AUDIT_MAX_NEW_TOKENS,
            temperature=AUDIT_TEMPERATURE,
            top_p=AUDIT_TOP_P,
            do_sample=True,
            repetition_penalty=1.05,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    gen_tokens = generated[:, encoded['input_ids'].shape[-1]:]
    return tokenizer.decode(gen_tokens[0], skip_special_tokens=True).strip()

def should_use_memory(query, window, hits, threshold=0.35):
    lowered = query.lower()
    memory_hints = [
        'remember', 'recall', 'what did i', 'what did we', 'what did you',
        'earlier', 'previous', 'last time', 'yesterday', 'today', 'last week',
        'last month', 'last year', 'ago'
    ]
    if window is not None:
        return True
    if any(hint in lowered for hint in memory_hints):
        return True
    if not hits:
        return False
    top_score = hits[0].get('score', 0.0)
    return top_score >= threshold


def summarize_events_audit(event_hits):
    lines = []
    for hit in event_hits:
        timestamp = hit.get('timestamp', 'unknown')
        question = (hit.get('question') or '').strip()
        response = (hit.get('response') or '').strip()

        evidence_taken = hit.get('evidence_taken')
        if isinstance(evidence_taken, list):
            evidence_taken = "; ".join(str(item).strip() for item in evidence_taken if str(item).strip())
        evidence_taken = (evidence_taken or '').strip()

        evidence_suggested = hit.get('evidence_suggested')
        if isinstance(evidence_suggested, list):
            evidence_suggested = "; ".join(str(item).strip() for item in evidence_suggested if str(item).strip())
        evidence_suggested = (evidence_suggested or '').strip()

        action_taken = hit.get('action_taken')
        if isinstance(action_taken, list):
            action_taken = "; ".join(str(item).strip() for item in action_taken if str(item).strip())
        action_taken = (action_taken or '').strip()

        action_suggested = hit.get('action_suggested')
        if isinstance(action_suggested, list):
            action_suggested = "; ".join(str(item).strip() for item in action_suggested if str(item).strip())
        action_suggested = (action_suggested or '').strip()

        extras = []
        if evidence_taken or evidence_suggested:
            if evidence_taken:
                extras.append(f"Evidence taken: {evidence_taken}")
            if evidence_suggested:
                extras.append(f"Evidence suggested: {evidence_suggested}")
        else:
            if action_taken:
                extras.append(f"Action taken: {action_taken}")
            if action_suggested:
                extras.append(f"Action suggested: {action_suggested}")

        extra_text = f" | {' | '.join(extras)}" if extras else ""

        line = (
            f"- ({hit.get('score', 0.0):.3f}) [{timestamp}] "
            f"Q: {question} | A: {response}{extra_text}"
        )
        lines.append(line)
    return "\n".join(lines) if lines else "- None"

if RUN_AUDIT_MODEL:
    device = pick_device(AUDIT_DEVICE)
    tokenizer = AutoTokenizer.from_pretrained(AUDIT_MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(AUDIT_MODEL_NAME)
    model.to(device)
    model.eval()

    audit_df = pd.read_csv(AUDIT_OUT)
    responses = []
    cited_list = []
    for _, row in audit_df.iterrows():
        query = str(row['query'])
        timestamp = datetime.utcnow().replace(tzinfo=timezone.utc)
        window = detect_time_request(query, now=timestamp)
        time_range = None
        if window is not None:
            time_range = (window.start.timestamp(), window.end.timestamp())
        event_hits = storage.retrieve(query, limit=AUDIT_TOP_K, time_window=time_range)
        # Force audit retrieval to use only the main dataset file
        # audit_index = storage.RetrievalIndex.build(
        #     primary_path=storage.STORAGE.normalized_events / "ace_events_h1_2025.jsonl",
        #     extra_paths=[]
        # )
        # event_hits = audit_index.search(query, limit=AUDIT_TOP_K, time_window=time_range)

        use_memory = should_use_memory(query, window, event_hits)
        event_context = summarize_events_audit(event_hits) if use_memory else '- None'
        window_note = ''
        if window is not None:
            window_note = f"\nTime window: {window.label} ({window.start.isoformat()} -> {window.end.isoformat()})\n"
        if use_memory:
            context = f"Relevant past events:\n{event_context}\n{window_note}\nUser question: {query}"
        else:
            context = f"User question: {query}"
        system_prompt = MEMORY_SYSTEM_PROMPT if use_memory else GENERAL_SYSTEM_PROMPT
        messages = [
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': context},
        ]
        response = generate_text(messages)
        responses.append(response)
        cited_ids = [h.get('event_id') for h in event_hits if h.get('event_id')]
        cited_list.append(','.join(cited_ids))

    audit_df['response'] = responses
    audit_df['cited_ids'] = cited_list
    audit_df.to_csv(AUDIT_OUT, index=False)
    print('Updated', AUDIT_OUT)


In [None]:
# Build audit context with cited event text
AUDIT_CONTEXT_OUT = AUDIT_DIR / 'audit_context.csv'
# Load events from all JSONL files under data/events

def _normalize_field(value):
    if value is None:
        return ''
    if isinstance(value, list):
        value = '; '.join(str(item).strip() for item in value if str(item).strip())
    return str(value).strip()

def _unique_join(items, sep=' | '):
    seen = set()
    out = []
    for item in items:
        if item and item not in seen:
            seen.add(item)
            out.append(item)
    return sep.join(out)

event_lookup = {}
events_dir = DATA_DIR / 'events'
event_files = list(events_dir.glob('*.jsonl')) if events_dir.exists() else []
for path in event_files:
    for evt in storage.iter_normalized_events(path):
        eid = evt.get('event_id')
        if eid and eid not in event_lookup:
            event_lookup[eid] = {
                'question': _normalize_field(evt.get('question')),
                'response': _normalize_field(evt.get('response')),
                'action_taken': _normalize_field(evt.get('action_taken')),
                'action_suggested': _normalize_field(evt.get('action_suggested')),
                'evidence_taken': _normalize_field(evt.get('evidence_taken')),
                'evidence_suggested': _normalize_field(evt.get('evidence_suggested')),
            }

audit_df = pd.read_csv(AUDIT_OUT)
contexts = []
action_taken_col = []
action_suggested_col = []
evidence_taken_col = []
evidence_suggested_col = []
for _, row in audit_df.iterrows():
    cited = str(row.get('cited_ids', '')).strip()
    if not cited:
        contexts.append('')
        action_taken_col.append('')
        action_suggested_col.append('')
        evidence_taken_col.append('')
        evidence_suggested_col.append('')
        continue
    ids = [x.strip() for x in cited.split(',') if x.strip()]
    parts = []
    actions_taken = []
    actions_suggested = []
    evidences_taken = []
    evidences_suggested = []
    for eid in ids:
        evt = event_lookup.get(eid)
        if not evt:
            continue
        q = evt.get('question') or ''
        a = evt.get('response') or ''
        if q or a:
            parts.append(f"[{eid}] Q: {q}\nA: {a}")
        if evt.get('action_taken'):
            actions_taken.append(f"[{eid}] {evt['action_taken']}")
        if evt.get('action_suggested'):
            actions_suggested.append(f"[{eid}] {evt['action_suggested']}")
        if evt.get('evidence_taken'):
            evidences_taken.append(f"[{eid}] {evt['evidence_taken']}")
        if evt.get('evidence_suggested'):
            evidences_suggested.append(f"[{eid}] {evt['evidence_suggested']}")
    contexts.append('\n\n'.join(parts))
    action_taken_col.append(_unique_join(actions_taken))
    action_suggested_col.append(_unique_join(actions_suggested))
    evidence_taken_col.append(_unique_join(evidences_taken))
    evidence_suggested_col.append(_unique_join(evidences_suggested))

audit_df['cited_events_text'] = contexts
audit_df['cited_evidence_taken'] = evidence_taken_col
audit_df['cited_evidence_suggested'] = evidence_suggested_col
audit_df['cited_action_taken'] = action_taken_col
audit_df['cited_action_suggested'] = action_suggested_col

audit_context_cols = ['query', 'response', 'cited_ids', 'correctness', 'citation_usefulness', 'notes', 'cited_events_text', 'cited_evidence_taken', 'cited_evidence_suggested', 'cited_action_taken', 'cited_action_suggested']
audit_context_df = audit_df.reindex(columns=audit_context_cols)
audit_context_df.to_csv(AUDIT_CONTEXT_OUT, index=False)
# Also update audit sheets with the filled action/evidence columns
audit_df = audit_df.reindex(columns=audit_context_cols)
audit_df.to_csv(AUDIT_OUT, index=False)
audit_df.to_csv(AUDIT_OUT_RATER2, index=False)
print('Wrote', AUDIT_CONTEXT_OUT)
