In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import json
import logging
import sys
import pickle
import hashlib
import getpass
import math
import re
from pathlib import Path
from datetime import datetime
from tqdm.notebook import tqdm
from typing import Dict, Optional, Tuple, List, Union
from functools import lru_cache
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import clip_grad_norm_
from transformers import (
    GemmaForCausalLM,
    AutoTokenizer,
    PreTrainedTokenizer,
    get_linear_schedule_with_warmup,
    set_seed
)
from transformers.tokenization_utils import AddedToken
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType,
    PeftModel
)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)],
    force=True
)
logger = logging.getLogger(__name__)

class BPETokenizer(PreTrainedTokenizer):

    def __init__(self, model_path: str, **kwargs):
        self.model_path = Path(model_path)
        self.model_data = self._load_model_data()
        self.vocab = self.model_data['vocab']
        self.id2token = {int(k): v for k, v in self.model_data['id2token'].items()}
        self.merges = [(m['left'], m['right']) for m in self.model_data['merges']]
        self._special_tokens = self.model_data.get('special_tokens', {})

        self._token_cache = {}
        self._max_cache_size = 10000
        self._merge_pairs = set(self.merges)

        super().__init__(
            unk_token=AddedToken('<unk>', lstrip=False, rstrip=False),
            bos_token=AddedToken('<s>', lstrip=False, rstrip=False),
            eos_token=AddedToken('</s>', lstrip=False, rstrip=False),
            pad_token=AddedToken('<pad>', lstrip=False, rstrip=False),
            **kwargs
        )

        logger.info(f"BPE tokenizer loaded: {len(self.vocab):,} tokens, {len(self.merges):,} merges")

    def _load_model_data(self):
        model_file = self.model_path / 'simple_bpe_model.json'
        if not model_file.exists():
            raise FileNotFoundError(f"BPE model not found: {model_file}")

        with open(model_file, 'r', encoding='utf-8') as f:
            return json.load(f)

    @property
    def vocab_size(self):
        return len(self.vocab)

    def get_vocab(self):
        return self.vocab.copy()

    def _tokenize(self, text: str, **kwargs):
        if not text:
            return []

        if text in self._token_cache:
            return self._token_cache[text]

        if len(text) > 5000:
            result = self._tokenize_long_text(text)
        else:
            result = self._tokenize_standard(text)

        if len(self._token_cache) < self._max_cache_size:
            self._token_cache[text] = result

        return result

    def _tokenize_standard(self, text: str):
        tokens = [char if char in self.vocab else self.unk_token for char in text]
        return self._apply_all_merges(tokens)

    def _tokenize_long_text(self, text: str):
        chunk_size = 1000
        tokens = []

        for i in range(0, len(text), chunk_size):
            chunk = text[i:i + chunk_size]
            chunk_tokens = self._tokenize_standard(chunk)
            tokens.extend(chunk_tokens)

        return tokens

    def _apply_all_merges(self, tokens):
        changed = True
        iteration_count = 0
        max_iterations = len(self.merges)

        while changed and iteration_count < max_iterations:
            changed = False
            iteration_count += 1

            i = 0
            while i < len(tokens) - 1:
                current_pair = (tokens[i], tokens[i + 1])
                if current_pair in self._merge_pairs:
                    merged_token = tokens[i] + tokens[i + 1]
                    tokens[i:i + 2] = [merged_token]
                    changed = True
                else:
                    i += 1

        return tokens

    @lru_cache(maxsize=10000)
    def _convert_token_to_id(self, token):
        return self.vocab.get(token, self.vocab.get(self.unk_token, 0))

    @lru_cache(maxsize=10000)
    def _convert_id_to_token(self, index):
        return self.id2token.get(index, self.unk_token)

    def convert_tokens_to_string(self, tokens):
        return ''.join(tokens)

