In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -q sentencepiece transformers datasets huggingface_hub

In [None]:
import os
import json
import time
import random
import logging
import numpy as np
import pandas as pd
from pathlib import Path
from collections import Counter, defaultdict
from typing import List, Dict, Union, Tuple, Optional
import unicodedata
import re
import getpass
from transformers import AutoTokenizer
import sentencepiece as spm

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class PickyBPETokenizer:

    def __init__(self, model_path: str):
        self.model_path = model_path
        self.model = self._load_model()

    def _load_model(self):
        try:
            if self.model_path.endswith('tokenizer.json'):
                try:
                    from transformers import PreTrainedTokenizerFast
                    tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.model_path)
                    logger.info(f"Loaded as HuggingFace tokenizer: {self.model_path}")
                    return tokenizer
                except Exception as e:
                    logger.warning(f"Failed to load as HuggingFace tokenizer: {e}")
            with open(self.model_path, 'r', encoding='utf-8') as f:
                data = json.load(f)

            logger.info(f"Loaded JSON data with keys: {list(data.keys())}")

            if 'model' in data and 'vocab' in data['model']:
                vocab = data['model']['vocab']
                tokens_data = []
                for token_str, token_id in vocab.items():
                    tokens_data.append({
                        'id': token_id,
                        'str': token_str,
                        'freq': 0,
                        'special': token_str in ['<pad>', '<unk>', '<s>', '</s>', '[PAD]', '[UNK]', '[CLS]', '[SEP]'],
                        'present': True
                    })
            elif 'tokens' in data:
                tokens_data = data['tokens']
            elif 'vocab' in data:
                vocab = data['vocab']
                tokens_data = []
                for token_str, token_id in vocab.items():
                    tokens_data.append({
                        'id': token_id,
                        'str': token_str,
                        'freq': 0,
                        'special': token_str in ['<pad>', '<unk>', '<s>', '</s>'],
                        'present': True
                    })
            else:
                tokens_data = []
                for i, (token_str, token_id) in enumerate(data.items()):
                    if isinstance(token_id, int):
                        tokens_data.append({
                            'id': token_id,
                            'str': token_str,
                            'freq': 0,
                            'special': token_str in ['<pad>', '<unk>', '<s>', '</s>'],
                            'present': True
                        })

            class SimplePickyBPEModel:
                def __init__(self, tokens_data):
                    self.id2token = {}
                    self.str2token = {}

                    for token in tokens_data:
                        if isinstance(token, dict):
                            token_id = token['id']
                            token_str = token['str']
                        else:
                            continue

                        self.id2token[token_id] = token_str
                        self.str2token[token_str] = token_id

                    self.unk_id = self.str2token.get('<unk>', self.str2token.get('[UNK]', 1))
                    logger.info(f"Created model with {len(self.id2token)} tokens, UNK ID: {self.unk_id}")

                def encode(self, text: str) -> List[int]:
                    """Simple word + character-level encoding"""
                    tokens = []
                    words = text.split()

                    for word in words:
                        # Try word-level first
                        if word in self.str2token:
                            tokens.append(self.str2token[word])
                        else:
                            # Fall back to character-level
                            for char in word:
                                if char in self.str2token:
                                    tokens.append(self.str2token[char])
                                else:
                                    tokens.append(self.unk_id)

                    return tokens

                def tokenize(self, text: str) -> List[str]:
                    """Tokenize text to token strings"""
                    ids = self.encode(text)
                    return [self.id2token.get(token_id, '<unk>') for token_id in ids]

            return SimplePickyBPEModel(tokens_data)

        except Exception as e:
            logger.error(f"Error loading Picky BPE model: {e}")
            return None

    def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
        """Encode text to token IDs"""
        if self.model is None:
            return []
        return self.model.encode(text)

    def tokenize(self, text: str) -> List[str]:
        """Tokenize text to token strings"""
        if self.model is None:
            return []
        return self.model.tokenize(text)

    @property
    def unk_token_id(self) -> int:
        """Get UNK token ID"""
        if hasattr(self.model, 'unk_id'):
            return self.model.unk_id
        elif hasattr(self.model, 'unk_token_id'):
            return self.model.unk_token_id
        return 1

