In [None]:
# Install required packages
!pip install -q transformers torch

import json
import pandas as pd
import os
import time
from datetime import datetime
import numpy as np
import re
from google.colab import drive
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import glob

class MIRModelComparator:
    def __init__(self, questions_file: str):
        """Initialize the comparator with questions file"""
        self.questions_file = questions_file
        self.base_model_path = "google/flan-t5-base"
        self.models_base_dir = "/content/drive/MyDrive/TFM2/models"
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        # Find latest model versions
        self.models_to_evaluate = self._find_latest_models()
        print("\nModels to evaluate:")
        for model_name, path in self.models_to_evaluate.items():
            print(f"{model_name}: {path}")

    def _find_latest_models(self):
        """Find the latest versions of trained models"""
        # Initialize with vanilla model
        models = {
            'vanilla': self.base_model_path
        }

        # Find latest QA model
        qa_models = glob.glob(os.path.join(self.models_base_dir, "mir_flan_t5_qa_v2_*"))
        if qa_models:
            latest_qa = max(qa_models, key=os.path.getctime)
            models['qa_trained'] = latest_qa
            print(f"Found latest QA model: {latest_qa}")
        else:
            print("No QA models found!")

        # Find latest fulltext model
        fulltext_models = glob.glob(os.path.join(self.models_base_dir, "mir_flan_t5_fulltext_v2_*"))
        if fulltext_models:
            latest_fulltext = max(fulltext_models, key=os.path.getctime)
            models['fulltext_trained'] = latest_fulltext
            print(f"Found latest fulltext model: {latest_fulltext}")
        else:
            print("No fulltext models found!")

        return models

    def load_model(self, model_path: str):
        """Load a model and its tokenizer"""
        print(f"\nLoading model from: {model_path}")
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_path)
            model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
            model = model.to(self.device)
            return model, tokenizer
        except Exception as e:
            print(f"Error loading model from {model_path}: {str(e)}")
            raise

    def format_prompt(self, question):
        """Format the prompt with examples"""
        question_text = question.get('question', '')
        options = question.get('options', {})

        prompt = f"""Task: Medical Multiple Choice Question
Instructions: Select the correct answer (A, B, C, or D).
Respond only with the letter of the correct answer.

Question: {question_text}

Options:
A) {options.get('A', '')}
B) {options.get('B', '')}
C) {options.get('C', '')}
D) {options.get('D', '')}

Your answer (A/B/C/D):"""

        return prompt

    def get_model_response(self, model, tokenizer, prompt):
        """Get response from model showing full response"""
        inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=50,
                num_beams=4,
                do_sample=True,
                temperature=0.7,
                early_stopping=False
            )

        full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        letter_match = re.search(r'[ABCD]', full_response.upper())
        extracted_letter = letter_match.group(0) if letter_match else None

        return {
            'full_response': full_response,
            'extracted_letter': extracted_letter
        }

    def evaluate_model(self, model_name: str, model, tokenizer, test_run: bool = False):
        """Evaluate a single model"""
        print(f"\nEvaluating model: {model_name}")

        # Load questions
        with open(self.questions_file, 'r', encoding='utf-8') as f:
            questions = json.load(f)
        print(f"Loaded {len(questions)} questions")

        questions_to_evaluate = questions[:5] if test_run else questions
        results = []

        for i, question in enumerate(questions_to_evaluate):
            try:
                start_time = time.time()
                prompt = self.format_prompt(question)

                print(f"\nProcessing Question {i+1}/{len(questions_to_evaluate)}:")
                print(f"Question: {question['question']}")

                response = self.get_model_response(model, tokenizer, prompt)
                end_time = time.time()

                model_answer = response['extracted_letter']
                is_correct = model_answer == question['correct_answer'] if model_answer else False

                result = {
                    'question_id': question.get('id', f'Q{i+1}'),
                    'model': model_name,
                    'question': question['question'],
                    'full_response': response['full_response'],
                    'extracted_letter': model_answer,
                    'correct_answer': question['correct_answer'],
                    'correct': is_correct,
                    'time': end_time - start_time,
                    'error': None if model_answer in ['A', 'B', 'C', 'D'] else 'Invalid answer format'
                }

                results.append(result)

                print("\nResults:")
                print(f"Full response: {response['full_response']}")
                print(f"Extracted answer: {model_answer}")
                print(f"Correct answer: {question['correct_answer']}")
                print(f"Correct: {is_correct}")
                print(f"Time: {result['time']:.2f}s")
                print("-" * 40)

            except Exception as e:
                error_msg = str(e)
                print(f"Error on question {i+1}: {error_msg}")
                results.append({
                    'question_id': question.get('id', f'Q{i+1}'),
                    'model': model_name,
                    'question': question.get('question', ''),
                    'full_response': None,
                    'extracted_letter': None,
                    'correct_answer': question.get('correct_answer', ''),
                    'correct': False,
                    'time': None,
                    'error': error_msg
                })

        return results

    def run_comparison(self, test_run: bool = False, output_dir: str = '/content/drive/MyDrive/TFM2/TFM-DATASETS/evaluations'):
        """Run evaluation on all models and save comparative results"""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        os.makedirs(output_dir, exist_ok=True)

        all_results = []
        comparative_metrics = {}

        for model_name, model_path in self.models_to_evaluate.items():
            try:
                model, tokenizer = self.load_model(model_path)
                results = self.evaluate_model(model_name, model, tokenizer, test_run)
                all_results.extend(results)

                # Calculate metrics for this model
                df_model = pd.DataFrame(results)
                valid_answers = df_model[df_model['extracted_letter'].isin(['A', 'B', 'C', 'D'])]

                metrics = {
                    'total_questions': len(df_model),
                    'valid_responses': len(valid_answers),
                    'invalid_responses': len(df_model) - len(valid_answers),
                    'correct': int(df_model['correct'].sum()),
                    'accuracy': float(df_model['correct'].sum() / len(df_model)) if len(df_model) > 0 else 0.0,
                    'valid_accuracy': float(valid_answers['correct'].sum() / len(valid_answers)) if len(valid_answers) > 0 else 0.0,
                    'avg_time': float(df_model['time'].mean()) if not df_model['time'].isna().all() else 0.0
                }

                comparative_metrics[model_name] = metrics

                print(f"\nResults for {model_name}:")
                print(f"Total Questions: {metrics['total_questions']}")
                print(f"Valid Responses: {metrics['valid_responses']}")
                print(f"Correct Answers: {metrics['correct']}")
                print(f"Overall Accuracy: {metrics['accuracy']:.2%}")
                print(f"Valid Answers Accuracy: {metrics['valid_accuracy']:.2%}")

            except Exception as e:
                print(f"Error evaluating {model_name}: {str(e)}")

        # Save all results
        df_all = pd.DataFrame(all_results)
        base_path = f"{output_dir}/comparative_evaluation_{timestamp}"

        df_all.to_csv(f"{base_path}_detailed_results.csv", index=False)

        with open(f"{base_path}_comparative_metrics.json", 'w') as f:
            json.dump(comparative_metrics, f, indent=2)

        # Create comparison table
        comparison_df = pd.DataFrame.from_dict(comparative_metrics, orient='index')
        comparison_df.to_csv(f"{base_path}_comparison_table.csv")

        print("\nComparative Results:")
        print(comparison_df)

        return comparative_metrics

