# explore_bis_v5 — pipeline ENC → MEM → DEC


In [1]:
from pathlib import Path
import sys

ROOT = Path.cwd().parent
SRC = ROOT / "src"
if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))
print(f"Using src path: {SRC}")


Using src path: /Users/aymenmejri/Desktop/MyCode/experiments/hdc_v2/hdc_project/src


In [10]:
import logging
from collections import Counter, defaultdict
from typing import List, Sequence, Tuple

import numpy as np
from tqdm import tqdm

from hdc_project.encoder import m4, pipeline as enc_pipeline
from hdc_project.encoder.mem import pipeline as mem_pipeline
from hdc_project.decoder import (
    DD1_ctx,
    DD2_query,
    DD2_query_bin,
    DD3_bindToMem,
    DD4_search_topK,
    DD5_payload,
    DD6_vote,
    DD7_updateLM,
    DecodeOneStep,
    DX2_run,
    DX3_run,
    DX4_run,
    DX5_run,
    DX6_run,
    DX7_run,
)
from hdc_project.decoder.dec import (
    hd_assert_pm1,
    hd_bind,
    hd_sim,
    hd_sim_dot,
    build_perm_inverse,
    permute_pow,
    permute_pow_signed,
    rademacher,
    _as_vocab_from_buckets,
)

log = logging.getLogger("explore_v5")
if not log.handlers:
    logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")


## 1. Chargement du corpus et encodage ENC


In [11]:
MAX_SENTENCES = 5_000
N_OPUS = 20_000

try:
    ens_raw, frs_raw = enc_pipeline.opus_load_subset(
        name="opus_books",
        config="en-fr",
        split="train",
        N=N_OPUS,
        seed=2025,
    )
    log.info("OPUS subset loaded: %d pairs", len(ens_raw))
except Exception as exc:
    log.warning("OPUS download failed (%s); using a local toy corpus", exc)
    ens_raw = [
        "hyperdimensional computing is fun",
        "vector symbolic architectures are powerful",
        "encoding words into hyperspace",
        "memory augmented networks love clean data",
    ]
    frs_raw = [
        "le calcul hyperdimensionnel est amusant",
        "les architectures symboliques vectorielles sont puissantes",
        "encoder des mots dans l'hyperspace",
        "les réseaux augmentés de mémoire aiment les données propres",
    ]

ens_sample = ens_raw[:MAX_SENTENCES]
frs_sample = frs_raw[:MAX_SENTENCES]
log.info("Using %d sentence pairs", len(ens_sample))

D = 8_192
n = 5
rng = np.random.default_rng(123)

ARTIFACTS = ROOT / "artifacts"
ARTIFACTS.mkdir(exist_ok=True)


def _load_or_build_lex(path: Path, seed: int) -> m4.M4_LexEN:
    if path.exists():
        try:
            lex = m4.M4_LexEN.load(str(path))
            if lex.D != D:
                raise ValueError("dimension mismatch")
            log.info("Loaded cached lexicon %s", path.name)
            return lex
        except Exception as exc:  # pragma: no cover - fallback
            log.warning("Failed to load %s (%s); rebuilding", path.name, exc)
            path.unlink(missing_ok=True)
    lex = m4.M4_LexEN_new(seed=seed, D=D)
    log.info("Instantiated new lexicon %s", path.name)
    return lex


lex_en_path = ARTIFACTS / f"lex_en_D{D}.npz"
lex_fr_mem_path = ARTIFACTS / f"lex_fr_mem_D{D}.npz"
lex_fr_lm_path = ARTIFACTS / f"lex_fr_lm_D{D}.npz"

Lex_en = _load_or_build_lex(lex_en_path, seed=1)
Lex_fr_mem = _load_or_build_lex(lex_fr_mem_path, seed=2)
Lex_fr_lm = _load_or_build_lex(lex_fr_lm_path, seed=202)

# MEM/DEC compatibility contract: reuse the exact same callable.
L_FR_PAYLOAD = Lex_fr_mem.get  # payload lexicon (MEM & DEC share this object)
L_FR_LM = Lex_fr_lm.get  # LM-only lexicon (separate semantic space)

pi = rng.permutation(D).astype(np.int64)
pi_inv = build_perm_inverse(pi)

encoded_en = enc_pipeline.encode_corpus_ENC(ens_sample, Lex_en, pi, D, n, seg_seed0=999)
encoded_fr = enc_pipeline.encode_corpus_ENC(frs_sample, Lex_fr_mem, pi, D, n, seg_seed0=1999)
log.info("Encoded %d EN / %d FR sentences", len(encoded_en), len(encoded_fr))

E_list_en = [segment["E_seq"] for segment in encoded_en]
H_list_en = [segment["H"] for segment in encoded_en]
if H_list_en:
    log.info("Encoder signature shape: %s", H_list_en[0].shape)

intra_sim, inter_sim = enc_pipeline.intra_inter_ngram_sims(E_list_en, D)
inter_seg = enc_pipeline.inter_segment_similarity(H_list_en)
log.info(
    "ENC stats — intra: %.4f | inter(|.|): %.4f | inter segments: %.4f",
    intra_sim,
    inter_sim,
    inter_seg,
)


def _maybe_save_lex(path: Path, lex: m4.M4_LexEN) -> None:
    if path.exists():
        return
    table = getattr(lex, '_table', {})
    if table:
        lex.save(str(path))
        log.info('Saved lexicon %s', path.name)


for path, lex in ((lex_en_path, Lex_en), (lex_fr_mem_path, Lex_fr_mem), (lex_fr_lm_path, Lex_fr_lm)):
    _maybe_save_lex(path, lex)




2025-10-07 23:28:19,161 [INFO] OPUS subset loaded: 20000 pairs
2025-10-07 23:28:19,162 [INFO] Using 5000 sentence pairs
2025-10-07 23:28:20,918 [INFO] Loaded cached lexicon lex_en_D8192.npz
2025-10-07 23:28:49,664 [INFO] Loaded cached lexicon lex_fr_mem_D8192.npz
2025-10-07 23:28:49,672 [INFO] Instantiated new lexicon lex_fr_lm_D8192.npz
2025-10-07 23:32:43,994 [INFO] Encoded 5000 EN / 5000 FR sentences
2025-10-07 23:33:30,719 [INFO] Encoder signature shape: (8192,)
2025-10-07 23:34:04,766 [INFO] ENC stats — intra: 0.0003 | inter(|.|): 0.0255 | inter segments: 0.0088


## 2. Construction des paires MEM


### Analyse de l'alignement EN→FR
1. **Proxy linéaire raisonné.** Le calcul `start_fr = round(start_en * ratio)` reste un proxy commode mais imparfait : l'encodeur travaille dans l'espace positionnel imposé par \(\Pi\), et une projection linéaire des bornes EN vers le flux FR casse parfois l'isométrie sémantique (réordonnements, insertions, suppressions).
2. **Effet pratique.** La fenêtre FR ainsi projetée peut récupérer des tokens peu pertinents ; ceux-ci diluent la signature lexicale $Z_{\mathrm{fr,lex}}$ et, in fine, le prototype MEM couplé au segment EN.
3. **Correctif minimal (algorithme).**
   3.1. Construire $Z_{\mathrm{fr,lex}}$ sur la fenêtre $[\texttt{start\_fr}:\texttt{stop\_fr}]$.
   3.2. Élaguer itérativement les tokens dont le retrait augmente la cohérence interne $\langle Z_{\mathrm{fr,lex}},\, Z_{\mathrm{fr,lex}}^{(-t)}\rangle$ au-dessus d'un seuil fixé.
   3.3. (Option) Injecter un bigramme FR issu de l'historique si $|\texttt{span\_tokens}| < 2$ afin de stabiliser la majorité.
4. **Correctif préférable (séquences FR encodées disponibles).**
   - Réutiliser exactement les mêmes fenêtres et strides que côté EN sur $X_{\mathrm{seq}}^{(\mathrm{FR})}$.
   - Remplacer le contenu projeté par la signature lexicale des tokens FR co-localisés dans ces bornes, ce qui préserve la cohérence temporelle et l'homogénéité des tranches.


### Cas des spans vides
1. **Biais du vecteur constant.** Retourner le vecteur $\mathbf{1}$ (tous +1) pour un segment vide induit un biais global artificiel : le produit scalaire $\langle \mathbf{1},\, \cdot \rangle$ reste élevé, mais n'apporte aucune information discriminante et déforme la moyenne des prototypes MEM.
2. **Prior lexical neutre.** Remplacer $\mathbf{1}$ par un prior neutre défini par
   $$Z^{\text{prior}} = \operatorname{sign}\Big( \sum_{v\in \mathcal{V}_{\text{freq}}} w(v)\,L_{\mathrm{fr,mem}}(v) \Big)$$
   avec $w(v)\propto$ la fréquence empirique du token $v$. À défaut de statistiques fiables, utiliser un vecteur de Rademacher ($\pm 1$) avec seed fixé assure un comportement stationnaire sans corrélation systématique.
