# explore_bis_v5 — pipeline ENC → MEM → DEC


In [19]:
from pathlib import Path
import sys

ROOT = Path.cwd().parent
SRC = ROOT / "src"
if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))
print(f"Using src path: {SRC}")


Using src path: /Users/aymenmejri/Desktop/MyCode/experiments/hdc_v2/hdc_project/src


In [20]:
import logging
from collections import Counter, defaultdict
from typing import List, Sequence, Tuple

import numpy as np
from tqdm import tqdm

from hdc_project.encoder import m4, pipeline as enc_pipeline
from hdc_project.encoder.mem import pipeline as mem_pipeline
from hdc_project.decoder import (
    DD1_ctx,
    DD2_query,
    DD2_query_bin,
    DD3_bindToMem,
    DD4_search_topK,
    DD5_payload,
    DD6_vote,
    DD7_updateLM,
    DecodeOneStep,
    DX2_run,
    DX3_run,
    DX4_run,
    DX5_run,
    DX6_run,
    DX7_run,
)
from hdc_project.decoder.dec import (
    hd_assert_pm1,
    hd_bind,
    hd_sim,
    hd_sim_dot,
    build_perm_inverse,
    permute_pow,
    permute_pow_signed,
    rademacher,
    _as_vocab_from_buckets,
)

log = logging.getLogger("explore_v5")
if not log.handlers:
    logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")


## 1. Chargement du corpus et encodage ENC


In [21]:
MAX_SENTENCES = 2_000
N_OPUS = 5_000

try:
    ens_raw, frs_raw = enc_pipeline.opus_load_subset(
        name="opus_books",
        config="en-fr",
        split="train",
        N=N_OPUS,
        seed=2025,
    )
    log.info("OPUS subset loaded: %d pairs", len(ens_raw))
except Exception as exc:
    log.warning("OPUS download failed (%s); using a local toy corpus", exc)
    ens_raw = [
        "hyperdimensional computing is fun",
        "vector symbolic architectures are powerful",
        "encoding words into hyperspace",
        "memory augmented networks love clean data",
    ]
    frs_raw = [
        "le calcul hyperdimensionnel est amusant",
        "les architectures symboliques vectorielles sont puissantes",
        "encoder des mots dans l'hyperspace",
        "les réseaux augmentés de mémoire aiment les données propres",
    ]

ens_sample = ens_raw[:MAX_SENTENCES]
frs_sample = frs_raw[:MAX_SENTENCES]
log.info("Using %d sentence pairs", len(ens_sample))

D = 8_192
n = 5
rng = np.random.default_rng(123)

Lex_en = m4.M4_LexEN_new(seed=1, D=D)
Lex_fr_mem = m4.M4_LexEN_new(seed=2, D=D)
Lex_fr_lm = m4.M4_LexEN_new(seed=202, D=D)
pi = rng.permutation(D).astype(np.int64)
pi_inv = build_perm_inverse(pi)

encoded_en = enc_pipeline.encode_corpus_ENC(ens_sample, Lex_en, pi, D, n, seg_seed0=999)
encoded_fr = enc_pipeline.encode_corpus_ENC(frs_sample, Lex_fr_mem, pi, D, n, seg_seed0=1999)
log.info("Encoded %d EN / %d FR sentences", len(encoded_en), len(encoded_fr))

E_list_en = [segment["E_seq"] for segment in encoded_en]
H_list_en = [segment["H"] for segment in encoded_en]
if H_list_en:
    log.info("Encoder signature shape: %s", H_list_en[0].shape)

intra_sim, inter_sim = enc_pipeline.intra_inter_ngram_sims(E_list_en, D)
inter_seg = enc_pipeline.inter_segment_similarity(H_list_en)
log.info("ENC stats — intra: %.4f | inter(|.|): %.4f | inter segments: %.4f", intra_sim, inter_sim, inter_seg)


2025-10-07 02:11:57,133 [INFO] OPUS subset loaded: 5000 pairs
2025-10-07 02:11:57,134 [INFO] Using 2000 sentence pairs
2025-10-07 02:15:19,878 [INFO] Encoded 2000 EN / 2000 FR sentences
2025-10-07 02:15:34,102 [INFO] Encoder signature shape: (8192,)
2025-10-07 02:15:46,220 [INFO] ENC stats — intra: 0.0003 | inter(|.|): 0.0271 | inter segments: 0.0086


