In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel, PeftConfig
from typing import List, Dict, Union
import numpy as np

class MedicalQAInference:
    def __init__(self, model_path: str = "./final_model"):
        """
        Initialize the inference pipeline with the fine-tuned model
        
        Args:
            model_path: Path to the saved fine-tuned model
        """
        # Load the configuration
        self.peft_config = PeftConfig.from_pretrained(model_path)
        
        # Load the base tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.peft_config.base_model_name_or_path)
        
        # Load the base model with quantization config
        self.base_model = AutoModelForSequenceClassification.from_pretrained(
            self.peft_config.base_model_name_or_path,
            num_labels=4,
        )
        
        # Load the PEFT model
        self.model = PeftModel.from_pretrained(
            self.base_model,
            model_path,
        )
        
        # Set model to evaluation mode
        self.model.eval()
        
        # Define option mapping
        self.option_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}

    def format_question(self, question: str, options: Dict[str, str]) -> str:
        """
        Format the question and options in the same way as training data
        
        Args:
            question: The medical question
            options: Dictionary containing options A, B, C, D
            
        Returns:
            Formatted question string
        """
        return f"""Question: {question}
A) {options['A']}
B) {options['B']}
C) {options['C']}
D) {options['D']}"""

    def predict(self, 
                question: str, 
                options: Dict[str, str], 
                return_probabilities: bool = False
               ) -> Union[str, Dict[str, Union[str, List[float]]]]:
        """
        Make a prediction for a single medical question
        
        Args:
            question: The medical question
            options: Dictionary containing options A, B, C, D
            return_probabilities: Whether to return probability scores
            
        Returns:
            Predicted answer option or dictionary with prediction and probabilities
        """
        # Format the question
        formatted_input = self.format_question(question, options)
        
        # Tokenize
        inputs = self.tokenizer(
            formatted_input,
            return_tensors="pt",
            truncation=True,
            max_length=512
        ).to(self.model.device)
        
        # Get prediction
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probabilities = torch.softmax(logits, dim=-1)
            prediction = torch.argmax(logits, dim=-1)
        
        # Convert prediction to option letter
        predicted_option = self.option_mapping[prediction.item()]
        
        if return_probabilities:
            # Convert probabilities to list and map to options
            prob_dict = {
                self.option_mapping[i]: prob.item() 
                for i, prob in enumerate(probabilities[0])
            }
            return {
                'prediction': predicted_option,
                'probabilities': prob_dict,
                'confidence': prob_dict[predicted_option]
            }
        
        return predicted_option

    def batch_predict(self, 
                     questions: List[Dict[str, Union[str, Dict[str, str]]]]
                    ) -> List[Dict[str, Union[str, float]]]:
        """
        Make predictions for a batch of questions
        
        Args:
            questions: List of dictionaries containing questions and options
            
        Returns:
            List of dictionaries with predictions and confidence scores
        """
        formatted_inputs = [
            self.format_question(q['question'], q['options']) 
            for q in questions
        ]
        
        # Tokenize
        inputs = self.tokenizer(
            formatted_inputs,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=512
        ).to(self.model.device)
        
        # Get predictions
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probabilities = torch.softmax(logits, dim=-1)
            predictions = torch.argmax(logits, dim=-1)
        
        # Process results
        results = []
        for i, (pred, probs) in enumerate(zip(predictions, probabilities)):
            pred_option = self.option_mapping[pred.item()]
            confidence = probs[pred].item()
            
            results.append({
                'question_id': i,
                'prediction': pred_option,
                'confidence': confidence
            })
        
        return results