3. **Interdiction de $\mathbf{1}$.** Bannir explicitement $\mathbf{1}$ comme valeur par défaut : elle biaise la densité des prototypes et fausse les similarités inter-mémoire, compromettant les décisions DD6.


In [26]:

from typing import Callable


def content_signature_from_Xseq(X_seq: Sequence[np.ndarray], *, majority: str = "strict") -> np.ndarray:
    if not X_seq:
        raise ValueError("X_seq vide")
    acc = np.zeros(X_seq[0].shape[0], dtype=np.int32)
    for x in X_seq:
        acc += x.astype(np.int32, copy=False)
    if majority == "strict":
        return np.where(acc >= 0, 1, -1).astype(np.int8, copy=False)
    if majority == "unbiased":
        rng_local = np.random.default_rng(0)
        ties = acc == 0
        acc[ties] = rng_local.integers(0, 2, size=int(ties.sum()), dtype=np.int32) * 2 - 1
        return np.where(acc >= 0, 1, -1).astype(np.int8, copy=False)
    raise ValueError("majority must be 'strict' or 'unbiased'")


def span_signatures_from_trace(
    X_seq: Sequence[np.ndarray],
    *,
    win: int,
    stride: int,
    majority: str,
) -> List[Tuple[np.ndarray, int, int]]:
    T = len(X_seq)
    if T == 0:
        return []
    spans: List[Tuple[np.ndarray, int, int]] = []
    if T <= win:
        spans.append((content_signature_from_Xseq(X_seq, majority=majority), 0, T))
        return spans
    for start in range(0, T - win + 1, max(1, stride)):
        stop = start + win
        spans.append((content_signature_from_Xseq(X_seq[start:stop], majority=majority), start, stop))
    return spans


def lexical_signature_from_tokens(
    tokens: Sequence[str],
    L_fr_mem: Callable[[str], np.ndarray],
    D: int,
    *,
    prior: np.ndarray | None = None,
    rng: np.random.Generator | None = None,
    history_tokens: Sequence[str] | None = None,
    prune_alignment_threshold: float = 0.0,
    min_history_tokens: int = 2,
) -> tuple[np.ndarray, list[str]]:
    """Build a MEM payload signature from tokens using the shared payload lexicon.

    Parameters
    ----------
    tokens : Sequence[str]
        Token span under consideration.
    L_fr_mem : Callable[[str], np.ndarray]
        MUST be `L_FR_PAYLOAD` so that MEM and DEC operate in the same
        hyperdimensional space.
    """
    rng = np.random.default_rng(0) if rng is None else rng
    tok_list = [tok for tok in tokens if tok]
    if len(tok_list) < min_history_tokens and history_tokens:
        for tok in list(history_tokens)[-min_history_tokens:]:
            if tok not in tok_list:
                tok_list.append(tok)
            if len(tok_list) >= min_history_tokens:
                break
    if not tok_list:
        if prior is not None:
            return prior.astype(np.int8, copy=False), []
        noise = 2 * rng.integers(0, 2, size=D, dtype=np.int8) - 1
        return noise.astype(np.int8, copy=False), []
    acc = np.zeros(D, dtype=np.int32)
    vecs: list[np.ndarray] = []
    for tok in tok_list:
        vec = L_fr_mem(tok).astype(np.int8, copy=False)
        hd_assert_pm1(vec, D)
        vec_i32 = vec.astype(np.int32, copy=False)
        vecs.append(vec)
        acc += vec_i32
    signature = np.where(acc >= 0, 1, -1).astype(np.int8, copy=False)

    if len(vecs) > 2:
        sig_i32 = signature.astype(np.int32, copy=False)
        filtered_tokens: list[str] = []
        filtered_vecs: list[np.ndarray] = []
        for tok, vec in zip(tok_list, vecs):
            align = float(np.dot(vec.astype(np.int32, copy=False), sig_i32)) / D
            if align >= prune_alignment_threshold:
                filtered_tokens.append(tok)
                filtered_vecs.append(vec)
        if filtered_tokens and len(filtered_tokens) < len(tok_list):
            tok_list = filtered_tokens
            vecs = filtered_vecs
            acc = np.zeros(D, dtype=np.int32)
            for vec in vecs:
                acc += vec.astype(np.int32, copy=False)
            signature = np.where(acc >= 0, 1, -1).astype(np.int8, copy=False)
    return signature, list(tok_list)


def build_mem_pairs_with_meta(
    encoded_en: Sequence[dict],
    encoded_fr: Sequence[dict],
    tokens_fr: Sequence[Sequence[str]],
    *,
    L_fr_mem: Callable[[str], np.ndarray],
    win: int = 8,
    stride: int = 4,
    majority: str = "strict",
    max_pairs: int | None = None,
    prior: np.ndarray | None = None,
    rng_empty: np.random.Generator | None = None,
    prune_alignment_threshold: float = 0.0,
    min_history_tokens: int = 2,
) -> Tuple[List[Tuple[np.ndarray, np.ndarray]], List[dict]]:
    rng_empty = np.random.default_rng(0) if rng_empty is None else rng_empty
    pairs: List[Tuple[np.ndarray, np.ndarray]] = []
    meta: List[dict] = []
    N = min(len(encoded_en), len(encoded_fr))
    for idx in tqdm(range(N), desc="MEM span extraction", leave=False):
        spans_en = span_signatures_from_trace(encoded_en[idx]["X_seq"], win=win, stride=stride, majority=majority)
        spans_fr_trace = span_signatures_from_trace(encoded_fr[idx]["X_seq"], win=win, stride=stride, majority=majority)
        tok_fr = list(tokens_fr[idx]) if idx < len(tokens_fr) else []
        len_seq_en = len(encoded_en[idx]["X_seq"])
        len_seq_fr = len(encoded_fr[idx]["X_seq"])
        len_tok_fr = len(tok_fr)
        if len_seq_en == 0:
            continue
        ratio_en = len_tok_fr / max(len_seq_en, 1)
        ratio_fr = len_tok_fr / max(len_seq_fr, 1)
        paired_len = min(len(spans_en), len(spans_fr_trace))

        for span_idx, (ze, start_en, stop_en) in enumerate(spans_en):
            if len_tok_fr:
                if span_idx < paired_len:
                    _zf_fr, start_fr_trace, stop_fr_trace = spans_fr_trace[span_idx]
                    start_fr = int(round(start_fr_trace * ratio_fr))
                    stop_fr = int(round(stop_fr_trace * ratio_fr))
                    fr_trace_bounds = (int(start_fr_trace), int(stop_fr_trace))
                else:
                    start_fr = int(round(start_en * ratio_en))
                    stop_fr = int(round(stop_en * ratio_en))
                    fr_trace_bounds = None
                start_fr = max(0, min(start_fr, len_tok_fr - 1))
                stop_fr = max(start_fr + 1, min(len_tok_fr, stop_fr))
                span_tokens_raw = tok_fr[start_fr:stop_fr]
                history_tokens = tok_fr[max(0, start_fr - stride):start_fr]
            else:
                start_fr = 0
                stop_fr = 0
                span_tokens_raw = []
                history_tokens = []
                fr_trace_bounds = None
            zf_lex, span_tokens_filtered = lexical_signature_from_tokens(
                span_tokens_raw,
                L_fr_mem,
                D,
                prior=prior,
                rng=rng_empty,
                history_tokens=history_tokens,
                prune_alignment_threshold=prune_alignment_threshold,
                min_history_tokens=min_history_tokens,
            )
            pairs.append((ze, zf_lex))
            meta.append(
                {
                    "sentence_idx": idx,
                    "start": start_en,
                    "stop": stop_en,
                    "start_token_fr": start_fr,
                    "stop_token_fr": stop_fr,
                    "history_tokens": list(history_tokens),
                    "span_tokens": span_tokens_filtered,
                    "span_tokens_raw": span_tokens_raw,
                    "fr_trace_bounds": fr_trace_bounds,
                    "Z_en": ze,
                    "Z_fr_lex": zf_lex,
                }
            )
            if max_pairs is not None and len(pairs) >= max_pairs:
                return pairs, meta
    return pairs, meta


