In [10]:
import json
from collections import defaultdict
import heapq
import kenlm
from gensim.models import Word2Vec

def _sent_logp(lm, tokens):
    # Score the sentence formed by joining tokens with spaces
    # (log10; bos/eos disabled to match autocomplete style)
    return lm.score(" ".join(tokens), bos=False, eos=False)

def _cond_logp(lm, history, nxt):
    # log P(nxt | history) via sentence-score diff
    return _sent_logp(lm, history + [nxt]) - _sent_logp(lm, history)

def _topk_followers(lm, history, candidates, K):
    # Return top-K (token, logp) for P(token | history)
    heap = []
    base = _sent_logp(lm, history)
    for tok in candidates:
        lp = _sent_logp(lm, history + [tok]) - base
        if len(heap) < K:
            heapq.heappush(heap, (lp, tok))
        elif lp > heap[0][0]:
            heapq.heapreplace(heap, (lp, tok))
    return sorted(((tok, lp) for lp, tok in heap), key=lambda x: -x[1])

def build_sequences(lm, w2v, K=5, V_LIMIT=50000, CAND_M=None, out_jsonl="seqs.jsonl"):
    """
    Build a flat JSONL of scored sequences (2/3/4 tokens) using KenLM for ranking.
    Each line: {"seq": ["A","B"], "score": ...} or length-3/4 analogs.

    Args:
      lm: kenlm.Model (already loaded)
      w2v: gensim Word2Vec (already loaded)
      K: beam width and per-step top-k
      V_LIMIT: cap vocab size (Word2Vec frequency order)
      CAND_M: if set, prefilter with w2v.most_similar(token, topn=CAND_M); fallback to full V if OOV
      out_jsonl: output file path
    """
    V = w2v.wv.index_to_key[:V_LIMIT]

    def cand_for(tok):
        if CAND_M is None or tok not in w2v.wv:
            return V
        return [w for (w, _) in w2v.wv.most_similar(tok, topn=min(CAND_M, len(V)))]

    with open(out_jsonl, "w", encoding="utf-8") as out:
        for i, A in enumerate(V):
            # ---- length 2: [A, B] ----
            topB = _topk_followers(lm, [A], cand_for(A), K)  # [(B, lpB)]
            for B, sAB in topB:
                out.write(json.dumps({"seq": [A, B], "score": sAB}, ensure_ascii=False) + "\n")

            # ---- length 3: [A, B, C] (beam over B) ----
            beam3 = []
            for B, sAB in topB:
                topC = _topk_followers(lm, [A, B], cand_for(B), K)  # [(C, lpC)]
                for C, sC in topC:
                    beam3.append((A, B, C, sAB + sC))
            beam3.sort(key=lambda x: -x[3])
            beam3 = beam3[:K]
            for A_, B_, C_, sABC in beam3:
                out.write(json.dumps({"seq": [A_, B_, C_], "score": sABC}, ensure_ascii=False) + "\n")

            # ---- length 4: [A, B, C, D] (beam over triples) ----
            for A_, B_, C_, sABC in beam3:
                topD = _topk_followers(lm, [A_, B_, C_], cand_for(C_), K)  # [(D, lpD)]
                for D, sD in topD:
                    out.write(json.dumps({"seq": [A_, B_, C_, D], "score": sABC + sD}, ensure_ascii=False) + "\n")

            if (i + 1) % 500 == 0:
                print(f"processed {i+1}/{len(V)} tokens")



In [2]:
kenml = kenlm.Model("vi_model_6gramVinToken.binary")
w2v = Word2Vec.load("word2vec_vi_bao_st.model")

In [7]:
len(w2v.wv.index_to_key)

10755

In [11]:
build_sequences(kenml, w2v, K=5, V_LIMIT=50000, CAND_M=None, out_jsonl="suggestion.jsonl")

processed 500/10755 tokens
processed 1000/10755 tokens
processed 1500/10755 tokens
processed 2000/10755 tokens
processed 2500/10755 tokens
processed 3000/10755 tokens
processed 3500/10755 tokens
processed 4000/10755 tokens
processed 4500/10755 tokens
processed 5000/10755 tokens
processed 5500/10755 tokens
processed 6000/10755 tokens
processed 6500/10755 tokens
processed 7000/10755 tokens
processed 7500/10755 tokens
processed 8000/10755 tokens
processed 8500/10755 tokens
processed 9000/10755 tokens
processed 9500/10755 tokens
processed 10000/10755 tokens
processed 10500/10755 tokens


In [12]:
def load_sequences(jsonl_path):
    sequences = []
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                sequences.append(json.loads(line))
    return sequences

def find_by_prefix(sequences, prefix, topn=10):
    plen = len(prefix)
    matches = [rec for rec in sequences if rec["seq"][:plen] == prefix]
    matches.sort(key=lambda x: -x["score"])
    return matches[:topn]


In [None]:
# Load
seqs = load_sequences("suggestion.jsonl")

In [18]:
res1 = find_by_prefix(seqs, ["tai nạn"])
res1


[{'seq': ['tai nạn', 'giao'], 'score': -0.3564491271972656},
 {'seq': ['tai nạn', 'giao thông'], 'score': -0.36070823669433594},
 {'seq': ['tai nạn', 'giao', 'thông'], 'score': -0.36070823669433594},
 {'seq': ['tai nạn', 'nhân'], 'score': -1.061161994934082},
 {'seq': ['tai nạn', 'lao'], 'score': -1.0780587196350098},
 {'seq': ['tai nạn', 'lao động'], 'score': -1.0887422561645508},
 {'seq': ['tai nạn', 'lao', 'động'], 'score': -1.0887422561645508},
 {'seq': ['tai nạn', 'nhân', 'chất'], 'score': -1.464848518371582},
 {'seq': ['tai nạn', 'nhân', 'chất độc'], 'score': -1.4769086837768555},
 {'seq': ['tai nạn', 'nhân', 'chất', 'độc'], 'score': -1.4769086837768555}]