class ExtendedTokenizer:

    def __init__(self, base_tokenizer: AutoTokenizer, bpe_tokenizer: BPETokenizer, combined_vocab: Dict[str, int]):
        self.base_tokenizer = base_tokenizer
        self.bpe_tokenizer = bpe_tokenizer
        self.vocab = combined_vocab
        self.id2token = {token_id: token for token, token_id in combined_vocab.items()}
        self.base_vocab_size = len(base_tokenizer.get_vocab())

        self._encode_cache = {}
        self._tokenize_cache = {}
        self._max_cache_size = 10000

        self.yoruba_pattern = re.compile(r'[ẹọṣáàéèíìóòúùńǹāēīōū]')
        self.english_pattern = re.compile(r'[a-zA-Z]')

        self.unk_token = base_tokenizer.unk_token
        self.bos_token = base_tokenizer.bos_token
        self.eos_token = base_tokenizer.eos_token
        self.pad_token = '<pad>'

    def __len__(self):
        return len(self.vocab)

    def get_vocab(self):
        return self.vocab.copy()

    @property
    def vocab_size(self):
        return len(self.vocab)

    @property
    def bos_token_id(self):
        return self.vocab.get(self.bos_token, 1)

    @property
    def eos_token_id(self):
        return self.vocab.get(self.eos_token, 2)

    @property
    def pad_token_id(self):
        return self.vocab.get(self.pad_token, 32000)

    @property
    def unk_token_id(self):
        return self.vocab.get(self.unk_token, 0)

    @property
    def pretrained_token_ids(self):
        return set(range(self.base_vocab_size))

    @property
    def new_token_ids(self):
        return set(range(self.base_vocab_size, len(self.vocab)))

    def _classify_text_type(self, text: str) -> str:
        if not text.strip():
            return "base"

        analyzable_chars = re.sub(r'[^\w]', '', text, flags=re.UNICODE)
        if not analyzable_chars:
            return "base"

        yoruba_chars = len(self.yoruba_pattern.findall(analyzable_chars))
        english_chars = len(self.english_pattern.findall(analyzable_chars))
        total_chars = len(analyzable_chars)

        yoruba_ratio = yoruba_chars / total_chars if total_chars > 0 else 0
        english_ratio = english_chars / total_chars if total_chars > 0 else 0

        if yoruba_chars > 0:
            if english_ratio < 0.8:
                return "bpe"

        return "base"

    def tokenize(self, text: str):

        if not text:
            return []

        if text in self._tokenize_cache:
            return self._tokenize_cache[text].copy()

        text_type = self._classify_text_type(text)

        tokens = []

        if text_type == "bpe":
            try:
                bpe_tokens = self.bpe_tokenizer._tokenize(text)
                if all(token in self.vocab for token in bpe_tokens):
                    tokens = bpe_tokens
                else:
                    tokens = self.base_tokenizer.tokenize(text)
            except Exception:
                tokens = self.base_tokenizer.tokenize(text)
        else:
            tokens = self.base_tokenizer.tokenize(text)

        if len(self._tokenize_cache) < self._max_cache_size:
            self._tokenize_cache[text] = tokens.copy()

        return tokens

    def encode(self, text: str, add_special_tokens=True, return_tensors=None, truncation=None, max_length=None, padding=False):
        cache_key = (text, add_special_tokens, truncation, max_length)

        if cache_key in self._encode_cache:
            token_ids = self._encode_cache[cache_key].copy()
        else:
            tokens = self.tokenize(text)
            token_ids = [self.vocab.get(token, self.unk_token_id) for token in tokens]

            if add_special_tokens:
                if self.bos_token and self.bos_token in self.vocab:
                    token_ids = [self.vocab[self.bos_token]] + token_ids
                if self.eos_token and self.eos_token in self.vocab:
                    token_ids = token_ids + [self.vocab[self.eos_token]]

            if truncation and max_length and len(token_ids) > max_length:
                token_ids = token_ids[:max_length]

            if len(self._encode_cache) < self._max_cache_size:
                self._encode_cache[cache_key] = token_ids.copy()

        if return_tensors == "pt":
            return torch.tensor([token_ids])

        return token_ids

    def decode(self, token_ids, skip_special_tokens=True):
        if hasattr(token_ids, 'tolist'):
            token_ids = token_ids.tolist()

        if isinstance(token_ids[0], list):
            token_ids = token_ids[0]

        tokens = []
        for token_id in token_ids:
            token = self.id2token.get(token_id, self.unk_token)
            if not skip_special_tokens or token not in [self.bos_token, self.eos_token, self.pad_token]:
                tokens.append(token)

        base_vocab = set(self.base_tokenizer.get_vocab().keys())
        non_base_tokens = sum(1 for token in tokens if token not in base_vocab)

        if non_base_tokens > len(tokens) * 0.5:
            return ''.join(tokens)
        else:
            return self.base_tokenizer.convert_tokens_to_string(tokens)

    def analyze_token_types(self, token_ids):
        if hasattr(token_ids, 'tolist'):
            token_ids = token_ids.tolist()

        pretrained_count = sum(1 for token_id in token_ids if token_id < self.base_vocab_size)
        new_count = len(token_ids) - pretrained_count
        return pretrained_count, new_count

    def save_pretrained(self, save_directory: str):
        save_dir = Path(save_directory)
        save_dir.mkdir(parents=True, exist_ok=True)

        with open(save_dir / 'vocab.json', 'w', encoding='utf-8') as f:
            json.dump(self.vocab, f, indent=2, ensure_ascii=False)

        config = {
            'tokenizer_class': 'ExtendedTokenizer',
            'vocab_size': len(self.vocab),
            'base_vocab_size': len(self.base_tokenizer.get_vocab()),
            'strategy': 'mean',
            'language': 'yoruba',
            'special_tokens': {
                'unk_token': self.unk_token,
                'bos_token': self.bos_token,
                'eos_token': self.eos_token,
                'pad_token': self.pad_token
            }
        }

        with open(save_dir / 'tokenizer_config.json', 'w', encoding='utf-8') as f:
            json.dump(config, f, indent=2, ensure_ascii=False)

        import shutil
        bpe_source = self.bpe_tokenizer.model_path / 'simple_bpe_model.json'
        if bpe_source.exists():
            shutil.copy2(bpe_source, save_dir / 'simple_bpe_model.json')

