In [None]:
import os
import ast
import torch
import pandas as pd
import numpy as np
from dataclasses import dataclass
from typing import Optional, Union
from datasets import load_dataset, Dataset, concatenate_datasets
from transformers import (
    AutoTokenizer, 
    AutoModelForMultipleChoice, 
    TrainingArguments, 
    Trainer,
    PreTrainedTokenizerBase
)
from transformers.tokenization_utils_base import PaddingStrategy

# -----------------------------------------------------------------------------
# 1. Configuration & Hyperparameters (Ref: ReClor Paper Table 9)
# -----------------------------------------------------------------------------
MODEL_ID = "FacebookAI/roberta-large"
TRAIN_FILE = "train.json"
VAL_FILE = "val.json"
TEST_FILE = "test.csv"
OUTPUT_DIR = "./reclor_roberta_large_finetuned"

# Paper settings for RoBERTa-Large
MAX_SEQ_LENGTH = 256
LEARNING_RATE = 1e-5
NUM_TRAIN_EPOCHS = 10
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.1
ADAM_EPSILON = 1e-6
ADAM_BETAS = (0.9, 0.98)

# Batch size handling
PER_DEVICE_BATCH_SIZE = 2   # Adjust based on your VRAM (2 fits most Colab GPUs)
EFFECTIVE_BATCH_SIZE = 24   # As per paper
GRAD_ACCUMULATION = EFFECTIVE_BATCH_SIZE // PER_DEVICE_BATCH_SIZE

# Validation set target size (standard ReClor Val size is 500)
TARGET_VAL_SIZE = 500

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -----------------------------------------------------------------------------
# 2. Data Preprocessing & Collator
# -----------------------------------------------------------------------------
def preprocess_function(examples):
    # ReClor structure: Context + Question + [Option 0...3]
    first_sentences = [[context] * 4 for context in examples["context"]]
    question_headers = examples["question"]
    
    second_sentences = []
    for i, header in enumerate(question_headers):
        options = examples["answers"][i]
        
        # Robust parsing for stringified lists
        if isinstance(options, str):
            try:
                options = ast.literal_eval(options)
            except:
                options = ["", "", "", ""] 

        second_sentences.append([f"{header} {option}" for option in options])

    # Flatten
    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    # Tokenize
    tokenized_examples = tokenizer(
        first_sentences,
        second_sentences,
        truncation=True,
        max_length=MAX_SEQ_LENGTH,
        padding=False, 
    )

    # Un-flatten
    return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}

@dataclass
class DataCollatorForMultipleChoice:
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature.pop(label_name) for feature in features] if label_name in features[0].keys() else None
        
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
        ]
        flattened_features = sum(flattened_features, [])
        
        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        
        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        
        if labels is not None:
            batch["labels"] = torch.tensor(labels, dtype=torch.int64)
            
        return batch

