# ACE Minimal RAG Agent

This notebook runs a minimal storage-backed RAG loop. Supported temporal queries include:
- yesterday / today / last week / last month / last year
- N days/weeks/months/years ago
- Jan 2025 / 2025-01 / 2025-03-31
- from June 1, 2025 to June 15, 2025


In [None]:
import os
import sys
import json
from pathlib import Path

ROOT_CANDIDATE = Path.cwd().resolve()
if (ROOT_CANDIDATE / 'ACE.ipynb').exists():
    REPO_ROOT = ROOT_CANDIDATE
elif (ROOT_CANDIDATE / 'ACE_release' / 'ACE.ipynb').exists():
    REPO_ROOT = ROOT_CANDIDATE / 'ACE_release'
else:
    REPO_ROOT = ROOT_CANDIDATE
DEFAULT_STORAGE_DIR = REPO_ROOT / 'data'
EVENTS_PATH = DEFAULT_STORAGE_DIR / "events" / "ace_events_h1_2025.jsonl"
EMBED_DIR = DEFAULT_STORAGE_DIR / "embeddings"
EMBED_DIR.mkdir(parents=True, exist_ok=True)
EMBED_MODEL_NAME = "intfloat/e5-large-v2" # Used for current test/results
# EMBED_MODEL_NAME = "intfloat/e5-small-v2" # smaller/faster for limited resources - results may vary
EMBED_MODEL_SLUG = EMBED_MODEL_NAME.replace("/", "_").replace("-", "_")
EMBED_PATH = EMBED_DIR / f"event_embeddings_{EMBED_MODEL_SLUG}.npz"
EMBED_QUERY_PREFIX = "query: "
EMBED_PASSAGE_PREFIX = "passage: "

os.environ.setdefault('ACE_STORAGE_DIR', str(DEFAULT_STORAGE_DIR))
# Optional override: os.environ['ACE_USER_TZ'] = 'America/New_York'

sys.path.insert(0, str(REPO_ROOT / 'src'))

import torch
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
try:
    from huggingface_hub import snapshot_download
    from huggingface_hub.utils import LocalEntryNotFoundError
except Exception:
    snapshot_download = None
    LocalEntryNotFoundError = Exception
import src.storage_helpers as storage
from src.ace_agent import run_chat_loop


BATCH_SIZE = 64
REBUILD_EMBEDDINGS = False # Set to True force re-building embeddings


## To Download and use the ACE dataset:

In [None]:
# Load dataset
ds = load_dataset("anon-user-423/ACE", "events", split="train")
print(ds)
print("Columns:", ds.column_names)
print("Sample:", ds[0])


In [None]:
# Write events to JSONL (used by storage helpers)
EVENTS_PATH.parent.mkdir(parents=True, exist_ok=True)
if not EVENTS_PATH.exists():
    with EVENTS_PATH.open("w", encoding="utf-8") as f:
        for row in ds:
            f.write(json.dumps(dict(row), ensure_ascii=True) + "\n")
    print("Wrote", EVENTS_PATH)
else:
    print("Events file already exists:", EVENTS_PATH)


In [None]:
# Build episodes from dataset

# ------ Will take around an hour or two to run ------

episode_dir = DEFAULT_STORAGE_DIR / "episodes"
if not episode_dir.exists() or not any(episode_dir.rglob("*.json")):
    summary = storage.roll_up_episodes(events_path=EVENTS_PATH, destination_dir=episode_dir, overwrite=True, show_progress=True)
    print("Episode roll-up:", summary)
else:
    print("Episodes already present:", episode_dir)


In [None]:
# Build or load embeddings
if REBUILD_EMBEDDINGS or not EMBED_PATH.exists():
    model = SentenceTransformer(EMBED_MODEL_NAME)
    event_ids = []
    vectors = []
    for start in tqdm(range(0, len(ds), BATCH_SIZE), desc="Embedding events"):
        batch = ds[start:start + BATCH_SIZE]
        texts = [f"{EMBED_PASSAGE_PREFIX}{q} {r}".strip() for q, r in zip(batch["question"], batch["response"]) ]
        emb = model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
        vectors.append(emb)
        event_ids.extend(batch["event_id"])
    vectors = np.vstack(vectors)
    np.savez(EMBED_PATH, event_ids=np.array(event_ids), vectors=vectors)
    print("Wrote", EMBED_PATH)
