# GRPO AND SFT TRAINING NOTEBOOK

Author: Chengheng Li Chen (Template from Unsloth)

## Installation

In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm

In [2]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft "trl==0.15.2" triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

In [None]:
!pip install langchain_openai

##  Imports


In [None]:

import pandas as pd
from unsloth import FastLanguageModel
from datasets import Dataset
from trl import SFTTrainer, SFTConfig, GRPOTrainer, GRPOConfig
import re

from typing import Optional, Union
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.language_models.chat_models import BaseChatModel

from langchain_core.prompts import ChatPromptTemplate
from tqdm import tqdm

import math
import numpy as np

import json
import time


## Loading the model

In [None]:
max_seq_length = 2048 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-4B-Base",
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference (Only linux)
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.7, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank*2, # *2 speeds up training
    use_gradient_checkpointing = "unsloth", # Reduces memory usage
    random_state = 3407,
)

## Setting up the template


In [None]:
reasoning_start = "<think>" # Acts as <think>
reasoning_end   = "</think>"   # Acts as </think>
solution_start  = "<O>"
solution_end    = "</O>"

system_prompt = f"""You are a helpful assistant.
You are given a story.
Think about the story and determine if the speaker is doing the thing correctly.
Place your reasoning between {reasoning_start} and {reasoning_end}.
Then, provide your answer "RIGHT" or "WRONG" within {solution_start} and {solution_end} as the final verdict.
"""
system_prompt

'You are a helpful assistant.\nYou are given a story.\nThink about the story and determine if the speaker is doing the thing correctly.\nPlace your reasoning between <think> and </think>.\nThen, provide your answer "RIGHT" or "WRONG" within <O> and </O> as the final verdict.\n'

Trying the template

In [None]:
chat_template = \
    "{% if messages[0]['role'] == 'system' %}"\
        "{{ messages[0]['content'] + eos_token }}"\
        "{% set loop_messages = messages[1:] %}"\
    "{% else %}"\
        "{{ '{system_prompt}' + eos_token }}"\
        "{% set loop_messages = messages %}"\
    "{% endif %}"\
    "{% for message in loop_messages %}"\
        "{% if message['role'] == 'user' %}"\
            "{{ message['content'] }}"\
        "{% elif message['role'] == 'assistant' %}"\
            "{{ message['content'] + eos_token }}"\
        "{% endif %}"\
    "{% endfor %}"\
    "{% if add_generation_prompt %}{{ '{reasoning_start}' }}"\
    "{% endif %}"

# Replace with out specific template:
chat_template = chat_template\
    .replace("'{system_prompt}'",   f"'{system_prompt}'")\
    .replace("'{reasoning_start}'", f"'{reasoning_start}'")
tokenizer.chat_template = chat_template

Let's see how our chat template behaves on an example:

In [None]:
tokenizer.apply_chat_template([
    {"role" : "user", "content" : "What is 1+1?"},
    {"role" : "assistant", "content" : f"{reasoning_start}I think it's 2.{reasoning_end}{solution_start}2{solution_end}"},
    {"role" : "user", "content" : "What is 2+2?"},
], tokenize = False, add_generation_prompt = True)

'You are a helpful assistant.\nYou are given a story.\nThink about the story and determine if the speaker is doing the thing correctly.\nPlace your reasoning between <think> and </think>.\nThen, provide your answer "RIGHT" or "WRONG" within <O> and </O> as the final verdict.\n<|endoftext|>What is 1+1?<think>I think it\'s 2.</think><O>2</O><|endoftext|>What is 2+2?<think>'

## Loading the data


In [None]:
def filter_by_title(df, title_query, case_sensitive=False):
    """
    Filter the DataFrame to find entries with titles containing the specified query

    Parameters:
    -----------
    df : pandas.DataFrame
        The DataFrame to search
    title_query : str
        The title text to search for
    case_sensitive : bool, default=False
        Whether to perform a case-sensitive search

    Returns:
    --------
    pandas.DataFrame
        DataFrame containing rows with matching titles
    """
    if case_sensitive:
        matching_rows = df[df['title'].str.contains(title_query, na=False)]
    else:
        matching_rows = df[df['title'].str.contains(title_query, case=False, na=False)]

    return matching_rows

def get_exact_title_match(df, title):
    """
    Get the exact row that matches the complete title

    Parameters:
    -----------
    df : pandas.DataFrame
        The DataFrame to search
    title : str
        The exact title to match

    Returns:
    --------
    pandas.Series or None
        The matching row or None if no match found
    """
    matching_rows = df[df['title'] == title]

    if len(matching_rows) == 1:
        return matching_rows.iloc[0]
    elif len(matching_rows) > 1:
        print(f"Warning: Multiple matches found for title '{title}'")
        return None
    else:
        print(f"No exact match found for title '{title}'")
        return None

In [None]:
scores_df = pd.read_csv("reasoning_scores.csv")
reasoning_df = pd.read_csv("resoning_dataset.csv")

In [None]:
scores_df.head()

