<a href="https://colab.research.google.com/github/SNCA-24/5218_dl_snca.py/blob/main/LLM_Distractor_Ranking_Part_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Part 1 - Preparation & Training

In [1]:
# cell 1 - dependencies
!pip install torch
!pip install transformers==4.28.0
!pip install datasets scipy matplotlib seaborn pandas numpy

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
# cell 2 - import libraries
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import logging

from datasets import Dataset
from scipy.stats import spearmanr
from sklearn.metrics import cohen_kappa_score

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Trainer, TrainingArguments

# Set random seed for reproducibility
np.random.seed(42)

In [3]:
# Cell 3 – configs

import os

# Test mode: 'smoke' for quick test (low config only), 'full' for all configs
TEST_MODE = 'full'

# File paths
HUMAN_RANKED_PATH = '/content/human_ranked.csv'
MMLU_PATH = '/content/training_data.csv'

# Model names from Hugging Face
MODEL_NAMES = [
    'google/t5-efficient-mini',
    'google-t5/t5-small',
    'google/flan-t5-small',
    'sshleifer/distilbart-cnn-6-6',
    'sshleifer/distilbart-xsum-12-3',
]

# Number of labels for classification
NUM_LABELS = 4

# Hyperparameter settings
HYPERPARAMS = {
    'low':    {'batch_size': 4,  'epochs': 1, 'learning_rate': 1e-5},
    'medium': {'batch_size': 8,  'epochs': 3, 'learning_rate': 5e-5},
    'high':   {'batch_size': 16, 'epochs': 5, 'learning_rate': 1e-4},
}

# Pick which settings to run
if TEST_MODE == 'smoke':
    HYPERPARAM_SETTINGS = ['low']
else:
    HYPERPARAM_SETTINGS = list(HYPERPARAMS.keys())

# Output directories
OUTPUT_DIR = './experiments'
VISUALIZATIONS_DIR = './plots'

# Create directories if they don't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(VISUALIZATIONS_DIR, exist_ok=True)

In [4]:
# cell 4 - load datasets
def load_datasets(human_ranked_path, mmlu_path):
    """
    Load human_ranked and MMLU datasets into pandas DataFrames.

    Args:
        human_ranked_path (str): Path to human_ranked.csv
        mmlu_path (str): Path to training_data_sample.csv

    Returns:
        tuple: (human_ranked_df, mmlu_df)

    Raises:
        FileNotFoundError: If input files are not found
        ValueError: If expected columns are missing
    """


    # Load datasets
    human_ranked_df = pd.read_csv(HUMAN_RANKED_PATH)
    mmlu_df = pd.read_csv(MMLU_PATH)

    # Expected columns
    human_ranked_columns = [
        'subject', 'question', 'correct_answer', 'option_0', 'option_1',
        'option_2', 'option_3', 'distractor_ranking_best_to_worst_Annotator_1',
        'distractor_ranking_best_to_worst_Annotator_2'
    ]

    mmlu_columns = ['question', 'subject', 'choices', 'correct_answer', 'option_0', 'option_1', 'option_2', 'option_3']

    # Validate columns
    if not all(col in human_ranked_df.columns for col in human_ranked_columns):
        missing = [col for col in human_ranked_columns if col not in human_ranked_df.columns]
        raise ValueError(f"Missing columns in human_ranked_df: {missing}")
    if not all(col in mmlu_df.columns for col in mmlu_columns):
        missing = [col for col in mmlu_columns if col not in mmlu_df.columns]
        raise ValueError(f"Missing columns in mmlu_df: {missing}")

    return human_ranked_df, mmlu_df

# Load datasets using constants from Cell 3
human_ranked_df, mmlu_df = load_datasets(HUMAN_RANKED_PATH, MMLU_PATH)



In [5]:
# Cell 5 - Preprocess Data

def preprocess_data(human_ranked_df, mmlu_df):
    # Print mmlu_df columns to diagnose KeyError
    print("mmlu_df columns:", mmlu_df.columns.tolist())

    # Identify rows to keep and drop for human_ranked_df
    critical_columns_human = [
        'subject', 'question', 'correct_answer',
        'option_0', 'option_1', 'option_2', 'option_3',
        'distractor_ranking_best_to_worst_Annotator_1'
    ]
    keep_mask_human = human_ranked_df[critical_columns_human].notna().all(axis=1)
    human_ranked_df_dropped = human_ranked_df.loc[~keep_mask_human].copy()
    human_ranked_df_keep = human_ranked_df.loc[keep_mask_human].copy()

    # Log dropped rows for human_ranked_df
    dropped_count_human = len(human_ranked_df_dropped)
    print(f"Dropped {dropped_count_human} rows from human_ranked_df. Saved to ./dropped_human_ranked_rows.csv")
    human_ranked_df_dropped.to_csv('./dropped_human_ranked_rows.csv', index=False)

    # Identify rows to keep and drop for mmlu_df
    critical_columns_mmlu = [
        'subject', 'question', 'correct_answer',
        'option_0', 'option_1', 'option_2', 'option_3'
    ]
    # Check if critical columns exist in mmlu_df
    missing_columns = [col for col in critical_columns_mmlu if col not in mmlu_df.columns]
    if missing_columns:
        print(f"Warning: Columns {missing_columns} not found in mmlu_df. Available columns: {mmlu_df.columns.tolist()}")
        # Proceed with available columns only
        critical_columns_mmlu = [col for col in critical_columns_mmlu if col in mmlu_df.columns]

    keep_mask_mmlu = mmlu_df[critical_columns_mmlu].notna().all(axis=1)
    mmlu_df_dropped = mmlu_df.loc[~keep_mask_mmlu].copy()
    mmlu_df_keep = mmlu_df.loc[keep_mask_mmlu].copy()

    # Log dropped rows for mmlu_df
    dropped_count_mmlu = len(mmlu_df_dropped)
    print(f"Dropped {dropped_count_mmlu} rows from mmlu_df. Saved to ./dropped_mmlu_rows.csv")
    mmlu_df_dropped.to_csv('./dropped_mmlu_rows.csv', index=False)

    # (Removed standardization of option columns - no renaming performed)

    # Verify that human_ranked_df_keep has required option columns
    required_options = ['option_0', 'option_1', 'option_2', 'option_3']
    missing_options = [opt for opt in required_options if opt not in human_ranked_df_keep.columns]
    if missing_options:
        print(f"Error: Required option columns {missing_options} missing in human_ranked_df.")

    return human_ranked_df_keep, mmlu_df_keep

# Apply preprocessing
human_ranked_df, mmlu_df = preprocess_data(human_ranked_df, mmlu_df)


mmlu_df columns: ['question', 'subject', 'choices', 'correct_answer', 'option_0', 'option_1', 'option_2', 'option_3']
Dropped 0 rows from human_ranked_df. Saved to ./dropped_human_ranked_rows.csv
Dropped 10 rows from mmlu_df. Saved to ./dropped_mmlu_rows.csv


In [6]:
human_ranked_df.head()

Unnamed: 0,subject,question_id,question,correct_answer,option_0,option_1,option_2,option_3,distractor_ranking_best_to_worst_Annotator_1,distractor_ranking_best_to_worst_Annotator_2
0,high_school_microeconomics,1,Which of the following is not a reason why dem...,1,An increase in consumer income (for normal goods),A fall in the price of the good itself,A rise in the price of a substitute,An expected increase in future prices,230,
1,high_school_microeconomics,2,"At the profit-maximizing quantity, a perfectly...",1,Marginal revenue,Marginal cost,Average cost,None of the them,23,
2,high_school_microeconomics,3,A negative externality results in...,0,Overproduction and underpricing,Underproduction and overpricing,Efficient market output,None of the them,132,
3,high_school_microeconomics,4,Which statement is false about price ceilings?,2,They can lead to shortages,They are set below the equilibrium price,They always increase producer surplus,They distort market efficiency,13,
4,high_school_microeconomics,5,Cross-price elasticity of demand between two g...,1,The goods are substitutes,The goods are complements,They are unrelated,Cannot be determined,23,