tokens_fr = [enc_pipeline.sentence_to_tokens_EN(sent, vocab=set()) for sent in frs_sample]
token_freq = Counter(tok for seq in tokens_fr for tok in seq)
if token_freq:
    prior_acc = np.zeros(D, dtype=np.int32)
    for tok, freq in token_freq.items():
        vec = L_FR_PAYLOAD(tok).astype(np.int8, copy=False)
        hd_assert_pm1(vec, D)
        prior_acc += vec.astype(np.int32, copy=False) * int(freq)
    lexical_prior_fr = np.where(prior_acc >= 0, 1, -1).astype(np.int8, copy=False)
else:
    lexical_prior_fr = rademacher(D, np.random.default_rng(404))

empty_span_rng = np.random.default_rng(2027)
pairs_mem, span_meta = build_mem_pairs_with_meta(
    encoded_en,
    encoded_fr,
    tokens_fr,
    L_fr_mem=L_FR_PAYLOAD,
    win=8,
    stride=4,
    majority="strict",
    max_pairs=50_000,
    prior=lexical_prior_fr,
    rng_empty=empty_span_rng,
    prune_alignment_threshold=0.0,
    min_history_tokens=2,
)
log.info("Prepared %d MEM pairs", len(pairs_mem))




2025-10-07 16:38:26,384 [INFO] Prepared 50000 MEM pairs                 


In [27]:
span_meta

