# Get Huggingface Data

In [1]:
# install Huggingface datasets
!pip install datasets

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

In [2]:

# Acquire the MedHALT Data, using the reasoning_FCT subset
# This subset is for multiple choice questions
# Source:
# https://huggingface.co/datasets/openlifescienceai/Med-HALT/viewer/reasoning_FCT
# There are 18.9k records to use, far more than we need

from datasets import load_dataset

reasoning_FCT = load_dataset("openlifescienceai/Med-HALT", "reasoning_FCT")
train_data = reasoning_FCT['train']

# How many records?
print("No. of Records in Train:", len(train_data))
print("Example:\n")
# see example
train_data[0]

README.md:   0%|          | 0.00/4.80k [00:00<?, ?B/s]

reasoning_FCT/reasoning_FCT.csv:   0%|          | 0.00/9.93M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/18866 [00:00<?, ? examples/s]

No. of Records in Train: 18866
Example:



{'id': 'bc8659f4-3062-4f57-9e24-e32ad92a8d4e',
 'dataset': 'headqa_en',
 'question': 'Which of the following structural elements is characteristic of the ortopramide group drugs?',
 'options': "{'0': 'They are anilides with propyl group in ortho.', '1': 'They are benzamides with methoxy group in ortho.', '2': 'They are benzenesulfonamides with a methyl group in ortho.', '3': 'They are ortho-halogenated derivatives of phenothiazine.', 'correct answer': 'They are ortho-halogenated derivatives of phenothiazine.'}",
 'correct_answer': 'They are benzamides with methoxy group in ortho.',
 'correct_index': 1,
 'split_type': 'val',
 'subject_name': 'pharmacology',
 'topic_name': None,
 'year': 2015.0,
 'exam_name': 'Cuaderno_2015_1_F',
 'student_answer': 'They are ortho-halogenated derivatives of phenothiazine.',
 'student_index': 3}

In [3]:
# How many subjects are there, and how many records per subject

# Using pandas
def count_subjects_pandas(records):
    import pandas as pd

    # Convert list of records to DataFrame
    df = pd.DataFrame(records)

    # Group by subject_name and count
    subject_counts = df.groupby('subject_name').size().reset_index(name='count')

    # Sort by count in descending order
    subject_counts = subject_counts.sort_values('count', ascending=False)

    return subject_counts

results_pandas = count_subjects_pandas(train_data)
print(results_pandas)

                    subject_name  count
3                         Dental   2401
19                       Surgery    803
23                      medicine    684
24                       nursery    681
25                  pharmacology    679
6       Gynaecology & Obstetrics    679
26                    psychology    678
22                     chemistry    674
21                       biology    672
20                       Unknown    646
11                     Pathology    601
7                       Medicine    577
14                    Physiology    528
13                  Pharmacology    521
2                   Biochemistry    477
1                        Anatomy    427
12                    Pediatrics    358
18  Social & Preventive Medicine    353
8                   Microbiology    268
9                  Ophthalmology    209
5              Forensic Medicine    182
16                     Radiology    174
4                            ENT    117
0                    Anaesthesia     87


In the above data we have the question, the options (which will be appended to the prompt) and we have the target, correct_index.

To build the prompt and get the matching target, we will need a function, as follows....

## Preprocess data

Aappending contexts to question

In [4]:
import json
import ast

def create_prompt(example):
    """
    Creates a formatted prompt from a single example and extracts the correct index.

    Args:
        example (dict): Single example from the dataset

    Returns:
        tuple: (formatted_prompt, correct_index)
    """
    # Define introduction
    introduction = "You are a medical expert and this is a multiple choice exam question. Please respond with the integer index of the CORRECT answer only; [0,1,2,3]."

    # Get question
    question = example['question']

    # Parse options string to dict (assuming it's stored as string)
    try:
        # Handle case where options might be already a dict
        if isinstance(example['options'], str):
            options_dict = ast.literal_eval(example['options'])
        else:
            options_dict = example['options']

        # Filter out 'correct answer' key
        options_filtered = {k: v for k, v in options_dict.items() if k != 'correct answer'}

        # Format options
        options_formatted = "Options: " + json.dumps(options_filtered)

    except (ValueError, SyntaxError) as e:
        print(f"Error parsing options: {e}")
        return None, None

    # Combine all parts
    prompt = f"{introduction}\n\n{question}\n\n{options_formatted}"

    # Get correct index
    correct_index = example['correct_index']

    return prompt, correct_index

