In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install gensim

In [None]:
"""
SaGe tokenizer training with early stopping and adaptive pruning.
"""

import os
import json
import numpy as np
from pathlib import Path
from collections import defaultdict, Counter
from typing import List, Dict, Tuple
import tempfile
import time
import logging
import sys
import shutil
import matplotlib.pyplot as plt
import sentencepiece as spm
import gensim.models
import random
import re

def sigmoid(x):
    """sigmoid function"""
    x = np.asarray(x)
    return np.where(x >= 0, 1 / (1 + np.exp(-x)), np.exp(x) / (1 + np.exp(x)))

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    handlers=[
        logging.StreamHandler(sys.stdout),
    ],
    force=True
)
logger = logging.getLogger(__name__)

class YorubaMonitoredSaGeTokenizer:

    def __init__(self,
                 vocab_size: int = 10000,
                 initial_vocab_multiplier: float = 2.0,
                 max_token_length: int = 5,
                 embedding_dim: int = 50,
                 window_size: int = 5,
                 negative_samples: int = 10,
                 min_token_freq: int = 20,
                 pruning_batch_size: int = 800,
                 embedding_update_frequency: int = 2,
                 # gensim parameters
                 gensim_workers: int = 4,
                 gensim_epochs: int = 5,
                 gensim_min_count: int = 1,
                 # early stopping params
                 fertility_target: float = 1.4,
                 fertility_tolerance: float = 0.05,
                 patience: int = 8,
                 min_improvement: float = 0.005):

        self.vocab_size = vocab_size
        self.initial_vocab_multiplier = initial_vocab_multiplier
        self.max_token_length = max_token_length
        self.embedding_dim = embedding_dim
        self.window_size = window_size
        self.negative_samples = negative_samples
        self.min_token_freq = min_token_freq
        self.pruning_batch_size = pruning_batch_size
        self.embedding_update_frequency = embedding_update_frequency

        # gensim parameters
        self.gensim_workers = gensim_workers
        self.gensim_epochs = gensim_epochs
        self.gensim_min_count = gensim_min_count

        # early stopping parameters
        self.fertility_target = fertility_target
        self.fertility_tolerance = fertility_tolerance
        self.patience = patience
        self.min_improvement = min_improvement

        # data structures
        self.vocabulary = {}
        self.inv_vocabulary = {}
        self.token_frequencies = Counter()
        self.embeddings = None
        self.context_embeddings = None

        # yoruba-specific patterns
        self.yoruba_diacritics = set('áàéèíìóòúùṣẹọ́ẹ̀ọ́ọ̀ṣ́ṣ̀')
        self.yoruba_nasals = {'m', 'n', 'ṅ'}

        # monitoring
        self.training_history = {
            'iteration': [],
            'vocab_size': [],
            'fertility': [],
            'skip_gram_loss': [],
            'ablation_loss': [],
            'tokens_per_char': [],
            'coverage': [],
            'diacritic_preservation': []
        }
        self.validation_lines = []
        self.total_lines = 0
        self.initial_vocab_size = None

    def normalize_yoruba_text(self, text: str) -> str:
        text = text.strip()
        text = re.sub(r'\s+', ' ', text)
        return text

    def analyze_diacritic_preservation(self, vocab: Dict[str, int], validation_lines: List[str]) -> float:
        """calculate how well the vocabulary preserves tone marks"""
        diacritic_tokens = 0
        total_tokens_with_diacritics = 0

        for token in vocab.keys():
            if any(char in self.yoruba_diacritics for char in token):
                diacritic_tokens += 1

        for line in validation_lines[:100]:
            if any(char in self.yoruba_diacritics for char in line):
                total_tokens_with_diacritics += 1

        if total_tokens_with_diacritics == 0:
            return 1.0

        return min(diacritic_tokens / max(total_tokens_with_diacritics, 1), 1.0)

    def count_corpus_lines(self, corpus_file: str) -> int:
        total = 0
        with open(corpus_file, 'r', encoding='utf-8', errors='ignore') as f:
            for _ in f:
                total += 1
        return total

    def load_validation_set(self, corpus_file: str, num_lines: int = 1000) -> List[str]:
        validation_start = max(int(self.total_lines * 0.9), self.total_lines - num_lines)
        lines = []
        with open(corpus_file, 'r', encoding='utf-8', errors='ignore') as f:
            for i, line in enumerate(f):
                if i >= validation_start:
                    line = self.normalize_yoruba_text(line)
                    if line:
                        lines.append(line)

        logger.info(f"Loaded {len(lines)} validation lines (from line {validation_start})")
        return lines

    def initialize_vocabulary(self, corpus_file: str) -> Dict[str, int]:
        """initialize vocabulary with frequency-based selection"""
        target_based_size = int(self.vocab_size * self.initial_vocab_multiplier)
        corpus_based_size = self.total_lines // 4

        self.initial_vocab_size = min(
            max(target_based_size, self.vocab_size + 8000),
            target_based_size * 2
        )

        logger.info(f"Initializing vocabulary with size {self.initial_vocab_size} (target: {self.vocab_size})")

        if self.initial_vocab_size <= self.vocab_size:
            raise ValueError(f"Initial vocab size {self.initial_vocab_size} must be > target {self.vocab_size}")

        ngram_counts = Counter()
        max_lines = int(self.total_lines * 0.9)

        with open(corpus_file, 'r', encoding='utf-8', errors='ignore') as f:
            for line_num, line in enumerate(f):
                if line_num >= max_lines:
                    break
                if line_num % 10000 == 0 and line_num > 0:
                    logger.info(f"Processing line {line_num}...")

                line = self.normalize_yoruba_text(line)
                line = '▁' + line.replace(' ', '▁')

                for n in range(1, min(len(line) + 1, self.max_token_length + 1)):
                    for i in range(len(line) - n + 1):
                        ngram = line[i:i+n]
                        boost = 1.2 if any(char in self.yoruba_diacritics for char in ngram) else 1.0
                        ngram_counts[ngram] += boost

        current_threshold = self.min_token_freq
        while current_threshold > 5:
            filtered_ngrams = {
                ngram: count for ngram, count in ngram_counts.items()
                if count >= current_threshold
            }
            if len(filtered_ngrams) >= self.initial_vocab_size:
                break
            current_threshold -= 5
            logger.info(f"Lowering frequency threshold to {current_threshold} to get more tokens")

        logger.info(f"Found {len(filtered_ngrams)} n-grams with freq >= {current_threshold}")

        essential_tokens = ['<unk>', '<s>', '</s>', '<pad>', '▁']

        for i in range(256):
            try:
                essential_tokens.append(chr(i))
            except:
                pass

        essential_tokens.extend(list(self.yoruba_diacritics))

        # build vocabulary
        vocabulary = {}
        token_id = 0

        for token in essential_tokens:
            if token not in vocabulary:
                vocabulary[token] = token_id
                self.token_frequencies[token] = ngram_counts.get(token, 1)
                token_id += 1

        sorted_ngrams = sorted(filtered_ngrams.items(),
                              key=lambda x: (any(char in self.yoruba_diacritics for char in x[0]), x[1]),
                              reverse=True)

        for ngram, freq in sorted_ngrams:
            if ngram not in vocabulary:
                vocabulary[ngram] = token_id
                self.token_frequencies[ngram] = freq
                token_id += 1

                if len(vocabulary) >= self.initial_vocab_size:
                    break

        if len(vocabulary) <= self.vocab_size:
            all_ngrams = sorted(ngram_counts.items(), key=lambda x: -x[1])
            for ngram, freq in all_ngrams:
                if ngram not in vocabulary and freq >= 2:
                    vocabulary[ngram] = token_id
                    self.token_frequencies[ngram] = freq
                    token_id += 1
                    if len(vocabulary) >= self.initial_vocab_size:
                        break

        self.vocabulary = vocabulary
        self.inv_vocabulary = {v: k for k, v in vocabulary.items()}

        diacritic_tokens = sum(1 for token in vocabulary.keys()
                             if any(char in self.yoruba_diacritics for char in token))
        logger.info(f"Initialized vocabulary with {len(vocabulary)} tokens")

        return vocabulary

    def tokenize_with_vocabulary(self, text: str, vocab: Dict[str, int]) -> List[int]:
        text = self.normalize_yoruba_text(text)
        text = '▁' + text.replace(' ', '▁')
        tokens = []
        i = 0

        while i < len(text):
            matched = False
            for length in range(min(self.max_token_length, len(text) - i), 0, -1):
                substr = text[i:i+length]
                if substr in vocab:
                    tokens.append(vocab[substr])
                    i += length
                    matched = True
                    break

            if not matched:
                char = text[i]
                tokens.append(vocab.get(char, vocab.get('<unk>', 0)))
                i += 1

        return tokens

    def compute_fertility(self, vocab: Dict[str, int], validation_lines: List[str]) -> Dict[str, float]:

        total_words = 0
        total_tokens = 0
        total_chars = 0
        covered_chars = 0

        for line in validation_lines[:100]:
            words = line.split()
            if not words:
                continue
            total_words += len(words)

            processed_line = '▁' + line.replace(' ', '▁')
            total_chars += len(processed_line)

            tokens = self.tokenize_with_vocabulary(line, vocab)
            total_tokens += len(tokens)

            for token_id in tokens:
                token = self.inv_vocabulary.get(token_id, '<unk>')
                if token != '<unk>':
                    covered_chars += len(token)

        diacritic_preservation = self.analyze_diacritic_preservation(vocab, validation_lines)

        metrics = {
            'fertility': total_tokens / max(total_words, 1),
            'tokens_per_char': total_tokens / max(total_chars, 1),
            'coverage': min(covered_chars / max(total_chars, 1), 1.0),
            'avg_tokens_per_line': total_tokens / max(min(len(validation_lines), 100), 1),
            'diacritic_preservation': diacritic_preservation
        }

        return metrics

    def compute_skip_gram_loss(self, token_ids: List[int], target_emb: np.ndarray, context_emb: np.ndarray) -> float:
        """compute Skip-gram loss"""

        if len(token_ids) < 2:
            return 0.0

        loss = 0.0
        count = 0

        for i, target_id in enumerate(token_ids):
            if target_id >= len(target_emb):
                continue

            context_start = max(0, i - self.window_size)
            context_end = min(len(token_ids), i + self.window_size + 1)

            for j in range(context_start, context_end):
                if i == j:
                    continue

                context_id = token_ids[j]
                if context_id >= len(context_emb):
                    continue

                # positive sample loss
                score = np.dot(target_emb[target_id], context_emb[context_id])
                loss -= np.log(sigmoid(score) + 1e-10)
                count += 1

                # negative samples loss
                for _ in range(self.negative_samples):
                    neg_id = np.random.randint(0, len(context_emb))
                    score = np.dot(target_emb[target_id], context_emb[neg_id])
                    loss -= np.log(1 - sigmoid(score) + 1e-10)
                    count += 1

        return loss / max(count, 1)

    def train_embeddings_with_gensim(self, corpus_file: str, vocab: Dict[str, int]) -> Tuple[np.ndarray, np.ndarray, float]:
        """train embeddings using Gensim Word2Vec"""
        logger.info(f"Training embeddings with Gensim for Yoruba vocabulary of size {len(vocab)}")

        temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False,
                                                suffix='.txt', encoding='utf-8')

        max_training_lines = int(self.total_lines * 0.9)
        tokenized_lines = []

        with open(corpus_file, 'r', encoding='utf-8', errors='ignore') as f:
            for line_num, line in enumerate(f):
                if line_num >= max_training_lines:
                    break
                if line_num % 10000 == 0 and line_num > 0:
                    logger.info(f"Tokenizing line {line_num} for Gensim...")

                line = self.normalize_yoruba_text(line)
                token_ids = self.tokenize_with_vocabulary(line, vocab)
                token_strings = [self.inv_vocabulary.get(tid, '<unk>') for tid in token_ids]

                if token_strings:
                    temp_file.write(' '.join(token_strings) + '\n')
                    if len(tokenized_lines) < 1000:
                        tokenized_lines.append(token_ids)

        temp_file.close()

        logger.info("Training Word2Vec model...")
        start_time = time.time()

        model = gensim.models.Word2Vec(
            corpus_file=temp_file.name,
            vector_size=self.embedding_dim,
            window=self.window_size,
            min_count=self.gensim_min_count,
            workers=self.gensim_workers,
            sg=1,
            negative=self.negative_samples,
            alpha=0.025,
            min_alpha=0.0001,
            epochs=self.gensim_epochs,
            seed=42
        )

        logger.info(f"Word2Vec training completed in {time.time() - start_time:.2f}s")

        vocab_size = len(vocab)
        target_embeddings = np.random.uniform(-0.5/self.embedding_dim, 0.5/self.embedding_dim,
                                             (vocab_size, self.embedding_dim))
        context_embeddings = np.random.uniform(-0.5/self.embedding_dim, 0.5/self.embedding_dim,
                                              (vocab_size, self.embedding_dim))

        learned_tokens = 0
        for token, token_id in vocab.items():
            if token in model.wv:
                target_embeddings[token_id] = model.wv[token]
                if hasattr(model.wv, 'syn1neg') and model.wv.syn1neg is not None:
                    word_index = model.wv.key_to_index[token]
                    context_embeddings[token_id] = model.wv.syn1neg[word_index]
                else:
                    context_embeddings[token_id] = model.wv[token]
                learned_tokens += 1

        logger.info(f"Learned embeddings for {learned_tokens}/{vocab_size} tokens "
                   f"({learned_tokens/vocab_size*100:.1f}%)")

        total_loss = 0.0
        for token_ids in tokenized_lines[:100]:
            loss = self.compute_skip_gram_loss(token_ids, target_embeddings, context_embeddings)
            total_loss += loss

        avg_loss = total_loss / max(len(tokenized_lines[:100]), 1)
        os.unlink(temp_file.name)

        return target_embeddings, context_embeddings, avg_loss

    def compute_ablation_scores(self, corpus_file: str, vocab: Dict[str, int], target_emb: np.ndarray,
                                context_emb: np.ndarray, sample_size: int = 15000) -> Tuple[Dict[str, float], float]:
        logger.info(f"Computing ablation scores (sample size: {sample_size})...")

        token_contexts = defaultdict(list)
        total_loss = 0.0
        total_pairs = 0

        max_lines = min(sample_size, int(self.total_lines * 0.9))

        with open(corpus_file, 'r', encoding='utf-8', errors='ignore') as f:
            lines_processed = 0

            for line in f:
                if lines_processed >= max_lines:
                    break

                line = self.normalize_yoruba_text(line)
                token_ids = self.tokenize_with_vocabulary(line, vocab)

                for i in range(len(token_ids)):
                    target_id = token_ids[i]
                    if target_id >= len(self.inv_vocabulary):
                        continue

                    target_token = self.inv_vocabulary[target_id]

                    for j in range(max(0, i - self.window_size),
                                  min(len(token_ids), i + self.window_size + 1)):
                        if i != j:
                            context_id = token_ids[j]
                            if context_id >= len(context_emb):
                                continue

                            token_contexts[target_token].append((target_id, context_id))

                            score = np.dot(target_emb[target_id], context_emb[context_id])
                            total_loss -= np.log(sigmoid(score) + 1e-10)
                            total_pairs += 1

                lines_processed += 1

                if lines_processed % 5000 == 0:
                    logger.info(f"  Processed {lines_processed}/{max_lines} lines for ablation scores")

        avg_loss = total_loss / max(total_pairs, 1)

        ablation_scores = {}

        for token, token_id in vocab.items():
            is_essential = token in ['<unk>', '<s>', '</s>', '<pad>', '▁']
            is_single_char = len(token) == 1
            is_protected_diacritic = (
                any(char in self.yoruba_diacritics for char in token) and
                (len(token) == 1 or self.token_frequencies.get(token, 0) > 300)
            )

            if is_essential or is_single_char or is_protected_diacritic:
                ablation_scores[token] = float('-inf')
                continue

            contexts = token_contexts.get(token, [])

            if not contexts:
                ablation_scores[token] = float('inf')
                continue

            likelihood_with = 0.0
            for target_id, context_id in contexts:
                score = np.dot(target_emb[target_id], context_emb[context_id])
                likelihood_with += np.log(sigmoid(score) + 1e-10)

            likelihood_without = 0.0
            sample_contexts = random.sample(contexts, min(100, len(contexts)))
            for target_id, context_id in sample_contexts:
                similar_score = np.mean([
                    np.dot(target_emb[tid], context_emb[context_id])
                    for tid in range(min(5, len(target_emb)))
                    if tid != target_id
                ])
                likelihood_without += np.log(sigmoid(similar_score) + 1e-10)

            if sample_contexts:
                likelihood_without *= len(contexts) / len(sample_contexts)

            ablation_scores[token] = likelihood_without - likelihood_with

        logger.info(f"Computed ablation scores. Average loss: {avg_loss:.4f}")
        return ablation_scores, avg_loss

    def get_adaptive_pruning_size(self, metrics: Dict[str, float], current_vocab_size: int, base_pruning_size: int) -> int:
        """Get adaptive pruning size based on current metrics"""

        # start with base size
        pruning_multiplier = 1.0

        if metrics['fertility'] > self.fertility_target + 0.3:
            pruning_multiplier = 1.2
        elif metrics['fertility'] < self.fertility_target - 0.15:
            pruning_multiplier = 0.7

        if metrics['diacritic_preservation'] < 0.7:
            pruning_multiplier *= 0.6
        elif metrics['diacritic_preservation'] > 0.85:
            pruning_multiplier *= 1.1

        # slow down when approaching target
        remaining_to_target = current_vocab_size - self.vocab_size
        if remaining_to_target < 3000:
            pruning_multiplier *= 0.5
        elif remaining_to_target < 8000:
            pruning_multiplier *= 0.8

        tokens_to_remove = int(base_pruning_size * pruning_multiplier)

        # never remove more than what's needed to reach target
        tokens_to_remove = min(tokens_to_remove, remaining_to_target)

        # minimum removal when far from target to ensure progress
        if remaining_to_target > 10000 and tokens_to_remove < 100:
            tokens_to_remove = 100

        return tokens_to_remove

    def prune_vocabulary(self, vocab: Dict[str, int], ablation_scores: Dict[str, float], num_to_remove: int) -> Dict[str, int]:

        if num_to_remove <= 0:
            return vocab.copy()

        # reduce tokens to remove if close to target
        current_vocab_size = len(vocab)
        remaining_to_target = current_vocab_size - self.vocab_size

        if remaining_to_target < 5000:
            num_to_remove = min(num_to_remove, max(50, remaining_to_target // 10))

        scored_tokens = [(score, token) for token, score in ablation_scores.items()
                        if score != float('-inf')]
        scored_tokens.sort(reverse=True)

        tokens_to_remove = set()

        for score, token in scored_tokens:
            if len(tokens_to_remove) >= num_to_remove:
                break

            is_essential = token in ['<unk>', '<s>', '</s>', '<pad>', '▁']
            is_single_char = len(token) == 1
            is_high_value_diacritic = (
                any(char in self.yoruba_diacritics for char in token) and
                (len(token) == 1 or self.token_frequencies.get(token, 0) > 200)
            )

            if not (is_essential or is_single_char or is_high_value_diacritic):
                tokens_to_remove.add(token)

        # if still not enough tokens, be more selective rather than aggressive
        if len(tokens_to_remove) < num_to_remove:
            remaining_needed = num_to_remove - len(tokens_to_remove)
            logger.info(f"Only found {len(tokens_to_remove)} safe tokens to remove out of {num_to_remove} needed")

            # only remove a fraction of what's still needed to avoid being too aggressive
            additional_to_remove = min(remaining_needed, remaining_needed // 2)

            for score, token in scored_tokens:
                if len(tokens_to_remove) >= len(tokens_to_remove) + additional_to_remove:
                    break
                if token not in tokens_to_remove:
                    is_absolutely_essential = (
                        token in ['<unk>', '<s>', '</s>', '<pad>', '▁'] or
                        (len(token) == 1 and token in self.yoruba_diacritics) or
                        self.token_frequencies.get(token, 0) > 500
                    )
                    if not is_absolutely_essential:
                        tokens_to_remove.add(token)

        new_vocab = {}
        new_id = 0

        for token in sorted(vocab.keys()):
            if token not in tokens_to_remove:
                new_vocab[token] = new_id
                new_id += 1

        logger.info(f"Pruned vocabulary from {len(vocab)} to {len(new_vocab)} tokens")

        if len(new_vocab) == 0:
            logger.error("Created empty vocabulary. Returning original vocabulary.")
            return vocab.copy()

        return new_vocab

    def should_stop_early(self) -> Tuple[bool, str]:

        if len(self.training_history['iteration']) < 2:
            return False, ""

        current_fertility = self.training_history['fertility'][-1]
        current_vocab_size = self.training_history['vocab_size'][-1]
        current_diacritic = self.training_history['diacritic_preservation'][-1]
        current_iteration = len(self.training_history['iteration'])

        # add minimum iteration requirement - don't stop too early
        MIN_ITERATIONS = 8
        if current_iteration < MIN_ITERATIONS:
            return False, ""

        # primary goal is vocabulary size - only stop when much closer to target
        if current_vocab_size <= self.vocab_size * 1.02:
            return True, f"Very close to target vocab size: {current_vocab_size} (target: {self.vocab_size})"

        # only stop on fertility if also reasonably close to vocab target
        if (abs(current_fertility - self.fertility_target) <= self.fertility_tolerance and
            current_vocab_size <= self.vocab_size * 1.15):
            if current_diacritic > 0.7:
                return True, f"Reached target fertility: {current_fertility:.3f} with reasonable vocab size"

        PATIENCE = 8
        if len(self.training_history['fertility']) >= PATIENCE:
            recent_fertilities = self.training_history['fertility'][-PATIENCE:]

            # stop if fertility is getting much worse AND not close to vocab target
            if all(recent_fertilities[i] > recent_fertilities[i-1] + 0.015 for i in range(1, len(recent_fertilities))):
                if current_fertility > self.fertility_target + 0.4 and current_vocab_size > self.vocab_size * 1.3:
                    return True, f"Fertility increasing too much: {current_fertility:.3f} and vocab still large"

            # stop on no improvement if very close to target vocab size
            fertility_improvement = abs(recent_fertilities[0] - recent_fertilities[-1])
            if fertility_improvement < self.min_improvement:
                if current_vocab_size <= self.vocab_size * 1.05:
                    return True, f"No fertility improvement in {PATIENCE} iterations and close to target"

        MAX_ITERATIONS = 50
        if current_iteration >= MAX_ITERATIONS:
            return True, f"Reached maximum iterations ({MAX_ITERATIONS})"

        return False, ""

    def train(self, corpus_file: str, output_dir: str) -> str:
        logger.info("=" * 70)
        logger.info("Starting SaGe tokenizer training")
        logger.info(f"Target vocabulary size: {self.vocab_size}")
        logger.info(f"Initial vocabulary multiplier: {self.initial_vocab_multiplier}")
        logger.info(f"Target fertility: {self.fertility_target} ± {self.fertility_tolerance}")
        logger.info("=" * 70)

        self.total_lines = self.count_corpus_lines(corpus_file)
        logger.info(f"Yoruba corpus has {self.total_lines} lines")

        self.validation_lines = self.load_validation_set(corpus_file)

        current_vocab = self.initialize_vocabulary(corpus_file)

        initial_metrics = self.compute_fertility(current_vocab, self.validation_lines)
        logger.info(f"Initial metrics: Fertility={initial_metrics['fertility']:.3f}, "
                   f"Coverage={initial_metrics['coverage']:.3f}, "

        # training loop
        iteration = 0
        embeddings_trained = False
        best_fertility_distance = float('inf')
        best_vocab = current_vocab.copy()

        while len(current_vocab) > self.vocab_size:
            iteration += 1
            logger.info(f"\n--- Yoruba Iteration {iteration} ---")
            logger.info(f"Current vocabulary size: {len(current_vocab)}")

            # train embeddings with Gensim
            if not embeddings_trained or iteration % self.embedding_update_frequency == 0:
                self.embeddings, self.context_embeddings, skip_gram_loss = \
                    self.train_embeddings_with_gensim(corpus_file, current_vocab)
                embeddings_trained = True
            else:
                skip_gram_loss = 0.0
                num_val_lines = min(100, len(self.validation_lines))
                for line in self.validation_lines[:num_val_lines]:
                    token_ids = self.tokenize_with_vocabulary(line, current_vocab)
                    skip_gram_loss += self.compute_skip_gram_loss(
                        token_ids, self.embeddings, self.context_embeddings
                    )
                skip_gram_loss /= max(num_val_lines, 1)

            # Compute ablation scores
            ablation_scores, ablation_loss = self.compute_ablation_scores(
                corpus_file, current_vocab,
                self.embeddings, self.context_embeddings,
                sample_size=min(15000, self.total_lines)
            )

            metrics = self.compute_fertility(current_vocab, self.validation_lines)

            self.training_history['iteration'].append(iteration)
            self.training_history['vocab_size'].append(len(current_vocab))
            self.training_history['fertility'].append(metrics['fertility'])
            self.training_history['skip_gram_loss'].append(skip_gram_loss)
            self.training_history['ablation_loss'].append(ablation_loss)
            self.training_history['tokens_per_char'].append(metrics['tokens_per_char'])
            self.training_history['coverage'].append(metrics['coverage'])
            self.training_history['diacritic_preservation'].append(metrics['diacritic_preservation'])

            logger.info(f"Metrics: Fertility={metrics['fertility']:.3f}, "
                       f"Coverage={metrics['coverage']:.3f}, "
                       f"Diacritic preservation={metrics['diacritic_preservation']:.3f}, "
                       f"Skip-gram Loss={skip_gram_loss:.4f}")

            current_fertility_distance = abs(metrics['fertility'] - self.fertility_target)

            if (current_fertility_distance < best_fertility_distance or
                (current_fertility_distance == best_fertility_distance and
                 metrics['diacritic_preservation'] > 0.8)):
                best_fertility_distance = current_fertility_distance
                best_vocab = current_vocab.copy()
                logger.info(f"New best fertility distance: {current_fertility_distance:.3f}")

            should_stop, reason = self.should_stop_early()
            if should_stop:
                logger.info(f"\nEarly stopping: {reason}")
                break

            tokens_to_remove = self.get_adaptive_pruning_size(
                metrics, len(current_vocab), self.pruning_batch_size
            )

            max_removable = len(current_vocab) - self.vocab_size
            if max_removable <= 0:
                logger.info("Already at or below target vocabulary size")
                break

            tokens_to_remove = min(tokens_to_remove, max_removable)

            if tokens_to_remove == 0 and max_removable > 1000:
                tokens_to_remove = min(50, max_removable)
                logger.info(f"Forcing minimum progress: removing {tokens_to_remove} tokens")

            # prune vocabulary
            current_vocab = self.prune_vocabulary(
                current_vocab, ablation_scores, tokens_to_remove
            )

            # update internal state
            self.vocabulary = current_vocab
            self.inv_vocabulary = {v: k for k, v in current_vocab.items()}

            logger.info(f"Pruned {tokens_to_remove} tokens. New size: {len(current_vocab)}")
            logger.info(f"Remaining to target: {len(current_vocab) - self.vocab_size}")

        final_metrics = self.compute_fertility(current_vocab, self.validation_lines)
        logger.info(f"\nYoruba tokenizer training completed!")
        logger.info(f"Training iterations: {iteration}")
        logger.info(f"Final vocabulary size: {len(current_vocab)}")
        logger.info(f"Final fertility: {final_metrics['fertility']:.3f}")
        logger.info(f"Final coverage: {final_metrics['coverage']:.3f}")
        logger.info(f"Final diacritic preservation: {final_metrics['diacritic_preservation']:.3f}")

        os.makedirs(output_dir, exist_ok=True)
        history_file = os.path.join(output_dir, 'yoruba_training_history.json')
        with open(history_file, 'w') as f:
            json.dump(self.training_history, f, indent=2)

        model_file = self.convert_to_sentencepiece(current_vocab, corpus_file, output_dir)

        return model_file

    def convert_to_sentencepiece(self, final_vocab: Dict[str, int], corpus_file: str, output_dir: str) -> str:
        logger.info(f"\nConverting tokenizer to SentencePiece format...")

        max_reasonable_vocab = min(len(final_vocab), max(8000, self.total_lines // 8))

        if len(final_vocab) > max_reasonable_vocab:
            logger.warning(f"Vocabulary size {len(final_vocab)} too large for corpus, reducing to {max_reasonable_vocab}")

            sorted_tokens = sorted(final_vocab.keys(), key=lambda x: self.token_frequencies.get(x, 0), reverse=True)

            reduced_vocab = {}
            for i, token in enumerate(sorted_tokens[:max_reasonable_vocab]):
                reduced_vocab[token] = i

            final_vocab = reduced_vocab
            logger.info(f"Reduced vocabulary to {len(final_vocab)} tokens")

        os.makedirs(output_dir, exist_ok=True)

        vocab_file = os.path.join(output_dir, "yoruba_vocab.txt")
        total_freq = sum(self.token_frequencies.values())

        with open(vocab_file, 'w', encoding='utf-8') as f:
            sorted_tokens = sorted(final_vocab.keys(),
                                 key=lambda x: (
                                     not any(char in self.yoruba_diacritics for char in x),
                                     -self.token_frequencies.get(x, 0)
                                 ))

            for token in sorted_tokens:
                freq = self.token_frequencies.get(token, 1)
                score = np.log(freq / max(total_freq, 1))
                if any(char in self.yoruba_diacritics for char in token):
                    score += 0.05
                f.write(f"{token}\t{score}\n")

        try:
            with tempfile.TemporaryDirectory(prefix='sage_yoruba_', dir='/tmp') as temp_dir:
                temp_corpus = os.path.join(temp_dir, 'yoruba_corpus.txt')

                line_count = 0
                with open(corpus_file, 'r', encoding='utf-8') as src, \
                     open(temp_corpus, 'w', encoding='utf-8') as dst:
                    for line in src:
                        normalized = self.normalize_yoruba_text(line)
                        if normalized:
                            dst.write(normalized + '\n')
                            line_count += 1

                temp_model_prefix = os.path.join(temp_dir, 'yoruba_tokenizer')

                adaptive_vocab_size = min(len(final_vocab), max(6000, line_count // 10))
                logger.info(f"Using adaptive vocabulary size: {adaptive_vocab_size}")

                spm.SentencePieceTrainer.train(
                    input=temp_corpus,
                    model_prefix=temp_model_prefix,
                    vocab_size=adaptive_vocab_size,
                    model_type='unigram',
                    character_coverage=0.9995,
                    normalization_rule_name='nmt_nfkc_cf',
                    add_dummy_prefix=False,
                    unk_id=0,
                    bos_id=1,
                    eos_id=2,
                    pad_id=3,
                    input_sentence_size=min(20000, line_count),
                    shuffle_input_sentence=True,
                    num_threads=8,
                    user_defined_symbols=['á', 'à', 'é', 'è', 'í', 'ì', 'ó', 'ò', 'ú', 'ù', 'ṣ', 'ẹ', 'ọ']
                )

                for ext in ['.model', '.vocab']:
                    src = f"{temp_model_prefix}{ext}"
                    if os.path.exists(src):
                        dst = os.path.join(output_dir, f"yoruba_tokenizer{ext}")
                        shutil.copy(src, dst)

        except Exception as e:
            logger.error(f"Error converting tokenizer to SentencePiece: {e}")
            return None

        self.create_huggingface_configs(output_dir, len(final_vocab))

        model_file = os.path.join(output_dir, "yoruba_tokenizer.model")
        if os.path.exists(model_file):
            logger.info(f"Created SentencePiece model: {model_file}")
        else:
            logger.error("Failed to create SentencePiece model")

        return model_file

    def create_huggingface_configs(self, output_dir: str, vocab_size: int):
        configs = {
            "tokenizer_config.json": {
                "tokenizer_class": "LlamaTokenizer",
                "model_max_length": 4096,
                "padding_side": "left",
                "bos_token": "<s>",
                "eos_token": "</s>",
                "unk_token": "<unk>",
                "pad_token": "<pad>",
                "add_bos_token": True,
                "add_eos_token": False,
                "clean_up_tokenization_spaces": False,
                "legacy": False,
                "name_or_path": "yoruba_sage_tokenizer",
                "special_tokens_map_file": "special_tokens_map.json"
            },
            "special_tokens_map.json": {
                "bos_token": "<s>",
                "eos_token": "</s>",
                "unk_token": "<unk>",
                "pad_token": "<pad>"
            }
        }

        for filename, config in configs.items():
            with open(os.path.join(output_dir, filename), 'w') as f:
                json.dump(config, f, indent=2)


def run_sage(corpus_file: str, output_dir: str):

    tokenizer = YorubaMonitoredSaGeTokenizer(
        vocab_size=10000,
        initial_vocab_multiplier=1.5,
        max_token_length=5,
        embedding_dim=50,
        window_size=5,
        negative_samples=10,
        min_token_freq=20,
        pruning_batch_size=1500,
        embedding_update_frequency=2,
        gensim_workers=4,
        gensim_epochs=5,
        gensim_min_count=1,
        fertility_target=1.4,
        fertility_tolerance=0.05,
        patience=8,
        min_improvement=0.005
    )

    model_file = tokenizer.train(corpus_file, output_dir)

    # test tokenizer
    if model_file and os.path.exists(model_file):
        sp = spm.SentencePieceProcessor()
        sp.load(model_file)

        logger.info(f"\nFinal Yoruba Statistics:")
        logger.info(f"   Vocabulary size: {sp.vocab_size()}")

        if tokenizer.training_history['iteration']:
            logger.info(f"   Training iterations: {len(tokenizer.training_history['iteration'])}")
            logger.info(f"   Final fertility: {tokenizer.training_history['fertility'][-1]:.3f}")
            logger.info(f"   Final coverage: {tokenizer.training_history['coverage'][-1]:.3f}")
            logger.info(f"   Final diacritic preservation: {tokenizer.training_history['diacritic_preservation'][-1]:.3f}")

        test_sentences = [
            "Pẹlẹ o, báwo ni ẹ ṣe wa",  # Hello, how are you
            "Mo fẹ́ kọ́ èdè Yorùbá",     # I want to learn Yoruba language
            "Ilẹ̀ wa lọ́wọ́ wa",          # Our land is in our hands
            "Ẹ̀kọ́ ní koko ayé",        # Education is the essence of life
            "Òrèwá lórúkọ mi"          # My name is Orewa
        ]

        logger.info("\nYoruba Tokenization Examples:")
        for sentence in test_sentences:
            tokens = sp.encode_as_pieces(sentence)
            logger.info(f"   '{sentence}' -> {tokens[:10]}..." if len(tokens) > 10 else f"   '{sentence}' -> {tokens}")

    return model_file

if __name__ == "__main__":
    corpus_file = "/content/drive/My Drive/Colab Notebooks/LRLs/yoruba/dataset/yo_train.txt"
    output_dir = "/content/drive/My Drive/Colab Notebooks/LRLs/yoruba/tokenizers/sage"

    model = run_sage(corpus_file, output_dir)

    if model:
        logger.info(f"\nSaGe model saved to: {output_dir}")