def main():
    # Mount Google Drive
    drive.mount('/content/drive')

    # Initialize comparator
    questions_file = '/content/drive/MyDrive/TFM2/TFM-DATASETS/structured_questions.json'
    comparator = MIRModelComparator(questions_file)

    # Run test evaluation first
    print("\nRunning test evaluation with 5 questions...")
    test_metrics = comparator.run_comparison(test_run=True)

    if input("\nContinue with full evaluation? (y/n): ").lower() == 'y':
        print("\nRunning full evaluation...")
        full_metrics = comparator.run_comparison(test_run=False)

if __name__ == "__main__":
    main()

Mounted at /content/drive
Using device: cpu
Found latest QA model: /content/drive/MyDrive/TFM2/models/mir_flan_t5_qa_v2_20250119_1623
Found latest fulltext model: /content/drive/MyDrive/TFM2/models/mir_flan_t5_fulltext_v2_20250119_1623

Models to evaluate:
vanilla: google/flan-t5-base
qa_trained: /content/drive/MyDrive/TFM2/models/mir_flan_t5_qa_v2_20250119_1623
fulltext_trained: /content/drive/MyDrive/TFM2/models/mir_flan_t5_fulltext_v2_20250119_1623

Running test evaluation with 5 questions...

Loading model from: google/flan-t5-base


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m

Results:
Full response: A
Extracted answer: A
Correct answer: A
Correct: True
Time: 2.21s
----------------------------------------

Processing Question 73/174:
Question: Mujer de 52 años, carnicera de profesión, sin antecedentes de interés ni caídas, que presenta dolor en el hombro de 4 meses de evolución al levantar el brazo. El dolor es de características inflamatorias y las maniobras de impingement positivas. Señale la actitud INCORRECTA:

Results:
Full response: A
Extracted answer: A
Correct answer: B
Correct: False
Time: 1.87s
----------------------------------------

Processing Question 74/174:
Question: Mujer de 56 años que consulta porque hace 4 meses se torció el tobillo en la playa y desde entonces no ha dejado de molestarle. Se cansa subiendo escaleras y le cuesta usar calzado plano. En la exploración presenta dolor en el seno del tarso, el talón y la cara medial del tobillo. Al examen podoscópico presenta el 