# Pharmaceutical Multi-Label Classification with QLoRA Fine-tuning

## Project Overview

**Objective**: Improve pharmaceutical document classification from 80% to 95%+ F1 score using QLoRA fine-tuned Llama 3.1 8B

**Business Context**: 
- Current rule-based + traditional ML system achieves 80% F1 score
- Need to classify pharmaceutical documents into multiple therapeutic areas, regulatory categories, and risk levels
- Critical for drug development, regulatory compliance, and pharmacovigilance

**Key Challenges**:
- Complex medical terminology and drug interactions
- Multiple overlapping categories per document
- Regulatory compliance requirements (FDA, EMA)
- Class imbalance in rare disease categories
- Need for explainable predictions

## Table of Contents
1. Environment Setup and Dependencies
2. Pharmaceutical Data Preparation
3. Multi-Label Classification Framework
4. QLoRA Configuration for Classification
5. Model Architecture and Training
6. Evaluation Metrics and Benchmarking
7. Ablation Studies and Optimization
8. Production Deployment and Monitoring
9. Regulatory Compliance and Explainability

## 1. Environment Setup and Dependencies

In [None]:
# Install required packages for pharmaceutical NLP and QLoRA
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install transformers>=4.31.0
!pip install peft>=0.4.0
!pip install datasets
!pip install bitsandbytes>=0.39.0
!pip install accelerate>=0.20.3
!pip install trl
!pip install scikit-learn
!pip install scipy
!pip install pandas
!pip install numpy
!pip install matplotlib
!pip install seaborn
!pip install wandb
!pip install nltk
!pip install spacy
!pip install scispacy  # For biomedical NLP
!pip install regex
!pip install wordcloud
!pip install plotly

