In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install tqdm cupy-cuda12x grapheme

# Standard BPE

In [None]:
from __future__ import annotations
import json
import logging
import time
import argparse
import re
import sys
import unicodedata
from pathlib import Path
from collections import defaultdict, Counter
from typing import Union, Optional, Dict, List
import numpy as np

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    handlers=[
        logging.StreamHandler(sys.stdout),
    ],
    force=True
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

WHITESPACE = '▁'
PAD = '<pad>'
UNK = '<unk>'
BOS = '<s>'
EOS = '</s>'

class MCounter(Counter):
    """extended Counter class with multiplication support"""
    def __mul__(self, other):
        if not isinstance(other, int):
            raise TypeError("Non-int factor")
        return MCounter({k: other * v for k, v in self.items()})

    def __rmul__(self, other):
        return self * other

    def __add__(self, other):
        return MCounter(super().__add__(other))

class TamilSimpleBPE:
    def __init__(
        self,
        vocab_size: int,
        pad_id: int = 3,
        unk_id: int = 0,
        bos_id: int = 1,
        eos_id: int = 2,
        coverage: float = 0.9999,
    ):
        self.desired_vocab_size = vocab_size
        self.coverage = coverage

        self.special_tokens = {
            PAD: pad_id,
            UNK: unk_id,
            BOS: bos_id,
            EOS: eos_id
        }

        # vocabulary mappings
        self.vocab = {}  # token_str -> token_id
        self.id2token = {}  # token_id -> token_str
        self.merges = []  # list of merge rules (left, right)

        for token_str, token_id in self.special_tokens.items():
            self.vocab[token_str] = token_id
            self.id2token[token_id] = token_str

        self.next_id = max(self.special_tokens.values()) + 1

    def _preprocess_tamil_text(self, text: str) -> str:

        text = unicodedata.normalize('NFC', text)

        text = re.sub(r'\s+', ' ', text)

        # keep Tamil script, Tamil numerals, and common punctuation
        # Unicode range: U+0B80-U+0BFF
        text = re.sub(r'[^\u0B80-\u0BFF\s\w\u0030-\u0039\u002E\u002C\u003F\u0021\u003A\u003B\u002D]', ' ', text)

        # English mixed with Tamil
        text = re.sub(r'([a-zA-Z]+)', r' \1 ', text)

        # Tamil numerals and punctuation
        text = re.sub(r'([௦-௯]+)', r' \1 ', text)

        text = re.sub(r'\s+', ' ', text)

        return text.strip()

    def _get_words(self, file: str) -> Dict[str, int]:

        logging.info(f'Loading corpus from {file}...')
        start_time = time.time()

        word_freqs = MCounter()
        line_count = 0

        with open(file, 'r', encoding='utf-8') as f:
            for line in f:
                line_count += 1
                if not line.strip():
                    continue

                processed_line = self._preprocess_tamil_text(line)
                if not processed_line:
                    continue

                words = processed_line.split()

                words = [WHITESPACE + word for word in words if word]
                word_freqs.update(words)

                if line_count % 50000 == 0:
                    logging.info(f'Processed {line_count:,} lines.')

        num_words = len(word_freqs)
        logging.info(f'Loaded {num_words:,} unique words from {line_count:,} lines in {time.time() - start_time:.2f}s.')

        return dict(word_freqs)

    def _get_characters(self, word_freqs: Dict[str, int]) -> MCounter:
        """extract character frequencies from words"""
        char_freqs = MCounter()
        for word, freq in word_freqs.items():
            for char in word:
                char_freqs[char] += freq
        return char_freqs

    def _filter_characters(self, char_freqs: MCounter) -> MCounter:
        """filter rare characters based on coverage threshold"""
        if self.coverage < 1:
            total_chars = sum(char_freqs.values())
            target_chars = round(self.coverage * total_chars)

            # sort characters by frequency (descending)
            sorted_chars = char_freqs.most_common()

            # keep characters until target coverage is reached
            kept_chars = MCounter()
            char_count = 0
            for char, freq in sorted_chars:
                kept_chars[char] = freq
                char_count += freq
                if char_count >= target_chars:
                    break

            removed_count = len(char_freqs) - len(kept_chars)
            if removed_count > 0:
                logging.info(f'Filtered out {removed_count} rare characters.')

            return kept_chars
        return char_freqs

    def _initialize_vocab(self, word_freqs: Dict[str, int]) -> Dict[str, List[str]]:
        """initialize vocabulary with characters and return word splits"""
        logging.info('Initializing vocabulary...')

        # get character frequencies
        char_freqs = self._get_characters(word_freqs)
        filtered_chars = self._filter_characters(char_freqs)

        # add characters to vocabulary
        for char in filtered_chars:
            if char not in self.vocab:
                self.vocab[char] = self.next_id
                self.id2token[self.next_id] = char
                self.next_id += 1

        # initialize word splits
        word_splits = {}
        for word in word_freqs:
            splits = []
            for char in word:
                if char in self.vocab:
                    splits.append(char)
                else:
                    splits.append(UNK)
            word_splits[word] = splits

        tamil_chars = sum(1 for char in filtered_chars if '\u0B80' <= char <= '\u0BFF')
        tamil_numerals = sum(1 for char in filtered_chars if '\u0BE6' <= char <= '\u0BEF')

        logging.info(f'Initialized vocabulary with {len(filtered_chars)} characters.')
        logging.info(f'Found {tamil_chars} Tamil script characters.')
        logging.info(f'Found {tamil_numerals} Tamil numerals.')

        return word_splits

    def _get_pairs(self, word_splits: Dict[str, List[str]], word_freqs: Dict[str, int]) -> MCounter:
        """count all adjacent pairs in the vocabulary"""
        pairs = MCounter()

        for word, splits in word_splits.items():
            freq = word_freqs[word]
            for i in range(len(splits) - 1):
                pair = (splits[i], splits[i + 1])
                pairs[pair] += freq

        return pairs

    def _merge_pair(self, pair: tuple[str, str], word_splits: Dict[str, List[str]],
                    word_freqs: Dict[str, int]) -> Dict[str, List[str]]:
        """merge a pair in all word splits"""
        left, right = pair
        merged = left + right

        # add merged token to vocabulary
        if merged not in self.vocab:
            self.vocab[merged] = self.next_id
            self.id2token[self.next_id] = merged
            self.next_id += 1

        # record merge rule
        self.merges.append(pair)

        # update word splits
        new_word_splits = {}
        for word, splits in word_splits.items():
            new_splits = []
            i = 0
            while i < len(splits):
                if i < len(splits) - 1 and splits[i] == left and splits[i + 1] == right:
                    # Merge the pair
                    new_splits.append(merged)
                    i += 2
                else:
                    new_splits.append(splits[i])
                    i += 1
            new_word_splits[word] = new_splits

        return new_word_splits

    def fit(self, input_file: str, output_dir: str, logging_step: int = 200) -> None:
        """train BPE tokenizer"""
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)

        word_freqs = self._get_words(input_file)
        word_splits = self._initialize_vocab(word_freqs)

        current_vocab_size = len(self.vocab)
        logging.info(f'Starting BPE training with {current_vocab_size} initial tokens.')

        # BPE training loop
        merge_times = []
        while current_vocab_size < self.desired_vocab_size:
            start_time = time.time()

            # count all pairs
            pairs = self._get_pairs(word_splits, word_freqs)

            if not pairs:
                logging.info('No more pairs to merge. Stopping training.')
                break

            # find most frequent pair
            most_frequent_pair, freq = pairs.most_common(1)[0]

            if freq <= 1:
                logging.info('No pairs with frequency > 1. Stopping training.')
                break

            # merge pair
            word_splits = self._merge_pair(most_frequent_pair, word_splits, word_freqs)
            current_vocab_size += 1

            merge_times.append(time.time() - start_time)

            if current_vocab_size % logging_step == 0:
                left, right = most_frequent_pair
                avg_time = np.mean(merge_times) if merge_times else 0
                logging.info(
                    f'Vocab size: {current_vocab_size:,}/{self.desired_vocab_size:,}. '
                    f'Merged "{left}" + "{right}" (freq: {freq:,}). '
                    f'Avg merge time: {avg_time:.3f}s'
                )
                merge_times = []

        logging.info(f'Training completed with final vocabulary size: {len(self.vocab):,}')

        self._save_simple_bpe_model(output_path / 'simple_bpe_model.json')
        self._save_huggingface_files(output_path)

        logging.info(f'Files saved to {output_path}')

    def _save_simple_bpe_model(self, file_path: Path) -> None:
        logging.info(f'Saving Simple BPE model to {file_path}...')

        model_data = {
            'vocab': self.vocab,
            'merges': [{'left': left, 'right': right} for left, right in self.merges],
            'vocab_size': len(self.vocab),
            'special_tokens': self.special_tokens,
            'language': 'tamil',
            'script': 'tamil'
        }

        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(model_data, f, indent=2, ensure_ascii=False)

    def _save_huggingface_files(self, output_path: Path) -> None:
        logging.info('Saving HuggingFace compatible files...')

        hf_merges = [f"{left} {right}" for left, right in self.merges]

        # tokenizer.json
        tokenizer_data = {
            "version": "1.0",
            "truncation": None,
            "padding": None,
            "added_tokens": [],
            "normalizer": {
                "type": "NFC"
            },
            "pre_tokenizer": {
                "type": "Sequence",
                "pretokenizers": [
                    {
                        "type": "WhitespaceSplit"
                    },
                    {
                        "type": "Metaspace",
                        "replacement": WHITESPACE,
                        "add_prefix_space": True
                    }
                ]
            },
            "post_processor": {
                "type": "TemplateProcessing",
                "single": f"{BOS}:1 $A:0 {EOS}:1",
                "pair": f"{BOS}:1 $A:0 {EOS}:1 $B:0 {EOS}:1",
                "special_tokens": {
                    BOS: {"id": self.special_tokens[BOS], "type_id": 1},
                    EOS: {"id": self.special_tokens[EOS], "type_id": 1}
                }
            },
            "decoder": {
                "type": "Metaspace",
                "replacement": WHITESPACE,
                "add_prefix_space": True
            },
            "model": {
                "type": "BPE",
                "dropout": None,
                "unk_token": UNK,
                "continuing_subword_prefix": None,
                "end_of_word_suffix": None,
                "fuse_unk": False,
                "vocab": self.vocab,
                "merges": hf_merges
            }
        }

        with open(output_path / 'tokenizer.json', 'w', encoding='utf-8') as f:
            json.dump(tokenizer_data, f, indent=2, ensure_ascii=False)

        # tokenizer_config.json
        config_data = {
            "tokenizer_class": "PreTrainedTokenizerFast",
            "auto_map": {
                "AutoTokenizer": ["tokenizer.json", None]
            },
            "bos_token": BOS,
            "eos_token": EOS,
            "unk_token": UNK,
            "pad_token": PAD,
            "model_max_length": 2048,
            "padding_side": "left",
            "truncation_side": "right",
            "chat_template": None,
            "clean_up_tokenization_spaces": True,
            "spaces_between_special_tokens": False,
            "language": "tamil",
            "script": "tamil"
        }

        with open(output_path / 'tokenizer_config.json', 'w', encoding='utf-8') as f:
            json.dump(config_data, f, indent=2, ensure_ascii=False)

        # special_tokens_map.json
        special_tokens_data = {
            "bos_token": BOS,
            "eos_token": EOS,
            "unk_token": UNK,
            "pad_token": PAD
        }

        with open(output_path / 'special_tokens_map.json', 'w', encoding='utf-8') as f:
            json.dump(special_tokens_data, f, indent=2, ensure_ascii=False)

        # added_tokens.json
        with open(output_path / 'added_tokens.json', 'w', encoding='utf-8') as f:
            json.dump({}, f, indent=2, ensure_ascii=False)

        # vocab.json
        with open(output_path / 'vocab.json', 'w', encoding='utf-8') as f:
            json.dump(self.vocab, f, indent=2, ensure_ascii=False)

        logging.info(f'Final vocabulary size: {len(self.vocab):,}')
        logging.info(f'Number of merge rules: {len(self.merges):,}')

    def tokenize(self, text: str) -> List[str]:
        text = self._preprocess_tamil_text(text)
        words = text.split()

        tokens = []
        for word in words:
            word_with_marker = WHITESPACE + word
            word_tokens = self._tokenize_word(word_with_marker)
            tokens.extend(word_tokens)

        return tokens

    def _tokenize_word(self, word: str) -> List[str]:
        """Tokenize a word using BPE merges"""
        tokens = []
        for char in word:
            if char in self.vocab:
                tokens.append(char)
            else:
                tokens.append(UNK)

        # apply merge rules
        for left, right in self.merges:
            new_tokens = []
            i = 0
            while i < len(tokens):
                if i < len(tokens) - 1 and tokens[i] == left and tokens[i + 1] == right:
                    merged = left + right
                    new_tokens.append(merged)
                    i += 2
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            tokens = new_tokens

        return tokens

    def encode(self, text: str) -> List[int]:
        """encode text to token IDs"""
        tokens = self.tokenize(text)
        return [self.vocab.get(token, self.special_tokens[UNK]) for token in tokens]

    def decode(self, token_ids: List[int]) -> str:
        """decode token IDs back to text"""
        tokens = [self.id2token.get(token_id, UNK) for token_id in token_ids]
        text = ''.join(tokens)
        text = text.replace(WHITESPACE, ' ')
        return text.strip()

