# explore_bis_v5 — pipeline ENC → MEM → DEC


In [None]:
from pathlib import Path
import sys

ROOT = Path.cwd().parent
SRC = ROOT / "src"
if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))
print(f"Using src path: {SRC}")


In [None]:
import logging
from collections import defaultdict
from typing import List, Sequence, Tuple

import numpy as np
from tqdm import tqdm

from hdc_project.encoder import m4, pipeline as enc_pipeline
from hdc_project.encoder.mem import pipeline as mem_pipeline
from hdc_project.decoder import (
    DD1_ctx,
    DD2_query,
    DD2_query_bin,
    DD3_bindToMem,
    DD4_search_topK,
    DD5_payload,
    DD6_vote,
    DD7_updateLM,
    DecodeOneStep,
    DX2_run,
    DX3_run,
    DX4_run,
    DX5_run,
    DX6_run,
    DX7_run,
)
from hdc_project.decoder.dec import (
    hd_assert_pm1,
    hd_bind,
    hd_sim,
    hd_sim_dot,
    build_perm_inverse,
    permute_pow,
    permute_pow_signed,
    rademacher,
    _as_vocab_from_buckets,
)

log = logging.getLogger("explore_v5")
if not log.handlers:
    logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")


## 1. Chargement du corpus et encodage ENC


In [None]:
MAX_SENTENCES = 2_000
N_OPUS = 5_000

try:
    ens_raw, frs_raw = enc_pipeline.opus_load_subset(
        name="opus_books",
        config="en-fr",
        split="train",
        N=N_OPUS,
        seed=2025,
    )
    log.info("OPUS subset loaded: %d pairs", len(ens_raw))
except Exception as exc:
    log.warning("OPUS download failed (%s); using a local toy corpus", exc)
    ens_raw = [
        "hyperdimensional computing is fun",
        "vector symbolic architectures are powerful",
        "encoding words into hyperspace",
        "memory augmented networks love clean data",
    ]
    frs_raw = [
        "le calcul hyperdimensionnel est amusant",
        "les architectures symboliques vectorielles sont puissantes",
        "encoder des mots dans l'hyperspace",
        "les réseaux augmentés de mémoire aiment les données propres",
    ]

ens_sample = ens_raw[:MAX_SENTENCES]
frs_sample = frs_raw[:MAX_SENTENCES]
log.info("Using %d sentence pairs", len(ens_sample))

D = 8_192
n = 5
rng = np.random.default_rng(123)

Lex_en = m4.M4_LexEN_new(seed=1, D=D)
Lex_fr = m4.M4_LexEN_new(seed=2, D=D)
pi = rng.permutation(D).astype(np.int64)
pi_inv = build_perm_inverse(pi)

encoded_en = enc_pipeline.encode_corpus_ENC(ens_sample, Lex_en, pi, D, n, seg_seed0=999)
encoded_fr = enc_pipeline.encode_corpus_ENC(frs_sample, Lex_fr, pi, D, n, seg_seed0=1999)
log.info("Encoded %d EN / %d FR sentences", len(encoded_en), len(encoded_fr))

E_list_en = [segment["E_seq"] for segment in encoded_en]
H_list_en = [segment["H"] for segment in encoded_en]
if H_list_en:
    log.info("Encoder signature shape: %s", H_list_en[0].shape)

intra_sim, inter_sim = enc_pipeline.intra_inter_ngram_sims(E_list_en, D)
inter_seg = enc_pipeline.inter_segment_similarity(H_list_en)
log.info("ENC stats — intra: %.4f | inter(|.|): %.4f | inter segments: %.4f", intra_sim, inter_sim, inter_seg)


## 2. Construction des paires MEM


In [None]:
def content_signature_from_Xseq(X_seq: Sequence[np.ndarray], *, majority: str = "strict") -> np.ndarray:
    if not X_seq:
        raise ValueError("X_seq vide")
    acc = np.zeros(X_seq[0].shape[0], dtype=np.int32)
    for x in X_seq:
        acc += x.astype(np.int32, copy=False)
    if majority == "strict":
        return np.where(acc >= 0, 1, -1).astype(np.int8, copy=False)
    if majority == "unbiased":
        rng_local = np.random.default_rng(0)
        ties = acc == 0
        acc[ties] = rng_local.integers(0, 2, size=int(ties.sum()), dtype=np.int32) * 2 - 1
        return np.where(acc >= 0, 1, -1).astype(np.int8, copy=False)
    raise ValueError("majority must be 'strict' or 'unbiased'")