Unnamed: 0,title,NEUTRALITY,EVIDENCE_INTEGRATION,HYPOTHESIS_TESTING,LOGICAL_COHERENCE,NUANCE_RECOGNITION,DECISION_ALIGNMENT,TOTAL_SCORE,COGNITIVE_QUALITY_RATING,SUMMARY_DIAGNOSIS,target_decision
0,AITA for not wanting to have any contact with ...,8,7,8,7,8,9,47,FUNCTIONAL_REASONING,The reasoning demonstrates functional quality ...,RIGHT
1,AITA for telling my three best friends that th...,8,7,8,7,9,8,47,FUNCTIONAL_REASONING,The reasoning demonstrates functional quality ...,RIGHT
2,AITA for thinking showing pubes in public is a...,8,7,8,7,9,8,47,FUNCTIONAL_REASONING,The reasoning demonstrates functional analytic...,RIGHT
3,AITA for refusing to accept mental illness as ...,7,8,7,8,8,9,47,FUNCTIONAL_REASONING,The reasoning demonstrates functional analytic...,RIGHT
4,WIBTA if I told my husband we REALLY need to c...,8,7,8,7,9,8,47,FUNCTIONAL_REASONING,The reasoning demonstrates functional quality ...,RIGHT


In [None]:
titles = scores_df["title"]

In [None]:
stories = list()
reasonings = list()
target = list()
titles_complete  = list()

for title in titles:
    column = get_exact_title_match(reasoning_df, title)
    if column is None:
        continue

    text = column["text"]
    output = column["output"]
    reasoning = column["initial_analysis"] + column["dialectical_challenge"] + column["integration"] + f"\n Therefore the author is doing the { output } thing."


    stories.append(text)
    reasonings.append(reasoning)
    target.append(output)
    titles_complete.append(title)

len(stories), len(reasonings), len(target), len(titles_complete)



(2477, 2477, 2477, 2477)

We have to format the dataset to follow our GRPO style formatting:

In [None]:
dataset = pd.DataFrame({
    'expected_answer': target,
    'problem': stories,
    'generated_solution': reasonings,
    'title': titles_complete
})


In [None]:
def format_dataset(x):
    expected_answer = x["expected_answer"]
    problem = x["problem"]

    # Remove generated <think> and </think>
    thoughts = x["generated_solution"]
    thoughts = thoughts.replace("<think>", "").replace("</think>", "")

    # Strip newlines on left and right
    thoughts = thoughts.strip()
    # Add our custom formatting
    final_prompt = \
        reasoning_start + thoughts + reasoning_end + \
        solution_start + expected_answer + solution_end
    return [
        {"role" : "system",    "content" : system_prompt},
        {"role" : "user",      "content" : problem},
        {"role" : "assistant", "content" : final_prompt},
    ]



## SFT Training

### Dataset Setup

In [None]:
sft_dataset = dataset.copy()

sft_dataset = sft_dataset[:100]

sft_dataset["Messages"] = dataset.apply(format_dataset, axis = 1)

Check to see if it worked:

In [None]:
tokenizer.apply_chat_template(sft_dataset["Messages"][0], tokenize = False)

'You are a helpful assistant.\nYou are given a story.\nThink about the story and determine if the speaker is doing the thing correctly.\nPlace your reasoning between <think> and </think>.\nThen, provide your answer "RIGHT" or "WRONG" within <O> and </O> as the final verdict.\n<|endoftext|>So my father used to mentally and physically abuse me for a very long time, from like my 10th till I moved out (19). And now he claims that he wants to have contact again but I don\'t want it anymore. He didn\'t even tell me he is sorry for the things he did. Now everyone says I\'m the asshole because he wants to \'make things right\' but I don\'t give him the chance., am I the asshole here?\n\n&amp;#x200B;\n\np.s. excuse me for my English <think>The situation involves a complex history of abuse and the current desire of the father to reconnect. \n\n1. **Hypothesis 1: The father genuinely wants to make amends.**\n   - Evidence for: He has expressed a desire to reconnect, which could indicate a change 

Let's truncate the pre fine-tuning dataset to `max_seq_length/2` since we don't want too long reasoning traces.

In [None]:
sft_dataset["N"] = sft_dataset["Messages"].apply(lambda x: len(tokenizer.apply_chat_template(x)))

sft_dataset = sft_dataset.loc[sft_dataset["N"] <= max_seq_length/2].copy()
sft_dataset.shape

(45, 6)

We then tokenize the messages and convert it to a Hugging Face compatible dataset format:

In [None]:
sft_dataset["text"] = tokenizer.apply_chat_template(sft_dataset["Messages"].values.tolist(), tokenize = False)
sft_dataset = Dataset.from_pandas(sft_dataset)
sft_dataset

Dataset({
    features: ['expected_answer', 'problem', 'generated_solution', 'title', 'Messages', 'N', 'text', '__index_level_0__'],
    num_rows: 45
})

In [None]:
sft_dataset[0]