## 2. Construction des paires MEM


In [22]:
from typing import Callable


def content_signature_from_Xseq(X_seq: Sequence[np.ndarray], *, majority: str = "strict") -> np.ndarray:
    if not X_seq:
        raise ValueError("X_seq vide")
    acc = np.zeros(X_seq[0].shape[0], dtype=np.int32)
    for x in X_seq:
        acc += x.astype(np.int32, copy=False)
    if majority == "strict":
        return np.where(acc >= 0, 1, -1).astype(np.int8, copy=False)
    if majority == "unbiased":
        rng_local = np.random.default_rng(0)
        ties = acc == 0
        acc[ties] = rng_local.integers(0, 2, size=int(ties.sum()), dtype=np.int32) * 2 - 1
        return np.where(acc >= 0, 1, -1).astype(np.int8, copy=False)
    raise ValueError("majority must be 'strict' or 'unbiased'")


def span_signatures_from_trace(
    X_seq: Sequence[np.ndarray],
    *,
    win: int,
    stride: int,
    majority: str,
) -> List[Tuple[np.ndarray, int, int]]:
    T = len(X_seq)
    if T == 0:
        return []
    spans: List[Tuple[np.ndarray, int, int]] = []
    if T <= win:
        spans.append((content_signature_from_Xseq(X_seq, majority=majority), 0, T))
        return spans
    for start in range(0, T - win + 1, max(1, stride)):
        stop = start + win
        spans.append((content_signature_from_Xseq(X_seq[start:stop], majority=majority), start, stop))
    return spans


def lexical_signature_from_tokens(tokens: Sequence[str], L_fr_mem: Callable[[str], np.ndarray], D: int) -> np.ndarray:
    if not tokens:
        return np.ones(D, dtype=np.int8)
    acc = np.zeros(D, dtype=np.int32)
    for tok in tokens:
        vec = L_fr_mem(tok).astype(np.int8, copy=False)
        hd_assert_pm1(vec, D)
        acc += vec.astype(np.int32, copy=False)
    return np.where(acc >= 0, 1, -1).astype(np.int8, copy=False)


def build_mem_pairs_with_meta(
    encoded_en: Sequence[dict],
    encoded_fr: Sequence[dict],
    tokens_fr: Sequence[Sequence[str]],
    *,
    L_fr_mem: Callable[[str], np.ndarray],
    win: int = 8,
    stride: int = 4,
    majority: str = "strict",
    max_pairs: int | None = None,
) -> Tuple[List[Tuple[np.ndarray, np.ndarray]], List[dict]]:
    pairs: List[Tuple[np.ndarray, np.ndarray]] = []
    meta: List[dict] = []
    N = min(len(encoded_en), len(encoded_fr))
    for idx in tqdm(range(N), desc="MEM span extraction", leave=False):
        spans_en = span_signatures_from_trace(encoded_en[idx]["X_seq"], win=win, stride=stride, majority=majority)
        spans_fr = span_signatures_from_trace(encoded_fr[idx]["X_seq"], win=win, stride=stride, majority=majority)
        tok_fr = list(tokens_fr[idx]) if idx < len(tokens_fr) else []
        L = min(len(spans_en), len(spans_fr))
        for (ze, start_en, stop_en), (_, start_fr, stop_fr) in zip(spans_en[:L], spans_fr[:L]):
            span_tokens = tok_fr[start_fr:stop_fr] if tok_fr else []
            zf_lex = lexical_signature_from_tokens(span_tokens, L_fr_mem, D)
            pairs.append((ze, zf_lex))
            meta.append(
                {
                    "sentence_idx": idx,
                    "start": start_fr,
                    "stop": stop_fr,
                    "history_tokens": tok_fr[max(0, start_fr - stride):start_fr] if tok_fr else [],
                    "span_tokens": span_tokens,
                    "Z_en": ze,
                    "Z_fr_lex": zf_lex,
                }
            )
            if max_pairs is not None and len(pairs) >= max_pairs:
                return pairs, meta
    return pairs, meta


