In [None]:
!pip install sentence-transformers datasets transformers tqdm matplotlib pandas scikit-learn -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/491.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/193.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import os
import torch
import numpy as np
import pandas as pd
from typing import List, Dict, Tuple, Optional
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from sklearn.model_selection import train_test_split
import logging
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
class MedRefineConfig:
    """Configuration for MedRefine system"""

    def __init__(
        self,
        # Using smaller models that are more suitable for Colab's free GPU
        teacher_model_name: str = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
        student_model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
        sentence_bert_model: str = "pritamdeka/S-BioBert-snli-multinli-stsb",
        initial_threshold: float = 0.75,
        adaptive_threshold_factor: float = 0.02,
        max_threshold: float = 0.95,
        min_threshold: float = 0.65,
        dataset_name: str = "medical_qa",
        batch_size: int = 4,  # Reduced batch size for Colab GPU
        max_length: int = 512,
        learning_rate: float = 2e-5,
        num_train_epochs: int = 2,  # Reduced epochs for Colab demo
        output_dir: str = "./medrefine_output"
    ):
        self.teacher_model_name = teacher_model_name
        self.student_model_name = student_model_name
        self.sentence_bert_model = sentence_bert_model
        self.initial_threshold = initial_threshold
        self.adaptive_threshold_factor = adaptive_threshold_factor
        self.max_threshold = max_threshold
        self.min_threshold = min_threshold
        self.dataset_name = dataset_name
        self.batch_size = batch_size
        self.max_length = max_length
        self.learning_rate = learning_rate
        self.num_train_epochs = num_train_epochs
        self.output_dir = output_dir

        # Create output directory if it doesn't exist
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)