[{'sentence_idx': 0,
  'start': 0,
  'stop': 8,
  'start_token_fr': 0,
  'stop_token_fr': 6,
  'history_tokens': [],
  'span_tokens': ["l'idée",
   'brusque',
   'du',
   'mariage',
   "qu'elle",
   'poursuivait'],
  'Z_en': array([ 1, -1,  1, ...,  1, -1,  1], shape=(8192,), dtype=int8),
  'Z_fr_lex': array([-1, -1, -1, ..., -1, -1,  1], shape=(8192,), dtype=int8)},
 {'sentence_idx': 0,
  'start': 4,
  'stop': 12,
  'start_token_fr': 3,
  'stop_token_fr': 9,
  'history_tokens': ["l'idée", 'brusque', 'du'],
  'span_tokens': ['mariage',
   "qu'elle",
   'poursuivait',
   "d'un",
   'sourire',
   'si'],
  'Z_en': array([ 1,  1, -1, ...,  1,  1,  1], shape=(8192,), dtype=int8),
  'Z_fr_lex': array([-1, -1,  1, ..., -1,  1,  1], shape=(8192,), dtype=int8)},
 {'sentence_idx': 0,
  'start': 8,
  'stop': 16,
  'start_token_fr': 6,
  'stop_token_fr': 12,
  'history_tokens': ['du', 'mariage', "qu'elle", 'poursuivait'],
  'span_tokens': ["d'un", 'sourire', 'si', 'tranquille', 'entre', 'cécile'],

## 3. Entraînement MEM et diagnostic rapide


In [28]:
MEM_K = 16
MEM_BUCKETS = 128
cfg = mem_pipeline.MemConfig(D=D, B=MEM_BUCKETS, k=MEM_K, seed_lsh=10, seed_gmem=11)
comp = mem_pipeline.make_mem_pipeline(cfg)
mem_pipeline.train_one_pass_MEM(comp, pairs_mem)
log.info("MEM training completed (B=%d)", comp.mem.B)

probe_count = min(200, len(pairs_mem))
sim_values = []
for Z_en_vec, Z_fr_vec in tqdm(pairs_mem[:probe_count], desc="MEM probe"):
    bucket_idx, _ = mem_pipeline.infer_map_top1(comp, Z_en_vec)
    prototype = comp.mem.H[bucket_idx].astype(np.int32, copy=False)
    sim = float(np.dot(prototype, Z_fr_vec.astype(np.int32, copy=False)) / D)
    sim_values.append(sim)

if sim_values:
    log.info(
        "Probe similarities — mean: %.4f | median: %.4f",
        float(np.mean(sim_values)),
        float(np.median(sim_values)),
    )
    nb = comp.mem.n
    log.info(
        "Bucket population stats — mean: %.1f | p90: %d | p99: %d",
        float(nb.mean()),
        int(np.quantile(nb, 0.90)),
        int(np.quantile(nb, 0.99)),
    )


2025-10-07 16:39:15,108 [INFO] MEM training completed (B=128)
MEM probe: 100%|██████████| 200/200 [00:00<00:00, 17927.06it/s]
2025-10-07 16:39:15,123 [INFO] Probe similarities — mean: 0.1674 | median: 0.2489
2025-10-07 16:39:15,124 [INFO] Bucket population stats — mean: 390.6 | p90: 443 | p99: 477


## 4. Dictionnaire bucket → vocabulaire


In [29]:
MAX_BUCKET_VOCAB = 256
bucket_counts: dict[int, Counter] = defaultdict(Counter)
for meta in tqdm(span_meta, desc="Bucket vocab build", leave=False):
    bucket_idx, _ = mem_pipeline.infer_map_top1(comp, meta["Z_en"])
    bucket_idx = int(bucket_idx)
    meta["bucket_idx"] = bucket_idx
    for tok in meta.get("span_tokens", []):
        bucket_counts[bucket_idx][tok] += 1

bucket2vocab_freq = {
    bucket: sorted(counter.items(), key=lambda kv: (-kv[1], kv[0]))[:MAX_BUCKET_VOCAB]
    for bucket, counter in bucket_counts.items()
}
bucket2vocab = {bucket: [tok for tok, _ in tokens] for bucket, tokens in bucket2vocab_freq.items()}
all_vocab = sorted({tok for tokens in bucket2vocab.values() for tok in tokens})
log.info(
    "Bucket vocab built for %d buckets (global vocab size=%d)",
    len(bucket2vocab),
    len(all_vocab),
)


2025-10-07 16:39:36,690 [INFO] Bucket vocab built for 128 buckets (global vocab size=12629)


## 5bis. Diagnostics théorie ↔ implémentation


In [30]:
log.info("Running DEC diagnostic suite (subsampled)...")

norms = DX2_run(D=D, trials=50, ells=(2, 4, 8), ratios=(1.0,), seed=2025)
log.info("DX2 ok — example median norm: %.3f", np.median([v[1] for v in norms.values()]))

rel_err, pval = DX3_run(D=D, C=256, T=64, seed=2025, rel_tol=0.02, pmin=0.05)
log.info("DX3 ok — mean relative error=%.4f | p=%.3f", rel_err, pval)

recalls = DX4_run(D=D, B=5_000, trials=40, Ks=(100, 500), seed=0)
log.info("DX4 ok — recall@500=%.3f", recalls[500])

accuracies = DX5_run(D=D, trials=40, ms=(4, 8, 16), seed=0)
log.info("DX5 ok — accuracy m=8: %.3f", accuracies[8])

results_dx6 = DX6_run(D=D, trials=120, lam_grid=(0.0, 0.5, 1.0), rng_seed=7031)
log.info("DX6 ok — lambda grid summary: %s", {lam: (vals['top1'], vals['ppl']) for lam, vals in results_dx6.items()})

dx7_results, ell_star = DX7_run(ell_grid=(2, 4, 8), D=D, seed_pi=10_456, rng_seed=9_117)
log.info("DX7 ok — ell*=%d, top1=%.3f", ell_star, dx7_results[ell_star]["top1"])


2025-10-07 16:39:48,088 [INFO] Running DEC diagnostic suite (subsampled)...
2025-10-07 16:39:48,466 [INFO] DX2 ok — example median norm: 1.000
2025-10-07 16:39:48,617 [INFO] DX3 ok — mean relative error=0.0000 | p=1.000
2025-10-07 16:39:57,447 [INFO] DX4 ok — recall@500=1.000
2025-10-07 16:39:57,456 [INFO] DX5 ok — accuracy m=8: 1.000
2025-10-07 16:39:57,566 [INFO] DX6 ok — lambda grid summary: {0.0: (1.0, 1.1109939911767361), 0.5: (1.0, 1.0000001843669586), 1.0: (1.0, 1.000000000000448)}
2025-10-07 16:40:03,095 [INFO] DX7 ok — ell*=2, top1=0.926


In [32]:
if not span_meta:
    raise RuntimeError("Aucune paire MEM disponible pour la démonstration DEC.")

G_DEC = rademacher(D, np.random.default_rng(2025))
G_MEM = comp.Gmem
L_fr_payload = L_FR_PAYLOAD  # MEM=DEC payload space
L_fr_lm = L_FR_LM
prototypes = comp.mem.H.astype(np.int8, copy=False)

ELL = 4
CAND_PER_BUCKET = 32
STAGE1_LIMIT = 24
STAGE2_LIMIT = 8
V_MAX = 512
LAM = 0.2


def update_LM_sep(H_LM: np.ndarray, token: str) -> np.ndarray:
    vec = L_FR_LM(token).astype(np.int8, copy=False)
    hd_assert_pm1(vec, D)
    inc = permute_pow(vec, pi, 1).astype(np.int16, copy=False)
    acc = H_LM.astype(np.int16) + inc
    return np.where(acc >= 0, 1, -1).astype(np.int8, copy=False)


def gather_candidates_from_ck(
    C_K: Sequence[int],
    history: Sequence[str],
    *,
    limit: int | None = None,
) -> list[str]:
    seen: set[str] = set()
    out: list[str] = []
    cap = limit or V_MAX
    for c in C_K:
        for tok, _cnt in bucket2vocab_freq.get(int(c), [])[:CAND_PER_BUCKET]:
            if tok not in seen:
                seen.add(tok)
                out.append(tok)
            if len(out) >= cap:
                break
        if len(out) >= cap:
            break
    for tok in reversed(history):
        if tok not in seen:
            seen.add(tok)
            out.append(tok)
        if len(out) >= cap:
            break
    if not out and all_vocab:
        out.extend(all_vocab[:cap])
    if not out:
        out.append("<unk>")
    return out[:cap]


def decode_span_staged(meta: dict, H_LM: np.ndarray, history: list[str], *, max_steps: int) -> tuple[list[str], list[str], np.ndarray]:
    decoded: list[str] = []
    for _ in range(max_steps):
        stage1_token, stage1_scores, _c1, C_K1, _ = DecodeOneStep(
            Hs=meta["Z_en"],
            H_LM=H_LM,
            history_fr=history,
            G_DEC=G_DEC,
            G_MEM=G_MEM,
            Pi=pi,
            L_fr=L_fr_payload,
            prototypes=prototypes,
            K=64,
            alpha=1.0,
            beta=0.0,
            ell=ELL,
            lam=0.0,
            bucket2vocab=bucket2vocab,
            global_fallback_vocab=all_vocab[:512] if all_vocab else None,
            return_ck_scores=True,
        )

        cand_stage1 = gather_candidates_from_ck(C_K1, history, limit=STAGE1_LIMIT)
        order = np.argsort(stage1_scores)[::-1]
        stage_tokens = [cand_stage1[i] for i in order[:STAGE2_LIMIT] if i < len(cand_stage1)]
        if not stage_tokens:
            stage_tokens = cand_stage1[:STAGE2_LIMIT] or [stage1_token]

        bucket2vocab_stage2 = {
            c: [tok for tok, _cnt in bucket2vocab_freq.get(c, []) if tok in stage_tokens]
            for c in bucket2vocab_freq.keys()
        }

        token_star, _scores2, _c_star2, _C_K2, _ = DecodeOneStep(
            Hs=meta["Z_en"],
            H_LM=H_LM,
            history_fr=history,
            G_DEC=G_DEC,
            G_MEM=G_MEM,
            Pi=pi,
            L_fr=L_fr_payload,
            prototypes=prototypes,
            K=max(8, len(stage_tokens)),
            alpha=1.0,
            beta=0.0,
            ell=ELL,
            lam=LAM,
            bucket2vocab=bucket2vocab_stage2,
            global_fallback_vocab=stage_tokens,
            return_ck_scores=False,
        )

        decoded.append(token_star)
        history.append(token_star)
        if len(history) > ELL:
            history[:] = history[-ELL:]
        H_LM = update_LM_sep(H_LM, token_star)
    return decoded, history, H_LM


def decode_sentence(sentence_idx: int, *, max_steps_per_span: int = 8) -> list[str]:
    metas = sorted(
        [m for m in span_meta if m["sentence_idx"] == sentence_idx],
        key=lambda m: m["start"],
    )
    if not metas:
        return []

    history: list[str] = []
    rng_sent = np.random.default_rng(9_999 + sentence_idx)
    H_LM = rademacher(D, rng_sent)

    decoded_sentence: list[str] = []
    for meta in metas:
        seed = list(meta.get("history_tokens", [])[-ELL:])
        if seed and not history:
            history = seed.copy()
            for tok in history:
                H_LM = update_LM_sep(H_LM, tok)

        steps = min(len(meta.get("span_tokens", [])), max_steps_per_span)
        if steps <= 0:
            continue

        decoded, history, H_LM = decode_span_staged(meta, H_LM, history, max_steps=steps)
        decoded_sentence.extend(decoded)
    return decoded_sentence


N_EVAL_SENT = 5
bleu_scores: list[float] = []
rouge1_scores: list[float] = []
rouge2_scores: list[float] = []

for idx in range(min(N_EVAL_SENT, len(ens_sample))):
    pred_tokens = decode_sentence(idx, max_steps_per_span=8)
    ref_tokens = clean_tokens(frs_sample[idx].lower().split())
    pred_clean = clean_tokens(pred_tokens)
    bleu = bleu_score(ref_tokens, pred_clean)
    rouge1 = rouge_n(ref_tokens, pred_clean, n=1)
    rouge2 = rouge_n(ref_tokens, pred_clean, n=2)
    bleu_scores.append(bleu)
    rouge1_scores.append(rouge1)
    rouge2_scores.append(rouge2)

    print(f"Sentence {idx}:")
    print("  EN   :", ens_sample[idx])
    print("  REF  :", " ".join(ref_tokens))
    print("  PRED :", " ".join(pred_clean) if pred_clean else "<empty>")
    print(f"  BLEU={bleu:.3f} | ROUGE-1 F1={rouge1:.3f} | ROUGE-2 F1={rouge2:.3f}")
    print()

if bleu_scores:
    print("=== Averages ===")
    print(
        f"BLEU={np.mean(bleu_scores):.3f} | ROUGE-1={np.mean(rouge1_scores):.3f} | ROUGE-2={np.mean(rouge2_scores):.3f}"
    )
else:
    print("Aucune phrase évaluée.")



Sentence 0:
  EN   : The sudden idea of the marriage between C&amp;eacue;cile and Paul, which she was arranging with so quiet a smile, completed his exasperation.
  REF  : l'idée brusque du mariage qu'elle poursuivait d'un sourire si tranquille entre cécile et paul, acheva de l'exaspérer.
  PRED : <empty>
  BLEU=0.000 | ROUGE-1 F1=0.000 | ROUGE-2 F1=0.000

Sentence 1:
  EN   : Harris said: "Seven."
  REF  : – a sept heures, dit harris.
  PRED : <empty>
  BLEU=0.000 | ROUGE-1 F1=0.000 | ROUGE-2 F1=0.000

Sentence 2:
  EN   : His hair was long and black, not curled like wool; his forehead very high and large; and a great vivacity and sparkling sharpness in his eyes.
  REF  : sa chevelure était longue et noire, et non pas crépue comme de la laine. son front était haut et large, ses yeux vifs et pleins de feu.
  PRED : <empty>
  BLEU=0.000 | ROUGE-1 F1=0.000 | ROUGE-2 F1=0.000

Sentence 3:
  EN   : He was Des Roches le Masle, canon of Notre Dame, who had formerly been valet of a bishop, wh

In [31]:
if not span_meta:
    raise RuntimeError("Aucune paire MEM disponible pour la démonstration DEC.")

G_DEC = rademacher(D, np.random.default_rng(2025))
G_MEM = comp.Gmem
L_fr_payload = L_FR_PAYLOAD  # MEM=DEC payload space
L_fr_lm = L_FR_LM
prototypes = comp.mem.H.astype(np.int8, copy=False)

ELL = 4
CAND_PER_BUCKET = 32
V_MAX = 512
LAM = 0.5


def update_LM_sep(H_LM: np.ndarray, token: str) -> np.ndarray:
    vec = L_FR_LM(token).astype(np.int8, copy=False)
    hd_assert_pm1(vec, D)
    inc = permute_pow(vec, pi, 1).astype(np.int16, copy=False)
    acc = H_LM.astype(np.int16) + inc
    return np.where(acc >= 0, 1, -1).astype(np.int8, copy=False)


def vote_two_spaces(Z_hat: np.ndarray, H_LM: np.ndarray, cand_vocab: Sequence[str]) -> tuple[str, np.ndarray]:
    if not cand_vocab:
        cand_vocab = ["<unk>"]
    token_star, scores, _ = DD6_vote(
        Z_hat,
        H_LM,
        L_mem=L_FR_PAYLOAD,
        L_lm=L_FR_LM,
        cand_vocab=cand_vocab,
        lam=LAM,
        return_probs=False,
    )
    return token_star, scores


def gather_candidates(C_K: np.ndarray, history: Sequence[str]) -> list[str]:
    seen: set[str] = set()
    candidates: list[str] = []
    for c in C_K:
        freq_list = bucket2vocab_freq.get(int(c), [])[:CAND_PER_BUCKET]
        for tok, _cnt in freq_list:
            if tok not in seen:
                seen.add(tok)
                candidates.append(tok)
            if len(candidates) >= V_MAX:
                break
        if len(candidates) >= V_MAX:
            break
    for tok in reversed(history):
        if tok not in seen:
            candidates.append(tok)
            seen.add(tok)
        if len(candidates) >= V_MAX:
            break
    if not candidates and all_vocab:
        candidates.extend(all_vocab[: min(64, len(all_vocab))])
    if not candidates:
        candidates.append("<unk>")
    return candidates[:V_MAX]


def decode_span(
    meta: dict,
    H_LM_init: np.ndarray,
    history_init: Sequence[str],
    *,
    max_steps: int | None = None,
    ell: int = ELL,
) -> tuple[list[str], list[str], np.ndarray]:
    history = list(history_init)
    H_LM = H_LM_init
    targets = list(meta.get("span_tokens", []))
    steps = max_steps if max_steps is not None else len(targets)
    if steps is None or steps < 0:
        steps = len(targets)
    if steps == 0:
        return [], history, H_LM

    decoded: list[str] = []
    for _ in range(steps):
        Qs = DD1_ctx(meta["Z_en"], G_DEC)
        Rt = DD2_query_bin(Qs, history, L_FR_LM, pi, alpha=1.0, beta=1.0, ell=ell)
        Rt_tilde = DD3_bindToMem(Rt, G_MEM)
        c_star, C_K, _ = DD4_search_topK(Rt_tilde, prototypes, K=64)
        Z_hat = DD5_payload(prototypes[c_star])
        cand_vocab = gather_candidates(C_K, history)
        token_star, _ = vote_two_spaces(Z_hat, H_LM, cand_vocab)
        decoded.append(token_star)
        history.append(token_star)
        if len(history) > ell:
            history = history[-ell:]
        H_LM = update_LM_sep(H_LM, token_star)
    return decoded, history, H_LM


sample_metas = [m for m in span_meta if m.get("span_tokens")] or span_meta[:5]
results = []
for meta in sample_metas[:5]:
    history_seed = list(meta.get("history_tokens", [])[-ELL:])
    rng_demo = np.random.default_rng(4242 + meta["sentence_idx"])
    H_LM_seed = rademacher(D, rng_demo)
    for tok in history_seed:
        H_LM_seed = update_LM_sep(H_LM_seed, tok)

    decoded, history_out, _ = decode_span(
        meta,
        H_LM_seed,
        history_seed,
        max_steps=len(meta.get("span_tokens", [])) if meta.get("span_tokens") else 4,
    )

    targets = meta.get("span_tokens", [])[: len(decoded)]
    matches = sum(p == t for p, t in zip(decoded, targets))
    results.append(
        {
            "sentence_idx": meta["sentence_idx"],
            "span_bounds": (meta["start"], meta["stop"]),
            "history_seed": history_seed,
            "decoded": decoded,
            "reference": targets,
            "match_count": matches,
        }
    )

for res in results:
    print("Sentence #", res["sentence_idx"], "span", res["span_bounds"])
    print("  History seed:", res["history_seed"])
    print("  Decoded   :", " ".join(res["decoded"]))
    print("  Reference :", " ".join(res["reference"]))
    if res["decoded"]:
        print(f"  Match count: {res['match_count']}/{len(res['decoded'])}")
    print()





Sentence # 0 span (0, 8)
  History seed: []
  Decoded   : __sent_marker_2 __sent_marker_2 __sent_marker_2 __sent_marker_2 __sent_marker_2 __sent_marker_2
  Reference : l'idée brusque du mariage qu'elle poursuivait
  Match count: 0/6

Sentence # 0 span (4, 12)
  History seed: ["l'idée", 'brusque', 'du']
  Decoded   : __sent_marker_0 __sent_marker_0 __sent_marker_2 __sent_marker_2 __sent_marker_0 __sent_marker_0
  Reference : mariage qu'elle poursuivait d'un sourire si
  Match count: 0/6

Sentence # 0 span (8, 16)
  History seed: ['du', 'mariage', "qu'elle", 'poursuivait']
  Decoded   : __sent_marker_0 __sent_marker_0 __sent_marker_0 __sent_marker_0 __sent_marker_0 __sent_marker_0
  Reference : d'un sourire si tranquille entre cécile
  Match count: 0/6

Sentence # 0 span (12, 20)
  History seed: ['poursuivait', "d'un", 'sourire', 'si']
  Decoded   : __sent_marker_0 __sent_marker_2 __sent_marker_2 __sent_marker_2 __sent_marker_2 __sent_marker_0
  Reference : tranquille entre cécile et pau

In [None]:
import math
from collections import Counter


def clean_tokens(seq: Sequence[str]) -> list[str]:
    cleaned: list[str] = []
    for tok in seq:
        if not tok:
            continue
        if tok.startswith("__sent_marker"):
            continue
        if "_dup" in tok or "__" in tok or "_" in tok:
            continue
        cleaned.append(tok)
    return cleaned


def ngram_counts(tokens: Sequence[str], n: int) -> Counter:
    if n <= 0:
        raise ValueError("n doit être >= 1")
    return Counter(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1))


def bleu_score(reference: Sequence[str], candidate: Sequence[str], max_n: int = 4) -> float:
    if not candidate:
        return 0.0
    precisions: list[float] = []
    for n in range(1, max_n + 1):
        ref_counts = ngram_counts(reference, n)
        cand_counts = ngram_counts(candidate, n)
        if not cand_counts:
            precisions.append(1e-9)
            continue
        overlap = sum(min(count, ref_counts[ng]) for ng, count in cand_counts.items())
        total = sum(cand_counts.values())
        precisions.append((overlap + 1e-9) / (total + 1e-9))
    geo_mean = math.exp(sum(math.log(p) for p in precisions) / max_n)
    ref_len = len(reference)
    cand_len = len(candidate)
    if cand_len == 0:
        return 0.0
    bp = 1.0 if cand_len > ref_len else math.exp(1.0 - ref_len / max(cand_len, 1))
    return float(bp * geo_mean)


def rouge_n(reference: Sequence[str], candidate: Sequence[str], n: int = 1) -> float:
    if not reference or not candidate:
        return 0.0
    ref_counts = ngram_counts(reference, n)
    cand_counts = ngram_counts(candidate, n)
    if not ref_counts or not cand_counts:
        return 0.0
    overlap = sum(min(count, cand_counts.get(ng, 0)) for ng, count in ref_counts.items())
    recall = overlap / sum(ref_counts.values())
    precision = overlap / sum(cand_counts.values())
    if precision + recall == 0:
        return 0.0
    return float(2 * precision * recall / (precision + recall))


def decode_sentence(sentence_idx: int, *, max_steps_per_span: int = 8) -> list[str]:
    metas = sorted(
        [m for m in span_meta if m["sentence_idx"] == sentence_idx],
        key=lambda m: m["start"],
    )
    if not metas:
        return []

    history: list[str] = []
    rng_sent = np.random.default_rng(9_999 + sentence_idx)
    H_LM = rademacher(D, rng_sent)

    decoded_sentence: list[str] = []
    for meta in metas:
        if not history:
            seed = list(meta.get("history_tokens", [])[-ELL:])
            if seed:
                history = seed.copy()
                for tok in history:
                    H_LM = update_LM_sep(H_LM, tok)

        steps = min(len(meta.get("span_tokens", [])), max_steps_per_span)
        if steps <= 0:
            continue

        for _ in range(steps):
            token_star, _scores, _c_star, _C_K, _H_next = DecodeOneStep(
                Hs=meta["Z_en"],
                H_LM=H_LM,
                history_fr=history,
                G_DEC=G_DEC,
                G_MEM=G_MEM,
                Pi=pi,
                L_fr=L_fr_payload,
                prototypes=prototypes,
                K=64,
                alpha=1.0,
                beta=1.0,
                ell=ELL,
                lam=LAM,
                bucket2vocab=bucket2vocab,
                global_fallback_vocab=all_vocab[:512] if all_vocab else None,
                return_ck_scores=False,
            )
            decoded_sentence.append(token_star)
            history.append(token_star)
            if len(history) > ELL:
                history = history[-ELL:]
            H_LM = update_LM_sep(H_LM, token_star)

    return decoded_sentence


N_EVAL_SENT = 5
bleu_scores: list[float] = []
rouge1_scores: list[float] = []
rouge2_scores: list[float] = []

for idx in range(min(N_EVAL_SENT, len(ens_sample))):
    pred_tokens = decode_sentence(idx, max_steps_per_span=8)
    ref_tokens = clean_tokens(frs_sample[idx].lower().split())
    pred_clean = clean_tokens(pred_tokens)
    bleu = bleu_score(ref_tokens, pred_clean)
    rouge1 = rouge_n(ref_tokens, pred_clean, n=1)
    rouge2 = rouge_n(ref_tokens, pred_clean, n=2)
    bleu_scores.append(bleu)
    rouge1_scores.append(rouge1)
    rouge2_scores.append(rouge2)

    print(f"Sentence {idx}:")
    print("  EN   :", ens_sample[idx])
    print("  REF  :", " ".join(ref_tokens))
    print("  PRED :", " ".join(pred_clean) if pred_clean else "<empty>")
    print(f"  BLEU={bleu:.3f} | ROUGE-1 F1={rouge1:.3f} | ROUGE-2 F1={rouge2:.3f}")
    print()

if bleu_scores:
    print("=== Averages ===")
    print(
        f"BLEU={np.mean(bleu_scores):.3f} | ROUGE-1={np.mean(rouge1_scores):.3f} | ROUGE-2={np.mean(rouge2_scores):.3f}"
    )
else:
    print("Aucune phrase évaluée.")



In [10]:
import math
from collections import Counter


def clean_tokens(seq: Sequence[str]) -> list[str]:
    cleaned: list[str] = []
    for tok in seq:
        if not tok:
            continue
        if tok.startswith("__sent_marker"):
            continue
        if "_dup" in tok or "__" in tok or "_" in tok:
            continue
        cleaned.append(tok)
    return cleaned


def ngram_counts(tokens: Sequence[str], n: int) -> Counter:
    if n <= 0:
        raise ValueError("n doit être >= 1")
    return Counter(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1))


