In [4]:
from datasets import load_dataset
import random

# Load datasets
mc4 = load_dataset("zicsx/mC4-Hindi-Cleaned-2.0", split="train")
stories = load_dataset("OmAlve/TinyStories-Hindi", split="train")

# Sample 30k total: 18k mC4 + 12k stories (60-40 split)
mc4_sample_size = 40_000
stories_sample_size = 10_000

# Random sampling
random.seed(42)  # for reproducibility
mc4_indices = random.sample(range(len(mc4)), mc4_sample_size)
stories_indices = random.sample(range(len(stories)), stories_sample_size)

mc4_sampled = mc4.select(mc4_indices)
stories_sampled = stories.select(stories_indices)

# Extract text
mc4_texts = [row['text'] for row in mc4_sampled]
stories_texts = [row['translated'] for row in stories_sampled]

# Combine and shuffle
all_texts = mc4_texts + stories_texts
random.shuffle(all_texts)

# Save corpus
training_text = '\n'.join(all_texts)

with open('hindi_corpus_30k.txt', 'w', encoding='utf-8') as f:
    f.write(training_text)

print(f"Corpus size: {len(all_texts)} texts")
print(f"Total characters: {len(training_text):,}")

Corpus size: 50000 texts
Total characters: 77,948,347


In [5]:
print(training_text[:500]) 

भारतीय वन डे टीम से लगभग एक साल से बहार चल रहे शानदार बल्लेबाज सुरेश रैना की वापसी पर फिर ग्रहण लग गया है। वायरल बुखार के चलते न्यूजीलैंड के खिलाफ चल रही सीरीज के पहले वनडे से बाहर होने के बाद रैना मंगलवार को टीम इंडिया से न सिर्फ जुड़ गए बल्कि उन्होंने हाथ में बल्ला भी उठा लिया।
धोनी की दिक्कत यह होती कि धर्मशाल में अच्छा प्रदर्शन करने वाले केदार जाधव को रैना के लिए कैसे नज़र अंदाज़ किया जाए। फिलहाल दोनों के लिए कोई परेशानी खड़ी नहीं होने जा रही है। देर शाम बीसीसीआई की मेडिकल टीम ने साफ़ कर दिया कि 


In [8]:
import regex as re

# Hindi-focused BPE pattern
HINDI_PATTERN = re.compile(
    r"( ?[\p{Devanagari}\p{M}]+| ?\p{N}+| ?[^\p{Devanagari}\p{M}\p{N}\s]+|\s+)",
    re.UNICODE
)

def pretokenize_hindi(text):
    return re.findall(HINDI_PATTERN, text)

# Test on first 2000 chars
tokens = pretokenize_hindi(training_text[:2000])
print("Sample tokens:", tokens[:50])
print("Token count:", len(tokens))


Sample tokens: ['भारतीय', ' वन', ' डे', ' टीम', ' से', ' लगभग', ' एक', ' साल', ' से', ' बहार', ' चल', ' रहे', ' शानदार', ' बल्लेबाज', ' सुरेश', ' रैना', ' की', ' वापसी', ' पर', ' फिर', ' ग्रहण', ' लग', ' गया', ' है', '।', ' वायरल', ' बुखार', ' के', ' चलते', ' न्यूजीलैंड', ' के', ' खिलाफ', ' चल', ' रही', ' सीरीज', ' के', ' पहले', ' वनडे', ' से', ' बाहर', ' होने', ' के', ' बाद', ' रैना', ' मंगलवार', ' को', ' टीम', ' इंडिया', ' से', ' न']
Token count: 425


In [None]:
# Build corpus tokens using characters instead of bytes
corpus_tokens = []

# Convert full big_text to pre-tokens
ptoks = pretokenize_hindi(training_text)

for pt in ptoks:
    # list(pt) splits Hindi string into real Unicode chars
    symbols = list(pt)
    corpus_tokens.append(symbols)

print("Total pre-tokens:", len(corpus_tokens))
print("Example token:", corpus_tokens[0])

Total pre-tokens: 17498763
Example token: ['भ', 'ा', 'र', 'त', 'ी', 'य']


In [10]:
# STEP 2: Build initial character vocabulary
unique_chars = set()

for tok in corpus_tokens:
    for ch in tok:
        unique_chars.add(ch)

unique_chars = sorted(list(unique_chars))
print("Number of unique chars:", len(unique_chars))
print("Chars:", unique_chars[:200])  # just to inspect