class SentencePieceTokenizer:

    def __init__(self, model_path: str):
        self.model_path = model_path
        self.sp = spm.SentencePieceProcessor()
        self.load_model()

    def load_model(self):
        try:
            if not os.path.exists(self.model_path):
                logger.error(f"SentencePiece model file not found: {self.model_path}")
                self.sp = None
                return

            self.sp.load(self.model_path)

            test_result = self.sp.encode("test", out_type=int)
            logger.info(f"SentencePiece model loaded. Test encoding: {test_result}")

        except Exception as e:
            logger.error(f"Error loading SentencePiece model: {e}")
            self.sp = None

    def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
        """encode text to token IDs"""
        if self.sp is None:
            return []
        try:
            return self.sp.encode(text, out_type=int)
        except Exception as e:
            logger.warning(f"Error encoding text: {e}")
            return []

    def encode_as_ids(self, text: str) -> List[int]:
        """encode text to token IDs"""
        return self.encode(text)

    def encode_as_pieces(self, text: str) -> List[str]:
        """encode text to token pieces"""
        if self.sp is None:
            return []
        try:
            return self.sp.encode(text, out_type=str)
        except Exception as e:
            logger.warning(f"Error encoding text to pieces: {e}")
            return []

    def tokenize(self, text: str) -> List[str]:
        return self.encode_as_pieces(text)

    @property
    def unk_token_id(self) -> int:
        if self.sp is None:
            return 1
        try:
            return self.sp.unk_id()
        except:
            return 1

class SimpleBPETokenizer:

    def __init__(self, vocab_path: str, merges_path: str = None):
        self.vocab_path = vocab_path
        self.merges_path = merges_path
        self.vocab = {}
        self.merges = []
        self.load_model()

    def load_model(self):
        try:
            if self.vocab_path.endswith('.json'):
                with open(self.vocab_path, 'r', encoding='utf-8') as f:
                    self.vocab = json.load(f)
            else:
                with open(self.vocab_path, 'r', encoding='utf-8') as f:
                    for i, line in enumerate(f):
                        token = line.strip().split('\t')[0] if '\t' in line else line.strip()
                        self.vocab[token] = i

            if self.merges_path and os.path.exists(self.merges_path):
                with open(self.merges_path, 'r', encoding='utf-8') as f:
                    for line in f:
                        if line.strip() and not line.startswith('#'):
                            parts = line.strip().split()
                            if len(parts) >= 2:
                                self.merges.append((parts[0], parts[1]))

            logger.info(f"Loaded BPE model with {len(self.vocab)} tokens and {len(self.merges)} merges")
        except Exception as e:
            logger.error(f"Error loading BPE model: {e}")

    def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
        """encode text to token IDs"""
        tokens = self.tokenize(text)
        return [self.vocab.get(token, self.vocab.get('<unk>', 1)) for token in tokens]

    def tokenize(self, text: str) -> List[str]:
        if not self.vocab:
            return list(text)
        words = text.split()
        tokens = []

        for word in words:
            word_tokens = self._tokenize_word(word)
            tokens.extend(word_tokens)

        return tokens

    def _tokenize_word(self, word: str) -> List[str]:
        if word in self.vocab:
            return [word]

        tokens = []
        for char in word:
            if char in self.vocab:
                tokens.append(char)
            else:
                tokens.append('<unk>')

        return tokens

    @property
    def unk_token_id(self) -> int:
        """get UNK token ID"""
        return self.vocab.get('<unk>', 1)