class TrainingConfig:
    MODEL_PATH = "/content/drive/My Drive/Colab Notebooks/LRLs/yoruba/initialized/gemma_mean_init"
    TOKENIZER_PATH = "/content/drive/My Drive/Colab Notebooks/LRLs/yoruba/initialized/gemma_mean_init"
    BPE_MODEL_PATH = "/content/drive/My Drive/Colab Notebooks/LRLs/yoruba/tokenizers/grapheme_picky_bpe"
    OUTPUT_DIR = "/content/drive/My Drive/Colab Notebooks/LRLs/yoruba/models/gemma7b_mean_init_lapt"

    TRAIN_DATASET_PATH = "/content/drive/My Drive/Colab Notebooks/LRLs/yoruba/dataset/dataset_lapt/train"
    VAL_DATASET_PATH = "/content/drive/My Drive/Colab Notebooks/LRLs/yoruba/dataset/dataset_lapt/val"
    TEXT_COLUMN_NAME = "text"

    USE_CACHED_DATASET = True
    CACHE_DIR = "/content/drive/My Drive/Colab Notebooks/LRLs/yoruba/cached_datasets"
    FORCE_RETOKENIZE = True
    MAX_EXAMPLES = 100000

    LORA_R = 16
    LORA_ALPHA = 32
    LORA_DROPOUT = 0.1
    LORA_TARGET_MODULES = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
        "lm_head", "embed_tokens"
    ]

    PRETRAINED_EMBEDDING_LR = 5e-6
    NEW_EMBEDDING_LR = 2e-5
    LORA_LR = 5e-5
    BASE_MODEL_LR = 5e-5

    BATCH_SIZE = 4
    GRADIENT_ACCUMULATION_STEPS = 4
    MAX_SEQ_LENGTH = 1024
    WEIGHT_DECAY = 0.01
    MAX_GRAD_NORM = 1.0
    NUM_EPOCHS = 1
    MAX_STEPS = 10000

    EMBEDDING_REGULARIZATION_WEIGHT = 0.0001
    GRADIENT_SCALING_FACTOR = 0.5

    SAVE_STEPS = 200
    EVAL_STEPS = 100
    LOGGING_STEPS = 10
    SAVE_TOTAL_LIMIT = 3
    EARLY_STOPPING_PATIENCE = 8
    EVAL_BATCH_SIZE = 2
    MAX_EVAL_SAMPLES = 200

    SEED = 42
    USE_GRADIENT_CHECKPOINTING = True
    MERGE_AND_SAVE_FINAL = True
    RESUME_FROM_CHECKPOINT = None