In [None]:
class MedicalDataProcessor:
    """Class to handle loading and processing medical datasets"""

    def __init__(self, config: MedRefineConfig):
        self.config = config

    def load_medquad_dataset(self) -> pd.DataFrame:
        """
        Load the MedQuAD dataset or a suitable medical Q&A dataset

        Returns:
            DataFrame with columns 'question' and 'answer'
        """
        try:
            # Try to load from HuggingFace datasets
            # For Colab demo, we'll use a smaller synthetic dataset
            logger.info("Using synthetic medical Q&A data for demonstration")

            # Create synthetic medical Q&A data
            synthetic_data = [
                {
                    "question": "What are the symptoms of diabetes?",
                    "answer": "Common symptoms of diabetes include increased thirst, frequent urination, extreme hunger, unexplained weight loss, fatigue, irritability, blurred vision, slow-healing sores, and frequent infections."
                },
                {
                    "question": "How is hypertension diagnosed?",
                    "answer": "Hypertension is diagnosed when blood pressure readings are consistently 130/80 mm Hg or higher. Diagnosis typically requires multiple readings over time and may include ambulatory blood pressure monitoring."
                },
                {
                    "question": "What treatments are available for rheumatoid arthritis?",
                    "answer": "Treatments for rheumatoid arthritis include NSAIDs, steroids, conventional DMARDs like methotrexate, biologic DMARDs such as TNF inhibitors, JAK inhibitors, and supportive therapies like physical therapy and lifestyle modifications."
                },
                {
                    "question": "What are common side effects of statins?",
                    "answer": "Common side effects of statins include muscle pain and damage, liver damage, increased blood sugar, neurological side effects, and digestive problems. Most people tolerate statins well, but regular monitoring is important."
                },
                {
                    "question": "How is pneumonia diagnosed?",
                    "answer": "Pneumonia diagnosis involves physical examination, listening to the lungs, chest X-rays, blood tests to check for infection, pulse oximetry to measure blood oxygen, and sometimes sputum tests or bronchoscopy in severe cases."
                },
                {
                    "question": "What are the risk factors for heart disease?",
                    "answer": "Risk factors for heart disease include age, family history, smoking, high blood pressure, high cholesterol, diabetes, obesity, physical inactivity, unhealthy diet, excessive alcohol consumption, and stress."
                },
                {
                    "question": "How is type 2 diabetes managed?",
                    "answer": "Type 2 diabetes management involves lifestyle changes such as healthy eating, regular exercise, and weight loss. Medications may include metformin, sulfonylureas, DPP-4 inhibitors, GLP-1 receptor agonists, SGLT2 inhibitors, and insulin in some cases."
                },
                {
                    "question": "What causes migraine headaches?",
                    "answer": "Migraine headaches are caused by a combination of genetic factors and environmental triggers. These may include hormonal changes, certain foods, stress, sensory stimuli, sleep disruptions, physical exertion, and weather changes."
                },
                {
                    "question": "What are the symptoms of COVID-19?",
                    "answer": "Common symptoms of COVID-19 include fever, cough, shortness of breath, fatigue, muscle aches, headache, loss of taste or smell, sore throat, congestion, nausea, and diarrhea. Symptoms may appear 2-14 days after exposure."
                },
                {
                    "question": "How is osteoporosis diagnosed?",
                    "answer": "Osteoporosis is diagnosed using bone density tests (DXA scans), which measure bone mineral density. Risk assessment tools, medical history, physical exams, and sometimes blood tests may also be used in diagnosis."
                },
            ]

            # Expand synthetic data for demonstration purposes
            expanded_data = []
            for i in range(5):  # Reduced expansion for Colab
                for item in synthetic_data:
                    new_item = item.copy()
                    if i > 0:
                        new_item["question"] = f"{item['question']} (variant {i})"
                    expanded_data.append(new_item)

            df = pd.DataFrame(expanded_data)

        except Exception as e:
            logger.error(f"Error loading dataset: {e}")
            # Fallback to minimal dataset
            df = pd.DataFrame([
                {"question": "What is diabetes?", "answer": "Diabetes is a chronic condition characterized by high blood sugar levels."}
            ])

        return df

    def prepare_data(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """
        Prepare and split the dataset

        Returns:
            Tuple of (train_df, val_df, test_df)
        """
        df = self.load_medquad_dataset()

        # Split into train, validation, and test sets
        train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
        val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

        logger.info(f"Training examples: {len(train_df)}")
        logger.info(f"Validation examples: {len(val_df)}")
        logger.info(f"Test examples: {len(test_df)}")

        return train_df, val_df, test_df

    def format_prompt(self, question: str) -> str:
        """Format a question as a prompt for the model"""
        return f"Medical Question: {question}\nMedical Answer:"

    def create_dataset_for_finetuning(self,
                                  df: pd.DataFrame,
                                  teacher_responses: Optional[List[str]] = None) -> List[Dict]:
      """
      Create dataset for finetuning with either ground truth or teacher model responses

      Args:
        df: DataFrame with questions and ground truth answers
        teacher_responses: Optional list of teacher model generated answers

      Returns:
          List of formatted examples for training
      """
      examples = []

      # Make sure teacher_responses is the same length as df if provided
      if teacher_responses is not None and len(teacher_responses) != len(df):
          logger.warning(f"Length mismatch: {len(teacher_responses)} responses for {len(df)} questions")
          # Use the minimum length to avoid index errors
          max_idx = min(len(df), len(teacher_responses))
          df = df.iloc[:max_idx].copy()
          teacher_responses = teacher_responses[:max_idx]

      for i, row in df.iterrows():
          try:
              question = row['question']

            # Use teacher response if provided, otherwise use ground truth
              if teacher_responses is not None:
                # Use list index instead of DataFrame index to access teacher_responses
                  list_idx = df.index.get_loc(i)
                  answer = teacher_responses[list_idx] if list_idx < len(teacher_responses) else row['answer']
              else:
                  answer = row['answer']

              formatted_input = self.format_prompt(question)

              examples.append({
                  "input": formatted_input,
                  "output": answer,
                  "combined": f"{formatted_input} {answer}"
              })
          except Exception as e:
              logger.error(f"Error processing example {i}: {e}")
              continue

      logger.info(f"Created {len(examples)} examples for fine-tuning")
      return examples




In [None]:
class SimilarityEvaluator:
    """Class to evaluate semantic similarity between responses"""

    def __init__(self, config: MedRefineConfig):
        self.config = config
        logger.info(f"Loading sentence transformer model: {config.sentence_bert_model}")
        self.sentence_model = SentenceTransformer(config.sentence_bert_model)
        self.sentence_model.to(device)  # Move to GPU if available
        self.current_threshold = config.initial_threshold
        self.similarity_history = []

    def calculate_similarity(self, response1: str, response2: str) -> float:
        """
        Calculate semantic similarity between two responses

        Args:
            response1: First response text
            response2: Second response text

        Returns:
            Similarity score between 0 and 1
        """
        # Generate embeddings
        embedding1 = self.sentence_model.encode(response1, convert_to_tensor=True).to(device)
        embedding2 = self.sentence_model.encode(response2, convert_to_tensor=True).to(device)

        # Calculate cosine similarity
        similarity = util.pytorch_cos_sim(embedding1, embedding2).item()

        return similarity

    def evaluate_batch_similarity(self,
                                 teacher_responses: List[str],
                                 student_responses: List[str]) -> List[float]:
        """
        Calculate similarity scores for batches of responses

        Args:
            teacher_responses: List of teacher model responses
            student_responses: List of student model responses

        Returns:
            List of similarity scores
        """
        similarities = []

        for t_resp, s_resp in zip(teacher_responses, student_responses):
            sim = self.calculate_similarity(t_resp, s_resp)
            similarities.append(sim)

        return similarities

    def update_adaptive_threshold(self, recent_similarities: List[float]) -> None:
        """
        Update the adaptive similarity threshold based on recent performance

        Args:
            recent_similarities: Recent batch of similarity scores
        """
        if not recent_similarities:
            return

        # Calculate average similarity from recent batch
        avg_similarity = sum(recent_similarities) / len(recent_similarities)
        self.similarity_history.append(avg_similarity)

        # Adjust threshold based on recent performance
        if avg_similarity > self.current_threshold:
            # Increase threshold if performance is good
            new_threshold = self.current_threshold + self.config.adaptive_threshold_factor
            self.current_threshold = min(new_threshold, self.config.max_threshold)
        else:
            # Decrease threshold if performance is poor
            new_threshold = self.current_threshold - self.config.adaptive_threshold_factor
            self.current_threshold = max(new_threshold, self.config.min_threshold)

        logger.info(f"Adaptive threshold updated to: {self.current_threshold:.4f}")

    def plot_similarity_history(self) -> None:
        """Plot the history of similarity scores and threshold changes"""
        plt.figure(figsize=(10, 6))
        plt.plot(self.similarity_history, label='Average Similarity')

        # Add a horizontal line for the current threshold
        plt.axhline(y=self.current_threshold, color='r', linestyle='--',
                    label=f'Current Threshold: {self.current_threshold:.2f}')

        plt.title('Semantic Similarity History')
        plt.xlabel('Training Batch')
        plt.ylabel('Average Similarity Score')
        plt.legend()
        plt.grid(True)

        # Save the plot
        plt.savefig(os.path.join(self.config.output_dir, 'similarity_history.png'))
        plt.close()


In [None]:
class MedicalModel:
    """Base class for medical language models"""

    def __init__(self, model_name: str, tokenizer_name: str = None):
        self.model_name = model_name
        self.tokenizer_name = tokenizer_name if tokenizer_name else model_name

        logger.info(f"Loading model: {model_name}")

        # Load tokenizer with proper padding token handling
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)

        # Important: Make sure pad_token is properly set before model loading
        if self.tokenizer.pad_token is None:
            if self.tokenizer.eos_token is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            else:
                self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})

        # Load model with proper configuration for text generation
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            # Add important config parameters for proper text generation
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            # Don't truncate on initialization
            truncation=False
        )
        self.model.to(device)

        # Ensure the model knows about the pad token
        self.model.config.pad_token_id = self.tokenizer.pad_token_id

    def generate_response(self, question: str, max_length: int = 512) -> str:
      """
      Generate a proper medical response with fixed generation parameters
      """
      # Format the question properly
      formatted_prompt = self.format_prompt(question)

      # Encode with proper handling
      encoding = self.tokenizer(
          formatted_prompt,
          return_tensors="pt",
          padding=True,
          truncation=True,
          max_length=max_length // 2  # Leave room for response
      )

      input_ids = encoding['input_ids'].to(device)
      attention_mask = encoding['attention_mask'].to(device)

      # Use much more conservative generation parameters
      try:
          with torch.no_grad():
              # Generate with conservative parameters
              outputs = self.model.generate(
                  input_ids=input_ids,
                  attention_mask=attention_mask,
                  max_length=max_length,
                  min_length=30,  # Ensure minimum sensible answer length
                  do_sample=False,  # Use greedy decoding for stability
                  num_beams=4,  # Use beam search for better quality
                  early_stopping=True,
                  no_repeat_ngram_size=2,
                  length_penalty=1.0,
                  repetition_penalty=1.2,
                  # Avoid EOS issues
                  pad_token_id=self.tokenizer.pad_token_id,
                  eos_token_id=self.tokenizer.eos_token_id,
              )

          # Decode generated text
          generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

          # Extract only the answer part
          answer = self.extract_answer_from_response(formatted_prompt, generated_text)

          # Verify answer is meaningful
          if not self.is_meaningful_text(answer):
              answer = "Diabetes is a chronic condition that affects how your body processes blood sugar (glucose). There are several types, with Type 1 and Type 2 being most common. It occurs either when the pancreas doesn't produce enough insulin or when the body cannot effectively use the insulin it produces."
              logger.warning("Generated gibberish - using fallback answer")

          return answer

      except Exception as e:
          logger.error(f"Error generating response: {e}")
          # Return a fallback response
          return "This is a medical condition requiring professional diagnosis. Please consult with a healthcare provider for accurate information."

    def extract_answer_from_response(self, prompt: str, full_response: str) -> str:
        """Extract only the answer portion from the response"""
        # Handle the case where the model just repeats the prompt
        if prompt in full_response:
            answer = full_response.split(prompt, 1)[1].strip()
        else:
            # Look for likely answer markers
            answer_markers = ["Medical Answer:", "Answer:", "\n\n", "\n"]
            for marker in answer_markers:
                if marker in full_response:
                    parts = full_response.split(marker, 1)
                    if len(parts) > 1:
                        answer = parts[1].strip()
                        break
            else:
                # If no marker found, use the whole response
                answer = full_response.strip()

        # Clean up any garbage text with a simple filter
        answer = self.clean_response_text(answer)
        return answer

    def clean_response_text(self, text: str) -> str:
        """Clean up garbage text"""
        # Remove sequences of repeated characters (like "rrrrr")
        text = re.sub(r'([a-zA-Z])\1{3,}', r'\1', text)

        # Remove sequences of punctuation
        text = re.sub(r'([.,!?:;]){2,}', r'\1 ', text)

        # Remove sequences of single characters with spaces
        text = re.sub(r'\b([a-zA-Z])\s+\1\b', '', text)

        # Return only if the answer has actual sentences
        sentences = re.split(r'[.!?]+', text)
        valid_sentences = [s.strip() for s in sentences if len(s.strip().split()) > 3]

        if valid_sentences:
            return ' '.join(valid_sentences) + '.'
        return text

    def is_meaningful_text(self, text: str) -> bool:
        """Check if the text is meaningful medical content"""
        # Remove spaces and punctuation for analysis
        cleaned = re.sub(r'[^\w]', '', text.lower())

        # Check for repetition patterns
        for i in range(1, 4):  # Check 1-3 character patterns
            pattern = ''.join(['.{' + str(i) + '}' for _ in range(5)])  # Look for 5 repetitions
            if re.search(f'({pattern})\\1+', cleaned):
                return False

        # Check if text contains real words
        words = text.lower().split()
        medical_words = ["diabetes", "blood", "sugar", "insulin", "glucose", "pancreas",
                        "type", "chronic", "condition", "symptom", "treatment", "health",
                        "patient", "disease", "body", "cell", "medical", "medicine"]

        # Count medical words
        medical_word_count = sum(1 for word in words if any(med in word for med in medical_words))

        # Should have at least some medical terminology
        if medical_word_count < 1:
            return False

        # Text should have reasonable length and word variety
        if len(words) < 5 or len(set(words)) < 4:
            return False

        return True

    def generate_batch_responses(self, questions: List[str], max_length: int = 512) -> List[str]:
        """
        Generate responses for a batch of questions

        Args:
            questions: List of medical questions
            max_length: Maximum length of each response

        Returns:
            List of generated response texts
        """
        responses = []

        # For Colab, process in smaller batches to avoid OOM
        batch_size = 2

        for i in tqdm(range(0, len(questions), batch_size), desc=f"Generating responses with {self.model_name}"):
            batch_questions = questions[i:i+batch_size]
            batch_responses = []

            for question in batch_questions:
                response = self.generate_response(question, max_length)
                batch_responses.append(response)

            responses.extend(batch_responses)

        return responses


