In [1]:
from collections import Counter, defaultdict
import re
import sys
import argparse
import io

In [2]:
# -----------------------
# Utilities
# -----------------------
def read_corpus(path):
    word_counts = Counter()
    with io.open(path, "r", encoding="utf8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            tokens = line.split()
            for tok in tokens:
                word_counts[tok] += 1
    return word_counts

In [3]:
# -----------------------
# BPE Implementation
# -----------------------
def get_word_as_symbols(word):
    return list(word) + ["</w>"]

def build_token_sequences(word_counts):
    seqs = {}
    for w, c in word_counts.items():
        syms = tuple(get_word_as_symbols(w))
        seqs[syms] = seqs.get(syms, 0) + c
    return seqs

def get_pair_frequencies(seqs):
    pairs = Counter()
    for symbols, count in seqs.items():
        if len(symbols) < 2: 
            continue
        prev = symbols[0]
        for s in symbols[1:]:
            pairs[(prev, s)] += count
            prev = s
    return pairs

def merge_pair_in_seq(seq, pair):
    a,b = pair
    out = []
    i = 0
    n = len(seq)
    while i < n:
        if i < n-1 and seq[i] == a and seq[i+1] == b:
            out.append(a + b)
            i += 2
        else:
            out.append(seq[i])
            i += 1
    return tuple(out)

def train_bpe(word_counts, max_merges=None, target_vocab_size=None, report_every=1000):
    seqs = build_token_sequences(word_counts)
    merges = []
    vocab = set()
    for s in seqs:
        for sym in s:
            vocab.add(sym)
    initial_vocab_size = len(vocab)

    iteration = 0
    while True:
        iteration += 1
        pairs = get_pair_frequencies(seqs)
        if not pairs:
            break
        most_common_pair, freq = pairs.most_common(1)[0]
        merges.append(most_common_pair)
        new_seqs = {}
        for symbols, count in seqs.items():
            new_symbols = merge_pair_in_seq(symbols, most_common_pair)
            new_seqs[new_symbols] = new_seqs.get(new_symbols, 0) + count
        seqs = new_seqs
        a,b = most_common_pair
        vocab.add(a + b)
        if max_merges is not None and len(merges) >= max_merges:
            break
        if target_vocab_size is not None and len(vocab) >= target_vocab_size:
            break
        if iteration % report_every == 0:
            print(f"[BPE] Iter {iteration}: merges={len(merges)}, vocab_size={len(vocab)}, top_pair={most_common_pair}, freq={freq}")
    return merges, vocab, seqs

def save_merges(merges, path):
    with io.open(path, "w", encoding="utf8") as f:
        for a,b in merges:
            f.write(f"{a} {b}\n")

def save_vocab(vocab, path):
    with io.open(path, "w", encoding="utf8") as f:
        for tok in sorted(vocab):
            f.write(tok + "\n")

def bpe_encode_word(word, merges):
    symbols = list(word) + ["</w>"]
    merge_map = { (a,b): a+b for (a,b) in merges }
    seq = tuple(symbols)
    changed = True
    for pair in merges:
        seq = merge_pair_in_seq(seq, pair)
    tokens = list(seq)
    if tokens and tokens[-1] == "</w>":
        tokens = tokens[:-1]
    return tokens

In [4]:
# -----------------------
# WordPiece-style Implementation (frequency-driven)
# -----------------------
def train_wordpiece_like(word_counts, target_vocab_size=None, max_merges=None, max_subword_len=100, report_every=1000):
    seqs = build_token_sequences(word_counts)
    vocab = set()
    for s in seqs:
        for sym in s:
            vocab.add(sym)
    merges = []
    iteration = 0
    while True:
        iteration += 1
        substr_freq = Counter()
        for symbols, count in seqs.items():
            L = len(symbols)
            for i in range(L):
                for j in range(i+2, min(L, i+max_subword_len)+1):
                    sub = tuple(symbols[i:j])
                    if any(sym == "</w>" for sym in sub[:-1]):
                        continue
                    substr_freq[sub] += count
        if not substr_freq:
            break
        candidate, freq = substr_freq.most_common(1)[0]
        if freq <= 0:
            break
        new_token = "".join(candidate)
        merges.append(candidate)
        new_seqs = {}
        for symbols, count in seqs.items():
            s_list = list(symbols)
            i = 0
            out = []
            while i < len(s_list):
                if i + len(candidate) <= len(s_list) and tuple(s_list[i:i+len(candidate)]) == candidate:
                    out.append(new_token)
                    i += len(candidate)
                else:
                    out.append(s_list[i])
                    i += 1
            new_seqs[tuple(out)] = new_seqs.get(tuple(out), 0) + count
        seqs = new_seqs
        vocab.add(new_token)
        if max_merges is not None and len(merges) >= max_merges:
            break
        if target_vocab_size is not None and len(vocab) >= target_vocab_size:
            break
        if iteration % report_every == 0:
            print(f"[WP] Iter {iteration}: merges={len(merges)}, vocab_size={len(vocab)}, chosen_len={len(candidate)}, freq={freq}")
    return merges, vocab, seqs

def wordpiece_encode_word(word, merges):
    added_tokens = {"".join(m): True for m in merges}
    chars = list(word)
    i = 0
    out = []
    while i < len(chars):
        matched = None
        maxlen = 0
        for L in range(len(chars)-i, 0, -1):
            candidate = "".join(chars[i:i+L])
            if candidate in added_tokens or L == 1:
                matched = candidate
                maxlen = L
                break
        if matched is None:
            matched = chars[i]
            maxlen = 1
        out.append(matched)
        i += maxlen
    return out

In [None]:
# -----------------------
# Command-line / Run
# -----------------------
def main():
    parser = argparse.ArgumentParser(description="Train BPE and WordPiece-style tokenizers.")
    parser.add_argument("--input", type=str, default="bengali_tokenized.txt", help="Path to tokenized text (space-separated tokens).")
    parser.add_argument("--bpe_merges", type=int, default=32000, help="Number of BPE merges to run (if set).")
    parser.add_argument("--wp_merges", type=int, default=32000, help="Number of WordPiece-style merges to run (if set).")
    parser.add_argument("--vocab_size", type=int, default=32000, help="Target vocab size for vocab-driven training (if used).")
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    print("Reading corpus...")
    word_counts = read_corpus(args.input)
    print(f"Unique word types: {len(word_counts)}, token (types) example: {next(iter(word_counts))}")

    print("Training BPE with fixed merge steps =", args.bpe_merges)
    merges_bpe_fixed, vocab_bpe_fixed, seqs_bpe_fixed = train_bpe(
        word_counts, max_merges=args.bpe_merges, target_vocab_size=None, report_every=2000
    )
    save_merges(merges_bpe_fixed, f"merges_bpe_{args.bpe_merges}merges.txt")
    save_vocab(vocab_bpe_fixed, f"vocab_bpe_{args.bpe_merges}merges.txt")
    print("Saved BPE merges and vocab for fixed merges.")

    print("Training BPE until vocab size reaches", args.vocab_size)
    merges_bpe_vocab, vocab_bpe_vocab, seqs_bpe_vocab = train_bpe(
        word_counts, max_merges=None, target_vocab_size=args.vocab_size, report_every=2000
    )
    save_merges(merges_bpe_vocab, f"merges_bpe_vocab_{args.vocab_size}.txt")
    save_vocab(vocab_bpe_vocab, f"vocab_bpe_vocab_{args.vocab_size}.txt")
    print("Saved BPE merges and vocab for vocab-size target.")

    print("Training WordPiece-style with fixed merges =", args.wp_merges)
    merges_wp_fixed, vocab_wp_fixed, seqs_wp_fixed = train_wordpiece_like(
        word_counts, target_vocab_size=None, max_merges=args.wp_merges, report_every=2000
    )
    with io.open(f"merges_wp_{args.wp_merges}merges.txt", "w", encoding="utf8") as f:
        for m in merges_wp_fixed:
            f.write(" ".join(m) + "\n")
    save_vocab(vocab_wp_fixed, f"vocab_wp_{args.wp_merges}merges.txt")
    print("Saved WordPiece-style merges and vocab for fixed merges.")

    print("Training WordPiece-style until vocab size reaches", args.vocab_size)
    merges_wp_vocab, vocab_wp_vocab, seqs_wp_vocab = train_wordpiece_like(
        word_counts, target_vocab_size=args.vocab_size, max_merges=None, report_every=2000
    )
    with io.open(f"merges_wp_vocab_{args.vocab_size}.txt", "w", encoding="utf8") as f:
        for m in merges_wp_vocab:
            f.write(" ".join(m) + "\n")
    save_vocab(vocab_wp_vocab, f"vocab_wp_vocab_{args.vocab_size}.txt")
    print("Saved WordPiece-style merges and vocab for vocab-size target.")

    print("Done. Files written to current directory.")

if __name__ == "__main__":
    import sys
    sys.argv = ["train_subwords.py", "--input", "tokenized_bengali.txt"]
    main()



Reading corpus...
Unique word types: 38949, token (types) example: নয়াদিল্লি
Training BPE with fixed merge steps = 32000