else:
    model = SentenceTransformer(EMBED_MODEL_NAME)
    print("Using cached embeddings:", EMBED_PATH)


storage.load_event_embeddings_npz(EMBED_PATH)
storage.register_event_embedder(lambda text: model.encode([f"{EMBED_QUERY_PREFIX}{text}"], convert_to_numpy=True)[0])
print('Event embeddings loaded:', len(storage.EVENT_EMBEDDINGS))


# Load the Model

In [None]:
MODEL_NAME = os.getenv('ACE_MODEL_NAME', 'Qwen/Qwen2.5-3B-Instruct')
DEVICE = os.getenv('ACE_DEVICE', 'auto')
CHAT_ID = os.getenv('ACE_CHAT_ID', 'default')
TOP_K = int(os.getenv('ACE_TOP_K', '5'))
MAX_NEW_TOKENS = int(os.getenv('ACE_MAX_NEW_TOKENS', '400'))
EXTRACT_MAX_NEW_TOKENS = int(os.getenv('ACE_EXTRACT_MAX_NEW_TOKENS', '200'))
EVENT_CONTEXT_BUDGET_TOKENS = int(os.getenv('ACE_EVENT_CONTEXT_BUDGET_TOKENS', '1200'))
TEMPERATURE = float(os.getenv('ACE_TEMPERATURE', '0.2'))
TOP_P = float(os.getenv('ACE_TOP_P', '0.9'))
REFRESH_EVERY = int(os.getenv('ACE_REFRESH_EVERY', '3'))
EPISODE_LOOKBACK_MONTHS = int(os.getenv('ACE_EPISODE_LOOKBACK_MONTHS', '12'))
EPISODE_EVENTS_PER_CONTEXT = int(os.getenv('ACE_EPISODE_EVENTS_PER_CONTEXT', '5'))
EPISODE_TOP_K = int(os.getenv('ACE_EPISODE_TOP_K', '5'))
EPISODE_PREFILTER = os.getenv('ACE_EPISODE_PREFILTER', '1').lower() in {'1','true','yes'}
MEMORY_SCORE_THRESHOLD = float(os.getenv('ACE_MEMORY_SCORE_THRESHOLD', '0.35'))
ALLOW_GENERAL_FALLBACK = os.getenv('ACE_ALLOW_GENERAL_FALLBACK', '1').lower() in {'1','true','yes'}
MODEL_DTYPE = os.getenv('ACE_MODEL_DTYPE', 'float16')
LOW_CPU_MEM = os.getenv('ACE_LOW_CPU_MEM', '1').lower() in {'1','true','yes'}
FORCE_DOWNLOAD = os.getenv('ACE_FORCE_DOWNLOAD', '0').lower() in {'1', 'true', 'yes'}
HF_CACHE_DIR = os.getenv('ACE_HF_CACHE_DIR')
STREAM = int(os.getenv('ACE_STREAM', '1'))  # 1=stream tokens, 0=full response

MEMORY_SYSTEM_PROMPT = (
    'You are a personal assistant with access to prior interactions. '
    'Answer only using the provided context; if the answer is not in memory, say so.'
)
GENERAL_SYSTEM_PROMPT = (
    'You are a helpful assistant. Use general knowledge when needed. '
    'If memory context is provided, prefer it and cite it.'
)
def pick_device(choice: str) -> str:
    if choice != 'auto':
        return choice
    if torch.backends.mps.is_available():
        return 'mps'
    if torch.cuda.is_available():
        return 'cuda'
    return 'cpu'

def should_use_memory(query: str, window, hits, threshold: float) -> bool:
    lowered = query.lower()
    memory_hints = [
        'remember', 'recall', 'what did i', 'what did we', 'what did you',
        'earlier', 'previous', 'last time', 'yesterday', 'today', 'last week',
        'last month', 'last year', 'ago'
    ]
    if window is not None:
        return True
    if any(hint in lowered for hint in memory_hints):
        return True
    if not hits:
        return False
    top_score = hits[0].get('score', 0.0)
    return top_score >= threshold