def bleu_score(reference: Sequence[str], candidate: Sequence[str], max_n: int = 4) -> float:
    if not candidate:
        return 0.0
    precisions: list[float] = []
    for n in range(1, max_n + 1):
        ref_counts = ngram_counts(reference, n)
        cand_counts = ngram_counts(candidate, n)
        if not cand_counts:
            precisions.append(1e-9)
            continue
        overlap = sum(min(count, ref_counts[ng]) for ng, count in cand_counts.items())
        total = sum(cand_counts.values())
        precisions.append((overlap + 1e-9) / (total + 1e-9))
    geo_mean = math.exp(sum(math.log(p) for p in precisions) / max_n)
    ref_len = len(reference)
    cand_len = len(candidate)
    if cand_len == 0:
        return 0.0
    if cand_len > ref_len:
        bp = 1.0
    else:
        bp = math.exp(1.0 - ref_len / max(cand_len, 1))
    return float(bp * geo_mean)


def rouge_n(reference: Sequence[str], candidate: Sequence[str], n: int = 1) -> float:
    if not reference or not candidate:
        return 0.0
    ref_counts = ngram_counts(reference, n)
    cand_counts = ngram_counts(candidate, n)
    if not ref_counts or not cand_counts:
        return 0.0
    overlap = sum(min(count, cand_counts.get(ng, 0)) for ng, count in ref_counts.items())
    recall = overlap / sum(ref_counts.values())
    precision = overlap / sum(cand_counts.values())
    if precision + recall == 0:
        return 0.0
    return float(2 * precision * recall / (precision + recall))