In [7]:
# Create a diverse set of test questions across difficulty levels
batch_questions = [
    # Basic Questions
    {
        'question': "Which of the following is the most common cause of iron deficiency anemia?",
        'options': {
            'A': 'Chronic blood loss',
            'B': 'Decreased iron absorption',
            'C': 'Increased iron demand',
            'D': 'Dietary deficiency'
        },
        'difficulty': 'basic',
        'correct_answer': 'A'  # For validation
    },
    {
        'question': "The normal range for adult body temperature measured orally is:",
        'options': {
            'A': '35.0-36.0°C',
            'B': '36.0-37.0°C',
            'C': '36.5-37.5°C',
            'D': '37.5-38.5°C'
        },
        'difficulty': 'basic',
        'correct_answer': 'C'  # For validation
    },

    # Intermediate Questions
    {
        'question': "A 45-year-old patient presents with recurrent episodes of facial flushing, diarrhea, and bronchospasm. CT scan reveals a small mass in the terminal ileum. Which of the following is the most likely diagnosis?",
        'options': {
            'A': 'Gastrinoma',
            'B': 'Carcinoid syndrome',
            'C': 'VIPoma',
            'D': 'Pheochromocytoma'
        },
        'difficulty': 'intermediate',
        'correct_answer': 'B'  # For validation
    },
    
    {
        'question': "In the treatment of acute bacterial meningitis, which of the following factors most strongly influences the choice of empiric antibiotic therapy?",
        'options': {
            'A': "Patient's age",
            'B': 'Recent antibiotic use',
            'C': 'Duration of symptoms',
            'D': 'Presence of skin rash'
        },
        'difficulty': 'intermediate',
        'correct_answer': 'A'  # For validation
    },

    # Advanced Questions
    {
        'question': "A 28-year-old woman with systemic lupus erythematosus develops sudden onset of left-sided weakness and slurred speech. Laboratory studies reveal a prolonged PTT, positive lupus anticoagulant, and anti-cardiolipin antibodies. Brain MRI shows multiple small cortical infarcts. Which of the following is the most appropriate initial treatment?",
        'options': {
            'A': 'Aspirin',
            'B': 'Unfractionated heparin',
            'C': 'Cyclophosphamide',
            'D': 'Tissue plasminogen activator'
        },
        'difficulty': 'advanced',
        'correct_answer': 'B'  # For validation
    },
    {
        'question': "A 62-year-old male with chronic hepatitis B develops hepatorenal syndrome. Despite treatment with vasoconstrictors and albumin, his renal function continues to worsen. Laboratory studies show: Creatinine 4.2 mg/dL, INR 2.1, Total bilirubin 5.8 mg/dL, Albumin 2.8 g/dL. Calculate his MELD score and determine the most appropriate next step in management.",
        'options': {
            'A': 'MELD 25; Continue medical management',
            'B': 'MELD 28; Evaluate for liver transplantation',
            'C': 'MELD 32; Initiate hemodialysis',
            'D': 'MELD 35; Palliative care consultation'
        },
        'difficulty': 'advanced',
        'correct_answer': 'B'  # For validation
    }
]

# Function to evaluate model performance across difficulty levels
def evaluate_by_difficulty(inference_model, questions):
    results = inference_model.batch_predict(questions)
    
    # Group results by difficulty
    performance = {
        'basic': {'correct': 0, 'total': 0, 'confidence': []},
        'intermediate': {'correct': 0, 'total': 0, 'confidence': []},
        'advanced': {'correct': 0, 'total': 0, 'confidence': []}
    }
    
    for i, result in enumerate(results):
        difficulty = questions[i]['difficulty']
        correct_answer = questions[i]['correct_answer']
        
        # Update statistics
        performance[difficulty]['total'] += 1
        performance[difficulty]['confidence'].append(result['confidence'])
        if result['prediction'] == correct_answer:
            performance[difficulty]['correct'] += 1
    
    # Calculate and display statistics
    print("\nPerformance Analysis by Difficulty Level:")
    print("----------------------------------------")
    for difficulty in ['basic', 'intermediate', 'advanced']:
        stats = performance[difficulty]
        accuracy = (stats['correct'] / stats['total']) * 100 if stats['total'] > 0 else 0
        avg_confidence = sum(stats['confidence']) / len(stats['confidence']) if stats['confidence'] else 0
        
        print(f"\n{difficulty.upper()} Level Questions:")
        print(f"Accuracy: {accuracy:.1f}%")
        print(f"Average Confidence: {avg_confidence:.3f}")

    return performance

