In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import json
import math
import torch
import getpass
import re
import numpy as np
from pathlib import Path
import traceback
from typing import List, Dict, Optional, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from transformers.tokenization_utils import AddedToken

def round_to_nearest_multiple(vocabulary_size: int, multiple: int = 8) -> int:
    return math.ceil(vocabulary_size / multiple) * multiple

class BPETokenizer(PreTrainedTokenizer):
    def __init__(self, model_path: str, **kwargs):
        self.model_path = Path(model_path)
        self.model_data = self._load_model_data()
        self.vocab = self.model_data['vocab']
        self.id2token = {int(k): v for k, v in self.model_data['id2token'].items()}
        self.merges = [(m['left'], m['right']) for m in self.model_data['merges']]

        super().__init__(
            unk_token=AddedToken('<unk>', lstrip=False, rstrip=False),
            bos_token=AddedToken('<bos>', lstrip=False, rstrip=False),
            eos_token=AddedToken('<eos>', lstrip=False, rstrip=False),
            pad_token=AddedToken('<pad>', lstrip=False, rstrip=False),
            **kwargs
        )

        print(f"BPE tokenizer loaded: {len(self.vocab):,} tokens, {len(self.merges):,} merges")

    def _load_model_data(self):
        model_file = self.model_path / 'model.json'
        if not model_file.exists():
            raise FileNotFoundError(f"BPE model not found: {model_file}")

        with open(model_file, 'r', encoding='utf-8') as f:
            return json.load(f)

    @property
    def vocab_size(self):
        return len(self.vocab)

    def get_vocab(self):
        return self.vocab.copy()

    def _tokenize(self, text: str, **kwargs):
        if not text:
            return []

        tokens = [char if char in self.vocab else self.unk_token for char in text]

        for left, right in self.merges:
            tokens = self._apply_merge(tokens, left, right)

        return tokens

    def _apply_merge(self, tokens, left, right):
        new_tokens = []
        i = 0
        while i < len(tokens):
            if i < len(tokens) - 1 and tokens[i] == left and tokens[i + 1] == right:
                new_tokens.append(left + right)
                i += 2
            else:
                new_tokens.append(tokens[i])
                i += 1
        return new_tokens

    def _convert_token_to_id(self, token):
        return self.vocab.get(token, self.vocab.get(self.unk_token, 0))

    def _convert_id_to_token(self, index):
        return self.id2token.get(index, self.unk_token)

    def convert_tokens_to_string(self, tokens):
        return ''.join(tokens)