In [None]:
def resolve_model_source(model_name, cache_dir, force_download):
    if Path(model_name).exists():
        print(f'Using local model path: {model_name}')
        return model_name, True
    if snapshot_download is None:
        print('huggingface_hub not available; loading via transformers.')
        return model_name, False
    if force_download:
        print('Force download enabled; downloading model files with per-file progress.')
        cache_path = snapshot_download(
            model_name,
            cache_dir=cache_dir,
            local_files_only=False,
            resume_download=True,
        )
        return cache_path, True
    try:
        cache_path = snapshot_download(
            model_name,
            cache_dir=cache_dir,
            local_files_only=True,
        )
        print(f'Cache hit: model snapshot already present at {cache_path}')
        return cache_path, True
    except LocalEntryNotFoundError:
        print('Cache miss: downloading model files with per-file progress.')
        cache_path = snapshot_download(
            model_name,
            cache_dir=cache_dir,
            local_files_only=False,
            resume_download=True,
        )
        return cache_path, True
    except Exception:
        print('Cache lookup failed; downloading via transformers.')
        return model_name, False

device = pick_device(DEVICE)
dtype = None
if MODEL_DTYPE == 'float16':
    dtype = torch.float16
elif MODEL_DTYPE == 'bfloat16':
    dtype = torch.bfloat16
print(f'Loading model {MODEL_NAME} on {device}...')
if HF_CACHE_DIR:
    print(f'Using Hugging Face cache dir: {HF_CACHE_DIR}')
model_source, local_only = resolve_model_source(MODEL_NAME, HF_CACHE_DIR, FORCE_DOWNLOAD)
load_kwargs = {
    'cache_dir': HF_CACHE_DIR,
    'force_download': FORCE_DOWNLOAD,
    'local_files_only': local_only,
}
if dtype is not None:
    load_kwargs['dtype'] = dtype
if LOW_CPU_MEM:
    load_kwargs['low_cpu_mem_usage'] = True
print('Loading tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(model_source, **load_kwargs)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
print('Loading model weights...')
model = AutoModelForCausalLM.from_pretrained(model_source, **load_kwargs)
if device != 'cpu':
    print(f'Moving model to {device}...')
    model.to(device)

model.eval()
if hasattr(model, 'generation_config') and model.generation_config is not None:
    model.generation_config.top_k = 0
print('Model ready')




In [None]:
storage.ensure_retrieval_index(refresh=True)

In [None]:
# Optional demo: build episodes from faux_events.jsonl
# This is for reviewer sanity-checks only and is not used in paper results.

# Copy faux_events.jsonl from data/Demo_Dataset/ to data/events/ to run this cell,
# or change faux_path below to point to its current location.

# NOTE: Remove the faux dataset events and episodes or skip this cell during actual metric evaluations
# to avoid contaminating the dataset with synthetic events.



# faux_path = REPO_ROOT / 'data' / 'events' / 'faux_events.jsonl'
# if faux_path.exists():
#     print(f'Building episodes from {faux_path}...')
#     print(f'Episodes destination: {storage.STORAGE.episodes}')
#     storage.build_episodes_from_jsonl(faux_path, overwrite=False)
#     storage.ensure_retrieval_index(refresh=True)
#     print('Done.')
# else:
#     print('faux_events.jsonl not found; skipping demo cell.')


In [None]:
run_chat_loop(
    model=model,
    tokenizer=tokenizer,
    storage=storage,
    chat_id=CHAT_ID,
    top_k=TOP_K,
    max_new_tokens=MAX_NEW_TOKENS,
    extract_max_new_tokens=EXTRACT_MAX_NEW_TOKENS,
    event_context_budget_tokens=EVENT_CONTEXT_BUDGET_TOKENS,
    temperature=TEMPERATURE,
    top_p=TOP_P,
    refresh_every=REFRESH_EVERY,
    episode_lookback_months=EPISODE_LOOKBACK_MONTHS,
    episode_events_per_context=EPISODE_EVENTS_PER_CONTEXT,
    episode_top_k=EPISODE_TOP_K,
    episode_prefilter=EPISODE_PREFILTER,
    memory_score_threshold=MEMORY_SCORE_THRESHOLD,
    allow_general_fallback=ALLOW_GENERAL_FALLBACK,
    memory_system_prompt=MEMORY_SYSTEM_PROMPT,
    general_system_prompt=GENERAL_SYSTEM_PROMPT,
    stream=STREAM,
    should_use_memory_fn=should_use_memory,
)