tokens_fr = [enc_pipeline.sentence_to_tokens_EN(sent, vocab=set()) for sent in frs_sample]
pairs_mem, span_meta = build_mem_pairs_with_meta(
    encoded_en,
    encoded_fr,
    tokens_fr,
    L_fr_mem=Lex_fr_mem.get,
    win=8,
    stride=4,
    majority="strict",
    max_pairs=50_000,
)
log.info("Prepared %d MEM pairs", len(pairs_mem))


2025-10-07 02:16:47,982 [INFO] Prepared 38892 MEM pairs                  


## 3. Entraînement MEM et diagnostic rapide


In [23]:
MEM_K = 16
MEM_BUCKETS = 128
cfg = mem_pipeline.MemConfig(D=D, B=MEM_BUCKETS, k=MEM_K, seed_lsh=10, seed_gmem=11)
comp = mem_pipeline.make_mem_pipeline(cfg)
mem_pipeline.train_one_pass_MEM(comp, pairs_mem)
log.info("MEM training completed (B=%d)", comp.mem.B)

probe_count = min(200, len(pairs_mem))
sim_values = []
for Z_en_vec, Z_fr_vec in tqdm(pairs_mem[:probe_count], desc="MEM probe"):
    bucket_idx, _ = mem_pipeline.infer_map_top1(comp, Z_en_vec)
    prototype = comp.mem.H[bucket_idx].astype(np.int32, copy=False)
    sim = float(np.dot(prototype, Z_fr_vec.astype(np.int32, copy=False)) / D)
    sim_values.append(sim)

if sim_values:
    log.info(
        "Probe similarities — mean: %.4f | median: %.4f",
        float(np.mean(sim_values)),
        float(np.median(sim_values)),
    )
    nb = comp.mem.n
    log.info(
        "Bucket population stats — mean: %.1f | p90: %d | p99: %d",
        float(nb.mean()),
        int(np.quantile(nb, 0.90)),
        int(np.quantile(nb, 0.99)),
    )


2025-10-07 02:16:51,086 [INFO] MEM training completed (B=128)
MEM probe: 100%|██████████| 200/200 [00:00<00:00, 22613.85it/s]
2025-10-07 02:16:51,097 [INFO] Probe similarities — mean: 0.2666 | median: 0.2729
2025-10-07 02:16:51,097 [INFO] Bucket population stats — mean: 303.8 | p90: 344 | p99: 375


## 4. Dictionnaire bucket → vocabulaire


In [24]:
bucket_counts: dict[int, Counter] = defaultdict(Counter)
for meta in tqdm(span_meta, desc="Bucket vocab build", leave=False):
    bucket_idx, _ = mem_pipeline.infer_map_top1(comp, meta["Z_en"])
    bucket_idx = int(bucket_idx)
    meta["bucket_idx"] = bucket_idx
    for tok in meta.get("span_tokens", []):
        bucket_counts[bucket_idx][tok] += 1

bucket2vocab_freq = {
    bucket: sorted(counter.items(), key=lambda kv: (-kv[1], kv[0]))
    for bucket, counter in bucket_counts.items()
}
bucket2vocab = {bucket: [tok for tok, _cnt in tokens] for bucket, tokens in bucket2vocab_freq.items()}
all_vocab = sorted({tok for tokens in bucket2vocab.values() for tok in tokens})
log.info(
    "Bucket vocab built for %d buckets (global vocab size=%d)",
    len(bucket2vocab),
    len(all_vocab),
)


2025-10-07 02:16:57,409 [INFO] Bucket vocab built for 128 buckets (global vocab size=113742)


## 5bis. Diagnostics théorie ↔ implémentation


In [25]:
log.info("Running DEC diagnostic suite (subsampled)...")

norms = DX2_run(D=D, trials=50, ells=(2, 4, 8), ratios=(1.0,), seed=2025)
log.info("DX2 ok — example median norm: %.3f", np.median([v[1] for v in norms.values()]))

rel_err, pval = DX3_run(D=D, C=256, T=64, seed=2025, rel_tol=0.02, pmin=0.05)
log.info("DX3 ok — mean relative error=%.4f | p=%.3f", rel_err, pval)