# -----------------------------------------------------------------------------
# 3. Main Execution
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    
    # --- Load Tokenizer ---
    print(f"Loading tokenizer from {MODEL_ID}...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

    # --- Load Local Data ---
    print(f"Loading datasets...")
    data_files = {"train": TRAIN_FILE, "validation": VAL_FILE}
    dataset = load_dataset("json", data_files=data_files)

    # -------------------------------------------------------------------------
    # STEP A: Remove Data Leakage
    # -------------------------------------------------------------------------
    print("Checking for data leakage (overlapping questions in Test set)...")
    
    test_df_ref = pd.read_csv(TEST_FILE)
    # Create signature tuple (context, question) for unique identification
    test_signatures = set(zip(test_df_ref['context'].str.strip(), test_df_ref['question'].str.strip()))
    
    print(f"Test set contains {len(test_signatures)} unique context/question pairs.")

    def filter_leakage(example):
        c = example['context'].strip()
        q = example['question'].strip()
        # Keep example only if it is NOT in the test signatures
        return (c, q) not in test_signatures

    # Store original sizes
    orig_train_len = len(dataset['train'])
    orig_val_len = len(dataset['validation'])
    
    # Apply filter
    dataset = dataset.filter(filter_leakage)
    
    print(f"Train filtered: {orig_train_len} -> {len(dataset['train'])} (Removed {orig_train_len - len(dataset['train'])})")
    print(f"Val filtered:   {orig_val_len} -> {len(dataset['validation'])} (Removed {orig_val_len - len(dataset['validation'])})")

    # -------------------------------------------------------------------------
    # STEP B: Rebalance Validation Set (if too small)
    # -------------------------------------------------------------------------
    current_val_len = len(dataset['validation'])
    
    if current_val_len < TARGET_VAL_SIZE:
        needed = TARGET_VAL_SIZE - current_val_len
        print(f"\nValidation set is below target ({current_val_len} < {TARGET_VAL_SIZE}).")
        print(f"Transferring {needed} samples from Training to Validation...")
        
        # Ensure we don't drain the training set completely
        if needed > len(dataset['train']) * 0.2:
            print("Warning: Requested transfer is large relative to training set. Capping transfer.")
            needed = int(len(dataset['train']) * 0.2)
            
        # Split the training set
        # We use a fixed seed for reproducibility
        split_data = dataset['train'].train_test_split(test_size=needed, seed=42)
        
        new_train_set = split_data['train']
        moved_samples = split_data['test']
        
        # Combine existing val with moved samples
        new_val_set = concatenate_datasets([dataset['validation'], moved_samples])
        
        # Update main dataset object
        dataset['train'] = new_train_set
        dataset['validation'] = new_val_set
        
        print(f"Rebalancing Complete.")
        print(f"Final Train Size: {len(dataset['train'])}")
        print(f"Final Val Size:   {len(dataset['validation'])}")
    else:
        print("\nValidation set size is sufficient.")

    # -------------------------------------------------------------------------
    # Training
    # -------------------------------------------------------------------------
    print("\nPreprocessing datasets for training...")
    tokenized_reclor = dataset.map(preprocess_function, batched=True)
    
    print(f"Loading {MODEL_ID} for Multiple Choice...")
    model = AutoModelForMultipleChoice.from_pretrained(MODEL_ID)
    model.to(device)

    # Paper Specs via TrainingArguments
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        eval_strategy="epoch", 
        save_strategy="epoch",
        learning_rate=LEARNING_RATE,
        per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
        per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE,
        gradient_accumulation_steps=GRAD_ACCUMULATION,
        num_train_epochs=NUM_TRAIN_EPOCHS,
        weight_decay=WEIGHT_DECAY,
        warmup_ratio=WARMUP_RATIO,
        adam_epsilon=ADAM_EPSILON,
        adam_beta1=ADAM_BETAS[0],
        adam_beta2=ADAM_BETAS[1],
        max_grad_norm=None, 
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        logging_steps=50,
        fp16=torch.cuda.is_available(), 
        report_to="none"
    )

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        return {"accuracy": (predictions == labels).mean()}

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_reclor["train"],
        eval_dataset=tokenized_reclor["validation"],
        tokenizer=tokenizer,
        data_collator=DataCollatorForMultipleChoice(tokenizer),
        compute_metrics=compute_metrics,
    )

    print("Starting training...")
    trainer.train()
    
    print(f"Saving best model to {OUTPUT_DIR}...")
    trainer.save_model(OUTPUT_DIR)

    # -------------------------------------------------------------------------
    # Inference
    # -------------------------------------------------------------------------
    print(f"\nLoading test data from {TEST_FILE}...")
    test_df = pd.read_csv(TEST_FILE)
    
    test_df['answers'] = test_df['answers'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    test_df['label'] = 0 
    
    test_dataset = Dataset.from_pandas(test_df)
    
    print("Preprocessing test data...")
    tokenized_test = test_dataset.map(preprocess_function, batched=True)

    print("Running predictions...")
    predictions_output = trainer.predict(tokenized_test)
    preds = np.argmax(predictions_output.predictions, axis=1)

    # Save Submission
    submission_df = pd.DataFrame({
        "id": test_df["id"],
        "label": preds
    })
    
    submission_file = "submission.csv"
    submission_df.to_csv(submission_file, index=False)
    print(f"Done! Predictions saved to {submission_file}")