# **Installation and Imports**

In [None]:
!pip install -U transformers datasets huggingface_hub bitsandbytes accelerate

In [None]:
from datasets import load_dataset, DatasetDict, Dataset
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
import torch
import pandas as pd
from tqdm.auto import tqdm
import os
from huggingface_hub import HfApi, login

# **Load and Modify the Dataset**

In [None]:
ds = load_dataset("Salesforce/cos_e", "v1.11", split={"train": "train", "test": "validation"})

def modify_example(example):
    # Generate explicit choice labels
    choice_labels = ['(a)', '(b)', '(c)', '(d)', '(e)']
    formatted_choices = "\n".join([f"{choice_labels[i]} {choice}" for i, choice in enumerate(example["choices"])])

    # Merge 'question' and formatted 'choices' into 'input'
    input_text = f"{example['question']}\nAnswer Choices:\n{formatted_choices}"
    return {
        "input": input_text,  # Create 'input' column
        "label": example["answer"]  # Rename 'answer' to 'label'
    }

dataset = {split: data.map(modify_example, remove_columns=['id', 'question', 'choices', 'answer', 'abstractive_explanation', 'extractive_explanation']) for split, data in ds.items()}

In [4]:
print(dataset)
print(dataset['train'][0])

{'train': Dataset({
    features: ['input', 'label'],
    num_rows: 9741
}), 'test': Dataset({
    features: ['input', 'label'],
    num_rows: 1221
})}
{'input': '"There are 10 apples on an apple tree.  Three fall off.  Now there are X apples."  What is this an example of?\nAnswer Choices:\n(a) park\n(b) coloring book\n(c) garden center\n(d) math problem\n(e) gravity', 'label': 'math problem'}


# **Load Model**

In [None]:
model_name = "Qwen/Qwen2.5-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,
    device_map="auto",
    trust_remote_code=True
)

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=256,
    temperature=0.1,
    top_p=0.95,
    repetition_penalty=1.1,
    batch_size=16
)

# **Create Prompt Template**
#### Defined total of 10 custom examples as stated in the original paper