# Download medical NLP models
!python -m spacy download en_core_web_sm
!python -m spacy download en_core_sci_sm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback
)
from peft import (
    LoraConfig,
    get_peft_model,
    TaskType,
    PeftModel,
    prepare_model_for_kbit_training
)
from trl import SFTTrainer
from datasets import Dataset, DatasetDict
from sklearn.metrics import (
    f1_score, precision_score, recall_score, 
    classification_report, multilabel_confusion_matrix,
    hamming_loss, jaccard_score
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
import pandas as pd
import numpy as np
import json
import re
import time
import warnings
import gc
import matplotlib.pyplot as plt
import seaborn as sns
import spacy
from typing import Dict, List, Tuple, Optional, Union
from collections import Counter, defaultdict
import random

warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8')

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 2. Pharmaceutical Data Preparation

In [None]:
# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define pharmaceutical classification taxonomy
PHARMA_LABELS = {
    # Therapeutic Areas (Primary Classification)
    'therapeutic_areas': [
        'oncology', 'cardiovascular', 'neurology', 'immunology', 'infectious_diseases',
        'respiratory', 'endocrinology', 'gastroenterology', 'dermatology', 'ophthalmology',
        'psychiatry', 'rheumatology', 'urology', 'hematology', 'rare_diseases'
    ],
    
    # Drug Development Phase
    'development_phase': [
        'preclinical', 'phase_1', 'phase_2', 'phase_3', 'phase_4', 'post_market'
    ],
    
    # Regulatory Categories
    'regulatory': [
        'safety_report', 'efficacy_study', 'pharmacokinetics', 'pharmacodynamics',
        'toxicology', 'clinical_trial', 'regulatory_submission', 'label_update'
    ],
    
    # Risk Assessment
    'risk_level': [
        'low_risk', 'medium_risk', 'high_risk', 'critical_risk'
    ],
    
    # Document Type
    'document_type': [
        'clinical_study_report', 'adverse_event_report', 'drug_label', 
        'investigator_brochure', 'protocol', 'statistical_analysis_plan',
        'pharmacovigilance_report', 'regulatory_correspondence'
    ]
}

# Flatten all labels
ALL_LABELS = []
for category, labels in PHARMA_LABELS.items():
    ALL_LABELS.extend(labels)

print(f"Total classification labels: {len(ALL_LABELS)}")

In [None]:
# FIXED: Generate realistic pharmaceutical training data
def generate_pharma_training_data(n_samples: int = 2000) -> List[Dict]:
    """Generate realistic pharmaceutical documents with multi-label classifications."""
    
    # Document templates by type - FIXED formatting issues
    document_templates = {
        'clinical_study_report': [
            "Phase {} clinical trial evaluating {} in patients with {}. Primary endpoint: {}. Secondary endpoints include safety, tolerability, and pharmacokinetics. {} patients enrolled across {} sites. Results demonstrate {} with {} adverse events reported.",
            "Randomized controlled trial of {} vs placebo in {} patients. Study duration: {} months. Primary efficacy endpoint met with statistical significance (p<0.05). Safety profile consistent with known {} risk profile.",
            "Open-label extension study of {} in {} indication. Long-term safety and efficacy data collected over {} years. No new safety signals identified. Sustained efficacy observed in {}% of patients."
        ],
        'adverse_event_report': [
            "Serious adverse event report: {} year old {} patient experienced {} after {} days of {} treatment. Event assessed as {} related to study drug. Patient {} and treatment {}.",
            "Spontaneous report of {} in patient taking {} for {}. Onset {} hours post-dose. Concomitant medications: {}. Dechallenge: {}. Rechallenge: {}.",
            "Healthcare professional report: {} observed in {} patients receiving {}. All cases {} and {} within {} days. Causality assessment: {}."
        ],
        'drug_label': [
            "{} is indicated for the treatment of {} in adult patients. Recommended dose: {}. Contraindications: {}. Warnings and precautions: monitor for {}.",
            "INDICATIONS AND USAGE: {} is a {} indicated for {} treatment. DOSAGE AND ADMINISTRATION: {}. CONTRAINDICATIONS: {}. WARNINGS: {}.",
            "Prescribing information for {}: {} tablets. Indicated for {} treatment. Common adverse reactions (≥5%): {}. Drug interactions: {}."
        ]
    }
    
    # Medical terminology pools
    drug_names = [
        'Pembrolizumab', 'Adalimumab', 'Rituximab', 'Bevacizumab', 'Trastuzumab',
        'Infliximab', 'Etanercept', 'Tocilizumab', 'Nivolumab', 'Atezolizumab',
        'Durvalumab', 'Avelumab', 'Ipilimumab', 'Cetuximab', 'Panitumumab'
    ]
    
    drug_classes = [
        'monoclonal antibody', 'TNF-alpha inhibitor', 'PD-1 inhibitor', 
        'VEGF inhibitor', 'HER2 inhibitor', 'CD20 antagonist', 'IL-6 inhibitor'
    ]
    
    conditions = [
        'non-small cell lung cancer', 'breast cancer', 'colorectal cancer', 'melanoma',
        'rheumatoid arthritis', 'Crohn\'s disease', 'multiple sclerosis', 'psoriasis',
        'atrial fibrillation', 'heart failure', 'diabetes mellitus', 'hypertension',
        'COPD', 'asthma', 'depression', 'schizophrenia', 'Alzheimer\'s disease'
    ]
    
    adverse_events = [
        'neutropenia', 'thrombocytopenia', 'hepatotoxicity', 'cardiotoxicity',
        'pneumonitis', 'colitis', 'dermatitis', 'infusion reaction',
        'nausea', 'fatigue', 'diarrhea', 'headache', 'hypertension'
    ]
    
    training_data = []
    
    for i in range(n_samples):
        # Select document type and template
        doc_type = random.choice(list(document_templates.keys()))
        template = random.choice(document_templates[doc_type])
        
        # Generate document content based on type - FIXED argument count
        if doc_type == 'clinical_study_report':
            if 'Phase {}' in template:
                content = template.format(
                    random.choice(['I', 'II', 'III']),
                    random.choice(drug_names),
                    random.choice(conditions),
                    random.choice(['overall survival', 'progression-free survival', 'response rate']),
                    random.randint(50, 500),
                    random.randint(5, 50),
                    random.choice(['significant improvement', 'positive results', 'favorable outcomes']),
                    random.choice(['minimal', 'manageable', 'expected'])
                )
            elif 'Randomized controlled' in template:
                content = template.format(
                    random.choice(drug_names),
                    random.randint(100, 1000),
                    random.randint(6, 36),
                    random.choice(['acceptable', 'manageable', 'expected'])
                )
            else:  # Open-label extension
                content = template.format(
                    random.choice(drug_names),
                    random.choice(conditions),
                    random.randint(2, 5),
                    random.randint(60, 85)
                )
                
        elif doc_type == 'adverse_event_report':
            if 'Serious adverse event' in template:
                content = template.format(
                    random.randint(18, 85),
                    random.choice(['male', 'female']),
                    random.choice(adverse_events),
                    random.randint(1, 30),
                    random.choice(drug_names),
                    random.choice(['possibly', 'probably', 'definitely']),
                    random.choice(['recovered', 'recovering', 'not recovered']),
                    random.choice(['discontinued', 'continued', 'dose reduced'])
                )
            elif 'Spontaneous report' in template:
                content = template.format(
                    random.choice(adverse_events),
                    random.choice(drug_names),
                    random.choice(conditions),
                    random.randint(1, 72),
                    random.choice(['aspirin', 'metformin', 'none']),
                    random.choice(['positive', 'negative', 'not done']),
                    random.choice(['positive', 'negative', 'not done'])
                )
            else:  # Healthcare professional report
                content = template.format(
                    random.choice(adverse_events),
                    random.randint(2, 10),
                    random.choice(drug_names),
                    random.choice(['resolved', 'ongoing', 'fatal']),
                    random.choice(['recovered', 'improving', 'worsened']),
                    random.randint(1, 14),
                    random.choice(['probable', 'possible', 'unlikely'])
                )
                
        else:  # drug_label - FIXED with proper argument count
            if 'INDICATIONS AND USAGE' in template:
                content = template.format(
                    random.choice(drug_names),
                    random.choice(drug_classes),
                    random.choice(conditions),
                    f"{random.randint(5, 100)} mg {random.choice(['daily', 'twice daily', 'weekly'])}",
                    random.choice(['pregnancy', 'severe hepatic impairment', 'hypersensitivity']),
                    random.choice(['hepatic function', 'renal function', 'cardiac function'])
                )
            elif 'Prescribing information' in template:
                content = template.format(
                    random.choice(drug_names),
                    f"{random.randint(5, 100)} mg",
                    random.choice(conditions),
                    random.choice(['nausea, headache, fatigue', 'diarrhea, rash, dizziness']),
                    random.choice(['warfarin, digoxin', 'CYP3A4 inhibitors', 'none known'])
                )
            else:  # Simple format
                content = template.format(
                    random.choice(drug_names),
                    random.choice(conditions),
                    f"{random.randint(5, 100)} mg {random.choice(['daily', 'twice daily', 'weekly'])}",
                    random.choice(['pregnancy', 'severe hepatic impairment', 'hypersensitivity']),
                    random.choice(['hepatic function', 'renal function', 'cardiac function'])
                )
        
        # Assign labels based on content and document type
        labels = [doc_type]
        
        # Add therapeutic area labels
        if any(cancer in content.lower() for cancer in ['cancer', 'tumor', 'oncology', 'melanoma']):
            labels.append('oncology')
        if any(cardio in content.lower() for cardio in ['heart', 'cardiac', 'atrial', 'cardiovascular']):
            labels.append('cardiovascular')
        if any(neuro in content.lower() for neuro in ['multiple sclerosis', 'alzheimer', 'depression', 'schizophrenia']):
            labels.append('neurology')
        if any(immune in content.lower() for immune in ['rheumatoid', 'crohn', 'psoriasis', 'adalimumab', 'infliximab']):
            labels.append('immunology')
        
        # Add development phase
        if 'phase i' in content.lower() or 'phase 1' in content:
            labels.append('phase_1')
        elif 'phase ii' in content.lower() or 'phase 2' in content:
            labels.append('phase_2')
        elif 'phase iii' in content.lower() or 'phase 3' in content:
            labels.append('phase_3')
        elif 'post-market' in content.lower() or 'post market' in content.lower():
            labels.append('post_market')
        
        # Add regulatory categories
        if doc_type == 'adverse_event_report':
            labels.extend(['safety_report', 'pharmacovigilance_report'])
        elif doc_type == 'clinical_study_report':
            labels.extend(['efficacy_study', 'clinical_trial'])
        
        # Add risk level
        if any(serious in content.lower() for serious in ['serious', 'severe', 'critical', 'death']):
            labels.append('high_risk')
        elif any(moderate in content.lower() for moderate in ['moderate', 'significant']):
            labels.append('medium_risk')
        else:
            labels.append('low_risk')
        
        # Remove duplicates and ensure valid labels
        labels = list(set([label for label in labels if label in ALL_LABELS]))
        
        training_data.append({
            'text': content,
            'labels': labels,
            'doc_type': doc_type
        })
    
    return training_data

# Generate training data
print("Generating pharmaceutical training data...")
pharma_data = generate_pharma_training_data(1000)  # Reduced for demonstration

print(f"Generated {len(pharma_data)} pharmaceutical documents")
print(f"\nSample document:")
sample = pharma_data[0]
print(f"Text: {sample['text'][:200]}...")
print(f"Labels: {sample['labels']}")
print(f"Document type: {sample['doc_type']}")


In [None]:
# Analyze label distribution and dataset statistics
def analyze_dataset_statistics(data: List[Dict]):
    """Analyze the generated dataset for balance and coverage."""
    
    print("Dataset Statistics Analysis")
    print("=" * 50)
    
    # Basic statistics
    text_lengths = [len(doc['text'].split()) for doc in data]
    label_counts = [len(doc['labels']) for doc in data]
    
    print(f"Total documents: {len(data)}")
    print(f"Average text length: {np.mean(text_lengths):.1f} words")
    print(f"Text length range: {min(text_lengths)} - {max(text_lengths)} words")
    print(f"Average labels per document: {np.mean(label_counts):.1f}")
    print(f"Labels per document range: {min(label_counts)} - {max(label_counts)}")
    
    # Label frequency analysis
    all_labels = []
    for doc in data:
        all_labels.extend(doc['labels'])
    
    label_freq = Counter(all_labels)
    
    print(f"\nLabel Frequency Analysis:")
    print(f"Total unique labels used: {len(label_freq)}")
    print(f"Most common labels:")
    for label, count in label_freq.most_common(10):
        print(f"  {label}: {count} ({count/len(data)*100:.1f}%)")
    
    # Check for class imbalance
    min_freq = min(label_freq.values())
    max_freq = max(label_freq.values())
    imbalance_ratio = max_freq / min_freq
    
    print(f"\nClass Imbalance Analysis:")
    print(f"Most frequent label: {max_freq} documents")
    print(f"Least frequent label: {min_freq} documents")
    print(f"Imbalance ratio: {imbalance_ratio:.1f}:1")
    
    if imbalance_ratio > 10:
        print("⚠️  Significant class imbalance detected - consider balancing strategies")
    
    # Document type distribution
    doc_types = Counter([doc['doc_type'] for doc in data])
    print(f"\nDocument Type Distribution:")
    for doc_type, count in doc_types.items():
        print(f"  {doc_type}: {count} ({count/len(data)*100:.1f}%)")
    
    return {
        'label_freq': label_freq,
        'text_lengths': text_lengths,
        'label_counts': label_counts,
        'imbalance_ratio': imbalance_ratio
    }

# Analyze the dataset
stats = analyze_dataset_statistics(pharma_data)

In [None]:
# Visualize dataset characteristics
def visualize_dataset_statistics(data: List[Dict], stats: Dict):
    """Create visualizations for dataset analysis."""
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Text length distribution
    axes[0, 0].hist(stats['text_lengths'], bins=30, alpha=0.7, color='skyblue')
    axes[0, 0].set_title('Distribution of Document Lengths')
    axes[0, 0].set_xlabel('Number of Words')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].axvline(np.mean(stats['text_lengths']), color='red', linestyle='--', label='Mean')
    axes[0, 0].legend()
    
    # Labels per document distribution
    axes[0, 1].hist(stats['label_counts'], bins=range(1, max(stats['label_counts'])+2), 
                    alpha=0.7, color='lightgreen')
    axes[0, 1].set_title('Distribution of Labels per Document')
    axes[0, 1].set_xlabel('Number of Labels')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].axvline(np.mean(stats['label_counts']), color='red', linestyle='--', label='Mean')
    axes[0, 1].legend()
    
    # Top 15 most frequent labels
    top_labels = dict(stats['label_freq'].most_common(15))
    axes[1, 0].barh(range(len(top_labels)), list(top_labels.values()), color='orange', alpha=0.7)
    axes[1, 0].set_yticks(range(len(top_labels)))
    axes[1, 0].set_yticklabels(list(top_labels.keys()))
    axes[1, 0].set_title('Top 15 Most Frequent Labels')
    axes[1, 0].set_xlabel('Frequency')
    
    # Document type distribution
    doc_types = Counter([doc['doc_type'] for doc in data])
    axes[1, 1].pie(doc_types.values(), labels=doc_types.keys(), autopct='%1.1f%%', 
                   colors=['lightcoral', 'lightblue', 'lightgreen'])
    axes[1, 1].set_title('Document Type Distribution')
    
    plt.tight_layout()
    plt.show()
    
    # Label co-occurrence heatmap
    print("\nGenerating label co-occurrence analysis...")
    
    # Create binary matrix for labels
    mlb = MultiLabelBinarizer()
    label_matrix = mlb.fit_transform([doc['labels'] for doc in data])
    label_names = mlb.classes_
    
    # Calculate co-occurrence matrix
    cooccurrence = np.dot(label_matrix.T, label_matrix)
    
    # Normalize by diagonal (convert to conditional probability)
    diag = np.diag(cooccurrence)
    cooccurrence_norm = cooccurrence / diag[:, np.newaxis]
    
    # Plot heatmap for most frequent labels
    top_15_indices = [i for i, label in enumerate(label_names) 
                      if label in dict(stats['label_freq'].most_common(15))]
    
    plt.figure(figsize=(12, 10))
    sns.heatmap(cooccurrence_norm[np.ix_(top_15_indices, top_15_indices)], 
                xticklabels=[label_names[i] for i in top_15_indices],
                yticklabels=[label_names[i] for i in top_15_indices],
                annot=True, fmt='.2f', cmap='Blues')
    plt.title('Label Co-occurrence Matrix (Top 15 Labels)')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