def process_dataset(data):
    """
    Process entire dataset to create prompts and extract labels.

    Args:
        data (list): List of examples

    Returns:
        tuple: (X, y)
    """
    X = []
    y = []

    for example in data:
        prompt, label = create_prompt(example)
        if prompt is not None:  # Only add if parsing was successful
            X.append(prompt)
            y.append(label)

    return X, y

In [5]:
# Example usage:

# Assuming your data is loaded into a list called 'dataset':
X, y = process_dataset(train_data)

# Print an example:
print("Example prompt:")
print(X[0])
print("\nCorrect index:", y[0])

Example prompt:
You are a medical expert and this is a multiple choice exam question. Please respond with the integer index of the CORRECT answer only; [0,1,2,3].

Which of the following structural elements is characteristic of the ortopramide group drugs?

Options: {"0": "They are anilides with propyl group in ortho.", "1": "They are benzamides with methoxy group in ortho.", "2": "They are benzenesulfonamides with a methyl group in ortho.", "3": "They are ortho-halogenated derivatives of phenothiazine."}

Correct index: 1


# Get Goodfire


In [6]:
!pip install goodfire

Collecting goodfire
  Downloading goodfire-0.2.39-py3-none-any.whl.metadata (1.2 kB)
Collecting ipywidgets<9.0.0,>=8.1.5 (from goodfire)
  Downloading ipywidgets-8.1.5-py3-none-any.whl.metadata (2.3 kB)
Collecting comm>=0.1.3 (from ipywidgets<9.0.0,>=8.1.5->goodfire)
  Downloading comm-0.2.2-py3-none-any.whl.metadata (3.7 kB)
Collecting widgetsnbextension~=4.0.12 (from ipywidgets<9.0.0,>=8.1.5->goodfire)
  Downloading widgetsnbextension-4.0.13-py3-none-any.whl.metadata (1.6 kB)