class TokenizerEvaluator:

    def __init__(self, test_texts: List[str]):
        self.test_texts = test_texts

    def calculate_fertility(self, tokenizer, texts: List[str]) -> float:
        """calculate average tokens per word"""
        total_tokens = 0
        total_words = 0

        for text in texts:
            words = text.split()
            total_words += len(words)

            try:
                tokens = self._get_tokens(tokenizer, text)
                total_tokens += len(tokens)
            except Exception as e:
                logger.warning(f"Error tokenizing text with {type(tokenizer)}: {e}")
                continue

        return total_tokens / total_words if total_words > 0 else 0

    def calculate_compression_ratio(self, tokenizer, texts: List[str]) -> float:
        """calculate compression ratio (chars per token)"""
        total_chars = 0
        total_tokens = 0

        for text in texts:
            total_chars += len(text)

            try:
                tokens = self._get_tokens(tokenizer, text)
                total_tokens += len(tokens)
            except Exception as e:
                logger.warning(f"Error tokenizing text with {type(tokenizer)}: {e}")
                continue

        return total_chars / total_tokens if total_tokens > 0 else 0

    def calculate_coverage(self, tokenizer, texts: List[str]) -> float:
        """calculate vocabulary coverage (1 - UNK ratio)"""
        total_tokens = 0
        unk_tokens = 0

        for text in texts:
            try:
                tokens = self._get_tokens(tokenizer, text)
                total_tokens += len(tokens)

                # count UNK tokens
                if hasattr(tokenizer, 'unk_token_id'):
                    unk_id = tokenizer.unk_token_id
                    unk_tokens += sum(1 for t in tokens if t == unk_id)
                elif isinstance(tokens[0], str) if tokens else False:
                    unk_tokens += sum(1 for t in tokens if '<unk>' in str(t).lower() or '[unk]' in str(t).lower())
                else:
                    unk_tokens += sum(1 for t in tokens if t in {0, 1})

            except Exception as e:
                logger.warning(f"Error calculating coverage with {type(tokenizer)}: {e}")
                continue

        coverage = 1 - (unk_tokens / total_tokens) if total_tokens > 0 else 0
        return max(0, coverage)

    def _get_tokens(self, tokenizer, text: str) -> List:
        if hasattr(tokenizer, 'encode_as_ids') and hasattr(tokenizer, 'sp'):
            return tokenizer.encode_as_ids(text)
        elif hasattr(tokenizer, 'encode') and hasattr(tokenizer, 'model') and hasattr(tokenizer.model, 'encode'):
            return tokenizer.encode(text)
        elif hasattr(tokenizer, 'encode'):
            try:
                return tokenizer.encode(text, add_special_tokens=False)
            except TypeError:
                return tokenizer.encode(text)
        elif hasattr(tokenizer, 'tokenize'):
            return tokenizer.tokenize(text)
        else:
            raise ValueError(f"Unknown tokenizer interface: {type(tokenizer)}")

    def calculate_token_length_distribution(self, tokenizer, texts: List[str]) -> Dict[str, float]:

        token_lengths = []
        sample_texts = texts[:100]

        for text in sample_texts:
            try:
                if hasattr(tokenizer, 'encode_as_pieces'):
                    tokens = tokenizer.encode_as_pieces(text)
                elif hasattr(tokenizer, 'tokenize'):
                    tokens = tokenizer.tokenize(text)
                else:
                    continue

                for token in tokens:
                    token_str = str(token).replace('▁', '').replace('Ġ', '')
                    token_lengths.append(len(token_str))
            except Exception as e:
                continue

        if not token_lengths:
            return {'mean_length': 0, 'std_length': 0}

        return {
            'mean_length': np.mean(token_lengths),
            'std_length': np.std(token_lengths)
        }

    def evaluate_tokenizer(self, tokenizer, name: str) -> Dict[str, float]:
        logger.info(f"Evaluating tokenizer: {name}")

        eval_texts = self.test_texts[:500] if len(self.test_texts) > 500 else self.test_texts

        metrics = {
            'fertility': self.calculate_fertility(tokenizer, eval_texts),
            'compression_ratio': self.calculate_compression_ratio(tokenizer, eval_texts),
            'coverage': self.calculate_coverage(tokenizer, eval_texts)
        }

        length_stats = self.calculate_token_length_distribution(tokenizer, eval_texts)
        metrics.update(length_stats)

        metrics['composite_score'] = (
            metrics['compression_ratio'] * metrics['coverage'] / metrics['fertility']
            if metrics['fertility'] > 0 else 0
        )

        logger.info(f"Metrics for {name}: {metrics}")
        return metrics