def train_tamil_bpe(
    input_file: str,
    output_dir: str = "./tamil_tokenizer",
    vocab_size: int = 10000,
    coverage: float = 0.9999,
    logging_step: int = 200
):

    tokenizer = TamilSimpleBPE(
        vocab_size=vocab_size,
        coverage=coverage
    )

    start_time = time.time()
    tokenizer.fit(input_file, output_dir, logging_step)
    training_time = time.time() - start_time

    return tokenizer

def test_tamil_tokenizer(tokenizer, test_sentences=None):
    if test_sentences is None:
        test_sentences = [
            "வணக்கம், நீங்கள் எப்படி இருக்கிறீர்கள்?",
            "நான் தமிழ் மொழி கற்றுக்கொண்டிருக்கிறேன்.",
            "சென்னை தமிழ்நாட்டின் தலைநகரம்.",
            "புத்தகம் மேசையில் இருக்கிறது.",
            "இன்று வானிலை மிகவும் நல்லது.",
        ]

    print("\nTesting tokenizer with sample Tamil sentences:")
    print("=" * 70)

    for i, sentence in enumerate(test_sentences, 1):
        print(f"\n{i}. Original: {sentence}")

        tokens = tokenizer.tokenize(sentence)
        print(f"   Tokens: {tokens}")
        print(f"   Count: {len(tokens)} tokens")

        token_ids = tokenizer.encode(sentence)
        print(f"   IDs: {token_ids}")

        decoded = tokenizer.decode(token_ids)
        print(f"   Decoded: {decoded}")

        if decoded.strip() == sentence.strip():
            print("   Perfect reconstruction")
        else:
            print("   Reconstruction differs")

def main():
    input_file = "/content/drive/My Drive/Colab Notebooks/LRLs/tamil/dataset/ta_reduced_train.txt"
    output_dir = "/content/drive/My Drive/Colab Notebooks/LRLs/tamil/tokenizers/standard_bpe"

    vocab_size = 10000
    coverage = 0.9999
    logging_step = 200

    tokenizer = train_tamil_bpe(
        input_file=input_file,
        output_dir=output_dir,
        vocab_size=vocab_size,
        coverage=coverage,
        logging_step=logging_step
    )

    test_tamil_tokenizer(tokenizer)

if __name__ == '__main__':
    main()

# Picky BPE

In [None]:
from __future__ import annotations
import json
import logging
import time
import argparse
import re
from transformers import AutoTokenizer
from pathlib import Path
from collections import defaultdict, Counter
from typing import Union, Optional, Dict, List
import numpy as np

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    handlers=[
        logging.StreamHandler(),
    ],
    force=True
)

WHITESPACE = '▁'
PAD = '<pad>'
UNK = '<unk>'
BOS = '<s>'
EOS = '</s>'

class MCounter(Counter):
    """extended Counter class with multiplication support"""
    def __mul__(self, other):
        if not isinstance(other, int):
            raise TypeError("Non-int factor")
        return MCounter({k: other * v for k, v in self.items()})

    def __rmul__(self, other):
        return self * other

    def __add__(self, other):
        return MCounter(super().__add__(other))