In [None]:
class TeacherModel(MedicalModel):
    """Teacher model - specialized medical model"""

    def __init__(self, model_name: str = None):
        # Override with a more reliable model if the specified one isn't working
        model_name = model_name or "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"

        # For immediate fix, use a backup model approach
        try:
            super().__init__(model_name)
            logger.info("Teacher model initialized")

            # Test if the model produces valid output
            test_output = self.generate_response("What is hypertension?")
            if not self.is_meaningful_text(test_output):
                logger.warning("Teacher model producing invalid outputs, using backup responses")
                self.use_backup_responses = True
            else:
                self.use_backup_responses = False

        except Exception as e:
            logger.error(f"Error initializing teacher model: {e}")
            logger.info("Falling back to backup response method")
            self.use_backup_responses = True

    def generate_response(self, question: str, max_length: int = 512) -> str:
        """Generate response with fallback to backups if model fails"""
        if hasattr(self, 'use_backup_responses') and self.use_backup_responses:
            return self.get_backup_response(question)

        try:
            return super().generate_response(question, max_length)
        except Exception:
            return self.get_backup_response(question)

    def get_backup_response(self, question: str) -> str:
        """Provide reliable backup responses for common medical questions"""
        question_lower = question.lower()

        # Dictionary of reliable backup responses for common questions
        backups = {
            "diabetes": "Diabetes is a chronic condition that affects how your body processes blood sugar (glucose). There are two main types: Type 1 (where the body doesn't produce insulin) and Type 2 (where the body doesn't use insulin properly). Symptoms include increased thirst, frequent urination, hunger, fatigue, and blurred vision.",
            "hypertension": "Hypertension, or high blood pressure, is a condition where the force of blood against artery walls is consistently too high. It often has no symptoms but can lead to serious health problems like heart disease and stroke if untreated. It's diagnosed when blood pressure readings are consistently at or above 130/80 mm Hg.",
            "covid": "COVID-19 is a respiratory illness caused by the SARS-CoV-2 virus. Symptoms include fever, cough, shortness of breath, fatigue, body aches, headache, loss of taste or smell, sore throat, congestion, nausea, and diarrhea. Severity ranges from mild to severe, with some cases requiring hospitalization.",
            "corona": "COVID-19 (Coronavirus Disease 2019) is caused by the SARS-CoV-2 virus. It primarily spreads through respiratory droplets when infected people cough, sneeze, talk, or breathe. Symptoms typically appear 2-14 days after exposure and range from mild to severe, including fever, cough, shortness of breath, fatigue, and loss of taste or smell."
        }

        # Find the most relevant backup response
        for key, response in backups.items():
            if key in question_lower:
                return response

        # Generic response if no specific match
        return "This is a medical condition that requires professional assessment. The condition may present with various symptoms and treatment options depend on severity and individual factors. Please consult with a healthcare provider for accurate information specific to your situation."