In [None]:
def create_prompt_template():
    header = """I'll show you several examples of multiple-choice questions with their answers and rationales.
    For each new question, I want you to:
    1. Think through each answer choice carefully
    2. Explain why the correct answer makes sense
    3. Generate a unique and specific rationale for why the given label is correct
    4. Make sure your rationale relates directly to the specific content of the question
    5. Exclude any bracketed placeholders, markers, or artifacts such as `{start}`, `{end}`, `{rationale end}`, `{example}`, or similar text.
    6. Maintain coherence throughout the explanation, avoiding abrupt jumps, missing reasoning, or unclear conclusions.

    Examples:
    """

    examples = [
        {
            "input": "What home entertainment equipment requires cable?\n"
                     "Answer Choices:\n"
                     "(a) radio shack\n"
                     "(b) substation\n"
                     "(c) television\n"
                     "(d) cabinet\n"
                     "(e) desk",
            "label": "television",
            "rationale": "To determine what home entertainment equipment requires cable, I need to analyze each option. Radio shack is a store, not equipment. A substation is related to electricity distribution, not home entertainment. Television is home entertainment equipment that typically requires cable connections to receive signals from cable providers, antennas, or satellite dishes. Cabinet and desk are furniture items, not electronic equipment. Therefore, among the given options, television is the only home entertainment equipment that requires cable."
        },
        {
            "input": "The fox walked from the city into the forest, what was it looking for?\n"
                     "Answer Choices:\n"
                     "(a) food\n"
                     "(b) shelter\n"
                     "(c) escape\n"
                     "(d) friends\n"
                     "(e) exercise",
            "label": "food",
            "rationale": "To determine what the fox was looking for, I need to consider the natural behavior of foxes and the context of moving from a city to a forest. Foxes are wild animals that primarily hunt for prey in natural environments. While foxes might seek shelter, escape from threats, companions, or exercise, their primary motivation for movement is usually finding food. Forests provide better hunting opportunities than cities. Given that the fox is deliberately moving from an urban area to a more natural environment, it's most likely searching for food, which is the primary driver of animal migration between habitats."
        },
        {
            "input": "What could cause a plant to die?\n"
                     "Answer Choices:\n"
                     "(a) lack of sunlight\n"
                     "(b) over cultivation\n"
                     "(c) lack of space\n"
                     "(d) all of the above\n"
                     "(e) none of the above",
            "label": "all of the above",
            "rationale": "To determine what could cause a plant to die, I need to evaluate each option. Lack of sunlight prevents photosynthesis, which plants need to produce energy and survive. Over cultivation can deplete soil nutrients and damage root systems, causing plant death. Lack of space restricts root growth and can cause plants to become root-bound, eventually leading to death. Since all three factors—lack of sunlight, over cultivation, and lack of space—can independently cause a plant to die, the correct answer is 'all of the above'. None of these options would promote plant health, so 'none of the above' is incorrect."
        },
        {
            "input": "Where does the energy in food originally come from?\n"
                     "Answer Choices:\n"
                     "(a) water\n"
                     "(b) minerals\n"
                     "(c) soil\n"
                     "(d) the Sun\n"
                     "(e) carbon dioxide",
            "label": "the Sun",
            "rationale": "To determine the original source of energy in food, I need to trace the energy flow in ecosystems. Plants convert light energy into chemical energy through photosynthesis, which becomes stored in the plants' tissues. When animals eat plants or other animals, they obtain this stored energy. Water, minerals, and soil provide necessary materials for plant growth but don't supply energy. Carbon dioxide is used in photosynthesis but isn't an energy source. The Sun provides the initial light energy that drives photosynthesis, making it the original source of energy that eventually becomes stored in food. Therefore, the Sun is the correct answer."
        },
        {
            "input": "What is the most essential feature of a non-fiction novel?\n"
                     "Answer Choices:\n"
                     "(a) real events\n"
                     "(b) specific details\n"
                     "(c) personal experiences\n"
                     "(d) research\n"
                     "(e) social context",
            "label": "real events",
            "rationale": "To identify the most essential feature of a non-fiction novel, I need to understand what fundamentally defines non-fiction. Non-fiction, by definition, is writing based on real events, facts, and information, as opposed to fiction, which is imaginary. While specific details enhance non-fiction writing, they're not the defining characteristic. Personal experiences may be included but aren't required in all non-fiction. Research is important for accuracy but is a method rather than a defining feature. Social context adds depth but isn't essential to all non-fiction works. The foundational characteristic that makes something non-fiction is that it's based on real events, making this the most essential feature."
        },
        {
            "input": "What do people use to reduce glare from the sun while driving?\n"
                     "Answer Choices:\n"
                     "(a) visor\n"
                     "(b) gas\n"
                     "(c) horn\n"
                     "(d) radio\n"
                     "(e) mirror",
            "label": "visor",
            "rationale": "To determine what people use to reduce glare while driving, I need to consider each option's function. A visor is a hinged flap in cars specifically designed to block direct sunlight and reduce glare, improving visibility and safety. Gas is a fuel for the vehicle and has no relation to sun glare. The horn is used for signaling to other drivers and doesn't affect sunlight. A radio is for entertainment and communication and doesn't block light. Mirrors reflect light rather than block it, potentially increasing glare. Therefore, among the given options, a visor is specifically designed to reduce sun glare while driving."
        },
        {
            "input": "What is the main purpose of a zipper?\n"
                     "Answer Choices:\n"
                     "(a) join things together\n"
                     "(b) release air\n"
                     "(c) lock items away\n"
                     "(d) create shade\n"
                     "(e) provide insulation",
            "label": "join things together",
            "rationale": "To determine the main purpose of a zipper, I need to analyze its primary function. A zipper consists of two rows of interlocking teeth that can be joined or separated by moving a slider. Its fundamental purpose is to join fabric or other materials together in a way that can be easily opened and closed. Zippers don't primarily release air, though they might allow airflow when opened. While zippers can secure items, 'lock items away' suggests a security function that's secondary to joining. Zippers don't create shade or primarily provide insulation, though closed zippers on insulated clothing help maintain warmth. Therefore, the main purpose of a zipper is to join things together."
        },
        {
            "input": "What do cowboys typically sleep under when camping outdoors?\n"
                     "Answer Choices:\n"
                     "(a) stars\n"
                     "(b) rocks\n"
                     "(c) overhang\n"
                     "(d) tent\n"
                     "(e) trailer",
            "label": "stars",
            "rationale": "To determine what cowboys typically sleep under when camping outdoors, I need to consider historical practices and common imagery associated with cowboys. Traditionally, cowboys on cattle drives would sleep outdoors without elaborate shelter, often using only bedrolls. They wouldn't sleep under rocks, as this doesn't provide shelter and is impractical. An overhang might be used during bad weather but isn't the typical situation. Modern camping equipment like tents or trailers would be anachronistic for traditional cowboys and impractical to carry on horseback during cattle drives. The phrase 'sleeping under the stars' is commonly associated with cowboys camping in the open air, making 'stars' the most appropriate answer for what cowboys typically sleep under."
        },
        {
            "input": "What natural disaster is measured using the Richter scale?\n"
                     "Answer Choices:\n"
                     "(a) tornado\n"
                     "(b) earthquake\n"
                     "(c) hurricane\n"
                     "(d) drought\n"
                     "(e) flood",
            "label": "earthquake",
            "rationale": "To identify which natural disaster is measured using the Richter scale, I need to recall what the Richter scale specifically measures. The Richter scale was developed to quantify the magnitude or energy release of seismic events by measuring the amplitude of the largest seismic wave. Tornadoes are measured using the Enhanced Fujita (EF) scale based on damage patterns. Hurricanes are measured using the Saffir-Simpson scale based on wind speed. Droughts are measured using indices like the Palmer Drought Severity Index. Floods are measured by water levels, flow rates, and extent of inundation. Only earthquakes are measured using the Richter scale (though modern seismologists often use the moment magnitude scale for larger earthquakes). Therefore, earthquake is the correct answer."
        },
        {
            "input": "What is the most important safety feature of a car?\n"
                     "Answer Choices:\n"
                     "(a) seatbelt\n"
                     "(b) airbag\n"
                     "(c) anti-lock brakes\n"
                     "(d) backup camera\n"
                     "(e) blind spot detection",
            "label": "seatbelt",
            "rationale": "To determine the most important safety feature of a car, I need to consider the effectiveness, universality, and historical impact of each option. Seatbelts are considered the most fundamental safety device in vehicles, reducing fatalities by 45-60% according to safety studies. They're the primary restraint system that keeps occupants in position during a crash, enabling other safety features to work effectively. Airbags are supplemental restraint systems designed to work with seatbelts, not replace them. Anti-lock brakes, backup cameras, and blind spot detection enhance safety but primarily prevent accidents rather than protect during collisions. Safety experts and regulatory bodies consistently identify seatbelts as the single most important safety feature, with mandatory seatbelt laws preceding other safety requirements. Therefore, the seatbelt is the most important safety feature."
        }
    ]

    # Format examples in the prompt with clear separation
    formatted_examples = ""
    for i, example in enumerate(examples, 1):
        formatted_examples += f"EXAMPLE {i}:\n"
        formatted_examples += f"Input: {example['input']}\n"
        formatted_examples += f"Label: {example['label']}\n"
        formatted_examples += f"Rationale: {example['rationale']}\n\n"

    # Create a very explicit new task marker with instruction not to add extra text
    new_input_section = """====================
    NEW QUESTION REQUIRING RATIONALE:
    ====================

    Input: {input}
    Label: {label}

    Generate a factual and specific rationale that explains why the label is correct.
    Do not add any text asking if I want another question or example.
    Do not add anything after your explanation.

    Rationale:"""

    # Combine all parts of the prompt
    full_prompt_template = header + formatted_examples + new_input_section

    return full_prompt_template