class ExtendedTokenizer:
    def __init__(self, base_tokenizer: AutoTokenizer, bpe_tokenizer: BPETokenizer, combined_vocab: Dict[str, int]):
        self.base_tokenizer = base_tokenizer
        self.bpe_tokenizer = bpe_tokenizer
        self.vocab = combined_vocab
        self.id2token = {token_id: token for token, token_id in combined_vocab.items()}

        self.yoruba_pattern = re.compile(r'[ẹọṣáàéèíìóòúùńǹāēīōū]')
        self.english_pattern = re.compile(r'[a-zA-Z]')

        self.unk_token = base_tokenizer.unk_token
        self.bos_token = base_tokenizer.bos_token
        self.eos_token = base_tokenizer.eos_token
        self.pad_token = base_tokenizer.pad_token

    def __len__(self):
        return len(self.vocab)

    def get_vocab(self):
        return self.vocab.copy()

    @property
    def vocab_size(self):
        return len(self.vocab)

    @property
    def bos_token_id(self):
        return self.vocab.get(self.bos_token, 1)

    @property
    def eos_token_id(self):
        return self.vocab.get(self.eos_token, 2)

    @property
    def pad_token_id(self):
        return self.vocab.get(self.pad_token, 0)

    @property
    def unk_token_id(self):
        return self.vocab.get(self.unk_token, 3)

    def _classify_text_type(self, text: str) -> str:
        if not text.strip():
            return "base"

        analyzable_chars = re.sub(r'[^\w]', '', text, flags=re.UNICODE)

        if not analyzable_chars:
            return "base"

        yoruba_chars = len(self.yoruba_pattern.findall(analyzable_chars))
        english_chars = len(self.english_pattern.findall(analyzable_chars))
        total_chars = len(analyzable_chars)

        if yoruba_chars == 0:
            return "base"

        if english_chars == 0 and yoruba_chars / total_chars > 0.3:
            return "bpe"

        return "base"

    def tokenize(self, text: str):
        if not text:
            return []

        text_type = self._classify_text_type(text)

        if text_type == "bpe":
            try:
                bpe_tokens = self.bpe_tokenizer._tokenize(text)
                if all(token in self.vocab for token in bpe_tokens):
                    return bpe_tokens
            except Exception:
                pass

        return self.base_tokenizer.tokenize(text)

    def encode(self, text: str, add_special_tokens=True, return_tensors=None):
        tokens = self.tokenize(text)
        token_ids = [self.vocab.get(token, self.unk_token_id) for token in tokens]

        if add_special_tokens:
            if self.bos_token and self.bos_token in self.vocab:
                token_ids = [self.vocab[self.bos_token]] + token_ids
            if self.eos_token and self.eos_token in self.vocab:
                token_ids = token_ids + [self.vocab[self.eos_token]]

        if return_tensors == "pt":
            return torch.tensor([token_ids])

        return token_ids

    def decode(self, token_ids, skip_special_tokens=True):
        if hasattr(token_ids, 'tolist'):
            token_ids = token_ids.tolist()

        if isinstance(token_ids[0], list):
            token_ids = token_ids[0]

        tokens = []
        for token_id in token_ids:
            token = self.id2token.get(token_id, self.unk_token)
            if not skip_special_tokens or token not in [self.bos_token, self.eos_token, self.pad_token]:
                tokens.append(token)

        base_vocab = set(self.base_tokenizer.get_vocab().keys())
        non_base_tokens = sum(1 for token in tokens if token not in base_vocab)

        if non_base_tokens > len(tokens) * 0.5:
            return ''.join(tokens)
        else:
            return self.base_tokenizer.convert_tokens_to_string(tokens)

    def convert_ids_to_tokens(self, token_id):
        return self.id2token.get(token_id, self.unk_token)

    def convert_tokens_to_ids(self, tokens):
        if isinstance(tokens, str):
            return self.vocab.get(tokens, self.unk_token_id)
        return [self.vocab.get(token, self.unk_token_id) for token in tokens]

    def save_pretrained(self, save_directory: str):
        save_dir = Path(save_directory)
        save_dir.mkdir(parents=True, exist_ok=True)

        with open(save_dir / 'vocab.json', 'w', encoding='utf-8') as f:
            json.dump(self.vocab, f, indent=2, ensure_ascii=False)

        config = {
            'tokenizer_class': 'ExtendedTokenizer',
            'vocab_size': len(self.vocab),
            'base_vocab_size': len(self.base_tokenizer.get_vocab()),
            'strategy': 'content_routing',
            'language': 'yoruba',
            'special_tokens': {
                'unk_token': self.unk_token,
                'bos_token': self.bos_token,
                'eos_token': self.eos_token,
                'pad_token': self.pad_token
            }
        }

        with open(save_dir / 'tokenizer_config.json', 'w', encoding='utf-8') as f:
            json.dump(config, f, indent=2, ensure_ascii=False)

        import shutil
        bpe_source = self.bpe_tokenizer.model_path / 'simple_bpe_model.json'
        if bpe_source.exists():
            shutil.copy2(bpe_source, save_dir / 'simple_bpe_model.json')

        print(f"Tokenizer saved to {save_directory}")

def create_combined_vocabulary(base_tokenizer: AutoTokenizer, bpe_tokenizer: BPETokenizer) -> Dict[str, int]:
    print("Creating combined vocabulary...")

    combined_vocab = base_tokenizer.get_vocab().copy()
    base_size = len(combined_vocab)
    bpe_vocab = bpe_tokenizer.get_vocab()

    overlapping_tokens = []
    for bpe_token in bpe_vocab:
        if bpe_token in combined_vocab:
            overlapping_tokens.append(bpe_token)

    print(f"Overlapping tokens between base and BPE: {len(overlapping_tokens):,}")
    if len(overlapping_tokens) <= 20:
        print(f"  Sample overlapping tokens: {overlapping_tokens[:10]}")
    else:
        print(f"  Sample overlapping tokens: {overlapping_tokens[:10]} ...")

    id_offset = base_size
    new_tokens_added = 0

    for bpe_token in bpe_vocab:
        if bpe_token not in combined_vocab:
            combined_vocab[bpe_token] = id_offset + new_tokens_added
            new_tokens_added += 1

    print(f"Base tokens (preserved): {base_size:,}")
    print(f"New BPE tokens added: {new_tokens_added:,}")
    print(f"Combined vocabulary: {len(combined_vocab):,}")

    return combined_vocab