class TextDataset(Dataset):
    def __init__(
        self,
        data_path: str,
        tokenizer: ExtendedTokenizer,
        max_length: int = 512,
        stride: int = 256,
        text_column: str = "text",
        cache_dir: Optional[str] = None,
        use_cache: bool = True,
        force_retokenize: bool = False,
        max_examples: Optional[int] = None
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.stride = stride
        self.text_column = text_column
        self.examples = []

        logger.info(f"Loading dataset from {data_path}")

        if cache_dir:
            Path(cache_dir).mkdir(parents=True, exist_ok=True)

        cache_loaded = False
        if use_cache and cache_dir and not force_retokenize:
            cache_loaded = self._try_load_from_cache(data_path, cache_dir)

        if not cache_loaded:
            self._load_dataset(data_path)
            if use_cache and cache_dir and self.examples:
                self._save_to_cache(data_path, cache_dir)

        if max_examples and len(self.examples) > max_examples:
            self.examples = self.examples[:max_examples]

        logger.info(f"Loaded {len(self.examples)} examples")

    def _get_cache_key(self, data_path: str) -> str:
        cache_string = f"{data_path}_{self.max_length}_{self.stride}_{self.tokenizer.vocab_size}"
        return hashlib.md5(cache_string.encode()).hexdigest()

    def _try_load_from_cache(self, data_path: str, cache_dir: str) -> bool:
        try:
            cache_key = self._get_cache_key(data_path)
            cache_path = Path(cache_dir) / f"{Path(data_path).name}_{cache_key}.pkl"

            if not cache_path.exists():
                return False

            logger.info(f"Loading from cache: {cache_path}")
            with open(cache_path, 'rb') as f:
                cache_data = pickle.load(f)

            if cache_data.get('vocab_size') != self.tokenizer.vocab_size:
                logger.warning("Cache vocab size mismatch, will retokenize")
                return False

            self.examples = cache_data['examples']
            logger.info(f"Loaded {len(self.examples)} examples from cache")
            return True

        except Exception as e:
            logger.warning(f"Failed to load cache: {e}")
            return False

    def _save_to_cache(self, data_path: str, cache_dir: str):
        try:
            cache_key = self._get_cache_key(data_path)
            cache_path = Path(cache_dir) / f"{Path(data_path).name}_{cache_key}.pkl"

            cache_data = {
                'examples': self.examples,
                'vocab_size': self.tokenizer.vocab_size,
                'max_length': self.max_length,
                'stride': self.stride,
                'timestamp': datetime.now().isoformat()
            }

            with open(cache_path, 'wb') as f:
                pickle.dump(cache_data, f)

            logger.info(f"Saved cache: {cache_path}")
        except Exception as e:
            logger.warning(f"Failed to save cache: {e}")

    def _load_dataset(self, data_path: str):
        path = Path(data_path)
        if not path.exists():
            raise FileNotFoundError(f"Dataset path {data_path} does not exist")

        try:
            from datasets import load_from_disk
            dataset = load_from_disk(data_path)
            if hasattr(dataset, 'column_names'):
                self._process_hf_dataset(dataset)
            elif isinstance(dataset, dict):
                split_data = next(iter(dataset.values()))
                self._process_hf_dataset(split_data)
            return
        except Exception:
            pass

        text_files = list(path.glob("*.txt"))
        if text_files:
            all_texts = []
            for file_path in text_files:
                with open(file_path, 'r', encoding='utf-8') as f:
                    text = f.read().strip()
                    if text:
                        all_texts.append(text)
            self._tokenize_texts(all_texts)
            return

        json_files = list(path.glob("*.json")) + list(path.glob("*.jsonl"))
        if json_files:
            all_texts = []
            for file_path in json_files:
                with open(file_path, 'r', encoding='utf-8') as f:
                    if file_path.suffix == '.jsonl':
                        for line in f:
                            data = json.loads(line.strip())
                            text = self._extract_text_from_json(data)
                            if text:
                                all_texts.append(text)
                    else:
                        data = json.load(f)
                        if isinstance(data, list):
                            for item in data:
                                text = self._extract_text_from_json(item)
                                if text:
                                    all_texts.append(text)
                        else:
                            text = self._extract_text_from_json(data)
                            if text:
                                all_texts.append(text)
            self._tokenize_texts(all_texts)
            return

        raise ValueError(f"Could not load dataset from {data_path}")

    def _extract_text_from_json(self, data: dict) -> str:
        possible_keys = [self.text_column, 'text', 'content', 'document', 'sentence']
        for key in possible_keys:
            if key in data and data[key]:
                return str(data[key]).strip()

        for key, value in data.items():
            if isinstance(value, str) and value.strip():
                return value.strip()
        return ""

    def _process_hf_dataset(self, dataset):
        text_column = None
        possible_columns = [self.text_column, 'text', 'content', 'document']

        for col in possible_columns:
            if col in dataset.column_names:
                text_column = col
                break

        if text_column is None:
            raise ValueError(f"No text column found in {dataset.column_names}")

        texts = [item[text_column] for item in dataset if item.get(text_column)]
        self._tokenize_texts(texts)

    def _tokenize_texts(self, texts: List[str]):
        logger.info(f"Tokenizing {len(texts)} texts...")

        valid_texts = []
        for text in texts:
            text = text.strip() if text else ""
            if len(text) >= 10:
                valid_texts.append(text)

        if not valid_texts:
            logger.warning("No valid texts found for tokenization")
            return

        logger.info(f"Processing {len(valid_texts)} valid texts after filtering")

        batch_size = 100
        total_batches = (len(valid_texts) + batch_size - 1) // batch_size

        processed_count = 0
        error_count = 0

        batch_progress = tqdm(range(total_batches), desc="Tokenizing batches", unit="batch")

        for batch_idx in batch_progress:
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(valid_texts))
            batch_texts = valid_texts[start_idx:end_idx]

            batch_examples = []

            for text in batch_texts:
                try:
                    tokens = self.tokenizer.encode(text, add_special_tokens=True, truncation=False)
                    if len(tokens) < 10:
                        continue

                    pretrained_count, new_count = self.tokenizer.analyze_token_types(tokens)

                    if len(tokens) <= self.max_length:
                        batch_examples.append({
                            'tokens': tokens,
                            'pretrained_count': pretrained_count,
                            'new_count': new_count
                        })
                    else:
                        for i in range(0, len(tokens) - self.max_length + 1, self.stride):
                            chunk = tokens[i:i + self.max_length]
                            chunk_pretrained, chunk_new = self.tokenizer.analyze_token_types(chunk)
                            batch_examples.append({
                                'tokens': chunk,
                                'pretrained_count': chunk_pretrained,
                                'new_count': chunk_new
                            })

                    processed_count += 1

                except Exception as e:
                    error_count += 1
                    if error_count <= 5:
                        logger.warning(f"Error tokenizing text {processed_count + error_count}: {str(e)[:100]}")

            self.examples.extend(batch_examples)

            batch_progress.set_postfix({
                'examples': len(self.examples),
                'processed': processed_count,
                'errors': error_count
            })

        logger.info(f"Tokenization complete: {len(self.examples)} examples created from {processed_count} texts")
        if error_count > 0:
            logger.warning(f"Encountered {error_count} tokenization errors")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        tokens = example['tokens']

        if len(tokens) > self.max_length:
            tokens = tokens[:self.max_length]
        elif len(tokens) < self.max_length:
            pad_token = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
            tokens = tokens + [pad_token] * (self.max_length - len(tokens))

        input_ids = torch.tensor(tokens, dtype=torch.long)
        labels = input_ids.clone()
        attention_mask = torch.ones_like(input_ids)

        original_length = len(example['tokens'])
        if original_length < self.max_length:
            attention_mask[original_length:] = 0
            labels[original_length:] = -100

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
            'pretrained_count': example.get('pretrained_count', 0),
            'new_count': example.get('new_count', 0)
        }

