In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install transformers>=4.51.0 torch torchvision torchaudio accelerate bitsandbytes -q
!pip install sentencepiece protobuf -q
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import gc
import warnings
import re
warnings.filterwarnings('ignore')
# Check GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

In [None]:
import torch
import gc
import re
import pandas as pd
import numpy as np
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    BitsAndBytesConfig
)
from typing import List, Dict, Tuple, Any, Optional
import warnings
from sklearn.metrics import f1_score, jaccard_score
from huggingface_hub import login
import ast

HF_TOKEN = "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"
login(HF_TOKEN)
warnings.filterwarnings('ignore')

class MultiModelArabicMedicalClassifier:
    """
    Arabic Medical Question Classifier supporting multiple LLM models including Unsloth and DeepSeek
    """
    
    # Model configurations with their specific settings
    MODEL_CONFIGS = {
        'qwen3': {
            'name': 'Qwen/Qwen3-14B',
            'type': 'qwen',
            'context_length': 32768,
            'description': 'Qwen3 14B - Latest Qwen model'
        },
        'jais': {
            'name': 'core42/jais-13b-chat',
            'type': 'jais',
            'context_length': 2048,
            'description': 'JAIS - Arabic-focused LLM'
        },
        'qwen2': {
            'name': 'Qwen/Qwen2-7B-Instruct',
            'type': 'qwen',
            'context_length': 32768,
            'description': 'Qwen2 7B - Strong multilingual model'
        },
        'qwen2.5': {
            'name': 'unsloth/Qwen2.5-7B-Instruct',
            'type': 'qwen',
            'context_length': 131072,
            'description': 'Qwen2.5 7B Instruct - Unsloth optimized'
        },
        'qwen2.5-unsloth': {
            'name': 'unsloth/Qwen2.5-7B-Instruct',
            'type': 'qwen',
            'context_length': 131072,
            'description': 'Unsloth Qwen2.5 7B - Optimized for inference speed'
        },
        'deepseek-r1': {
            'name': 'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',
            'type': 'qwen',
            'context_length': 131072,
            'description': 'DeepSeek R1 Distill Qwen 7B - Advanced reasoning model'
        },
        'deepseek_r1': {
            'name': 'DeepSeek-R1-Distill-Qwen-7B',
            'type': 'qwen',
            'context_length': 32768,
            'description': 'DeepSeek R1 Distill Qwen 7B - Reasoning optimized'
        },
        'llama3': {
            'name': 'meta-llama/Llama-3.1-8B-Instruct',
            'type': 'llama',
            'context_length': 128000,
            'description': 'Llama 3.1 8B - Meta\'s latest model'
        },
        'mixtral': {
            'name': 'mistralai/Mixtral-8x7B-Instruct-v0.1',
            'type': 'mixtral',
            'context_length': 32768,
            'description': 'Mixtral 8x7B - Mixture of Experts'
        },
        'command_r': {
            'name': 'CohereForAI/c4ai-command-r-v01',
            'type': 'command_r',
            'context_length': 128000,
            'description': 'Command R - Cohere\'s multilingual model'
        },
        'gemma2': {
            'name': 'google/gemma-2-9b-it',
            'type': 'gemma',
            'context_length': 8192,
            'description': 'Gemma 2 9B - Google\'s instruction-tuned model'
        },
        'phi3': {
            'name': 'microsoft/Phi-3-medium-14b-instruct',
            'type': 'phi',
            'context_length': 128000,
            'description': 'Phi-3 Medium - Microsoft\'s efficient model'
        }
    }
    
    def __init__(self, model_key: str = 'qwen2.5-unsloth', use_quantization: bool = True, use_thinking_mode: bool = True):
        """
        Initialize classifier with specified model
        
        Args:
            model_key: Key from MODEL_CONFIGS (default: qwen2.5-unsloth for better performance)
            use_quantization: Whether to use 4-bit quantization
            use_thinking_mode: Enable thinking mode for supported models
        """
        if model_key not in self.MODEL_CONFIGS:
            raise ValueError(f"Model '{model_key}' not supported. Available models: {list(self.MODEL_CONFIGS.keys())}")
        
        self.model_config = self.MODEL_CONFIGS[model_key]
        self.model_key = model_key
        self.use_quantization = use_quantization
        self.use_thinking_mode = use_thinking_mode
        self.tokenizer = None
        self.model = None
        
        self.question_categories = {
            'A': 'Diagnosis (questions about interpreting clinical findings)',
            'B': 'Treatment (questions about seeking treatments)', 
            'C': 'Anatomy and Physiology (questions about basic medical knowledge)',
            'D': 'Epidemiology (questions about the course, prognosis, and etiology of diseases)',
            'E': 'Healthy Lifestyle (questions related to diet, exercise, and mood control)',
            'F': 'Provider Choices (questions seeking recommendations for medical professionals and facilities)',
            'Z': 'Other (questions that do not fall under the above-mentioned categories)'
        }
        
        self.load_model()
    
    def load_model(self):
        """Load the specified model with appropriate configurations"""
        print(f"Loading {self.model_config['description']}...")
        print(f"Model: {self.model_config['name']}")
        
        quantization_config = None
        if self.use_quantization:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4"
            )
        
        try:
            # Special handling for Unsloth models
            if 'unsloth' in self.model_config['name']:
                print("🚀 Loading Unsloth optimized model...")
                try:
                    # Try to use Unsloth's FastLanguageModel if available
                    from unsloth import FastLanguageModel
                    self.model, self.tokenizer = FastLanguageModel.from_pretrained(
                        model_name=self.model_config['name'],
                        max_seq_length=self.model_config['context_length'],
                        dtype=torch.float16,
                        load_in_4bit=self.use_quantization,
                    )
                    FastLanguageModel.for_inference(self.model)  # Enable native 2x faster inference
                    print("✅ Unsloth FastLanguageModel loaded successfully!")
                except ImportError:
                    print("⚠️ Unsloth not available, falling back to standard transformers...")
                    # Fallback to standard transformers
                    self._load_standard_model(quantization_config)
            else:
                self._load_standard_model(quantization_config)
            
            # Set pad token if not available
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                
            print(f"✅ {self.model_config['description']} loaded successfully!")
            self.print_model_info()
            
        except Exception as e:
            print(f"❌ Error loading model: {e}")
            raise
    
    def _load_standard_model(self, quantization_config):
        """Load model using standard transformers"""
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_config['name'],
            trust_remote_code=True,
            padding_side='left'
        )
        
        model_kwargs = {
            'trust_remote_code': True,
            'torch_dtype': torch.float16,
            'low_cpu_mem_usage': True,
            'device_map': 'auto'
        }
        
        if quantization_config:
            model_kwargs['quantization_config'] = quantization_config
        
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_config['name'],
            **model_kwargs
        )
    
    def print_model_info(self):
        """Print model and memory information"""
        print(f"\n📊 Model Information:")
        print(f"Model: {self.model_config['name']}")
        print(f"Type: {self.model_config['type']}")
        print(f"Context Length: {self.model_config['context_length']:,} tokens")
        print(f"Quantization: {'4-bit' if self.use_quantization else 'Full precision'}")
        print(f"Thinking Mode: {'Enabled' if self.use_thinking_mode else 'Disabled'}")
        print(f"Tokenizer Vocab Size: {len(self.tokenizer):,}")
        
        if torch.cuda.is_available():
            print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
            print(f"GPU Memory Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    
    def create_prompt(self, question: str) -> str:
        """Create model-specific prompt for classification"""
        category_descriptions = """
فئات الأسئلة الطبية:
(A) التشخيص - أسئلة حول تفسير النتائج السريرية والأعراض
(B) العلاج - أسئلة حول البحث عن علاجات وطرق العلاج
(C) التشريح وعلم وظائف الأعضاء - أسئلة حول المعرفة الطبية الأساسية
(D) علم الأوبئة - أسئلة حول مسار المرض وتشخيصه وأسبابه
(E) نمط الحياة الصحي - أسئلة متعلقة بالنظام الغذائي والرياضة والصحة النفسية
(F) اختيار مقدم الرعاية - أسئلة تطلب توصيات للمهنيين الطبيين والمرافق
(Z) أخرى - أسئلة لا تندرج تحت الفئات المذكورة أعلاه
"""
        
        # Enhanced prompt for DeepSeek R1 model with reasoning capabilities
        if self.model_key in ['deepseek-r1', 'deepseek_r1']:
            base_prompt = f"""أنت نموذج ذكي متخصص في تحليل وتصنيف الأسئلة الطبية باللغة العربية. لديك قدرات تفكير متقدمة وتحليل عميق.

{category_descriptions}

السؤال المراد تصنيفه:
{question}

تعليمات خاصة للتحليل المتقدم:
1. استخدم قدراتك في التفكير المنطقي لتحليل السؤال بعمق
2. حلل السياق الطبي والمفاهيم المستخدمة
3. فكر في الهدف من السؤال ونوع المعلومات المطلوبة
4. اعتبر التداخل بين الفئات المختلفة
5. اشرح منطق تفكيرك خطوة بخطوة
6. حدد الكلمات المفتاحية والمؤشرات اللغوية
7. في النهاية، قدم التصنيف النهائي بالتنسيق المطلوب

التنسيق المطلوب:
"التصنيف النهائي: [أحرف الفئات مفصولة بفواصل]"

مثال: إذا كان السؤال يتعلق بالتشخيص والعلاج، فاكتب: "التصنيف النهائي: [A,B]"
"""
        elif self.use_thinking_mode:
            base_prompt = f"""أنت خبير في تصنيف الأسئلة الطبية باللغة العربية. مهمتك هي تصنيف السؤال التالي إلى فئة أو أكثر من الفئات المحددة.

{category_descriptions}

السؤال المراد تصنيفه:
{question}

تعليمات:
1. فكر بعمق في محتوى السؤال وحلل كل جانب منه
2. اشرح تفكيرك خطوة بخطوة
3. حدد الكلمات المفتاحية والمفاهيم الطبية
4. اربط السؤال بالفئات المناسبة مع التبرير
5. اختر الفئة أو الفئات الأنسب (يمكن أن يكون أكثر من فئة)
6. في النهاية، اكتب الإجابة بالتنسيق التالي:
   "التصنيف النهائي: [A,B,C]" (استخدم الأحرف المناسبة مفصولة بفواصل)

مثال على التفكير والتنسيق:
- حلل السؤال: "ما سبب هذا العرض؟"
- الكلمات المفتاحية: "سبب، عرض"
- التفكير: هذا سؤال يتعلق بتحديد أسباب الأعراض، وهو جزء من عملية التشخيص
- التصنيف النهائي: [A]

الآن حلل السؤال أعلاه بنفس الطريقة.
"""
        else:
            base_prompt = f"""أنت خبير في تصنيف الأسئلة الطبية باللغة العربية. مهمتك هي تصنيف السؤال التالي إلى فئة أو أكثر من الفئات المحددة.

{category_descriptions}

السؤال المراد تصنيفه:
{question}

تعليمات:
1. اقرأ السؤال بعناية وحلل محتواه
2. حدد الفئة أو الفئات المناسبة (يمكن أن يكون هناك أكثر من فئة واحدة)
3. اشرح سبب اختيارك لكل فئة
4. في النهاية، اكتب الإجابة بالتنسيق التالي:
   "التصنيف النهائي: [A,B,C]" (استخدم الأحرف المناسبة مفصولة بفواصل)

مثال على التنسيق:
- إذا كان السؤال عن التشخيص فقط: "التصنيف النهائي: [A]"
- إذا كان السؤال عن التشخيص والعلاج: "التصنيف النهائي: [A,B]"
"""
        
        # Apply chat template for Qwen-based models (including Unsloth and DeepSeek)
        if self.model_config['type'] == 'qwen':
            messages = [{"role": "user", "content": base_prompt}]
            try:
                return self.tokenizer.apply_chat_template(
                    messages, 
                    tokenize=False, 
                    add_generation_prompt=True,
                    enable_thinking=self.use_thinking_mode if hasattr(self.tokenizer, 'enable_thinking') else False
                )
            except:
                # Fallback for models that don't support thinking mode
                return self.tokenizer.apply_chat_template(
                    messages, 
                    tokenize=False, 
                    add_generation_prompt=True
                )
        elif self.model_config['type'] in ['llama', 'mixtral']:
            return f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{base_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        elif self.model_config['type'] == 'jais':
            return f"### Instruction: {base_prompt}\n### Response:"
        elif self.model_config['type'] == 'gemma':
            return f"<start_of_turn>user\n{base_prompt}<end_of_turn>\n<start_of_turn>model\n"
        elif self.model_config['type'] == 'phi':
            return f"<|user|>\n{base_prompt}<|end|>\n<|assistant|>\n"
        else:
            return f"Human: {base_prompt}\n\nAssistant:"
    
    def classify_question(self, question: str, max_new_tokens: int = 8000) -> List[str]:
        """
        Classify Arabic medical question into categories
        """
        prompt = self.create_prompt(question)
        
        inputs = self.tokenizer(
            prompt, 
            return_tensors="pt", 
            truncation=True,
            max_length=self.model_config['context_length'] - max_new_tokens
        ).to(self.model.device)
        
        # Optimized generation parameters for different models
        gen_kwargs = {
            'max_new_tokens': max_new_tokens,
            'temperature': 0.7,
            'top_p': 0.9,
            'top_k': 50,
            'do_sample': True,
            'pad_token_id': self.tokenizer.pad_token_id,
            'eos_token_id': self.tokenizer.eos_token_id,
            'repetition_penalty': 1.1
        }
        
        # Model-specific optimizations
        if self.model_key in ['deepseek-r1', 'deepseek_r1']:
            # DeepSeek R1 optimized parameters for reasoning
            gen_kwargs.update({
                'temperature': 0.3,  # Lower temperature for more focused reasoning
                'top_p': 0.9,
                'top_k': 40,
                'repetition_penalty': 1.05
            })
        elif self.model_key in ['qwen2.5-unsloth', 'qwen2.5']:
            # Unsloth Qwen2.5 optimized parameters
            gen_kwargs.update({
                'temperature': 0.6,
                'top_p': 0.95,
                'top_k': 30,
                'repetition_penalty': 1.08
            })
        elif self.model_config['type'] == 'qwen':
            gen_kwargs.update({
                'temperature': 0.6,
                'top_p': 0.95,
                'top_k': 20
            })
        elif self.model_config['type'] == 'llama':
            gen_kwargs.update({
                'temperature': 0.8,
                'top_p': 0.95
            })
        
        with torch.no_grad():
            generated_ids = self.model.generate(
                **inputs,
                **gen_kwargs
            )
        
        output_ids = generated_ids[0][len(inputs.input_ids[0]):]
        response = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip()
        
        categories = self.extract_question_categories(response)
        return categories
    
    def extract_question_categories(self, response: str) -> List[str]:
        """Extract question categories from Arabic response text"""
        patterns = [
            r'التصنيف النهائي:\s*\[([ABCDEFZ,\s]+)\]',
            r'الفئات:\s*\[([ABCDEFZ,\s]+)\]',
            r'التصنيف:\s*\[([ABCDEFZ,\s]+)\]',
            r'النتيجة:\s*\[([ABCDEFZ,\s]+)\]',
            r'الإجابة:\s*\[([ABCDEFZ,\s]+)\]',
        ]
        
        for pattern in patterns:
            match = re.search(pattern, response, re.IGNORECASE)
            if match:
                categories_str = match.group(1)
                categories = [cat.strip().upper() for cat in categories_str.split(',')]
                return [cat for cat in categories if cat in ['A', 'B', 'C', 'D', 'E', 'F', 'Z']]
        
        english_patterns = [
            r'Final Classification:\s*\[([ABCDEFZ,\s]+)\]',
            r'Categories:\s*\[([ABCDEFZ,\s]+)\]',
            r'Classification:\s*\[([ABCDEFZ,\s]+)\]',
            r'Answer:\s*\[([ABCDEFZ,\s]+)\]',
        ]
        
        for pattern in english_patterns:
            match = re.search(pattern, response, re.IGNORECASE)
            if match:
                categories_str = match.group(1)
                categories = [cat.strip().upper() for cat in categories_str.split(',')]
                return [cat for cat in categories if cat in ['A', 'B', 'C', 'D', 'E', 'F', 'Z']]
        
        found_categories = []
        for category in ['A', 'B', 'C', 'D', 'E', 'F', 'Z']:
            if f'({category})' in response or f'[{category}]' in response or f' {category} ' in response:
                found_categories.append(category)
        
        return found_categories if found_categories else ['Z']
    
    def process_test_dataset(self, df: pd.DataFrame, max_new_tokens: int = 8000, show_progress: bool = True) -> pd.DataFrame:
        """
        Process test dataset for question classification
        """
        print(f"🚀 Starting Arabic Medical Question Classification with {self.model_config['description']}...")
        print(f"Test dataset size: {len(df)} samples")
        print("-" * 80)
        
        predictions = []
        
        for idx, row in df.iterrows():
            if show_progress and idx % 10 == 0:
                print(f"Processing sample {idx+1}/{len(df)} - Current GPU memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
            
            try:
                categories = self.classify_question(row['question'], max_new_tokens=max_new_tokens)
                prediction_str = ', '.join(sorted(categories))
                predictions.append(prediction_str)
                
                if show_progress and idx < 3:
                    print(f"Sample {idx+1} prediction: {prediction_str}")
                    
            except Exception as e:
                print(f"Error processing question {idx}: {e}")
                predictions.append('Z')
            
            if idx % 20 == 0:
                self.cleanup_memory()
        
        result_df = pd.DataFrame({'prediction': predictions})
        return result_df
    
    def cleanup_memory(self):
        """Clean up GPU memory"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
    
    @classmethod
    def list_available_models(cls):
        """List all available models with descriptions"""
        print("📋 Available Models:")
        print("-" * 80)
        
        for key, config in cls.MODEL_CONFIGS.items():
            print(f"Key: '{key}'")
            print(f"  Model: {config['name']}")
            print(f"  Description: {config['description']}")
            print(f"  Context Length: {config['context_length']:,} tokens")
            print(f"  Type: {config['type']}")
            print()

def evaluate_model_on_training_data(
    train_file_path: str,
    model_key: str = 'qwen2.5-unsloth',
    use_quantization: bool = True,
    max_new_tokens: int = 8000
) -> dict:
    """
    Evaluate the model on the training dataset using Weighted F1 Score and Jaccard Score.

    Args:
        train_file_path: Path to the training dataset (TSV with 'question' and 'final_QT' columns)
        model_key: Model key from MODEL_CONFIGS (default: qwen2.5-unsloth)
        use_quantization: Use 4-bit quantization
        max_new_tokens: Maximum tokens to generate per question

    Returns:
        dict: Dictionary containing evaluation metrics (Weighted F1 Score, Jaccard Score)
    """
    
    # Initialize the classifier
    print(f"🚀 Initializing {model_key} for evaluation...")
    try:
        classifier = MultiModelArabicMedicalClassifier(
            model_key=model_key,
            use_quantization=use_quantization
        )
    except Exception as e:
        print(f"❌ Failed to initialize model: {e}")
        return None

    # Load training dataset
    try:
        df = pd.read_csv(train_file_path, sep='\t')
        if 'question' not in df.columns or 'final_QT' not in df.columns:
            raise ValueError("Dataset must contain 'question' and 'final_QT' columns")
        print(f"✅ Training dataset loaded: {len(df)} samples")
    except Exception as e:
        print(f"❌ Error loading training dataset: {e}")
        return None

    # Parse final_QT labels
    def parse_labels(label):
        if isinstance(label, str):
            try:
                if label.startswith('['):
                    return ast.literal_eval(label)
                else:
                    return [cat.strip() for cat in label.split(',')]
            except:
                print(f"Warning: Could not parse label '{label}', defaulting to ['Z']")
                return ['Z']
        return label

    df['final_QT'] = df['final_QT'].apply(parse_labels)

    # Generate predictions
    print(f"🚀 Generating predictions for {len(df)} samples...")
    predictions = classifier.process_test_dataset(df, max_new_tokens=max_new_tokens)

    # Convert predictions and true labels to multi-label binary format
    all_categories = ['A', 'B', 'C', 'D', 'E', 'F', 'Z']
    y_true = []
    y_pred = []

    for true_labels, pred_labels in zip(df['final_QT'], predictions['prediction']):
        true_vec = [1 if cat in true_labels else 0 for cat in all_categories]
        y_true.append(true_vec)
        pred_cats = [cat.strip() for cat in pred_labels.split(',') if cat.strip() in all_categories]
        pred_vec = [1 if cat in pred_cats else 0 for cat in all_categories]
        y_pred.append(pred_vec)

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    # Calculate metrics
    weighted_f1 = f1_score(y_true, y_pred, average='weighted')
    jaccard = jaccard_score(y_true, y_pred, average='samples')

    # Print results
    print("\n📊 Evaluation Results:")
    print(f"Weighted F1 Score: {weighted_f1:.4f}")
    print(f"Jaccard Score (samples): {jaccard:.4f}")

    # Detailed per-category metrics
    print("\n📋 Per-Category Metrics:")
    for i, cat in enumerate(all_categories):
        cat_f1 = f1_score(y_true[:, i], y_pred[:, i])
        cat_jaccard = jaccard_score(y_true[:, i], y_pred[:, i])
        print(f"Category {cat}:")
        print(f"  F1 Score: {cat_f1:.4f}")
        print(f"  Jaccard Score: {cat_jaccard:.4f}")

    # Cleanup memory
    classifier.cleanup_memory()

    return {
        'weighted_f1': weighted_f1,
        'jaccard_score': jaccard,
        'per_category_f1': {cat: f1_score(y_true[:, i], y_pred[:, i]) for i, cat in enumerate(all_categories)},
        'per_category_jaccard': {cat: jaccard_score(y_true[:, i], y_pred[:, i]) for i, cat in enumerate(all_categories)}
    }

def run_multiple_model_evaluation(train_file_path: str, models_to_test: List[str] = None):
    """
    Run evaluation on multiple models and compare results
    """
    if models_to_test is None:
        models_to_test = ['qwen2.5-unsloth', 'deepseek-r1']  # Default to new models
    
    results = {}
    
    for model_key in models_to_test:
        print(f"\n{'='*100}")
        print(f"🔄 Evaluating Model: {model_key}")
        print(f"{'='*100}")
        
        try:
            model_results = evaluate_model_on_training_data(
                train_file_path=train_file_path,
                model_key=model_key,
                use_quantization=True,
                max_new_tokens=8000
            )
            
            if model_results:
                results[model_key] = model_results
                print(f"✅ {model_key} evaluation completed successfully!")
            else:
                print(f"❌ {model_key} evaluation failed!")
                
        except Exception as e:
            print(f"❌ Error evaluating {model_key}: {e}")
            continue
        
        # Clean up memory between models
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
    
    # Print comparison results
    if results:
        print(f"\n{'='*100}")
        print("📊 COMPARISON RESULTS")
        print(f"{'='*100}")
        
        print(f"{'Model':<20} {'Weighted F1':<15} {'Jaccard Score':<15}")
        print("-" * 50)
        
        for model_key, result in results.items():
            print(f"{model_key:<20} {result['weighted_f1']:<15.4f} {result['jaccard_score']:<15.4f}")
    
    return results

def compare_all_models(train_file_path: str, use_quantization: bool = True, max_new_tokens: int = 6000):
    """
    Compare performance of all available models on the training dataset
    """
    models_to_test = ['qwen2.5-unsloth', 'deepseek-r1']
    results = {}
    
    print("🚀 Starting comprehensive model comparison...")
    print("=" * 80)
    
    for model_key in models_to_test:
        print(f"\n🔄 Testing {model_key}...")
        try:
            result = evaluate_model_on_training_data(
                train_file_path=train_file_path,
                model_key=model_key,
                use_quantization=use_quantization,
                max_new_tokens=max_new_tokens
            )
            if result:
                results[model_key] = result
            
            # Clear memory between models
            torch.cuda.empty_cache()
            gc.collect()
            
        except Exception as e:
            print(f"❌ Error testing {model_key}: {e}")
            continue
    
    # Print comparison results
    print("\n" + "=" * 80)
    print("📊 MODEL COMPARISON RESULTS")
    print("=" * 80)
    
    for model_key, result in results.items():
        print(f"\n{model_key.upper()}:")
        print(f"  Weighted F1 Score: {result['weighted_f1']:.4f}")
        print(f"  Jaccard Score: {result['jaccard_score']:.4f}")
    
    # Find best model
    if results:
        best_model = max(results.keys(), key=lambda k: results[k]['weighted_f1'])
        print(f"\n🏆 BEST MODEL: {best_model}")
        print(f"   Weighted F1: {results[best_model]['weighted_f1']:.4f}")
        print(f"   Jaccard Score: {results[best_model]['jaccard_score']:.4f}")
    
    return results

def install_unsloth():
    """
    Install Unsloth for optimized inference (run this once)
    """
    install_commands = [
        "pip install unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git",
        "pip install --no-deps trl peft accelerate bitsandbytes"
    ]
    
    print("📦 To install Unsloth for optimized inference, run these commands:")
    for cmd in install_commands:
        print(f"   {cmd}")
    
    print("\n⚠️  Note: Restart your runtime after installation!")

def benchmark_inference_speed(model_key: str, sample_questions: List[str], use_quantization: bool = True):
    """
    Benchmark inference speed for a specific model
    
    Args:
        model_key: Model to benchmark
        sample_questions: List of sample questions to test
        use_quantization: Whether to use quantization
    """
    import time
    
    print(f"🚀 Benchmarking {model_key} inference speed...")
    
    classifier = MultiModelArabicMedicalClassifier(
        model_key=model_key,
        use_quantization=use_quantization
    )
    
    # Warm-up run
    classifier.classify_question(sample_questions[0], max_new_tokens=2000)
    
    # Benchmark runs
    start_time = time.time()
    total_tokens = 0
    
    for question in sample_questions:
        result = classifier.classify_question(question, max_new_tokens=2000)
        total_tokens += len(classifier.tokenizer.encode(question))
    
    end_time = time.time()
    total_time = end_time - start_time
    
    print(f"⏱️  Benchmark Results for {model_key}:")
    print(f"   Total Time: {total_time:.2f} seconds")
    print(f"   Samples Processed: {len(sample_questions)}")
    print(f"   Average Time per Sample: {total_time/len(sample_questions):.2f} seconds")
    print(f"   Total Input Tokens: {total_tokens}")
    print(f"   Tokens per Second: {total_tokens/total_time:.2f}")
    
    classifier.cleanup_memory()
    return {
        'model_key': model_key,
        'total_time': total_time,
        'samples': len(sample_questions),
        'avg_time_per_sample': total_time/len(sample_questions),
        'tokens_per_second': total_tokens/total_time
    }

def create_optimized_pipeline(model_key: str = 'qwen2.5-unsloth', batch_size: int = 4):
    """
    Create an optimized inference pipeline for batch processing
    
    Args:
        model_key: Model to use for the pipeline
        batch_size: Number of samples to process in each batch
    
    Returns:
        Optimized classifier instance
    """
    print(f"🔧 Creating optimized pipeline with {model_key}...")
    
    classifier = MultiModelArabicMedicalClassifier(
        model_key=model_key,
        use_quantization=True,
        use_thinking_mode=False  # Disable for faster inference
    )
    
    # Set model to evaluation mode for inference
    classifier.model.eval()
    
    # Enable torch.compile for PyTorch 2.0+ (if available)
    try:
        import torch._dynamo
        classifier.model = torch.compile(classifier.model, mode="reduce-overhead")
        print("✅ Torch compile enabled for faster inference")
    except:
        print("⚠️  Torch compile not available, using standard inference")
    
    print(f"✅ Optimized pipeline ready with batch size: {batch_size}")
    return classifier

def test_new_models():
    """
    Test the newly added models with sample data
    """
    sample_questions = [
        "ما أسباب الصداع المستمر؟",
        "كيف يمكن علاج آلام المفاصل؟",
        "ما هي وظائف الكبد في الجسم؟",
        "هل يمكنك أن تنصحني بطبيب جيد لعلاج الأطفال؟",
        "ما هي أفضل التمارين للحفاظ على الصحة؟"
    ]
    
    models_to_test = ['qwen2.5-unsloth', 'deepseek-r1']
    
    for model_key in models_to_test:
        print(f"\n{'='*50}")
        print(f"Testing {model_key}")
        print(f"{'='*50}")
        
        try:
            classifier = MultiModelArabicMedicalClassifier(
                model_key=model_key,
                use_quantization=True
            )
            
            for i, question in enumerate(sample_questions):
                result = classifier.classify_question(question, max_new_tokens=3000)
                print(f"\nSample {i+1}: {question}")
                print(f"Classification: {result}")
            
            classifier.cleanup_memory()
            
        except Exception as e:
            print(f"❌ Error testing {model_key}: {e}")

if __name__ == "__main__":
    # Update this path to your training dataset
    TRAIN_DATASET_PATH = '/kaggle/input/train-dataset/Train_Dev.tsv'  # Replace with actual path

    # List available models
    MultiModelArabicMedicalClassifier.list_available_models()
    
    # Show installation instructions for Unsloth
    print("\n" + "="*80)
    print("🚀 INSTALLATION GUIDE")
    print("="*80)
    install_unsloth()

    # Evaluate with Unsloth Qwen2.5 (recommended for speed)
    # print("\n" + "=" * 50)
    # print("Testing Unsloth Qwen2.5-7B-Instruct")
    # print("=" * 50)
    # results_unsloth = evaluate_model_on_training_data(
    #     train_file_path=TRAIN_DATASET_PATH,
    #     model_key='qwen2.5-unsloth',
    #     use_quantization=True,
    #     max_new_tokens=6000
    # )

    # Evaluate with DeepSeek R1 (recommended for reasoning)
    print("\n" + "=" * 50)
    print("Testing DeepSeek-R1-Distill-Qwen-7B")
    print("=" * 50)
    results_deepseek = evaluate_model_on_training_data(
        train_file_path=TRAIN_DATASET_PATH,
        model_key='deepseek-r1',
        use_quantization=True,
        max_new_tokens=6000
    )
    
    # Uncomment to run comprehensive comparison
    # comprehensive_results = compare_all_models(
    #     train_file_path=TRAIN_DATASET_PATH,
    #     use_quantization=True,
    #     max_new_tokens=6000
    # )

    # Print summary if both models were tested
    if results_deepseek:
        print("\n" + "=" * 80)
        
        print(f"\nDeepSeek R1 Distill Qwen-7B:")
        print(f"  Weighted F1: {results_deepseek['weighted_f1']:.4f}")
        print(f"  Jaccard Score: {results_deepseek['jaccard_score']:.4f}")
        print(f"  Advantage: Advanced reasoning capabilities")
        
        
        print(f"\n💡 Recommendation:")
        print(f"   - Use 'qwen2.5-unsloth' for faster inference and production deployment")
        print(f"   - Use 'deepseek-r1' for complex reasoning tasks and research")

    print("\n" + "="*80)
    print("🧪 TESTING NEW MODELS")
    print("="*80)
    # Uncomment to test the new models with sample data
    # test_new_models()
    
    print("\n✅ Setup complete! The classifier now supports:")
    print("   - unsloth/Qwen2.5-7B-Instruct (Optimized for speed)")
    print("   - deepseek-ai/DeepSeek-R1-Distill-Qwen-7B (Advanced reasoning)")
    print("\n💡 Use 'qwen2.5-unsloth' as default for best performance!")