In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import json
import torch
import logging
import re
import random
import getpass
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
from tqdm import tqdm
from transformers import (
    LlamaForCausalLM,
    GemmaForCausalLM,
    AutoTokenizer,
    set_seed
)
from rouge_score import rouge_scorer
import sys
from dataclasses import dataclass
from functools import lru_cache

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

@dataclass
class ModelConfig:
    name: str
    model_path: str
    tokenizer_path: str
    model_type: str
    use_extended_tokenizer: bool = False
    bpe_model_path: Optional[str] = None
    hf_token_var: Optional[str] = None

@dataclass
class EvaluationConfig:
    train_data_path: str
    test_data_path: str
    seed: int = 42
    max_context_length: int = 2048
    max_new_tokens: int = 150
    temperature: float = 0.7
    top_p: float = 0.9
    top_k: int = 50
    repetition_penalty: float = 1.05

class BPETokenizer:
    def __init__(self, model_path: str):
        self.model_path = Path(model_path)
        with open(self.model_path / "simple_bpe_model.json", 'r', encoding='utf-8') as f:
            model_data = json.load(f)

        self.vocab = model_data['vocab']
        self.id2token = {int(k): v for k, v in model_data['id2token'].items()}
        self.merges = [(m['left'], m['right']) for m in model_data['merges']]
        self._merge_pairs = set(self.merges)

        self.unk_token = '<unk>'
        self.bos_token = '<s>'
        self.eos_token = '</s>'
        self.pad_token = '<pad>'

    @property
    def vocab_size(self):
        return len(self.vocab)

    def get_vocab(self):
        return self.vocab.copy()

    def _tokenize(self, text: str):
        if not text:
            return []

        tokens = [char if char in self.vocab else self.unk_token for char in text]
        return self._apply_merges(tokens)

    def _apply_merges(self, tokens):
        changed = True
        while changed:
            changed = False
            i = 0
            while i < len(tokens) - 1:
                if (tokens[i], tokens[i + 1]) in self._merge_pairs:
                    tokens[i:i + 2] = [tokens[i] + tokens[i + 1]]
                    changed = True
                else:
                    i += 1
        return tokens

    @lru_cache(maxsize=10000)
    def _convert_token_to_id(self, token):
        return self.vocab.get(token, self.vocab.get(self.unk_token, 0))

    @lru_cache(maxsize=10000)
    def _convert_id_to_token(self, index):
        return self.id2token.get(index, self.unk_token)

    def convert_tokens_to_string(self, tokens):
        return ''.join(tokens)

class ExtendedTokenizer:
    def __init__(self, base_tokenizer: AutoTokenizer, bpe_tokenizer: BPETokenizer, combined_vocab: Dict[str, int]):
        self.base_tokenizer = base_tokenizer
        self.bpe_tokenizer = bpe_tokenizer
        self.vocab = combined_vocab
        self.id2token = {token_id: token for token, token_id in combined_vocab.items()}

        self.tamil_pattern = re.compile(r'[\u0B80-\u0BFF]')
        self.english_pattern = re.compile(r'[a-zA-Z]')

        self.unk_token = base_tokenizer.unk_token
        self.bos_token = base_tokenizer.bos_token
        self.eos_token = base_tokenizer.eos_token
        self.pad_token = '<pad>'

    @property
    def vocab_size(self):
        return len(self.vocab)

    @property
    def pad_token_id(self):
        return self.vocab.get(self.pad_token, 32000)

    @property
    def eos_token_id(self):
        return self.vocab.get(self.eos_token, 2)

    def _classify_text_type(self, text: str) -> str:
        analyzable_chars = re.sub(r'[^\w]', '', text, flags=re.UNICODE)
        if not analyzable_chars:
            return "base"

        tamil_chars = len(self.tamil_pattern.findall(analyzable_chars))
        if tamil_chars == 0:
            return "base"

        english_chars = len(self.english_pattern.findall(analyzable_chars))
        if english_chars == 0 and tamil_chars / len(analyzable_chars) > 0.8:
            return "bpe"

        return "base"

    def encode(self, text: str, add_special_tokens=True, return_tensors=None):
        text_type = self._classify_text_type(text)

        if text_type == "bpe":
            try:
                tokens = self.bpe_tokenizer._tokenize(text)
                if all(token in self.vocab for token in tokens):
                    token_ids = [self.vocab[token] for token in tokens]
                else:
                    return self.base_tokenizer.encode(text, add_special_tokens=add_special_tokens, return_tensors=return_tensors)
            except:
                return self.base_tokenizer.encode(text, add_special_tokens=add_special_tokens, return_tensors=return_tensors)
        else:
            return self.base_tokenizer.encode(text, add_special_tokens=add_special_tokens, return_tensors=return_tensors)

        if add_special_tokens:
            token_ids = [self.vocab[self.bos_token]] + token_ids + [self.vocab[self.eos_token]]

        if return_tensors == "pt":
            return torch.tensor([token_ids])
        return token_ids

    def decode(self, token_ids, skip_special_tokens=True):
        if hasattr(token_ids, 'tolist'):
            token_ids = token_ids.tolist()
        if isinstance(token_ids[0], list):
            token_ids = token_ids[0]

        tokens = [self.id2token.get(tid, self.unk_token) for tid in token_ids]

        if skip_special_tokens:
            special_tokens = {self.bos_token, self.eos_token, self.pad_token}
            tokens = [t for t in tokens if t not in special_tokens]

        base_vocab = set(self.base_tokenizer.get_vocab().keys())
        non_base_tokens = sum(1 for token in tokens if token not in base_vocab)

        if non_base_tokens > len(tokens) * 0.5:
            return ''.join(tokens)
        return self.base_tokenizer.convert_tokens_to_string(tokens)