In [7]:
mmlu_df.head()

Unnamed: 0,question,subject,choices,correct_answer,option_0,option_1,option_2,option_3
0,Find the degree for the given field extension ...,abstract_algebra,['0' '4' '2' '6'],1,0,4,2,6
1,"Let p = (1, 2, 5, 4)(2, 3) in S_5 . Find the i...",abstract_algebra,['8' '2' '24' '120'],2,8,2,24,120
2,Find all zeros in the indicated finite field o...,abstract_algebra,"['0' '1' '0,1' '0,4']",3,0,1,01,04
3,Statement 1 | A factor group of a non-Abelian ...,abstract_algebra,"['True, True' 'False, False' 'True, False' 'Fa...",1,"True, True","False, False","True, False","False, True"
4,Find the product of the given polynomials in t...,abstract_algebra,['2x^2 + 5' '6x^2 + 4x + 6' '0' 'x^2 + 1'],1,2x^2 + 5,6x^2 + 4x + 6,0,x^2 + 1


##Prediction Method Finalised - Log Probabibilty Scoring

In [8]:
# Cell 6 – Distractor Analysis Class (add log-prob scoring)

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Trainer, TrainingArguments
import torch
import torch.nn.functional as F
import numpy as np
import os
import logging

class DistractorAnalysis:
    def __init__(self, model_name, hyperparam_setting):
        self.model_name = model_name
        self.hyperparam_setting = hyperparam_setting

        if hyperparam_setting not in HYPERPARAMS:
            raise ValueError(f"Invalid hyperparam_setting: {hyperparam_setting}")

        # Load full encoder–decoder LM
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model     = AutoModelForSeq2SeqLM.from_pretrained(model_name)

        # TrainingArguments stay the same as before
        self.training_args = TrainingArguments(
            output_dir=os.path.join(
                OUTPUT_DIR, f"{model_name.replace('/', '_')}_{hyperparam_setting}"
            ),
            per_device_train_batch_size=HYPERPARAMS[hyperparam_setting]['batch_size'],
            per_device_eval_batch_size= HYPERPARAMS[hyperparam_setting]['batch_size'],
            num_train_epochs=         HYPERPARAMS[hyperparam_setting]['epochs'],
            learning_rate=            HYPERPARAMS[hyperparam_setting]['learning_rate'],
            evaluation_strategy=      "steps",
            eval_steps=               1000,
            logging_steps=            500,
            save_steps=               1000,
            save_total_limit=         2,
            load_best_model_at_end=   True,
            metric_for_best_model=    "loss",  # or "eval_loss"
            greater_is_better=        False,
            remove_unused_columns=    False,
            logging_dir=os.path.join(
                OUTPUT_DIR, f"{model_name.replace('/', '_')}_{hyperparam_setting}_logs"
            ),
        )

    def train(self, tokenized_datasets):
        """
        Fine‐tune the seq2seq LM in a standard teacher‐forcing way,
        where the target is the single letter token (A/B/C/D).
        """
        trainer = Trainer(
            model=self.model,
            args=self.training_args,
            train_dataset=tokenized_datasets[self.model_name]['train'],
            eval_dataset= tokenized_datasets[self.model_name]['eval'],
        )
        trainer.train()
        ckpt = os.path.join(
            OUTPUT_DIR, f"{self.model_name.replace('/', '_')}_{self.hyperparam_setting}_best"
        )
        trainer.save_model(ckpt)
        self.tokenizer.save_pretrained(ckpt)

    def predict_logprob(self, dataset, batch_size=1, device='cuda'):
        """
        For each input_text, score each letter token (A/B/C/D) by the
        total log‐prob of generating that letter next, then pick the best.
        Returns:
          preds_arr: np.array of shape [N] with 0–3
          logprobs:  np.array of shape [N,4] with the log‐prob for each letter
        """
        dev = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.model.to(dev).eval()

        # map letter idx → string
        letter_str = ["A","B","C","D"]
        N = len(dataset)
        all_preds    = []
        all_logprobs = []

        with torch.no_grad():
            for ex in dataset:
                prompt = ex['input_text']
                # tokenize prompt once
                enc = self.tokenizer(prompt, return_tensors='pt').to(dev)
                input_ids = enc.input_ids
                attn_mask = enc.attention_mask

                # score each letter
                letter_scores = []
                for L in letter_str:
                    # teacher‐force L as next token
                    lab_ids = self.tokenizer(L, add_special_tokens=False, return_tensors='pt').input_ids.to(dev)
                    full_ids   = torch.cat([input_ids, lab_ids], dim=-1)
                    full_mask  = torch.cat([attn_mask, torch.ones_like(lab_ids)], dim=-1)

                    out = self.model(
                        input_ids=full_ids,
                        attention_mask=full_mask,
                        labels=full_ids
                    )
                    # out.loss is average NLL per token → multiply by #tokens
                    total_nll = out.loss.item() * lab_ids.size(1)
                    letter_scores.append(-total_nll)

                # pick best
                scores_arr = np.array(letter_scores)
                best_idx   = int(scores_arr.argmax())
                all_preds.append(best_idx)
                all_logprobs.append(scores_arr)

        return np.array(all_preds), np.stack(all_logprobs, axis=0)

In [9]:
# Cell 7 - Prepare Datasets for Training

from datasets import Dataset
from transformers import AutoTokenizer
import pandas as pd
import logging