# Example usage:
if __name__ == "__main__":
    # Initialize inference model
    inference = MedicalQAInference("./final_model")
    
    # Run evaluation
    performance_metrics = evaluate_by_difficulty(inference, batch_questions)
    
    # Print detailed predictions
    print("\nDetailed Predictions:")
    print("--------------------")
    for i, question in enumerate(batch_questions):
        result = inference.predict(
            question['question'], 
            question['options'], 
            return_probabilities=True
        )
        
        print(f"\nQuestion {i+1} ({question['difficulty'].upper()} Level):")
        print(f"Q: {question['question']}")
        print(f"Predicted: Option {result['prediction']}")
        print(f"Correct Answer: Option {question['correct_answer']}")
        print("Confidence Scores:")
        for option, prob in result['probabilities'].items():
            print(f"Option {option}: {prob:.4f}")
        print("-" * 50)

Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-small and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Performance Analysis by Difficulty Level:
----------------------------------------

BASIC Level Questions:
Accuracy: 0.0%
Average Confidence: 0.316

INTERMEDIATE Level Questions:
Accuracy: 0.0%
Average Confidence: 0.322

ADVANCED Level Questions:
Accuracy: 0.0%
Average Confidence: 0.322

Detailed Predictions:
--------------------

Question 1 (BASIC Level):
Q: Which of the following is the most common cause of iron deficiency anemia?
Predicted: Option D
Correct Answer: Option A
Confidence Scores:
Option A: 0.2412
Option B: 0.2748
Option C: 0.1729
Option D: 0.3111
--------------------------------------------------

Question 2 (BASIC Level):
Q: The normal range for adult body temperature measured orally is:
Predicted: Option D
Correct Answer: Option C
Confidence Scores:
Option A: 0.2342
Option B: 0.2649
Option C: 0.1806
Option D: 0.3203
--------------------------------------------------

Question 3 (INTERMEDIATE Level):
Q: A 45-year-old patient presents with recurrent episodes of facial 

In [5]:
# Example usage
def main():
    # Initialize inference pipeline
    inference = MedicalQAInference("./final_model")
    
    # Single question example
    question = "Chronic urethral obstruction due to benign prismatic hyperplasia can lead to the following change in kidney parenchyma"
    options = {
        'A': 'Hyperplasia',
        'B': 'Hyperophy',
        'C': 'Atrophy',
        'D': 'Dyplasia'
    }
    
    # Get prediction with probabilities
    result = inference.predict(question, options, return_probabilities=True)
    print("\nSingle Question Prediction:")
    print(f"Question: {question}")
    print(f"Predicted Answer: Option {result['prediction']}")
    print("Confidence Scores:")
    for option, prob in result['probabilities'].items():
        print(f"Option {option}: {prob:.4f}")
    
    # Batch prediction example
    batch_questions = [
        {
            'question': question,
            'options': options
        },
        
        {
            'question': "Another medical question...",
            'options': {
                'A': 'Option A',
                'B': 'Option B',
                'C': 'Option C',
                'D': 'Option D'
            }
        }
    ]
    
    batch_results = inference.batch_predict(batch_questions)
    print("\nBatch Prediction Results:")
    for result in batch_results:
        print(f"Question {result['question_id']}: Option {result['prediction']} (Confidence: {result['confidence']:.4f})")

In [6]:
if __name__ == "__main__":
    main()

Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-small and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Single Question Prediction:
Question: Chronic urethral obstruction due to benign prismatic hyperplasia can lead to the following change in kidney parenchyma
Predicted Answer: Option A
Confidence Scores:
Option A: 0.2870
Option B: 0.2112
Option C: 0.2386
Option D: 0.2632

Batch Prediction Results:
Question 0: Option A (Confidence: 0.2870)
Question 1: Option A (Confidence: 0.2901)