{'expected_answer': 'RIGHT',
 'problem': "So my father used to mentally and physically abuse me for a very long time, from like my 10th till I moved out (19). And now he claims that he wants to have contact again but I don't want it anymore. He didn't even tell me he is sorry for the things he did. Now everyone says I'm the asshole because he wants to 'make things right' but I don't give him the chance., am I the asshole here?\n\n&amp;#x200B;\n\np.s. excuse me for my English ",
 'generated_solution': 'The situation involves a complex history of abuse and the current desire of the father to reconnect. \n\n1. **Hypothesis 1: The father genuinely wants to make amends.**\n   - Evidence for: He has expressed a desire to reconnect, which could indicate a change of heart.\n   - Evidence against: He has not apologized or acknowledged the harm he caused, which raises doubts about his sincerity.\n\n2. **Hypothesis 2: The father is seeking to alleviate his own guilt.**\n   - Evidence for: His lac

### Trainer Setup

In [None]:
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = sft_dataset,
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 1, # Use GA to mimic batch size!
        warmup_steps = 5,
        num_train_epochs = 2, # Set this for 1 full training run.
        learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
        logging_steps = 5,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none", # Use this for WandB etc
    ),
)

### Train the model

In [17]:
trainer.train()

## GRPO Training
### Reward definition

Defining Regex rules to detect the patterns

In [19]:
def extract_model_answer(content):
    """
    Extract the text between the first and second <|endoftext|> tags from a file.

    Args:
        file_path (str): Path to the file containing the text

    Returns:
        str: The extracted text between the tags
    """
    try:
        # Find the positions of the endoftext tags
        first_tag = content.find('<|endoftext|>')
        if first_tag == -1:
            return "First <|endoftext|> tag not found."

        # Find the second tag starting from after the first tag
        start_pos = first_tag + len('<|endoftext|>')
        second_tag = content.find('<|endoftext|>', start_pos)
        if second_tag == -1:
            return "Second <|endoftext|> tag not found."

        # Extract the text between the tags
        extracted_text = content[start_pos:second_tag]

        return extracted_text

    except Exception as e:
        return f"Error occurred: {str(e)}"

In [None]:
# Add optional EOS token matching
solution_end_regex = r"</O>[\s]{0,}" + \
    "(?:" + re.escape(tokenizer.eos_token) + ")?"

match_format = re.compile(
    rf"{reasoning_end}.*?"\
    rf"{solution_start}(.+?){solution_end_regex}"\
    rf"[\s]{{0,}}$",
    flags = re.MULTILINE | re.DOTALL
)
match_format

endoftext_pattern = re.compile(r'<\|endoftext\|>(.*?)<\|endoftext\|>', re.DOTALL)
think_pattern = re.compile(r'(.*?)<O>', re.DOTALL)
output_pattern = re.compile(r'<O>(.*?)</O>', re.DOTALL)

Defining the reward function

In [None]:
def match_format_exactly(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Match if format is seen exactly!
        if match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores

In [None]:
global PRINTED_TIMES
PRINTED_TIMES = 0
global PRINT_EVERY_STEPS
PRINT_EVERY_STEPS = 5

def check_verdict(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := output_pattern.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    # Print only every few steps
    global PRINTED_TIMES
    global PRINT_EVERY_STEPS
    if PRINTED_TIMES % PRINT_EVERY_STEPS == 0:
        print(
            '*'*20 + f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}"
        )
    PRINTED_TIMES += 1

    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(-2.5)
            continue
        # Convert to numbers
        try:
            scores.append(3.5 if true_answer.upper() in guess.upper() else -1.5)
        except:
            scores.append(0)
            continue
    return scores

If it fails, we want to reward the model if it at least follows the format partially, by counting each symbol:

In [None]:



def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]

        #score += 1 if response.count(reasoning_start) == 1 else -1.5
        score += 0.5 if response.count(reasoning_end)   == 1 else -1.0
        score += 1 if response.count(solution_start)  == 1 else -1.5
        score += 0.5 if response.count(solution_end)    == 1 else -1.0
        scores.append(score)
    return scores

Finally, we want to extract the generated answer, and reward or penalize it! We also reward it based on how close the answer is to the true one via ratios:

In [None]:
def get_system_prompt():

    system_prompt = """
    You are a cognitive psychologist specializing in decision-making and reasoning assessment, you will evaluate the quality of analytical reasoning in "Am I The Asshole" case analyses. Your evaluation aims to identify cognitive biases, logical structures, and psychological underpinnings of the reasoning process.

    IMPORTANT: IF THE REASONING IS INCOHERENCE AND REALLY NOISY GIVE 0 IN ALL CRITERIA.

    EVALUATION CRITERIA:

    Each criterion is scored on a scale of 0-10 based on psychological principles of sound reasoning.
    1. NEUTRALITY (0-10): Assessment of cognitive bias and perspective-taking ability
    - 0-2: Severe fundamental attribution error; attributes negative motives to one party while excusing another's similar behavior
    - 3-4: Demonstrates significant confirmation bias, seeking evidence that supports preconceived judgment
    - 5-6: Shows moderate actor-observer bias but attempts perspective-taking
    - 7-8: Displays good psychological distancing with minimal in-group favoritism
    - 9-10: Exhibits exceptional cognitive empathy and perspective-taking abilities across all parties
    2. EVIDENCE_INTEGRATION (0-10): Cognitive processing of available information
    - 0-2: Severe availability bias; focuses only on salient or recent information
    - 3-4: Shows anchoring bias; overly influenced by first or strongest impressions
    - 5-6: Demonstrates selective attention but includes major evidence points
    - 7-8: Displays good information processing with minimal cherry-picking
    - 9-10: Exhibits comprehensive information synthesis, including subtle contextual cues and metacognitive awareness of evidence quality
    3. HYPOTHESIS_TESTING (0-10): Cognitive flexibility and consideration of alternatives
    - 0-2: Strong belief perseverance; fails to consider alternative explanations
    - 3-4: Minimal cognitive flexibility; briefly acknowledges alternatives without exploration
    - 5-6: Moderate hypothesis generation with uneven evaluation (satisficing)
    - 7-8: Good falsification testing; actively seeks disconfirming evidence
    - 9-10: Exceptional counter-factual thinking with systematic evaluation of multiple causal models
    4. LOGICAL_COHERENCE (0-10): Structure and validity of cognitive reasoning chains
    - 0-2: Demonstrates illusory correlation or post hoc fallacies
    - 3-4: Shows conjunction fallacy or base rate neglect in probabilistic reasoning
    - 5-6: Generally coherent with occasional modus ponens errors or hasty generalizations
    - 7-8: Strong syllogistic reasoning with minor imperfections
    - 9-10: Exemplary deductive and inductive reasoning without cognitive fallacies
    5. NUANCE_RECOGNITION (0-10): Tolerance for ambiguity and cognitive complexity
    - 0-2: Black-and-white thinking characteristic of cognitive rigidity
    - 3-4: Low tolerance for ambiguity; forces complex situations into simple frameworks
    - 5-6: Moderate recognition of complexity but still seeks premature cognitive closure
    - 7-8: Good integration of emotional and contextual factors in judgment
    - 9-10: High integrative complexity; recognizes dialectical tensions and multiple valid perspectives
    6. DECISION_ALIGNMENT (0-10): Metacognitive consistency and rational judgment process
    - 0-2: Severe cognitive dissonance between analysis and conclusion
    - 3-4: Motivated reasoning evident in selectively weighing evidence
    - 5-6: General alignment with some unjustified inferential leaps
    - 7-8: Good metacognitive awareness with most key factors appropriately weighted
    - 9-10: Exceptional rational decision process with explicit weighting of competing values


    SCORING CALCULATION:
    - Individual scores: Each criterion receives a score from 0-10
    - TOTAL_SCORE: Sum of all individual criterion scores (0-60)
    - COGNITIVE_QUALITY_RATING:
    * 0-20: SEVERELY_IMPAIRED_REASONING
    * 21-35: MODERATELY_IMPAIRED_REASONING
    * 36-50: FUNCTIONAL_REASONING
    * 51-60: OPTIMAL_REASONING
    OUTPUT FORMAT:
    {{{{
    "NEUTRALITY": <score>,
    "EVIDENCE_INTEGRATION": <score>,
    "HYPOTHESIS_TESTING": <score>,
    "LOGICAL_COHERENCE": <score>,
    "NUANCE_RECOGNITION": <score>,
    "DECISION_ALIGNMENT": <score>,
    "TOTAL_SCORE": <sum>,
    "COGNITIVE_QUALITY_RATING": "<rating>",
    "PSYCHOLOGICAL_ASSESSMENT": {{{{
        "NEUTRALITY": {{{{
        "cognitive_strengths": <"specific strength with example">,
        "cognitive_biases": <"specific bias with example">,
        "improvement_suggestions": <"specific suggestion">
        }}}},
        <repeat for all criteria>
    }}}},
    "SUMMARY_DIAGNOSIS": "Brief psychological assessment of overall reasoning pattern and primary cognitive strengths/limitations"
    }}}}

    EVALUATION PROCESS:
    1. Conduct thorough cognitive assessment of the reasoning process
    2. Identify specific psychological mechanisms, biases, and heuristics evident in the reasoning
    3. Evaluate reasoning against normative standards of rational decision-making
    4. Provide constructive feedback framed in terms of cognitive psychology principles
    5. Maintain professional, objective stance without pathologizing reasonable variations in thinking styles
    6. Do NOT put scores into brackets.
    7. Take into account the length of the reasoning trace.
    8. IMPORTANT: Check if the reasoning reach an conclusion and is aligned with the conclusion, if not penalize the DECISION_ALIGNMENT.
    9. IMPORTANT: Check if the reasoning incorporate different perspectives, if not penalize the NEUTRALITY.
    10. IMPORTANT: Check if the reasoning follows a coherent flow, if not penalize the LOGICAL_COHERENCE.
    11. IMPORTANT: Check if the reasoning is nuanced, if not penalize the NUANCE_RECOGNITION.
    12. IMPORTANT: Penalize all criteria if the reasoning if it is incoherent, short or really noisy.
    13. IMPORTANT: Check if the reasoning is aligned with the story, if not penalize all criteria.


    As a input you will be given the following:
    - The original AITA case text story, which will be given after "Original AITA case:"
    - The reasoning analysis to be evaluated, which will be given after "Reasoning to evaluate:"

    IMPORTANT: FOCUS TO EVALUATE TE REASONING NOT THE STORY.
    IMPORTANT: IF THE REASONING IS INCOHERENCE AND REALLY NOISY GIVE 0 IN ALL CRITERIA.
    IMPORTANT: CHECK THE REASONING DETAILY ACCORDING TO THE PROVIDED STORY, IF THEY ARE NOT RELATED GIVE 0 IN ALL CRITERIA.
    IMPORTANT: IF THE REASONING DOES NOT REACH A CLEAR CONCLUSION OF IF THE SPEACKER IS WRONG OR RIGHT GIVE PENALIZE DECISION_ALIGNMENT.
    """
    return system_prompt

