In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install gensim

In [None]:
import os
import json
import numpy as np
from pathlib import Path
from collections import defaultdict, Counter
from typing import List, Dict, Tuple
import tempfile
import time
from scipy.special import expit
import logging
import sys
import shutil
import matplotlib.pyplot as plt
import sentencepiece as spm
import gensim.models
import random
import unicodedata

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    handlers=[
        logging.StreamHandler(sys.stdout),
    ],
    force=True
)
logger = logging.getLogger(__name__)

class SaGeTokenizer:

    def __init__(self,
                 vocab_size: int = 10000,
                 initial_vocab_multiplier: float = 2,
                 max_token_length: int = 20,
                 embedding_dim: int = 50,
                 window_size: int = 5,
                 negative_samples: int = 10,
                 min_token_freq: int = 12,
                 pruning_batch_size: int = 3500,
                 embedding_update_frequency: int = 5,
                 gensim_workers: int = 4,
                 gensim_epochs: int = 5,
                 gensim_min_count: int = 1,
                 fertility_target: float = 1.4,
                 fertility_tolerance: float = 0.08,
                 patience: int = 3,
                 min_improvement: float = 0.01):

        self.vocab_size = vocab_size
        self.initial_vocab_size = int(vocab_size * initial_vocab_multiplier)
        self.max_token_length = max_token_length
        self.embedding_dim = embedding_dim
        self.window_size = window_size
        self.negative_samples = negative_samples
        self.min_token_freq = min_token_freq
        self.pruning_batch_size = pruning_batch_size
        self.embedding_update_frequency = embedding_update_frequency

        self.gensim_workers = gensim_workers
        self.gensim_epochs = gensim_epochs
        self.gensim_min_count = gensim_min_count

        self.fertility_target = fertility_target
        self.fertility_tolerance = fertility_tolerance
        self.patience = patience
        self.min_improvement = min_improvement

        self.vocabulary = {}
        self.inv_vocabulary = {}
        self.token_frequencies = Counter()
        self.embeddings = None
        self.context_embeddings = None

        self.training_history = {
            'iteration': [],
            'vocab_size': [],
            'fertility': [],
            'skip_gram_loss': [],
            'ablation_loss': [],
            'tokens_per_char': [],
            'coverage': []
        }
        self.validation_lines = []
        self.total_lines = 0

        self.devanagari_range = range(0x0900, 0x097F + 1)

    def is_devanagari_char(self, char: str) -> bool:
        return ord(char) in self.devanagari_range

    def normalize_nepali_text(self, text: str) -> str:
        return unicodedata.normalize('NFKC', text)

    def preprocess_nepali_line(self, line: str) -> str:
        line = line.strip()
        if not line:
            return ""

        line = self.normalize_nepali_text(line)
        line = '▁' + line.replace(' ', '▁')

        return line

    def count_corpus_lines(self, corpus_file: str) -> int:
        total = 0
        with open(corpus_file, 'r', encoding='utf-8', errors='ignore') as f:
            for _ in f:
                total += 1
        return total

    def load_validation_set(self, corpus_file: str, num_lines: int = 1000) -> List[str]:
        logger.info(f"Loading validation lines...")

        validation_start = max(int(self.total_lines * 0.9), self.total_lines - num_lines)

        lines = []
        with open(corpus_file, 'r', encoding='utf-8', errors='ignore') as f:
            for i, line in enumerate(f):
                if i >= validation_start:
                    line = self.preprocess_nepali_line(line)
                    if line:
                        lines.append(line)

        logger.info(f"Loaded {len(lines)} validation lines (from line {validation_start})")
        return lines

    def tokenize_with_vocabulary(self, text: str, vocab: Dict[str, int]) -> List[int]:
        """tokenize using longest match"""
        text = self.preprocess_nepali_line(text)
        tokens = []
        i = 0

        while i < len(text):
            matched = False
            for length in range(min(self.max_token_length, len(text) - i), 0, -1):
                substr = text[i:i+length]
                if substr in vocab:
                    tokens.append(vocab[substr])
                    i += length
                    matched = True
                    break

            if not matched:
                char = text[i]
                tokens.append(vocab.get(char, vocab.get('<unk>', 0)))
                i += 1

        return tokens

    def compute_fertility(self, vocab: Dict[str, int], validation_lines: List[str]) -> Dict[str, float]:

        total_words = 0
        total_tokens = 0
        total_chars = 0
        covered_chars = 0
        devanagari_chars = 0

        for line in validation_lines[:100]:
            original_line = line.replace('▁', ' ').strip()
            words = original_line.split()
            if not words:
                continue
            total_words += len(words)

            total_chars += len(line)

            for char in line:
                if self.is_devanagari_char(char):
                    devanagari_chars += 1

            tokens = self.tokenize_with_vocabulary(original_line, vocab)
            total_tokens += len(tokens)

            for token_id in tokens:
                token = self.inv_vocabulary.get(token_id, '<unk>')
                if token != '<unk>':
                    covered_chars += len(token)

        metrics = {
            'fertility': total_tokens / max(total_words, 1),
            'tokens_per_char': total_tokens / max(total_chars, 1),
            'coverage': min(covered_chars / max(total_chars, 1), 1.0),
            'avg_tokens_per_line': total_tokens / max(min(len(validation_lines), 100), 1),
            'devanagari_ratio': devanagari_chars / max(total_chars, 1)
        }

        return metrics

    def compute_skip_gram_loss(self, token_ids: List[int], target_emb: np.ndarray, context_emb: np.ndarray) -> float:
        """compute Skip-gram loss"""
        if len(token_ids) < 2:
            return 0.0

        loss = 0.0
        count = 0

        for i, target_id in enumerate(token_ids):
            if target_id >= len(target_emb):
                continue

            context_start = max(0, i - self.window_size)
            context_end = min(len(token_ids), i + self.window_size + 1)

            for j in range(context_start, context_end):
                if i == j:
                    continue

                context_id = token_ids[j]
                if context_id >= len(context_emb):
                    continue

                score = np.dot(target_emb[target_id], context_emb[context_id])
                loss -= np.log(expit(score) + 1e-10)
                count += 1

                for _ in range(self.negative_samples):
                    neg_id = np.random.randint(0, len(context_emb))
                    score = np.dot(target_emb[target_id], context_emb[neg_id])
                    loss -= np.log(1 - expit(score) + 1e-10)
                    count += 1

        return loss / max(count, 1)

    def initialize_vocabulary(self, corpus_file: str) -> Dict[str, int]:
        """initialize vocabulary"""
        logger.info(f"Initializing vocabulary with target size {self.initial_vocab_size}")

        ngram_counts = Counter()
        max_lines = int(self.total_lines * 0.9)

        with open(corpus_file, 'r', encoding='utf-8', errors='ignore') as f:
            for line_num, line in enumerate(f):
                if line_num >= max_lines:
                    break
                if line_num % 10000 == 0 and line_num > 0:
                    logger.info(f"Processing line {line_num}...")

                line = self.preprocess_nepali_line(line)
                if not line:
                    continue

                for n in range(1, min(len(line) + 1, self.max_token_length + 1)):
                    for i in range(len(line) - n + 1):
                        ngram = line[i:i+n]

                        weight = 1
                        if any(self.is_devanagari_char(c) for c in ngram):
                            weight = 2

                        ngram_counts[ngram] += weight

        filtered_ngrams = {
            ngram: count for ngram, count in ngram_counts.items()
            if count >= self.min_token_freq
        }

        logger.info(f"Found {len(filtered_ngrams)} n-grams with freq >= {self.min_token_freq}")

        essential_tokens = ['<unk>', '<s>', '</s>', '<pad>', '▁']

        for i in range(256):
            essential_tokens.append(chr(i))

        essential_tokens.extend(['।', '॥', '॰'])

        vocabulary = {}
        token_id = 0

        for token in essential_tokens:
            vocabulary[token] = token_id
            self.token_frequencies[token] = ngram_counts.get(token, 1)
            token_id += 1

        sorted_ngrams = sorted(filtered_ngrams.items(), key=lambda x: x[1], reverse=True)

        for ngram, freq in sorted_ngrams:
            if ngram not in vocabulary:
                vocabulary[ngram] = token_id
                self.token_frequencies[ngram] = freq
                token_id += 1

                if len(vocabulary) >= self.initial_vocab_size:
                    break

        self.vocabulary = vocabulary
        self.inv_vocabulary = {v: k for k, v in vocabulary.items()}

        logger.info(f"Initialized vocabulary with {len(vocabulary)} tokens")

        devanagari_tokens = sum(1 for token in vocabulary.keys()
                              if any(self.is_devanagari_char(c) for c in token))
        logger.info(f"Vocabulary contains {devanagari_tokens} tokens with Devanagari characters")

        return vocabulary

    def train_embeddings_with_gensim(self, corpus_file: str, vocab: Dict[str, int]) -> Tuple[np.ndarray, np.ndarray, float]:
        logger.info(f"Training embeddings with vocabulary of size {len(vocab)}")

        temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False,
                                                suffix='.txt', encoding='utf-8')

        max_training_lines = int(self.total_lines * 0.9)
        tokenized_lines = []

        with open(corpus_file, 'r', encoding='utf-8', errors='ignore') as f:
            for line_num, line in enumerate(f):
                if line_num >= max_training_lines:
                    break
                if line_num % 10000 == 0 and line_num > 0:
                    logger.info(f"Tokenizing line {line_num} for Gensim...")

                original_line = line.strip()
                if not original_line:
                    continue

                token_ids = self.tokenize_with_vocabulary(original_line, vocab)
                token_strings = [self.inv_vocabulary.get(tid, '<unk>') for tid in token_ids]

                if token_strings:
                    temp_file.write(' '.join(token_strings) + '\n')

                    if len(tokenized_lines) < 1000:
                        tokenized_lines.append(token_ids)

        temp_file.close()

        logger.info("Training Word2Vec model...")
        start_time = time.time()

        model = gensim.models.Word2Vec(
            corpus_file=temp_file.name,
            vector_size=self.embedding_dim,
            window=self.window_size,
            min_count=self.gensim_min_count,
            workers=self.gensim_workers,
            sg=1,
            negative=self.negative_samples,
            alpha=0.025,
            min_alpha=0.0001,
            epochs=self.gensim_epochs,
            seed=42
        )

        logger.info(f"Word2Vec training completed in {time.time() - start_time:.2f}s")

        vocab_size = len(vocab)
        target_embeddings = np.random.uniform(-0.5/self.embedding_dim, 0.5/self.embedding_dim,
                                             (vocab_size, self.embedding_dim))
        context_embeddings = np.random.uniform(-0.5/self.embedding_dim, 0.5/self.embedding_dim,
                                              (vocab_size, self.embedding_dim))

        found_embeddings = 0
        for token, token_id in vocab.items():
            if token in model.wv:
                target_embeddings[token_id] = model.wv[token]
                if hasattr(model.wv, 'syn1neg') and model.wv.syn1neg is not None:
                    word_index = model.wv.key_to_index[token]
                    context_embeddings[token_id] = model.wv.syn1neg[word_index]
                else:
                    context_embeddings[token_id] = model.wv[token]
                found_embeddings += 1

        logger.info(f"Found embeddings for {found_embeddings}/{vocab_size} tokens")

        total_loss = 0.0
        for token_ids in tokenized_lines[:100]:
            loss = self.compute_skip_gram_loss(token_ids, target_embeddings, context_embeddings)
            total_loss += loss

        avg_loss = total_loss / max(len(tokenized_lines[:100]), 1)

        os.unlink(temp_file.name)

        return target_embeddings, context_embeddings, avg_loss

    def compute_ablation_scores(self, corpus_file: str, vocab: Dict[str, int],
                                target_emb: np.ndarray, context_emb: np.ndarray,
                                sample_size: int = 20000) -> Tuple[Dict[str, float], float]:
        logger.info(f"Computing ablation scores (sample size: {sample_size})...")

        token_contexts = defaultdict(list)
        total_loss = 0.0
        total_pairs = 0

        max_lines = min(sample_size, int(self.total_lines * 0.9))

        with open(corpus_file, 'r', encoding='utf-8', errors='ignore') as f:
            lines_processed = 0

            for line in f:
                if lines_processed >= max_lines:
                    break

                original_line = line.strip()
                if not original_line:
                    continue

                token_ids = self.tokenize_with_vocabulary(original_line, vocab)

                for i in range(len(token_ids)):
                    target_id = token_ids[i]
                    if target_id >= len(self.inv_vocabulary):
                        continue

                    target_token = self.inv_vocabulary[target_id]

                    for j in range(max(0, i - self.window_size),
                                  min(len(token_ids), i + self.window_size + 1)):
                        if i != j:
                            context_id = token_ids[j]
                            if context_id >= len(context_emb):
                                continue

                            token_contexts[target_token].append((target_id, context_id))

                            score = np.dot(target_emb[target_id], context_emb[context_id])
                            total_loss -= np.log(expit(score) + 1e-10)
                            total_pairs += 1

                lines_processed += 1

                if lines_processed % 5000 == 0:
                    logger.info(f"  Processed {lines_processed}/{max_lines} lines for ablation scores")

        avg_loss = total_loss / max(total_pairs, 1)

        ablation_scores = {}

        for token, token_id in vocab.items():
            if (token in ['<unk>', '<s>', '</s>', '<pad>', '▁', '।', '॥', '॰']
                or len(token) == 1):
                ablation_scores[token] = float('-inf')
                continue

            devanagari_ratio = sum(1 for c in token if self.is_devanagari_char(c)) / len(token)
            if devanagari_ratio > 0.8 and len(token) > 3:
                protection_factor = 0.5
            else:
                protection_factor = 1.0

            contexts = token_contexts.get(token, [])

            if not contexts:
                ablation_scores[token] = float('inf')
                continue

            likelihood_with = 0.0
            for target_id, context_id in contexts:
                score = np.dot(target_emb[target_id], context_emb[context_id])
                likelihood_with += np.log(expit(score) + 1e-10)

            likelihood_without = 0.0
            sample_contexts = random.sample(contexts, min(100, len(contexts)))
            for target_id, context_id in sample_contexts:
                similar_score = np.mean([
                    np.dot(target_emb[tid], context_emb[context_id])
                    for tid in range(min(5, len(target_emb)))
                    if tid != target_id
                ])
                likelihood_without += np.log(expit(similar_score) + 1e-10)

            if sample_contexts:
                likelihood_without *= len(contexts) / len(sample_contexts)

            base_score = likelihood_without - likelihood_with
            ablation_scores[token] = base_score * protection_factor

        logger.info(f"Computed ablation scores. Average loss: {avg_loss:.4f}")
        return ablation_scores, avg_loss

    def prune_vocabulary(self, vocab: Dict[str, int], ablation_scores: Dict[str, float],
                        num_to_remove: int) -> Dict[str, int]:

        scored_tokens = [(score, token) for token, score in ablation_scores.items()
                        if score != float('-inf')]
        scored_tokens.sort(reverse=True)

        tokens_to_remove = set()
        protected_skipped = 0

        for score, token in scored_tokens:
            if len(tokens_to_remove) >= num_to_remove:
                break

            is_devanagari_token = any(self.is_devanagari_char(c) for c in token)

            if is_devanagari_token and len(token) > 2:
                devanagari_ratio = sum(1 for c in token if self.is_devanagari_char(c)) / len(token)

                if devanagari_ratio > 0.8 and score < 1.0 and protected_skipped < num_to_remove * 0.2:
                    protected_skipped += 1
                    logger.debug(f"Protected Devanagari token: {token} (score: {score:.4f})")
                    continue

            tokens_to_remove.add(token)

        if len(tokens_to_remove) < num_to_remove * 0.8:
            logger.warning(f"Only selected {len(tokens_to_remove)} tokens for removal, need {num_to_remove}")
            logger.info("Reducing protection to meet pruning targets")

            additional_needed = num_to_remove - len(tokens_to_remove)
            for score, token in scored_tokens:
                if token in tokens_to_remove:
                    continue
                if len(tokens_to_remove) >= num_to_remove:
                    break

                is_devanagari_token = any(self.is_devanagari_char(c) for c in token)
                if not (is_devanagari_token and len(token) > 4 and score < 0.5):
                    tokens_to_remove.add(token)

        new_vocab = {}
        new_id = 0

        for token in sorted(vocab.keys()):
            if token not in tokens_to_remove:
                new_vocab[token] = new_id
                new_id += 1

        logger.info(f"Pruned vocabulary from {len(vocab)} to {len(new_vocab)} tokens")
        logger.info(f"Actually removed {len(vocab) - len(new_vocab)} tokens (target: {num_to_remove})")

        devanagari_removed = sum(1 for token in tokens_to_remove
                               if any(self.is_devanagari_char(c) for c in token))
        total_removed = len(tokens_to_remove)
        logger.info(f"Removed {devanagari_removed}/{total_removed} tokens containing Devanagari characters")

        return new_vocab

    def should_stop_early(self) -> Tuple[bool, str]:

        if len(self.training_history['iteration']) < 2:
            return False, ""

        current_fertility = self.training_history['fertility'][-1]
        current_vocab_size = self.training_history['vocab_size'][-1]

        if current_vocab_size <= self.vocab_size * 1.05:
            return True, f"Close to target vocab size: {current_vocab_size} (target: {self.vocab_size})"

        if (abs(current_fertility - self.fertility_target) <= self.fertility_tolerance and
            current_vocab_size <= self.vocab_size * 1.3):
            return True, f"Reached target fertility: {current_fertility:.3f} and reasonable vocab size: {current_vocab_size}"

        if len(self.training_history['fertility']) >= self.patience:
            recent_fertilities = self.training_history['fertility'][-self.patience:]

            if all(recent_fertilities[i] > recent_fertilities[i-1] + 0.02 for i in range(1, len(recent_fertilities))):
                if current_fertility > self.fertility_target + 0.6:
                    return True, f"Fertility increasing too much: {current_fertility:.3f}"

        if len(self.training_history['vocab_size']) >= 5:
            recent_vocab_sizes = self.training_history['vocab_size'][-5:]
            if (max(recent_vocab_sizes) - min(recent_vocab_sizes) < 100 and
                current_vocab_size > self.vocab_size * 1.5):
                return True, f"Vocabulary size stagnant at {current_vocab_size} (target: {self.vocab_size})"

        return False, ""

    def train(self, corpus_file: str, output_dir: str) -> str:
        logger.info("=" * 70)
        logger.info("Starting SaGe tokenizer training for Nepali with Gensim embeddings")
        logger.info(f"Target vocabulary size: {self.vocab_size}")
        logger.info(f"Target fertility: {self.fertility_target} +/- {self.fertility_tolerance}")
        logger.info("=" * 70)

        self.total_lines = self.count_corpus_lines(corpus_file)
        logger.info(f"Corpus has {self.total_lines} lines")

        self.validation_lines = self.load_validation_set(corpus_file)

        current_vocab = self.initialize_vocabulary(corpus_file)

        initial_metrics = self.compute_fertility(current_vocab, self.validation_lines)
        logger.info(f"Initial metrics: Fertility={initial_metrics['fertility']:.3f}, "
                   f"Coverage={initial_metrics['coverage']:.3f}, "
                   f"Devanagari ratio={initial_metrics['devanagari_ratio']:.3f}")

        iteration = 0
        embeddings_trained = False
        best_fertility = float('inf')
        best_vocab = current_vocab.copy()

        while len(current_vocab) > self.vocab_size:
            iteration += 1
            logger.info(f"\nIteration {iteration}")
            logger.info(f"Current vocabulary size: {len(current_vocab)}")

            if not embeddings_trained or iteration % self.embedding_update_frequency == 0:
                self.embeddings, self.context_embeddings, skip_gram_loss = \
                    self.train_embeddings_with_gensim(corpus_file, current_vocab)
                embeddings_trained = True
            else:
                skip_gram_loss = 0.0
                num_val_lines = min(100, len(self.validation_lines))
                for line in self.validation_lines[:num_val_lines]:
                    original_line = line.replace('▁', ' ').strip()
                    token_ids = self.tokenize_with_vocabulary(original_line, current_vocab)
                    skip_gram_loss += self.compute_skip_gram_loss(
                        token_ids, self.embeddings, self.context_embeddings
                    )
                skip_gram_loss /= max(num_val_lines, 1)

            ablation_scores, ablation_loss = self.compute_ablation_scores(
                corpus_file, current_vocab,
                self.embeddings, self.context_embeddings,
                sample_size=min(20000, self.total_lines)
            )

            metrics = self.compute_fertility(current_vocab, self.validation_lines)

            self.training_history['iteration'].append(iteration)
            self.training_history['vocab_size'].append(len(current_vocab))
            self.training_history['fertility'].append(metrics['fertility'])
            self.training_history['skip_gram_loss'].append(skip_gram_loss)
            self.training_history['ablation_loss'].append(ablation_loss)
            self.training_history['tokens_per_char'].append(metrics['tokens_per_char'])
            self.training_history['coverage'].append(metrics['coverage'])

            logger.info(f"Metrics: Fertility={metrics['fertility']:.3f}, "
                       f"Coverage={metrics['coverage']:.3f}, "
                       f"Devanagari ratio={metrics['devanagari_ratio']:.3f}, "
                       f"Skip-gram Loss={skip_gram_loss:.4f}")

            if abs(metrics['fertility'] - self.fertility_target) < \
               abs(best_fertility - self.fertility_target):
                best_fertility = metrics['fertility']
                best_vocab = current_vocab.copy()
                logger.info(f"New best fertility for Nepali: {best_fertility:.3f}")

            should_stop, reason = self.should_stop_early()
            if should_stop:
                logger.info(f"\nEarly stopping: {reason}")
                current_vocab = best_vocab
                break

            if metrics['fertility'] > self.fertility_target + 0.3:
                prune_multiplier = 1.5
            elif metrics['fertility'] < self.fertility_target - 0.2:
                prune_multiplier = 0.5
            else:
                prune_multiplier = 1.0

            tokens_to_remove = min(
                int(self.pruning_batch_size * prune_multiplier),
                len(current_vocab) - self.vocab_size
            )

            if len(current_vocab) > self.vocab_size and tokens_to_remove == 0:
                tokens_to_remove = 1

            current_vocab = self.prune_vocabulary(
                current_vocab, ablation_scores, tokens_to_remove
            )

            self.vocabulary = current_vocab
            self.inv_vocabulary = {v: k for k, v in current_vocab.items()}

            logger.info(f"Pruned {tokens_to_remove} tokens. New size: {len(current_vocab)}")

        final_metrics = self.compute_fertility(current_vocab, self.validation_lines)
        logger.info(f"\nTraining completed")
        logger.info(f"Final vocabulary size: {len(current_vocab)}")
        logger.info(f"Final fertility: {final_metrics['fertility']:.3f}")
        logger.info(f"Final coverage: {final_metrics['coverage']:.3f}")
        logger.info(f"Final Devanagari ratio: {final_metrics['devanagari_ratio']:.3f}")

        final_devanagari_tokens = sum(1 for token in current_vocab.keys()
                                    if any(self.is_devanagari_char(c) for c in token))
        logger.info(f"Final vocabulary contains {final_devanagari_tokens} Devanagari tokens")

        os.makedirs(output_dir, exist_ok=True)
        history_file = os.path.join(output_dir, 'nepali_training_history.json')
        with open(history_file, 'w') as f:
            json.dump(self.training_history, f, indent=2)

        model_file = self.convert_to_sentencepiece(current_vocab, corpus_file, output_dir)

        return model_file

    def convert_to_sentencepiece(self, final_vocab: Dict[str, int],
                                corpus_file: str, output_dir: str) -> str:
        os.makedirs(output_dir, exist_ok=True)

        vocab_file = os.path.join(output_dir, "nepali_vocab.txt")
        total_freq = sum(self.token_frequencies.values())

        with open(vocab_file, 'w', encoding='utf-8') as f:
            for token in sorted(final_vocab.keys(),
                              key=lambda x: self.token_frequencies.get(x, 0),
                              reverse=True):
                freq = self.token_frequencies.get(token, 1)
                score = np.log(freq / max(total_freq, 1))
                f.write(f"{token}\t{score}\n")

        try:
            with tempfile.TemporaryDirectory(prefix='sage_nepali_', dir='/tmp') as temp_dir:
                temp_corpus = os.path.join(temp_dir, 'nepali_corpus.txt')

                with open(corpus_file, 'r', encoding='utf-8') as src, \
                     open(temp_corpus, 'w', encoding='utf-8') as dst:
                    for line_num, line in enumerate(src):
                        if line_num % 10000 == 0 and line_num > 0:
                            logger.info(f"Preprocessing line {line_num} for SentencePiece...")

                        processed_line = self.normalize_nepali_text(line.strip())
                        if processed_line:
                            dst.write(processed_line + '\n')

                temp_model_prefix = os.path.join(temp_dir, 'nepali_tokenizer')

                spm.SentencePieceTrainer.train(
                    input=temp_corpus,
                    model_prefix=temp_model_prefix,
                    vocab_size=len(final_vocab),
                    model_type='unigram',
                    character_coverage=0.9999,
                    normalization_rule_name='nfkc',
                    add_dummy_prefix=False,
                    unk_id=0,
                    bos_id=1,
                    eos_id=2,
                    pad_id=3,
                    input_sentence_size=min(100000, self.total_lines),
                    shuffle_input_sentence=True,
                    num_threads=16,
                    required_chars='।॥॰',
                    byte_fallback=True,
                    split_digits=True
                )

                for ext in ['.model', '.vocab']:
                    src = f"{temp_model_prefix}{ext}"
                    if os.path.exists(src):
                        dst = os.path.join(output_dir, f"nepali_tokenizer{ext}")
                        shutil.copy(src, dst)
        except Exception as e:
            logger.error(f"Error converting Nepali tokenizer to SentencePiece: {e}")
            return None

        self.create_huggingface_configs(output_dir, len(final_vocab))

        model_file = os.path.join(output_dir, "nepali_tokenizer.model")
        if os.path.exists(model_file):
            logger.info(f"Created Nepali SentencePiece model: {model_file}")
        else:
            logger.error("Failed to create Nepali SentencePiece model")

        return model_file

    def create_huggingface_configs(self, output_dir: str, vocab_size: int):

        configs = {
            "tokenizer_config.json": {
                "tokenizer_class": "sage",
                "model_max_length": 4096,
                "padding_side": "left",
                "bos_token": "<s>",
                "eos_token": "</s>",
                "unk_token": "<unk>",
                "pad_token": "<pad>",
                "add_bos_token": True,
                "add_eos_token": False,
                "clean_up_tokenization_spaces": False,
                "legacy": False,
                "sp_model_kwargs": {
                    "normalization_rule_name": "nfkc"
                }
            },
            "special_tokens_map.json": {
                "bos_token": "<s>",
                "eos_token": "</s>",
                "unk_token": "<unk>",
                "pad_token": "<pad>"
            }
        }

        for filename, config in configs.items():
            with open(os.path.join(output_dir, filename), 'w') as f:
                json.dump(config, f, indent=2)

        logger.info("Created HuggingFace configuration files")