def span_signatures_from_trace(
    X_seq: Sequence[np.ndarray],
    *,
    win: int,
    stride: int,
    majority: str,
) -> List[Tuple[np.ndarray, int, int]]:
    T = len(X_seq)
    if T == 0:
        return []
    spans: List[Tuple[np.ndarray, int, int]] = []
    if T <= win:
        spans.append((content_signature_from_Xseq(X_seq, majority=majority), 0, T))
        return spans
    for start in range(0, T - win + 1, max(1, stride)):
        stop = start + win
        spans.append((content_signature_from_Xseq(X_seq[start:stop], majority=majority), start, stop))
    return spans


def build_mem_pairs_with_meta(
    encoded_en: Sequence[dict],
    encoded_fr: Sequence[dict],
    tokens_fr: Sequence[Sequence[str]],
    *,
    win: int = 8,
    stride: int = 4,
    majority: str = "strict",
    max_pairs: int | None = None,
) -> Tuple[List[Tuple[np.ndarray, np.ndarray]], List[dict]]:
    pairs: List[Tuple[np.ndarray, np.ndarray]] = []
    meta: List[dict] = []
    N = min(len(encoded_en), len(encoded_fr))
    for idx in tqdm(range(N), desc="MEM span extraction", leave=False):
        spans_en = span_signatures_from_trace(encoded_en[idx]["X_seq"], win=win, stride=stride, majority=majority)
        spans_fr = span_signatures_from_trace(encoded_fr[idx]["X_seq"], win=win, stride=stride, majority=majority)
        tok_fr = list(tokens_fr[idx]) if idx < len(tokens_fr) else []
        L = min(len(spans_en), len(spans_fr))
        for (ze, start_en, stop_en), (zf, start_fr, stop_fr) in zip(spans_en[:L], spans_fr[:L]):
            pairs.append((ze, zf))
            span_tokens = tok_fr[start_fr:stop_fr] if tok_fr else []
            history = tok_fr[max(0, start_fr - stride):start_fr] if tok_fr else []
            meta.append(
                {
                    "sentence_idx": idx,
                    "start": start_fr,
                    "stop": stop_fr,
                    "history_tokens": history,
                    "span_tokens": span_tokens,
                    "Z_en": ze,
                    "Z_fr": zf,
                }
            )
            if max_pairs is not None and len(pairs) >= max_pairs:
                return pairs, meta
    return pairs, meta


tokens_fr = [enc_pipeline.sentence_to_tokens_EN(sent, vocab=set()) for sent in frs_sample]
pairs_mem, span_meta = build_mem_pairs_with_meta(
    encoded_en,
    encoded_fr,
    tokens_fr,
    win=8,
    stride=4,
    majority="strict",
    max_pairs=50_000,
)
log.info("Prepared %d MEM pairs", len(pairs_mem))


## 3. Entraînement MEM et diagnostic rapide


In [None]:
MEM_K = 16
MEM_BUCKETS = 128
cfg = mem_pipeline.MemConfig(D=D, B=MEM_BUCKETS, k=MEM_K, seed_lsh=10, seed_gmem=11)
comp = mem_pipeline.make_mem_pipeline(cfg)
mem_pipeline.train_one_pass_MEM(comp, pairs_mem)
log.info("MEM training completed (B=%d)", comp.mem.B)

probe_count = min(200, len(pairs_mem))
sim_values = []
for Z_en_vec, Z_fr_vec in tqdm(pairs_mem[:probe_count], desc="MEM probe"):
    bucket_idx, _ = mem_pipeline.infer_map_top1(comp, Z_en_vec)
    prototype = comp.mem.H[bucket_idx].astype(np.int32, copy=False)
    sim = float(np.dot(prototype, Z_fr_vec.astype(np.int32, copy=False)) / D)
    sim_values.append(sim)

if sim_values:
    log.info(
        "Probe similarities — mean: %.4f | median: %.4f",
        float(np.mean(sim_values)),
        float(np.median(sim_values)),
    )
    nb = comp.mem.n
    log.info(
        "Bucket population stats — mean: %.1f | p90: %d | p99: %d",
        float(nb.mean()),
        int(np.quantile(nb, 0.90)),
        int(np.quantile(nb, 0.99)),
    )


## 4. Dictionnaire bucket → vocabulaire