class Token:
    def __init__(
        self,
        id: int,
        str: str,
        freq: int = 0,
        special: bool = False,
        present: bool = True,
        left: Optional[Token] = None,
        right: Optional[Token] = None,
        split: Optional[list[Token]] = None
    ):
        self.id = id
        self.str = str
        self.freq = freq
        self.special = special
        self.present = present
        self.atomic = len(str) == 1 or special
        self.words = set()
        self.left = left
        self.right = right
        self.split = split

    def __repr__(self):
        return f'{self.str} ({self.freq})'

    def walk(self) -> list[Token]:
        if self.atomic or self.present:
            return [self]
        return self.left.walk() + self.right.walk()

    def remove(self) -> None:
        if self.atomic:
            raise ValueError(f'Cannot remove an atomic token {self.str}.')
        self.present = False
        self.freq = 0
        self.words = set()

    def restore(self) -> None:
        if self.present:
            raise ValueError(f'Cannot revoke already present token {self.str}.')
        self.present = True

    def split_if_possible(self) -> Optional[list[Token]]:
        if self.atomic:
            return None
        self.present = False
        return self.walk()

    def to_dict(self) -> dict:
        return {
            'id': self.id,
            'str': self.str,
            'freq': self.freq,
            'special': self.special,
            'present': self.present,
            'left': self.left.id if self.left is not None else None,
            'right': self.right.id if self.right is not None else None,
            'split': [t.id for t in self.walk()]
        }

class Word:
    def __init__(self, id: int, word: str, freq: int = 0):
        self.id = id
        self.str = word
        self.freq = freq
        self.tokens = None
        self.pairs = None

    def __repr__(self) -> str:
        return f'{self.str} ({self.freq})'

    def encode(self, str2token: dict[str, Token]) -> None:
        self.tokens = [str2token[c] for c in self.str]
        self._recalculate()

    def _recalculate(self, update_tokens: bool = True) -> None:
        self.pairs = MCounter(zip(self.tokens[:-1], self.tokens[1:])) * self.freq
        if update_tokens:
            for token in self.tokens:
                token.words.add(self)

    def merge_pair(self, pair: tuple[Token, Token], new_token: Token, update_tokens: bool = True) -> int:
        new_tokens = []
        i = 0
        while i < len(self.tokens):
            if i < len(self.tokens) - 1 and (self.tokens[i], self.tokens[i+1]) == pair:
                new_tokens.append(new_token)
                i += 2
            else:
                new_tokens.append(self.tokens[i])
                i += 1
        new_token_frequency = len(self.tokens) - len(new_tokens)
        if update_tokens:
            pair[0].words.discard(self)
            pair[1].words.discard(self)
        self.tokens = new_tokens
        self._recalculate(update_tokens=update_tokens)
        return new_token_frequency * self.freq

    def split_token(self, token: Token, split: list[Token], update_tokens: bool = True):
        new_tokens = []
        for t in self.tokens:
            if t == token:
                new_tokens.extend(split)
            else:
                new_tokens.append(t)
        self.tokens = new_tokens
        self._recalculate(update_tokens=update_tokens)