def run__sage(corpus_file: str, output_dir: str):
    logger.info("Starting SaGe Tokenizer")

    tokenizer = SaGeTokenizer(
        vocab_size=10000,
        initial_vocab_multiplier=2,
        max_token_length=20,
        embedding_dim=50,
        window_size=5,
        negative_samples=10,
        min_token_freq=12,
        pruning_batch_size=3500,
        embedding_update_frequency=5,
        gensim_workers=4,
        gensim_epochs=5,
        gensim_min_count=1,
        fertility_target=1.4,
        fertility_tolerance=0.08,
        patience=3,
        min_improvement=0.01
    )

    model_file = tokenizer.train(corpus_file, output_dir)

    if model_file and os.path.exists(model_file):
        sp = spm.SentencePieceProcessor()
        sp.load(model_file)

        logger.info(f"\nFinal Statistics for Nepali:")
        logger.info(f"   Vocabulary size: {sp.vocab_size()}")

        if tokenizer.training_history['iteration']:
            logger.info(f"   Training iterations: {len(tokenizer.training_history['iteration'])}")
            logger.info(f"   Final fertility: {tokenizer.training_history['fertility'][-1]:.3f}")
            logger.info(f"   Final coverage: {tokenizer.training_history['coverage'][-1]:.3f}")

        test_sentences = [
            "नमस्कार संसार",
            "नेपाली भाषा धेरै राम्रो छ",
            "नेपालको काठमाडौं उपत्यका",
            "हिमालयको सुन्दरता अतुलनीय छ",
            "धन्यवाद र नमस्ते"
        ]

        logger.info("\nNepali Tokenization Examples:")
        for sentence in test_sentences:
            tokens = sp.encode_as_pieces(sentence)
            logger.info(f"   '{sentence}' -> {tokens[:10]}..." if len(tokens) > 10 else f"   '{sentence}' -> {tokens}")

        devanagari_count = 0
        for i in range(sp.vocab_size()):
            piece = sp.id_to_piece(i)
            if any(ord(c) in range(0x0900, 0x097F + 1) for c in piece):
                devanagari_count += 1

        logger.info(f"\nFinal vocabulary contains {devanagari_count} tokens with Devanagari characters")

    return model_file


if __name__ == "__main__":
    corpus_file = "/content/drive/My Drive/Colab Notebooks/LRLs/nepali/dataset/ne_reduced_train.txt"
    output_dir = "/content/drive/My Drive/Colab Notebooks/LRLs/nepali/tokenizers/sage"

    model = run_sage(corpus_file, output_dir)

    if model:
        logger.info(f"\nModel saved to: {output_dir}")