def collate_fn(batch):
    max_len = max(len(item['input_ids']) for item in batch)

    input_ids = []
    attention_mask = []
    labels = []
    pretrained_counts = []
    new_counts = []

    for item in batch:
        seq_len = len(item['input_ids'])
        if seq_len < max_len:
            pad_length = max_len - seq_len
            input_ids.append(torch.cat([item['input_ids'], torch.zeros(pad_length, dtype=torch.long)]))
            attention_mask.append(torch.cat([item['attention_mask'], torch.zeros(pad_length, dtype=torch.long)]))
            labels.append(torch.cat([item['labels'], torch.full((pad_length,), -100, dtype=torch.long)]))
        else:
            input_ids.append(item['input_ids'])
            attention_mask.append(item['attention_mask'])
            labels.append(item['labels'])

        pretrained_counts.append(item.get('pretrained_count', 0))
        new_counts.append(item.get('new_count', 0))

    return {
        'input_ids': torch.stack(input_ids),
        'attention_mask': torch.stack(attention_mask),
        'labels': torch.stack(labels),
        'pretrained_counts': torch.tensor(pretrained_counts),
        'new_counts': torch.tensor(new_counts)
    }

def setup_peft_model(model, config: TrainingConfig):
    lora_config = LoraConfig(
        r=config.LORA_R,
        lora_alpha=config.LORA_ALPHA,
        target_modules=config.LORA_TARGET_MODULES,
        lora_dropout=config.LORA_DROPOUT,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
        modules_to_save=["embed_tokens", "lm_head"]
    )

    model = get_peft_model(model, lora_config)

    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())

    logger.info(f"Trainable params: {trainable_params:,}")
    logger.info(f"Total params: {total_params:,}")
    logger.info(f"Trainable %: {100 * trainable_params / total_params:.2f}")

    return model

def setup_multi_optimizer(model, tokenizer: ExtendedTokenizer, config: TrainingConfig):
    pretrained_embedding_params = []
    lora_params = []
    other_params = []

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        if 'lora' in name.lower():
            lora_params.append(param)
        elif any(emb in name.lower() for emb in ['embed_tokens', 'lm_head']):
            pretrained_embedding_params.append(param)
        else:
            other_params.append(param)

    param_groups = []

    if pretrained_embedding_params:
        param_groups.append({
            'params': pretrained_embedding_params,
            'lr': config.PRETRAINED_EMBEDDING_LR,
            'weight_decay': config.WEIGHT_DECAY * 0.5,
            'name': 'embeddings'
        })

    if lora_params:
        param_groups.append({
            'params': lora_params,
            'lr': config.LORA_LR,
            'weight_decay': config.WEIGHT_DECAY,
            'name': 'lora'
        })

    if other_params:
        param_groups.append({
            'params': other_params,
            'lr': config.BASE_MODEL_LR,
            'weight_decay': config.WEIGHT_DECAY,
            'name': 'base'
        })

    optimizer = torch.optim.AdamW(param_groups, eps=1e-8)

    logger.info(f"Optimizer groups: {len(param_groups)}")
    for group in param_groups:
        logger.info(f"  {group['name']}: LR={group['lr']}")

    return optimizer

def test_sample_perplexity(model, tokenizer, device):
    model.eval()

    test_sentences = [
        "The quick brown fox jumps over the lazy dog.",
        "Èdè Yorùbá jẹ́ èdè àbínibí wa lórí ilẹ̀ Yorùbá."
    ]

    logger.info("Testing sample perplexity:")

    with torch.no_grad():
        for sentence in test_sentences:
            inputs = tokenizer.encode(sentence, return_tensors="pt", add_special_tokens=True)
            if hasattr(inputs, 'input_ids'):
                input_ids = inputs.input_ids.to(device)
                attention_mask = inputs.attention_mask.to(device)
            else:
                input_ids = inputs.to(device)
                attention_mask = torch.ones_like(input_ids)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss
            ppl = torch.exp(loss).item()

            logger.info(f"  '{sentence}' -> PPL: {ppl:.2f}")

    model.train()