# STEP 3: Map chars to integer IDs
sym2id = {}
id2sym = {}

for i, ch in enumerate(unique_chars):
    sym2id[ch] = i
    id2sym[i] = ch

print("First 20 id2sym:", [(i, id2sym[i]) for i in range(min(20, len(id2sym)))])


# STEP 4: Convert corpus_tokens -> corpus_ids
corpus_ids = []
for tok in corpus_tokens:
    ids = [sym2id[ch] for ch in tok]
    corpus_ids.append(ids)

print("Example ids:", corpus_ids[0])


Number of unique chars: 214
Chars: ['\n', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '¢', '¦', '©', 'Â', 'Ã', 'â', 'œ', 'ँ', 'ं', 'ः', 'अ', 'आ', 'इ', 'ई', 'उ', 'ऊ', 'ऋ', 'ऌ', 'ऍ', 'ऎ', 'ए', 'ऐ', 'ऑ', 'ऒ', 'ओ', 'औ', 'क', 'ख', 'ग', 'घ', 'ङ', 'च', 'छ', 'ज', 'झ', 'ञ', 'ट', 'ठ', 'ड', 'ढ', 'ण', 'त', 'थ', 'द', 'ध', 'न', 'ऩ', 'प', 'फ', 'ब', 'भ', 'म', 'य', 'र', 'ऱ', 'ल', 'ळ', 'ऴ', 'व', 'श', 'ष', 'स', 'ह', '़', 'ऽ', 'ा', 'ि', 'ी', 'ु', 'ू', 'ृ', 'ॄ', 'ॅ', 'ॆ', 'े', 'ै', 'ॉ', 'ॊ', 'ो', 'ौ', '्', 'ॎ', 'ॐ', '॑', '॒', '॓', '॔', 'क़', 'ख़', 'ग़', 'ज़', 'ड़', 'ढ़', 'फ़', 'य़', 'ॠ', 'ॢ', '।

In [11]:
from collections import Counter, defaultdict

def build_initial_stats(corpus_ids):
    """
    Returns:
      global_pair_counts: Counter of all pairs across the corpus
      pair_to_sequences: mapping pair -> set of sequence indices that contain it
    """
    global_pair_counts = Counter()
    pair_to_sequences = defaultdict(set)

    for seq_id, seq in enumerate(corpus_ids):
        for a, b in zip(seq, seq[1:]):
            pair = (a, b)
            global_pair_counts[pair] += 1
            pair_to_sequences[pair].add(seq_id)

    return global_pair_counts, pair_to_sequences


In [12]:
def merge_pair(
    a, b, new_id,
    corpus_ids,
    global_pair_counts,
    pair_to_sequences
):
    """
    Perform optimized BPE merge of pair (a, b) -> new_id.

    Updates:
      - corpus_ids
      - global_pair_counts
      - pair_to_sequences
    Only sequences containing (a,b) are modified.
    """

    # the pair we are merging
    pair = (a, b)

    # sequences that contain this pair
    affected = pair_to_sequences.get(pair, set())

    # we will delete this entry after merge
    if pair in pair_to_sequences:
        del pair_to_sequences[pair]
    if pair in global_pair_counts:
        del global_pair_counts[pair]

    # iterate only those sequences
    for seq_id in affected:
        seq = corpus_ids[seq_id]

        # --------------------------
        # 1. REMOVE OLD PAIR COUNTS
        # --------------------------
        # remove all pairs from this sequence from global counts
        for x, y in zip(seq, seq[1:]):
            global_pair_counts[(x, y)] -= 1
            pair_to_sequences[(x, y)].discard(seq_id)
            if global_pair_counts[(x, y)] <= 0:
                del global_pair_counts[(x, y)]
                del pair_to_sequences[(x, y)]

        # --------------------------
        # 2. MERGE THE SEQUENCE
        # --------------------------
        new_seq = []
        i = 0
        L = len(seq)
        while i < L:
            if i < L - 1 and seq[i] == a and seq[i+1] == b:
                new_seq.append(new_id)
                i += 2
            else:
                new_seq.append(seq[i])
                i += 1

        corpus_ids[seq_id] = new_seq

        # --------------------------
        # 3. ADD NEW PAIR COUNTS
        # --------------------------
        for x, y in zip(new_seq, new_seq[1:]):
            global_pair_counts[(x, y)] += 1
            pair_to_sequences[(x, y)].add(seq_id)


In [None]:
global_pair_counts, pair_to_sequences = build_initial_stats(corpus_ids)


In [14]:
vocab_size_target = 12000
initial_vocab_size = len(sym2id)
next_token_id = initial_vocab_size

merges = []

for it in range(vocab_size_target - initial_vocab_size):

    if not global_pair_counts:
        print("No more pairs left to merge.")
        break

    # pick best pair
    (a, b), freq = global_pair_counts.most_common(1)[0]

    if freq < 2:
        print("Stopping early, freq < 2")
        break

    new_id = next_token_id
    next_token_id += 1

    # build new symbol for logging
    new_sym = id2sym[a] + id2sym[b]
    id2sym[new_id] = new_sym
    merges.append((a, b, new_id))

    # log
    if it % 100 == 0:
        print(f"[{it}] merge '{id2sym[a]}' + '{id2sym[b]}' -> '{new_sym}' (freq={freq})")

    # perform merge + update stats
    merge_pair(
        a, b, new_id,
        corpus_ids,
        global_pair_counts,
        pair_to_sequences
    )

print("Done.")
print(f"Final vocab size: {next_token_id}")


[0] merge ' ' + 'क' -> ' क' (freq=2659325)
[100] merge 'त' + '्' -> 'त्' (freq=84654)
[200] merge 'ि' + 'र' -> 'िर' (freq=38373)
[300] merge ' ' + '"' -> ' "' (freq=25154)
[400] merge ' उन' + 'के' -> ' उनके' (freq=18503)
[500] merge ' क' + '्यों' -> ' क्यों' (freq=14240)
[600] merge 'े' + 'ड' -> 'ेड' (freq=11318)
[700] merge 'ेश' + 'न' -> 'ेशन' (freq=9532)
[800] merge ' ह' + 'ूं' -> ' हूं' (freq=8184)
[900] merge ' ब' + 'ल्' -> ' बल्' (freq=7055)
[1000] merge 'आ' + 'र' -> 'आर' (freq=6415)
[1100] merge ' अ' + 'ंदर' -> ' अंदर' (freq=5669)
[1200] merge 'क्' + 'शन' -> 'क्शन' (freq=5081)
[1300] merge 'त्र' + 'ी' -> 'त्री' (freq=4630)
[1400] merge ' इत' + 'ना' -> ' इतना' (freq=4257)
[1500] merge 'ग' + 'म' -> 'गम' (freq=3878)
[1600] merge 'ॉ' + 'ट' -> 'ॉट' (freq=3598)
[1700] merge 'प' + 'े' -> 'पे' (freq=3311)
[1800] merge ' पा' + 'या' -> ' पाया' (freq=3067)
[1900] merge ' निक' + 'ाल' -> ' निकाल' (freq=2867)
[2000] merge ' वाल' + 'ों' -> ' वालों' (freq=2677)
[2100] merge ' ले' + 'ख' -> ' लेख'

In [15]:
# ranks[(a,b)] = rank_index
# lower rank means earlier / higher priority merge
merge_ranks = {}
for rank, (a, b, new_id) in enumerate(merges):
    merge_ranks[(a, b)] = rank


In [22]:
def apply_merges_to_sequence(seq):
    while True:
        if len(seq) < 2:
            return seq
        
        pairs = [(seq[i], seq[i+1]) for i in range(len(seq)-1)]
        ranked = [(merge_ranks[p], i, p) for i, p in enumerate(pairs) if p in merge_ranks]
        
        if not ranked:
            return seq
        
        _, idx, (a, b) = min(ranked)
        
        # find new_id
        for x, y, nid in merges:
            if x == a and y == b:
                new_id = nid
                break
        
        new_seq = []
        i = 0
        while i < len(seq):
            if i == idx:
                new_seq.append(new_id)
                i += 2
            else:
                new_seq.append(seq[i])
                i += 1
        
        seq = new_seq


In [23]:
def encode(text):
    tokens = pretokenize_hindi(text)
    output_ids = []

    for tok in tokens:
        seq = [sym2id[ch] for ch in tok]
        seq = apply_merges_to_sequence(seq)
        output_ids.extend(seq)

    return output_ids


In [24]:
text = "भारत में चुनाव हुआ।"
ids = encode(text)
print(ids)
print("Decoded:", "".join(id2sym[i] for i in ids))


[2755, 232, 839, 639, 192]
Decoded: भारत में चुनाव हुआ।


In [25]:
import json

def save_vocab_json(path="vocab.json"):
    vocab = {}
    for i, sym in id2sym.items():
        vocab[sym] = i
    with open(path, "w", encoding="utf-8") as f:
        json.dump(vocab, f, ensure_ascii=False, indent=2)

save_vocab_json("vocab.json")


In [26]:
def save_merges_txt(path="merges.txt"):
    with open(path, "w", encoding="utf-8") as f:
        f.write("#version: BPE\n")
        for (a, b, new_id) in merges:
            A = id2sym[a]
            B = id2sym[b]
            # Save raw symbols exactly how BPE expects
            f.write(f"{A} {B}\n")

save_merges_txt("merges.txt")


In [28]:
import pickle
from collections import defaultdict, Counter

# You already have this from before:
import regex as re
HINDI_PATTERN = re.compile(
    r"( ?[\p{Devanagari}\p{M}]+| ?\p{N}+| ?[^\p{Devanagari}\p{M}\p{N}\s]+|\s+)",
    re.UNICODE
)

def pretokenize_hindi(text):
    return re.findall(HINDI_PATTERN, text)


class HindiBPETokenizer:
    def __init__(self, id2sym, sym2id, merges):
        self.id2sym = id2sym
        self.sym2id = sym2id
        self.merges = merges
        
        # build merge_ranks
        self.merge_ranks = {
            (a, b): rank for rank, (a, b, new_id) in enumerate(merges)
        }

        # Optional: direct map pair → new_id (for faster merges)
        self.pair2newid = {(a, b): new_id for (a, b, new_id) in merges}

    # -------------------------
    # APPLY MERGES TO ONE SEQUENCE
    # -------------------------
    def _apply_merges(self, seq):
        # Greedy BPE merge
        while True:
            if len(seq) < 2:
                return seq

            pairs = [(seq[i], seq[i+1]) for i in range(len(seq)-1)]

            # find mergeable pairs
            ranked = [
                (self.merge_ranks[pair], i, pair)
                for i, pair in enumerate(pairs)
                if pair in self.merge_ranks
            ]

            if not ranked:
                return seq

            # choose highest priority merge (lowest rank)
            _, idx, (a, b) = min(ranked)

            new_id = self.pair2newid[(a, b)]

            # apply it
            new_seq = []
            i = 0
            L = len(seq)
            while i < L:
                if i == idx:
                    new_seq.append(new_id)
                    i += 2
                else:
                    new_seq.append(seq[i])
                    i += 1
            seq = new_seq

    # -------------------------
    # ENCODE
    # -------------------------
    def encode(self, text):
        pretoks = pretokenize_hindi(text)
        output_ids = []

        for tok in pretoks:
            # split into chars → ids
            seq = [self.sym2id[ch] for ch in tok]
            # apply BPE
            seq = self._apply_merges(seq)
            # append
            output_ids.extend(seq)

        return output_ids

    # -------------------------
    # DECODE
    # -------------------------
    def decode(self, ids):
        return "".join(self.id2sym[i] for i in ids)

    # -------------------------
    # SAVE
    # -------------------------
    def save(self, path="hindi_tokenizer.pkl"):
        data = {
            "id2sym": self.id2sym,
            "sym2id": self.sym2id,
            "merges": self.merges
        }
        with open(path, "wb") as f:
            pickle.dump(data, f)

    # -------------------------
    # LOAD
    # -------------------------
    @classmethod
    def load(cls, path="hindi_tokenizer.pkl"):
        with open(path, "rb") as f:
            data = pickle.load(f)
        return cls(
            data["id2sym"],
            data["sym2id"],
            data["merges"]
        )


In [29]:
tok = HindiBPETokenizer(id2sym, sym2id, merges)
tok.save("hindi_tokenizer.pkl")


In [30]:
tok = HindiBPETokenizer.load("hindi_tokenizer.pkl")


In [31]:
ids = tok.encode("भारत में चुनाव हुआ।")
print(ids)


[2755, 232, 839, 639, 192]


In [32]:
print(tok.decode(ids))


भारत में चुनाव हुआ।


In [33]:
compression_ratio = len(training_text)/len(tok.encode(training_text))


In [34]:
print(compression_ratio)

3.86336353891803