Collecting jedi>=0.16 (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading goodfire-0.2.39-py3-none-any.whl (28 kB)
Downloading ipywidgets-8.1.5-py3-none-any.whl (139 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.8/139.8 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading comm-0.2.2-py3-none-any.whl (7.2 kB)
Downloading widgetsnbextension-4.0.13-py3-none-any.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━

## Extract features using Goodfire API

submit pubmedqa prompts, get top_k features.

Using top_k = 50, as per example at: https://docs.goodfire.ai/examples/decision_trees.html

In [7]:
from google.colab import userdata
import goodfire
import re
import json

# Get API key
api_key = userdata.get('GOODFIRE_API_KEY')

# Initialize Goodfire
client  = goodfire.Client(api_key)
variant = goodfire.Variant("meta-llama/Meta-Llama-3-8B-Instruct")



In [8]:
def extract_first_valid_integer(text):
    """
    Extract the first integer from text and validate if it's in [0,1,2,3].

    Args:
        text (str): Text to search for integers

    Returns:
        int or None: First valid integer if it's in [0,1,2,3], else None
    """
    # Pattern matches the first integer in the text
    pattern = r'\d+'
    match = re.search(pattern, text)

    if match:
        number = int(match.group())
        # Only return the number if it's in our valid set
        if number in [0, 1, 2, 3]:
            return number
    return -1

def get_integer_response(api_key, prompt):
    """
    Submit prompt to Goodfire API and extract first valid integer from response.

    Args:
        api_key (str): Goodfire API key
        prompt (str): Prompt to send to the model

    Returns:
        int or None: First valid integer found in response
    """
    # Initialize Goodfire client
    client = goodfire.Client(api_key)

    # Initialize model variant (using Llama 3 as shown in docs)
    variant = goodfire.Variant("meta-llama/Meta-Llama-3-8B-Instruct")

    # Create chat completion with streaming disabled for simpler processing
    response = client.chat.completions.create(
        messages=[
            {"role": "user", "content": prompt}
        ],
        model=variant,
        stream=False,
        max_completion_tokens=50  # Keep response short since we just need an integer
    )

    # Extract content from response based on the actual response structure
    try:
        content = response.choices[0].message['content']
        return extract_first_valid_integer(content)
    except (AttributeError, IndexError) as e:
        print(f"Error extracting content from response: {e}")
        print(f"Full response: {response}")
        return -1

In [9]:
## Test the code...
result = get_integer_response(api_key, X[0])
print(f"First valid integer found: {result}")

First valid integer found: 1


In [10]:
import numpy as np
import asyncio
import random
from concurrent.futures import ThreadPoolExecutor
from sklearn.metrics import accuracy_score, confusion_matrix, cohen_kappa_score
from typing import List, Tuple, Optional
import pandas as pd

def sample_data(X, y, k, random_seed = None):
    """
    Randomly sample k items from X and y, maintaining paired relationships.
    """
    if random_seed is not None:
        random.seed(random_seed)

    n = len(X)
    indices = random.sample(range(n), k)
    X_sampled = [X[i] for i in indices]
    y_sampled = [y[i] for i in indices]

    return X_sampled, y_sampled, indices

def evaluate_model_concurrent(api_key, X_sample, y_sample, max_workers = 10):
    """
    Non-async version of model evaluation using ThreadPoolExecutor.
    """
    # Create ThreadPoolExecutor for concurrent API calls
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Create tasks for each prompt
        futures = [
            executor.submit(get_integer_response, api_key, prompt)
            for prompt in X_sample
        ]

        # Gather predictions
        y_pred = []
        for future in futures:
            try:
                result = future.result()
                y_pred.append(result if result is not None else -1)  # Use -1 for failed predictions
            except Exception as e:
                print(f"Error in API call: {e}")
                y_pred.append(-1)

    # Calculate metrics
    # Filter out failed predictions (where y_pred is -1)
    valid_indices = [i for i, pred in enumerate(y_pred) if pred != -1]
    y_pred_valid = [y_pred[i] for i in valid_indices]
    y_sample_valid = [y_sample[i] for i in valid_indices]

    # Calculate metrics only on valid predictions
    if len(valid_indices) > 0:
        accuracy = accuracy_score(y_sample_valid, y_pred_valid)
        kappa = cohen_kappa_score(y_sample_valid, y_pred_valid)

        # Create confusion matrix
        conf_matrix = confusion_matrix(y_sample_valid, y_pred_valid, labels=[0,1,2,3])

        # Calculate per-class metrics
        per_class_accuracy = conf_matrix.diagonal() / conf_matrix.sum(axis=1)
    else:
        accuracy = 0
        kappa = 0
        per_class_accuracy = [0, 0, 0, 0]

    # Create results DataFrame
    results_df = pd.DataFrame({
        'prompt': X_sample,
        'true_answer': y_sample,
        'predicted_answer': y_pred,
        'correct': [pred == true for pred, true in zip(y_pred, y_sample)]
    })

    # Print evaluation summary
    print(f"\nEvaluation Results:")
    print(f"Total samples: {len(X_sample)}")
    print(f"Valid predictions: {len(valid_indices)}")
    print(f"Accuracy: {accuracy:.3f}")
    print(f"Cohen's Kappa: {kappa:.3f}")
    print("\nPer-class accuracy:")
    for i, acc in enumerate(per_class_accuracy):
        print(f"Class {i}: {acc:.3f}")

    return accuracy, kappa, results_df

def run_evaluation(api_key, X, y, k, random_seed= None, max_workers = 10):
    """
    Wrapper function to handle both notebook and app environments.
    """
    # Sample data
    X_sample, y_sample, indices = sample_data(X, y, k, random_seed)

    # Run evaluation
    return evaluate_model_concurrent(api_key, X_sample, y_sample, max_workers)


## Is LLM performance better than random?

There are only 4 choices. Even random guessing will achieve some sort of score...

In [23]:
import scipy
import numpy as np
from collections import Counter
from typing import Tuple, Dict


def find_min_successes(n: int, p: float, alpha: float = 0.05) -> int:
    """
    Find minimum number of successes needed for statistical significance.
    Uses binary search to find critical value.
    """
    left, right = int(n * p), n  # Start search from expected value

    while left <= right:
        mid = (left + right) // 2
        p_value = scipy.stats.binomtest(mid, n, p, alternative='greater').pvalue

        if p_value <= alpha:
            # Try to find a smaller value that still works
            if mid == left or scipy.stats.binomtest(mid - 1, n, p, alternative='greater').pvalue > alpha:
                return mid
            right = mid - 1
        else:
            left = mid + 1

    return left

def calculate_random_baseline(y_true: list) -> float:
    """
    Calculate random baseline accuracy based on class distribution.
    For perfectly balanced classes, this will return 0.25.
    For imbalanced classes, returns sum of squared proportions.

    Args:
        y_true: List of true labels

    Returns:
        float: Expected random accuracy based on class distribution
    """
    class_counts = Counter(y_true)
    total = len(y_true)

    # Calculate proportions for each class
    proportions = {k: v/total for k, v in class_counts.items()}

    # For random guessing with class imbalance,
    # probability of correct guess is sum of squared proportions
    random_baseline = sum(p*p for p in proportions.values())

    return random_baseline, proportions


def assess_performance(y_true: list, y_pred: list, alpha: float = 0.05) -> Dict:
    """
    Assess if model performance is significantly better than random chance.
    Uses binomial test and provides effect size metrics.

    Args:
        y_true: List of true labels
        y_pred: List of predicted labels (use only valid predictions, no -1s)
        alpha: Significance level for statistical test (default 0.05)

    Returns:
        Dictionary containing test results and metrics
    """
    # Calculate basic metrics
    n_samples = len(y_true)
    n_correct = sum(1 for t, p in zip(y_true, y_pred) if t == p)
    observed_accuracy = n_correct / n_samples

    # Calculate random baseline based on class distribution
    random_prob, class_proportions = calculate_random_baseline(y_true)

    # Perform one-sided binomial test
    p_value = scipy.stats.binomtest(n_correct, n_samples, p=random_prob, alternative='greater').pvalue

    # Calculate effect size (Cohen's h)
    h = 2 * (np.arcsin(np.sqrt(observed_accuracy)) - np.arcsin(np.sqrt(random_prob)))

    # Interpreted results
    is_significant = p_value < alpha

    # Effect size interpretation
    if abs(h) < 0.2:
        effect_size = 'negligible'
    elif abs(h) < 0.5:
        effect_size = 'small'
    elif abs(h) < 0.8:
        effect_size = 'medium'
    else:
        effect_size = 'large'

    # Find minimum correct needed for significance
    min_correct = find_min_successes(n_samples, random_prob, alpha)
    min_accuracy_needed = min_correct / n_samples

    results = {
        'better_than_random': is_significant,
        'p_value': p_value,
        'observed_accuracy': observed_accuracy,
        'effect_size': h,
        'effect_size_interpretation': effect_size,
        'n_samples': n_samples,
        'n_correct': n_correct,
        'min_accuracy_needed': min_accuracy_needed,
        'random_baseline': random_prob,
        'class_distribution': class_proportions
    }

    # Create distribution description
    dist_desc = "\nClass Distribution:\n"
    for class_label, prop in sorted(class_proportions.items()):
        count = int(prop * n_samples)
        dist_desc += f"Class {class_label}: {prop:.3f} ({count}/{n_samples} samples)\n"

    # Create human-readable summary
    summary = f"""
Performance Assessment:
----------------------
Observed Accuracy: {observed_accuracy:.3f} ({n_correct}/{n_samples})
Random Baseline: {random_prob:.3f} (based on class distribution)
P-value: {p_value:.4f}
Effect Size (Cohen's h): {h:.3f} ({effect_size})

{dist_desc}
Statistical Significance:
The model {'is' if is_significant else 'is not'} performing significantly better than the random baseline
(p{' < ' if p_value < alpha else ' = '}{p_value:.4f})

For {n_samples} samples, needed {min_accuracy_needed:.3f} accuracy ({min_correct} correct)
for statistical significance at α={alpha}
"""

    results['summary'] = summary

    return results

# Example usage integrated with the evaluation function
def run_evaluation_with_stats(api_key: str, X: list, y: list, k: int,
                            random_seed: Optional[int] = None,
                            max_workers: int = 10,
                            alpha: float = 0.05) -> Tuple[float, float, pd.DataFrame, Dict]:
    """
    Run evaluation and statistical analysis.

    Returns:
        Tuple of (accuracy, kappa, results DataFrame, statistical results)
    """
    # Run basic evaluation
    accuracy, kappa, results_df = run_evaluation(api_key, X, y, k, random_seed, max_workers)

    # Get valid predictions for statistical analysis
    valid_mask = results_df['predicted_answer'] != -1
    y_true_valid = results_df.loc[valid_mask, 'true_answer'].tolist()
    y_pred_valid = results_df.loc[valid_mask, 'predicted_answer'].tolist()

    # Run statistical analysis
    stats_results = assess_performance(y_true_valid, y_pred_valid, alpha)

    # Print statistical summary
    print(stats_results['summary'])

    return accuracy, kappa, results_df, stats_results



In [24]:
  # Example usage
k = 300
random_seed = 42
accuracy, kappa, results, stats = run_evaluation_with_stats(
    api_key, X, y, k, random_seed
)


Evaluation Results:
Total samples: 300
Valid predictions: 300
Accuracy: 0.450
Cohen's Kappa: 0.240

Per-class accuracy:
Class 0: 0.238
Class 1: 0.837
Class 2: 0.462
Class 3: 0.135

Performance Assessment:
----------------------
Observed Accuracy: 0.450 (135/300)
Random Baseline: 0.258 (based on class distribution)
P-value: 0.0000
Effect Size (Cohen's h): 0.405 (small)


Class Distribution:
Class 0: 0.280 (84/300 samples)
Class 1: 0.287 (86/300 samples)
Class 2: 0.260 (78/300 samples)
Class 3: 0.173 (52/300 samples)

Statistical Significance:
The model is performing significantly better than the random baseline
(p < 0.0000)

For 300 samples, needed 0.303 accuracy (91 correct)
for statistical significance at α=0.05



## Convert to production code for github...

In [18]:
# Standard library imports
import os
import json
import ast
import re
import random
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from collections import Counter
from concurrent.futures import ThreadPoolExecutor

# Third-party imports
import numpy as np
import pandas as pd
import scipy.stats
from datasets import load_dataset
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    cohen_kappa_score
)
import goodfire

# Logging configuration
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DataHandler:
    """Handles loading and preprocessing of medical evaluation data."""

    def __init__(self, cache_dir: str = ".cache/med_eval"):
        """
        Initialize DataHandler with cache directory.

        Args:
            cache_dir (str): Directory for caching downloaded data
        """
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self.data_path = self.cache_dir / "med_halt_data.json"

    def load_data(self) -> Tuple[List[str], List[int], List[str]]:
        """
        Load data, using cache if available, otherwise download fresh.

        Returns:
            Tuple[List[str], List[int], List[str]]: Processed prompts, labels, and subject names
        """
        if self.data_path.exists():
            logger.info("Loading data from cache...")
            return self._load_from_cache()

        logger.info("Downloading fresh data...")
        return self._download_and_cache()

    def _load_from_cache(self) -> Tuple[List[str], List[int], List[str]]:
        """Load processed data from cache."""
        try:
            with open(self.data_path, 'r') as f:
                data = json.load(f)
            return data['prompts'], data['labels'], data['subject_names']
        except Exception as e:
            logger.error(f"Error loading cached data: {e}")
            logger.info("Falling back to fresh download...")
            return self._download_and_cache()

    def _download_and_cache(self) -> Tuple[List[str], List[int], List[str]]:
        """Download fresh data, process it, and cache the results."""
        try:
            dataset = load_dataset("openlifescienceai/Med-HALT", "reasoning_FCT")
            train_data = dataset['train']

            prompts, labels, subject_names = self._process_data(train_data)

            # Cache the processed data
            cache_data = {
                'prompts': prompts,
                'labels': labels,
                'subject_names': subject_names
            }
            with open(self.data_path, 'w') as f:
                json.dump(cache_data, f)

            return prompts, labels, subject_names

        except Exception as e:
            logger.error(f"Error downloading/processing data: {e}")
            raise

    def _process_data(self, dataset) -> Tuple[List[str], List[int], List[str]]:
        """Process raw dataset into prompts, labels, and subject names."""
        prompts = []
        labels = []
        subject_names = []

        for example in dataset:
            prompt, label = self._create_prompt(example)
            if prompt is not None:
                prompts.append(prompt)
                labels.append(label)
                subject_names.append(example.get('subject_name', ''))

        return prompts, labels, subject_names

    @staticmethod
    def _create_prompt(example: Dict) -> Tuple[Optional[str], Optional[int]]:
        """Create a formatted prompt from a single example."""
        try:
            introduction = ("You are a medical expert and this is a multiple choice exam question. "
                          "Please respond with the integer index of the CORRECT answer only; [0,1,2,3].")

            question = example['question']

            # Parse options
            if isinstance(example['options'], str):
                options_dict = ast.literal_eval(example['options'])
            else:
                options_dict = example['options']

            options_filtered = {k: v for k, v in options_dict.items() if k != 'correct answer'}
            options_formatted = "Options: " + json.dumps(options_filtered)

            prompt = f"{introduction}\n\n{question}\n\n{options_formatted}"
            return prompt, example['correct_index']

        except Exception as e:
            logger.error(f"Error creating prompt: {e}")
            return None, None

    def filter_by_subject(self, prompts: List[str], labels: List[int],
                        subject_names: List[str], subject_name: Optional[str] = None) -> Tuple[List[str], List[int]]:
        """Filter data by subject name."""
        if not subject_name:
            return prompts, labels

        subject_name = subject_name.lower()
        filtered_indices = [i for i, name in enumerate(subject_names)
                          if name and name.lower() == subject_name]

        return ([prompts[i] for i in filtered_indices],
                [labels[i] for i in filtered_indices])



In [19]:

class LLMEvaluator:
    """Handles evaluation of LLM performance on medical questions."""

    def __init__(self, client, variant):
        """
        Initialize evaluator with API credentials.
        """
        self.variant = variant
        self.client = client

    def evaluate(self,
                X: List[str],
                y: List[int],
                k: int,
                random_seed: Optional[int] = None,
                max_workers: int = 10) -> Tuple[float, float, pd.DataFrame]:
        """
        Evaluate model performance on a sample of questions.

        Args:
            X (List[str]): List of prompts
            y (List[int]): List of correct answers
            k (int): Number of samples to evaluate
            random_seed (Optional[int]): Random seed for reproducibility
            max_workers (int): Maximum number of concurrent API calls

        Returns:
            Tuple[float, float, pd.DataFrame]: Accuracy, kappa, and detailed results
        """
        X_sample, y_sample, indices = self._sample_data(X, y, k, random_seed)
        return self._evaluate_concurrent(X_sample, y_sample, max_workers)

    def _sample_data(self,
                    X: List[str],
                    y: List[int],
                    k: int,
                    random_seed: Optional[int] = None) -> Tuple[List[str], List[int], List[int]]:
        """Randomly sample k items from data."""
        if random_seed is not None:
            random.seed(random_seed)

        n = len(X)
        indices = random.sample(range(n), k)
        X_sampled = [X[i] for i in indices]
        y_sampled = [y[i] for i in indices]

        return X_sampled, y_sampled, indices

    def _evaluate_concurrent(self,
                           X_sample: List[str],
                           y_sample: List[int],
                           max_workers: int) -> Tuple[float, float, pd.DataFrame]:
        """Evaluate samples concurrently."""
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [
                executor.submit(self._get_model_response, prompt)
                for prompt in X_sample
            ]

            y_pred = []
            for future in futures:
                try:
                    result = future.result()
                    y_pred.append(result if result is not None else -1)
                except Exception as e:
                    logger.error(f"Error in API call: {e}")
                    y_pred.append(-1)

        return self._calculate_metrics(X_sample, y_sample, y_pred)

    def _get_model_response(self, prompt: str) -> Optional[int]:
        """Get integer response from model."""
        try:
            response = self.client.chat.completions.create(
                messages=[{"role": "user", "content": prompt}],
                model=self.variant,
                stream=False,
                max_completion_tokens=50
            )
            content = response.choices[0].message['content']
            return self._extract_first_valid_integer(content)
        except Exception as e:
            logger.error(f"Error getting model response: {e}")
            return None

    @staticmethod
    def _extract_first_valid_integer(text: str) -> int:
        """Extract first valid integer from text."""
        pattern = r'\d+'
        match = re.search(pattern, text)

        if match:
            number = int(match.group())
            if number in [0, 1, 2, 3]:
                return number
        return -1

    def _calculate_metrics(self,
                         X_sample: List[str],
                         y_sample: List[int],
                         y_pred: List[int]) -> Tuple[float, float, pd.DataFrame]:
        """Calculate evaluation metrics."""
        valid_indices = [i for i, pred in enumerate(y_pred) if pred != -1]
        y_pred_valid = [y_pred[i] for i in valid_indices]
        y_sample_valid = [y_sample[i] for i in valid_indices]

        if len(valid_indices) > 0:
            accuracy = accuracy_score(y_sample_valid, y_pred_valid)
            kappa = cohen_kappa_score(y_sample_valid, y_pred_valid)
            conf_matrix = confusion_matrix(y_sample_valid, y_pred_valid, labels=[0,1,2,3])
            per_class_accuracy = conf_matrix.diagonal() / conf_matrix.sum(axis=1)
        else:
            accuracy = 0
            kappa = 0
            per_class_accuracy = [0, 0, 0, 0]

        results_df = pd.DataFrame({
            'prompt': X_sample,
            'true_answer': y_sample,
            'predicted_answer': y_pred,
            'correct': [pred == true for pred, true in zip(y_pred, y_sample)]
        })

        # Log results
        logger.info(f"\nEvaluation Results:")
        logger.info(f"Total samples: {len(X_sample)}")
        logger.info(f"Valid predictions: {len(valid_indices)}")
        logger.info(f"Accuracy: {accuracy:.3f}")
        logger.info(f"Cohen's Kappa: {kappa:.3f}")
        logger.info("\nPer-class accuracy:")
        for i, acc in enumerate(per_class_accuracy):
            logger.info(f"Class {i}: {acc:.3f}")

        return accuracy, kappa, results_df

In [20]:

class StatisticalAnalyzer:
    """Handles statistical analysis of model performance."""

    @staticmethod
    def analyze(y_true: List[int],
                y_pred: List[int],
                alpha: float = 0.05) -> Dict:
        """
        Analyze if model performance is significantly better than random.

        Args:
            y_true (List[int]): True labels
            y_pred (List[int]): Predicted labels
            alpha (float): Significance level

        Returns:
            Dict: Statistical analysis results
        """
        # Basic metrics
        n_samples = len(y_true)
        n_correct = sum(1 for t, p in zip(y_true, y_pred) if t == p)
        observed_accuracy = n_correct / n_samples

        # Calculate baselines
        random_prob, class_proportions = StatisticalAnalyzer._calculate_random_baseline(y_true)

        # Statistical test - using scipy.stats instead of stats
        p_value = scipy.stats.binomtest(n_correct, n_samples, p=random_prob,
                                alternative='greater').pvalue

        # Effect size
        h = 2 * (np.arcsin(np.sqrt(observed_accuracy)) -
                np.arcsin(np.sqrt(random_prob)))

        # Interpret results
        is_significant = p_value < alpha
        effect_size = StatisticalAnalyzer._interpret_effect_size(h)

        # Minimum needed for significance
        min_successes = StatisticalAnalyzer._find_min_successes(n_samples, random_prob, alpha)
        min_accuracy = min_successes / n_samples

        results = {
            'better_than_random': is_significant,
            'p_value': p_value,
            'observed_accuracy': observed_accuracy,
            'effect_size': h,
            'effect_size_interpretation': effect_size,
            'n_samples': n_samples,
            'n_correct': n_correct,
            'min_correct': min_successes,
            'min_accuracy_needed': min_accuracy,
            'random_baseline': random_prob,
            'class_distribution': class_proportions
        }

        # Add human-readable summary
        results['summary'] = StatisticalAnalyzer._create_summary(results, alpha)

        return results

    @staticmethod
    def _calculate_random_baseline(y_true: List[int]) -> Tuple[float, Dict[int, float]]:
        """Calculate random baseline accuracy based on class distribution."""
        class_counts = Counter(y_true)
        total = len(y_true)

        proportions = {k: v/total for k, v in class_counts.items()}
        random_baseline = sum(p*p for p in proportions.values())

        return random_baseline, proportions

    @staticmethod
    def _find_min_successes(n: int, p: float, alpha: float) -> int:
        """Find minimum successes needed for significance."""
        left, right = int(n * p), n

        while left <= right:
            mid = (left + right) // 2
            p_value = scipy.stats.binomtest(mid, n, p, alternative='greater').pvalue

            if p_value <= alpha:
                if mid == left or scipy.stats.binomtest(mid - 1, n, p,
                                                alternative='greater').pvalue > alpha:
                    return mid
                right = mid - 1
            else:
                left = mid + 1

        return left

    @staticmethod
    def _interpret_effect_size(h: float) -> str:
        """Interpret Cohen's h effect size."""
        if abs(h) < 0.2:
            return 'negligible'
        elif abs(h) < 0.5:
            return 'small'
        elif abs(h) < 0.8:
            return 'medium'
        else:
            return 'large'

    @staticmethod
    def _create_summary(results: Dict, alpha: float) -> str:
        """Create human-readable summary of results."""
        dist_desc = "\nClass Distribution:\n"
        for class_label, prop in sorted(results['class_distribution'].items()):
            count = int(prop * results['n_samples'])
            dist_desc += (f"Class {class_label}: {prop:.3f} "
                        f"({count}/{results['n_samples']} samples)\n")

        return f"""
Performance Assessment:
----------------------
Observed Accuracy: {results['observed_accuracy']:.3f} ({results['n_correct']}/{results['n_samples']})
Random Baseline: {results['random_baseline']:.3f} (based on class distribution)
P-value: {results['p_value']:.4f}
Effect Size (Cohen's h): {results['effect_size']:.3f} ({results['effect_size_interpretation']})

{dist_desc}
Statistical Significance:
The model {'is' if results['better_than_random'] else 'is not'} performing significantly better than the random baseline
(p{' < ' if results['p_value'] < alpha else ' = '}{results['p_value']:.4f})

For {results['n_samples']} samples, needed {results['min_accuracy_needed']:.3f} accuracy ({results['min_correct']} correct)
for statistical significance at α={alpha}
"""



In [21]:

class MedicalLLMEvaluator:
  """Main interface for evaluating LLM performance on medical questions."""

  def __init__(self, client, variant, cache_dir = ".cache/med_eval"):
      """
      Initialize the medical LLM evaluator.

      Args:
          api_key (str): Goodfire API key
          model_name (str): Name of the model to use
          cache_dir (str): Directory for caching downloaded data
      """
      self.data_handler = DataHandler(cache_dir)
      self.evaluator = LLMEvaluator(client, variant)
      self.analyzer = StatisticalAnalyzer()

  def run_evaluation(self,
                    k: int,
                    subject_name: Optional[str] = None,
                    random_seed: Optional[int] = None,
                    max_workers: int = 10,
                    alpha: float = 0.05) -> Tuple[float, float, pd.DataFrame, Dict]:
      """Run complete evaluation including statistical analysis."""
      # Load data
      prompts, labels, subject_names = self.data_handler.load_data()

      # Filter by subject_name if provided
      X, y = self.data_handler.filter_by_subject(prompts, labels, subject_names, subject_name)

      if not X:  # Check if we have any data after filtering
          raise ValueError(f"No data found for subject: {subject_name}")

      if len(X) < k:
          raise ValueError(f"Not enough data for subject: {subject_name}")

      # Rest remains unchanged
      accuracy, kappa, results_df = self.evaluator.evaluate(
          X, y, k, random_seed, max_workers
      )

      valid_mask = results_df['predicted_answer'] != -1
      y_true_valid = results_df.loc[valid_mask, 'true_answer'].tolist()
      y_pred_valid = results_df.loc[valid_mask, 'predicted_answer'].tolist()

      stats_results = self.analyzer.analyze(y_true_valid, y_pred_valid, alpha)

      logger.info(stats_results['summary'])

      return accuracy, kappa, results_df, stats_results



In [None]:
# EXAMPLE USAGE

import os
from google.colab import userdata
import goodfire

# Get API key
api_key = userdata.get('GOODFIRE_API_KEY')

# Initialize Goodfire
client  = goodfire.Client(api_key)
variant = goodfire.Variant("meta-llama/Meta-Llama-3-8B-Instruct")

# Initialize evaluator
evaluator = MedicalLLMEvaluator(client, variant)
evaluator

# Run evaluation
accuracy, kappa, results, stats = evaluator.run_evaluation(
    k=100              # number of samples
    random_seed=42,    # for reproducibility
    max_workers=10,    # concurrent API calls
    subject_name=None  # enter subject_name, eg 'psychology', or leave as None
)

# Results are already logged by the evaluator
# You can also access them programmatically:
print(f"\nAccuracy: {accuracy:.3f}")
print(f"Kappa: {kappa:.3f}")
print("\nDetailed results available in results DataFrame")
print("\nStatistical analysis results available in stats dictionary")


Accuracy: 0.600
Kappa: 0.394

Detailed results available in results DataFrame

Statistical analysis results available in stats dictionary