class YorubaCorpusProcessor:
    """handles Yoruba corpus loading and preprocessing"""

    def __init__(self, max_sentences: int = 5000):
        self.max_sentences = max_sentences
        self.yoruba_diacritics = re.compile(r'[àáâéèêíìîóòôúùû]')
        self.yoruba_pattern = re.compile(r'[abdeẹfghijklmnoprstuwy]', re.IGNORECASE)

    def load_custom_dataset(self, dataset_path: str, file_format: str = "txt", text_column: str = "text",
                            max_sentences: int = None) -> List[str]:

        if max_sentences is None:
            max_sentences = self.max_sentences

        dataset_path = Path(dataset_path)

        if file_format == "txt":
            sentences = self.load_text_file(dataset_path, max_sentences)

        valid_sentences = [s for s in sentences if self.is_valid_yoruba_text(s)]
        logger.info(f"Loaded {len(valid_sentences)} valid sentences from {len(sentences)} total")

        return valid_sentences[:max_sentences]

    def load_text_file(self, file_path: Path, max_sentences: int) -> List[str]:
        sentences = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i >= max_sentences:
                    break
                line = line.strip()
                if line:
                    sentences.append(line)
        return sentences

    def is_valid_yoruba_text(self, text: str) -> bool:
        """Check if text is valid Yoruba"""
        if len(text) < 10 or len(text) > 1000:
            return False

        latin_chars = len(re.findall(r'[a-zA-ZàáâéèêíìîóòôúùûẹọṣẹṇÀÁÂÉÈÊÍÌÎÓÒÔÚÙÛẸỌṢẸṆ]', text))
        total_chars = len([c for c in text if c.isalpha()])

        if total_chars == 0:
            return False

        latin_ratio = latin_chars / total_chars
        if latin_ratio < 0.8:
            return False

        common_yoruba_words = ['ni', 'wa', 'ti', 'si', 'ko', 'lo', 'ba', 'se', 'ati', 'ninu', 'lati', 'won', 'ile']
        words = text.lower().split()
        yoruba_word_count = sum(1 for word in words if any(yw in word for yw in common_yoruba_words))

        return yoruba_word_count > 0 or self.yoruba_diacritics.search(text) is not None

class TokenizerLoader:
    """Loads different types of pre-trained tokenizers"""

    def __init__(self, hf_token: str = None):
        self.hf_token = hf_token

    def load_huggingface_tokenizer(self, model_name: str):
        """load HF tokenizer"""
        try:
            tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                token=self.hf_token,
                trust_remote_code=True
            )
            logger.info(f"Loaded HuggingFace tokenizer: {model_name}")
            return tokenizer
        except Exception as e:
            logger.error(f"Error loading HuggingFace tokenizer {model_name}: {e}")
            return None

    def load_picky_bpe_tokenizer(self, model_path: str):
        """Load Picky BPE tokenizer"""
        try:
            tokenizer = PickyBPETokenizer(model_path)
            if tokenizer.model is not None:
                logger.info(f"Loaded Picky BPE tokenizer: {model_path}")
                return tokenizer
            else:
                return None
        except Exception as e:
            logger.error(f"Error loading Picky BPE tokenizer {model_path}: {e}")
            return None

    def load_sentencepiece_tokenizer(self, model_path: str):
        """Load SentencePiece tokenizer (SaGe or BPE)"""
        try:
            tokenizer = SentencePieceTokenizer(model_path)
            if tokenizer.sp is not None:
                logger.info(f"Loaded SentencePiece tokenizer: {model_path}")
                return tokenizer
            else:
                return None
        except Exception as e:
            logger.error(f"Error loading SentencePiece tokenizer {model_path}: {e}")
            return None

    def load_simple_bpe_tokenizer(self, vocab_path: str, merges_path: str = None):
        """Load simple BPE tokenizer"""
        try:
            tokenizer = SimpleBPETokenizer(vocab_path, merges_path)
            logger.info(f"Loaded simple BPE tokenizer: {vocab_path}")
            return tokenizer
        except Exception as e:
            logger.error(f"Error loading simple BPE tokenizer {vocab_path}: {e}")
            return None