recalls = DX4_run(D=D, B=5_000, trials=40, Ks=(100, 500), seed=0)
log.info("DX4 ok — recall@500=%.3f", recalls[500])

accuracies = DX5_run(D=D, trials=40, ms=(4, 8, 16), seed=0)
log.info("DX5 ok — accuracy m=8: %.3f", accuracies[8])

results_dx6 = DX6_run(D=D, trials=120, lam_grid=(0.0, 0.5, 1.0), rng_seed=7031)
log.info("DX6 ok — lambda grid summary: %s", {lam: (vals['top1'], vals['ppl']) for lam, vals in results_dx6.items()})

dx7_results, ell_star = DX7_run(ell_grid=(2, 4, 8), D=D, seed_pi=10_456, rng_seed=9_117)
log.info("DX7 ok — ell*=%d, top1=%.3f", ell_star, dx7_results[ell_star]["top1"])


2025-10-07 02:16:58,646 [INFO] Running DEC diagnostic suite (subsampled)...
2025-10-07 02:16:59,064 [INFO] DX2 ok — example median norm: 1.000
2025-10-07 02:16:59,216 [INFO] DX3 ok — mean relative error=0.0000 | p=1.000
2025-10-07 02:17:08,235 [INFO] DX4 ok — recall@500=1.000
2025-10-07 02:17:08,244 [INFO] DX5 ok — accuracy m=8: 1.000
2025-10-07 02:17:08,344 [INFO] DX6 ok — lambda grid summary: {0.0: (1.0, 1.1109939911767361), 0.5: (1.0, 1.0000001843669586), 1.0: (1.0, 1.000000000000448)}
2025-10-07 02:17:14,703 [INFO] DX7 ok — ell*=2, top1=0.926


## 6. Démonstration DEC (un pas)


In [26]:
if not span_meta:
    raise RuntimeError("Aucune paire MEM disponible pour la démonstration DEC.")

G_DEC = rademacher(D, np.random.default_rng(2025))
G_MEM = comp.Gmem
L_mem = Lex_fr_mem.get
L_lm = Lex_fr_lm.get
prototypes = comp.mem.H.astype(np.int8, copy=False)

ELL = 4
CAND_PER_BUCKET = 32
V_MAX = 512
LAM = 0.5

def update_LM_sep(H_LM: np.ndarray, token: str) -> np.ndarray:
    vec = L_lm(token).astype(np.int8, copy=False)
    hd_assert_pm1(vec, D)
    inc = permute_pow(vec, pi, 1).astype(np.int16, copy=False)
    acc = H_LM.astype(np.int16) + inc
    return np.where(acc >= 0, 1, -1).astype(np.int8, copy=False)


def vote_two_spaces(Z_hat: np.ndarray, H_LM: np.ndarray, cand_vocab: Sequence[str]) -> tuple[str, np.ndarray]:
    if not cand_vocab:
        cand_vocab = ["<unk>"]
    M_mem = np.vstack([L_mem(tok).astype(np.int8, copy=False) for tok in cand_vocab])
    M_lm = np.vstack([L_lm(tok).astype(np.int8, copy=False) for tok in cand_vocab])
    z32 = Z_hat.astype(np.int32, copy=False)
    h32 = H_LM.astype(np.int32, copy=False)
    scores = (M_mem.astype(np.int32) @ z32).astype(np.float64)
    scores += float(LAM) * (M_lm.astype(np.int32) @ h32).astype(np.float64)
    best = int(np.argmax(scores))
    return cand_vocab[best], scores


def gather_candidates(C_K: np.ndarray, history: Sequence[str]) -> list[str]:
    seen: set[str] = set()
    candidates: list[str] = []
    for c in C_K:
        freq_list = bucket2vocab_freq.get(int(c), [])
        for tok, _cnt in freq_list[:CAND_PER_BUCKET]:
            if tok not in seen:
                seen.add(tok)
                candidates.append(tok)
            if len(candidates) >= V_MAX:
                break
        if len(candidates) >= V_MAX:
            break
    for tok in reversed(history):
        if tok not in seen:
            candidates.append(tok)
            seen.add(tok)
        if len(candidates) >= V_MAX:
            break
    if not candidates and all_vocab:
        candidates.extend(all_vocab[: min(64, len(all_vocab))])
    if not candidates:
        candidates.append("<unk>")
    return candidates[:V_MAX]