class StudentModel(MedicalModel):
    """Student model - smaller model to be refined"""

    def __init__(self, model_name: str):
        super().__init__(model_name)
        logger.info("Student model initialized")

    def save_model(self, output_dir: str) -> None:
        """
        Save the student model and tokenizer

        Args:
            output_dir: Directory to save the model to
        """
        self.model.save_pretrained(os.path.join(output_dir, "student_model"))
        self.tokenizer.save_pretrained(os.path.join(output_dir, "student_tokenizer"))
        logger.info(f"Student model saved to {output_dir}")

    def finetune(self,
                train_examples: List[Dict],
                val_examples: List[Dict],
                config: MedRefineConfig) -> None:
        """
        Finetune the student model on examples

        Args:
            train_examples: List of training examples
            val_examples: List of validation examples
            config: Configuration object with training parameters
        """
        logger.info("Preparing datasets for fine-tuning")

        # Prepare training data
        train_texts = [example["combined"] for example in train_examples]
        val_texts = [example["combined"] for example in val_examples]

        # For Colab, limit the number of examples to avoid OOM
        max_examples = 50
        train_texts = train_texts[:max_examples]
        val_texts = val_texts[:max_examples//2]

        # Tokenize the data
        train_encodings = self.tokenizer(train_texts, truncation=True, padding="max_length",
                                         max_length=config.max_length, return_tensors="pt")
        val_encodings = self.tokenizer(val_texts, truncation=True, padding="max_length",
                                       max_length=config.max_length, return_tensors="pt")

        # Create PyTorch datasets
        class TextDataset(torch.utils.data.Dataset):
            def __init__(self, encodings):
                self.encodings = encodings

            def __len__(self):
                return len(self.encodings.input_ids)

            def __getitem__(self, idx):
                return {key: val[idx] for key, val in self.encodings.items()}

        train_dataset = TextDataset(train_encodings)
        val_dataset = TextDataset(val_encodings)

        # Set up training arguments optimized for Colab
        training_args = TrainingArguments(
            output_dir=os.path.join(config.output_dir, "checkpoints"),
            overwrite_output_dir=True,
            num_train_epochs=config.num_train_epochs,
            per_device_train_batch_size=config.batch_size,
            per_device_eval_batch_size=config.batch_size,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            save_total_limit=1,  # Save only the best model to conserve space
            load_best_model_at_end=True,
            learning_rate=config.learning_rate,
            weight_decay=0.01,
            logging_dir=os.path.join(config.output_dir, "logs"),
            logging_steps=10,  # More frequent logging for short runs
            # Gradient accumulation for effective larger batch size
            gradient_accumulation_steps=4,
            # Mixed precision training for better GPU utilization
            fp16=torch.cuda.is_available(),
            # Conserve GPU memory
            gradient_checkpointing=True
        )

        # Create data collator for language modeling
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False  # Not using masked language modeling
        )

        # Initialize trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            data_collator=data_collator
        )

        logger.info("Starting fine-tuning")
        trainer.train()

        # Update the model with the fine-tuned version
        self.model = trainer.model

        # Save the fine-tuned model
        self.save_model(config.output_dir)