def evaluate_tokenizers(test_texts: List[str], tokenizer_configs: List[Dict], hf_token: str = None) -> pd.DataFrame:

    evaluator = TokenizerEvaluator(test_texts)
    loader = TokenizerLoader(hf_token=hf_token)

    results = []

    for config in tokenizer_configs:
        tokenizer_name = config['name']
        tokenizer_type = config['type']
        tokenizer_path = config['path']

        logger.info(f"Loading tokenizer: {tokenizer_name}")

        tokenizer = None
        if tokenizer_type == 'huggingface':
            tokenizer = loader.load_huggingface_tokenizer(tokenizer_path)
        elif tokenizer_type == 'picky_bpe':
            tokenizer = loader.load_picky_bpe_tokenizer(tokenizer_path)
        elif tokenizer_type == 'sentencepiece':
            tokenizer = loader.load_sentencepiece_tokenizer(tokenizer_path)
        elif tokenizer_type == 'standard_bpe':
            vocab_path = tokenizer_path
            merges_path = config.get('merges_path', None)
            tokenizer = loader.load_simple_bpe_tokenizer(vocab_path, merges_path)
        else:
            logger.error(f"Unknown tokenizer type: {tokenizer_type}")
            continue

        if tokenizer is None:
            logger.warning(f"Skipping {tokenizer_name} - failed to load")
            continue

        try:
            metrics = evaluator.evaluate_tokenizer(tokenizer, tokenizer_name)
            results.append({"tokenizer": tokenizer_name, "type": tokenizer_type, **metrics})
        except Exception as e:
            logger.error(f"Error evaluating {tokenizer_name}: {e}")
            continue

    return pd.DataFrame(results)