class TamilPickyBPE:
    def __init__(
        self,
        vocab_size: int,
        pad_id: int = 3,
        unk_id: int = 0,
        bos_id: int = 1,
        eos_id: int = 2,
        coverage: float = 0.9999,
        threshold: float = 0.9999,
    ):
        self.desired_vocab_size = vocab_size
        self.pad_token = Token(pad_id, PAD, 0, special=True)
        self.unk_token = Token(unk_id, UNK, 0, special=True)
        self.bos_token = Token(bos_id, BOS, 0, special=True)
        self.eos_token = Token(eos_id, EOS, 0, special=True)

        self.id2token = {
            token.id: token for token in [self.pad_token, self.unk_token, self.bos_token, self.eos_token]
        }
        self.str2token = {
            token.str: token for token in [self.pad_token, self.unk_token, self.bos_token, self.eos_token]
        }
        self.str2token = defaultdict(lambda: self.unk_token, self.str2token)
        self.max_special_token_id = max(self.id2token.keys())
        self.actual_vocab_size = len(self.id2token)
        self.new_id = self.max_special_token_id + 1
        self.coverage = coverage
        self.threshold = threshold
        self.events = list()

    def _preprocess_tamil_text(self, text: str) -> str:
        """preprocess Tamil text preserving Tamil script"""
        import unicodedata

        text = unicodedata.normalize('NFC', text)
        text = text.replace(' ', f' {WHITESPACE}')

        # Tamil Unicode range: U+0B80-U+0BFF
        text = re.sub(r'[^\u0B80-\u0BFF\s\w\u0030-\u0039\u002E\u002C\u003F\u0021\u003A\u003B\u002D]', ' ', text)
        text = re.sub(r'([a-zA-Z]+)', r' \1 ', text)
        text = re.sub(r'([௦-௯]+)', r' \1 ', text)
        text = re.sub(r'\s+', ' ', text)

        return text.strip()

    def _get_words(self, file: str) -> list[Word]:
        """load and preprocess Tamil corpus from file"""
        logging.info(f'Loading Tamil corpus from {file}...')
        start_time = time.time()

        counter = MCounter()
        with open(file, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if not line.strip():
                    continue

                processed_line = self._preprocess_tamil_text(line.strip())
                if not processed_line:
                    continue

                words = processed_line.split()
                words = [WHITESPACE + word if not word.startswith(WHITESPACE) else word for word in words]
                counter.update(words)

                if i > 0 and i % 50000 == 0:
                    logging.info(f'Processed {i} lines.')

        num_words = len(counter)
        logging.info(f'Loaded {num_words} unique words in {time.time() - start_time:.2f}s.')

        return [Word(i, word, freq) for i, (word, freq) in enumerate(counter.items())]

    def _get_characters(self, words: list[Word]) -> MCounter:
        """extract character frequencies from words"""
        counter = MCounter()
        for i, word in enumerate(words):
            counter.update(MCounter(word.str) * word.freq)
            if i > 0 and i % 100000 == 0:
                logging.info(f'Processed {i} words for character extraction.')
        return counter

    def _filter_characters(self, characters: MCounter) -> MCounter:
        """filter rare characters based on coverage threshold"""
        if self.coverage < 1:
            corpus_size = sum(characters.values())
            freq_to_remove = corpus_size - round(self.coverage * corpus_size)
            if freq_to_remove > 0:
                cum_sum = np.cumsum([freq for _, freq in reversed(characters.most_common())])
                num_to_remove = np.searchsorted(cum_sum, freq_to_remove)
                characters_to_remove = [c for c, _ in characters.most_common()[-num_to_remove:]]
                for c in characters_to_remove:
                    characters.pop(c)
                logging.info(f'Replaced {num_to_remove} rare characters with UNK.')
        return characters

    def _initialize_vocab(self, words: list[Word]) -> None:
        """initialize vocabulary with Tamil characters"""
        logging.info('Initializing vocabulary with Tamil characters...')
        characters = self._get_characters(words)
        filtered_characters = self._filter_characters(characters)

        for i, character in enumerate(filtered_characters):
            token = Token(self.new_id + i, character, filtered_characters[character])
            self.id2token[token.id] = token
            self.str2token[token.str] = token

        self.new_id += len(filtered_characters)
        self.actual_vocab_size += len(filtered_characters)

        tamil_chars = sum(1 for char in filtered_characters if '\u0B80' <= char <= '\u0BFF')
        tamil_numerals = sum(1 for char in filtered_characters if '\u0BE6' <= char <= '\u0BEF')

        logging.info(f'Initialized vocabulary with {len(filtered_characters)} unique characters.')
        logging.info(f'Found {tamil_chars} Tamil script characters.')
        logging.info(f'Found {tamil_numerals} Tamil numerals.')

    @staticmethod
    def _validate_pair(pair) -> bool:
        """check if pair contains only non-special tokens"""
        return not any(token.special for token in pair)

    def _encode_words(self, words: list[Word]) -> None:
        """encode words using current vocabulary"""
        logging.info('Encoding words with Tamil characters...')
        for i, word in enumerate(words):
            word.encode(self.str2token)
            if i > 0 and i % 100000 == 0:
                logging.info(f'Processed {i} words for encoding.')

    def _initialize_pairs(self, words: list[Word]) -> MCounter:
        """initialize pair frequencies"""
        pairs = MCounter()
        logging.info('Counting Tamil character pairs...')
        for i, word in enumerate(words):
            pairs.update(word.pairs)
            if i > 0 and i % 100000 == 0:
                logging.info(f'Processed {i} words for pair counting.')

        # remove pairs containing special tokens
        to_remove = set()
        for pair in pairs:
            if not self._validate_pair(pair):
                to_remove.add(pair)
        for pair in to_remove:
            pairs.pop(pair)

        return pairs

    def _remove_if_possible(self, token: Token, merged_freq: int, pairs: MCounter) -> bool:
        """remove token if it meets the threshold criteria"""
        if merged_freq / (token.freq + merged_freq) > self.threshold:
            split = token.split_if_possible()
            if split is not None:
                self.actual_vocab_size -= 1
                for t in split:
                    t.freq += token.freq
                for pair in zip(split[:-1], split[1:]):
                    pairs[pair] += token.freq

                pairs_for_update = MCounter()
                for word in token.words:
                    if token not in word.tokens:
                        raise ValueError(f'Token {token} not found in word {word}.')
                    pairs_for_update.update({
                        pair: freq for pair, freq in word.pairs.items()
                        if self._validate_pair(pair) and token in pair
                    })
                    word.split_token(token, split)

                self._update_pairs_on_remove(token, split, pairs_for_update, pairs)
                token.remove()
                return True
        return False

    @staticmethod
    def _update_pairs_on_merge(new_token: Token, pair: tuple[Token, Token],
                              pairs_for_update: MCounter, pairs: MCounter):
        """update pair frequencies after merge operation"""
        pairs.update(pairs_for_update)
        for p, freq in pairs_for_update.items():
            if new_token not in p:
                raise ValueError(f'Pair {p} does not contain the new token {new_token}.')
            if new_token is p[0]:
                if new_token is p[1]:
                    to_update = (pair[1], pair[0])
                else:
                    to_update = (pair[1], p[1])
            else:
                to_update = (p[0], pair[0])
            if to_update in pairs:
                pairs[to_update] -= freq
                if pairs[to_update] <= 0:
                    pairs.pop(to_update)

    @staticmethod
    def _update_pairs_on_remove(token: Token, split: list[Token],
                               pairs_for_update: MCounter, pairs: MCounter):
        """update pair frequencies after split operation"""
        for pair, freq in pairs_for_update.items():
            if token is pair[0]:
                if token is pair[1]:
                    to_update = (split[-1], split[0])
                else:
                    to_update = (split[-1], pair[1])
            else:
                to_update = (pair[0], split[0])
            pairs[to_update] += freq
            pairs.pop(pair)

    def _merge_token_in_words(self, token_to_merge: Token, pair_to_merge: tuple[Token, Token],
                             pairs: MCounter) -> int:
        """merge token in all relevant words"""
        actual_freq = 0
        pairs_for_update = MCounter()

        for word in pair_to_merge[0].words & pair_to_merge[1].words:
            if pair_to_merge in word.pairs:
                word.pairs.pop(pair_to_merge)
                actual_freq += word.merge_pair(pair_to_merge, token_to_merge)
                pairs_for_update.update({
                    p: f for p, f in word.pairs.items()
                    if self._validate_pair(p) and token_to_merge in p
                })

        self._update_pairs_on_merge(token_to_merge, pair_to_merge, pairs_for_update, pairs)
        token_to_merge.freq += actual_freq

        if pair_to_merge[0] is pair_to_merge[1]:
            pair_to_merge[0].freq -= 2 * actual_freq
            removed = self._remove_if_possible(pair_to_merge[0], actual_freq, pairs)
            if removed:
                logging.info(f'Removed token {pair_to_merge[0].str} after merging into {token_to_merge.str}.')
                self.events.append(('SPLIT', pair_to_merge[0], pair_to_merge[0].walk()))
        else:
            for token in pair_to_merge:
                if not token.present:
                    raise ValueError(f'Token {token} is not present in vocabulary.')
                token.freq -= actual_freq
                token_freq = token.freq
                removed = self._remove_if_possible(token, actual_freq, pairs)
                if removed:
                    logging.info(f'Removed token {token.str} after merging into {token_to_merge.str}.')
                    self.events.append(('SPLIT', token, token.walk()))

        return actual_freq

    def _merge_pair(self, pair: tuple[Token, Token], pairs: MCounter) -> int:
        """merge a token pair"""
        pairs.pop(pair)
        merged_str = pair[0].str + pair[1].str

        if merged_str in self.str2token:
            new_token = self.str2token[merged_str]
            if not new_token.present:
                new_token.restore()
                logging.info(f'Restored previously removed token {new_token.str}.')
            else:
                logging.info(f'Additional merges for {new_token.str}.')
        else:
            new_token = Token(self.new_id, merged_str, 0, left=pair[0], right=pair[1])
            self.id2token[new_token.id] = new_token
            self.str2token[new_token.str] = new_token
            self.new_id += 1

        self.events.append(('MERGE', pair, new_token))
        actual_freq = self._merge_token_in_words(new_token, pair, pairs)
        return actual_freq

    def fit(self, input_file: str, output_dir: str, logging_step: int = 200) -> None:
        """train Tamil PickyBPE tokenizer"""
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)

        logging.info("Starting Tamil PickyBPE training...")
        logging.info(f"Input: {input_file}")
        logging.info(f"Output: {output_dir}")
        logging.info(f"Target vocab size: {self.desired_vocab_size:,}")
        logging.info(f"Coverage: {self.coverage*100:.2f}%")
        logging.info(f"Threshold: {self.threshold*100:.2f}%")

        words = self._get_words(input_file)
        self._initialize_vocab(words)
        self._encode_words(words)
        pairs = self._initialize_pairs(words)

        merge_time = []
        logging.info(f'Starting PickyBPE training for Tamil with {self.actual_vocab_size} initial tokens.')

        while self.actual_vocab_size < self.desired_vocab_size:
            start_time = time.time()
            if not pairs:
                logging.info(f'No more pairs to merge. Stopping with vocab size of {self.actual_vocab_size}.')
                break

            pair, count = pairs.most_common(1)[0]
            if count <= 0:
                logging.info(f'No more pairs to merge. Stopping with vocab size of {self.actual_vocab_size}.')
                break

            freq = self._merge_pair(pair, pairs)
            self.actual_vocab_size += 1
            merge_time.append(time.time() - start_time)

            if self.actual_vocab_size % logging_step == 0:
                avg_time = np.mean(merge_time) if merge_time else 0
                current_speed = 1.0 / avg_time if avg_time > 0 else 0
                logging.info(
                    f'VOCABULARY SIZE: {self.actual_vocab_size:,}/{self.desired_vocab_size:,}. '
                    f'Merged "{pair[0].str}" + "{pair[1].str}" with frequency {freq:,}. '
                    f'Speed: {current_speed:.1f} merges/sec'
                )
                merge_time = []

        self._save_picky_model(output_path / 'picky_bpe_model.json')
        self._save_huggingface_files(output_path)

        logging.info(f'Training completed. Files saved to {output_path}')

    def _save_picky_model(self, file_path: Path) -> None:
        """save PickyBPE model in original format"""
        logging.info(f'Saving Tamil PickyBPE model to {file_path}...')

        assigned_ids = sorted(self.id2token.keys())
        id_mapping = {}
        id_counter = 0

        for i in assigned_ids:
            if self.id2token[i].present:
                id_mapping[i] = id_counter
                id_counter += 1

        model_data = {
            'language': 'tamil',
            'script': 'tamil',
            'tokens': [token.to_dict() for token in self.id2token.values()],
            'id2int': {str(k): v for k, v in id_mapping.items()},
            'int2id': {str(v): k for k, v in id_mapping.items()},
            'merges': [
                {'id': i, 'pair': [token.to_dict() for token in merge[1]], 'new_token': merge[2].to_dict()}
                for i, merge in enumerate(self.events) if merge[0] == 'MERGE'
            ],
            'splits': [
                {'id': i, 'token': merge[1].to_dict(), 'split': [token.to_dict() for token in merge[2]]}
                for i, merge in enumerate(self.events) if merge[0] == 'SPLIT'
            ],
            'training_config': {
                'coverage': self.coverage,
                'threshold': self.threshold,
                'vocab_size': self.desired_vocab_size
            }
        }

        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(model_data, f, indent=2, ensure_ascii=False)

    def _save_huggingface_files(self, output_path: Path) -> None:
        """convert and save HuggingFace compatible tokenizer files"""

        # extract present tokens and create vocabulary
        vocab = {}
        present_tokens = []
        for token_id in sorted(self.id2token.keys()):
            token = self.id2token[token_id]
            if token.present:
                vocab[token.str] = len(present_tokens)
                present_tokens.append(token)

        # create merge rules from events
        merges = []
        for event in self.events:
            if event[0] == 'MERGE':
                pair = event[1]
                left_str = pair[0].str
                right_str = pair[1].str
                if left_str in vocab and right_str in vocab:
                    merges.append(f"{left_str} {right_str}")

        # tokenizer.json
        tokenizer_data = {
            "version": "1.0",
            "truncation": None,
            "padding": None,
            "added_tokens": [],
            "normalizer": {
                "type": "NFC"
            },
            "pre_tokenizer": {
                "type": "Sequence",
                "pretokenizers": [
                    {
                        "type": "WhitespaceSplit"
                    },
                    {
                        "type": "Metaspace",
                        "replacement": WHITESPACE,
                        "add_prefix_space": True
                    }
                ]
            },
            "post_processor": {
                "type": "TemplateProcessing",
                "single": f"{BOS}:1 $A:0 {EOS}:1",
                "pair": f"{BOS}:1 $A:0 {EOS}:1 $B:0 {EOS}:1",
                "special_tokens": {
                    BOS: {"id": 1, "type_id": 1},
                    EOS: {"id": 2, "type_id": 1}
                }
            },
            "decoder": {
                "type": "Metaspace",
                "replacement": WHITESPACE,
                "add_prefix_space": True
            },
            "model": {
                "type": "BPE",
                "dropout": None,
                "unk_token": UNK,
                "continuing_subword_prefix": None,
                "end_of_word_suffix": None,
                "fuse_unk": False,
                "vocab": vocab,
                "merges": merges
            }
        }

        with open(output_path / 'tokenizer.json', 'w', encoding='utf-8') as f:
            json.dump(tokenizer_data, f, indent=2, ensure_ascii=False)

        # tokenizer_config.json
        config_data = {
            "tokenizer_class": "PreTrainedTokenizerFast",
            "auto_map": {
                "AutoTokenizer": ["tokenizer.json", None]
            },
            "bos_token": BOS,
            "eos_token": EOS,
            "unk_token": UNK,
            "pad_token": PAD,
            "model_max_length": 2048,
            "padding_side": "left",
            "truncation_side": "right",
            "chat_template": None,
            "clean_up_tokenization_spaces": True,
            "spaces_between_special_tokens": False,
            "language": "tamil",
            "script": "tamil"
        }

        with open(output_path / 'tokenizer_config.json', 'w', encoding='utf-8') as f:
            json.dump(config_data, f, indent=2, ensure_ascii=False)

        # special_tokens_map.json
        special_tokens_data = {
            "bos_token": BOS,
            "eos_token": EOS,
            "unk_token": UNK,
            "pad_token": PAD
        }

        with open(output_path / 'special_tokens_map.json', 'w', encoding='utf-8') as f:
            json.dump(special_tokens_data, f, indent=2, ensure_ascii=False)

        # added_tokens.json
        with open(output_path / 'added_tokens.json', 'w', encoding='utf-8') as f:
            json.dump([], f, indent=2, ensure_ascii=False)

        # vocab.json
        with open(output_path / 'vocab.json', 'w', encoding='utf-8') as f:
            json.dump(vocab, f, indent=2, ensure_ascii=False)

        logging.info(f'Final vocabulary size: {len(vocab):,}')
        logging.info(f'Number of merge rules: {len(merges):,}')
        logging.info(f'Number of split events: {sum(1 for event in self.events if event[0] == "SPLIT")}')