def prepare_datasets(human_ranked_df, mmlu_df):
    # Verify required global variables
    if 'MODEL_NAMES' not in globals():
        raise NameError("MODEL_NAMES global variable not defined. Please run Cell #3 first.")

    option_columns = ['option_0', 'option_1', 'option_2', 'option_3']

    # Map answer index to choice token (A, B, C, D)
    def map_answer_to_choice(row, answer_col):
        try:
            answer_idx = int(row[answer_col])
            if answer_idx not in [0, 1, 2, 3]:
                return None
            return chr(65 + answer_idx)  # 0 -> 'A', 1 -> 'B', 2 -> 'C', 3 -> 'D'
        except (ValueError, TypeError, KeyError):
            return None

    # Validate columns in mmlu_df
    missing_mmlu_cols = [col for col in ['question', 'correct_answer'] + option_columns if col not in mmlu_df.columns]
    if missing_mmlu_cols:
        print(f"Warning: Missing columns in mmlu_df: {missing_mmlu_cols}. Available columns: {mmlu_df.columns.tolist()}")
        fallback_cols = ['choice1', 'choice2', 'choice3', 'choice4']
        if all(col in mmlu_df.columns for col in fallback_cols):
            print("Falling back to choice1, choice2, choice3, choice4 for mmlu_df")
            mmlu_df = mmlu_df.rename(columns={
                'choice1': 'option_0',
                'choice2': 'option_1',
                'choice3': 'option_2',
                'choice4': 'option_3'
            })
        else:
            raise ValueError("Cannot proceed: mmlu_df missing required option columns")

    # Validate columns in human_ranked_df - using consistent approach with mmlu_df
    missing_human_cols = [col for col in ['question', 'correct_answer'] + option_columns if col not in human_ranked_df.columns]
    if missing_human_cols:
        print(f"Warning: Missing columns in human_ranked_df: {missing_human_cols}. Available columns: {human_ranked_df.columns.tolist()}")
        # Could add fallback logic here similar to mmlu_df if needed
        raise ValueError("Cannot proceed: human_ranked_df missing required columns")

    # Process mmlu_df with safer type conversion
    mmlu_df = mmlu_df.copy()
    mmlu_df['correct_choice'] = mmlu_df.apply(lambda row: map_answer_to_choice(row, 'correct_answer'), axis=1)

    # Safer conversion with validation
    mmlu_df['correct_index'] = mmlu_df['correct_answer'].apply(
        lambda x: int(x) if pd.notna(x) and str(x).strip().isdigit() else None
    )

    mmlu_df_dropped = mmlu_df.loc[(mmlu_df['correct_choice'].isna()) | (mmlu_df['correct_index'].isna())].copy()
    mmlu_df = mmlu_df.loc[(mmlu_df['correct_choice'].notna()) & (mmlu_df['correct_index'].notna())].copy()
    if len(mmlu_df_dropped) > 0:
        print(f"Dropped {len(mmlu_df_dropped)} rows from mmlu_df. Saved to ./dropped_mmlu_choice_rows.csv")
        mmlu_df_dropped.to_csv('./dropped_mmlu_choice_rows.csv', index=False)

    # Process human_ranked_df with safer type conversion
    human_ranked_df = human_ranked_df.copy()
    human_ranked_df['correct_choice'] = human_ranked_df.apply(lambda row: map_answer_to_choice(row, 'correct_answer'), axis=1)

    # Safer conversion with validation
    human_ranked_df['correct_index'] = human_ranked_df['correct_answer'].apply(
        lambda x: int(x) if pd.notna(x) and str(x).strip().isdigit() else None
    )

    human_ranked_df_dropped = human_ranked_df.loc[(human_ranked_df['correct_choice'].isna()) | (human_ranked_df['correct_index'].isna())].copy()
    human_ranked_df = human_ranked_df.loc[(human_ranked_df['correct_choice'].notna()) & (human_ranked_df['correct_index'].notna())].copy()
    if len(human_ranked_df_dropped) > 0:
        print(f"Dropped {len(human_ranked_df_dropped)} rows from human_ranked_df. Saved to ./dropped_human_ranked_choice_rows.csv")
        human_ranked_df_dropped.to_csv('./dropped_human_ranked_choice_rows.csv', index=False)

    # Create input text
    def create_input_text(row):
        return f"Question: {row['question']} Options: A: {row['option_0']} B: {row['option_1']} C: {row['option_2']} D: {row['option_3']}"

    mmlu_df['input_text'] = mmlu_df.apply(create_input_text, axis=1)
    human_ranked_df['input_text'] = human_ranked_df.apply(create_input_text, axis=1)

    # Create datasets
    train_dataset = Dataset.from_pandas(mmlu_df[['input_text', 'correct_choice', 'question', 'option_0', 'option_1', 'option_2', 'option_3', 'correct_index']])
    eval_dataset = Dataset.from_pandas(human_ranked_df[['input_text', 'correct_choice', 'question', 'option_0', 'option_1', 'option_2', 'option_3', 'correct_index']])

    # Remove pandas index column if present
    for dataset in [train_dataset, eval_dataset]:
        if '__index_level_0__' in dataset.column_names:
            dataset = dataset.remove_columns('__index_level_0__')

    # Tokenization
    tokenized_datasets = {}
    for model_name in MODEL_NAMES:
        try:
            print(f"Tokenizing datasets for model: {model_name}")
            tokenizer = AutoTokenizer.from_pretrained(model_name)

            def tokenize_function(examples):
                model_inputs = tokenizer(
                    examples['input_text'],
                    max_length=512,
                    truncation=True,
                    padding='max_length',
                    return_tensors=None
                )
                # Tokenize correct_choice as labels
                labels = tokenizer(examples['correct_choice'], add_special_tokens=False).input_ids
                model_inputs['labels'] = [label[0] if len(label) == 1 else tokenizer.unk_token_id for label in labels]
                # Set decoder_input_ids for seq2seq models
                start_token = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
                model_inputs['decoder_input_ids'] = [[start_token]] * len(examples['input_text'])
                return model_inputs

            # Updated code to be more selective about which columns to keep
            tokenized_train = train_dataset.map(
                tokenize_function,
                batched=True,
                # Only keep columns needed by the model
                remove_columns=[col for col in train_dataset.column_names
                               if col not in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids']]
            )
            tokenized_eval = eval_dataset.map(
                tokenize_function,
                batched=True,
                # Only keep columns needed by the model
                remove_columns=[col for col in eval_dataset.column_names
                               if col not in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids']]
            )

            # Verify the tokenized datasets have the required fields for DistractorAnalysis.predict
            required_fields = ['input_ids', 'attention_mask', 'labels']
            missing_train = [f for f in required_fields if f not in tokenized_train.column_names]
            missing_eval = [f for f in required_fields if f not in tokenized_eval.column_names]

            if missing_train or missing_eval:
                print(f"Warning: Missing fields for {model_name}. Train: {missing_train}, Eval: {missing_eval}")
                if missing_train or missing_eval:
                    raise ValueError(f"Tokenized datasets missing required fields for model {model_name}")

            tokenized_datasets[model_name] = {'train': tokenized_train, 'eval': tokenized_eval}
            print(f"Tokenized datasets for {model_name}: train={len(tokenized_train)} rows, eval={len(tokenized_eval)} rows")
        except Exception as e:
            print(f"Failed to tokenize for {model_name}: {str(e)}")
            logging.error(f"Failed to tokenize for {model_name}: {str(e)}")
            raise

    # Log sizes for sanity check
    print(f"Original sizes - MMLU: {len(mmlu_df)}, Human Ranked: {len(human_ranked_df)}")
    print(f"After processing - Train dataset: {len(train_dataset)}, Eval dataset: {len(eval_dataset)}")

    return tokenized_datasets, train_dataset, eval_dataset

# Execute
try:
    tokenized_datasets, train_dataset, eval_dataset = prepare_datasets(human_ranked_df, mmlu_df)
    print("Datasets prepared successfully!")
except Exception as e:
    print(f"Error preparing datasets: {str(e)}")
    logging.error(f"Error preparing datasets: {str(e)}")
    raise

Tokenizing datasets for model: sshleifer/distilbart-xsum-12-3


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.51k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

Map:   0%|          | 0/13821 [00:00<?, ? examples/s]

Map:   0%|          | 0/480 [00:00<?, ? examples/s]

Tokenized datasets for sshleifer/distilbart-xsum-12-3: train=13821 rows, eval=480 rows
Original sizes - MMLU: 13821, Human Ranked: 480
After processing - Train dataset: 13821, Eval dataset: 480
Datasets prepared successfully!


In [10]:
# Cell 8 - Train Models

import os
import logging
from transformers import Trainer, TrainingArguments

def train_all_models(tokenized_datasets, train_dataset, eval_dataset):
    # set up logging to file
    log_file = os.path.join(OUTPUT_DIR, 'training_log.txt')
    logging.basicConfig(
        filename=log_file,
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )

    # start message
    print(f"Starting train_all_models with {len(MODEL_NAMES)} models × {len(HYPERPARAMS)} hyperparam settings")
    logging.info(f"MODEL_NAMES={MODEL_NAMES}, HYPERPARAMS={list(HYPERPARAMS.keys())}")

    # dataset size checks
    print(f"Raw train_dataset size: {len(train_dataset)} rows")
    print(f"Raw eval_dataset size: {len(eval_dataset)} rows")
    logging.info(f"Raw train_dataset size: {len(train_dataset)} rows")
    logging.info(f"Raw eval_dataset size: {len(eval_dataset)} rows")

    if len(train_dataset) == 0:
        logging.error("train_dataset is empty.")
        raise ValueError("train_dataset is empty.")
    if len(eval_dataset) == 0:
        logging.error("eval_dataset is empty.")
        raise ValueError("eval_dataset is empty.")
    if len(eval_dataset) < 50:
        print("Warning: eval_dataset is very small (<50 rows).")
        logging.warning("eval_dataset is very small (<50 rows).")

    expected_variants = len(MODEL_NAMES) * len(HYPERPARAMS)
    print(f"Expecting to train {expected_variants} variants")
    logging.info(f"Expecting to train {expected_variants} variants")

    for model_name in MODEL_NAMES:
        # log tokenized dataset columns
        train_cols = tokenized_datasets[model_name]['train'].column_names
        eval_cols = tokenized_datasets[model_name]['eval'].column_names
        print(f"{model_name}: train cols={train_cols}, eval cols={eval_cols}")
        logging.info(f"{model_name} columns: train={train_cols}, eval={eval_cols}")

        # log a sample label
        sample = tokenized_datasets[model_name]['train'][0]
        logging.info(f"Sample train label for {model_name}: {sample['labels']}")

        for hyperparam_setting in HYPERPARAMS:
            variant_name = f"{model_name.replace('/', '_')}_{hyperparam_setting}"
            print(f"\n### Training {variant_name} ###")
            logging.info(f"Starting training for {variant_name}")

            try:
                da = DistractorAnalysis(model_name, hyperparam_setting)
                da.train(tokenized_datasets)
                logging.info(f"trainer.train() completed for {variant_name}")

                # verify checkpoint
                checkpoint_path = os.path.join(OUTPUT_DIR, f"{variant_name}_best")
                if os.path.exists(checkpoint_path):
                    print(f"✔️  Checkpoint found at {checkpoint_path}")
                    logging.info(f"Checkpoint verified at {checkpoint_path}")
                else:
                    print(f"⚠️  Checkpoint not found at {checkpoint_path}")
                    logging.warning(f"Checkpoint not found at {checkpoint_path}")

                print(f"Completed training for {variant_name}")
                logging.info(f"Completed training for {variant_name}")

            except Exception as e:
                error_msg = f"Failed training for {variant_name}: {e}"
                print(error_msg)
                logging.error(error_msg)
                continue

# Execute training
train_all_models(tokenized_datasets, train_dataset, eval_dataset)

Starting train_all_models with 1 models × 3 hyperparam settings
Raw train_dataset size: 13821 rows
Raw eval_dataset size: 480 rows
Expecting to train 3 variants
sshleifer/distilbart-xsum-12-3: train cols=['input_ids', 'attention_mask', 'labels', 'decoder_input_ids'], eval cols=['input_ids', 'attention_mask', 'labels', 'decoder_input_ids']

### Training sshleifer_distilbart-xsum-12-3_low ###




pytorch_model.bin:   0%|          | 0.00/716M [00:00<?, ?B/s]



<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnc-lonestar-tx[0m ([33mnc-lonestar-tx-university-of-north-texas[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss
1000,1.4574,1.42051
2000,1.4322,1.373131
3000,1.4046,1.38589


✔️  Checkpoint found at ./experiments/sshleifer_distilbart-xsum-12-3_low_best
Completed training for sshleifer_distilbart-xsum-12-3_low

### Training sshleifer_distilbart-xsum-12-3_medium ###




Step,Training Loss,Validation Loss
1000,1.4604,1.411631
2000,1.4142,1.39172
3000,1.4162,1.395612
4000,1.401,1.409791
5000,1.3944,1.40628


✔️  Checkpoint found at ./experiments/sshleifer_distilbart-xsum-12-3_medium_best
Completed training for sshleifer_distilbart-xsum-12-3_medium

### Training sshleifer_distilbart-xsum-12-3_high ###




Step,Training Loss,Validation Loss
1000,1.4284,1.397711
2000,1.4096,1.394878
3000,1.3957,1.396047


Step,Training Loss,Validation Loss
1000,1.4284,1.397711
2000,1.4096,1.394878
3000,1.3957,1.396047
4000,1.3927,1.393892


✔️  Checkpoint found at ./experiments/sshleifer_distilbart-xsum-12-3_high_best
Completed training for sshleifer_distilbart-xsum-12-3_high


In [11]:
# Cell 9 – Generate Predictions (log‐prob method)

import os
import pandas as pd
import torch

def generate_predictions(eval_dataset):
    """
    Run predict_logprob() for each model‐variant and write predictions.csv.
    Assumes eval_dataset rows have:
      'question', 'input_text', 'option_0'…'option_3', and 'correct_index'.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    rows = []

    for model_name in MODEL_NAMES:
        for hp in HYPERPARAM_SETTINGS:
            variant = f"{model_name.replace('/', '_')}_{hp}"
            ckpt    = os.path.join(OUTPUT_DIR, f"{variant}_best")
            print(f"→ Predicting {variant}")

            # reload model
            da = DistractorAnalysis(model_name, hp)
            da.model.load_state_dict(
                torch.load(os.path.join(ckpt,'pytorch_model.bin'), map_location=device)
            )

            # get preds + log‐probs
            preds, logps = da.predict_logprob(eval_dataset, batch_size=1, device=device)

            # build output rows
            for i, ex in enumerate(eval_dataset):
                rows.append({
                    'question':             ex['question'],
                    'model_name':           model_name,
                    'variant':              hp,
                    'predicted_choice':     int(preds[i]),
                    'correct_choice_index': int(ex['correct_index']),
                    'logprob_A':            float(logps[i,0]),
                    'logprob_B':            float(logps[i,1]),
                    'logprob_C':            float(logps[i,2]),
                    'logprob_D':            float(logps[i,3]),
                    'option_0':             ex['option_0'],
                    'option_1':             ex['option_1'],
                    'option_2':             ex['option_2'],
                    'option_3':             ex['option_3'],
                })
            print(f"✔ Done {variant}")

    # compile DataFrame & write out
    predictions_df = pd.DataFrame(rows)
    out_csv = os.path.join(OUTPUT_DIR, 'predictions.csv')
    predictions_df.to_csv(out_csv, index=False)
    print(f"All predictions saved to {out_csv}")

    return predictions_df

# Execute
predictions_df = generate_predictions(eval_dataset)

→ Predicting sshleifer_distilbart-xsum-12-3_low




✔ Done sshleifer_distilbart-xsum-12-3_low
→ Predicting sshleifer_distilbart-xsum-12-3_medium
✔ Done sshleifer_distilbart-xsum-12-3_medium
→ Predicting sshleifer_distilbart-xsum-12-3_high
✔ Done sshleifer_distilbart-xsum-12-3_high
All predictions saved to ./experiments/predictions.csv


##Prediction Method 1 - Initial Probability Estimation

In [None]:
# Cell 6 - Distractor Analysis Class
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Trainer, TrainingArguments
import torch
import numpy as np
import os
import logging

class DistractorAnalysis:
    def __init__(self, model_name, hyperparam_setting):
        self.model_name = model_name
        self.hyperparam_setting = hyperparam_setting

        if hyperparam_setting not in HYPERPARAMS:
            raise ValueError(f"Invalid hyperparam_setting: {hyperparam_setting}")

        # Load tokenizer and model
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        except Exception as e:
            raise RuntimeError(f"Failed to load model or tokenizer for {model_name}: {str(e)}")

        # Check choice tokens
        choice_tokens = ["A", "B", "C", "D"]
        self.choice_token_ids = []
        for token in choice_tokens:
            token_ids = self.tokenizer(token, add_special_tokens=False).input_ids
            if len(token_ids) != 1:
                raise ValueError(
                    f"Token '{token}' is not a single token in the tokenizer for {self.model_name}"
                )
            self.choice_token_ids.append(token_ids[0])

        self.training_args = TrainingArguments(
            output_dir=os.path.join(
                OUTPUT_DIR, f"{self.model_name.replace('/', '_')}_{self.hyperparam_setting}"
            ),
            per_device_train_batch_size=HYPERPARAMS[self.hyperparam_setting]['batch_size'],
            per_device_eval_batch_size=HYPERPARAMS[self.hyperparam_setting]['batch_size'],
            num_train_epochs=HYPERPARAMS[self.hyperparam_setting]['epochs'],
            learning_rate=HYPERPARAMS[self.hyperparam_setting]['learning_rate'],
            save_steps=1000,
            save_total_limit=2,
            logging_dir=os.path.join(
                OUTPUT_DIR, f"{self.model_name.replace('/', '_')}_{self.hyperparam_setting}_logs"
            ),
            logging_steps=500,
            evaluation_strategy="steps",
            eval_steps=1000,
            save_strategy="steps",
            load_best_model_at_end=True,
            metric_for_best_model="accuracy",
            greater_is_better=True,
            remove_unused_columns=False
        )

    def compute_metrics(self, eval_pred):
        """
        Compute accuracy on first-token choice logits.
        Handles both EvalPrediction and plain-tuple inputs.
        """
        # Unpack predictions and labels
        if hasattr(eval_pred, 'predictions') and hasattr(eval_pred, 'label_ids'):
            raw_preds = eval_pred.predictions
            labels = eval_pred.label_ids
        else:
            raw_preds, labels = eval_pred

        # If raw_preds is a tuple/list (e.g., (logits, _)), take the first element
        if isinstance(raw_preds, (tuple, list)):
            raw_preds = raw_preds[0]

        # If it’s a torch.Tensor, convert to numpy
        if hasattr(raw_preds, 'numpy'):
            raw_preds = raw_preds.numpy()

        # If seq2seq LM output [batch, seq_len, vocab_size], get only first-token logits
        if raw_preds.ndim == 3:
            first_logits = raw_preds[:, 0, :]
            # Keep only our A/B/C/D token columns
            label_logits = first_logits[:, self.choice_token_ids]
        else:
            # Already [batch, num_labels]
            label_logits = raw_preds

        # Argmax and accuracy
        preds = np.argmax(label_logits, axis=-1)
        acc = np.mean(preds == labels)
        return {"accuracy": acc}

    def train(self, tokenized_datasets):
        try:
            logging.info(f"Starting training for {self.model_name} with {self.hyperparam_setting}")
            trainer = Trainer(
                model=self.model,
                args=self.training_args,
                train_dataset=tokenized_datasets[self.model_name]['train'],
                eval_dataset=tokenized_datasets[self.model_name]['eval'],
                compute_metrics=self.compute_metrics
            )
            trainer.train()
            # Fix: use self.hyperparam_setting
            trainer.save_model(os.path.join(
                OUTPUT_DIR, f"{self.model_name.replace('/', '_')}_{self.hyperparam_setting}_best"
            ))
            self.tokenizer.save_pretrained(os.path.join(
                OUTPUT_DIR, f"{self.model_name.replace('/', '_')}_{self.hyperparam_setting}_best"
            ))
            logging.info(f"Training completed for {self.model_name} with {self.hyperparam_setting}")
        except Exception as e:
            logging.error(f"Training failed for {self.model_name} with {self.hyperparam_setting}: {str(e)}")
            raise

    def predict(self, dataset, batch_size=16, device='cuda'):
        # Device fallback
        actual_device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.model.to(actual_device).eval()

        logging.info(f"Running predict on {len(dataset)} samples with batch size {batch_size} on {actual_device}")

        all_logits = []
        with torch.no_grad():
            for i in range(0, len(dataset), batch_size):
                batch = dataset[i:i + batch_size]
                # Create input tensors directly from batch data
                input_ids = torch.tensor(batch['input_ids']).to(actual_device)
                attention_mask = torch.tensor(batch['attention_mask']).to(actual_device)

                # Set decoder_input_ids to start token
                start_token_id = (
                    self.tokenizer.pad_token_id
                    if self.tokenizer.pad_token_id is not None
                    else self.tokenizer.eos_token_id
                )
                decoder_input_ids = torch.full(
                    (input_ids.size(0), 1),  # Fixed: use input_ids instead of undefined inputs
                    start_token_id,
                    device=actual_device
                )

                # Get model outputs
                outputs = self.model(
                    input_ids=input_ids,  # Fixed: use local variables
                    attention_mask=attention_mask,  # Fixed: use local variables
                    decoder_input_ids=decoder_input_ids
                )

                # Extract logits for the first generated token
                logits = outputs.logits[:, 0, :]
                choice_logits = logits[:, self.choice_token_ids]  # [batch_size, 4]
                all_logits.append(choice_logits.cpu().numpy())

        logits = np.vstack(all_logits)  # [num_samples, 4]
        predictions = logits.argmax(axis=1)

        return predictions, logits

In [None]:
# Cell 7 - Prepare Datasets for Training

from datasets import Dataset
from transformers import AutoTokenizer
import pandas as pd
import logging

def prepare_datasets(human_ranked_df, mmlu_df):
    # Verify required global variables
    if 'MODEL_NAMES' not in globals():
        raise NameError("MODEL_NAMES global variable not defined. Please run Cell #3 first.")

    option_columns = ['option_0', 'option_1', 'option_2', 'option_3']

    # Map answer index to choice token (A, B, C, D)
    def map_answer_to_choice(row, answer_col):
        try:
            answer_idx = int(row[answer_col])
            if answer_idx not in [0, 1, 2, 3]:
                return None
            return chr(65 + answer_idx)  # 0 -> 'A', 1 -> 'B', 2 -> 'C', 3 -> 'D'
        except (ValueError, TypeError, KeyError):
            return None

    # Validate columns in mmlu_df
    missing_mmlu_cols = [col for col in ['question', 'correct_answer'] + option_columns if col not in mmlu_df.columns]
    if missing_mmlu_cols:
        print(f"Warning: Missing columns in mmlu_df: {missing_mmlu_cols}. Available columns: {mmlu_df.columns.tolist()}")
        fallback_cols = ['choice1', 'choice2', 'choice3', 'choice4']
        if all(col in mmlu_df.columns for col in fallback_cols):
            print("Falling back to choice1, choice2, choice3, choice4 for mmlu_df")
            mmlu_df = mmlu_df.rename(columns={
                'choice1': 'option_0',
                'choice2': 'option_1',
                'choice3': 'option_2',
                'choice4': 'option_3'
            })
        else:
            raise ValueError("Cannot proceed: mmlu_df missing required option columns")

    # Validate columns in human_ranked_df - using consistent approach with mmlu_df
    missing_human_cols = [col for col in ['question', 'correct_answer'] + option_columns if col not in human_ranked_df.columns]
    if missing_human_cols:
        print(f"Warning: Missing columns in human_ranked_df: {missing_human_cols}. Available columns: {human_ranked_df.columns.tolist()}")
        # Could add fallback logic here similar to mmlu_df if needed
        raise ValueError("Cannot proceed: human_ranked_df missing required columns")

    # Process mmlu_df with safer type conversion
    mmlu_df = mmlu_df.copy()
    mmlu_df['correct_choice'] = mmlu_df.apply(lambda row: map_answer_to_choice(row, 'correct_answer'), axis=1)

    # Safer conversion with validation
    mmlu_df['correct_index'] = mmlu_df['correct_answer'].apply(
        lambda x: int(x) if pd.notna(x) and str(x).strip().isdigit() else None
    )

    mmlu_df_dropped = mmlu_df.loc[(mmlu_df['correct_choice'].isna()) | (mmlu_df['correct_index'].isna())].copy()
    mmlu_df = mmlu_df.loc[(mmlu_df['correct_choice'].notna()) & (mmlu_df['correct_index'].notna())].copy()
    if len(mmlu_df_dropped) > 0:
        print(f"Dropped {len(mmlu_df_dropped)} rows from mmlu_df. Saved to ./dropped_mmlu_choice_rows.csv")
        mmlu_df_dropped.to_csv('./dropped_mmlu_choice_rows.csv', index=False)

    # Process human_ranked_df with safer type conversion
    human_ranked_df = human_ranked_df.copy()
    human_ranked_df['correct_choice'] = human_ranked_df.apply(lambda row: map_answer_to_choice(row, 'correct_answer'), axis=1)

    # Safer conversion with validation
    human_ranked_df['correct_index'] = human_ranked_df['correct_answer'].apply(
        lambda x: int(x) if pd.notna(x) and str(x).strip().isdigit() else None
    )

    human_ranked_df_dropped = human_ranked_df.loc[(human_ranked_df['correct_choice'].isna()) | (human_ranked_df['correct_index'].isna())].copy()
    human_ranked_df = human_ranked_df.loc[(human_ranked_df['correct_choice'].notna()) & (human_ranked_df['correct_index'].notna())].copy()
    if len(human_ranked_df_dropped) > 0:
        print(f"Dropped {len(human_ranked_df_dropped)} rows from human_ranked_df. Saved to ./dropped_human_ranked_choice_rows.csv")
        human_ranked_df_dropped.to_csv('./dropped_human_ranked_choice_rows.csv', index=False)

    # Create input text
    def create_input_text(row):
        return f"Question: {row['question']} Options: A: {row['option_0']} B: {row['option_1']} C: {row['option_2']} D: {row['option_3']}"

    mmlu_df['input_text'] = mmlu_df.apply(create_input_text, axis=1)
    human_ranked_df['input_text'] = human_ranked_df.apply(create_input_text, axis=1)

    # Create datasets
    train_dataset = Dataset.from_pandas(mmlu_df[['input_text', 'correct_choice', 'question', 'option_0', 'option_1', 'option_2', 'option_3', 'correct_index']])
    eval_dataset = Dataset.from_pandas(human_ranked_df[['input_text', 'correct_choice', 'question', 'option_0', 'option_1', 'option_2', 'option_3', 'correct_index']])

    # Remove pandas index column if present
    for dataset in [train_dataset, eval_dataset]:
        if '__index_level_0__' in dataset.column_names:
            dataset = dataset.remove_columns('__index_level_0__')

    # Tokenization
    tokenized_datasets = {}
    for model_name in MODEL_NAMES:
        try:
            print(f"Tokenizing datasets for model: {model_name}")
            tokenizer = AutoTokenizer.from_pretrained(model_name)

            def tokenize_function(examples):
                model_inputs = tokenizer(
                    examples['input_text'],
                    max_length=512,
                    truncation=True,
                    padding='max_length',
                    return_tensors=None
                )
                # Tokenize correct_choice as labels
                labels = tokenizer(examples['correct_choice'], add_special_tokens=False).input_ids
                model_inputs['labels'] = [label[0] if len(label) == 1 else tokenizer.unk_token_id for label in labels]
                # Set decoder_input_ids for seq2seq models
                start_token = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
                model_inputs['decoder_input_ids'] = [[start_token]] * len(examples['input_text'])
                return model_inputs

            # Updated code to be more selective about which columns to keep
            tokenized_train = train_dataset.map(
                tokenize_function,
                batched=True,
                # Only keep columns needed by the model
                remove_columns=[col for col in train_dataset.column_names
                               if col not in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids']]
            )
            tokenized_eval = eval_dataset.map(
                tokenize_function,
                batched=True,
                # Only keep columns needed by the model
                remove_columns=[col for col in eval_dataset.column_names
                               if col not in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids']]
            )

            # Verify the tokenized datasets have the required fields for DistractorAnalysis.predict
            required_fields = ['input_ids', 'attention_mask', 'labels']
            missing_train = [f for f in required_fields if f not in tokenized_train.column_names]
            missing_eval = [f for f in required_fields if f not in tokenized_eval.column_names]

            if missing_train or missing_eval:
                print(f"Warning: Missing fields for {model_name}. Train: {missing_train}, Eval: {missing_eval}")
                if missing_train or missing_eval:
                    raise ValueError(f"Tokenized datasets missing required fields for model {model_name}")

            tokenized_datasets[model_name] = {'train': tokenized_train, 'eval': tokenized_eval}
            print(f"Tokenized datasets for {model_name}: train={len(tokenized_train)} rows, eval={len(tokenized_eval)} rows")
        except Exception as e:
            print(f"Failed to tokenize for {model_name}: {str(e)}")
            logging.error(f"Failed to tokenize for {model_name}: {str(e)}")
            raise

    # Log sizes for sanity check
    print(f"Original sizes - MMLU: {len(mmlu_df)}, Human Ranked: {len(human_ranked_df)}")
    print(f"After processing - Train dataset: {len(train_dataset)}, Eval dataset: {len(eval_dataset)}")

    return tokenized_datasets, train_dataset, eval_dataset

# Execute
try:
    tokenized_datasets, train_dataset, eval_dataset = prepare_datasets(human_ranked_df, mmlu_df)
    print("Datasets prepared successfully!")
except Exception as e:
    print(f"Error preparing datasets: {str(e)}")
    logging.error(f"Error preparing datasets: {str(e)}")
    raise

In [None]:
# Cell 8 - Train Models

import os
import logging
from transformers import Trainer, TrainingArguments

def train_all_models(tokenized_datasets, train_dataset, eval_dataset):
    # set up logging to file
    log_file = os.path.join(OUTPUT_DIR, 'training_log.txt')
    logging.basicConfig(
        filename=log_file,
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )

    # start message
    print(f"Starting train_all_models with {len(MODEL_NAMES)} models × {len(HYPERPARAMS)} hyperparam settings")
    logging.info(f"MODEL_NAMES={MODEL_NAMES}, HYPERPARAMS={list(HYPERPARAMS.keys())}")

    # dataset size checks
    print(f"Raw train_dataset size: {len(train_dataset)} rows")
    print(f"Raw eval_dataset size: {len(eval_dataset)} rows")
    logging.info(f"Raw train_dataset size: {len(train_dataset)} rows")
    logging.info(f"Raw eval_dataset size: {len(eval_dataset)} rows")

    if len(train_dataset) == 0:
        logging.error("train_dataset is empty.")
        raise ValueError("train_dataset is empty.")
    if len(eval_dataset) == 0:
        logging.error("eval_dataset is empty.")
        raise ValueError("eval_dataset is empty.")
    if len(eval_dataset) < 50:
        print("Warning: eval_dataset is very small (<50 rows).")
        logging.warning("eval_dataset is very small (<50 rows).")

    expected_variants = len(MODEL_NAMES) * len(HYPERPARAMS)
    print(f"Expecting to train {expected_variants} variants")
    logging.info(f"Expecting to train {expected_variants} variants")

    for model_name in MODEL_NAMES:
        # log tokenized dataset columns
        train_cols = tokenized_datasets[model_name]['train'].column_names
        eval_cols = tokenized_datasets[model_name]['eval'].column_names
        print(f"{model_name}: train cols={train_cols}, eval cols={eval_cols}")
        logging.info(f"{model_name} columns: train={train_cols}, eval={eval_cols}")

        # log a sample label
        sample = tokenized_datasets[model_name]['train'][0]
        logging.info(f"Sample train label for {model_name}: {sample['labels']}")

        for hyperparam_setting in HYPERPARAMS:
            variant_name = f"{model_name.replace('/', '_')}_{hyperparam_setting}"
            print(f"\n### Training {variant_name} ###")
            logging.info(f"Starting training for {variant_name}")

            try:
                da = DistractorAnalysis(model_name, hyperparam_setting)
                da.train(tokenized_datasets)
                logging.info(f"trainer.train() completed for {variant_name}")

                # verify checkpoint
                checkpoint_path = os.path.join(OUTPUT_DIR, f"{variant_name}_best")
                if os.path.exists(checkpoint_path):
                    print(f"✔️  Checkpoint found at {checkpoint_path}")
                    logging.info(f"Checkpoint verified at {checkpoint_path}")
                else:
                    print(f"⚠️  Checkpoint not found at {checkpoint_path}")
                    logging.warning(f"Checkpoint not found at {checkpoint_path}")

                print(f"Completed training for {variant_name}")
                logging.info(f"Completed training for {variant_name}")

            except Exception as e:
                error_msg = f"Failed training for {variant_name}: {e}"
                print(error_msg)
                logging.error(error_msg)
                continue

# Execute training
train_all_models(tokenized_datasets, train_dataset, eval_dataset)

In [None]:
# Cell 9 - Generate predictions

import pandas as pd
import torch
import logging
import os

def generate_predictions(eval_dataset, tokenized_datasets):
    log_file = os.path.join(OUTPUT_DIR, 'predictions_log.txt')
    logging.basicConfig(filename=log_file, level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s')

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    results = []

    for model_name in MODEL_NAMES:
        for hyperparam_setting in HYPERPARAMS:
            variant_name = f"{model_name.replace('/', '_')}_{hyperparam_setting}"
            checkpoint_path = os.path.join(OUTPUT_DIR, f"{variant_name}_best")

            print(f"Generating predictions for {variant_name}...")
            logging.info(f"Generating predictions for {variant_name}")

            try:
                da = DistractorAnalysis(model_name, hyperparam_setting)
                checkpoint_file = os.path.join(checkpoint_path, 'pytorch_model.bin')
                if not os.path.exists(checkpoint_file):
                    raise FileNotFoundError(f"Checkpoint not found at {checkpoint_file}")

                da.model.load_state_dict(torch.load(checkpoint_file, map_location=device))

                # Use the tokenized eval dataset specific to this model
                # This ensures we're using data with the right format for the model
                tokenized_eval = tokenized_datasets[model_name]['eval']

                predictions, choice_logits = da.predict(
                    tokenized_eval,
                    batch_size=HYPERPARAMS[hyperparam_setting]['batch_size'],
                    device=device
                )

                # To match tokenized predictions with original data attributes
                for i in range(len(predictions)):
                    result = {
                        'question': eval_dataset[i]['question'],
                        'model_name': model_name,
                        'variant': hyperparam_setting,
                        'predicted_choice': int(predictions[i]),
                        'correct_choice_index': int(eval_dataset[i]['correct_index']),
                        'logit_A': float(choice_logits[i][0]),
                        'logit_B': float(choice_logits[i][1]),
                        'logit_C': float(choice_logits[i][2]),
                        'logit_D': float(choice_logits[i][3]),
                        'option_0': eval_dataset[i]['option_0'],
                        'option_1': eval_dataset[i]['option_1'],
                        'option_2': eval_dataset[i]['option_2'],
                        'option_3': eval_dataset[i]['option_3']
                    }
                    results.append(result)

                print(f"Completed predictions for {variant_name}")
                logging.info(f"Completed predictions for {variant_name}")
            except Exception as e:
                error_msg = f"Failed predictions for {variant_name}: {str(e)}"
                print(error_msg)
                logging.error(error_msg)
                continue

    results_df = pd.DataFrame(results)
    output_csv = os.path.join(OUTPUT_DIR, 'predictions.csv')
    results_df.to_csv(output_csv, index=False)
    print(f"Predictions saved to {output_csv}")
    logging.info(f"Predictions saved to {output_csv}")

    return results_df

# Execute the function with both the original evaluation dataset (for metadata)
# and the tokenized datasets (for model input)
predictions_df = generate_predictions(eval_dataset, tokenized_datasets)

##Prediction Method 2 - Classification Head (Encoder Only)

In [None]:
# Cell 6 – Distractor Analysis Class (Classification Head)

from transformers import AutoTokenizer, Trainer, TrainingArguments
import torch
import torch.nn as nn
import numpy as np
import os
import logging

# 4-way classification head on top of T5/BART encoders
class ClassificationModel(nn.Module):
    def __init__(self, model_name, num_labels):
        super().__init__()
        if 't5' in model_name.lower():
            from transformers import T5Model
            self.encoder = T5Model.from_pretrained(model_name).get_encoder()
        else:
            from transformers import BartModel
            self.encoder = BartModel.from_pretrained(model_name).get_encoder()
        hidden_size = self.encoder.config.d_model
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_size, num_labels)

    def forward(self, input_ids, attention_mask=None, labels=None):
        # Encode and pool <s> token
        enc = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = self.dropout(enc.last_hidden_state[:, 0, :])
        logits = self.classifier(pooled)          # [batch, num_labels]

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)
        return {'logits': logits, 'loss': loss}

class DistractorAnalysis:
    def __init__(self, model_name, hyperparam_setting):
        self.model_name = model_name
        self.hyperparam_setting = hyperparam_setting

        if hyperparam_setting not in HYPERPARAMS:
            raise ValueError(f"Invalid hyperparam_setting: {hyperparam_setting}")

        # Load tokenizer + classification model
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = ClassificationModel(model_name, NUM_LABELS)

        # TrainingArguments setup
        self.training_args = TrainingArguments(
            output_dir=os.path.join(OUTPUT_DIR, f"{model_name.replace('/', '_')}_{hyperparam_setting}"),
            per_device_train_batch_size=HYPERPARAMS[hyperparam_setting]['batch_size'],
            per_device_eval_batch_size= HYPERPARAMS[hyperparam_setting]['batch_size'],
            num_train_epochs=         HYPERPARAMS[hyperparam_setting]['epochs'],
            learning_rate=            HYPERPARAMS[hyperparam_setting]['learning_rate'],
            evaluation_strategy=      "steps",
            eval_steps=               1000,
            logging_steps=            500,
            save_steps=               1000,
            save_total_limit=         2,
            load_best_model_at_end=   True,
            metric_for_best_model=    "accuracy",
            greater_is_better=        True,
            remove_unused_columns=    False,
            logging_dir=os.path.join(OUTPUT_DIR, f"{model_name.replace('/', '_')}_{hyperparam_setting}_logs"),
        )

    def compute_metrics(self, eval_pred):
        # Unpack EvalPrediction or tuple
        if hasattr(eval_pred, 'predictions') and hasattr(eval_pred, 'label_ids'):
            logits = eval_pred.predictions
            labels = eval_pred.label_ids
        else:
            logits, labels = eval_pred
        preds = np.argmax(logits, axis=-1)
        acc = np.mean(preds == labels)
        return {'accuracy': acc}

    def train(self, tokenized_datasets):
        trainer = Trainer(
            model=self.model,
            args=self.training_args,
            train_dataset=tokenized_datasets[self.model_name]['train'],
            eval_dataset= tokenized_datasets[self.model_name]['eval'],
            compute_metrics=self.compute_metrics
        )
        trainer.train()
        ckpt = os.path.join(OUTPUT_DIR, f"{self.model_name.replace('/', '_')}_{self.hyperparam_setting}_best")
        trainer.save_model(ckpt)
        self.tokenizer.save_pretrained(ckpt)

    def predict(self, dataset, batch_size=16, device='cuda'):
        # Classification: no decoder inputs needed
        dev = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.model.to(dev).eval()

        all_logits, all_preds = [], []
        with torch.no_grad():
            for i in range(0, len(dataset), batch_size):
                batch = dataset[i : i + batch_size]
                inputs = self.tokenizer(
                    batch['input_text'],
                    padding='max_length',
                    truncation=True,
                    max_length=512,
                    return_tensors='pt'
                ).to(dev)

                out = self.model(input_ids=inputs.input_ids,
                                 attention_mask=inputs.attention_mask)
                logits = out['logits']                    # [batch, NUM_LABELS]
                preds  = torch.argmax(logits, dim=-1)     # [batch]

                all_logits.append(logits.cpu().numpy())
                all_preds.append(preds.cpu().numpy())

        logits_arr = np.vstack(all_logits)
        preds_arr  = np.concatenate(all_preds)
        return preds_arr, logits_arr

In [None]:
# Cell 7 – Prepare Datasets

from datasets import Dataset
import pandas as pd
from transformers import AutoTokenizer

def prepare_datasets(human_ranked_df, mmlu_df):
    """
    Returns:
      tokenized_datasets: {model_name: {'train','eval'}}
      train_dataset (HF Dataset), eval_dataset (HF Dataset)
    """
    option_cols = ['option_0','option_1','option_2','option_3']

    # Build unified input_text and integer labels
    def make_input(qrow):
        opts = " ".join(f"{chr(65+i)}: {qrow[c]}" for i,c in enumerate(option_cols))
        return f"Question: {qrow['question']} Options: {opts}"

    mmlu_df['input_text'] = mmlu_df.apply(make_input, axis=1)
    mmlu_df['labels']     = mmlu_df['correct_answer'].astype(int)

    human_ranked_df['input_text'] = human_ranked_df.apply(make_input, axis=1)
    human_ranked_df['labels']     = human_ranked_df['correct_answer'].astype(int)

    # Create HF Datasets
    train_dataset = Dataset.from_pandas(mmlu_df[['input_text','labels']].reset_index(drop=True))
    eval_dataset  = Dataset.from_pandas(human_ranked_df[['input_text','labels','question',
                                                          'option_0','option_1','option_2','option_3']].reset_index(drop=True))

    # Tokenize per model
    tokenized_datasets = {}
    for model_name in MODEL_NAMES:
        tok = AutoTokenizer.from_pretrained(model_name)
        def tokenize_fn(examples):
            out = tok(examples['input_text'],
                      padding='max_length',
                      truncation=True,
                      max_length=512)
            out['labels'] = examples['labels']
            return out

        ttrain = train_dataset.map(tokenize_fn, batched=True, remove_columns=train_dataset.column_names)
        teval  = eval_dataset.map( tokenize_fn, batched=True, remove_columns= eval_dataset.column_names)
        tokenized_datasets[model_name] = {'train': ttrain, 'eval': teval}
        print(f"Tokenized for {model_name}: train={len(ttrain)}, eval={len(teval)}")

    return tokenized_datasets, train_dataset, eval_dataset



# Execute
try:
    tokenized_datasets, train_dataset, eval_dataset = prepare_datasets(human_ranked_df, mmlu_df)
    print("Datasets prepared successfully!")
except Exception as e:
    print(f"Error preparing datasets: {str(e)}")
    logging.error(f"Error preparing datasets: {str(e)}")
    raise

In [None]:
# Cell 8 – Train All Models

import os
import logging

def train_all_models(tokenized_datasets, train_dataset, eval_dataset):
    # Setup logging once
    log_file = os.path.join(OUTPUT_DIR, 'training_log.txt')
    logging.basicConfig(filename=log_file, level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s')

    # Sanity checks
    if len(train_dataset)==0 or len(eval_dataset)==0:
        raise ValueError("Empty train or eval dataset.")

    # Loop variants
    for model_name in MODEL_NAMES:
        for hp in HYPERPARAM_SETTINGS:
            variant_name = f"{model_name.replace('/', '_')}_{hp}"
            print(f"Training {variant_name}...")
            logging.info(f"Training {variant_name}")
            try:
                da = DistractorAnalysis(model_name, hp)
                da.train(tokenized_datasets)
                print(f"✓ Completed {variant_name}")
            except Exception as e:
                print(f"✗ Failed {variant_name}: {e}")
                logging.error(f"{variant_name} failed: {e}")
                continue

# Execute
train_all_models(tokenized_datasets, train_dataset, eval_dataset)

In [None]:
# Cell 9 – Generate Predictions

import os
import pandas as pd
import torch

def generate_predictions(eval_dataset):
    """
    Runs each model‐variant over eval_dataset, collects preds+logits,
    and writes a single predictions.csv with one row per question/variant.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    rows = []

    for model_name in MODEL_NAMES:
        for hp in HYPERPARAM_SETTINGS:
            variant_name = f"{model_name.replace('/', '_')}_{hp}"
            ckpt_dir = os.path.join(OUTPUT_DIR, f"{variant_name}_best")
            model_bin = os.path.join(ckpt_dir, 'pytorch_model.bin')
            print(f"Predicting {variant_name}…")

            da = DistractorAnalysis(model_name, hp)
            da.model.load_state_dict(torch.load(model_bin, map_location=device))

            preds, logits = da.predict(eval_dataset, batch_size=HYPERPARAMS[hp]['batch_size'], device=device)

            # Build rows
            for i, qrow in enumerate(eval_dataset):
                rows.append({
                    'question':            qrow['question'],
                    'model_name':          model_name,
                    'variant':             hp,
                    'predicted_choice':    int(preds[i]),
                    'correct_choice_index': int(qrow['labels']),
                    'logit_0':             float(logits[i,0]),
                    'logit_1':             float(logits[i,1]),
                    'logit_2':             float(logits[i,2]),
                    'logit_3':             float(logits[i,3]),
                    'option_0':            qrow['option_0'],
                    'option_1':            qrow['option_1'],
                    'option_2':            qrow['option_2'],
                    'option_3':            qrow['option_3'],
                })
            print(f"✓ Done {variant_name}")

    # Compile DataFrame + write CSV
    predictions_df = pd.DataFrame(rows)
    out_csv = os.path.join(OUTPUT_DIR, 'predictions.csv')
    predictions_df.to_csv(out_csv, index=False)
    print(f"All predictions saved to {out_csv}")
    return predictions_df

# Execute
predictions_df = generate_predictions(eval_dataset)