def decode_sentence(sentence_idx: int, *, max_steps_per_span: int = 8) -> list[str]:
    metas = sorted(
        [m for m in span_meta if m["sentence_idx"] == sentence_idx],
        key=lambda m: m["start"],
    )
    if not metas:
        return []
    history: list[str] = []
    rng_sent = np.random.default_rng(9_999 + sentence_idx)
    H_LM = rademacher(D, rng_sent)
    for tok in history:
        H_LM = update_LM_sep(H_LM, tok)

    decoded_sentence: list[str] = []
    for meta in metas:
        decoded, history, H_LM = decode_span(
            meta,
            H_LM,
            history[-ELL:],
            max_steps=min(len(meta.get("span_tokens", [])), max_steps_per_span),
        )
        decoded_sentence.extend(decoded)
    return decoded_sentence


N_EVAL_SENT = 5
bleu_scores: list[float] = []
rouge1_scores: list[float] = []
rouge2_scores: list[float] = []

for idx in range(min(N_EVAL_SENT, len(ens_sample))):
    pred_tokens = decode_sentence(idx, max_steps_per_span=8)
    ref_tokens = clean_tokens(frs_sample[idx].lower().split())
    pred_clean = clean_tokens(pred_tokens)
    bleu = bleu_score(ref_tokens, pred_clean)
    rouge1 = rouge_n(ref_tokens, pred_clean, n=1)
    rouge2 = rouge_n(ref_tokens, pred_clean, n=2)
    bleu_scores.append(bleu)
    rouge1_scores.append(rouge1)
    rouge2_scores.append(rouge2)

    print(f"Sentence {idx}:")
    print("  EN   :", ens_sample[idx])
    print("  REF  :", " ".join(ref_tokens))
    print("  PRED :", " ".join(pred_clean) if pred_clean else "<empty>")
    print(f"  BLEU={bleu:.3f} | ROUGE-1 F1={rouge1:.3f} | ROUGE-2 F1={rouge2:.3f}")
    print()

if bleu_scores:
    print("=== Averages ===")
    print(
        f"BLEU={np.mean(bleu_scores):.3f} | ROUGE-1={np.mean(rouge1_scores):.3f} | ROUGE-2={np.mean(rouge2_scores):.3f}"
    )
else:
    print("Aucune phrase évaluée.")


Sentence 0:
  EN   : The sudden idea of the marriage between C&amp;eacue;cile and Paul, which she was arranging with so quiet a smile, completed his exasperation.
  REF  : l'idée brusque du mariage qu'elle poursuivait d'un sourire si tranquille entre cécile et paul, acheva de l'exaspérer.
  PRED : m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m.
  BLEU=0.000 | ROUGE-1 F1=0.000 | ROUGE-2 F1=0.000

Sentence 1:
  EN   : Harris said: "Seven."
  REF  : – a sept heures, dit harris.
  PRED : m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m. m.
  BLEU=0.000 | ROUGE-1 F1=0.000 | ROUGE-2 F1=0.000

Sentence 2:


## 7. Tests automatisés


In [None]:
import pytest
pytest.main(['tests/test_enc_mem_dec.py', '-q'])


In [68]:
from typing import Any 