def train_tamil_picky_bpe(
    input_file: str,
    output_dir: str = "./tamil_picky_bpe",
    vocab_size: int = 10000,
    coverage: float = 0.9999,
    threshold: float = 0.9999,
    logging_step: int = 200
):
    """train a PickyBPE tokenizer for Tamil language"""
    tokenizer = TamilPickyBPE(
        vocab_size=vocab_size,
        coverage=coverage,
        threshold=threshold
    )

    print("Training PickyBPE tokenizer for Tamil...")
    print(f"Input file: {input_file}")
    print(f"Output directory: {output_dir}")
    print(f"Target vocabulary size: {vocab_size:,}")
    print(f"Special tokens: {UNK} (0), {BOS} (1), {EOS} (2), {PAD} (3)")
    print(f"Character coverage: {coverage*100:.2f}%")
    print(f"Removal threshold: {threshold*100:.2f}%")

    start_time = time.time()
    tokenizer.fit(input_file, output_dir, logging_step)
    training_time = time.time() - start_time

    return tokenizer

def test_tamil_picky_bpe(tokenizer_path: str):
    """test the trained tokenizer"""
    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

    test_sentences = [
        "வணக்கம், நீங்கள் எப்படி இருக்கிறீர்கள்?",
        "நான் தமிழ் மொழி கற்றுக்கொண்டிருக்கிறேன்.",
        "சென்னை தமிழ்நாட்டின் தலைநகரம்.",
        "புத்தகம் மேசையில் இருக்கிறது.",
        "இன்று வானிலை மிகவும் நல்லது.",
    ]

    print("\nTesting Tamil PickyBPE tokenizer:")
    print("=" * 70)

    for i, sentence in enumerate(test_sentences, 1):
        print(f"\n{i}. Original: {sentence}")

        tokens = tokenizer.tokenize(sentence)
        print(f"   Tokens: {tokens}")
        print(f"   Count: {len(tokens)} tokens")

        token_ids = tokenizer.encode(sentence, add_special_tokens=False)
        print(f"   IDs: {token_ids}")

        decoded = tokenizer.decode(token_ids)
        print(f"   Decoded: {decoded}")

        if decoded.strip() == sentence.strip():
            print("   Perfect reconstruction")
        else:
            print("   Reconstruction differs")

def main():
    input_file = "/content/drive/My Drive/Colab Notebooks/LRLs/tamil/dataset/ta_reduced_train.txt"
    output_dir = "/content/drive/My Drive/Colab Notebooks/LRLs/tamil/tokenizers/picky_bpe"

    vocab_size = 10000
    coverage = 0.9999
    threshold = 0.9999
    logging_step = 500

    tokenizer = train_tamil_picky_bpe(
        input_file=input_file,
        output_dir=output_dir,
        vocab_size=vocab_size,
        coverage=coverage,
        threshold=threshold,
        logging_step=logging_step
    )

    test_tamil_picky_bpe(output_dir)

if __name__ == '__main__':
    main()

# Grapheme Picky BPE

In [None]:
from __future__ import annotations
import json
import logging
import time
import argparse
import re
import sys
import os
import unicodedata
from transformers import AutoTokenizer
from pathlib import Path
from collections import defaultdict, Counter
from typing import Union, Optional, Dict, List
import numpy as np
import grapheme

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    handlers=[
        logging.StreamHandler(),
    ],
    force=True
)
logger = logging.getLogger(__name__)

WHITESPACE = '▁'
PAD = '<pad>'
UNK = '<unk>'
BOS = '<s>'
EOS = '</s>'

TAMIL_CONSONANTS = set('\u0B95\u0B99\u0B9A\u0B9E\u0B9F\u0BA3\u0BA4\u0BA8\u0BAA\u0BAE\u0BAF\u0BB0\u0BB2\u0BB5\u0BB6\u0BB7\u0BB8\u0BB9')
TAMIL_VOWELS = set('\u0B85\u0B86\u0B87\u0B88\u0B89\u0B8A\u0B8E\u0B8F\u0B90\u0B92\u0B93\u0B94')
TAMIL_VOWEL_SIGNS = set('\u0BBE\u0BBF\u0BC0\u0BC1\u0BC2\u0BC6\u0BC7\u0BC8\u0BCA\u0BCB\u0BCC')
TAMIL_VIRAMA = '\u0BCD'
TAMIL_ANUSVARA = '\u0B82'
TAMIL_VISARGA = '\u0B83'

class MCounter(Counter):
    """extended Counter class with multiplication support"""
    def __mul__(self, other):
        if not isinstance(other, int):
            raise TypeError("Non-int factor")
        return MCounter({k: other * v for k, v in self.items()})

    def __rmul__(self, other):
        return self * other

    def __add__(self, other):
        return MCounter(super().__add__(other))