def evaluate_model(model, eval_dataloader, device, tokenizer, config):
    model.eval()

    test_sample_perplexity(model, tokenizer, device)

    total_loss = 0
    total_tokens = 0
    pretrained_tokens = 0
    new_tokens = 0

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(eval_dataloader, desc="Evaluating", leave=True, dynamic_ncols=True, miniters=1)):
            if config.MAX_EVAL_SAMPLES > 0 and batch_idx * config.EVAL_BATCH_SIZE >= config.MAX_EVAL_SAMPLES:
                break

            batch_device = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}

            outputs = model(
                input_ids=batch_device['input_ids'],
                attention_mask=batch_device['attention_mask'],
                labels=batch_device['labels']
            )

            loss = outputs.loss
            if not torch.isnan(loss):
                num_tokens = (batch_device['labels'] != -100).sum().item()
                total_loss += loss.item() * num_tokens
                total_tokens += num_tokens

                pretrained_tokens += batch.get('pretrained_counts', torch.zeros(len(batch['input_ids']))).sum().item()
                new_tokens += batch.get('new_counts', torch.zeros(len(batch['input_ids']))).sum().item()

    avg_loss = total_loss / total_tokens if total_tokens > 0 else float('inf')
    perplexity = np.exp(avg_loss)

    model.train()

    return {
        'eval_loss': avg_loss,
        'eval_perplexity': perplexity,
        'total_tokens': total_tokens,
        'pretrained_tokens': pretrained_tokens,
        'new_tokens': new_tokens
    }

def apply_gradient_scaling(model, tokenizer: ExtendedTokenizer, config: TrainingConfig):
    if not hasattr(model, 'get_input_embeddings'):
        return

    input_embeddings = model.get_input_embeddings()
    if input_embeddings is None or not hasattr(input_embeddings, 'weight') or input_embeddings.weight.grad is None:
        return

    with torch.no_grad():
        pretrained_indices = list(tokenizer.pretrained_token_ids)
        if pretrained_indices and len(pretrained_indices) < input_embeddings.weight.size(0):
            input_embeddings.weight.grad[pretrained_indices] *= config.GRADIENT_SCALING_FACTOR

def compute_regularization_loss(model, config: TrainingConfig):
    reg_loss = torch.tensor(0.0, device=next(model.parameters()).device)

    if hasattr(model, 'get_input_embeddings'):
        input_embeddings = model.get_input_embeddings()
        if input_embeddings is not None and hasattr(input_embeddings, 'weight'):
            reg_loss += torch.norm(input_embeddings.weight, p=2) * config.EMBEDDING_REGULARIZATION_WEIGHT

    return reg_loss

def save_checkpoint(model, tokenizer, optimizer, scheduler, epoch, step, best_eval_loss, config, is_best=False):
    checkpoint_dir = Path(config.OUTPUT_DIR) / f"checkpoint-{step}"
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    model.save_pretrained(checkpoint_dir)
    tokenizer.save_pretrained(checkpoint_dir)

    training_state = {
        'epoch': epoch,
        'step': step,
        'best_eval_loss': best_eval_loss,
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
    }

    torch.save(training_state, checkpoint_dir / 'training_state.pt')

    if is_best:
        best_dir = Path(config.OUTPUT_DIR) / "best_model"
        if best_dir.exists():
            import shutil
            shutil.rmtree(best_dir)
        import shutil
        shutil.copytree(checkpoint_dir, best_dir)

    checkpoints = sorted([d for d in Path(config.OUTPUT_DIR).glob("checkpoint-*")],
                        key=lambda x: int(x.name.split("-")[1]))

    if len(checkpoints) > config.SAVE_TOTAL_LIMIT:
        for checkpoint in checkpoints[:-config.SAVE_TOTAL_LIMIT]:
            import shutil
            shutil.rmtree(checkpoint)

    logger.info(f"Saved checkpoint at step {step}")

def create_combined_vocabulary(base_tokenizer: AutoTokenizer, bpe_tokenizer: BPETokenizer):
    logger.info("Creating combined vocabulary...")

    combined_vocab = base_tokenizer.get_vocab().copy()
    base_size = len(combined_vocab)
    bpe_vocab = bpe_tokenizer.get_vocab()

    id_offset = base_size
    new_tokens_added = 0

    for bpe_token, _ in bpe_vocab.items():
        if bpe_token not in combined_vocab:
            combined_vocab[bpe_token] = id_offset + new_tokens_added
            new_tokens_added += 1

    logger.info(f"Base tokens: {base_size:,}, New tokens: {new_tokens_added:,}, Combined: {len(combined_vocab):,}")
    return combined_vocab