def main(dataset_path: str = None, file_format: str = "auto", text_column: str = "text", max_sentences: int = 2000,
         hf_token: str = None, custom_tokenizer_paths: Dict = None):
    logger.info("Starting Yoruba tokenizer evaluation...")

    random.seed(42)
    np.random.seed(42)

    processor = YorubaCorpusProcessor(max_sentences=max_sentences)

    if dataset_path:
        logger.info(f"Using custom dataset: {dataset_path}")
        test_texts = processor.load_custom_dataset(dataset_path, file_format, text_column, max_sentences)
    else:
        logger.error("No dataset path provided!")
        return

    logger.info(f"Loaded {len(test_texts)} sentences for evaluation")

    tokenizer_configs = [
        {
            'name': 'Llama-2-7B',
            'type': 'huggingface',
            'path': 'meta-llama/Llama-2-7b-hf'
        },
        {
            'name': 'Gemma-7B',
            'type': 'huggingface',
            'path': 'google/gemma-7b'
        }
    ]

    if custom_tokenizer_paths:
        if 'standard_bpe' in custom_tokenizer_paths:
            tokenizer_configs.append({
                'name': 'Simple BPE (HF)',
                'type': 'huggingface',
                'path': custom_tokenizer_paths['standard_bpe']
            })

        if 'picky_bpe' in custom_tokenizer_paths:
            tokenizer_configs.append({
                'name': 'Picky BPE (HF)',
                'type': 'huggingface',
                'path': custom_tokenizer_paths['picky_bpe']
            })

        if 'picky_bpe_custom' in custom_tokenizer_paths:
            tokenizer_configs.append({
                'name': 'Picky BPE (Custom)',
                'type': 'picky_bpe',
                'path': custom_tokenizer_paths['picky_bpe_custom']
            })

        if 'sage_sp' in custom_tokenizer_paths:
            tokenizer_configs.append({
                'name': 'SentencePiece Model',
                'type': 'sentencepiece',
                'path': custom_tokenizer_paths['sage_sp']
            })

        if 'sage' in custom_tokenizer_paths and not '/path/to/your/' in custom_tokenizer_paths['sage']:
            tokenizer_configs.append({
                'name': 'SaGe',
                'type': 'sentencepiece',
                'path': custom_tokenizer_paths['sage']
            })

    logger.info(f"Configured {len(tokenizer_configs)} tokenizers for evaluation")

    print("Evaluating tokenizers...")
    results_df = evaluate_tokenizers(test_texts, tokenizer_configs, hf_token=hf_token)

    print("\n" + "="*120)
    print("EVALUATION RESULTS")
    print("="*120)

    display_df = results_df.copy()
    numeric_cols = ['fertility', 'compression_ratio', 'coverage', 'composite_score', 'mean_length']
    for col in numeric_cols:
        if col in display_df.columns:
            display_df[col] = display_df[col].round(4)

    print(display_df[['tokenizer', 'type', 'fertility', 'compression_ratio', 'coverage', 'composite_score']].to_string(index=False))
    print("="*120)

    best_tokenizer = results_df.loc[results_df['composite_score'].idxmax()]
    print(f"\nBest tokenizer: {best_tokenizer['tokenizer']}")
    print(f"  - Composite Score: {best_tokenizer['composite_score']:.4f}")
    print(f"  - Fertility: {best_tokenizer['fertility']:.4f} (lower is better)")
    print(f"  - Compression Ratio: {best_tokenizer['compression_ratio']:.4f} (higher is better)")
    print(f"  - Coverage: {best_tokenizer['coverage']:.4f} (higher is better)")

    custom_tokenizers = results_df[results_df['type'].isin(['picky_bpe', 'sentencepiece', 'standard_bpe'])]
    if len(custom_tokenizers) > 0:
        print(f"\nCUSTOM TOKENIZER INSIGHTS:")
        for _, row in custom_tokenizers.iterrows():
            print(f"  - {row['tokenizer']}: Fertility={row['fertility']:.3f}, Coverage={row['coverage']:.3f}, Compression={row['compression_ratio']:.3f}")

    results_df.to_csv('yoruba_tokenizer_evaluation_results.csv', index=False)
    logger.info("Results saved to yoruba_tokenizer_evaluation_results.csv")

if __name__ == "__main__":

    hf_token = getpass.getpass("Hugging Face token : ")

    CUSTOM_TOKENIZER_PATHS = {
        'simple_bpe': '/content/drive/My Drive/Colab Notebooks/LRLs/yoruba/tokenizers/standard_bpe',
        'picky_bpe': '/content/drive/My Drive/Colab Notebooks/LRLs/yoruba/tokenizers/picky_bpe',
        'sage': '/content/drive/My Drive/Colab Notebooks/LRLs/yoruba/tokenizers/sage,
    }

    DATASET_CONFIG = {
        'dataset_path': '/content/drive/My Drive/Colab Notebooks/LRLs/yoruba/dataset/yo_eval.txt',
        'file_format': 'txt',
        'text_column': 'text',
        'max_sentences': 5000,
        'hf_token': hf_token,
        'custom_tokenizer_paths': CUSTOM_TOKENIZER_PATHS
    }

    main(**DATASET_CONFIG)