In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install transformers>=4.51.0 torch torchvision torchaudio accelerate bitsandbytes -q
!pip install sentencepiece protobuf -q
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import gc
import warnings
import re
warnings.filterwarnings('ignore')
# Check GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

In [None]:
import torch
import gc
import re
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from typing import List, Dict, Tuple, Any, Optional
from sklearn.metrics import f1_score, jaccard_score
import ast
import warnings
warnings.filterwarnings('ignore')

class ArabicMedicalAnswerClassifier:
    """
    Arabic Medical Answer Classifier supporting multiple LLM models including Unsloth and DeepSeek
    """
    
    # Model configurations with their specific settings
    MODEL_CONFIGS = {
        'qwen2': {
            'name': 'Qwen/Qwen2-7B-Instruct',
            'type': 'qwen',
            'context_length': 32768,
            'description': 'Qwen2 7B - Strong multilingual model'
        },
        'qwen2.5-unsloth': {
            'name': 'unsloth/Qwen2.5-7B-Instruct',
            'type': 'qwen',
            'context_length': 131072,
            'description': 'Unsloth Qwen2.5 7B - Optimized for inference speed'
        },
        'deepseek-r1': {
            'name': 'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',
            'type': 'qwen',
            'context_length': 131072,
            'description': 'DeepSeek R1 Distill Qwen 7B - Advanced reasoning model'
        },
        'llama3': {
            'name': 'meta-llama/Llama-3.1-8B-Instruct',
            'type': 'llama',
            'context_length': 128000,
            'description': 'Llama 3.1 8B - Meta\'s latest model'
        },
        'gemma2': {
            'name': 'google/gemma-2-9b-it',
            'type': 'gemma',
            'context_length': 8192,
            'description': 'Gemma 2 9B - Google\'s instruction-tuned model'
        },
        'phi3': {
            'name': 'microsoft/Phi-3-medium-14b-instruct',
            'type': 'phi',
            'context_length': 128000,
            'description': 'Phi-3 Medium - Microsoft\'s efficient model'
        }
    }
    
    def __init__(self, model_key: str = 'qwen2.5-unsloth', use_quantization: bool = True, use_thinking_mode: bool = True):
        """
        Initialize classifier with specified model
        
        Args:
            model_key: Key from MODEL_CONFIGS (default: qwen2.5-unsloth for better performance)
            use_quantization: Whether to use 4-bit quantization
            use_thinking_mode: Enable thinking mode for supported models
        """
        if model_key not in self.MODEL_CONFIGS:
            raise ValueError(f"Model '{model_key}' not supported. Available models: {list(self.MODEL_CONFIGS.keys())}")
        
        self.model_config = self.MODEL_CONFIGS[model_key]
        self.model_key = model_key
        self.use_quantization = use_quantization
        self.use_thinking_mode = use_thinking_mode
        self.tokenizer = None
        self.model = None
        self.answer_strategies = {
            '1': 'Information (answers providing information, resources, etc.)',
            '2': 'Direct Guidance (answers providing suggestions, instructions, or advice)',
            '3': 'Emotional Support (answers providing approval, reassurance, or other forms of emotional support)'
        }
        self.load_model()
    
    def load_model(self):
        """Load the specified model with appropriate configurations"""
        print(f"Loading {self.model_config['description']}...")
        print(f"Model: {self.model_config['name']}")
        
        quantization_config = None
        if self.use_quantization:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4"
            )
        
        try:
            # Special handling for Unsloth models
            if 'unsloth' in self.model_config['name']:
                print("🚀 Loading Unsloth optimized model...")
                try:
                    # Try to use Unsloth's FastLanguageModel if available
                    from unsloth import FastLanguageModel
                    self.model, self.tokenizer = FastLanguageModel.from_pretrained(
                        model_name=self.model_config['name'],
                        max_seq_length=self.model_config['context_length'],
                        dtype=torch.float16,
                        load_in_4bit=self.use_quantization,
                    )
                    FastLanguageModel.for_inference(self.model)  # Enable native 2x faster inference
                    print("✅ Unsloth FastLanguageModel loaded successfully!")
                except ImportError:
                    print("⚠️ Unsloth not available, falling back to standard transformers...")
                    # Fallback to standard transformers
                    self._load_standard_model(quantization_config)
            else:
                self._load_standard_model(quantization_config)
            
            # Set pad token if not available
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            print(f"✅ {self.model_config['description']} loaded successfully!")
            self.print_model_info()
            
        except Exception as e:
            print(f"❌ Error loading model: {e}")
            raise
    
    def _load_standard_model(self, quantization_config):
        """Load model using standard transformers"""
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_config['name'],
            trust_remote_code=True,
            padding_side='left'
        )
        
        model_kwargs = {
            'trust_remote_code': True,
            'torch_dtype': torch.float16,
            'low_cpu_mem_usage': True,
            'device_map': 'auto'
        }
        
        if quantization_config:
            model_kwargs['quantization_config'] = quantization_config
        
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_config['name'],
            **model_kwargs
        )
    
    def print_model_info(self):
        """Print model and memory information"""
        print(f"\n📊 Model Information:")
        print(f"Model: {self.model_config['name']}")
        print(f"Type: {self.model_config['type']}")
        print(f"Context Length: {self.model_config['context_length']:,} tokens")
        print(f"Quantization: {'4-bit' if self.use_quantization else 'Full precision'}")
        print(f"Thinking Mode: {'Enabled' if self.use_thinking_mode else 'Disabled'}")
        print(f"Tokenizer Vocab Size: {len(self.tokenizer):,}")
        
        if torch.cuda.is_available():
            print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
            print(f"GPU Memory Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    
    def create_prompt(self, answer: str) -> str:
        """Create model-specific prompt for classification"""
        strategy_descriptions = """
استراتيجيات الإجابة الطبية:
(1) المعلومات - إجابات تقدم معلومات وموارد وحقائق طبية
(2) التوجيه المباشر - إجابات تقدم اقتراحات وتعليمات ونصائح محددة
(3) الدعم العاطفي - إجابات تقدم الموافقة والطمأنينة أو أشكال أخرى من الدعم العاطفي
"""
        
        # Enhanced prompt for DeepSeek R1 model with reasoning capabilities
        if self.model_key == 'deepseek-r1':
            base_prompt = f"""أنت نموذج ذكي متخصص في تحليل وتصنيف الإجابات الطبية باللغة العربية. لديك قدرات تفكير متقدمة وتحليل عميق.

{strategy_descriptions}

الإجابة المراد تصنيفها:
{answer}

تعليمات خاصة للتحليل المتقدم:
1. استخدم قدراتك في التفكير المنطقي لتحليل الإجابة بعمق
2. حلل السياق الطبي والمفاهيم المستخدمة
3. فكر في الهدف من الإجابة ونوع المساعدة المقدمة
4. اعتبر التداخل بين الاستراتيجيات المختلفة
5. اشرح منطق تفكيرك خطوة بخطوة
6. حدد الكلمات المفتاحية والمؤشرات اللغوية
7. في النهاية، قدم التصنيف النهائي بالتنسيق المطلوب

التنسيق المطلوب:
"التصنيف النهائي: [أرقام الاستراتيجيات مفصولة بفواصل]"

مثال: إذا كانت الإجابة تحتوي على معلومات وتوجيه، فاكتب: "التصنيف النهائي: [1,2]"
"""
        elif self.use_thinking_mode:
            base_prompt = f"""أنت خبير في تصنيف الإجابات الطبية باللغة العربية. مهمتك هي تصنيف الإجابة التالية إلى استراتيجية أو أكثر من الاستراتيجيات المحددة.

{strategy_descriptions}

الإجابة المراد تصنيفها:
{answer}

تعليمات:
1. فكر بعمق في محتوى الإجابة وحلل كل جانب منها
2. اشرح تفكيرك خطوة بخطوة
3. حدد الكلمات المفتاحية والمفاهيم الطبية
4. اربط الإجابة بالاستراتيجيات المناسبة مع التبرير
5. اختر الاستراتيجية أو الاستراتيجيات الأنسب (يمكن أن يكون أكثر من استراتيجية)
6. في النهاية، اكتب الإجابة بالتنسيق التالي:
   "التصنيف النهائي: [1,2,3]" (استخدم الأرقام المناسبة مفصولة بفواصل)

مثال على التفكير والتنسيق:
- حلل الإجابة: "الصداع قد يكون بسبب الإجهاد. حاول الراحة."
- الكلمات المفتاحية: "سبب، إجهاد، حاول الراحة"
- التفكير: الإجابة توفر معلومات عن سبب (الإجهاد) وتقدم نصيحة (الراحة)
- التصنيف النهائي: [1,2]
"""
        else:
            base_prompt = f"""أنت خبير في تصنيف الإجابات الطبية باللغة العربية. مهمتك هي تصنيف الإجابة التالية إلى استراتيجية أو أكثر من الاستراتيجيات المحددة.

{strategy_descriptions}

الإجابة المراد تصنيفها:
{answer}

تعليمات:
1. اقرأ الإجابة بعناية وحلل محتواها
2. حدد الاستراتيجية أو الاستراتيجيات المناسبة (يمكن أن يكون هناك أكثر من استراتيجية واحدة)
3. اشرح سبب اختيارك لكل استراتيجية
4. في النهاية، اكتب الإجابة بالتنسيق التالي:
   "التصنيف النهائي: [1,2,3]" (استخدم الأرقام المناسبة مفصولة بفواصل)

مثال على التنسيق:
- إذا كانت الإجابة معلوماتية فقط: "التصنيف النهائي: [1]"
- إذا كانت الإجابة تحتوي على معلومات وتوجيه: "التصنيف النهائي: [1,2]"
"""
        
        # Apply chat template for Qwen-based models (including Unsloth and DeepSeek)
        if self.model_config['type'] == 'qwen':
            messages = [{"role": "user", "content": base_prompt}]
            try:
                return self.tokenizer.apply_chat_template(
                    messages, 
                    tokenize=False, 
                    add_generation_prompt=True,
                    enable_thinking=self.use_thinking_mode if hasattr(self.tokenizer, 'enable_thinking') else False
                )
            except:
                # Fallback for models that don't support thinking mode
                return self.tokenizer.apply_chat_template(
                    messages, 
                    tokenize=False, 
                    add_generation_prompt=True
                )
        elif self.model_config['type'] in ['llama']:
            return f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{base_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        elif self.model_config['type'] == 'gemma':
            return f"<start_of_turn>user\n{base_prompt}<end_of_turn>\n<start_of_turn>model\n"
        elif self.model_config['type'] == 'phi':
            return f"<|user|>\n{base_prompt}<|end|>\n<|assistant|>\n"
        else:
            return f"Human: {base_prompt}\n\nAssistant:"
    
    def classify_answer(self, answer: str, max_new_tokens: int = 8000) -> List[str]:
        """
        Classify Arabic medical answer into strategy categories
        """
        prompt = self.create_prompt(answer)
        
        model_inputs = self.tokenizer(
            prompt, 
            return_tensors="pt", 
            truncation=True,
            max_length=self.model_config['context_length'] - max_new_tokens
        ).to(self.model.device)
        
        # Optimized generation parameters for different models
        gen_kwargs = {
            'max_new_tokens': max_new_tokens,
            'temperature': 0.7,
            'top_p': 0.9,
            'top_k': 50,
            'do_sample': True,
            'pad_token_id': self.tokenizer.pad_token_id,
            'eos_token_id': self.tokenizer.eos_token_id,
            'repetition_penalty': 1.1
        }
        
        # Model-specific optimizations
        if self.model_key == 'deepseek-r1':
            # DeepSeek R1 optimized parameters for reasoning
            gen_kwargs.update({
                'temperature': 0.3,  # Lower temperature for more focused reasoning
                'top_p': 0.9,
                'top_k': 40,
                'repetition_penalty': 1.05
            })
        elif self.model_key == 'qwen2.5-unsloth':
            # Unsloth Qwen2.5 optimized parameters
            gen_kwargs.update({
                'temperature': 0.6,
                'top_p': 0.95,
                'top_k': 30,
                'repetition_penalty': 1.08
            })
        elif self.model_config['type'] == 'qwen':
            gen_kwargs.update({
                'temperature': 0.6,
                'top_p': 0.95,
                'top_k': 20
            })
        elif self.model_config['type'] == 'llama':
            gen_kwargs.update({
                'temperature': 0.8,
                'top_p': 0.95
            })
        
        with torch.no_grad():
            generated_ids = self.model.generate(
                **model_inputs,
                **gen_kwargs
            )
        
        output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]
        content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip()
        
        strategies = self.extract_answer_strategies(content)
        return strategies
    
    def extract_answer_strategies(self, response: str) -> List[str]:
        """Extract answer strategies from Arabic response text"""
        patterns = [
            r'التصنيف النهائي:\s*\[([123,\s]+)\]',
            r'الاستراتيجيات:\s*\[([123,\s]+)\]',
            r'التصنيف:\s*\[([123,\s]+)\]',
            r'النتيجة:\s*\[([123,\s]+)\]',
            r'الإجابة:\s*\[([123,\s]+)\]',
        ]
        
        for pattern in patterns:
            match = re.search(pattern, response, re.IGNORECASE)
            if match:
                strategies_str = match.group(1)
                strategies = [strat.strip() for strat in strategies_str.split(',')]
                return [strat for strat in strategies if strat in ['1', '2', '3']]
        
        english_patterns = [
            r'Final Classification:\s*\[([123,\s]+)\]',
            r'Strategies:\s*\[([123,\s]+)\]',
            r'Classification:\s*\[([123,\s]+)\]',
            r'Answer:\s*\[([123,\s]+)\]',
        ]
        
        for pattern in english_patterns:
            match = re.search(pattern, response, re.IGNORECASE)
            if match:
                strategies_str = match.group(1)
                strategies = [strat.strip() for strat in strategies_str.split(',')]
                return [strat for strat in strategies if strat in ['1', '2', '3']]
        
        found_strategies = []
        for strategy in ['1', '2', '3']:
            if f'({strategy})' in response or f'[{strategy}]' in response or f' {strategy} ' in response:
                found_strategies.append(strategy)
        
        return found_strategies if found_strategies else ['1']
    
    def process_test_dataset(self, df: pd.DataFrame, max_new_tokens: int = 8000, show_progress: bool = True) -> pd.DataFrame:
        """
        Process dataset for answer classification
        """
        print(f"🚀 Starting Arabic Medical Answer Classification with {self.model_config['description']}...")
        print(f"Dataset size: {len(df)} samples")
        print("-" * 80)
        
        predictions = []
        
        for idx, row in df.iterrows():
            if show_progress and idx % 10 == 0:
                print(f"Processing sample {idx+1}/{len(df)} - Current GPU memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
            
            try:
                strategies = self.classify_answer(row['answer'], max_new_tokens=max_new_tokens)
                prediction_str = ', '.join(sorted(strategies))
                predictions.append(prediction_str)
                
                if show_progress and idx < 3:
                    print(f"Sample {idx+1} prediction: {prediction_str}")
                    
            except Exception as e:
                print(f"Error processing answer {idx}: {e}")
                predictions.append('1')
            
            if idx % 20 == 0:
                self.cleanup_memory()
        
        result_df = pd.DataFrame({'prediction': predictions})
        return result_df
    
    def cleanup_memory(self):
        """Clean up GPU memory"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
    
    @classmethod
    def list_available_models(cls):
        """List all available models with descriptions"""
        print("📋 Available Models:")
        print("-" * 80)
        
        for key, config in cls.MODEL_CONFIGS.items():
            print(f"Key: '{key}'")
            print(f"  Model: {config['name']}")
            print(f"  Description: {config['description']}")
            print(f"  Context Length: {config['context_length']:,} tokens")
            print(f"  Type: {config['type']}")
            print()

def evaluate_model_on_training_data(
    train_file_path: str,
    model_key: str = 'qwen2.5-unsloth',
    use_quantization: bool = True,
    max_new_tokens: int = 8000
) -> dict:
    """
    Evaluate the model on the training dataset using Weighted F1 Score and Jaccard Score.

    Args:
        train_file_path: Path to the training dataset (TSV with 'answer' and 'final_AS' columns)
        model_key: Model key from MODEL_CONFIGS (default: qwen2.5-unsloth)
        use_quantization: Use 4-bit quantization
        max_new_tokens: Maximum tokens to generate per answer

    Returns:
        dict: Dictionary containing evaluation metrics (Weighted F1 Score, Jaccard Score)
    """
    
    print(f"🚀 Initializing {model_key} for evaluation...")
    try:
        classifier = ArabicMedicalAnswerClassifier(
            model_key=model_key,
            use_quantization=use_quantization
        )
    except Exception as e:
        print(f"❌ Failed to initialize model: {e}")
        return None

    # Load training dataset
    try:
        df = pd.read_csv(train_file_path, sep='\t')
        if 'answer' not in df.columns or 'final_AS' not in df.columns:
            raise ValueError("Dataset must contain 'answer' and 'final_AS' columns")
        print(f"✅ Training dataset loaded: {len(df)} samples")
    except Exception as e:
        print(f"❌ Error loading training dataset: {e}")
        return None

    # Parse final_AS labels
    def parse_labels(label):
        if isinstance(label, str):
            try:
                if label.startswith('['):
                    return ast.literal_eval(label)
                else:
                    return [strat.strip() for strat in label.split(',')]
            except:
                print(f"Warning: Could not parse label '{label}', defaulting to ['1']")
                return ['1']
        return label

    df['final_AS'] = df['final_AS'].apply(parse_labels)

    # Generate predictions
    print(f"🚀 Generating predictions for {len(df)} samples...")
    predictions = classifier.process_test_dataset(df, max_new_tokens=max_new_tokens)

    # Convert predictions and true labels to multi-label binary format
    all_strategies = ['1', '2', '3']
    y_true = []
    y_pred = []

    for true_labels, pred_labels in zip(df['final_AS'], predictions['prediction']):
        true_vec = [1 if strat in true_labels else 0 for strat in all_strategies]
        y_true.append(true_vec)
        pred_strats = [strat.strip() for strat in pred_labels.split(',') if strat.strip() in all_strategies]
        pred_vec = [1 if strat in pred_strats else 0 for strat in all_strategies]
        y_pred.append(pred_vec)

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    # Calculate metrics
    weighted_f1 = f1_score(y_true, y_pred, average='weighted')
    jaccard = jaccard_score(y_true, y_pred, average='samples')

    # Print results
    print("\n📊 Evaluation Results:")
    print(f"Weighted F1 Score: {weighted_f1:.4f}")
    print(f"Jaccard Score (samples): {jaccard:.4f}")

    # Detailed per-strategy metrics
    print("\n📋 Per-Strategy Metrics:")
    for i, strat in enumerate(all_strategies):
        strat_f1 = f1_score(y_true[:, i], y_pred[:, i])
        strat_jaccard = jaccard_score(y_true[:, i], y_pred[:, i])
        print(f"Strategy {strat}:")
        print(f"  F1 Score: {strat_f1:.4f}")
        print(f"  Jaccard Score: {strat_jaccard:.4f}")

    # Cleanup memory
    classifier.cleanup_memory()

    return {
        'weighted_f1': weighted_f1,
        'jaccard_score': jaccard,
        'per_strategy_f1': {strat: f1_score(y_true[:, i], y_pred[:, i]) for i, strat in enumerate(all_strategies)},
        'per_strategy_jaccard': {strat: jaccard_score(y_true[:, i], y_pred[:, i]) for i, strat in enumerate(all_strategies)}
    }

def compare_all_models(train_file_path: str, use_quantization: bool = True, max_new_tokens: int = 6000):
    """
    Compare performance of all available models on the training dataset
    """
    models_to_test = ['qwen2.5-unsloth', 'deepseek-r1']
    results = {}
    
    print("🚀 Starting comprehensive model comparison...")
    print("=" * 80)
    
    for model_key in models_to_test:
        print(f"\n🔄 Testing {model_key}...")
        try:
            result = evaluate_model_on_training_data(
                train_file_path=train_file_path,
                model_key=model_key,
                use_quantization=use_quantization,
                max_new_tokens=max_new_tokens
            )
            if result:
                results[model_key] = result
            
            # Clear memory between models
            torch.cuda.empty_cache()
            gc.collect()
            
        except Exception as e:
            print(f"❌ Error testing {model_key}: {e}")
            continue
    
    # Print comparison results
    print("\n" + "=" * 80)
    print("📊 MODEL COMPARISON RESULTS")
    print("=" * 80)
    
    for model_key, result in results.items():
        print(f"\n{model_key.upper()}:")
        print(f"  Weighted F1 Score: {result['weighted_f1']:.4f}")
        print(f"  Jaccard Score: {result['jaccard_score']:.4f}")
    
    # Find best model
    if results:
        best_model = max(results.keys(), key=lambda k: results[k]['weighted_f1'])
        print(f"\n🏆 BEST MODEL: {best_model}")
        print(f"   Weighted F1: {results[best_model]['weighted_f1']:.4f}")
        print(f"   Jaccard Score: {results[best_model]['jaccard_score']:.4f}")
    
    return results

if __name__ == "__main__":
    # Update this path to your training dataset
    TRAIN_DATASET_PATH = '/kaggle/input/train-dataset/Train_Dev.tsv'  # Replace with actual path

    # List available models
    ArabicMedicalAnswerClassifier.list_available_models()

    # Evaluate with Unsloth Qwen2.5 (recommended for speed)
    print("\n" + "=" * 50)
    print("Testing Unsloth Qwen2.5-7B-Instruct")
    print("=" * 50)
    results_unsloth = evaluate_model_on_training_data(
        train_file_path=TRAIN_DATASET_PATH,
        model_key='qwen2.5-unsloth',
        use_quantization=True,
        max_new_tokens=6000
    )

    # Evaluate with DeepSeek R1 (recommended for reasoning)
    print("\n" + "=" * 50)
    print("Testing DeepSeek-R1-Distill-Qwen-7B")
    print("=" * 50)
    results_deepseek = evaluate_model_on_training_data(
        train_file_path=TRAIN_DATASET_PATH,
        model_key='deepseek-r1',
        use_quantization=True,
        max_new_tokens=6000
    )

    # Uncomment to run comprehensive comparison
    # comprehensive_results = compare_all_models(
    #     train_file_path=TRAIN_DATASET_PATH,
    #     use_quantization=True,
    #     max_new_tokens=6000
    # )

    # Print summary if both models were tested
    if results_unsloth and results_deepseek:
        print("\n" + "=" * 80)
        print("📊 COMPARISON SUMMARY")
        print("=" * 80)
        print(f"Unsloth Qwen2.5-7B:")
        print(f"  Weighted F1: {results_unsloth['weighted_f1']:.4f}")
        print(f"  Jaccard Score: {results_unsloth['jaccard_score']:.4f}")
        print(f"  Advantage: Faster inference, optimized performance")
        
        print(f"\nDeepSeek R1 Distill Qwen-7B:")
        print(f"  Weighted F1: {results_deepseek['weighted_f1']:.4f}")
        print(f"  Jaccard Score: {results_deepseek['jaccard_score']:.4f}")
        print(f"  Advantage: Advanced reasoning capabilities")
        
        # Determine winner
        if results_unsloth['weighted_f1'] > results_deepseek['weighted_f1']:
            print(f"\n🏆 Winner: Unsloth Qwen2.5-7B (Better F1 Score)")
        elif results_deepseek['weighted_f1'] > results_unsloth['weighted_f1']:
            print(f"\n🏆 Winner: DeepSeek R1 (Better F1 Score)")
        else:
            print(f"\n🤝 Tie: Both models perform equally well")
        
        print(f"\n💡 Recommendation:")
        print(f"   - Use 'qwen2.5-unsloth' for faster inference and production deployment")
        print(f"   - Use 'deepseek-r1' for complex reasoning tasks and research")

# Additional utility functions for model management and optimization

def install_unsloth():
    """
    Install Unsloth for optimized inference (run this once)
    """
    install_commands = [
        "pip install unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git",
        "pip install --no-deps trl peft accelerate bitsandbytes"
    ]
    
    print("📦 To install Unsloth for optimized inference, run these commands:")
    for cmd in install_commands:
        print(f"   {cmd}")
    
    print("\n⚠️  Note: Restart your runtime after installation!")

def benchmark_inference_speed(model_key: str, sample_answers: List[str], use_quantization: bool = True):
    """
    Benchmark inference speed for a specific model
    
    Args:
        model_key: Model to benchmark
        sample_answers: List of sample answers to test
        use_quantization: Whether to use quantization
    """
    import time
    
    print(f"🚀 Benchmarking {model_key} inference speed...")
    
    classifier = ArabicMedicalAnswerClassifier(
        model_key=model_key,
        use_quantization=use_quantization
    )
    
    # Warm-up run
    classifier.classify_answer(sample_answers[0], max_new_tokens=2000)
    
    # Benchmark runs
    start_time = time.time()
    total_tokens = 0
    
    for answer in sample_answers:
        result = classifier.classify_answer(answer, max_new_tokens=2000)
        total_tokens += len(classifier.tokenizer.encode(answer))
    
    end_time = time.time()
    total_time = end_time - start_time
    
    print(f"⏱️  Benchmark Results for {model_key}:")
    print(f"   Total Time: {total_time:.2f} seconds")
    print(f"   Samples Processed: {len(sample_answers)}")
    print(f"   Average Time per Sample: {total_time/len(sample_answers):.2f} seconds")
    print(f"   Total Input Tokens: {total_tokens}")
    print(f"   Tokens per Second: {total_tokens/total_time:.2f}")
    
    classifier.cleanup_memory()
    return {
        'model_key': model_key,
        'total_time': total_time,
        'samples': len(sample_answers),
        'avg_time_per_sample': total_time/len(sample_answers),
        'tokens_per_second': total_tokens/total_time
    }

def create_optimized_pipeline(model_key: str = 'qwen2.5-unsloth', batch_size: int = 4):
    """
    Create an optimized inference pipeline for batch processing
    
    Args:
        model_key: Model to use for the pipeline
        batch_size: Number of samples to process in each batch
    
    Returns:
        Optimized classifier instance
    """
    print(f"🔧 Creating optimized pipeline with {model_key}...")
    
    classifier = ArabicMedicalAnswerClassifier(
        model_key=model_key,
        use_quantization=True,
        use_thinking_mode=False  # Disable for faster inference
    )
    
    # Set model to evaluation mode for inference
    classifier.model.eval()
    
    # Enable torch.compile for PyTorch 2.0+ (if available)
    try:
        import torch._dynamo
        classifier.model = torch.compile(classifier.model, mode="reduce-overhead")
        print("✅ Torch compile enabled for faster inference")
    except:
        print("⚠️  Torch compile not available, using standard inference")
    
    print(f"✅ Optimized pipeline ready with batch size: {batch_size}")
    return classifier

# Example usage and testing functions

def test_new_models():
    """
    Test the newly added models with sample data
    """
    sample_answers = [
        "يمكن أن يكون الصداع بسبب الإجهاد أو قلة النوم. أنصحك بالراحة وشرب الماء.",
        "لا تقلق، هذه الأعراض طبيعية. ستشعر بتحسن قريباً.",
        "فيتامين د مهم لصحة العظام ويمكن الحصول عليه من أشعة الشمس والأطعمة المدعمة."
    ]
    
    models_to_test = ['qwen2.5-unsloth']  #deepseek-r1
    
    for model_key in models_to_test:
        print(f"\n{'='*50}")
        print(f"Testing {model_key}")
        print(f"{'='*50}")
        
        try:
            classifier = ArabicMedicalAnswerClassifier(
                model_key=model_key,
                use_quantization=True
            )
            
            for i, answer in enumerate(sample_answers):
                result = classifier.classify_answer(answer, max_new_tokens=3000)
                print(f"\nSample {i+1}: {answer[:50]}...")
                print(f"Classification: {result}")
            
            classifier.cleanup_memory()
            
        except Exception as e:
            print(f"❌ Error testing {model_key}: {e}")

# Run tests if this is the main module
if __name__ == "__main__":
    # Show installation instructions for Unsloth
    print("\n" + "="*80)
    print("🚀 INSTALLATION GUIDE")
    print("="*80)
    install_unsloth()
    
    print("\n" + "="*80)
    print("🧪 TESTING NEW MODELS")
    print("="*80)
    # Uncomment to test the new models with sample data
    # test_new_models()
    
    print("\n✅ Setup complete! The classifier now supports:")
    print("   - unsloth/Qwen2.5-7B-Instruct (Optimized for speed)")
    print("   - deepseek-ai/DeepSeek-R1-Distill-Qwen-7B (Advanced reasoning)")
    print("\n💡 Use 'qwen2.5-unsloth' as default for best performance!")