def generate_prompt(story, reasoning):
    """
    Creates a prompt that evaluates the quality of reasoning in an AITA case.

    Parameters:
    -----------
    story : str
        The original AITA case text

    reasoning : str
        The reasoning analysis to be evaluated

    target_decision : str
        The expected judgment (RIGHT or WRONG)

    Returns:
    --------
    An evaluation of the reasoning quality with scores on six cognitive dimensions
    """

    prompt = f"""
    Evaluate the reasoning quality in the following AITA case analysis.

    Original AITA case:
    ---
    {story}
    ---

    Reasoning to evaluate:
    ---
    {reasoning}
    ---

    Return your evaluation in JSON format with scores for NEUTRALITY, EVIDENCE_INTEGRATION,
    HYPOTHESIS_TESTING, LOGICAL_COHERENCE, NUANCE_RECOGNITION, DECISION_ALIGNMENT,
    plus TOTAL_SCORE and QUALITY_RATING.

    IMPORTANT: IF THE REASONING IS INCOHERENCE, SHORT OR REALLY NOISY GIVE 0 IN ALL CRITERIA.

    Include brief explanations for each score that identify specific strengths and weaknesses
    in the reasoning process.
    """

    return prompt

In [None]:
def calculate_optimized_reward(assessment_output):
    """
    Calculate an optimized reward value based on cognitive reasoning assessment output.
    Uses a balanced hybrid approach with priority boosts for important criteria.

    Args:
        assessment_output (dict): The assessment output following the evaluation criteria

    Returns:
        float: A reward value between 0.0 and 1.0, with higher values for better reasoning
    """

    # Extract individual scores and validate they are numbers
    score_keys = ["NEUTRALITY", "EVIDENCE_INTEGRATION", "HYPOTHESIS_TESTING",
                 "LOGICAL_COHERENCE", "NUANCE_RECOGNITION", "DECISION_ALIGNMENT"]

    scores = {}
    for key in score_keys:
        value = assessment_output.get(key, 0)
        # Check if value is numeric or can be converted to a number
        if isinstance(value, (int, float)):
            scores[key] = value
        else:
            try:
                scores[key] = float(value)
            except (ValueError, TypeError):
              try:
                scores[key] = float(value[0])
              except (ValueError, TypeError):
                raise ValueError(f"Score '{key}' with value '{value}' is not a number and cannot be converted to one.")

    # Extract validated scores
    neutrality = scores["NEUTRALITY"]
    evidence_integration = scores["EVIDENCE_INTEGRATION"]
    hypothesis_testing = scores["HYPOTHESIS_TESTING"]
    logical_coherence = scores["LOGICAL_COHERENCE"]
    nuance_recognition = scores["NUANCE_RECOGNITION"]
    decision_alignment = scores["DECISION_ALIGNMENT"]

    # Define importance weights for each criterion
    # Higher weights for more important criteria
    importance_weights = {
        "NEUTRALITY": 1.3,         # 20% more important
        "EVIDENCE_INTEGRATION": 1.3, # 10% more important
        "HYPOTHESIS_TESTING": 0.9,   # 20% more important
        "LOGICAL_COHERENCE": 1.1,    # baseline importance
        "NUANCE_RECOGNITION": 0.9,   # slightly less important
        "DECISION_ALIGNMENT": 1    # 10% more important
    }

    # Apply importance weights to individual scores
    weighted_scores = {
        "NEUTRALITY": neutrality * importance_weights["NEUTRALITY"],
        "EVIDENCE_INTEGRATION": evidence_integration * importance_weights["EVIDENCE_INTEGRATION"],
        "HYPOTHESIS_TESTING": hypothesis_testing * importance_weights["HYPOTHESIS_TESTING"],
        "LOGICAL_COHERENCE": logical_coherence * importance_weights["LOGICAL_COHERENCE"],
        "NUANCE_RECOGNITION": nuance_recognition * importance_weights["NUANCE_RECOGNITION"],
        "DECISION_ALIGNMENT": decision_alignment * importance_weights["DECISION_ALIGNMENT"]
    }

    # Calculate weighted total score (0-60 scaled by importance)
    raw_total = neutrality + evidence_integration + hypothesis_testing + logical_coherence + nuance_recognition + decision_alignment
    weighted_total = sum(weighted_scores.values())
    max_possible_weighted = sum(10 * weight for weight in importance_weights.values())

    if weighted_total == 0: return { "final_reward": 0}

    # Store original scores for reference
    original_scores = {
        "NEUTRALITY": neutrality,
        "EVIDENCE_INTEGRATION": evidence_integration,
        "HYPOTHESIS_TESTING": hypothesis_testing,
        "LOGICAL_COHERENCE": logical_coherence,
        "NUANCE_RECOGNITION": nuance_recognition,
        "DECISION_ALIGNMENT": decision_alignment
    }

    # COMPONENT 1 (40%): Sigmoid Reward with importance-weighted total
    # Non-linear reward using sigmoid function to emphasize medium-quality improvements
    # Adjust the midpoint (30) proportionally to account for importance weights
    midpoint = 30 * (sum(importance_weights.values()) / 6)
    sigmoid_reward = 1.0 / (1.0 + math.exp(-0.15 * (weighted_total - midpoint)))

    # COMPONENT 2 (20%): Threshold Reward
    # Discrete reward levels based on cognitive quality rating
    thresholds = {
        "SEVERELY_IMPAIRED_REASONING": 0.1,  # (0-20 points)
        "MODERATELY_IMPAIRED_REASONING": 0.4,  # (21-35 points)
        "FUNCTIONAL_REASONING": 0.7,  # (36-50 points)
        "OPTIMAL_REASONING": 1.0,  # (51-60 points)
    }

    cognitive_quality = assessment_output.get("COGNITIVE_QUALITY_RATING", "")
    # Check if cognitive_quality is a valid key in thresholds
    if cognitive_quality not in thresholds and cognitive_quality != "":
        try:
            # See if it's a string representation of a number
            threshold_reward = float(cognitive_quality) / 60.0
        except (ValueError, TypeError):
            # If not a valid key or number, default to raw_total based calculation
            threshold_reward = raw_total / 60.0
    else:
        threshold_reward = thresholds.get(cognitive_quality, raw_total / 60.0)

    # COMPONENT 3 (20%): Minimum Criterion Reward
    # Rewards based on the weakest aspect of reasoning
    # Use original unweighted scores to find the weakest criterion
    min_criterion_score = min(original_scores.values())
    min_criterion_reward = min_criterion_score / 10.0  # Normalize to 0-1

    # COMPONENT 4 (20%): Variance Reward
    # Rewards balanced reasoning across all criteria
    # Use original unweighted scores to calculate variance
    scores_list = list(original_scores.values())
    variance = np.var(scores_list)
    # Maximum possible variance for six scores on a 0-10 scale
    max_possible_variance = 33.33  # Theoretical max with most extreme distribution
    variance_reward = 1.0 - (variance / max_possible_variance)

    # Calculate optimized hybrid reward with component weights
    optimized_reward = (
        0.4 * sigmoid_reward +       # 40% weight - primary non-linear signal (with importance weighting)
        0.2 * threshold_reward +     # 20% weight - quality tier recognition
        0.2 * min_criterion_reward + # 20% weight - weakest aspect emphasis
        0.2 * variance_reward        # 20% weight - balanced reasoning emphasis
    )

    # Add boost for high scores in priority criteria
    priority_criteria = ["NEUTRALITY", "HYPOTHESIS_TESTING", "DECISION_ALIGNMENT"]
    boost_threshold = 8  # Apply boost for scores of 8 or higher in priority criteria
    boost_value = 0.0

    for criterion in priority_criteria:
        if original_scores[criterion] >= boost_threshold:
            boost_value += 0.02  # 2% boost per high-scoring priority criterion

    # Apply boost to final reward (capped at 1.0)
    final_reward = min(1.0, optimized_reward + boost_value)

    # Return detailed breakdown for analysis
    return {
        "component_rewards": {
            "sigmoid_reward": sigmoid_reward,
            "threshold_reward": threshold_reward,
            "min_criterion_reward": min_criterion_reward,
            "variance_reward": variance_reward
        },
        "component_weights": {
            "sigmoid_weight": 0.4,
            "threshold_weight": 0.2,
            "min_criterion_weight": 0.2,
            "variance_weight": 0.2
        },
        "importance_weights": importance_weights,
        "priority_boost": boost_value,
        "raw_total": raw_total,
        "weighted_total": weighted_total,
        "optimized_reward": optimized_reward,
        "final_reward": final_reward
    }