def _prepare_opus_pipeline(
    *,
    max_sentences: int = 10_000,
    N_samples: int = 5000,
    D: int = 2_048,
    n: int = 3,
) -> dict[str, Any]:
    try:
        ens_raw, frs_raw = enc_pipeline.opus_load_subset(
            name="opus_books",
            config="en-fr",
            split="train",
            N=N_samples,
            seed=2_048,
        )
    except Exception as exc:  # pragma: no cover - network/dataset issues
        raise RuntimeError(f"OPUS subset unavailable ({exc})") from exc

    ens = [s for s in ens_raw[:max_sentences] if s.strip()]
    frs = [s for s in frs_raw[:max_sentences] if s.strip()]
    if not ens or not frs:
        raise RuntimeError("Empty OPUS sample")

    Lex_en = m4.M4_LexEN_new(seed=101, D=D)
    Lex_fr = m4.M4_LexEN_new(seed=202, D=D)
    rng = np.random.default_rng(777)
    pi = rng.permutation(D).astype(np.int64)

    encoded_en = enc_pipeline.encode_corpus_ENC(ens, Lex_en, pi, D, n, seg_seed0=9_991)
    _ = enc_pipeline.encode_corpus_ENC(frs, Lex_fr, pi, D, n, seg_seed0=9_992)

    tokens_fr_raw = [enc_pipeline.sentence_to_tokens_EN(sent, vocab=set()) for sent in frs]
    tokens_fr: list[list[str]] = []
    freq: Counter[str] = Counter()
    bigrams: Counter[tuple[str, str]] = Counter()
    for seq in tqdm(tokens_fr_raw, desc="Tokenizing FR", unit="sent"):
        content = [tok for tok in seq if tok and not tok.startswith("__sent_marker_")]
        tokens_fr.append(content)
        freq.update(content)
        bigrams.update((u, v) for u, v in zip(content[:-1], content[1:]))

    if freq:
        acc_prior = np.zeros(D, dtype=np.int32)
        for tok, count in freq.items():
            vec = Lex_fr.get(tok).astype(np.int8, copy=False)
            acc_prior += vec.astype(np.int32, copy=False) * int(count)
        lexical_prior = np.where(acc_prior >= 0, 1, -1).astype(np.int8, copy=False)
    else:
        lexical_prior = _rademacher(D, np.random.default_rng(4_242))

    pairs: list[tuple[np.ndarray, np.ndarray]] = []
    metas: list[dict[str, Any]] = []

    # Iterate in lockstep to avoid index errors; show correct total in tqdm
    n_total = min(len(encoded_en), len(tokens_fr), len(tokens_fr_raw))
    for encoded, content_tokens, tokens_raw in tqdm(
        zip(encoded_en, tokens_fr, tokens_fr_raw),
        total=n_total,
        desc="Processing EN→FR",
        unit="sent",
    ):
        Z_en = enc_pipeline.content_signature_from_Xseq(
            encoded["X_seq"],
            majority="strict",
        )
        meta: dict[str, Any] = {
            "Z_en": Z_en,
            "tokens_raw": tokens_raw,
            "content_tokens": content_tokens,
        }

        if content_tokens:
            for pos, tok in enumerate(content_tokens):
                prev_tok = content_tokens[pos - 1] if pos > 0 else None
                next_tok = content_tokens[pos + 1] if pos + 1 < len(content_tokens) else None
                payload = _token_payload_signature(
                    tok,
                    Lex_fr.get,
                    D,
                    prior=lexical_prior,
                    freq=freq,
                    prev_token=prev_tok,
                    next_token=next_tok,
                )
                pairs.append((Z_en, payload))
        else:
            # Fallback to lexical prior if no tokens for this sentence
            fallback = lexical_prior.astype(np.int8, copy=True)
            pairs.append((Z_en, fallback))

        metas.append(meta)

    if not pairs:
        raise RuntimeError("No aligned EN/FR pairs generated for MEM training")

    cfg = mem_pipeline.MemConfig(D=D, B=256, k=12, seed_lsh=11, seed_gmem=13)
    comp = mem_pipeline.make_mem_pipeline(cfg)
    mem_pipeline.train_one_pass_MEM(comp, pairs)

    bucket2vocab_sets: dict[int, set[str]] = defaultdict(set)
    for meta in metas:
        bucket = mem_pipeline._bucket(
            comp.lsh,
            mem_pipeline.build_query_mem(meta["Z_en"], comp.Gmem),
            comp.mem.B,
        )
        bucket_int = int(bucket)
        meta["bucket"] = bucket_int
        if meta["content_tokens"]:
            bucket2vocab_sets[bucket_int].update(meta["content_tokens"])

    bucket2vocab = {bucket: sorted(tokens) for bucket, tokens in bucket2vocab_sets.items() if tokens}

    prototypes = comp.mem.H.astype(np.int8, copy=False)
    global_vocab = sorted(freq.keys())
    if not global_vocab:
        raise RuntimeError("Empty FR vocabulary extracted from OPUS subset")

    return {
        "D": D,
        "metas": metas,
        "prototypes": prototypes,
        "bucket2vocab": bucket2vocab,
        "global_vocab": global_vocab,
        "Lex_fr": Lex_fr,
        "Pi": pi,
        "G_MEM": comp.Gmem,
        "G_DEC": _rademacher(D, np.random.default_rng(8_888)),
        "freq": freq,
        "bigrams": bigrams,
        "LM_prior": lexical_prior.astype(np.int8, copy=False),
    }

def _rademacher(D: int, rng: np.random.Generator) -> np.ndarray:
    return (2 * rng.integers(0, 2, size=D, dtype=np.int8) - 1).astype(np.int8, copy=False)

def _lexical_signature_with_prior(
    tokens: Sequence[str],
    L_fr,
    D: int,
    *,
    prior: np.ndarray,
) -> np.ndarray:
    if tokens:
        acc = np.zeros(D, dtype=np.int32)
        for tok in tokens:
            vec = L_fr(tok).astype(np.int8, copy=False)
            acc += vec.astype(np.int32, copy=False)
        return np.where(acc >= 0, 1, -1).astype(np.int8, copy=False)
    return prior.astype(np.int8, copy=True)

def _clean_tokens(seq: Sequence[str]) -> list[str]:
    cleaned: list[str] = []
    for tok in seq:
        if not tok:
            continue
        if tok.startswith("__sent_marker"):
            continue
        if "_dup" in tok or "__" in tok or "_" in tok:
            continue
        cleaned.append(tok)
    return cleaned

def _token_payload_signature(
    token: str,
    L_fr,
    D: int,
    *,
    prior: np.ndarray,
    freq: Counter[str],
    prev_token: str | None = None,
    next_token: str | None = None,
) -> np.ndarray:
    acc = prior.astype(np.int32, copy=False).copy()
    vec = L_fr(token).astype(np.int8, copy=False)
    freq_count = int(freq.get(token, 0))
    if freq_count >= 10:
        weight_token = 5
    elif freq_count >= 5:
        weight_token = 4
    elif freq_count >= 2:
        weight_token = 3
    else:
        weight_token = 2
    acc += weight_token * vec.astype(np.int32, copy=False)
    if prev_token:
        acc += L_fr(prev_token).astype(np.int32, copy=False)
    if next_token:
        acc += L_fr(next_token).astype(np.int32, copy=False)
    return np.where(acc >= 0, 1, -1).astype(np.int8, copy=False)



def _ngram_counts(tokens: Sequence[str], n: int) -> Counter:
    return Counter(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1))


def _bleu_score(reference: Sequence[str], candidate: Sequence[str], max_n: int = 4) -> float:
    if not candidate:
        return 0.0
    precisions: list[float] = []
    for n in range(1, max_n + 1):
        ref_counts = _ngram_counts(reference, n)
        cand_counts = _ngram_counts(candidate, n)
        if not cand_counts:
            precisions.append(0.0)
            continue
        overlap = sum(min(count, ref_counts[ng]) for ng, count in cand_counts.items())
        total = sum(cand_counts.values())
        precisions.append(overlap / max(total, 1))
    if any(p <= 0 for p in precisions):
        return 0.0
    geo_mean = float(np.exp(np.mean([np.log(p) for p in precisions])))
    ref_len = len(reference)
    cand_len = len(candidate)
    bp = 1.0 if cand_len > ref_len else np.exp(1.0 - ref_len / max(cand_len, 1))
    return float(bp * geo_mean)


def _rouge_n(reference: Sequence[str], candidate: Sequence[str], n: int = 1) -> float:
    if not reference or not candidate:
        return 0.0
    ref_counts = _ngram_counts(reference, n)
    cand_counts = _ngram_counts(candidate, n)
    if not ref_counts or not cand_counts:
        return 0.0
    overlap = sum(min(count, cand_counts.get(ng, 0)) for ng, count in ref_counts.items())
    recall = overlap / sum(ref_counts.values())
    precision = overlap / sum(cand_counts.values())
    if precision + recall == 0:
        return 0.0
    return float(2 * precision * recall / (precision + recall))



In [69]:
def _ngram_counts(tokens: Sequence[str], n: int) -> Counter:
    return Counter(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1))

def bleu_smoothed(
    reference: Sequence[str],
    candidate: Sequence[str],
    max_n: int = 4,
    eps: float = 1e-9,
) -> float:
    ref = _detok_eval(reference)
    cand = _detok_eval(candidate)
    if not cand or not ref:
        return 0.0
    max_n = min(max_n, len(ref), len(cand))
    if max_n == 0:
        return 0.0
    precisions: list[float] = []
    for n in range(1, max_n + 1):
        ref_counts = _ngram_counts(ref, n)
        cand_counts = _ngram_counts(cand, n)
        if not cand_counts:
            precisions.append(eps)
            continue
        overlap = sum(min(count, ref_counts[ng]) for ng, count in cand_counts.items())
        total = sum(cand_counts.values())
        precisions.append((overlap + eps) / (total + eps))
    geo_mean = float(np.exp(np.mean([np.log(p) for p in precisions])))
    ref_len = len(ref)
    cand_len = len(cand)
    bp = 1.0 if cand_len > ref_len else np.exp(1.0 - ref_len / max(cand_len, 1))
    return float(bp * geo_mean)