In [None]:
bucket_vocab: dict[int, set[str]] = defaultdict(set)
for meta in tqdm(span_meta, desc="Bucket vocab build", leave=False):
    bucket_idx, _ = mem_pipeline.infer_map_top1(comp, meta["Z_en"])
    meta["bucket_idx"] = int(bucket_idx)
    for tok in meta["span_tokens"]:
        bucket_vocab[int(bucket_idx)].add(tok)

bucket2vocab = {bucket: sorted(tokens) for bucket, tokens in bucket_vocab.items()}
all_vocab = sorted({tok for tokens in bucket2vocab.values() for tok in tokens})
log.info("Bucket vocab built for %d buckets (global vocab size=%d)", len(bucket2vocab), len(all_vocab))


## 5bis. Diagnostics théorie ↔ implémentation


In [None]:
log.info("Running DEC diagnostic suite (subsampled)...")

norms = DX2_run(D=D, trials=50, ells=(2, 4, 8), ratios=(1.0,), seed=2025)
log.info("DX2 ok — example median norm: %.3f", np.median([v[1] for v in norms.values()]))

rel_err, pval = DX3_run(D=D, C=256, T=64, seed=2025, rel_tol=0.02, pmin=0.05)
log.info("DX3 ok — mean relative error=%.4f | p=%.3f", rel_err, pval)

recalls = DX4_run(D=D, B=5_000, trials=40, Ks=(100, 500), seed=0)
log.info("DX4 ok — recall@500=%.3f", recalls[500])

accuracies = DX5_run(D=D, trials=40, ms=(4, 8, 16), seed=0)
log.info("DX5 ok — accuracy m=8: %.3f", accuracies[8])

results_dx6 = DX6_run(D=D, trials=120, lam_grid=(0.0, 0.5, 1.0), rng_seed=7031)
log.info("DX6 ok — lambda grid summary: %s", {lam: (vals['top1'], vals['ppl']) for lam, vals in results_dx6.items()})

dx7_results, ell_star = DX7_run(ell_grid=(2, 4, 8), D=D, seed_pi=10_456, rng_seed=9_117)
log.info("DX7 ok — ell*=%d, top1=%.3f", ell_star, dx7_results[ell_star]["top1"])


## 6. Démonstration DEC (un pas)


In [None]:
if not span_meta:
    raise RuntimeError("Aucune paire MEM disponible pour la démonstration DEC.")

demo = next((m for m in span_meta if m["span_tokens"]), span_meta[0])
G_DEC = rademacher(D, np.random.default_rng(2025))
G_MEM = comp.Gmem
L_fr = Lex_fr.get

history = list(demo["history_tokens"][-4:])
Hs = demo["Z_en"]
H_LM = rademacher(D, np.random.default_rng(4242))
for tok in history:
    H_LM = DD7_updateLM(H_LM, tok, L_fr, pi)

prototypes = comp.mem.H.astype(np.int8, copy=False)

Qs = DD1_ctx(Hs, G_DEC)
Rt = DD2_query_bin(Qs, history, L_fr, pi, alpha=1.0, beta=1.0, ell=max(1, len(history)))
Rt_tilde = DD3_bindToMem(Rt, G_MEM)
c_star, C_K, scores_CK = DD4_search_topK(Rt_tilde, prototypes, K=32)
Z_hat = DD5_payload(prototypes[c_star])

cand_vocab = _as_vocab_from_buckets(
    C_K=C_K,
    bucket2vocab=bucket2vocab,
    history_fr=history,
    global_fallback_vocab=all_vocab[:256] if all_vocab else None,
    min_size=1,
)

token_star, scores_cand, _ = DD6_vote(
    Z_hat=Z_hat,
    H_LM=H_LM,
    L_mem=L_fr,
    L_lm=L_fr,
    cand_vocab=cand_vocab,
    lam=0.5,
    normalize="sqrtD",
    return_probs=False,
)

print("History tokens:", history)
print("Span tokens (ground truth):", demo["span_tokens"][:10])
print("Decoded token*:", token_star)
print("Top-5 candidats (score brut):")
order = np.argsort(scores_cand)[::-1][:5]
for rank, idx in enumerate(order, start=1):
    tok = cand_vocab[idx] if idx < len(cand_vocab) else "<unk>"
    print(f"  {rank}. {tok:20s} -> {scores_cand[idx]:8.1f}")


## 7. Tests automatisés


In [None]:
import pytest
pytest.main(['tests/test_enc_mem_dec.py', '-q'])