def decode_span(meta: dict, *, max_steps: int = 6) -> dict:
    history = list(meta.get("history_tokens", [])[-ELL:])
    rng_demo = np.random.default_rng(4242 + meta["sentence_idx"])
    H_LM = rademacher(D, rng_demo)
    for tok in history:
        H_LM = update_LM_sep(H_LM, tok)

    decoded: list[str] = []
    targets = list(meta.get("span_tokens", []))
    Hs = meta["Z_en"]
    steps = min(max_steps, len(targets)) if targets else max_steps

    for _ in range(steps):
        Qs = DD1_ctx(Hs, G_DEC)
        Rt = DD2_query_bin(Qs, history, L_lm, pi, alpha=1.0, beta=1.0, ell=ELL)
        Rt_tilde = DD3_bindToMem(Rt, G_MEM)
        c_star, C_K, _ = DD4_search_topK(Rt_tilde, prototypes, K=64)
        Z_hat = DD5_payload(prototypes[c_star])
        cand_vocab = gather_candidates(C_K, history)
        token_star, scores = vote_two_spaces(Z_hat, H_LM, cand_vocab)
        decoded.append(token_star)
        history.append(token_star)
        if len(history) > ELL:
            history = history[-ELL:]
        H_LM = update_LM_sep(H_LM, token_star)

    return {
        "sentence_idx": meta["sentence_idx"],
        "span_bounds": (meta["start"], meta["stop"]),
        "history_seed": meta.get("history_tokens", []),
        "decoded": decoded,
        "reference": targets[: len(decoded)],
    }


sample_metas = [m for m in span_meta if m.get("span_tokens")] or span_meta[:5]
results = [decode_span(meta, max_steps=8) for meta in tqdm(sample_metas[:5])]

correct = 0
count = 0
for res in results:
    print("Sentence #", res["sentence_idx"], "span", res["span_bounds"])
    print("  History seed:", res["history_seed"])
    print("  Decoded   :", " ".join(res["decoded"]))
    print("  Reference :", " ".join(res["reference"]))
    matches = sum(p == t for p, t in zip(res["decoded"], res["reference"]))
    correct += matches
    count += len(res["decoded"])
    print(f"  Match count: {matches}/{len(res['decoded'])}")
    print()

if count:
    print(f"Top-1 token accuracy on decoded spans: {correct / count:.3f} (over {count} tokens)")
else:
    print("Decoded spans are empty; check MEM span generation.")


100%|██████████| 5/5 [00:00<00:00, 35.30it/s]

Sentence # 0 span (0, 8)
  History seed: []
  Decoded   : m. si m. m. m. m. m. m.
  Reference : l'idée brusque du mariage qu'elle poursuivait d'un sourire
  Match count: 0/8

Sentence # 0 span (4, 12)
  History seed: ["l'idée", 'brusque', 'du', 'mariage']
  Decoded   : m. m. m. m. m. m. m. m.
  Reference : qu'elle poursuivait d'un sourire si tranquille entre cécile
  Match count: 0/8

Sentence # 0 span (8, 16)
  History seed: ["qu'elle", 'poursuivait', "d'un", 'sourire']
  Decoded   : m. m. m. de_ses de_ses de_ses de_ses de_ses
  Reference : si tranquille entre cécile et paul, acheva __sent_marker_0
  Match count: 0/8

Sentence # 0 span (12, 20)
  History seed: ['si', 'tranquille', 'entre', 'cécile']
  Decoded   : de_ses m. m. m. de_ses m. m. m.
  Reference : et paul, acheva __sent_marker_0 de l'exaspérer. l'idée_brusque brusque_du
  Match count: 0/8

Sentence # 0 span (16, 24)
  History seed: ['et', 'paul,', 'acheva', '__sent_marker_0']
  Decoded   : m. de_ses de_ses de_ses de_ses de_




## 7. Tests automatisés


In [None]:
import pytest
pytest.main(['tests/test_enc_mem_dec.py', '-q'])