In [None]:
from sentence_transformers import util

def compute_cosine_metrics(student_outputs, teacher_outputs, model):
    similarities = []
    for s_out, t_out in zip(student_outputs, teacher_outputs):
        s_emb = model.encode(s_out, convert_to_tensor=True)
        t_emb = model.encode(t_out, convert_to_tensor=True)
        sim = util.pytorch_cos_sim(s_emb, t_emb).item()
        similarities.append(sim)

    avg_similarity = np.mean(similarities)
    return avg_similarity, similarities


In [None]:
class MedRefine:
    """Main MedRefine system implementing Contrastive Knowledge Distillation"""

    def __init__(self, config: MedRefineConfig):
        self.config = config

        # Initialize teacher and student models
        logger.info("Initializing teacher model")
        self.teacher = TeacherModel(config.teacher_model_name)

        logger.info("Initializing student model")
        self.student = StudentModel(config.student_model_name)

        # Initialize similarity evaluator
        logger.info("Initializing similarity evaluator")
        self.evaluator = SimilarityEvaluator(config)

        # Initialize data processor
        logger.info("Initializing data processor")
        self.data_processor = MedicalDataProcessor(config)

        # Performance tracking
        self.training_stats = {
            "iteration": [],
            "avg_similarity": [],
            "threshold": [],
            "refinement_needed": []
        }

    def run_training_iteration(self,
                              train_df: pd.DataFrame,
                              val_df: pd.DataFrame,
                              iteration: int) -> Dict:
        """
        Run a single training iteration with contrastive knowledge distillation

        Args:
            train_df: Training data DataFrame
            val_df: Validation data DataFrame
            iteration: Current iteration number

        Returns:
            Dictionary with training statistics
        """
        logger.info(f"=== Starting training iteration {iteration} ===")

        # For Colab, limit the number of examples to avoid OOM
        max_train = 20
        max_val = 10

        # Format prompts for teacher and student
        train_prompts = [self.data_processor.format_prompt(q) for q in train_df['question'][:max_train]]
        val_prompts = [self.data_processor.format_prompt(q) for q in val_df['question'][:max_val]]

        # Generate teacher responses
        logger.info("Generating teacher responses")
        teacher_train_responses = self.teacher.generate_batch_responses(train_prompts)
        teacher_val_responses = self.teacher.generate_batch_responses(val_prompts)

        # Generate student responses
        logger.info("Generating student responses")
        student_train_responses = self.student.generate_batch_responses(train_prompts)
        student_val_responses = self.student.generate_batch_responses(val_prompts)

        # Evaluate semantic similarity
        logger.info("Evaluating semantic similarity")
        train_similarities = self.evaluator.evaluate_batch_similarity(
            teacher_train_responses, student_train_responses)

        val_similarities = self.evaluator.evaluate_batch_similarity(
            teacher_val_responses, student_val_responses)

        # Calculate statistics
        avg_train_similarity = sum(train_similarities) / len(train_similarities)
        avg_val_similarity = sum(val_similarities) / len(val_similarities)

        logger.info(f"Average training similarity: {avg_train_similarity:.4f}")
        logger.info(f"Average validation similarity: {avg_val_similarity:.4f}")

        # Determine if refinement is needed using adaptive threshold
        current_threshold = self.evaluator.current_threshold
        refinement_needed = avg_val_similarity < current_threshold

        stats = {
            "iteration": iteration,
            "avg_train_similarity": avg_train_similarity,
            "avg_val_similarity": avg_val_similarity,
            "threshold": current_threshold,
            "refinement_needed": refinement_needed
        }

        # Add to performance tracking
        self.training_stats["iteration"].append(iteration)
        self.training_stats["avg_similarity"].append(avg_val_similarity)
        self.training_stats["threshold"].append(current_threshold)
        self.training_stats["refinement_needed"].append(refinement_needed)

        # Update adaptive threshold
        self.evaluator.update_adaptive_threshold(val_similarities)

        # If refinement is needed, finetune the student model
        if refinement_needed:
            logger.info(f"Refinement needed (similarity {avg_val_similarity:.4f} < threshold {current_threshold:.4f})")

            # Use limited examples for finetuning
            train_subset_df = train_df.iloc[:max_train].copy()
            val_subset_df = val_df.iloc[:max_val].copy()

            # Create examples for finetuning using teacher responses
            train_examples = self.data_processor.create_dataset_for_finetuning(
                train_subset_df, teacher_responses=teacher_train_responses)

            val_examples = self.data_processor.create_dataset_for_finetuning(
                val_subset_df, teacher_responses=teacher_val_responses)

            # Finetune student model
            logger.info("Fine-tuning student model")
            self.student.finetune(train_examples, val_examples, self.config)
        else:
            logger.info(f"No refinement needed (similarity {avg_val_similarity:.4f} >= threshold {current_threshold:.4f})")

        return stats

    def train(self, num_iterations: int = 3) -> Dict:
        """
        Train the MedRefine system for multiple iterations

        Args:
            num_iterations: Number of training iterations (reduced for Colab)

        Returns:
            Training statistics
        """
        logger.info(f"Starting MedRefine training for {num_iterations} iterations")

        # Prepare data
        train_df, val_df, test_df = self.data_processor.prepare_data()

        # Run training iterations
        for i in range(1, num_iterations + 1):
            stats = self.run_training_iteration(train_df, val_df, i)

            # Plot similarity history
            self.evaluator.plot_similarity_history()

        # Final evaluation on test set (limited for Colab)
        logger.info("Performing final evaluation on test set")
        max_test = min(10, len(test_df))
        test_prompts = [self.data_processor.format_prompt(q) for q in test_df['question'][:max_test]]

        teacher_test_responses = self.teacher.generate_batch_responses(test_prompts)
        student_test_responses = self.student.generate_batch_responses(test_prompts)

        test_similarities = self.evaluator.evaluate_batch_similarity(
            teacher_test_responses, student_test_responses)

        avg_test_similarity = sum(test_similarities) / len(test_similarities)

        logger.info(f"Final test set similarity: {avg_test_similarity:.4f}")

        # Save final model
        self.student.save_model(self.config.output_dir)

        # Plot and save training statistics
        self.plot_training_stats()

        final_stats = {
            "training_stats": self.training_stats,
            "final_test_similarity": avg_test_similarity,
            "final_threshold": self.evaluator.current_threshold
        }

        return final_stats

    def plot_training_stats(self) -> None:
        """Plot and save training statistics"""
        plt.figure(figsize=(12, 8))

        # Plot similarity and threshold
        plt.subplot(2, 1, 1)
        plt.plot(self.training_stats["iteration"], self.training_stats["avg_similarity"],
                marker='o', label='Average Similarity')
        plt.plot(self.training_stats["iteration"], self.training_stats["threshold"],
                marker='s', linestyle='--', label='Adaptive Threshold')
        plt.title('Semantic Similarity vs. Threshold')
        plt.xlabel('Iteration')
        plt.ylabel('Score')
        plt.legend()
        plt.grid(True)

        # Plot refinement needed
        plt.subplot(2, 1, 2)
        plt.bar(self.training_stats["iteration"],
                [1 if r else 0 for r in self.training_stats["refinement_needed"]],
                color='orange', alpha=0.7)
        plt.title('Refinement Needed')
        plt.xlabel('Iteration')
        plt.ylabel('Status (1=Yes, 0=No)')
        plt.yticks([0, 1], ['No', 'Yes'])
        plt.grid(True, axis='x')

        plt.tight_layout()
        plt.savefig(os.path.join(self.config.output_dir, 'training_stats.png'))
        plt.close()

    def evaluate_on_new_question(self, question: str) -> Dict:
        """
        Generate and evaluate responses for a new question with reliability checking
        """
        prompt = self.data_processor.format_prompt(question)

        # Check if we should use the teacher or student model based on reliability
        use_teacher = False
        if not hasattr(self, 'student_reliability_score'):
            # First run - initialize reliability tracking
            self.student_reliability_score = 0.5  # Initial neutral score

        if self.student_reliability_score < 0.7:
            # If student model is unreliable, use teacher model responses
            logger.warning("Using teacher model due to student unreliability")
            use_teacher = True

        # Generate responses
        teacher_response = self.teacher.generate_response(prompt)
        student_response = self.student.generate_response(prompt)

        # Calculate similarity
        similarity = self.evaluator.calculate_similarity(teacher_response, student_response)

        # Perform content checks
        is_teacher_valid = self.is_valid_medical_response(teacher_response)
        is_student_valid = self.is_valid_medical_response(student_response)

        # Update reliability score
        alpha = 0.3  # Learning rate for reliability updates
        if is_student_valid:
            self.student_reliability_score = (1-alpha) * self.student_reliability_score + alpha * 1.0
        else:
            self.student_reliability_score = (1-alpha) * self.student_reliability_score + alpha * 0.0

        logger.info(f"Student model reliability score: {self.student_reliability_score:.2f}")

        # Decide which response to use
        final_response = teacher_response if (use_teacher or not is_student_valid) else student_response

        result = {
            "question": question,
            "teacher_response": teacher_response,
            "teacher_valid": is_teacher_valid,
            "student_response": student_response,
            "student_valid": is_student_valid,
            "final_response": final_response,
            "similarity": similarity,
            "threshold": self.evaluator.current_threshold,
            "passes_threshold": similarity >= self.evaluator.current_threshold,
            "student_reliability": self.student_reliability_score
        }

        return result

    def is_valid_medical_response(self, text: str) -> bool:
        """Check if a response is a valid medical text"""
        # Check text length
        if len(text.strip()) < 20:
            return False

        # Check for gibberish patterns
        if re.search(r'([a-zA-Z])\1{3,}', text):  # Repeated characters
            return False

        if re.search(r'([.,!?:;]){3,}', text):  # Repeated punctuation
            return False

        # Check for meaningful sentences
        sentences = re.split(r'[.!?]+', text)
        valid_sentences = [s.strip() for s in sentences if len(s.strip().split()) > 3]
        if len(valid_sentences) == 0:
            return False

        # Check for presence of medical terminology
        medical_terms = [
            "symptom", "treatment", "condition", "disease", "patient",
            "diagnosis", "health", "medical", "medicine", "therapy",
            "blood", "doctor", "hospital", "clinic"
        ]

        has_medical_term = any(term in text.lower() for term in medical_terms)
        if not has_medical_term:
            return False

        return True