class Token:
    def __init__(
        self,
        id: int,
        str: str,
        freq: int = 0,
        special: bool = False,
        present: bool = True,
        left: Optional[Token] = None,
        right: Optional[Token] = None,
        split: Optional[list[Token]] = None
    ):
        self.id = id
        self.str = str
        self.freq = freq
        self.special = special
        self.present = present
        self.atomic = len(str) == 1 or special
        self.words = set()
        self.left = left
        self.right = right
        self.split = split

    def __repr__(self):
        return f'{self.str} ({self.freq})'

    def walk(self) -> list[Token]:
        if self.atomic or self.present:
            return [self]
        result = []
        if self.left is not None:
            result.extend(self.left.walk())
        if self.right is not None:
            result.extend(self.right.walk())
        if not result:
            return [self]
        return result

    def remove(self) -> None:
        if self.atomic:
            raise ValueError(f'Cannot remove an atomic token {self.str}.')
        self.present = False
        self.freq = 0
        self.words = set()

    def restore(self) -> None:
        if self.present:
            raise ValueError(f'Cannot revoke already present token {self.str}.')
        self.present = True

    def split_if_possible(self) -> Optional[list[Token]]:
        if self.atomic:
            return None
        self.present = False
        return self.walk()

    def to_dict(self) -> dict:
        return {
            'id': self.id,
            'str': self.str,
            'freq': self.freq,
            'special': self.special,
            'present': self.present,
            'left': self.left.id if self.left is not None else None,
            'right': self.right.id if self.right is not None else None,
            'split': [t.id for t in self.walk()]
        }

class Word:
    def __init__(self, id: int, word: str, freq: int = 0):
        self.id = id
        self.str = word
        self.freq = freq
        self.tokens = None
        self.pairs = None

    def __repr__(self) -> str:
        return f'{self.str} ({self.freq})'

    def encode(self, str2token: dict[str, Token]) -> None:
        graphemes_list = list(grapheme.graphemes(self.str))
        self.tokens = [str2token[g] for g in graphemes_list]
        self._recalculate()

    def _recalculate(self, update_tokens: bool = True) -> None:
        self.pairs = MCounter(zip(self.tokens[:-1], self.tokens[1:])) * self.freq
        if update_tokens:
            for token in self.tokens:
                token.words.add(self)

    def merge_pair(self, pair: tuple[Token, Token], new_token: Token, update_tokens: bool = True) -> int:
        new_tokens = []
        i = 0
        while i < len(self.tokens):
            if i < len(self.tokens) - 1 and (self.tokens[i], self.tokens[i+1]) == pair:
                new_tokens.append(new_token)
                i += 2
            else:
                new_tokens.append(self.tokens[i])
                i += 1
        new_token_frequency = len(self.tokens) - len(new_tokens)
        if update_tokens:
            pair[0].words.discard(self)
            pair[1].words.discard(self)
        self.tokens = new_tokens
        self._recalculate(update_tokens=update_tokens)
        return new_token_frequency * self.freq

    def split_token(self, token: Token, split: list[Token], update_tokens: bool = True):
        new_tokens = []
        for t in self.tokens:
            if t == token:
                new_tokens.extend(split)
            else:
                new_tokens.append(t)
        self.tokens = new_tokens
        self._recalculate(update_tokens=update_tokens)