def embedding_initialization(
    source_model: AutoModelForCausalLM,
    source_tokenizer: AutoTokenizer,
    target_tokenizer: ExtendedTokenizer,
    alignment_multiple: int = 8,
    tie_word_embeddings: bool = False
) -> AutoModelForCausalLM:
    print("Embedding initialization...")

    aligned_vocab_size = round_to_nearest_multiple(len(target_tokenizer), alignment_multiple)

    source_embeddings = source_model.get_input_embeddings().weight.detach().numpy()
    target_embeddings = np.random.normal(
        np.mean(source_embeddings, axis=0),
        np.std(source_embeddings, axis=0),
        (aligned_vocab_size, source_embeddings.shape[1])
    )

    if not tie_word_embeddings:
        print("You are using the output projection init.")
        source_head_embeddings = source_model.get_output_embeddings().weight.detach().numpy()
        target_head_embeddings = np.random.normal(
            np.mean(source_head_embeddings, axis=0),
            np.std(source_head_embeddings, axis=0),
            (aligned_vocab_size, source_head_embeddings.shape[1])
        )

    source_model.resize_token_embeddings(aligned_vocab_size, pad_to_multiple_of=alignment_multiple)
    source_model.get_input_embeddings().weight.data = torch.from_numpy(target_embeddings)
    source_model.config.vocab_size = aligned_vocab_size

    if not tie_word_embeddings:
        source_model.get_output_embeddings().weight.data = torch.from_numpy(target_head_embeddings)
        source_model.config.tie_word_embeddings = False
    else:
        source_model.tie_weights()

    print("Initialization completed")
    return source_model

def create_gemma_yoruba_model(
    source_model_name: str,
    bpe_model_path: str,
    hf_token: Optional[str] = None,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    tie_word_embeddings: bool = False
) -> Tuple[AutoModelForCausalLM, ExtendedTokenizer]:

    print("Loading base model and tokenizer...")
    source_model = AutoModelForCausalLM.from_pretrained(
        source_model_name,
        token=hf_token,
        torch_dtype=torch.float32,
        low_cpu_mem_usage=True
    )
    source_tokenizer = AutoTokenizer.from_pretrained(source_model_name, token=hf_token)

    print("Loading BPE tokenizer...")
    bpe_tokenizer = BPETokenizer(model_path=bpe_model_path)

    if device != "cpu":
        source_model = source_model.to(device)

    combined_vocab = create_combined_vocabulary(source_tokenizer, bpe_tokenizer)

    extended_tokenizer = ExtendedTokenizer(
        source_tokenizer, bpe_tokenizer, combined_vocab
    )

    extended_model = embedding_initialization(
        source_model, source_tokenizer, extended_tokenizer, tie_word_embeddings=tie_word_embeddings
    )

    return extended_model, extended_tokenizer

if __name__ == "__main__":
    CONFIG = {
        'source_model': "google/gemma-7b",
        'bpe_model_path': "/content/drive/My Drive/Colab Notebooks/LRLs/yoruba/tokenizers/grapheme_picky_bpe",
        'output_path': "/content/drive/My Drive/Colab Notebooks/LRLs/yoruba/initialized/gemma_bpe_random_init",
        'device': "cuda" if torch.cuda.is_available() else "cpu"
    }

    try:
        hf_token = getpass.getpass("Enter HuggingFace token: ")

        extended_model, extended_tokenizer = create_gemma_yoruba_model(
            source_model_name=CONFIG['source_model'],
            bpe_model_path=CONFIG['bpe_model_path'],
            hf_token=hf_token,
            device=CONFIG['device'],
            tie_word_embeddings=False
        )

        print(f"\nSaving to {CONFIG['output_path']}")
        output_dir = Path(CONFIG['output_path'])
        output_dir.mkdir(parents=True, exist_ok=True)

        extended_model.save_pretrained(CONFIG['output_path'])
        extended_tokenizer.save_pretrained(CONFIG['output_path'])

        extension_config = {
            'approach': 'random',
            'language': 'yoruba',
            'source_model': CONFIG['source_model'],
            'bpe_model_path': CONFIG['bpe_model_path'],
            'embeddings_tied': False,
            'base_embeddings_preserved': False,
            'final_vocab_size': extended_model.config.vocab_size,
            'embedding_dim': extended_model.get_input_embeddings().embedding_dim,
            'total_parameters': sum(p.numel() for p in extended_model.parameters())
        }

        with open(output_dir / 'extension_config.json', 'w') as f:
            json.dump(extension_config, f, indent=2)

        print(f"Final vocabulary: {extended_model.config.vocab_size:,} tokens")

    except Exception as e:
        print(f"Error: {e}")
        traceback.print_exc()