# Create visualizations
visualize_dataset_statistics(pharma_data, stats)

In [None]:
pharma_data

## 3. Multi-Label Classification Framework

In [None]:
# Prepare data for multi-label classification training
def prepare_classification_data(data: List[Dict], test_size: float = 0.2, val_size: float = 0.1):
    """Prepare data for multi-label classification with proper splits."""
    
    # Convert to format suitable for training
    texts = [doc['text'] for doc in data]
    labels = [doc['labels'] for doc in data]
    
    # Create stratified split to maintain label distribution
    # For multi-label, we'll use iterative stratification approach
    train_texts, temp_texts, train_labels, temp_labels = train_test_split(
        texts, labels, test_size=(test_size + val_size), random_state=42
    )
    
    # Split temp into validation and test
    val_texts, test_texts, val_labels, test_labels = train_test_split(
        temp_texts, temp_labels, 
        test_size=test_size/(test_size + val_size), 
        random_state=42
    )
    
    print(f"Data split:")
    print(f"  Training: {len(train_texts)} documents")
    print(f"  Validation: {len(val_texts)} documents")
    print(f"  Test: {len(test_texts)} documents")
    
    # Create label encoder
    mlb = MultiLabelBinarizer()
    mlb.fit(train_labels + val_labels + test_labels)
    
    print(f"\nTotal unique labels: {len(mlb.classes_)}")
    print(f"Label classes: {list(mlb.classes_)}")
    
    return {
        'train': {'texts': train_texts, 'labels': train_labels},
        'val': {'texts': val_texts, 'labels': val_labels},
        'test': {'texts': test_texts, 'labels': test_labels},
        'label_encoder': mlb
    }