class TamilGraphemePickyBPE:
    def __init__(
        self,
        vocab_size: int,
        pad_id: int = 3,
        unk_id: int = 0,
        bos_id: int = 1,
        eos_id: int = 2,
        coverage: float = 0.9999,
        threshold: float = 0.9999,
    ):
        self.desired_vocab_size = vocab_size
        self.pad_token = Token(pad_id, PAD, 0, special=True)
        self.unk_token = Token(unk_id, UNK, 0, special=True)
        self.bos_token = Token(bos_id, BOS, 0, special=True)
        self.eos_token = Token(eos_id, EOS, 0, special=True)

        self.id2token = {
            token.id: token for token in [self.pad_token, self.unk_token, self.bos_token, self.eos_token]
        }
        self.str2token = {
            token.str: token for token in [self.pad_token, self.unk_token, self.bos_token, self.eos_token]
        }
        self.str2token = defaultdict(lambda: self.unk_token, self.str2token)
        self.max_special_token_id = max(self.id2token.keys())
        self.actual_vocab_size = len(self.id2token)
        self.new_id = self.max_special_token_id + 1
        self.coverage = coverage
        self.threshold = threshold
        self.events = list()
        self.grapheme_vocab = set()

    @staticmethod
    def _validate_pair(pair) -> bool:
        """check if pair contains only non-special tokens"""
        return not any(token.special for token in pair)

    def _preprocess_tamil_text(self, text: str) -> str:
        """preprocess Tamil text preserving grapheme clusters"""
        text = unicodedata.normalize('NFC', text)
        text = text.replace(' ', f' {WHITESPACE}')
        text = re.sub(r'[^\u0B80-\u0BFF\s\w\u0030-\u0039\u002E\u002C\u003F\u0021\u003A\u003B\u002D]', ' ', text)
        text = re.sub(r'([a-zA-Z]+)', r' \1 ', text)
        text = re.sub(r'([௦-௯]+)', r' \1 ', text)
        text = re.sub(r'\s+', ' ', text)
        return text.strip()

    def _get_words(self, file: str) -> list[Word]:
        logging.info(f'Loading corpus from {file}...')
        start_time = time.time()

        counter = MCounter()
        with open(file, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if not line.strip():
                    continue

                processed_line = self._preprocess_tamil_text(line.strip())
                if not processed_line:
                    continue

                words = processed_line.split()
                words = [WHITESPACE + word if not word.startswith(WHITESPACE) else word for word in words]
                counter.update(words)

                if i > 0 and i % 50000 == 0:
                    logging.info(f'Processed {i} lines.')

        num_words = len(counter)
        logging.info(f'Loaded {num_words} unique words in {time.time() - start_time:.2f}s.')

        return [Word(i, word, freq) for i, (word, freq) in enumerate(counter.items())]

    def _extract_graphemes(self, words: list[Word]) -> MCounter:
        """extract graphemes with Tamil cluster awareness"""
        logging.info('Extracting Tamil grapheme clusters from corpus...')
        start_time = time.time()

        grapheme_counter = MCounter()

        for i, word in enumerate(words):
            clusters = list(grapheme.graphemes(word.str))

            for cluster in clusters:
                if cluster.strip():
                    grapheme_counter[cluster] += word.freq
                    self.grapheme_vocab.add(cluster)

            if i > 0 and i % 100000 == 0:
                logging.info(f'Processed {i} words for Tamil grapheme extraction.')

        processing_time = time.time() - start_time
        logging.info(f'Extracted {len(grapheme_counter)} Tamil grapheme clusters in {processing_time:.2f}s.')

        tamil_clusters = sum(1 for g in grapheme_counter if any('\u0B80' <= char <= '\u0BFF' for char in g))
        logging.info(f'Found {tamil_clusters} Tamil script clusters out of {len(grapheme_counter)} total.')

        return grapheme_counter

    def _filter_graphemes(self, graphemes: MCounter) -> MCounter:
        """filter rare graphemes based on coverage threshold"""
        if self.coverage < 1:
            corpus_size = sum(graphemes.values())
            freq_to_remove = corpus_size - round(self.coverage * corpus_size)
            if freq_to_remove > 0:
                cum_sum = np.cumsum([freq for _, freq in reversed(graphemes.most_common())])
                num_to_remove = np.searchsorted(cum_sum, freq_to_remove)
                graphemes_to_remove = [g for g, _ in graphemes.most_common()[-num_to_remove:]]
                for g in graphemes_to_remove:
                    graphemes.pop(g)
                    self.grapheme_vocab.discard(g)
                logging.info(f'Replaced {num_to_remove} rare graphemes with UNK.')
        return graphemes

    def _initialize_vocab(self, graphemes: MCounter) -> None:
        """initialize the BPE vocabulary from extracted graphemes"""
        next_id = self.new_id

        for grapheme_str in sorted(graphemes.keys()):
            if grapheme_str not in self.str2token:
                token_obj = Token(str=grapheme_str, id=next_id, special=False)
                self.id2token[next_id] = token_obj
                self.str2token[grapheme_str] = token_obj
                next_id += 1

        self.actual_vocab_size = len(self.id2token)
        self.new_id = next_id

        special_count = len([t for t in self.id2token.values() if t.special])
        logging.info(f"Initialized vocab: {self.actual_vocab_size} tokens ({special_count} special + {self.actual_vocab_size - special_count} learned)")

    def _validate_tamil_merge(self, pair: tuple) -> bool:
        """validate that merging two tokens won't break Tamil grapheme rules"""
        if not self._validate_pair(pair):
            return False

        left_str = pair[0].str
        right_str = pair[1].str
        merged_str = left_str + right_str

        original_graphemes = list(grapheme.graphemes(left_str)) + list(grapheme.graphemes(right_str))
        merged_graphemes = list(grapheme.graphemes(merged_str))

        return len(merged_graphemes) <= len(original_graphemes)

    def _encode_words(self, words: list[Word]) -> None:
        logging.info('Encoding words with Tamil graphemes...')

        for i, word in enumerate(words):
            word.encode(self.str2token)

            if i < 5:
                tokens_str = [token.str for token in word.tokens]
                logging.info(f'Word "{word.str}" -> tokens: {tokens_str}')
                if word.pairs:
                    pair_strs = [(f"{p[0].str}+{p[1].str}", freq) for p, freq in word.pairs.items()]
                    logging.info(f'  Pairs: {pair_strs[:5]}')

            if i > 0 and i % 100000 == 0:
                logging.info(f'Processed {i} words for grapheme encoding.')

        logging.info(f'Encoding complete.')

    def _initialize_pairs(self, words: list[Word]) -> MCounter:
        """initialize pair frequencies from grapheme-encoded words"""
        pairs = MCounter()
        logging.info('Counting Tamil grapheme pairs...')

        for i, word in enumerate(words):
            pairs.update(word.pairs)
            if i > 0 and i % 100000 == 0:
                logging.info(f'Processed {i} words for pair counting.')

        to_remove = set()
        for pair in pairs:
            if not self._validate_pair(pair):
                to_remove.add(pair)

        for pair in to_remove:
            pairs.pop(pair)

        logging.info(f'Final pair count: {len(pairs)} unique pairs, {sum(pairs.values())} total instances.')

        if pairs:
            top_pairs = pairs.most_common(10)
            logging.info(f'Top 10 pairs: {[(f"{p[0].str}+{p[1].str}", freq) for p, freq in top_pairs]}')

        return pairs

    def _remove_if_possible(self, token: Token, merged_freq: int, pairs: MCounter) -> bool:
        """remove token if it meets the threshold criteria"""
        if token.freq + merged_freq == 0:
            return False

        if merged_freq / (token.freq + merged_freq) > self.threshold:
            split = token.split_if_possible()
            if split is not None:
                self.actual_vocab_size -= 1
                for t in split:
                    t.freq += token.freq
                for pair in zip(split[:-1], split[1:]):
                    pairs[pair] += token.freq

                pairs_for_update = MCounter()
                for word in token.words:
                    if token not in word.tokens:
                        raise ValueError(f'Token {token} not found in word {word}.')
                    pairs_for_update.update({
                        pair: freq for pair, freq in word.pairs.items()
                        if self._validate_pair(pair) and token in pair
                    })
                    word.split_token(token, split)

                self._update_pairs_on_remove(token, split, pairs_for_update, pairs)
                token.remove()
                return True
        return False

    @staticmethod
    def _update_pairs_on_merge(new_token: Token, pair: tuple[Token, Token],
                              pairs_for_update: MCounter, pairs: MCounter):
        """update pair frequencies after merge operation"""
        pairs.update(pairs_for_update)
        for p, freq in pairs_for_update.items():
            if new_token not in p:
                raise ValueError(f'Pair {p} does not contain the new token {new_token}.')
            if new_token is p[0]:
                if new_token is p[1]:
                    to_update = (pair[1], pair[0])
                else:
                    to_update = (pair[1], p[1])
            else:
                to_update = (p[0], pair[0])
            if to_update in pairs:
                pairs[to_update] -= freq
                if pairs[to_update] <= 0:
                    pairs.pop(to_update)

    @staticmethod
    def _update_pairs_on_remove(token: Token, split: list[Token],
                               pairs_for_update: MCounter, pairs: MCounter):
        """update pair frequencies after split operation"""
        for pair, freq in pairs_for_update.items():
            if token is pair[0]:
                if token is pair[1]:
                    to_update = (split[-1], split[0])
                else:
                    to_update = (split[-1], pair[1])
            else:
                to_update = (pair[0], split[0])
            pairs[to_update] += freq
            pairs.pop(pair)

    def _merge_token_in_words(self, token_to_merge: Token, pair_to_merge: tuple[Token, Token],
                             pairs: MCounter) -> int:
        """merge token in all relevant words"""
        actual_freq = 0
        pairs_for_update = MCounter()

        for word in pair_to_merge[0].words & pair_to_merge[1].words:
            if pair_to_merge in word.pairs:
                word.pairs.pop(pair_to_merge)
                actual_freq += word.merge_pair(pair_to_merge, token_to_merge)
                pairs_for_update.update({
                    p: f for p, f in word.pairs.items()
                    if self._validate_pair(p) and token_to_merge in p
                })

        self._update_pairs_on_merge(token_to_merge, pair_to_merge, pairs_for_update, pairs)
        token_to_merge.freq += actual_freq

        if pair_to_merge[0] is pair_to_merge[1]:
            pair_to_merge[0].freq -= 2 * actual_freq
            removed = self._remove_if_possible(pair_to_merge[0], actual_freq, pairs)
            if removed:
                logging.info(f'Removed token {pair_to_merge[0].str} after merging into {token_to_merge.str}.')
                self.events.append(('SPLIT', pair_to_merge[0], pair_to_merge[0].walk()))
        else:
            for token in pair_to_merge:
                if not token.present:
                    raise ValueError(f'Token {token} is not present in vocabulary.')
                token.freq -= actual_freq
                removed = self._remove_if_possible(token, actual_freq, pairs)
                if removed:
                    logging.info(f'Removed token {token.str} after merging into {token_to_merge.str}.')
                    self.events.append(('SPLIT', token, token.walk()))

        return actual_freq

    def _merge_pair(self, pair: tuple[Token, Token], pairs: MCounter) -> int:
        """merge a token pair with Tamil validation"""
        if not self._validate_tamil_merge(pair):
            return 0

        pairs.pop(pair)
        merged_str = pair[0].str + pair[1].str

        if merged_str in self.str2token:
            new_token = self.str2token[merged_str]
            if not new_token.present:
                new_token.restore()
                logging.info(f'Restored previously removed token {new_token.str}.')
            else:
                logging.info(f'Additional merges for {new_token.str}.')
        else:
            new_token = Token(self.new_id, merged_str, 0, left=pair[0], right=pair[1])
            self.id2token[new_token.id] = new_token
            self.str2token[new_token.str] = new_token
            self.new_id += 1

        self.events.append(('MERGE', pair, new_token))
        actual_freq = self._merge_token_in_words(new_token, pair, pairs)
        return actual_freq

    def fit(self, input_file: str, output_dir: str, logging_step: int = 200) -> None:
        """train Grapheme PickyBPE tokenizer"""
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)

        logging.info("Starting Tamil Grapheme-Aware PickyBPE training...")
        logging.info(f"Input: {input_file}")
        logging.info(f"Output: {output_dir}")
        logging.info(f"Target vocab size: {self.desired_vocab_size:,}")
        logging.info(f"Coverage: {self.coverage*100:.2f}%")
        logging.info(f"Threshold: {self.threshold*100:.2f}%")

        words = self._get_words(input_file)
        graphemes = self._extract_graphemes(words)
        filtered_graphemes = self._filter_graphemes(graphemes)
        self._initialize_vocab(filtered_graphemes)
        self._encode_words(words)
        pairs = self._initialize_pairs(words)

        merge_time = []
        logging.info(f'Starting BPE training with {self.actual_vocab_size} initial tokens.')

        while self.actual_vocab_size < self.desired_vocab_size:
            start_time = time.time()
            if not pairs:
                logging.info(f'No more pairs to merge. Stopping with vocab size of {self.actual_vocab_size}.')
                break

            pair, count = pairs.most_common(1)[0]
            if count <= 0:
                logging.info(f'No more pairs to merge. Stopping with vocab size of {self.actual_vocab_size}.')
                break

            freq = self._merge_pair(pair, pairs)
            self.actual_vocab_size += 1
            merge_time.append(time.time() - start_time)

            if self.actual_vocab_size % logging_step == 0:
                avg_time = np.mean(merge_time) if merge_time else 0
                current_speed = 1.0 / avg_time if avg_time > 0 else 0
                logging.info(
                    f'VOCABULARY SIZE: {self.actual_vocab_size:,}/{self.desired_vocab_size:,}. '
                    f'Merged "{pair[0].str}" + "{pair[1].str}" with frequency {freq:,}. '
                    f'Speed: {current_speed:.1f} merges/sec'
                )
                merge_time = []

        self._save_picky_model(output_path / 'grapheme_picky_bpe_model.json')
        self._save_huggingface_files(output_path)

        logging.info(f'Training completed. Files saved to {output_path}')

    def _save_picky_model(self, file_path: Path) -> None:
        logging.info(f'Saving Tamil Grapheme PickyBPE model to {file_path}...')

        assigned_ids = sorted(self.id2token.keys())
        id_mapping = {}
        id_counter = 0

        for i in assigned_ids:
            if self.id2token[i].present:
                id_mapping[i] = id_counter
                id_counter += 1

        model_data = {
            'language': 'tamil',
            'script': 'tamil',
            'algorithm': 'Grapheme-Aware PickyBPE',
            'tokens': [token.to_dict() for token in self.id2token.values()],
            'id2int': {str(k): v for k, v in id_mapping.items()},
            'int2id': {str(v): k for k, v in id_mapping.items()},
            'merges': [
                {'id': i, 'pair': [token.to_dict() for token in merge[1]], 'new_token': merge[2].to_dict()}
                for i, merge in enumerate(self.events) if merge[0] == 'MERGE'
            ],
            'splits': [
                {'id': i, 'token': merge[1].to_dict(), 'split': [token.to_dict() for token in merge[2]]}
                for i, merge in enumerate(self.events) if merge[0] == 'SPLIT'
            ],
            'training_config': {
                'coverage': self.coverage,
                'threshold': self.threshold,
                'vocab_size': self.desired_vocab_size
            },
            'grapheme_vocab': list(self.grapheme_vocab)
        }

        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(model_data, f, indent=2, ensure_ascii=False)

    def _save_huggingface_files(self, output_path: Path) -> None:

        vocab = {}
        added_tokens = []
        special_tokens = [UNK, BOS, EOS, PAD]

        next_id = 0
        for tok in special_tokens:
            vocab[tok] = next_id
            added_tokens.append({
                "id": next_id,
                "content": tok,
                "special": True,
                "single_word": False,
                "lstrip": False,
                "rstrip": False,
                "normalized": False
            })
            next_id += 1

        non_special_tokens = [
            token for token in self.id2token.values()
            if token.present and not token.special
        ]
        for token in non_special_tokens:
            if token.str not in vocab:
                vocab[token.str] = next_id
                next_id += 1

        logging.info(f"Final vocab size: {len(vocab)}")

        merges = []
        for event in self.events:
            if event[0] == 'MERGE':
                left, right = event[1]
                if left.str in vocab and right.str in vocab:
                    merges.append(f"{left.str} {right.str}")

        logging.info(f"Number of merge rules: {len(merges)}")

        tokenizer_data = {
            "version": "1.0",
            "truncation": None,
            "padding": None,
            "added_tokens": added_tokens,
            "normalizer": {"type": "NFC"},
            "pre_tokenizer": {
                "type": "Metaspace",
                "replacement": WHITESPACE,
                "add_prefix_space": True,
                "prepend_scheme": "always"
            },
            "post_processor": {
                "type": "TemplateProcessing",
                "single": f"{BOS}:1 $A:0 {EOS}:2",
                "pair": f"{BOS}:1 $A:0 {EOS}:2 $B:0 {EOS}:2",
                "special_tokens": {
                    BOS: {"id": vocab[BOS], "type_id": 1},
                    EOS: {"id": vocab[EOS], "type_id": 1}
                }
            },
            "decoder": {
                "type": "Metaspace",
                "replacement": WHITESPACE,
                "add_prefix_space": True,
                "prepend_scheme": "always"
            },
            "model": {
                "type": "BPE",
                "unk_token": UNK,
                "vocab": vocab,
                "merges": merges
            }
        }

        (output_path / "tokenizer.json").write_text(
            json.dumps(tokenizer_data, indent=2, ensure_ascii=False), encoding="utf-8"
        )
        (output_path / "vocab.json").write_text(
            json.dumps(vocab, indent=2, ensure_ascii=False), encoding="utf-8"
        )
        (output_path / "added_tokens.json").write_text(
            json.dumps(added_tokens, indent=2, ensure_ascii=False), encoding="utf-8"
        )
        (output_path / "special_tokens_map.json").write_text(
            json.dumps({"bos_token": BOS, "eos_token": EOS, "unk_token": UNK, "pad_token": PAD},
                      indent=2, ensure_ascii=False),
            encoding="utf-8"
        )
        (output_path / "tokenizer_config.json").write_text(
            json.dumps({
                "tokenizer_class": "PreTrainedTokenizerFast",
                "auto_map": {"AutoTokenizer": ["tokenizer.json", None]},
                "bos_token": BOS,
                "eos_token": EOS,
                "unk_token": UNK,
                "pad_token": PAD,
                "model_max_length": 2048,
                "padding_side": "left",
                "truncation_side": "right",
                "language": "tamil",
                "script": "tamil",
                "algorithm": "Grapheme-Aware PickyBPE"
            }, indent=2, ensure_ascii=False),
            encoding="utf-8"
        )

        logging.info("HuggingFace-compatible files saved.")