In [None]:
def main():
    """
    Main function to run the MedRefine system with dataset-based responses
    """
    try:
        import pandas as pd
        from sklearn.metrics.pairwise import cosine_similarity
        from sentence_transformers import SentenceTransformer

        # Set up configuration with safer defaults
        config = MedRefineConfig(
            teacher_model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
            student_model_name="emilyalsentzer/Bio_ClinicalBERT",
            sentence_bert_model="pritamdeka/S-BioBert-snli-multinli-stsb",
            initial_threshold=0.75,
            num_train_epochs=1,  # Do minimal training initially
            batch_size=2
        )

        # Load the dataset with questions and answers
        try:
            # Adjust the path to your dataset file
            dataset = pd.read_csv("medquad.csv")
            print(f"Loaded {len(dataset)} question-answer pairs from dataset")

            # Make sure dataset has required columns
            if not all(col in dataset.columns for col in ['question', 'answer']):
                raise ValueError("Dataset must contain 'question' and 'answer' columns")

            # Set a default response in case no good match is found
            default_response = "This is a medical condition that requires professional assessment. Please consult with a healthcare provider for accurate information specific to your situation."

            # Load sentence embedding model for semantic matching
            sentence_model = SentenceTransformer(config.sentence_bert_model)
            question_embeddings = sentence_model.encode(dataset['question'].tolist())

        except FileNotFoundError:
            print("Warning: Dataset file not found. Using limited backup responses.")
            # Fallback to hardcoded backup responses if dataset is unavailable
            dataset = None
            default_response = "This is a medical condition that requires professional assessment. Please consult with a healthcare provider for accurate information specific to your situation."

        print("\n===== Medical QA System =====")
        print("Starting interactive medical response system...")

        # Simple Q&A loop using dataset-based responses
        while True:
            query = input("\nEnter a medical question (or 'quit' to exit): ")
            if query.lower() in ['exit', 'quit', 'q']:
                break

            # Use dataset if available
            if dataset is not None:
                # Encode the user query
                query_embedding = sentence_model.encode([query])

                # Calculate similarities with all questions in the dataset
                similarities = cosine_similarity(query_embedding, question_embeddings)[0]

                # Find the most similar question
                best_match_idx = similarities.argmax()
                best_match_score = similarities[best_match_idx]

                # If similarity is above threshold, use the corresponding answer
                if best_match_score > 0.7:  # Adjust threshold as needed
                    answer = dataset.iloc[best_match_idx]['answer']
                    print(f"\nMatched question: {dataset.iloc[best_match_idx]['question']}")
                    print(f"Similarity score: {best_match_score:.2f}")
                else:
                    answer = default_response
                    print("\nNo close match found in dataset.")
            else:
                # Fallback to basic keyword matching if dataset is unavailable
                backup_responses = {
                    "diabetes": "Diabetes is a chronic condition characterized by elevated blood sugar levels...",
                    "covid": "COVID-19 is a respiratory illness caused by the SARS-CoV-2 virus...",
                    "corona": "COVID-19 is a respiratory illness caused by the SARS-CoV-2 virus..."
                }

                answer = default_response
                for key, response in backup_responses.items():
                    if key in query.lower():
                        answer = response
                        break

            print("\nMedical Answer:")
            print(answer)

    except Exception as e:
        print(f"Error: {e}")
        print("The system encountered an error. Please try again with a simpler medical question.")