# Prepare the data
data_splits = prepare_classification_data(pharma_data)
mlb = data_splits['label_encoder']

In [None]:
# Create instruction-based training format for classification
def create_classification_instruction(text: str, labels: List[str], is_training: bool = True) -> str:
    """Create instruction-following format for multi-label classification."""
    
    instruction = """
You are an expert pharmaceutical document classifier. Your task is to analyze pharmaceutical documents and assign appropriate labels from the following categories:

THERAPEUTIC AREAS: oncology, cardiovascular, neurology, immunology, infectious_diseases, respiratory, endocrinology, gastroenterology, dermatology, ophthalmology, psychiatry, rheumatology, urology, hematology, rare_diseases

DEVELOPMENT PHASE: preclinical, phase_1, phase_2, phase_3, phase_4, post_market

REGULATORY CATEGORIES: safety_report, efficacy_study, pharmacokinetics, pharmacodynamics, toxicology, clinical_trial, regulatory_submission, label_update

RISK LEVEL: low_risk, medium_risk, high_risk, critical_risk

DOCUMENT TYPE: clinical_study_report, adverse_event_report, drug_label, investigator_brochure, protocol, statistical_analysis_plan, pharmacovigilance_report, regulatory_correspondence

Analyze the following pharmaceutical document and provide a comma-separated list of applicable labels:
"""
    
    if is_training:
        # Training format with expected output
        formatted_text = f"""{instruction.strip()}

DOCUMENT:
{text}

LABELS: {', '.join(sorted(labels))}"""
    else:
        # Inference format without labels
        formatted_text = f"""{instruction.strip()}

DOCUMENT:
{text}

LABELS:"""
    
    return formatted_text

# Convert data to instruction format
def create_instruction_dataset(data_dict: Dict) -> DatasetDict:
    """Convert data to instruction-following format."""
    
    datasets = {}
    
    for split_name, split_data in data_dict.items():
        if split_name == 'label_encoder':
            continue
            
        formatted_examples = []
        for text, labels in zip(split_data['texts'], split_data['labels']):
            formatted_text = create_classification_instruction(text, labels, is_training=True)
            formatted_examples.append({'text': formatted_text})
        
        datasets[split_name] = Dataset.from_list(formatted_examples)
    
    return DatasetDict(datasets)

# Create instruction datasets
instruction_datasets = create_instruction_dataset(data_splits)

print("Created instruction-following datasets:")
for split_name, dataset in instruction_datasets.items():
    print(f"  {split_name}: {len(dataset)} examples")

print("\nSample instruction format:")
print("=" * 80)
sample_text = instruction_datasets['train'][0]['text']
print(sample_text[:800] + "..." if len(sample_text) > 800 else sample_text)

## 4. QLoRA Configuration for Classification

In [None]:
# Optimized QLoRA configuration for classification tasks
def create_classification_bnb_config():
    """Create BitsAndBytes configuration optimized for classification."""
    
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )

def create_classification_lora_config():
    """Create LoRA configuration optimized for classification."""
    
    return LoraConfig(
        r=128,  # Higher rank for better classification performance
        lora_alpha=32,  # Balanced scaling
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ],
        lora_dropout=0.05,  # Lower dropout for classification
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )

# Model configuration
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
new_model = "llama-3.1-8b-pharma-classifier-qlora"

# Create configurations
quantization_config = create_classification_bnb_config()
lora_config = create_classification_lora_config()

print("Classification QLoRA Configuration:")
print("=" * 40)
print(f"LoRA rank (r): {lora_config.r}")
print(f"LoRA alpha: {lora_config.lora_alpha}")
print(f"LoRA dropout: {lora_config.lora_dropout}")
print(f"Target modules: {len(lora_config.target_modules)}")
print(f"Quantization: 4-bit NF4 with double quantization")
print(f"Expected trainable parameters: ~{lora_config.r * len(lora_config.target_modules) * 2 * 4096:,}")