def train_tamil_grapheme_tokenizer(
    input_file: str,
    output_dir: str = "./tamil_grapheme_tokenizer",
    vocab_size: int = 10000,
    coverage: float = 0.9999,
    threshold: float = 0.9999,
    logging_step: int = 200
):
    """train a Tamil tokenizer with grapheme awareness"""

    tokenizer = TamilGraphemePickyBPE(
        vocab_size=vocab_size,
        coverage=coverage,
        threshold=threshold
    )

    start_time = time.time()
    tokenizer.fit(input_file, output_dir, logging_step)
    training_time = time.time() - start_time

    print(f"\nTamil tokenizer training completed in {training_time:.2f} seconds")
    print(f"Files saved to: {output_dir}")

    return tokenizer

def test_tamil_grapheme_tokenizer(tokenizer_path: str):

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

    test_phrases = [
        "வணக்கம்",
        "தமிழ்நாடு",
        "கணினி அறிவியல்",
        "வணக்கம், நான் தமிழ் பேசுகிறேன்"
    ]

    print("\nTesting Tamil tokenizer:")
    print("=" * 60)

    all_perfect = True

    for phrase in test_phrases:
        tokens = tokenizer.tokenize(phrase)
        decoded = tokenizer.decode(tokenizer.encode(phrase, add_special_tokens=False))
        perfect = phrase == decoded
        all_perfect = all_perfect and perfect

        print(f"Original: '{phrase}'")
        print(f"Tokens: {tokens}")
        print(f"Decoded: '{decoded}'")
        print(f"Perfect: {perfect}")
        print("-" * 40)

    if all_perfect:
        print("All test phrases reconstructed perfectly")
    else:
        print("Some reconstruction issues detected")

def main():

    input_file = "/content/drive/My Drive/Colab Notebooks/LRLs/tamil/dataset/ta_reduced_train.txt"
    output_dir = "/content/drive/My Drive/Colab Notebooks/LRLs/tamil/tokenizers/grapheme_picky_bpe"

    vocab_size = 10000
    coverage = 0.9999
    threshold = 0.9999
    logging_step = 200

    if not Path(input_file).exists():
        print(f"Error: Input file not found: {input_file}")
        return

    tokenizer = train_tamil_grapheme_tokenizer(
        input_file=input_file,
        output_dir=output_dir,
        vocab_size=vocab_size,
        coverage=coverage,
        threshold=threshold,
        logging_step=logging_step
    )

    test_tamil_grapheme_tokenizer(output_dir)

if __name__ == '__main__':
    main()