In [None]:
def main():
    """
    Prepares the MedRefine system for answering medical questions based on a dataset.
    """
    import pandas as pd
    from sklearn.metrics.pairwise import cosine_similarity
    from sentence_transformers import SentenceTransformer

    # Load configs
    config = {
        "sentence_bert_model": "pritamdeka/S-BioBert-snli-multinli-stsb",
        "similarity_threshold": 0.7
    }

    # Load dataset
    try:
        dataset = pd.read_csv("medquad.csv")
        print(f"Loaded {len(dataset)} question-answer pairs.")

        if not all(col in dataset.columns for col in ['question', 'answer']):
            raise ValueError("Dataset must contain 'question' and 'answer' columns.")

        sentence_model = SentenceTransformer(config["sentence_bert_model"])
        question_embeddings = sentence_model.encode(dataset['question'].tolist())

    except FileNotFoundError:
        print("Dataset not found. Using fallback responses.")
        dataset = None
        sentence_model = None
        question_embeddings = None

    default_response = "This is a medical condition that requires professional assessment. Please consult a healthcare provider."

    backup_responses = {
        "diabetes": "Diabetes is a chronic condition characterized by elevated blood sugar levels...",
        "covid": "COVID-19 is a respiratory illness caused by the SARS-CoV-2 virus...",
        "corona": "COVID-19 is a respiratory illness caused by the SARS-CoV-2 virus..."
    }

    # Define the response function
    def answer_query(user_query):
        if not user_query.strip():
            return "Please enter a valid medical question."

        if dataset is not None and sentence_model is not None:
            query_embedding = sentence_model.encode([user_query])
            similarities = cosine_similarity(query_embedding, question_embeddings)[0]
            best_match_idx = similarities.argmax()
            best_match_score = similarities[best_match_idx]

            if best_match_score > config["similarity_threshold"]:
                answer = dataset.iloc[best_match_idx]['answer']
                matched_question = dataset.iloc[best_match_idx]['question']
                similarity_score = best_match_score
                return f"Answer: {answer}\n\nSimilarity: {similarity_score:.2f}"
            else:
                return default_response
        else:
            for key, response in backup_responses.items():
                if key in user_query.lower():
                    return response
            return default_response

    return answer_query


In [None]:
import gradio as gr

if __name__ == "__main__":
    try:
        answer_query = main()

        demo = gr.Interface(
            fn=answer_query,
            inputs=gr.Textbox(lines=2, placeholder="Enter a medical question..."),
            outputs=gr.Textbox(label="Answer"),
            title="Medical Question Answering System",
            description="Ask a medical question and get an answer based on a curated dataset.",
        )

        demo.launch(share=True)  # Important in Colab!

    except Exception as e:
        print(f"Error occurred: {e}")
        import traceback
        traceback.print_exc()

Loaded 16412 question-answer pairs.


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/4.32k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/610 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/433M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/433M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/546 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://6134fae21fbe24a48a.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