In [None]:
# Memory monitoring utilities
def print_memory_usage(stage=""):
    """Print current memory usage."""
    if torch.cuda.is_available():
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
        gpu_allocated = torch.cuda.memory_allocated(0) / 1024**3
        gpu_reserved = torch.cuda.memory_reserved(0) / 1024**3
        
        print(f"\n{stage} GPU Memory:")
        print(f"  Allocated: {gpu_allocated:.1f} GB / {gpu_memory:.1f} GB ({gpu_allocated/gpu_memory*100:.1f}%)")
        print(f"  Reserved: {gpu_reserved:.1f} GB ({gpu_reserved/gpu_memory*100:.1f}%)")

print_memory_usage("Initial")

In [None]:
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print_memory_usage("After tokenizer")

In [None]:
# Load model with quantization
print("Loading Llama 3.1 8B with 4-bit quantization...")
print("This will take a few minutes...")

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    attn_implementation="flash_attention_2",  # Use flash attention for efficiency
)

print_memory_usage("After model loading")
print("✅ Model loaded successfully!")

In [None]:
# Prepare model for training
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

# Print trainable parameters
model.print_trainable_parameters()
print_memory_usage("After LoRA setup")

## 5. Model Architecture and Training

In [None]:
# Advanced training configuration for classification
def create_classification_training_args(output_dir: str, num_epochs: int = 5):
    """Create optimized training arguments for classification."""
    
    return TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=16,  # Effective batch size = 16
        optim="paged_adamw_32bit",
        save_steps=100,
        logging_steps=25,
        learning_rate=1e-4,  # Conservative for classification
        weight_decay=0.01,
        fp16=False,
        bf16=True,
        max_grad_norm=0.3,
        max_steps=-1,
        warmup_ratio=0.1,  # More warmup for stability
        group_by_length=True,
        lr_scheduler_type="cosine",
        save_total_limit=3,
        evaluation_strategy="steps",
        eval_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        report_to="none",
        dataloader_num_workers=0,
        remove_unused_columns=False,
    )

# Create training arguments
training_args = create_classification_training_args(
    output_dir=f"./results_{new_model}",
    num_epochs=5
)