class MultiModelEvaluator:
    def __init__(self, config: EvaluationConfig, model_configs: List[ModelConfig]):
        set_seed(config.seed)
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.models = {}
        self.tokenizers = {}

        self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=False)

        for model_config in model_configs:
            self._load_model(model_config)

    def _load_model(self, model_config: ModelConfig):
        logger.info(f"Loading {model_config.name}")

        token = None
        if model_config.hf_token_var and model_config.hf_token_var in os.environ:
            token = os.environ[model_config.hf_token_var]

        try:
            if model_config.use_extended_tokenizer:
                base_tokenizer = AutoTokenizer.from_pretrained(model_config.tokenizer_path, token=token)
                bpe_tokenizer = BPETokenizer(model_config.bpe_model_path)

                vocab_file = Path(model_config.tokenizer_path) / 'vocab.json'
                with open(vocab_file, 'r', encoding='utf-8') as f:
                    combined_vocab = json.load(f)

                tokenizer = ExtendedTokenizer(base_tokenizer, bpe_tokenizer, combined_vocab)
            else:
                tokenizer = AutoTokenizer.from_pretrained(model_config.tokenizer_path, token=token)
                if tokenizer.pad_token is None:
                    tokenizer.pad_token = tokenizer.eos_token

            if model_config.model_type == 'llama2':
                model = LlamaForCausalLM.from_pretrained(
                    model_config.model_path,
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                    device_map="auto" if torch.cuda.is_available() else None,
                    low_cpu_mem_usage=True,
                    token=token
                )
            elif model_config.model_type == 'gemma':
                model = GemmaForCausalLM.from_pretrained(
                    model_config.model_path,
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                    device_map="auto" if torch.cuda.is_available() else None,
                    low_cpu_mem_usage=True,
                    token=token
                )
            else:
                raise ValueError(f"Unsupported model type: {model_config.model_type}")

            model.eval()

            self.models[model_config.name] = model
            self.tokenizers[model_config.name] = tokenizer

            logger.info(f"Successfully loaded {model_config.name}")

        except Exception as e:
            logger.error(f"Failed to load {model_config.name}: {e}")

    def _encode_text(self, text: str, tokenizer, add_special_tokens: bool = True):
        if isinstance(tokenizer, ExtendedTokenizer):
            return tokenizer.encode(text, add_special_tokens=add_special_tokens)
        return tokenizer.encode(text, add_special_tokens=add_special_tokens)

    def _decode_tokens(self, token_ids, tokenizer, skip_special_tokens: bool = True) -> str:
        if isinstance(tokenizer, ExtendedTokenizer):
            return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
        return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)

    def _get_special_token_ids(self, tokenizer):
        return {
            'pad_token_id': tokenizer.pad_token_id,
            'eos_token_id': tokenizer.eos_token_id
        }

    def generate_summary(self, prompt: str, model_name: str) -> str:
        if model_name not in self.models:
            return ""

        model = self.models[model_name]
        tokenizer = self.tokenizers[model_name]

        try:
            input_ids = self._encode_text(prompt, tokenizer, add_special_tokens=True)

            if len(input_ids) > self.config.max_context_length - self.config.max_new_tokens:
                input_ids = input_ids[-(self.config.max_context_length - self.config.max_new_tokens):]

            input_tensor = torch.tensor([input_ids]).to(self.device)
            special_tokens = self._get_special_token_ids(tokenizer)

            with torch.no_grad():
                generated_tokens = model.generate(
                    input_ids=input_tensor,
                    max_new_tokens=self.config.max_new_tokens,
                    do_sample=True,
                    temperature=self.config.temperature,
                    top_p=self.config.top_p,
                    top_k=self.config.top_k,
                    pad_token_id=special_tokens['pad_token_id'],
                    eos_token_id=special_tokens['eos_token_id'],
                    repetition_penalty=self.config.repetition_penalty,
                    early_stopping=True,
                    use_cache=True
                )

            new_tokens = generated_tokens[0][len(input_ids):].tolist()
            summary = self._decode_tokens(new_tokens, tokenizer, skip_special_tokens=True)
            summary = self._post_process_summary(summary)

            return summary

        except Exception as e:
            logger.error(f"Generation failed for {model_name}: {e}")
            return ""

    def _post_process_summary(self, summary: str) -> str:
        if not summary:
            return ""

        summary = re.sub(r'^[:\-\s]+', '', summary)
        summary = re.sub(r'\n+', ' ', summary)
        summary = re.sub(r'\s+', ' ', summary)

        if summary and not re.search(r'[.!?।॥]$', summary):
            summary += '.'

        return summary.strip()

    def create_prompt(self, examples: List[Dict], target_text: str, n_shots: int) -> str:
        selected_examples = random.sample(examples, min(n_shots, len(examples))) if n_shots > 0 else []

        truncated_target = target_text[:800] + "..." if len(target_text) > 800 else target_text

        if n_shots == 0:
            return f"Summarize the following text concisely:\n\n{truncated_target}\n\nSummary:"

        prompt_parts = ["Write a summary in தமிழ் based on the examples given below:", ""]

        for i, example in enumerate(selected_examples):
            ex_text = example['text'][:400] + "..." if len(example['text']) > 400 else example['text']
            ex_summary = example['summary'][:100] if len(example['summary']) > 100 else example['summary']

            prompt_parts.append(f"Text {i+1}: {ex_text}")
            prompt_parts.append(f"Summary {i+1}: {ex_summary}")
            prompt_parts.append("")

        prompt_parts.append(f"Text: {truncated_target}")
        prompt_parts.append("Summary:")

        return "\n".join(prompt_parts)

    def load_dataset(self, file_path: str, max_samples: Optional[int] = None) -> List[Dict]:
        data = []

        with open(file_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f):
                if max_samples and len(data) >= max_samples:
                    break

                try:
                    example = json.loads(line.strip())

                    if all(key in example for key in ['text', 'summary']) and \
                       example['text'].strip() and example['summary'].strip() and \
                       len(example['text']) >= 20 and len(example['summary']) >= 5:

                        data.append({
                            'text': example['text'].strip(),
                            'summary': example['summary'].strip(),
                            'id': example.get('id', f'item_{line_num}')
                        })
                except:
                    continue

        logger.info(f"Loaded {len(data)} examples from {file_path}")
        return data

    def evaluate_rouge(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
        rouge_scores = {
            'rouge1': 0.0,
            'rouge2': 0.0,
            'rougeL': 0.0
        }

        valid_pairs = 0

        for pred, ref in zip(predictions, references):
            pred = pred.strip()
            ref = ref.strip()

            if not pred or not ref:
                continue

            try:
                scores = self.rouge_scorer.score(ref, pred)
                rouge_scores['rouge1'] += scores['rouge1'].fmeasure
                rouge_scores['rouge2'] += scores['rouge2'].fmeasure
                rouge_scores['rougeL'] += scores['rougeL'].fmeasure
                valid_pairs += 1
            except:
                continue

        if valid_pairs > 0:
            for metric in rouge_scores:
                rouge_scores[metric] /= valid_pairs

        return rouge_scores

    def run_evaluation(self, train_data: List[Dict], test_data: List[Dict],
                      n_shots_list: List[int] = [0, 3], max_test_samples: int = 50) -> Dict:
        results = {}

        test_sample = random.sample(test_data, max_test_samples) if len(test_data) > max_test_samples else test_data

        logger.info(f"Evaluating on {len(test_sample)} test samples")

        for model_name in self.models.keys():
            logger.info(f"\nEvaluating {model_name}")
            model_results = {}

            for n_shots in n_shots_list:
                config_key = f"{n_shots}_shot"
                logger.info(f"  {config_key}")

                predictions = []
                references = []

                for example in tqdm(test_sample, desc=f"{model_name}_{config_key}"):
                    prompt = self.create_prompt(train_data, example['text'], n_shots)
                    predicted_summary = self.generate_summary(prompt, model_name)

                    predictions.append(predicted_summary)
                    references.append(example['summary'])

                rouge_results = self.evaluate_rouge(predictions, references)

                model_results[config_key] = {
                    'rouge1': rouge_results['rouge1'],
                    'rouge2': rouge_results['rouge2'],
                    'rougeL': rouge_results['rougeL']
                }

                logger.info(f"    R1: {rouge_results['rouge1']:.4f}, R2: {rouge_results['rouge2']:.4f}, RL: {rouge_results['rougeL']:.4f}")

            results[model_name] = model_results

        return results

def main():
    print("=" * 60)
    print("TAMIL SUMMARIZATION EVALUATION")
    print("=" * 60)

    llama_token = getpass.getpass("Enter HuggingFace token for Llama-2: ").strip()
    if llama_token:
        os.environ['LLAMA_HF_TOKEN'] = llama_token
    else:
        print("Error: Llama token is required.")
        return

    gemma_token = getpass.getpass("Enter HuggingFace token for Gemma: ").strip()
    if gemma_token:
        os.environ['GEMMA_HF_TOKEN'] = gemma_token
    else:
        print("Error: Gemma token is required.")
        return

    model_configs = [
        ModelConfig(
            name="llama2_7b_base",
            model_path="meta-llama/Llama-2-7b-hf",
            tokenizer_path="meta-llama/Llama-2-7b-hf",
            model_type="llama2",
            hf_token_var="LLAMA_HF_TOKEN"
        ),
        ModelConfig(
            name="gemma_7b_base",
            model_path="google/gemma-7b",
            tokenizer_path="google/gemma-7b",
            model_type="gemma",
            hf_token_var="GEMMA_HF_TOKEN"
        ),
        ModelConfig(
            name="llama2_mean_init_lapt",
            model_path="/content/drive/My Drive/Colab Notebooks/LRLs/tamil/models/llama2_mean_init_lapt/final_merged_model",
            tokenizer_path="/content/drive/My Drive/Colab Notebooks/LRLs/tamil/models/llama2_mean_init_lapt/final_merged_model",
            model_type="llama2",
            use_extended_tokenizer=True,
            bpe_model_path="/content/drive/My Drive/Colab Notebooks/LRLs/tamil/tokenizers/grapheme_picky_bpe"
        ),
        ModelConfig(
            name="gemma_mean_init_lapt",
            model_path="/content/drive/My Drive/Colab Notebooks/LRLs/tamil/models/gemma_mean_init_lapt/final_merged_model",
            tokenizer_path="/content/drive/My Drive/Colab Notebooks/LRLs/tamil/models/gemma_mean_init_lapt/final_merged_model",
            model_type="gemma",
            use_extended_tokenizer=True,
            bpe_model_path="/content/drive/My Drive/Colab Notebooks/LRLs/tamil/tokenizers/grapheme_picky_bpe"
        )
    ]

    config = EvaluationConfig(
        train_data_path="/content/drive/MyDrive/Colab Notebooks/LRLs/tamil/4_evaluation/tamil_XLSum_v2.0/tamil_train.jsonl",
        test_data_path="/content/drive/MyDrive/Colab Notebooks/LRLs/tamil/4_evaluation/tamil_XLSum_v2.0/tamil_test.jsonl",
        seed=42,
        max_context_length=2048,
        max_new_tokens=150,
        temperature=0.5,
        top_p=0.9,
        top_k=50
    )

    try:
        evaluator = MultiModelEvaluator(config, model_configs)

        train_data = evaluator.load_dataset(config.train_data_path, max_samples=100)
        test_data = evaluator.load_dataset(config.test_data_path, max_samples=20)

        logger.info("\nStarting evaluation")

        results = evaluator.run_evaluation(
            train_data=train_data,
            test_data=test_data,
            n_shots_list=[0, 3],
            max_test_samples=10
        )

        logger.info("\n" + "="*60)
        logger.info("COMPARATIVE RESULTS")
        logger.info("="*60)

        for config_name in ['0_shot', '3_shot']:
            logger.info(f"\n{config_name.upper()}:")
            logger.info("-" * 50)
            for model_name in results.keys():
                if config_name in results[model_name]:
                    metrics = results[model_name][config_name]
                    logger.info(f"{model_name:25s}: R1={metrics['rouge1']:.4f}, R2={metrics['rouge2']:.4f}, RL={metrics['rougeL']:.4f}")

        output_path = Path("tamil_evaluation_results.json")
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
        logger.info(f"\nResults saved to {output_path}")

        return results

    except Exception as e:
        logger.error(f"Evaluation failed: {e}")
        raise

if __name__ == "__main__":
    results = main()