In [None]:
if __name__ == "__main__":
    try:
        # Execute main function
        med_refine, stats = main()

        # Print summary
        print("\n=== MedRefine Training Summary ===")
        print(f"Iterations completed: {len(stats['training_stats']['iteration'])}")
        print(f"Final similarity score: {stats['final_test_similarity']:.4f}")
        print(f"Final threshold: {stats['final_threshold']:.4f}")
        print(f"Output directory: {med_refine.config.output_dir}")

        # Optional interactive demo - kept simple for Colab
        demo_mode = input("\nDo you want to try the interactive demo? (y/n): ")
        if demo_mode.lower() == 'y':
            print("\nStarting interactive demo (type 'exit' to quit)")
            while True:
                user_query = input("\nEnter a medical question: ")
                if user_query.lower() in ['exit', 'quit', 'q']:
                    break

                result = med_refine.evaluate_on_new_question(user_query)

                print("\n=== Results ===")
                print(f"Teacher: {result['teacher_response'][:100]}...")
                print(f"\nStudent: {result['student_response'][:100]}...")
                print(f"\nSimilarity: {result['similarity']:.4f}")

    except Exception as e:
        print(f"Error occurred: {e}")
        import traceback
        traceback.print_exc()


In [None]:
!pip install gradio

Collecting gradio
  Downloading gradio-5.29.0-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<25.0,>=22.0 (from gradio)
  Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.10.0 (from gradio)
  Downloading gradio_client-1.10.0-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.11.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<0.2.0,>=0.1.6

In [None]:
import gradio as gr

def interactive_demo(user_query):
    try:
        result = med_refine.evaluate_on_new_question(user_query)

        teacher_response = result['teacher_response'][:100] + "..." if result['teacher_response'] else "No response"
        student_response = result['student_response'][:100] + "..." if result['student_response'] else "No response"
        similarity = f"{result['similarity']:.4f}"

        return teacher_response, student_response, similarity
    except Exception as e:
        return f"Error: {e}", "", ""

if __name__ == "__main__":
    try:
        med_refine, stats = main()

        print("\n=== MedRefine Training Summary ===")
        print(f"Iterations completed: {len(stats['training_stats']['iteration'])}")
        print(f"Final similarity score: {stats['final_test_similarity']:.4f}")
        print(f"Final threshold: {stats['final_threshold']:.4f}")
        print(f"Output directory: {med_refine.config.output_dir}")

        # Launch Gradio UI
        demo = gr.Interface(
            fn=interactive_demo,
            inputs=gr.Textbox(lines=2, placeholder="Enter a medical question..."),
            outputs=[
                gr.Textbox(label="Teacher Response"),
                gr.Textbox(label="Student Response"),
                gr.Textbox(label="Similarity Score"),
            ],
            title="MedRefine Interactive Demo",
            description="Ask a medical question and see how the student and teacher models respond.",
        )

        demo.launch()

    except Exception as e:
        print(f"Error occurred: {e}")
        import traceback
        traceback.print_exc()


In [None]:
!zip -r medrefine_output.zip medrefine_output


zip error: Nothing to do! (try: zip -r medrefine_output.zip . -i medrefine_output)


In [None]:
from tensorboardX import SummaryWriter
from sentence_transformers import util, SentenceTransformer
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

writer = SummaryWriter(logdir="./runs/src_kd_metrics")


In [None]:
def compute_cosine_metrics(student_outputs, teacher_outputs, model):
    similarities = []
    for s_out, t_out in zip(student_outputs, teacher_outputs):
        s_emb = model.encode(s_out, convert_to_tensor=True)
        t_emb = model.encode(t_out, convert_to_tensor=True)
        sim = util.pytorch_cos_sim(s_emb, t_emb).item()
        similarities.append(sim)
    avg_similarity = np.mean(similarities)
    return avg_similarity, similarities

def compute_acceptance_rate(similarities, threshold):
    accepted = [sim for sim in similarities if sim >= threshold]
    return len(accepted) / len(similarities)

def compute_kl_divergence(student_logits, teacher_logits):
    student_probs = F.log_softmax(student_logits, dim=-1)
    teacher_probs = F.softmax(teacher_logits, dim=-1)
    kl_div = F.kl_div(student_probs, teacher_probs, reduction='batchmean')
    return kl_div.item()

def plot_similarity_distribution(similarities, title="Similarity Score Distribution"):
    plt.hist(similarities, bins=30, color='skyblue', edgecolor='black')
    plt.title(title)
    plt.xlabel("Cosine Similarity")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.show()


In [None]:
# Example usage — replace variables with actual training loop outputs
student_outputs_before = ["This is a sample answer"] * 10
student_outputs_after = ["This is a better refined answer"] * 10
teacher_outputs = ["This is the expected medical answer"] * 10
sentence_bert = SentenceTransformer("pritamdeka/S-BioBert-snli-multinli-stsb")

before_sim, before_all = compute_cosine_metrics(student_outputs_before, teacher_outputs, sentence_bert)
after_sim, after_all = compute_cosine_metrics(student_outputs_after, teacher_outputs, sentence_bert)

writer.add_scalar("Similarity/Before_Refinement", before_sim, 0)
writer.add_scalar("Similarity/After_Refinement", after_sim, 0)
writer.add_scalar("Similarity/Improvement", after_sim - before_sim, 0)

threshold = 0.75
acceptance_rate = compute_acceptance_rate(after_all, threshold)
writer.add_scalar("Refinement/Acceptance_Rate", acceptance_rate, 0)

student_logits = torch.randn(10, 30522)
teacher_logits = torch.randn(10, 30522)
kl_loss = compute_kl_divergence(student_logits, teacher_logits)
writer.add_scalar("Divergence/KL_Loss", kl_loss, 0)

writer.close()
print("Metrics logged to TensorBoard. Run `tensorboard --logdir=runs` to view.")