print("Classification Training Configuration:")
print("=" * 40)
print(f"Epochs: {training_args.num_train_epochs}")
print(f"Batch size: {training_args.per_device_train_batch_size}")
print(f"Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"Learning rate: {training_args.learning_rate}")
print(f"LR scheduler: {training_args.lr_scheduler_type}")
print(f"Warmup ratio: {training_args.warmup_ratio}")
print(f"Weight decay: {training_args.weight_decay}")

In [None]:
# Custom data collator for classification
class ClassificationDataCollator:
    """Custom data collator for classification tasks."""
    
    def __init__(self, tokenizer, max_length=2048):
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __call__(self, examples):
        batch_texts = [example['text'] for example in examples]
        
        # Tokenize the batch
        tokenized = self.tokenizer(
            batch_texts,
            truncation=True,
            padding=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        # For causal LM, labels are the same as input_ids
        tokenized["labels"] = tokenized["input_ids"].clone()
        
        return tokenized

# Create data collator
data_collator = ClassificationDataCollator(tokenizer, max_length=2048)

print("✅ Data collator created")

In [None]:
# Initialize trainer with early stopping
trainer = SFTTrainer(
    model=model,
    train_dataset=instruction_datasets['train'],
    eval_dataset=instruction_datasets['val'],
    peft_config=lora_config,
    dataset_text_field="text",
    max_seq_length=2048,
    tokenizer=tokenizer,
    args=training_args,
    packing=False,
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

print("✅ Trainer initialized with early stopping")
print_memory_usage("Before training")

In [None]:
# Training with comprehensive monitoring
print("Starting pharmaceutical classification training...")
print("Expected training time: ~2-3 hours on RTX 4090")
print("=" * 60)

# Clear cache and prepare for training
torch.cuda.empty_cache()
gc.collect()

start_time = time.time()

# Train the model
trainer.train()

training_time = time.time() - start_time

print(f"\n🎉 Training completed in {training_time/3600:.1f} hours!")
print_memory_usage("After training")

# Save the trained model
trainer.model.save_pretrained(new_model)
tokenizer.save_pretrained(new_model)

print(f"\n✅ Model saved to ./{new_model}")

## 6. Evaluation Metrics and Benchmarking

In [None]:
# Comprehensive evaluation framework
class PharmaClassificationEvaluator:
    """Comprehensive evaluator for pharmaceutical multi-label classification."""
    
    def __init__(self, model, tokenizer, label_encoder):
        self.model = model
        self.tokenizer = tokenizer
        self.label_encoder = label_encoder
        self.all_labels = list(label_encoder.classes_)
    
    def predict_labels(self, text: str, max_new_tokens: int = 100) -> List[str]:
        """Predict labels for a single document."""
        
        # Create instruction format for inference
        formatted_text = create_classification_instruction(text, [], is_training=False)
        
        # Tokenize
        inputs = self.tokenizer(formatted_text, return_tensors="pt", truncation=True, max_length=2048).to(device)
        
        # Generate prediction
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=0.1,  # Low temperature for consistent predictions
                top_p=0.9,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )
        
        # Decode prediction
        full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract labels from response
        try:
            # Find the labels section
            if "LABELS:" in full_response:
                labels_section = full_response.split("LABELS:")[-1].strip()
                # Parse comma-separated labels
                predicted_labels = [label.strip() for label in labels_section.split(",")]
                # Filter valid labels
                predicted_labels = [label for label in predicted_labels if label in self.all_labels]
            else:
                predicted_labels = []
        except:
            predicted_labels = []
        
        return predicted_labels
    
    def evaluate_dataset(self, texts: List[str], true_labels: List[List[str]], 
                        batch_size: int = 8) -> Dict:
        """Evaluate the model on a dataset."""
        
        print(f"Evaluating on {len(texts)} documents...")
        
        all_predictions = []
        prediction_times = []
        
        # Process in batches to avoid memory issues
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            batch_predictions = []
            
            for text in batch_texts:
                start_time = time.time()
                pred_labels = self.predict_labels(text)
                prediction_times.append(time.time() - start_time)
                batch_predictions.append(pred_labels)
            
            all_predictions.extend(batch_predictions)
            
            if (i // batch_size + 1) % 5 == 0:
                print(f"Processed {i + len(batch_texts)}/{len(texts)} documents")
        
        # Convert to binary format for sklearn metrics
        y_true = self.label_encoder.transform(true_labels)
        y_pred = self.label_encoder.transform(all_predictions)
        
        # Calculate comprehensive metrics
        metrics = self._calculate_metrics(y_true, y_pred, true_labels, all_predictions)
        metrics['avg_prediction_time'] = np.mean(prediction_times)
        
        return metrics
    
    def _calculate_metrics(self, y_true, y_pred, true_labels_list, pred_labels_list):
        """Calculate comprehensive evaluation metrics."""
        
        # Overall metrics
        micro_f1 = f1_score(y_true, y_pred, average='micro')
        macro_f1 = f1_score(y_true, y_pred, average='macro')
        weighted_f1 = f1_score(y_true, y_pred, average='weighted')
        
        micro_precision = precision_score(y_true, y_pred, average='micro')
        macro_precision = precision_score(y_true, y_pred, average='macro')
        
        micro_recall = recall_score(y_true, y_pred, average='micro')
        macro_recall = recall_score(y_true, y_pred, average='macro')
        
        # Multi-label specific metrics
        hamming = hamming_loss(y_true, y_pred)
        jaccard = jaccard_score(y_true, y_pred, average='macro')
        
        # Per-label metrics
        per_label_f1 = f1_score(y_true, y_pred, average=None)
        per_label_precision = precision_score(y_true, y_pred, average=None)
        per_label_recall = recall_score(y_true, y_pred, average=None)
        
        # Exact match accuracy (all labels must match exactly)
        exact_matches = sum(1 for true, pred in zip(true_labels_list, pred_labels_list) 
                           if set(true) == set(pred))
        exact_match_ratio = exact_matches / len(true_labels_list)
        
        return {
            'micro_f1': micro_f1,
            'macro_f1': macro_f1,
            'weighted_f1': weighted_f1,
            'micro_precision': micro_precision,
            'macro_precision': macro_precision,
            'micro_recall': micro_recall,
            'macro_recall': macro_recall,
            'hamming_loss': hamming,
            'jaccard_score': jaccard,
            'exact_match_ratio': exact_match_ratio,
            'per_label_metrics': {
                'labels': list(self.label_encoder.classes_),
                'f1_scores': per_label_f1.tolist(),
                'precision_scores': per_label_precision.tolist(),
                'recall_scores': per_label_recall.tolist()
            }
        }

# Create evaluator
evaluator = PharmaClassificationEvaluator(model, tokenizer, mlb)
print("✅ Evaluator created")

In [None]:
# Evaluate on test set
print("Evaluating fine-tuned model on test set...")
print("This may take 15-30 minutes depending on test set size...")

test_metrics = evaluator.evaluate_dataset(
    data_splits['test']['texts'],
    data_splits['test']['labels'],
    batch_size=4
)

print("\n" + "=" * 60)
print("PHARMACEUTICAL CLASSIFICATION RESULTS")
print("=" * 60)

print(f"\n📊 OVERALL PERFORMANCE:")
print(f"  Micro F1:     {test_metrics['micro_f1']:.4f} ({test_metrics['micro_f1']*100:.1f}%)")
print(f"  Macro F1:     {test_metrics['macro_f1']:.4f} ({test_metrics['macro_f1']*100:.1f}%)")
print(f"  Weighted F1:  {test_metrics['weighted_f1']:.4f} ({test_metrics['weighted_f1']*100:.1f}%)")

print(f"\n🎯 PRECISION & RECALL:")
print(f"  Micro Precision: {test_metrics['micro_precision']:.4f} ({test_metrics['micro_precision']*100:.1f}%)")
print(f"  Macro Precision: {test_metrics['macro_precision']:.4f} ({test_metrics['macro_precision']*100:.1f}%)")
print(f"  Micro Recall:    {test_metrics['micro_recall']:.4f} ({test_metrics['micro_recall']*100:.1f}%)")
print(f"  Macro Recall:    {test_metrics['macro_recall']:.4f} ({test_metrics['macro_recall']*100:.1f}%)")

print(f"\n📈 MULTI-LABEL METRICS:")
print(f"  Hamming Loss:      {test_metrics['hamming_loss']:.4f}")
print(f"  Jaccard Score:     {test_metrics['jaccard_score']:.4f} ({test_metrics['jaccard_score']*100:.1f}%)")
print(f"  Exact Match Ratio: {test_metrics['exact_match_ratio']:.4f} ({test_metrics['exact_match_ratio']*100:.1f}%)")

print(f"\n⚡ PERFORMANCE:")
print(f"  Avg Prediction Time: {test_metrics['avg_prediction_time']:.3f} seconds")

# Improvement calculation
baseline_f1 = 0.80  # Original model performance
improvement = (test_metrics['micro_f1'] - baseline_f1) / baseline_f1 * 100

print(f"\n🚀 IMPROVEMENT vs BASELINE:")
print(f"  Baseline F1:  {baseline_f1:.1%}")
print(f"  Current F1:   {test_metrics['micro_f1']:.1%}")
print(f"  Improvement:  {improvement:+.1f}%")

if test_metrics['micro_f1'] >= 0.95:
    print(f"\n🎉 SUCCESS! Target of 95%+ F1 score achieved!")
elif test_metrics['micro_f1'] >= 0.90:
    print(f"\n✅ Excellent performance! Close to 95% target.")
else:
    print(f"\n⚠️ Performance improvement needed for 95% target.")

In [None]:
# Detailed per-label analysis
def analyze_per_label_performance(metrics: Dict):
    """Analyze performance for each label category."""
    
    per_label = metrics['per_label_metrics']
    labels = per_label['labels']
    f1_scores = per_label['f1_scores']
    precision_scores = per_label['precision_scores']
    recall_scores = per_label['recall_scores']
    
    # Create DataFrame for analysis
    df = pd.DataFrame({
        'Label': labels,
        'F1': f1_scores,
        'Precision': precision_scores,
        'Recall': recall_scores
    })
    
    # Sort by F1 score
    df = df.sort_values('F1', ascending=False)
    
    print("\nPER-LABEL PERFORMANCE ANALYSIS:")
    print("=" * 70)
    print(f"{'Label':<30} {'F1':<8} {'Precision':<10} {'Recall':<8}")
    print("-" * 70)
    
    # Group by category for better analysis
    for category, category_labels in PHARMA_LABELS.items():
        category_df = df[df['Label'].isin(category_labels)]
        if len(category_df) > 0:
            print(f"\n{category.upper().replace('_', ' ')}:")
            for _, row in category_df.iterrows():
                print(f"{row['Label']:<30} {row['F1']:<8.3f} {row['Precision']:<10.3f} {row['Recall']:<8.3f}")
    
    # Identify best and worst performing labels
    best_labels = df.head(5)
    worst_labels = df[df['F1'] > 0].tail(5)  # Exclude labels with 0 F1
    
    print(f"\n🏆 TOP 5 PERFORMING LABELS:")
    for _, row in best_labels.iterrows():
        print(f"  {row['Label']}: F1={row['F1']:.3f}")
    
    print(f"\n⚠️ LABELS NEEDING IMPROVEMENT:")
    for _, row in worst_labels.iterrows():
        print(f"  {row['Label']}: F1={row['F1']:.3f}")
    
    return df

# Analyze per-label performance
label_performance_df = analyze_per_label_performance(test_metrics)

In [None]:
# Visualize performance results
def visualize_classification_results(metrics: Dict, label_df: pd.DataFrame):
    """Create comprehensive performance visualizations."""
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # Overall metrics comparison
    overall_metrics = {
        'Micro F1': metrics['micro_f1'],
        'Macro F1': metrics['macro_f1'],
        'Weighted F1': metrics['weighted_f1'],
        'Jaccard Score': metrics['jaccard_score'],
        'Exact Match': metrics['exact_match_ratio']
    }
    
    bars = axes[0, 0].bar(overall_metrics.keys(), overall_metrics.values(), 
                         color=['skyblue', 'lightgreen', 'orange', 'pink', 'lightyellow'])
    axes[0, 0].set_title('Overall Performance Metrics')
    axes[0, 0].set_ylabel('Score')
    axes[0, 0].set_ylim(0, 1)
    axes[0, 0].tick_params(axis='x', rotation=45)
    
    # Add value labels on bars
    for bar, value in zip(bars, overall_metrics.values()):
        axes[0, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                       f'{value:.3f}', ha='center', va='bottom')
    
    # F1 scores by category
    top_20_labels = label_df.head(20)
    bars = axes[0, 1].barh(range(len(top_20_labels)), top_20_labels['F1'].values, 
                          color='lightcoral')
    axes[0, 1].set_yticks(range(len(top_20_labels)))
    axes[0, 1].set_yticklabels(top_20_labels['Label'].values)
    axes[0, 1].set_title('Top 20 Labels by F1 Score')
    axes[0, 1].set_xlabel('F1 Score')
    
    # Precision vs Recall scatter
    scatter = axes[1, 0].scatter(label_df['Precision'], label_df['Recall'], 
                                c=label_df['F1'], cmap='viridis', alpha=0.7, s=60)
    axes[1, 0].set_xlabel('Precision')
    axes[1, 0].set_ylabel('Recall')
    axes[1, 0].set_title('Precision vs Recall (colored by F1)')
    axes[1, 0].plot([0, 1], [0, 1], 'k--', alpha=0.3)
    plt.colorbar(scatter, ax=axes[1, 0], label='F1 Score')
    
    # Comparison with baseline
    baseline_scores = [0.80, 0.75, 0.82, 0.60, 0.45]  # Simulated baseline
    current_scores = [metrics['micro_f1'], metrics['macro_f1'], metrics['weighted_f1'], 
                     metrics['jaccard_score'], metrics['exact_match_ratio']]
    
    x = np.arange(len(overall_metrics))
    width = 0.35
    
    axes[1, 1].bar(x - width/2, baseline_scores, width, label='Baseline', color='lightgray')
    axes[1, 1].bar(x + width/2, current_scores, width, label='QLoRA Fine-tuned', color='lightblue')
    
    axes[1, 1].set_xlabel('Metrics')
    axes[1, 1].set_ylabel('Score')
    axes[1, 1].set_title('Performance Comparison: Baseline vs Fine-tuned')
    axes[1, 1].set_xticks(x)
    axes[1, 1].set_xticklabels(overall_metrics.keys(), rotation=45)
    axes[1, 1].legend()
    axes[1, 1].set_ylim(0, 1)
    
    plt.tight_layout()
    plt.show()
    
    # Performance by therapeutic area
    fig, ax = plt.subplots(figsize=(12, 8))
    
    therapeutic_performance = []
    for category, category_labels in PHARMA_LABELS.items():
        category_df = label_df[label_df['Label'].isin(category_labels)]
        if len(category_df) > 0:
            avg_f1 = category_df['F1'].mean()
            therapeutic_performance.append((category.replace('_', ' ').title(), avg_f1))
    
    if therapeutic_performance:
        categories, f1_scores = zip(*therapeutic_performance)
        bars = ax.bar(categories, f1_scores, color='lightseagreen', alpha=0.8)
        ax.set_title('Average F1 Score by Label Category')
        ax.set_ylabel('Average F1 Score')
        ax.tick_params(axis='x', rotation=45)
        
        # Add value labels
        for bar, score in zip(bars, f1_scores):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                   f'{score:.3f}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()

# Create visualizations
visualize_classification_results(test_metrics, label_performance_df)

## 7. Ablation Studies and Optimization

In [None]:
# Test model with sample predictions
def test_sample_predictions():
    """Test the model with sample pharmaceutical documents."""
    
    test_documents = [
        {
            'text': "Phase III randomized controlled trial of pembrolizumab versus chemotherapy in patients with advanced non-small cell lung cancer. Primary endpoint was overall survival. 1200 patients enrolled across 150 sites globally. Results demonstrate significant improvement in overall survival with pembrolizumab (HR=0.73, p<0.001). Safety profile consistent with known immune-related adverse events including grade 3 pneumonitis in 2% of patients.",
            'expected_labels': ['oncology', 'phase_3', 'clinical_trial', 'efficacy_study', 'clinical_study_report', 'medium_risk']
        },
        {
            'text': "Serious adverse event report: 67-year-old female patient experienced severe hepatotoxicity after 14 days of treatment with investigational drug XYZ-123. Event assessed as probably related to study drug. Patient hospitalized and treatment permanently discontinued. Liver function tests showed ALT 15x ULN. Patient recovered with supportive care.",
            'expected_labels': ['adverse_event_report', 'safety_report', 'high_risk', 'pharmacovigilance_report']
        },
        {
            'text': "INDICATIONS AND USAGE: Adalimumab is a TNF-alpha inhibitor indicated for the treatment of rheumatoid arthritis, Crohn's disease, and psoriasis in adult patients. DOSAGE: 40 mg subcutaneous injection every other week. CONTRAINDICATIONS: Active tuberculosis, serious infections. WARNINGS: Monitor for serious infections and malignancies.",
            'expected_labels': ['drug_label', 'immunology', 'rheumatology', 'post_market', 'medium_risk']
        }
    ]
    
    print("SAMPLE PREDICTION TESTING")
    print("=" * 80)
    
    for i, doc in enumerate(test_documents, 1):
        print(f"\n📄 Test Document {i}:")
        print(f"Text: {doc['text'][:150]}...")
        
        predicted_labels = evaluator.predict_labels(doc['text'])
        expected_labels = doc['expected_labels']
        
        print(f"\n🎯 Expected: {expected_labels}")
        print(f"🤖 Predicted: {predicted_labels}")
        
        # Calculate overlap
        overlap = set(predicted_labels) & set(expected_labels)
        precision = len(overlap) / len(predicted_labels) if predicted_labels else 0
        recall = len(overlap) / len(expected_labels) if expected_labels else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        print(f"📊 Sample F1: {f1:.3f} (P: {precision:.3f}, R: {recall:.3f})")
        
        if f1 > 0.8:
            print("✅ Excellent prediction!")
        elif f1 > 0.6:
            print("✅ Good prediction")
        else:
            print("⚠️ Needs improvement")
        
        print("-" * 80)

# Test sample predictions
test_sample_predictions()

In [None]:
# Error analysis and improvement recommendations
def analyze_prediction_errors(true_labels: List[List[str]], predicted_labels: List[List[str]]):
    """Analyze common prediction errors and provide recommendations."""
    
    print("ERROR ANALYSIS & IMPROVEMENT RECOMMENDATIONS")
    print("=" * 70)
    
    # Track common errors
    false_positives = Counter()
    false_negatives = Counter()
    label_confusions = Counter()
    
    for true_set, pred_set in zip(true_labels, predicted_labels):
        true_set = set(true_set)
        pred_set = set(pred_set)
        
        # False positives: predicted but not true
        fp = pred_set - true_set
        for label in fp:
            false_positives[label] += 1
        
        # False negatives: true but not predicted
        fn = true_set - pred_set
        for label in fn:
            false_negatives[label] += 1
        
        # Label confusions: co-occurring errors
        for fp_label in fp:
            for fn_label in fn:
                label_confusions[(fp_label, fn_label)] += 1
    
    print(f"\n🔴 MOST COMMON FALSE POSITIVES:")
    for label, count in false_positives.most_common(10):
        print(f"  {label}: {count} times")
    
    print(f"\n🔴 MOST COMMON FALSE NEGATIVES:")
    for label, count in false_negatives.most_common(10):
        print(f"  {label}: {count} times")
    
    print(f"\n🔄 COMMON LABEL CONFUSIONS:")
    for (fp_label, fn_label), count in label_confusions.most_common(5):
        print(f"  {fp_label} ↔ {fn_label}: {count} times")
    
    # Improvement recommendations
    print(f"\n💡 IMPROVEMENT RECOMMENDATIONS:")
    
    # High false positive rate
    top_fp = false_positives.most_common(3)
    if top_fp:
        print(f"  1. Reduce false positives for: {', '.join([label for label, _ in top_fp])}")
        print(f"     → Consider adding negative examples or adjusting decision threshold")
    
    # High false negative rate
    top_fn = false_negatives.most_common(3)
    if top_fn:
        print(f"  2. Improve recall for: {', '.join([label for label, _ in top_fn])}")
        print(f"     → Add more training examples or improve feature representation")
    
    # Class imbalance
    total_predictions = sum(false_positives.values()) + sum(false_negatives.values())
    if total_predictions > 0:
        print(f"  3. Address class imbalance with balanced sampling or cost-sensitive learning")
        print(f"  4. Consider ensemble methods or multiple specialized models")
        print(f"  5. Implement confidence-based filtering for high-stakes predictions")

# Perform error analysis on a subset
sample_size = min(100, len(data_splits['test']['texts']))
sample_texts = data_splits['test']['texts'][:sample_size]
sample_true_labels = data_splits['test']['labels'][:sample_size]

print(f"Performing error analysis on {sample_size} test samples...")
sample_predictions = [evaluator.predict_labels(text) for text in sample_texts]

analyze_prediction_errors(sample_true_labels, sample_predictions)

## 8. Production Deployment and Monitoring