In [None]:
def parse_json_response(response_text):
    """
    Attempts to parse a JSON response, handling common issues.

    Parameters:
    -----------
    response_text : str
        The text response from an LLM

    Returns:
    --------
    dict or None: The parsed JSON or None if parsing failed
    """
    # Strip any markdown code block syntax
    cleaned_text = re.sub(r'^```json\s*', '', response_text)
    cleaned_text = re.sub(r'\s*```$', '', cleaned_text)

    # Try to find JSON object if there's extra text
    json_pattern = r'\{[\s\S]*\}'
    json_match = re.search(json_pattern, cleaned_text)

    if json_match:
        json_str = json_match.group(0)
        try:
            return json.loads(json_str)
        except json.JSONDecodeError:
            return None

    # If no JSON object pattern found, try the whole text
    try:
        return json.loads(cleaned_text)
    except json.JSONDecodeError:
        return None

def check_reasoning(prompts, completions, **kwargs):
    """
    Evaluates the reasoning quality in model completions.

    Parameters:
    -----------
    prompts : list
        List of prompts where each prompt is a list of messages
    completions : list
        List of completions where each completion is a list of messages

    Returns:
    --------
    list: A list of reward scores
    """
    # Extract the question (story) from the prompt
    question = prompts[0][-1]["content"]

    # Extract the responses
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = []

    for r in responses:
        reasoning = match_format.search(r)
        if reasoning is not None:
            extracted_responses.append(reasoning.group(1))
        else:
            reasoning = think_pattern.search(r)
            if reasoning is not None:
                extracted_responses.append(reasoning.group(1))
            else:
                extracted_responses.append(None)

    # Get the system prompt for evaluation
    system_prompt = get_system_prompt()

    # Process each response and calculate rewards
    all_rewards = []

    for reasoning in extracted_responses:
        if reasoning is None:
            all_rewards.append(0)
            continue

        if len(reasoning) < 10:
            all_rewards.append(-1)
            continue


        # Try to extract reasoning if needed (adjust based on your model's output format)
        extracted_reasoning = reasoning

        # Generate the evaluation prompt
        user_prompt = generate_prompt(question, extracted_reasoning)

        # Try parsing and calculating rewards for up to 3 attempts
        reward_calculated = False
        max_attempts = 3

        for attempt in range(max_attempts):
            try:
                # Run the evaluation
                response = gpt_chat.run(
                    prompt=user_prompt,
                    system_prompt=system_prompt
                )

                # Parse the response
                assessment = parse_json_response(response)

                # If successfully parsed, try to calculate reward
                if assessment is not None:
                    try:
                        # Attempt to calculate reward
                        reward_results = calculate_optimized_reward(assessment)
                        all_rewards.append(reward_results["final_reward"]*10)
                        reward_calculated = True
                        break
                    except ValueError as e:
                        print(f"Reward calculation failed on attempt {attempt+1}: {str(e)}")
                else:
                    print(f"Failed to parse JSON on attempt {attempt+1}")

                    if attempt < max_attempts - 1:
                        # Modify the prompt to emphasize JSON formatting
                        user_prompt += "\n\nPlease ensure your response is in valid JSON format with numeric values only for all metrics."

            except Exception as e:
                print(f"Attempt {attempt+1} failed with error: {str(e)}")
                # Short delay before retrying to avoid rate limits
                time.sleep(1)

        # If we exited the loop without calculating a reward, add zero
        if not reward_calculated:
            print("Failed to calculate reward after all attempts. Returning zero reward.")
            all_rewards.append(0.0)

    return all_rewards



