In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install transformers>=4.51.0 torch torchvision torchaudio accelerate bitsandbytes -q
!pip install sentencepiece protobuf -q
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import gc
import warnings
import re
warnings.filterwarnings('ignore')
# Check GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

In [None]:
import torch
import gc
import re
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoModelForSequenceClassification
from typing import List, Dict, Tuple, Any, Optional
from sklearn.metrics import f1_score, jaccard_score
import ast
import warnings
from huggingface_hub import login
warnings.filterwarnings('ignore')

HF_TOKEN = "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"
login(HF_TOKEN)

class ArabicMedicalAnswerClassifier:
    """
    Arabic Medical Answer Classifier supporting Arabic-optimized LLM models
    """
    
    # Model configurations with Arabic-specific models
    MODEL_CONFIGS = {
        # Arabic-specific models
        'jais': {
            'name': 'inception-mbzuai/jais-13b-chat',
            'type': 'jais',
            'context_length': 4096,
            'description': 'Jais 13B Chat - Arabic-English bilingual model',
            'arabic_native': True
        },
        'aragpt2': {
            'name': 'aubmindlab/aragpt2-mega',
            'type': 'aragpt',
            'context_length': 1024,
            'description': 'AraGPT2 Mega - Arabic GPT model',
            'arabic_native': True
        },
        'arabert': {
            'name': 'aubmindlab/bert-base-arabertv2',
            'type': 'arabert',
            'context_length': 512,
            'description': 'AraBERT v2 - Arabic BERT for classification',
            'arabic_native': True,
            'classification_model': True
        },
        'camelbert': {
            'name': 'CAMeL-Lab/bert-base-arabic-camelbert-mix',
            'type': 'camelbert',
            'context_length': 512,
            'description': 'CAMeLBERT - NYU Arabic BERT model',
            'arabic_native': True,
            'classification_model': True
        },
        'marbert': {
            'name': 'UBC-NLP/MARBERT',
            'type': 'marbert',
            'context_length': 512,
            'description': 'MARBERT - Multilingual Arabic BERT',
            'arabic_native': True,
            'classification_model': True
        },
        'arat5': {
            'name': 'UBC-NLP/AraT5-base',
            'type': 'arat5',
            'context_length': 512,
            'description': 'AraT5 - Arabic T5 model',
            'arabic_native': True
        },
        # Multilingual models with good Arabic support
        'qwen3': {
            'name': 'Qwen/Qwen3-14B',
            'type': 'qwen',
            'context_length': 32768,
            'description': 'Qwen3 14B - Latest Qwen model with Arabic support',
            'arabic_native': False
        },
        'qwen2': {
            'name': 'Qwen/Qwen2-7B-Instruct',
            'type': 'qwen',
            'context_length': 32768,
            'description': 'Qwen2 7B - Strong multilingual model',
            'arabic_native': False
        },
        'llama3': {
            'name': 'meta-llama/Llama-3.1-8B-Instruct',
            'type': 'llama',
            'context_length': 128000,
            'description': 'Llama 3.1 8B - Meta\'s latest model',
            'arabic_native': False
        },
        'aya': {
            'name': 'CohereForAI/aya-23-8B',
            'type': 'aya',
            'context_length': 8192,
            'description': 'Aya 23 8B - Multilingual model with Arabic focus',
            'arabic_native': False
        },
        'mt5': {
            'name': 'google/mt5-large',
            'type': 'mt5',
            'context_length': 512,
            'description': 'mT5 Large - Multilingual T5 with Arabic',
            'arabic_native': False
        }
    }
    
    def __init__(self, model_key: str = 'jais', use_quantization: bool = True, use_thinking_mode: bool = True):
        """
        Initialize classifier with specified model
        
        Args:
            model_key: Key from MODEL_CONFIGS
            use_quantization: Whether to use 4-bit quantization
            use_thinking_mode: Enable thinking mode for supported models
        """
        if model_key not in self.MODEL_CONFIGS:
            raise ValueError(f"Model '{model_key}' not supported. Available models: {list(self.MODEL_CONFIGS.keys())}")
        
        self.model_config = self.MODEL_CONFIGS[model_key]
        self.model_key = model_key
        self.use_quantization = use_quantization
        self.use_thinking_mode = use_thinking_mode and not self.model_config.get('classification_model', False)
        self.tokenizer = None
        self.model = None
        self.is_classification_model = self.model_config.get('classification_model', False)
        self.answer_strategies = {
            '1': 'Information (answers providing information, resources, etc.)',
            '2': 'Direct Guidance (answers providing suggestions, instructions, or advice)',
            '3': 'Emotional Support (answers providing approval, reassurance, or other forms of emotional support)'
        }
        self.load_model()
    
    def load_model(self):
        """Load the specified model with appropriate configurations"""
        print(f"Loading {self.model_config['description']}...")
        print(f"Model: {self.model_config['name']}")
        print(f"Arabic Native: {'Yes' if self.model_config.get('arabic_native', False) else 'No'}")
        
        quantization_config = None
        if self.use_quantization and not self.is_classification_model:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4"
            )
        
        try:
            # Load tokenizer
            tokenizer_kwargs = {
                'trust_remote_code': True,
                'padding_side': 'left' if not self.is_classification_model else 'right'
            }
            
            if HF_TOKEN:
                tokenizer_kwargs['token'] = HF_TOKEN
                
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_config['name'],
                **tokenizer_kwargs
            )
            
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            # Load model
            model_kwargs = {
                'trust_remote_code': True,
                'low_cpu_mem_usage': True,
                'device_map': 'auto'
            }
            
            if not self.is_classification_model:
                model_kwargs['torch_dtype'] = torch.float16
                if quantization_config:
                    model_kwargs['quantization_config'] = quantization_config
                
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_config['name'],
                    token=HF_TOKEN,
                    **model_kwargs
                )
            else:
                # For classification models like AraBERT
                self.model = AutoModelForSequenceClassification.from_pretrained(
                    self.model_config['name'],
                    num_labels=3,  # For our 3 strategies
                    token=HF_TOKEN,
                    **model_kwargs
                )
            
            print(f"✅ {self.model_config['description']} loaded successfully!")
            self.print_model_info()
            
        except Exception as e:
            print(f"❌ Error loading model: {e}")
            raise
    
    def print_model_info(self):
        """Print model and memory information"""
        print(f"\n📊 Model Information:")
        print(f"Model: {self.model_config['name']}")
        print(f"Type: {self.model_config['type']}")
        print(f"Context Length: {self.model_config['context_length']:,} tokens")
        print(f"Arabic Native: {'Yes' if self.model_config.get('arabic_native', False) else 'No'}")
        print(f"Classification Model: {'Yes' if self.is_classification_model else 'No'}")
        print(f"Quantization: {'4-bit' if self.use_quantization and not self.is_classification_model else 'Full precision'}")
        print(f"Thinking Mode: {'Enabled' if self.use_thinking_mode else 'Disabled'}")
        print(f"Tokenizer Vocab Size: {len(self.tokenizer):,}")
        
        if torch.cuda.is_available():
            print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
            print(f"GPU Memory Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    
    def create_arabic_prompt(self, answer: str) -> str:
        """Create Arabic-optimized prompt for classification"""
        strategy_descriptions = """
استراتيجيات الإجابة الطبية:
(1) المعلومات - إجابات تقدم معلومات وموارد وحقائق طبية
(2) التوجيه المباشر - إجابات تقدم اقتراحات وتعليمات ونصائح محددة  
(3) الدعم العاطفي - إجابات تقدم الموافقة والطمأنينة أو أشكال أخرى من الدعم العاطفي
"""
        
        if self.use_thinking_mode:
            base_prompt = f"""أنت خبير في تحليل وتصنيف الإجابات الطبية باللغة العربية. مهمتك هي تصنيف الإجابة المعطاة إلى استراتيجية أو أكثر من الاستراتيجيات التالية:

{strategy_descriptions}

الإجابة المراد تحليلها وتصنيفها:
"{answer}"

تعليمات التحليل:
1. اقرأ الإجابة بعناية وحلل محتواها بالتفصيل
2. حدد الكلمات والعبارات المفتاحية في النص
3. اربط كل جزء من الإجابة بالاستراتيجية المناسبة
4. فكر في السياق الطبي والغرض من الإجابة
5. يمكن أن تنتمي الإجابة لأكثر من استراتيجية واحدة
6. اكتب تحليلك خطوة بخطوة
7. في النهاية، اكتب التصنيف بالتنسيق التالي بالضبط:
   "التصنيف النهائي: [الأرقام مفصولة بفواصل]"

مثال على التحليل:
إجابة: "الصداع قد يكون بسبب الإجهاد أو قلة النوم. أنصحك بأخذ قسط كاف من الراحة وشرب الماء."
التحليل: 
- "قد يكون بسبب" = معلومات طبية (استراتيجية 1)
- "أنصحك" = توجيه مباشر (استراتيجية 2)
التصنيف النهائي: [1,2]
"""
        else:
            base_prompt = f"""أنت خبير في تصنيف الإجابات الطبية باللغة العربية. حلل الإجابة التالية وصنفها حسب الاستراتيجيات المحددة:

{strategy_descriptions}

الإجابة:
"{answer}"

المطلوب:
1. حلل محتوى الإجابة وحدد الاستراتيجية أو الاستراتيجيات المناسبة
2. يمكن أن تنتمي الإجابة لأكثر من استراتيجية
3. اكتب التصنيف بالتنسيق التالي بالضبط:
   "التصنيف النهائي: [الأرقام مفصولة بفواصل]"

أمثلة:
- إجابة معلوماتية فقط: "التصنيف النهائي: [1]"
- إجابة تحتوي معلومات ونصائح: "التصنيف النهائي: [1,2]"
- إجابة داعمة عاطفياً: "التصنيف النهائي: [3]"
"""
        
        return base_prompt
    
    def create_prompt(self, answer: str) -> str:
        """Create model-specific prompt for classification"""
        base_prompt = self.create_arabic_prompt(answer)
        
        # Model-specific prompt formatting
        if self.model_config['type'] == 'jais':
            return f"### Instruction: {base_prompt}\n### Response:"
        elif self.model_config['type'] in ['llama']:
            return f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{base_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        elif self.model_config['type'] == 'qwen':
            messages = [{"role": "user", "content": base_prompt}]
            return self.tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True,
                enable_thinking=self.use_thinking_mode
            )
        elif self.model_config['type'] == 'aya':
            return f"<|USER|>{base_prompt}<|ASSISTANT|>"
        elif self.model_config['type'] in ['aragpt', 'arat5', 'mt5']:
            return base_prompt
        else:
            return f"Human: {base_prompt}\n\nAssistant:"
    
    def classify_answer_generative(self, answer: str, max_new_tokens: int = 2000) -> List[str]:
        """
        Classify using generative models (GPT-style)
        """
        prompt = self.create_prompt(answer)
        
        # Adjust max_length based on model context length
        max_input_length = min(self.model_config['context_length'] - max_new_tokens, 2048)
        
        model_inputs = self.tokenizer(
            prompt, 
            return_tensors="pt", 
            truncation=True,
            max_length=max_input_length
        ).to(self.model.device)
        
        # Model-specific generation parameters
        gen_kwargs = {
            'max_new_tokens': max_new_tokens,
            'temperature': 0.3,  # Lower temperature for more consistent results
            'top_p': 0.9,
            'top_k': 50,
            'do_sample': True,
            'pad_token_id': self.tokenizer.pad_token_id,
            'eos_token_id': self.tokenizer.eos_token_id,
            'repetition_penalty': 1.1
        }
        
        # Arabic model optimizations
        if self.model_config.get('arabic_native', False):
            gen_kwargs.update({
                'temperature': 0.2,  # Even lower for Arabic models
                'top_p': 0.85,
                'repetition_penalty': 1.05
            })
        
        if self.model_config['type'] == 'jais':
            gen_kwargs.update({
                'temperature': 0.1,
                'top_p': 0.8,
                'max_new_tokens': min(max_new_tokens, 1000)
            })
        elif self.model_config['type'] in ['aragpt', 'arat5']:
            gen_kwargs.update({
                'temperature': 0.2,
                'max_new_tokens': min(max_new_tokens, 512)
            })
        
        with torch.no_grad():
            generated_ids = self.model.generate(
                **model_inputs,
                **gen_kwargs
            )
        
        output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]
        content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip()
        
        strategies = self.extract_answer_strategies(content)
        return strategies
    
    def classify_answer_bert(self, answer: str) -> List[str]:
        """
        Classify using BERT-style models (classification head)
        """
        # For BERT models, we need to implement a different approach
        # This is a simplified version - in practice, you'd need to fine-tune these models
        inputs = self.tokenizer(
            answer, 
            return_tensors="pt", 
            truncation=True,
            padding=True,
            max_length=self.model_config['context_length']
        ).to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            predictions = torch.sigmoid(outputs.logits)  # Multi-label classification
            
        # Convert predictions to strategy labels (threshold = 0.5)
        strategies = []
        threshold = 0.5
        for i, prob in enumerate(predictions[0]):
            if prob > threshold:
                strategies.append(str(i + 1))
        
        return strategies if strategies else ['1']  # Default to strategy 1
    
    def classify_answer(self, answer: str, max_new_tokens: int = 2000) -> List[str]:
        """
        Classify Arabic medical answer into strategy categories
        """
        if self.is_classification_model:
            return self.classify_answer_bert(answer)
        else:
            return self.classify_answer_generative(answer, max_new_tokens)
    
    def extract_answer_strategies(self, response: str) -> List[str]:
        """Extract answer strategies from Arabic response text"""
        # Arabic patterns (primary)
        arabic_patterns = [
            r'التصنيف النهائي:\s*\[([123,\s]+)\]',
            r'الاستراتيجيات:\s*\[([123,\s]+)\]',
            r'التصنيف:\s*\[([123,\s]+)\]',
            r'النتيجة:\s*\[([123,\s]+)\]',
            r'الإجابة:\s*\[([123,\s]+)\]',
            r'التصنيف النهائي:\s*([123,\s]+)',
            r'الاستراتيجية:\s*([123,\s]+)',
        ]
        
        for pattern in arabic_patterns:
            match = re.search(pattern, response, re.IGNORECASE)
            if match:
                strategies_str = match.group(1)
                # Remove brackets if present
                strategies_str = re.sub(r'[\[\]]', '', strategies_str)
                strategies = [strat.strip() for strat in strategies_str.split(',')]
                valid_strategies = [strat for strat in strategies if strat in ['1', '2', '3']]
                if valid_strategies:
                    return valid_strategies
        
        # English patterns (fallback)
        english_patterns = [
            r'Final Classification:\s*\[([123,\s]+)\]',
            r'Strategies:\s*\[([123,\s]+)\]',
            r'Classification:\s*\[([123,\s]+)\]',
            r'Answer:\s*\[([123,\s]+)\]',
        ]
        
        for pattern in english_patterns:
            match = re.search(pattern, response, re.IGNORECASE)
            if match:
                strategies_str = match.group(1)
                strategies = [strat.strip() for strat in strategies_str.split(',')]
                valid_strategies = [strat for strat in strategies if strat in ['1', '2', '3']]
                if valid_strategies:
                    return valid_strategies
        
        # Fallback: look for individual numbers
        found_strategies = []
        for strategy in ['1', '2', '3']:
            if f'({strategy})' in response or f'[{strategy}]' in response or f' {strategy} ' in response:
                found_strategies.append(strategy)
        
        if found_strategies:
            return found_strategies
        
        # Last resort: analyze content for keywords
        return self.fallback_classification(response)
    
    def fallback_classification(self, response: str) -> List[str]:
        """Fallback classification based on Arabic keywords"""
        strategies = []
        
        # Keywords for each strategy in Arabic
        info_keywords = ['معلومات', 'حقائق', 'بيانات', 'إحصائيات', 'دراسات', 'أبحاث', 'يُعرف', 'يُسمى']
        guidance_keywords = ['انصح', 'يجب', 'ينبغي', 'حاول', 'تجنب', 'اتبع', 'استشر', 'راجع']
        support_keywords = ['لا تقلق', 'طبيعي', 'شائع', 'تحسن', 'بخير', 'مطمئن', 'دعم', 'مساندة']
        
        response_lower = response.lower()
        
        if any(keyword in response_lower for keyword in info_keywords):
            strategies.append('1')
        if any(keyword in response_lower for keyword in guidance_keywords):
            strategies.append('2')
        if any(keyword in response_lower for keyword in support_keywords):
            strategies.append('3')
        
        return strategies if strategies else ['1']  # Default to information
    
    def process_test_dataset(self, df: pd.DataFrame, max_new_tokens: int = 2000, show_progress: bool = True) -> pd.DataFrame:
        """
        Process dataset for answer classification
        """
        print(f"🚀 Starting Arabic Medical Answer Classification with {self.model_config['description']}...")
        print(f"Dataset size: {len(df)} samples")
        print(f"Arabic Native Model: {'Yes' if self.model_config.get('arabic_native', False) else 'No'}")
        print("-" * 80)
        
        predictions = []
        
        for idx, row in df.iterrows():
            if show_progress and idx % 10 == 0:
                print(f"Processing sample {idx+1}/{len(df)} - Current GPU memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
            
            try:
                strategies = self.classify_answer(row['answer'], max_new_tokens=max_new_tokens)
                prediction_str = ', '.join(sorted(strategies))
                predictions.append(prediction_str)
                
                if show_progress and idx < 3:
                    print(f"Sample {idx+1} prediction: {prediction_str}")
                    
            except Exception as e:
                print(f"Error processing answer {idx}: {e}")
                predictions.append('1')
            
            if idx % 20 == 0:
                self.cleanup_memory()
        
        result_df = pd.DataFrame({'prediction': predictions})
        return result_df
    
    def cleanup_memory(self):
        """Clean up GPU memory"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
    
    @classmethod
    def list_available_models(cls):
        """List all available models with descriptions"""
        print("📋 Available Arabic LLM Models:")
        print("-" * 80)
        
        # Group by Arabic native vs multilingual
        arabic_native = []
        multilingual = []
        
        for key, config in cls.MODEL_CONFIGS.items():
            if config.get('arabic_native', False):
                arabic_native.append((key, config))
            else:
                multilingual.append((key, config))
        
        print("🇸🇦 Arabic Native Models:")
        for key, config in arabic_native:
            print(f"  Key: '{key}'")
            print(f"    Model: {config['name']}")
            print(f"    Description: {config['description']}")
            print(f"    Context Length: {config['context_length']:,} tokens")
            print(f"    Type: {config['type']}")
            print()
        
        print("🌍 Multilingual Models with Arabic Support:")
        for key, config in multilingual:
            print(f"  Key: '{key}'")
            print(f"    Model: {config['name']}")
            print(f"    Description: {config['description']}")
            print(f"    Context Length: {config['context_length']:,} tokens")
            print(f"    Type: {config['type']}")
            print()

def evaluate_model_on_training_data(
    train_file_path: str,
    model_key: str = 'jais',
    use_quantization: bool = True,
    max_new_tokens: int = 2000
) -> dict:
    """
    Evaluate the model on the training dataset using Weighted F1 Score and Jaccard Score.
    """
    
    print(f"🚀 Initializing {model_key} for evaluation...")
    try:
        classifier = ArabicMedicalAnswerClassifier(
            model_key=model_key,
            use_quantization=use_quantization
        )
    except Exception as e:
        print(f"❌ Failed to initialize model: {e}")
        return None

    # Load training dataset
    try:
        df = pd.read_csv(train_file_path, sep='\t')
        if 'answer' not in df.columns or 'final_AS' not in df.columns:
            raise ValueError("Dataset must contain 'answer' and 'final_AS' columns")
        print(f"✅ Training dataset loaded: {len(df)} samples")
    except Exception as e:
        print(f"❌ Error loading training dataset: {e}")
        return None

    # Parse final_AS labels
    def parse_labels(label):
        if isinstance(label, str):
            try:
                if label.startswith('['):
                    return ast.literal_eval(label)
                else:
                    return [strat.strip() for strat in label.split(',')]
            except:
                print(f"Warning: Could not parse label '{label}', defaulting to ['1']")
                return ['1']
        return label

    df['final_AS'] = df['final_AS'].apply(parse_labels)

    # Generate predictions
    print(f"🚀 Generating predictions for {len(df)} samples...")
    predictions = classifier.process_test_dataset(df, max_new_tokens=max_new_tokens)

    # Convert predictions and true labels to multi-label binary format
    all_strategies = ['1', '2', '3']
    y_true = []
    y_pred = []

    for true_labels, pred_labels in zip(df['final_AS'], predictions['prediction']):
        true_vec = [1 if strat in true_labels else 0 for strat in all_strategies]
        y_true.append(true_vec)
        pred_strats = [strat.strip() for strat in pred_labels.split(',') if strat.strip() in all_strategies]
        pred_vec = [1 if strat in pred_strats else 0 for strat in all_strategies]
        y_pred.append(pred_vec)

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    # Calculate metrics
    weighted_f1 = f1_score(y_true, y_pred, average='weighted')
    jaccard = jaccard_score(y_true, y_pred, average='samples')

    # Print results
    print("\n📊 Evaluation Results:")
    print(f"Model: {classifier.model_config['description']}")
    print(f"Arabic Native: {'Yes' if classifier.model_config.get('arabic_native', False) else 'No'}")
    print(f"Weighted F1 Score: {weighted_f1:.4f}")
    print(f"Jaccard Score (samples): {jaccard:.4f}")

    # Detailed per-strategy metrics
    print("\n📋 Per-Strategy Metrics:")
    for i, strat in enumerate(all_strategies):
        strat_f1 = f1_score(y_true[:, i], y_pred[:, i])
        strat_jaccard = jaccard_score(y_true[:, i], y_pred[:, i])
        print(f"Strategy {strat} ({classifier.answer_strategies[strat]}):")
        print(f"  F1 Score: {strat_f1:.4f}")
        print(f"  Jaccard Score: {strat_jaccard:.4f}")

    # Cleanup memory
    classifier.cleanup_memory()

    return {
        'model_name': classifier.model_config['name'],
        'arabic_native': classifier.model_config.get('arabic_native', False),
        'weighted_f1': weighted_f1,
        'jaccard_score': jaccard,
        'per_strategy_f1': {strat: f1_score(y_true[:, i], y_pred[:, i]) for i, strat in enumerate(all_strategies)},
        'per_strategy_jaccard': {strat: jaccard_score(y_true[:, i], y_pred[:, i]) for i, strat in enumerate(all_strategies)}
    }

def compare_arabic_models(train_file_path: str, models_to_test: List[str] = None):
    """Compare multiple Arabic models"""
    if models_to_test is None:
        models_to_test = ['jais', 'arabert', 'aragpt2', 'aya']  # Start with a few key models
    
    results = {}
    
    for model_key in models_to_test:
        print(f"\n{'='*60}")
        print(f"Testing Model: {model_key}")
        print(f"{'='*60}")
        
        try:
            result = evaluate_model_on_training_data(
                train_file_path=train_file_path,
                model_key=model_key,
                use_quantization=True,
                max_new_tokens=1500
            )
            
            if result:
                results[model_key] = result
                print(f"✅ {model_key} completed successfully")
            else:
                print(f"❌ {model_key} failed")
                
        except Exception as e:
            print(f"❌ Error testing {model_key}: {e}")
            results[model_key] = None
        
        # Clean up memory between models
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
    
    # Summary comparison
    print(f"\n{'='*80}")
    print("📊 FINAL COMPARISON RESULTS")
    print(f"{'='*80}")
    
    valid_results = {k: v for k, v in results.items() if v is not None}
    
    if valid_results:
        # Sort by weighted F1 score
        sorted_results = sorted(valid_results.items(), key=lambda x: x[1]['weighted_f1'], reverse=True)
        
        print(f"{'Rank':<5} {'Model':<15} {'Arabic Native':<15} {'Weighted F1':<12} {'Jaccard':<10}")
        print("-" * 70)
        
        for rank, (model_key, result) in enumerate(sorted_results, 1):
            arabic_native = "Yes" if result['arabic_native'] else "No"
            print(f"{rank:<5} {model_key:<15} {arabic_native:<15} {result['weighted_f1']:<12.4f} {result['jaccard_score']:<10.4f}")
        
        # Best model details
        best_model, best_result = sorted_results[0]
        print(f"\n🏆 Best Performing Model: {best_model}")
        print(f"Model Name: {best_result['model_name']}")
        print(f"Arabic Native: {'Yes' if best_result['arabic_native'] else 'No'}")
        print(f"Weighted F1 Score: {best_result['weighted_f1']:.4f}")
        print(f"Jaccard Score: {best_result['jaccard_score']:.4f}")
    
    return results

if __name__ == "__main__":
    # Update this path to your training dataset
    TRAIN_DATASET_PATH = '/kaggle/input/train-df/Train_Dev.tsv'  # Replace with actual path

    # List available models
    ArabicMedicalAnswerClassifier.list_available_models()

    # Test individual model (Arabic native)
    print("\n" + "="*80)
    print("Testing Arabic Native Model: Jais")
    print("="*80)
    
    jais_results = evaluate_model_on_training_data(
        train_file_path=TRAIN_DATASET_PATH,
        model_key='jais',
        use_quantization=True,
        max_new_tokens=1500
    )

    if jais_results:
        print(f"\n📈 Jais Results Summary:")
        print(f"Weighted F1 Score: {jais_results['weighted_f1']:.4f}")
        print(f"Jaccard Score: {jais_results['jaccard_score']:.4f}")

    # Compare multiple models
    print("\n" + "="*80)
    print("Comparing Multiple Arabic Models")
    print("="*80)
    
    # Test both Arabic native and multilingual models
    models_to_compare = [
        'jais',      # Arabic native
        'arabert',     # Multilingual with good Arabic
        'aragpt2',    # Multilingual
        'aya'        # Arabic-focused multilingual
    ]
    
    comparison_results = compare_arabic_models(
        train_file_path=TRAIN_DATASET_PATH,
        models_to_test=models_to_compare
    )
    
    # Additional analysis for Arabic models
    print("\n" + "="*80)
    print("📊 ARABIC MODEL ANALYSIS")
    print("="*80)
    
    arabic_models = []
    multilingual_models = []
    
    for model_key, result in comparison_results.items():
        if result is not None:
            if result['arabic_native']:
                arabic_models.append((model_key, result))
            else:
                multilingual_models.append((model_key, result))
    
    if arabic_models:
        print("\n🇸🇦 Arabic Native Models Performance:")
        for model_key, result in arabic_models:
            print(f"{model_key}: F1={result['weighted_f1']:.4f}, Jaccard={result['jaccard_score']:.4f}")
    
    if multilingual_models:
        print("\n🌍 Multilingual Models Performance:")
        for model_key, result in multilingual_models:
            print(f"{model_key}: F1={result['weighted_f1']:.4f}, Jaccard={result['jaccard_score']:.4f}")
    
    # Recommendations
    print("\n" + "="*80)
    print("💡 RECOMMENDATIONS")
    print("="*80)
    
    if comparison_results:
        valid_results = {k: v for k, v in comparison_results.items() if v is not None}
        if valid_results:
            best_model = max(valid_results.items(), key=lambda x: x[1]['weighted_f1'])
            
            print(f"🏆 Best Overall Model: {best_model[0]}")
            print(f"   - Model: {best_model[1]['model_name']}")
            print(f"   - Arabic Native: {'Yes' if best_model[1]['arabic_native'] else 'No'}")
            print(f"   - Performance: F1={best_model[1]['weighted_f1']:.4f}")
            
            if best_model[1]['arabic_native']:
                print(f"   - ✅ Recommended for Arabic medical texts due to native Arabic support")
            else:
                print(f"   - ⚠️  Multilingual model - consider Arabic native alternatives if available")
            
            print(f"\n📋 Usage Example:")
            print(f"classifier = ArabicMedicalAnswerClassifier(model_key='{best_model[0]}')")
            print(f"strategies = classifier.classify_answer('your_arabic_medical_answer_here')")

# Additional utility functions for Arabic text processing
class ArabicTextPreprocessor:
    """Utility class for Arabic text preprocessing"""
    
    @staticmethod
    def normalize_arabic(text: str) -> str:
        """Normalize Arabic text"""
        import re
        
        # Remove diacritics
        text = re.sub(r'[\u064B-\u065F\u0670\u0640]', '', text)
        
        # Normalize different forms of alef
        text = re.sub(r'[إأآا]', 'ا', text)
        
        # Normalize teh marbuta
        text = re.sub(r'ة', 'ه', text)
        
        # Normalize different forms of yeh
        text = re.sub(r'[يى]', 'ي', text)
        
        return text.strip()
    
    @staticmethod
    def is_arabic_text(text: str) -> bool:
        """Check if text is primarily Arabic"""
        arabic_chars = len(re.findall(r'[\u0600-\u06FF]', text))
        total_chars = len(re.findall(r'[^\s\d\W]', text))
        return arabic_chars / max(total_chars, 1) > 0.5

def create_arabic_optimized_classifier(model_key: str = 'jais') -> ArabicMedicalAnswerClassifier:
    """Create an Arabic-optimized classifier with best practices"""
    
    print(f"🚀 Creating Arabic-optimized classifier...")
    
    # Recommended settings for Arabic models
    arabic_native_models = ['jais', 'aragpt2', 'arabert', 'camelbert', 'marbert', 'arat5']
    
    if model_key in arabic_native_models:
        print(f"✅ Using Arabic native model: {model_key}")
        use_quantization = True  # Arabic models tend to be smaller, quantization helps with memory
    else:
        print(f"⚠️  Using multilingual model: {model_key}")
        use_quantization = True  # Still recommended for memory efficiency
    
    classifier = ArabicMedicalAnswerClassifier(
        model_key=model_key,
        use_quantization=use_quantization,
        use_thinking_mode=True  # Helps with Arabic reasoning
    )
    
    return classifier