# **Generate Rationales**

In [None]:
def generate_rationales(dataset, pipe, batch_size=16, max_retries=3):
    prompt_template = create_prompt_template()
    results = {"train": [], "test": []}
    total_processed = 0
    last_saved_count = 0

    for split_name in dataset:
        split_data = dataset[split_name]
        print(f"Processing {split_name} split with {len(split_data)} examples...")

        # Process in batches
        for i in tqdm(range(0, len(split_data), batch_size)):
            batch_indices = list(range(i, min(i+batch_size, len(split_data))))

            # Format prompts for the batch
            if split_name == "train":
                prompts = []
                for idx in batch_indices:
                    example = split_data[idx]
                    formatted_prompt = prompt_template.replace("{input}", example["input"]).replace("{label}", example["label"])
                    prompts.append(formatted_prompt)

                outputs = pipe(prompts)

                # Process each output
                for j, output in enumerate(outputs):
                    example_idx = batch_indices[j]
                    example = split_data[example_idx]
                    generated_text = output[0]["generated_text"]

                    # Extract rationale, Retry generation if no rationale is found
                    if "Rationale:" not in generated_text:
                        for retry in range(max_retries):
                            retry_prompt = prompt_template.replace("{input}", example["input"]).replace("{label}", example["label"])
                            retry_output = pipe([retry_prompt])[0]["generated_text"]

                            if "Rationale:" in retry_output:
                                generated_text = retry_output
                                break
                        else:
                            # If all retries fail, raise an error or handle accordingly
                            raise ValueError(f"Failed to generate rationale for input: {example['input']}")

                    # Extract rationale
                    rationale_parts = generated_text.split("Rationale:")
                    rationale_text = rationale_parts[-1].strip()

                    # Clean up any trailing markers or new tasks
                    for marker in ["====================", "NEW QUESTION", "Input:"]:
                        if marker in rationale_text:
                            rationale_text = rationale_text.split(marker)[0].strip()

                    results[split_name].append({
                        "input": example["input"],
                        "label": example["label"],
                        "rationale": rationale_text
                    })
            else:
                for idx in batch_indices:
                    example = split_data[idx]
                    results[split_name].append({
                        "input": example["input"],
                        "label": example["label"],
                        "rationale": ""
                    })

            total_processed += len(batch_indices)

            # Save interim results every 200 rows
            if total_processed - last_saved_count >= 200:
                temp_df = pd.DataFrame(results["train"])
                temp_df.to_csv(f"interim_results_{total_processed}.csv", index=False)
                print(f"Saved interim results at {total_processed} rows")
                last_saved_count = total_processed

    # Save ALL rows to final CSV
    train_df = pd.DataFrame(results["train"])
    train_df.to_csv("complete_rationale_results.csv", index=False)
    print(f"Processing complete! Total train examples processed: {len(results['train'])}")
    print(f"📂 Final results saved to complete_rationale_results.csv")

    return results

In [None]:
results = generate_rationales(dataset, pipe)

# **Push to Hugging Face**
#### This is to make it easier to load the dataset later when training the T5-small model.

In [None]:
def upload_dataset(results, dataset_name, username, token):
    login(token=token)

    hf_dataset = DatasetDict({
        "train": Dataset.from_list(results["train"]),
        "test": Dataset.from_list(results["test"])
    })

    hf_dataset.push_to_hub(f"{username}/{dataset_name}")
    print(f"Dataset successfully pushed to {username}/{dataset_name}")

    return hf_dataset

uploaded_dataset = upload_dataset(
    results,
    dataset_name="Salesforce/cos_e-rationale",
    username="your-username",
    token="huggingface_token"
)