### Defining the teacher model

In [None]:
class Chat:
    def __init__(self, model_name: str, api_key: Optional[str] = None, base_url: Optional[str] = None, verbose: bool = False):
        self.model_name = model_name
        self.api_key = api_key
        self.base_url = base_url  # This can be None, will use default in _initialize_model
        self.verbose = verbose
        self.llm = self._initialize_model()


    def _initialize_model(self) -> Union[BaseChatModel, BaseLanguageModel]:
        """Initialize the appropriate language model based on the model name."""
        # Handle OpenAI models
        if "gpt" in self.model_name.lower() or "openai" in self.model_name.lower():
            from langchain_openai import ChatOpenAI

            return ChatOpenAI(
                model_name=self.model_name,
                openai_api_key=self.api_key,
                temperature=0.1
            )

        # Handle Anthropic models
        elif "anthropic" in self.model_name.lower() or "claude" in self.model_name.lower():
            from langchain_anthropic import ChatAnthropic
            return ChatAnthropic(
                model_name=self.model_name,
                anthropic_api_key=self.api_key,
                temperature=0.1
            )

        # Handle Ollama models
        elif "ollama" in self.model_name.lower() or self.model_name.lower() in ["gemma", "llama", "mistral", "vicuna", "phi"]:
            from langchain_ollama.llms import OllamaLLM
            # Extract the actual model name if "ollama:" prefix is used
            ollama_model = self.model_name
            if "ollama:" in self.model_name.lower():
                ollama_model = self.model_name.split(":", 1)[1]
            if self.base_url:
                return OllamaLLM(
                    model=ollama_model,
                    base_url=self.base_url,
                    temperature=0.1
                )
            else:
                return OllamaLLM(
                    model=ollama_model,
                    temperature=0.1
                )
        else:
            error_msg = f"Unsupported model: {self.model_name}"
            raise ValueError(error_msg)

    def run(self, prompt: str, system_prompt: Optional[str] = None) -> str:
        """Run a completion using a chain."""

        # Create prompt template based on whether system prompt is provided
        if system_prompt:
            prompt_template = ChatPromptTemplate.from_messages([
                ("system", system_prompt),
                ("human", "{input}")
            ])
        else:
            prompt_template = ChatPromptTemplate.from_messages([
                ("human", "{input}")
            ])
        # Create and invoke the chain
        chain = prompt_template | self.llm
        try:
            response = chain.invoke({"input": prompt})

            if hasattr(response, 'content'):
                return response.content
            else:
                return str(response)
        except Exception as e:
            error_msg = f"Error invoking chain: {str(e)}"
            raise RuntimeError(error_msg)