def train():
    hf_token = getpass.getpass("Enter Hugging Face token: ")

    config = TrainingConfig()
    set_seed(config.SEED)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Device: {device}")

    Path(config.OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

    logger.info("Loading initialized model and tokenizers...")

    base_tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b", token=hf_token)
    bpe_tokenizer = BPETokenizer(model_path=config.BPE_MODEL_PATH)

    vocab_file = Path(config.TOKENIZER_PATH) / 'vocab.json'

    if vocab_file.exists():
        logger.info("Loading existing combined vocabulary...")
        with open(vocab_file, 'r', encoding='utf-8') as f:
            combined_vocab = json.load(f)
        logger.info(f"Loaded combined vocabulary: {len(combined_vocab):,} tokens")
    else:
        logger.info("Creating combined vocabulary from base + BPE...")
        combined_vocab = create_combined_vocabulary(base_tokenizer, bpe_tokenizer)

    tokenizer = ExtendedTokenizer(base_tokenizer, bpe_tokenizer, combined_vocab)

    model_config_file = Path(config.MODEL_PATH) / 'config.json'
    if not model_config_file.exists():
        raise FileNotFoundError(f"Model config not found at {model_config_file}")

    with open(model_config_file, 'r') as f:
        model_config = json.load(f)

    model_vocab_size = model_config.get('vocab_size', 0)
    tokenizer_vocab_size = len(combined_vocab)

    logger.info(f"Model vocab size: {model_vocab_size:,}")
    logger.info(f"Tokenizer vocab size: {tokenizer_vocab_size:,}")

    if abs(model_vocab_size - tokenizer_vocab_size) <= 10:
        logger.info("Vocabulary sizes match (within alignment tolerance)")
        model = GemmaForCausalLM.from_pretrained(
            config.MODEL_PATH,
            dtype=torch.float32,
            device_map="auto",
            trust_remote_code=True,
            local_files_only=True
        )
    else:
        raise ValueError(f"Vocabulary size mismatch: model={model_vocab_size:,}, tokenizer={tokenizer_vocab_size:,}")

    logger.info(f"Model vocab size: {model.config.vocab_size}")
    logger.info(f"Tokenizer vocab size: {tokenizer.vocab_size}")

    test_text = "Hello Báwo ni world"
    tokens1 = tokenizer.tokenize(test_text)
    tokens2 = tokenizer.tokenize(test_text)
    tokens3 = tokenizer.tokenize(test_text)
    assert tokens1 == tokens2 == tokens3, "Tokenizer is not consistent"

    tokens = tokenizer.encode(test_text, add_special_tokens=False)
    decoded = tokenizer.decode(tokens)
    pretrained, new = tokenizer.analyze_token_types(tokens)
    logger.info(f"Test: '{test_text}' -> {len(tokens)} tokens -> '{decoded}'")
    logger.info(f"Token analysis: {pretrained} pretrained, {new} new")
    logger.info("Tokenizer consistency verified")

    model = setup_peft_model(model, config)

    if config.USE_GRADIENT_CHECKPOINTING:
        model.enable_input_require_grads()
        model.gradient_checkpointing_enable()

    logger.info("Loading datasets...")

    train_dataset = TextDataset(
        config.TRAIN_DATASET_PATH,
        tokenizer,
        max_length=config.MAX_SEQ_LENGTH,
        text_column=config.TEXT_COLUMN_NAME,
        cache_dir=config.CACHE_DIR,
        use_cache=config.USE_CACHED_DATASET,
        force_retokenize=config.FORCE_RETOKENIZE,
        max_examples=config.MAX_EXAMPLES
    )

    eval_dataset = TextDataset(
        config.VAL_DATASET_PATH,
        tokenizer,
        max_length=config.MAX_SEQ_LENGTH,
        text_column=config.TEXT_COLUMN_NAME,
        cache_dir=config.CACHE_DIR,
        use_cache=config.USE_CACHED_DATASET,
        max_examples=1000
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=0,
        collate_fn=collate_fn
    )

    eval_dataloader = DataLoader(
        eval_dataset,
        batch_size=config.EVAL_BATCH_SIZE,
        shuffle=False,
        num_workers=0,
        collate_fn=collate_fn
    )

    num_update_steps = len(train_dataloader) // config.GRADIENT_ACCUMULATION_STEPS
    max_steps = min(config.MAX_STEPS, num_update_steps * config.NUM_EPOCHS) if config.MAX_STEPS > 0 else num_update_steps * config.NUM_EPOCHS

    optimizer = setup_multi_optimizer(model, tokenizer, config)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=500,
        num_training_steps=max_steps
    )

    logger.info(f"Total training steps: {max_steps}")

    global_step = 0
    best_eval_loss = float('inf')
    patience_counter = 0

    logger.info("Starting training...")

    initial_metrics = evaluate_model(model, eval_dataloader, device, tokenizer, config)
    logger.info(f"Initial eval loss: {initial_metrics['eval_loss']:.4f}, PPL: {initial_metrics['eval_perplexity']:.2f}")

    model.train()

    for epoch in range(config.NUM_EPOCHS):
        epoch_loss = 0
        epoch_tokens = 0

        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{config.NUM_EPOCHS}")

        for step, batch in enumerate(progress_bar):
            batch_device = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}

            outputs = model(
                input_ids=batch_device['input_ids'],
                attention_mask=batch_device['attention_mask'],
                labels=batch_device['labels']
            )

            loss = outputs.loss / config.GRADIENT_ACCUMULATION_STEPS
            reg_loss = compute_regularization_loss(model, config)
            total_loss = loss + reg_loss

            if torch.isnan(total_loss):
                logger.warning("NaN loss detected, skipping batch")
                optimizer.zero_grad()
                continue

            total_loss.backward()

            num_tokens = (batch_device['labels'] != -100).sum().item()
            epoch_loss += loss.item() * config.GRADIENT_ACCUMULATION_STEPS * num_tokens
            epoch_tokens += num_tokens

            if (step + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0:
                apply_gradient_scaling(model, tokenizer, config)
                clip_grad_norm_(model.parameters(), config.MAX_GRAD_NORM)

                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

                if global_step % config.LOGGING_STEPS == 0:
                    avg_loss = epoch_loss / epoch_tokens if epoch_tokens > 0 else 0
                    perplexity = np.exp(avg_loss)

                    progress_bar.set_postfix({
                        'loss': f'{avg_loss:.4f}',
                        'ppl': f'{perplexity:.2f}',
                        'lr': f'{scheduler.get_last_lr()[0]:.2e}'
                    })

                    logger.info(f"Step {global_step} | Loss: {avg_loss:.4f} | PPL: {perplexity:.2f}")

                if global_step % config.EVAL_STEPS == 0:
                    eval_metrics = evaluate_model(model, eval_dataloader, device, tokenizer, config)

                    logger.info(f"Eval - Loss: {eval_metrics['eval_loss']:.4f}, PPL: {eval_metrics['eval_perplexity']:.2f}")
                    logger.info(f"Tokens - Pretrained: {eval_metrics['pretrained_tokens']:,}, New: {eval_metrics['new_tokens']:,}")

                    is_best = eval_metrics['eval_loss'] < best_eval_loss
                    if is_best:
                        best_eval_loss = eval_metrics['eval_loss']
                        patience_counter = 0
                        logger.info(f"New best model! Loss: {best_eval_loss:.4f}")
                    else:
                        patience_counter += 1

                    save_checkpoint(model, tokenizer, optimizer, scheduler, epoch, global_step, best_eval_loss, config, is_best)

                    if patience_counter >= config.EARLY_STOPPING_PATIENCE:
                        logger.info("Early stopping triggered!")
                        break

                elif global_step % config.SAVE_STEPS == 0:
                    save_checkpoint(model, tokenizer, optimizer, scheduler, epoch, global_step, best_eval_loss, config)

                if global_step >= max_steps:
                    break

        if patience_counter >= config.EARLY_STOPPING_PATIENCE or global_step >= max_steps:
            break

    logger.info("Saving final model...")
    final_dir = Path(config.OUTPUT_DIR) / "final_model"
    model.save_pretrained(final_dir)
    tokenizer.save_pretrained(final_dir)

    if config.MERGE_AND_SAVE_FINAL:
        logger.info("Merging LoRA and saving...")
        merged_model = model.merge_and_unload()
        merged_dir = Path(config.OUTPUT_DIR) / "final_merged_model"
        merged_model.save_pretrained(merged_dir)
        tokenizer.save_pretrained(merged_dir)

    final_metrics = evaluate_model(model, eval_dataloader, device, tokenizer, config)

    summary = {
        'total_steps': global_step,
        'best_eval_loss': best_eval_loss,
        'final_eval_loss': final_metrics['eval_loss'],
        'final_eval_perplexity': final_metrics['eval_perplexity'],
        'model_vocab_size': model.config.vocab_size,
        'tokenizer_vocab_size': tokenizer.vocab_size,
        'pretrained_tokens': len(tokenizer.pretrained_token_ids),
        'new_tokens': len(tokenizer.new_token_ids)
    }

    with open(final_dir / 'training_summary.json', 'w') as f:
        json.dump(summary, f, indent=2)

    logger.info("Training completed.")
    logger.info(f"Best eval loss: {best_eval_loss:.4f}")
    logger.info(f"Final eval loss: {final_metrics['eval_loss']:.4f}")
    logger.info(f"Final perplexity: {final_metrics['eval_perplexity']:.2f}")
    logger.info(f"Pretrained tokens: {len(tokenizer.pretrained_token_ids):,}")
    logger.info(f"New tokens: {len(tokenizer.new_token_ids):,}")

if __name__ == "__main__":
    train()