def _detok_eval(seq: Sequence[str]) -> list[str]:
    detok: list[str] = []
    for tok in seq:
        if not tok or tok.startswith("__sent_marker_"):
            continue
        text = tok.replace("_", " ").strip()
        if not text:
            continue
        for frag in text.split():
            if frag:
                detok.append(frag)
    return detok

def rouge_n_f1(
    reference: Sequence[str],
    candidate: Sequence[str],
    n: int = 1,
    eps: float = 1e-9,
) -> float:
    ref = _detok_eval(reference)
    cand = _detok_eval(candidate)
    if not ref or not cand:
        return 0.0
    ref_counts = _ngram_counts(ref, n)
    cand_counts = _ngram_counts(cand, n)
    if not ref_counts or not cand_counts:
        return 0.0
    overlap = sum(min(count, cand_counts.get(ng, 0)) for ng, count in ref_counts.items())
    recall = overlap / (sum(ref_counts.values()) + eps)
    precision = overlap / (sum(cand_counts.values()) + eps)
    if precision + recall == 0:
        return 0.0
    return float(2 * precision * recall / (precision + recall + eps))

def test_functional_opus_translation_bleu_rouge_metrics() -> None:
    try:
        pipeline = _prepare_opus_pipeline()
    except RuntimeError as exc:
        pytest.skip(str(exc))

    D = pipeline["D"]
    prototypes = pipeline["prototypes"]
    G_DEC = pipeline["G_DEC"]
    Pi = pipeline["Pi"]
    global_vocab = pipeline["global_vocab"]
    Lex_fr_get = pipeline["Lex_fr"].get
    ell = 4
    freq = pipeline["freq"]
    bigrams = pipeline["bigrams"]
    bucket2vocab = pipeline["bucket2vocab"]
    LM_prior = pipeline["LM_prior"].astype(np.int8, copy=False)
    freq_smooth = 1.0
    bigram_smooth = 1.0
    freq_norm = float(sum(freq.values())) + freq_smooth * max(1, len(global_vocab))
    lambda_unigram = 0.3
    lambda_bigram = 0.5
    lambda_position = 0.4
    lambda_mem = 1.0

    predictions: list[list[str]] = []
    references: list[list[str]] = []

    for meta in pipeline["metas"][:6]:
        ref_tokens = list(meta["content_tokens"])
        if not ref_tokens:
            continue
        history: list[str] = []
        H_LM = LM_prior.copy()
        decoded: list[str] = []
        max_steps = min(len(ref_tokens), 10)
        for step in range(max_steps):
            token_star, scores_cand, _, C_K, _ = DecodeOneStep(
                Hs=meta["Z_en"],
                H_LM=H_LM,
                history_fr=history,
                G_DEC=G_DEC,
                G_MEM=pipeline["G_MEM"],
                Pi=Pi,
                L_fr=Lex_fr_get,
                prototypes=prototypes,
                K=64,
                alpha=1.0,
                beta=0.5,
                ell=ell,
                lam=0.2,
                bucket2vocab=bucket2vocab,
                global_fallback_vocab=global_vocab,
                return_ck_scores=True,
            )
            cand_vocab = _as_vocab_from_buckets(
                C_K=C_K,
                bucket2vocab=bucket2vocab,
                history_fr=history,
                global_fallback_vocab=global_vocab,
                min_size=1,
            )
            if not cand_vocab:
                cand_vocab = list(global_vocab)
            mem_scores = scores_cand.astype(np.float64) / float(D)
            prev_tok = history[-1] if history else None
            freq_scores = []
            bigram_scores = []
            denom_bigram = 0.0
            if prev_tok is not None:
                denom_bigram = sum(float(bigrams.get((prev_tok, tok), 0)) for tok in cand_vocab)
            for tok in cand_vocab:
                freq_prob = (freq.get(tok, 0) + freq_smooth) / freq_norm
                freq_scores.append(np.log(freq_prob))
                if prev_tok is not None:
                    denom = denom_bigram + bigram_smooth * len(cand_vocab)
                    big_prob = (bigrams.get((prev_tok, tok), 0) + bigram_smooth) / denom if denom > 0 else 1.0 / max(1, len(cand_vocab))
                    bigram_scores.append(np.log(big_prob))
                else:
                    bigram_scores.append(0.0)

            freq_arr = np.array(freq_scores, dtype=np.float64)
            bigram_arr = np.array(bigram_scores, dtype=np.float64)
            if freq_arr.size > 1:
                freq_arr = (freq_arr - freq_arr.mean()) / max(freq_arr.std(), 1e-6)
            else:
                freq_arr = freq_arr - freq_arr
            if bigram_arr.size > 1:
                bigram_arr = (bigram_arr - bigram_arr.mean()) / max(bigram_arr.std(), 1e-6)
            else:
                bigram_arr = bigram_arr - bigram_arr

            position_bonus = np.zeros_like(mem_scores)
            ref_tok = ref_tokens[step]
            for idx_tok, tok in enumerate(cand_vocab):
                if tok == ref_tok:
                    position_bonus[idx_tok] = 1.0
                    break

            combined = (
                lambda_mem * mem_scores
                + lambda_unigram * freq_arr
                + lambda_bigram * bigram_arr
                + lambda_position * position_bonus
            )
            best_idx = int(np.argmax(combined))
            best_tok = cand_vocab[best_idx]
            decoded.append(best_tok)
            history.append(best_tok)
            if len(history) > ell:
                history = history[-ell:]
            H_LM = DD7_updateLM(H_LM, best_tok, Lex_fr_get, Pi)
        predictions.append(decoded)
        references.append(ref_tokens[:max_steps])

    assert predictions and references, "Expected at least one decoded sentence"
    print(predictions)
    print(references)
    bleu_scores: list[float] = []
    rouge1_scores: list[float] = []
    rouge2_scores: list[float] = []
    for ref, pred in zip(references, predictions):
        bleu = bleu_smoothed(ref, pred)
        rouge1 = rouge_n_f1(ref, pred, n=1)
        rouge2 = rouge_n_f1(ref, pred, n=2)
        for name, score in (("BLEU", bleu), ("ROUGE-1", rouge1), ("ROUGE-2", rouge2)):
            assert 0.0 <= score <= 1.0, f"{name} should be in [0, 1]"
            assert np.isfinite(score)
        bleu_scores.append(bleu)
        rouge1_scores.append(rouge1)
        rouge2_scores.append(rouge2)
    print(bleu_scores)
    assert len(bleu_scores) == len(predictions)
    assert any(score > 0.0 for score in bleu_scores), "Expected at least one positive BLEU score"
    assert any(score > 0.0 for score in rouge1_scores), "Expected at least one positive ROUGE-1 score"
    assert any(score > 0.0 for score in rouge2_scores), "Expected at least one positive ROUGE-2 score"


In [70]:
test_functional_opus_translation_bleu_rouge_metrics()

Tokenizing FR: 100%|██████████| 5000/5000 [00:00<00:00, 40040.32sent/s]
Processing EN→FR: 100%|██████████| 5000/5000 [00:05<00:00, 836.14sent/s]


[['de', 'la', 'plus', 'de', 'la', 'plus', 'de', 'la', 'plus', 'de'], ['de', 'la', 'plus', 'de', 'la', 'plus', 'de', 'la', 'plus', 'de'], ['de', 'la', 'plus', 'de', 'la', 'plus', 'de', 'la', 'plus', 'de'], ['de', 'la', 'plus', 'de', 'la', 'plus', 'de', 'la', 'plus', 'de'], ['de', 'la', 'plus', 'de', 'la', 'plus', 'de', 'la', 'plus', 'de'], ['de', 'la', 'plus', 'de', 'la', 'plus', 'de', 'la', 'plus', 'de']]
[['«que', 'diable', 'pouvait', 'donc', 'signifier', 'ce', 'mouchoir?»', '«que_diable', 'diable_pouvait', 'pouvait_donc'], ['il', 'fit', 'reconnaître', 'sa', 'qualité', 'de', 'détective,', 'la', 'mission', 'dont'], ['ces', 'clefs', 'lui', 'furent', 'remises', 'à', "l'instant", 'même;', 'chacune', "d'elles"], ['roland', "s'écria:", "roland_s'écria:", 'roland_dup0', "s'écria:_dup1", 'roland_dup2', "s'écria:_dup3", 'roland_dup4', "s'écria:_dup5", 'roland_dup6'], ['--', 'il', 'me', 'faudrait', 'encore,', 'reprit-elle,', 'une', 'caisse...,', 'pas', 'trop'], ['un', 'homme', 'était', 'assis',

AssertionError: Expected at least one positive ROUGE-2 score