In [None]:
MODEL_NAME = "gpt-4o-mini"  # or "ollama/llama2"
API_KEY = ""

gpt_chat = Chat(model_name=MODEL_NAME, api_key=API_KEY, verbose=False)

### Preparing teh GRPO dataset

In [None]:
n = 1000
wrong_subset = dataset[dataset['expected_answer'] == "WRONG"].head(n//2)
right_subset = dataset[dataset['expected_answer'] == "RIGHT"].head(n//2)
combined_dataset = pd.concat([wrong_subset, right_subset], ignore_index=True)
combined_dataset = combined_dataset.sample(frac=1).reset_index(drop=True)

grpo_dataset = Dataset.from_pandas(combined_dataset)

grpo_dataset = grpo_dataset.map(lambda x: {
    "prompt" : [
        {"role": "system", "content": system_prompt},
        {"role": "user",   "content": x["problem"]},
    ],
    "answer": x["expected_answer"],
})
grpo_dataset[0]

In [None]:
tokenized = grpo_dataset.map(
    lambda x: {"tokens" : tokenizer.apply_chat_template(x["prompt"], add_generation_prompt = True, tokenize = True)},
    batched = True,
)
print(tokenizer.decode(tokenized[0]["tokens"]))
tokenized = tokenized.map(lambda x: {"L" : len(x["tokens"])})

maximum_length = int(np.quantile(tokenized["L"], 0.9))
print("Max Length = ", maximum_length)

# Filter only samples smaller than 90% max length
grpo_dataset = grpo_dataset.select(np.where(np.array(tokenized["L"]) <= maximum_length)[0])

# Reset the indices to be continuous
grpo_dataset = grpo_dataset.select(range(len(grpo_dataset)))
del tokenized


### Trainer Setup



In [None]:
max_prompt_length = maximum_length + 1 # + 1 just in case!
max_completion_length = max_seq_length - max_prompt_length

from vllm import SamplingParams
vllm_sampling_params = SamplingParams(
    min_p = 0.1,
    top_p = 1.0,
    top_k = -1,
    seed = 3407,
    stop = [tokenizer.eos_token],
    include_stop_str_in_output = True,
)

training_args = GRPOConfig(
    vllm_sampling_params = vllm_sampling_params,
    temperature = 1.0,
    learning_rate = 5e-6,
    weight_decay = 0.01,
    warmup_ratio = 0.1,
    lr_scheduler_type = "linear",
    optim = "adamw_8bit",
    logging_steps = 1,
    #per_device_train_batch_size = 1,
    gradient_accumulation_steps = 8, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 1000,
    save_steps = 25,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",

    # For optional training + evaluation
    # fp16_full_eval = True,
    per_device_eval_batch_size = 4,
    # eval_accumulation_steps = 1,
    # eval_strategy = "steps",
    # eval_steps = 1,
)

### Train the model

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        match_format_exactly,
        match_format_approximately,
        check_reasoning,
        check_verdict,
    ],
    args = training_args,
    train_dataset = grpo_dataset,
)
trainer.train() #resume_from_checkpoint = True)


## Inference and Save Model
Now let's try the model we just trained! First, let's first try the model without any GRPO trained:

In [None]:
text = "What is the sqrt of 101?"

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 1.0,
    top_k = 50,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

And now with the LoRA we just trained with GRPO - we first save the LoRA first!

In [None]:
model.save_lora("grpo_saved_lora")

Verify LoRA is actually trained!

In [None]:
from safetensors import safe_open

tensors = {}
with safe_open("grpo_saved_lora/adapter_model.safetensors", framework = "pt") as f:
    # Verify both A and B are non zero
    for key in f.keys():
        tensor = f.get_tensor(key)
        n_zeros = (tensor == 0).sum() / tensor.numel()
        assert(n_zeros.item() != tensor.numel())

Now we load the LoRA and test:

In [None]:
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user",   "content": "What is the sqrt of 101?"},
]

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    tokenize = False,
)
from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 1.0,
    top_k = 50,
    max_tokens = 2048,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output

Our reasoning model is much better - it's not always correct, since we only trained it for an hour or so - it'll be better if we extend the sequence length and train for longer!


### Saving to float16 for VLLM


In [None]:
# Merge to 16bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")

# Merge to 4bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# Just LoRA adapters
if False: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "lora", token = "")

### GGUF / llama.cpp Conversion

* `q8_0` - Fast conversion. High resource use, but generally acceptable.
* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.
* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.



In [None]:
# Save to 8bit Q8_0
if False: model.save_pretrained_gguf("model", tokenizer,)
# Remember to go to https://huggingface.co/settings/tokens for a token!
# And change hf to your username!
if False: model.push_to_hub_gguf("hf/model", tokenizer, token = "")

# Save to 16bit GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")

# Save to q4_k_m GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")

# Save to multiple GGUF options - much faster if you want multiple!
if False:
    model.push_to_hub_gguf(
        "hf/model", # Change hf to your username!
        tokenizer,
        quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
        token = "",
    )