# Large language models (LLMs) post-training

![logo](imgs/for_printing.jpg)


All checkpoints from this notebook can be downloaded [here](https://drive.google.com/drive/folders/1fgZQCNbbEyRaj0UFdtnhGF-upBRkwJPY?usp=sharing).

Table of content:
- LLM lifecycle overview
- In-context learning
- Supervised Fine Tuning (SFT)
- Reinforcement learning (RL)
- Direct Preference Optimization (DPO)
- Reasoning models
- Group Relative Policy Optimization (GRPO)

## LLM lifecircle overview

In a standard LLM lerning pipeline there are two main steps:
1. Pre-training 
2. Post-training

![LLM lifecircle](imgs/LLM-overview.png)

For a deeper dive into this topic, I recommend watching the video of Andrej Karpathy: [State of GPT](https://www.youtube.com/watch?v=bZQun8Y4L2A)

We usually get two models during the release: base and instruct. For example, there are two versions of the same Llama 3.2 1B (1 billion parameters) model:
- Base: [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)
- Instruct [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct)

Note: Be careful, depending on the company, the models may have different namings. For example, in the Qwen3 family, the base model is named as [Qwen/Qwen3-0.6B-Base](https://huggingface.co/Qwen/Qwen3-0.6B-Base) and instruct is simply named as [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B)

### Base model

Base models are trained only to continue the text and are not trained to follow user instructions. They are usually used by ML researchers for developing AI algorithms. Let's take a look at one of them

In [19]:
import os
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "Qwen/Qwen2.5-0.5B"

# To infer a model we nned to first tokenize our text.
# It means that we convert the plain text to a format understandable for LLMs. 
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
base_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

text = "Hey! What's the capital of Great Britain?"
input_ids = tokenizer(
    text,
    return_tensors="pt",  # pytorch format of output tensors
)['input_ids']

generated_ids = base_model.generate(
    input_ids.to(base_model.device),
    max_new_tokens=64,
)

# Convert generated tokens back to human readable text
generated_text = tokenizer.decode(generated_ids[0])
print(f"Base model answer:\n{generated_text}")

del base_model

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Base model answer:
Hey! What's the capital of Great Britain? A. London B. Paris C. New York D. Tokyo
London
Paris
New York
Tokyo
答案:
A

下列关于“三会一课”制度的说法，正确的是____。
A. 党支部应当组织党员按期参加党员大会、党小组会和上党课


### Instruct model

These models are trained to be helpful assistants for people with different questions. Every instruct model is trained to work with a specialized format, which is described in a chat template. This template is pre-installed in the model's tokenizer class and can be easily applied by the method `apply_chat_template`.

#### Chat format:
Usually, every reply or message for an instruct model is marked with one of three roles:
- System: The first reply which provides the model with a general instruction and information about what we expect from it. For example:
    - `"Your are a helpful assistant. Reply in a truthful and polite manner"`
    - `"Imagine that you are a financial analyst. Give a truthful and correct financial advice."` -- Don't do that!
- User: The exact question that we want to ask our model about. For example:
    - `"What's the capital of Great Britain?"`
    - `"How many words are there in the song 'The Scientist' by Coldplay?"`
- Assistant: Replies from your LLM. 

Note: Usually, instruct models expect User and Assistant roles to alternate. This is a good practice when prompting the model.

In [29]:
import os
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "Qwen/Qwen2.5-0.5B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
instruct_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

messages = [
    {"role": "user", "content": "Hey! What's the capital of Great Britain?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,  # Add the start of the assistant's replica at the end
    return_tensors="pt",  # pytorch format of output tensors
)

generated_ids = instruct_model.generate(
    input_ids.to(instruct_model.device),
    max_new_tokens=64,
)

# Convert generated tokens back to human readable text
generated_text = tokenizer.decode(generated_ids[0])
print(f"Instruct model answer:\n{generated_text}")

Instruct model answer:
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Hey! What's the capital of Great Britain?<|im_end|>
<|im_start|>assistant
The capital of Great Britain is London.<|im_end|>


As you can see, there is a default system prompt for Qwen2.5 family instruct models:
- `"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."`

Also, you can notice that every message starts with a certain format, which is described in the chat template.

In [21]:
print(f"Chat template of Qwen2 models:\n{tokenizer.chat_template}")

Chat template of Qwen2 models:
{%- if tools %}
    {{- '<|im_start|>system\n' }}
    {%- if messages[0]['role'] == 'system' %}
        {{- messages[0]['content'] }}
    {%- else %}
        {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}
    {%- endif %}
    {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
    {%- for tool in tools %}
        {{- "\n" }}
        {{- tool | tojson }}
    {%- endfor %}
    {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
    {%- if messages[0]['role'] == 'system' %}
        {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
    {%- else %}
        {{- '<|im_start|>system\nYou

It's not in a human readable format, but it's more like an algorithm for creating the input text expected by the model.

We can also remove all special symbols, like `<|im_start|>` etc., from the final answer by setting `skip_special_tokens=True`.

In [None]:
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(f"Instruct model answer without special symbols:\n{generated_text}")

Instruct model answer without special symbols:
system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
user
Hey! What's the capital of Great Britain?
assistant
The capital of Great Britain is London.


In [27]:
# Or even better, leave only generated symbols:
prompt_length = input_ids.shape[1]
generated_text = tokenizer.decode(generated_ids[0][prompt_length:], skip_special_tokens=True)
print(f"Instruct model answer without special symbols:\n{generated_text}")

Instruct model answer without special symbols:
The capital of Great Britain is London.


By providing a list of sample question-and-answer pairs, we can show the model what we expect to see in an answer. It's useful when we want to get answers in a specific format:

In [58]:
messages = [
    {"role": "user", "content": "What's the capital of Great Britain?"},
    {"role": "assistant", "content": "**London**"},
    {"role": "user", "content": "What's the capital of France?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,  # Add the start of the assistant's replica at the end
    return_tensors="pt",  # pytorch format of output tensors
)

prompt_length = input_ids.shape[1]

generated_ids = instruct_model.generate(
    input_ids.to(instruct_model.device),
    max_new_tokens=64,
)

# Convert generated tokens back to human readable text
generated_text = tokenizer.decode(generated_ids[0][prompt_length:], skip_special_tokens=True)
print(f"Instruct model answer in special format:\n{generated_text}")

del instruct_model

Instruct model answer in special format:
The capital of France is **Paris**.


## Question:  What should we do if we want better performance on a specific task?
.

.

.

.

.

.

.

.

.

.

.


.

.

.

.


## In-context learning

One of the simplest ways to enhance the performance of a LLM is in-context learning. In this method, we add some useful information to System or User messages to get better performance.

## Question: Is this method really effective? What do you think?

.

.

.

.

.

.

.

.

.

.

.

.

Let's find out! And we need to start with evaluation.

In [None]:
import os
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

# API provider keys are needed to calculate the 'oracle' metric and for dataset generation.
# In my case, I use Nebius AI Studio, but you can change it to any other provider that supports the OpenAI API.
# os.environ['ORACLE_BASE_URL'] = "https://api.studio.nebius.ai/v1/"
# os.environ['ORACLE_API_KEY'] = ""

import torch

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

from src.metrics import ModelEvaluator

# To allow run metrics calculator in notebook
import nest_asyncio
nest_asyncio.apply()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# This dataset contains errors as it described at https://www.logical-fallacy.com/articles/dataset-review/
DATASET_NAME = "tasksource/logical-fallacy"

dataset = load_dataset(
  DATASET_NAME,
  revision="main",
)

In [3]:
promt_id = 0

user_prompt = dataset['test'][promt_id]['source_article']
answer = dataset['test'][promt_id]['logical_fallacies']

print(
    f"Statement: {user_prompt}\nLogical fallacies: {answer}"
)

Statement: People who drive big cars probably hate the environment.
Logical fallacies: fallacy of extension


In [4]:
SYSTEM_PROMPT = "You are an expert in critical thinking. Analyze the following text, identify which logical fallacy it contains, and write only the name of this logical fallacy."
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype=torch.bfloat16,
    device_map="auto",
)

In [5]:
messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": user_prompt},
]
# Get the tokeniser of the model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)

# Create the input prompt based on the roles
input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,  # Add the start of the assistant's replica at the end
    return_tensors="pt",  # pytorch format of output tensors
)

# Generate a response
prompt_length = input_ids.shape[1]

generated_ids = model.generate(
    input_ids.to(model.device),
    max_new_tokens=512,
    do_sample=False,
)

assistant_ids = generated_ids[0][prompt_length:]
response = tokenizer.decode(assistant_ids, skip_special_tokens=True, )
print(f"Prompt: {user_prompt}\nModel answer: {response}\nGround truth: {answer}")


Prompt: People who drive big cars probably hate the environment.
Model answer: Hasty generalization
Ground truth: fallacy of extension


## Question: How can we evaluate the correctness of a LLM's output?

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

We will use the following metrics (or a set of them):
- Exact match (`em`): A binary, all-or-nothing metric showing if the model response exactly matches the correct answer.
- [F1-score](https://en.wikipedia.org/wiki/F-score) (`f1`): Measures the average word overlap between the prediction and the ground truth. (Simplified explanation). 
- Oracle (`oracle`): A binary evaluation by a large LLM to determine if the answer is correct or not.
- Length (`len`): Average length of the LLM's answer.

In [15]:
# Print all Logical fallacies from dataset:
print("Logical fallacies presented in dataset:")

true_labels = dataset['test'].unique('logical_fallacies')
for label in true_labels:
    print(label)

Logical fallacies presented in dataset:
fallacy of extension
faulty generalization
fallacy of logic
false causality
fallacy of credibility
circular reasoning
ad hominem
ad populum
intentional
fallacy of relevance
appeal to emotion
equivocation
false dilemma


In [7]:
import transformers.data.metrics.squad_metrics as squad_metrics

first_fallacy = "fallacy of extension"
second_fallacy = "fallacy of logic"

score = squad_metrics.compute_f1(first_fallacy, second_fallacy)
print(f"F1 score between '{first_fallacy}' and '{second_fallacy}': {score}")


first_fallacy = "fallacy of logic"
second_fallacy = "Logic fallacy"

score = squad_metrics.compute_f1(first_fallacy, second_fallacy)
print(f"F1 score between '{first_fallacy}' and '{second_fallacy}': {score}")

F1 score between 'fallacy of extension' and 'fallacy of logic': 0.6666666666666666
F1 score between 'fallacy of logic' and 'Logic fallacy': 0.8


In [12]:
ORACLE_SYSTEM_PROMPT = """You are a critical thinking oracle specializing in identifying logical fallacies. You will receive a statement and a proposed logical fallacy. Your task is to evaluate if the proposed logical fallacy accurately describes the error in reasoning within the statement.

You must respond exclusively in a valid JSON format that conforms to the following schema:
{
    "type": "object",
    "properties": {
        "is_correct": {
            "title": "Is Correct",
            "type": "boolean",
            "description": "A boolean value, 'true' if the logical fallacy correctly identifies the error in the statement, and 'false' otherwise."
        }
    },
    "title": "OracleAnswer",
    "required": ["is_correct"]
}

Do not include any text, markdown formatting, or explanations outside of the JSON object. Your entire output must be the JSON itself."""


evaluator = ModelEvaluator(model, system_prompt=SYSTEM_PROMPT)
metrics = evaluator.eval(
    DATASET_NAME,
    metrics=["f1", "em", "oracle"],
    oracle_system_prompt=ORACLE_SYSTEM_PROMPT,
    oracle_model_name="deepseek-ai/DeepSeek-R1-0528",
)

for key, value in metrics.items():
    print(f"Metric: {key}, Score: {value:.3}")

Model eval: 100%|██████████| 64/64 [00:34<00:00,  1.87it/s]
Running Oracle Evaluation: 100%|██████████| 511/511 [00:09<00:00, 54.59it/s]

Metric: f1, Score: 0.0628
Metric: em, Score: 0.0137
Metric: oracle, Score: 0.112





For in-context learning, let's add one example of each class from the train split to a system prompt.

In [26]:
all_labels = dataset['train']['logical_fallacies']
unique_labels = dataset['test'].unique('logical_fallacies')
first_indices = [all_labels.index(label) for label in unique_labels]

unique_fallacies_dataset = dataset['train'].select(first_indices)

enhanced_system_prompt = [SYSTEM_PROMPT, "\n--- EXAMPLES ---"]
for example in unique_fallacies_dataset:
    formatted_example = (
        f"Text: \"{example['source_article']}\"\n"
        f"Logical fallacy: {example['logical_fallacies']}"
    )
    enhanced_system_prompt.append(formatted_example)

enhanced_system_prompt.append("--- END OF EXAMPLES ---")

enhanced_system_prompt = "\n\n".join(enhanced_system_prompt)


print(f"New system prompt: {enhanced_system_prompt}")

New system prompt: You are an expert in critical thinking. Analyze the following text, identify which logical fallacy it contains, and write only the name of this logical fallacy.


--- EXAMPLES ---

Text: "John: I think we should hire someone to redesign our website.
Lola: You're saying we should throw our money away on external resources instead of building up our in-house design team? That's going to hurt our company in the long run."
Logical fallacy: fallacy of extension

Text: "If we ban Hummers because they are bad for the environment, eventually the government will ban all cars, so we should not ban Hummers."
Logical fallacy: faulty generalization

Text: "If Joe eats greasy food, he will feel sick.
Joe feels sick.
Therefore, Joe ate greasy food."
Logical fallacy: fallacy of logic

Text: "The bigger a child's shoe size, the better the child's handwriting"
Logical fallacy: false causality

Text: "This herbal supplement is made from a plant that grows in Zambia. It must be healthie

In [27]:
evaluator = ModelEvaluator(model, system_prompt=enhanced_system_prompt)
metrics = evaluator.eval(
    DATASET_NAME,
    metrics=["f1", "em", "oracle"],
    oracle_system_prompt=ORACLE_SYSTEM_PROMPT,
    oracle_model_name="deepseek-ai/DeepSeek-R1-0528",
)

for key, value in metrics.items():
    print(f"Metric: {key}, Score: {value:.3}")

Model eval: 100%|██████████| 64/64 [00:42<00:00,  1.51it/s]
Running Oracle Evaluation: 100%|██████████| 511/511 [00:10<00:00, 51.01it/s]

Metric: f1, Score: 0.169
Metric: em, Score: 0.0587
Metric: oracle, Score: 0.121





So, as we can see, this approach boosts `f1` and `em` metrics but the `oracle` metric is almost the same. 

## Question: Why did in-context learning improve `f1` and `em` metrics but not the `oracle` metric?

Can we improve our metrics even further? Yes, we can!

## Supervised fine-tuning

`Supervised fine-tuning` (SFT) is a phase of post-training that adapts an LLM to a specific task with a small specialized dataset. [HF example how to use it](https://huggingface.co/docs/transformers/en/training). For SFT we usually use a small dataset specialized for a specific task. This method is very popular for improving LLM performance on a specific task and DOES NOT require much compute.

To fine-tune an LLM we first need to collect data. For SFT we collect pairs: `(prompt, ideal_response)`.
- The prompt is the question or instruction you give the model.
- The ideal_response is the high-quality, perfect answer you want the model to learn to generate.

There are a few ways to create this dataset:
- Collected by humans: This approach provides the highest quality and most nuanced data, but it is also the most expensive and time-consuming.
- Generated by a more powerful LLM: This is fast and cheap, but the data can contain mistakes or reflect the biases of the model that created it.
- A combination of both: Often the best method is to have a powerful LLM generate the data and then have humans review and correct it, giving you a good balance of speed, cost, and quality.


There are many open-sourced datasets on different topics, for example our `tasksource/logical-fallacy`. You can find a lot of different datasets on Hugging Face via search. But you need to be careful, because they have different sources and can contain errors, class imbalances, etc.

To get a good result after SFT, you must be sure the dataset is balanced. This means every category or class you want the model to learn is represented fairly. For example, if you train an LLM on a dataset of 100,000 customer reviews where 90,000 are positive and only 10,000 are negative, the model will develop a strong bias. It will learn that guessing "positive" is almost always the right answer and will frequently choose that option, even for negative reviews.

In [None]:
import os

# Needed to train on GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

# API provider keys are needed to calculate the 'oracle' metric and for dataset generation.
# In my case, I use Nebius AI Studio, but you can change it to any other provider that supports the OpenAI API.
# os.environ['ORACLE_BASE_URL'] = "https://api.studio.nebius.ai/v1/"
# os.environ['ORACLE_API_KEY'] = "..."

import torch

from collections import Counter
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel
from trl import SFTConfig, SFTTrainer

from src.metrics import ModelEvaluator

# To allow run metrics calculator in notebook
import nest_asyncio
nest_asyncio.apply()

Let's look at class balances in logical-fallacy dataset:

In [2]:
# This dataset contains errors as it described at https://www.logical-fallacy.com/articles/dataset-review/
DATASET_NAME = "tasksource/logical-fallacy"

dataset = load_dataset(
  DATASET_NAME,
  revision="main",
)

# Count the frequency of each logical fallacy
all_labels = dataset['train']['logical_fallacies']
label_counts = Counter(all_labels)

print("Logical fallacy counts:")
for label, count in label_counts.items():
    print(f"- {label}: {count}")

Logical fallacy counts:
- appeal to emotion: 217
- false causality: 212
- ad populum: 209
- circular reasoning: 140
- fallacy of relevance: 175
- faulty generalization: 401
- ad hominem: 289
- fallacy of extension: 139
- equivocation: 58
- fallacy of logic: 176
- fallacy of credibility: 200
- intentional: 321
- false dilemma: 143


As you can see, our dataset is not balanced. To address this, we can do one of two things:
- **Oversampling**. In this method, we can gather more data for under-represented classes, such as 'equivocation' (58), 'fallacy of extension' (139), 'circular reasoning' (140), and 'false dilemma' (143). 
- **Undersampling**. This method reduces the number of examples in the over-represented classes. While this can balance the dataset, you risk removing important information that the model could learn from. For example: 'faulty generalization' (401), 'intentional' (321).

Let's first train SFT on the original dataset.

For that, we need to make our dataset compatible with Hugging Face's SFTTrainer. You can read more here: https://huggingface.co/docs/trl/en/dataset_formats#working-with-conversational-datasets-in-trl

In [4]:
SYSTEM_PROMPT = "You are an expert in critical thinking. Analyze the following text, identify which logical fallacy it contains, and write only the name of this logical fallacy."

def preprocess_function(example):
    return {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": example['source_article']},
        ],
        "completion": [
            {"role": "assistant", "content": example['logical_fallacies']}
        ]
    }

# Preprocessing the dataset
processed_dataset = dataset['train'].map(preprocess_function, remove_columns=["config", "source_article", "logical_fallacies"])

print(f"Original dataset example:\n{dataset['train'][0]}\n\n")
print(f"The same example in processed dataset:\n{processed_dataset[0]}")

Original dataset example:
{'config': 'edu', 'source_article': 'company\'s slogan "Expect More. Pay Less."', 'logical_fallacies': 'appeal to emotion'}


The same example in processed dataset:
{'prompt': [{'content': 'You are an expert in critical thinking. Analyze the following text, identify which logical fallacy it contains, and write only the name of this logical fallacy.', 'role': 'system'}, {'content': 'company\'s slogan "Expect More. Pay Less."', 'role': 'user'}], 'completion': [{'content': 'appeal to emotion', 'role': 'assistant'}]}


In [5]:
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"

args = SFTConfig(
    output_dir="./checkpoints/sft",
    # Do not report training metrics 
    report_to="none",
    # Batch size per GPU
    per_device_train_batch_size=16,
    # Controls the size of the steps taken during training.
    learning_rate=2e-5,
    # Number of times we will go through our dataset
    num_train_epochs=1,
    # Train only on completions or assistant replicas
    completion_only_loss=True,
)

trainer = SFTTrainer(
    model=MODEL_NAME,
    args=args,
    train_dataset=processed_dataset,
)

trainer.train()

Tokenizing train dataset: 100%|██████████| 2680/2680 [00:01<00:00, 1894.19 examples/s]
Truncating train dataset: 100%|██████████| 2680/2680 [00:00<00:00, 235329.20 examples/s]


{'loss': 0.916, 'grad_norm': 27.853498458862305, 'learning_rate': 1.892857142857143e-05, 'entropy': 2.564848256111145, 'num_tokens': 13372.0, 'mean_token_accuracy': 0.8271097838878632, 'epoch': 0.05952380952380952}
{'loss': 0.373, 'grad_norm': 18.72037696838379, 'learning_rate': 1.7738095238095237e-05, 'entropy': 2.7334084033966066, 'num_tokens': 27237.0, 'mean_token_accuracy': 0.8696904420852661, 'epoch': 0.11904761904761904}
{'loss': 0.3304, 'grad_norm': 13.7752046585083, 'learning_rate': 1.6547619047619046e-05, 'entropy': 2.6462952136993407, 'num_tokens': 40701.0, 'mean_token_accuracy': 0.8765043675899505, 'epoch': 0.17857142857142858}
{'loss': 0.2828, 'grad_norm': 15.849040031433105, 'learning_rate': 1.535714285714286e-05, 'entropy': 2.5765918254852296, 'num_tokens': 54459.0, 'mean_token_accuracy': 0.8981055080890655, 'epoch': 0.23809523809523808}
{'loss': 0.2865, 'grad_norm': 17.42316246032715, 'learning_rate': 1.416666666666667e-05, 'entropy': 2.546259617805481, 'num_tokens': 677

TrainOutput(global_step=168, training_loss=0.28065220089185805, metrics={'train_runtime': 29.4327, 'train_samples_per_second': 91.055, 'train_steps_per_second': 5.708, 'train_loss': 0.28065220089185805, 'entropy': 2.655073642730713, 'num_tokens': 226729.0, 'mean_token_accuracy': 0.9364175200462341, 'epoch': 1.0})

In [13]:
def get_logical_fallacy(model: PreTrainedModel, user_prompt: str, system_prompt: str):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    # Get the tokeniser of the model
    tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=False)

    # Generate the input prompt based on the roles
    input_ids = tokenizer.apply_chat_template(
        messages,
        # Add the start of the assistant's replica at the end
        add_generation_prompt=True,
        return_tensors="pt",  # pytorch format of output tensors
    )

    prompt_length = input_ids.shape[1]

    generated_ids = model.generate(
        input_ids.to(model.device),
        max_new_tokens=512,
        do_sample=False,
    )

    assistant_ids = generated_ids[0][prompt_length:]
    response = tokenizer.decode(assistant_ids, skip_special_tokens=True)
    return response


sft_model = AutoModelForCausalLM.from_pretrained(
    "./checkpoints/sft/checkpoint-168",
    dtype=torch.bfloat16,
    device_map="auto",
)

data_id = 3
user_prompt = dataset['test'][data_id]['source_article']
answer = dataset['test'][data_id]['logical_fallacies']
response = get_logical_fallacy(sft_model, user_prompt, SYSTEM_PROMPT)
print(f"Prompt: {user_prompt}\nModel response: {response}\nCorrect answer: {answer}")

Prompt: "Why are you hitting your computer!?"
"The last time my wifi was weak, I hit my computer and it got better." What is this?
Model response: false causality
Correct answer: false causality


In [15]:
evaluator = ModelEvaluator(sft_model, system_prompt=SYSTEM_PROMPT)
metrics = evaluator.eval(
    DATASET_NAME,
    metrics=["f1", "em", "oracle"],
    oracle_model_name="deepseek-ai/DeepSeek-R1-0528",
)

for key, value in metrics.items():
    print(f"Metric: {key}, Score: {value:.3}")

Model eval: 100%|██████████| 64/64 [00:39<00:00,  1.64it/s]
Running Oracle Evaluation: 100%|██████████| 511/511 [00:10<00:00, 49.99it/s]

Metric: f1, Score: 0.498
Metric: em, Score: 0.47
Metric: oracle, Score: 0.354





## Question: What can we do better to improve model performance even more?

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

## Reinforcement learning


**Reinforcement Learning** (RL) is a powerful way of training a model through trial and error, much like teaching a dog a new trick. You give a command, it performs an action, and you give it a treat if it's the right one. Over time, the dog learns to perform the correct action to maximize its treats.

For Large Language Models (LLMs), RL is a method to fine-tune their behavior. It allows us to provide feedback on their answers, steering them toward a desired outcome.

Here are a couple of key uses for RL with LLMs:
- **Alignment**. This is about teaching the LLM to be a good digital citizen. Sometimes, models can generate harmful, biased, or simply made-up answers (called "hallucinations"). RL is a powerful tool to discourage this behavior and "align" the model with human values.
- **Improving Skills**. We can also use RL to make an LLM better at specific tasks, like following complex instructions, writing in a particular style (e.g., more professional or more humorous), or generating better code.

Generally speaking, RL is used far beyond LLMs. It's the technology that has taught machines to master complex games like Chess and Go, and helped robots learn to walk. Here’s a simple diagram of the process:


![RL](imgs/reinforcement-learning-figure-1.png)
Source: https://www.ibm.com/think/topics/reinforcement-learning


Where:
- Agent: The learner or decision-maker. In our case, this is the LLM itself.
- Environment: The world the agent interacts with. For an LLM, this is the task context, like a chat window. It provides the prompt and receives the output.
- State $`S_t`$: A snapshot of the current situation. For the LLM, this is the prompt plus all the text it has generated so far.
- Action $`A_t`$: A move made by the agent. For the LLM, this is the generation of the very next word (or token). A complete response is simply a series of these actions.
- Reward $`R_t`$: The feedback signal, or the "treat." This is a score that evaluates the quality of the agent's final, complete response. This score often comes from a separate "reward model" that is trained to predict which answers a human would prefer.
- Policy $`\pi_t`$: The agent's strategy or "brain" that decides what action to take in a given state. For an LLM, this is its internal knowledge, represented by the model's weights.
---
### How Reinforcement Learning from Human Feedback Works

Reinforcement Learning from Human Feedback (RLHF) is a common way to apply reinforcement learning to LLMs, which involves several main steps:
1. Pre-train the base model.
2. Generate responses: For various prompts, generate pairs of responses.
3. Collect human feedback: Have humans rank these pairs according to their preferences.
4. Train a reward model: Train a model that mimics these human rankings.
5. Train the LLM: Fine-tune the LLM to generate responses that receive a high reward from the reward model. (See the picture below).

Note: During the last stage, we add a term to the loss function, which prevents the output distribution of the result model from deviating too far from the distribution of the original model. In other words, this rule ensures that as the new AI model learns its latest task, it doesn't forget the core knowledge of the original model. It keeps the model's responses sensible and prevents them from becoming too strange.

---
Today, RLHF approaches can be split into two main camps. To understand them, let's use a chess analogy:
- **On-Policy**: The model learns by doing. It's like a chess player improving by actively playing new games. It makes moves, gets feedback (wins or loses), and updates its strategy based on its own, recent experiences. Proximal Policy Optimization (PPO) is a popular on-policy method.
- **Off-Policy**: The model learns by observing. It's like a chess student improving by studying a database of grandmaster games. The student analyzes moves made by other experts (another policy) to learn what makes a good or bad move, without having to play the games themselves.

Note: While these aren't strict technical definitions, they provide a common-sense way to think about the difference. Generally speaking, an on-policy method is one where the model obtains rewards for generations created by its current policy. In contrast, an off-policy method can learn from responses that were generated at a previous step or even earlier, like the original policy.

Source: [article](https://huggingface.co/blog/NormalUhr/rlhf-pipeline) (For a deep dive to RL for LLM I will highly recommend this article, but it's technical and contains a lot of math).

## Direct Preference Optimization

![DPO](imgs/DPO_picture.png)
Source: https://arxiv.org/pdf/2305.18290

**Direct Preference Optimization** (DPO) is a clever and efficient off-policy method. The core idea is that instead of a complex feedback loop of playing games and getting rewards, we can just teach the model directly from a "handbook" of good and bad examples.

In contrast to the original RLHF idea, where you first train a separate reward model to score answers and then train the main LLM to get the high score from that Reward model, DPO takes a shortcut. It directly teaches the main LLM to distinguish between good and bad responses. The model learns to increase the likelihood of generating the `chosen` or positive answers while decreasing the likelihood of generating the `rejected` or negative ones, all in a single, more efficient step.

To train with DPO, we need a special kind of dataset: a preference dataset. Each item in this dataset contains three parts:
- `prompt`: The initial question or instruction.
- `chosen`: A high-quality, preferred answer to the prompt.
- `rejected`: A less-good answer to the same prompt.

By training on thousands of these examples, the model learns a very simple but powerful lesson: "Given this prompt, I should increase the probability of generating an answer like the chosen one and decrease the probability of generating an answer like the rejected one." It’s like showing a student two solutions to a problem: one correct and one with a common mistake and simply saying, "Do this, not that."

There are a few ways to create a DPO dataset. Each has its pros and cons.
1. **Use Your Model, Rank by Hand**: Your model generates answers, and a human or small LLM picks the best (`chosen`) and worst (`rejected`).
    - Pros: Sets realistic goals, as the model learns from its own abilities.
    - Cons: If your model is weak, you're just teaching it to be "less bad," not great.
2. **Use Your Model, Enhance with a Pro Model**: Your model generates an answer (`rejected`), and a more powerful model improves it (`chosen`).
    - Pros: Provides a high-quality "gold standard" for your model to aim for.
    - Cons: The quality gap between answers might be too large, making the lesson too difficult for your model to learn.
3. **Use a Pro Model for Everything**: A powerful model generates both the good (`chosen`) and bad (`rejected`) answers.
    - Pros: The fastest way to generate a large, high-quality dataset.
    - Cons: The data is in another model's "style," which might not be a good fit for your model

For this objective, lets teach our model to generate the logical fallacy and a small explanation for that.

In [None]:
import os

# Needed to train on GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

# API provider keys are needed to calculate the 'oracle' metric and for dataset generation.
# In my case, I use Nebius AI Studio, but you can change it to any other provider that supports the OpenAI API.
# os.environ['ORACLE_BASE_URL'] = "https://api.studio.nebius.ai/v1/"
# os.environ['ORACLE_API_KEY'] = ""

import torch

from pydantic import BaseModel
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel
from trl import SFTConfig, SFTTrainer

from src.metrics import ModelEvaluator
from src.generate_dataset import generate_preprocessed_dataset

from trl import DPOConfig
# The version with a fixed bug
from src.dpo_trainer import DPOTrainer

# To allow run metrics calculator in notebook
import nest_asyncio
nest_asyncio.apply()

  from .autonotebook import tqdm as notebook_tqdm


Fisrt step: Genearte SFT dataset to teach the model to genearte small description.

In [None]:
SYSTEM_PROMPT_WITH_EXPLANATION = "You are an expert in critical thinking. Analyze the following text, identify which logical fallacy it contains, and write the name of this logical fallacy and a short explanation."

DATASET_NAME = "tasksource/logical-fallacy"

dataset = load_dataset(
  DATASET_NAME,
  revision="main"
)

In [None]:
# Code to create sft dataset with explanation

oracle_system_prompt = """
You are an AI assistant creating educational data for a Supervised Fine-Tuning (SFT) dataset. Your task is to provide a standard, helpful explanation of a logical fallacy.

For the given text and fallacy, generate a single, balanced explanation. The explanation should **name the fallacy, briefly define it, and connect it to the text**. The total length should be around two sentences.

Please provide the output as a single string.

---
**EXAMPLE**
**Input Text:** "You can't trust Professor Jones's theory on economics because he's a known socialist."
**Fallacy:** "ad hominem"

**Your Explanation:**
The text contains an **ad hominem** fallacy. This is when an argument attacks the person making it, which is seen here as the speaker focuses on Professor Jones's personal beliefs rather than his theory.
---

**YOUR TASK**
**Input Text:** "{{text_with_fallacy}}"
**Fallacy:** "{{fallacy_label}}"

**Your Explanation:**
"""

class StructuredSFTResponse(BaseModel):
    explanation: str


# The function that creates the user_prompt to big LLM
def row_preprocessor(data_row: dict[str, str]):
    return f"Input Text: {data_row['source_article']}, Fallacy: {data_row['logical_fallacies']}"


# The function that geneartes data row for our small model, based on the response
def row_postprocessor(
  response: StructuredSFTResponse,
  system_prompt: str,
  data_row: dict[str, str],
):
    return {
        "prompt": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": data_row['source_article']},
        ],
        "completion": [
            {"role": "assistant", "content": response.explanation}
        ]
    }

# Uncomment this to regenerate the dataset
# generate_preprocessed_dataset(
#   original_dataset=dataset['train'],
#   response_model=StructuredSFTResponse,
#   row_preprocessor=row_preprocessor,
#   row_postprocessor=row_postprocessor,
#   oracle_system_prompt=oracle_system_prompt,
#   dataset_system_prompt=SYSTEM_PROMPT_WITH_EXPLANATION,
#   result_dataset_name="sft_explanation.jsonl",
# )

Creating dataset: 100%|██████████| 2680/2680 [02:32<00:00, 17.56it/s]

Writing result dataset to datasets/sft_explanation.jsonl





In [None]:
sft_dataset = load_dataset('json', data_files="datasets/sft_explanation.jsonl")

completion_len = []
for row in sft_dataset["train"]:
    completion_len.append(len(row["completion"][0]["content"]))

print(f"Median response length: {(sum(completion_len) / len(completion_len)):.4}")

Median response length: 260.9


In [30]:
print(f"Example dataset sample:\n{sft_dataset['train'][0]}")

Example dataset sample:
{'prompt': [{'role': 'system', 'content': 'You are an expert in critical thinking. Analyze the following text, identify which logical fallacy it contains, and write the name of this logical fallacy and a short explanation.'}, {'role': 'user', 'content': 'company\'s slogan "Expect More. Pay Less."'}], 'completion': [{'role': 'assistant', 'content': 'The slogan employs an **appeal to emotion** fallacy, which tries to persuade by triggering positive feelings—here the desire for getting more while spending less—rather than presenting a logical argument about the product’s actual value.'}]}


In [None]:
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"

args = SFTConfig(
    output_dir="./checkpoints/sft_explanation",
    # Do not report training metrics 
    report_to="none",
    # Batch size per GPU
    per_device_train_batch_size=8,
    # Controls the size of the steps taken during training.
    learning_rate=2e-5,
    # Number of times we will go through our dataset
    num_train_epochs=1,
    # Train only on completions or assistant replicas
    completion_only_loss=True,
)

trainer = SFTTrainer(
    model=MODEL_NAME,
    args=args,
    train_dataset=sft_dataset['train'],
)

trainer.train()

Tokenizing train dataset: 100%|██████████| 2680/2680 [00:02<00:00, 1153.29 examples/s]
Truncating train dataset: 100%|██████████| 2680/2680 [00:00<00:00, 219674.32 examples/s]


{'loss': 1.4465, 'grad_norm': 11.633830070495605, 'learning_rate': 1.892857142857143e-05, 'entropy': 2.1654189586639405, 'num_tokens': 21466.0, 'mean_token_accuracy': 0.6500340580940247, 'epoch': 0.05952380952380952}
{'loss': 1.1601, 'grad_norm': 10.52587604522705, 'learning_rate': 1.7738095238095237e-05, 'entropy': 2.166372513771057, 'num_tokens': 43529.0, 'mean_token_accuracy': 0.6921733915805817, 'epoch': 0.11904761904761904}
{'loss': 1.0854, 'grad_norm': 10.185722351074219, 'learning_rate': 1.6547619047619046e-05, 'entropy': 2.09293292760849, 'num_tokens': 65152.0, 'mean_token_accuracy': 0.7049457490444183, 'epoch': 0.17857142857142858}
{'loss': 1.0344, 'grad_norm': 10.531123161315918, 'learning_rate': 1.535714285714286e-05, 'entropy': 2.007581281661987, 'num_tokens': 87139.0, 'mean_token_accuracy': 0.7231891691684723, 'epoch': 0.23809523809523808}
{'loss': 0.9894, 'grad_norm': 10.130460739135742, 'learning_rate': 1.416666666666667e-05, 'entropy': 2.0070352792739867, 'num_tokens': 

TrainOutput(global_step=168, training_loss=0.9437200512204852, metrics={'train_runtime': 29.8026, 'train_samples_per_second': 89.925, 'train_steps_per_second': 5.637, 'train_loss': 0.9437200512204852, 'entropy': 1.8114690035581589, 'num_tokens': 362629.0, 'mean_token_accuracy': 0.7726730033755302, 'epoch': 1.0})

In [8]:
def get_logical_fallacy(model: PreTrainedModel, user_prompt: str, system_prompt: str):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    # Get the tokeniser of the model
    tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=False)

    # Generate the input prompt based on the roles
    input_ids = tokenizer.apply_chat_template(
        messages,
        # Add the start of the assistant's replica at the end
        add_generation_prompt=True,
        return_tensors="pt",  # pytorch format of output tensors
    )

    prompt_length = input_ids.shape[1]

    generated_ids = model.generate(
        input_ids.to(model.device),
        max_new_tokens=512,
        do_sample=False,
    )

    assistant_ids = generated_ids[0][prompt_length:]
    response = tokenizer.decode(assistant_ids, skip_special_tokens=True)
    return response


sft_model = AutoModelForCausalLM.from_pretrained(
    "./checkpoints/sft_explanation/checkpoint-168",
    dtype=torch.bfloat16,
    device_map="auto",
)

data_id = 3
user_prompt = dataset['test'][data_id]['source_article']
answer = dataset['test'][data_id]['logical_fallacies']
response = get_logical_fallacy(sft_model, user_prompt, SYSTEM_PROMPT_WITH_EXPLANATION)
print(f"Prompt: {user_prompt}\nModel response: {response}\nCorrect answer: {answer}")

Prompt: "Why are you hitting your computer!?"
"The last time my wifi was weak, I hit my computer and it got better." What is this?
Model response: The passage commits a **false causality** (or post‑hoc) fallacy, which assumes that because one event followed another, they must be caused by each other; here the speaker links the weak WiFi to the computer’s improvement without providing evidence of a direct causal link.
Correct answer: false causality


In [None]:
evaluator = ModelEvaluator(sft_model, system_prompt=SYSTEM_PROMPT_WITH_EXPLANATION)
metrics = evaluator.eval(
    DATASET_NAME,
    metrics=["length", "oracle"],
    batch_size=16,
    oracle_model_name="deepseek-ai/DeepSeek-R1-0528",
)

for key, value in metrics.items():
    print(f"Metric: {key}, Score: {value:.4}")

Model eval: 100%|██████████| 32/32 [06:46<00:00, 12.69s/it]
Running Oracle Evaluation: 100%|██████████| 511/511 [00:14<00:00, 35.07it/s]

Metric: length, Score: 2.75e+02
Metric: oracle, Score: 0.307





Right now, we have a good SFT model that generates the answer and a small explanation, so let's do some DPO.

We will use the second approach for the generation of the DPO dataset. This approach is highly inspired by the [Constitutional AI article](https://arxiv.org/pdf/2212.08073)

For DPO we will use the 'dev' split of our dataset - data our SFT model has not been trained on yet.

One of the targets of DPO is to ask the model to make its answers shorter. That's why we are calculating a `length` metric.

In [43]:
import tqdm
import transformers

# First, we need to genearte answers with a SFT model.
def _create_prompt(user_prompt: str):
    return [
        {"role": "system", "content": SYSTEM_PROMPT_WITH_EXPLANATION},
        {"role": "user", "content": user_prompt},
    ]

def eval_model_on_prompts(model, dataset, batch_size=16):
    tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=False)
    pipeline = transformers.pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        model_kwargs={
            "dtype": torch.bfloat16,
        },
        device_map="auto",
    )

    results = []
    for i in tqdm.tqdm(range(0, len(dataset), batch_size)):
        batch_articles = dataset[i : i + batch_size]['source_article']
        batch_prompts = [_create_prompt(article) for article in batch_articles]

        answers = pipeline(
            batch_prompts,
            do_sample=True,
            return_full_text=False,
        )

        pred_answers = [ans[0]['generated_text'] for ans in answers]
        true_answers = dataset[i : i + batch_size]['logical_fallacies']
        
        for true_answer, pred_answer, article in zip(true_answers, pred_answers, batch_articles):
            results.append(
                {
                    "prompt": article,
                    "generated_text": pred_answer,
                    "true_answer": true_answer,
                }
            )
    return results

model_evals = eval_model_on_prompts(
    sft_model,
    dataset['dev'],
)

100%|██████████| 36/36 [08:19<00:00, 13.87s/it]


In [None]:
model_evals[0]

{'prompt': '"Just like students are given a couple of weeks of preparation before taking exams, doctors should also be given few days or weeks to prepare themselves before an operation or surgery, after all surgery is not as easy task"\nIs an example of....',
 'generated_text': 'The passage commits an **appeal to emotion** fallacy, which tries to persuade by evoking fear or anxiety rather than presenting logical evidence. It does this by emphasizing that surgery can be difficult because it requires less time for patients than exams, using the perceived threat of hardship to stir concern instead of providing factual support.',
 'true_answer': 'fallacy of logic'}

In [47]:
# Code to create DPO dataset with explanation

oracle_system_prompt = """
You are an AI assistant that acts as a writing coach. Your task is to refine and improve a given response.

You will be provided with three pieces of information:
1.  **Original Prompt:** The question that was asked.
2.  **Model's Generated Answer:** An initial attempt to answer the prompt.
3.  **Ideal Answer:** A reference for the correct information.

Your job is to **optimise the 'Model's Generated Answer'**. You must follow these rules:
- Make the answer shorter and more concise with some explanation related to the original prompt.
- Improve clarity and directness.
- Ensure the final answer is factually aligned with the 'Ideal Answer'.
- The output should be a single, polished sentence.
- **It must contain short explanation**

Please provide the output in a JSON format with a single key: "optimised_answer".

---
**EXAMPLE**
**Original Prompt:** "Analyze the following text and identify the logical fallacy: 'You can't trust Professor Jones's theory on economics because he's a known socialist.'"

**Model's Generated Answer:** "The fallacy in the text is an ad hominem. This is a type of logical error where an argument is rebutted by attacking the character of the person making it, rather than the substance of the argument itself."

**Ideal Answer:** "This is an **ad hominem** fallacy, where the speaker attacks Professor Jones's personal beliefs instead of his economic theory."

**Your JSON Output:**
{
  "optimised_answer": "This is an **ad hominem** fallacy, as the speaker attacks Professor Jones's character instead of his economic theory."
}
---

**YOUR TASK**
**Original Prompt:** {{prompt}}
**Model's Generated Answer:** {{sft_model_output}}
**Ideal Answer:** {{sft_ideal_output}}

**Your JSON Output:**
"""

class StructuredDPOResponse(BaseModel):
    optimised_answer: str

# The function that creates the user_prompt to big LLM
def row_preprocessor(data_row: dict[str, str]):
    return f"**Original Prompt:** {data_row['prompt']}\n**Model's Generated Answer:** {data_row['generated_text']}\n**Ideal Answer:** {data_row['true_answer']}"

# The function that geneartes data row for our small model, based on the response
def row_postprocessor(
  response: StructuredDPOResponse,
  system_prompt: str,
  data_row: dict[str, str],
):
    return {
        "prompt": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": data_row['prompt']},
        ],
        "chosen": [{"role": "assistant", "content": response.optimised_answer}],
        "rejected": [{"role": "assistant", "content": data_row['generated_text']}],
    }

# Uncomment this to regenerate the dataset
# generate_preprocessed_dataset(
#   original_dataset=model_evals,
#   response_model=StructuredDPOResponse,
#   row_preprocessor=row_preprocessor,
#   row_postprocessor=row_postprocessor,
#   oracle_system_prompt=oracle_system_prompt,
#   dataset_system_prompt=SYSTEM_PROMPT_WITH_EXPLANATION,
#   result_dataset_name="dpo_explanation.jsonl",
# )



Creating dataset: 100%|██████████| 570/570 [00:33<00:00, 17.14it/s]

Writing result dataset to datasets/dpo_explanation.jsonl





In [3]:
dpo_dataset = load_dataset('json', data_files="datasets/dpo_explanation.jsonl")

print("Example of DPO dataset:")
for key, value in dpo_dataset['train'][0].items():
    print(f"{key}: {value}")
    
print()

selected_len = []
rejected_len = []
for row in dpo_dataset["train"]:
    selected_len.append(len(row["chosen"][0]["content"]))
    rejected_len.append(len(row["rejected"][0]["content"]))

print(f"Median chosen length: {(sum(selected_len) / len(selected_len)):.4}")
print(f"Median rejected length: {(sum(rejected_len) / len(rejected_len)):.4}")

Example of DPO dataset:
prompt: [{'role': 'system', 'content': 'You are an expert in critical thinking. Analyze the following text, identify which logical fallacy it contains, and write the name of this logical fallacy and a short explanation.'}, {'role': 'user', 'content': '"Just like students are given a couple of weeks of preparation before taking exams, doctors should also be given few days or weeks to prepare themselves before an operation or surgery, after all surgery is not as easy task"\nIs an example of....'}]
chosen: [{'role': 'assistant', 'content': 'This is a logical fallacy—a false analogy that improperly equates exam preparation with surgical preparation.'}]
rejected: [{'role': 'assistant', 'content': 'The passage commits an **appeal to emotion** fallacy, which tries to persuade by evoking fear or anxiety rather than presenting logical evidence. It does this by emphasizing that surgery can be difficult because it requires less time for patients than exams, using the perce

In [None]:
training_args = DPOConfig(
    # Enables training in bf16 precision, which usually improves performance.
    bf16=True,
    # Controls the deviation from the reference model. A higher beta
    # means less divergence from the initial policy (the original model weights).
    beta=0.1,
    output_dir="./checkpoints/dpo",
    # Disables reporting logs to any external provider (e.g., Weights & Biases).
    report_to="none",
    # The number of samples over which gradients are accumulated.
    per_device_train_batch_size=16,
    # The learning rate, which determines the step size for model weight updates.
    learning_rate=1e-6,
    lr_scheduler_type="cosine",
    # The total number of times the model will iterate over the entire dataset.
    num_train_epochs=1,
)

trainer = DPOTrainer(
    model="./checkpoints/sft_explanation/checkpoint-168",
    args=training_args,
    train_dataset=dpo_dataset['train']
)

trainer.train()

{'loss': 0.2626, 'grad_norm': 10.14705753326416, 'learning_rate': 8.535533905932737e-07, 'rewards/chosen': 1.2404437065124512, 'rewards/rejected': -0.27815794944763184, 'rewards/accuracies': 0.8999999761581421, 'rewards/margins': 1.518601655960083, 'logps/chosen': -52.332130432128906, 'logps/rejected': -36.59880447387695, 'logits/chosen': -1.047598958015442, 'logits/rejected': -1.0941046476364136, 'epoch': 0.2777777777777778}
{'loss': 0.0416, 'grad_norm': 2.301567792892456, 'learning_rate': 4.5642212862617085e-07, 'rewards/chosen': 2.349787473678589, 'rewards/rejected': -1.2131578922271729, 'rewards/accuracies': 1.0, 'rewards/margins': 3.562945604324341, 'logps/chosen': -41.968441009521484, 'logps/rejected': -45.02446365356445, 'logits/chosen': -0.9047890901565552, 'logits/rejected': -0.9145330190658569, 'epoch': 0.5555555555555556}
{'loss': 0.018, 'grad_norm': 1.2117276191711426, 'learning_rate': 9.042397785550404e-08, 'rewards/chosen': 2.531099796295166, 'rewards/rejected': -2.039825

TrainOutput(global_step=36, training_loss=0.09230444062915114, metrics={'train_runtime': 28.0059, 'train_samples_per_second': 20.353, 'train_steps_per_second': 1.285, 'train_loss': 0.09230444062915114, 'epoch': 1.0})

In [9]:
dpo_model = AutoModelForCausalLM.from_pretrained(
    "./checkpoints/dpo/checkpoint-36",
    dtype=torch.bfloat16,
    device_map="auto",
)

data_id = 3
user_prompt = dataset['test'][data_id]['source_article']
answer = dataset['test'][data_id]['logical_fallacies']
response = get_logical_fallacy(dpo_model, user_prompt, SYSTEM_PROMPT_WITH_EXPLANATION)
print(f"Prompt: {user_prompt}\nModel response: {response}\nCorrect answer: {answer}")

Prompt: "Why are you hitting your computer!?"
"The last time my wifi was weak, I hit my computer and it got better." What is this?
Model response: The statement commits a **false causality** fallacy, assuming that because the computer improved after the weak WiFi, the WiFi must have caused the improvement.
Correct answer: false causality


In [12]:
evaluator = ModelEvaluator(dpo_model, system_prompt=SYSTEM_PROMPT_WITH_EXPLANATION)
metrics = evaluator.eval(
    DATASET_NAME,
    metrics=["length", "oracle"],
    batch_size=16,
    oracle_model_name="deepseek-ai/DeepSeek-R1-0528",
)

for key, value in metrics.items():
    print(f"Metric: {key}, Score: {value:.4}")

Model eval: 100%|██████████| 32/32 [03:49<00:00,  7.16s/it]
Running Oracle Evaluation: 100%|██████████| 511/511 [00:10<00:00, 47.86it/s]

Metric: length, Score: 163.6
Metric: oracle, Score: 0.3014





## Question: What can we improve our DPO training? How to do it better?

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

**Note**: Feeding a large model both a question and its correct answer, as we did in SFT and DPO examples, is not a best practice. Generally, this can lead the model to produce incorrect explanations and flawed reasoning. We are simply asking the model to identify a connection that it might not see.

A better approach is to sample responses from a LLM using a high temperature setting. This method generates diverse results, allowing us to select the outputs with correct answers. By doing this, we can more reliably assume that the reasoning and context the big LLM generated are also valid.

Finally, the resulting responses must be cleaned before being used to create the dataset: duplicates should be removed, and the dataset must be balanced.

## Reasoning models

Generally speaking, this format of thinking is not "native" to LLMs. It's more effective if the model thinks first and then produces the correct answer. This is how modern reasoning LLMs work: they reason before they answer. Almost all top-tier LLMs support a reasoning regime: GPT-5, Gemini, DeepSeek-R1, GPT OSS, etc. To make this work, these models train with RL algorithms such as Group Relative Policy Optimization (GRPO).

Before producing the final answer, this type of model generates some reasoning, putting it between `<reasoning>...</reasoning>` tags, and then produces the answer.

Lets see how it works

In [1]:
import os
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

from transformers import AutoModelForCausalLM, AutoTokenizer

# This is instruct version, Alibaba changed the model's naming.
model_name = "Qwen/Qwen3-0.6B"

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
instruct_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

messages = [
    {"role": "user", "content": "What's the capital of Great Britain?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt",
    # Switches between thinking (resoning) and non-thinking modes. Default is True.
    enable_thinking=True,
)

prompt_length = input_ids.shape[1]

generated_ids = instruct_model.generate(
    input_ids.to(instruct_model.device),
    max_new_tokens=256,
)

# Convert generated tokens back to human readable text
generated_text = tokenizer.decode(generated_ids[0][prompt_length:], skip_special_tokens=True)
print(f"Instruct model answer:\n{generated_text}")


  from .autonotebook import tqdm as notebook_tqdm


Instruct model answer:
<think>
Okay, the user is asking about the capital of Great Britain. Let me think. I know that the capital of the United Kingdom is London. But wait, are there any other cities with the same name? I recall that London is the capital, but maybe there's a city called London that's a different location. Let me double-check. No, London is the capital. So the answer should be London.
</think>

The capital of Great Britain is **London**.


This format is also very useful for our case because it generates the final answer at the end, which makes it possible to calculate the exact match metric on the answer alone. Let's try to make our model use reasoning and train it using RL. 

## Group Relative Policy Optimization

Group Relative Policy Optimization (GRPO) is an advanced RL method introduced in the [DeepSeekMath paper](https://arxiv.org/pdf/2402.03300). GRPO is an example of an on-policy method, which modifies the traditional Proximal Policy Optimization (PPO). Here is how it works:
1. **Sampling**. Generate multiple responses for every prompt with our LLM.
2. **Rewarding**. Each generation is scored by a reward function defined by the user.
3. **Advantage Calculation**. To improve the LLM, it's common to calculate not the absolute reward values but the Advantage, which is the reward minus a baseline value. This helps the model avoid getting stuck at high rewards and encourages it to always look for ways to improve the existing solution. In GRPO, the baseline value is calculated as the average reward of a group of generated outputs.
4. **Policy Update**. The policy tries to maximize the GRPO objective, which includes the calculated advantage term, while also preventing the output distribution from changing too rapidly from the original one.

![GRPO](imgs/grpo_visual.png)

Source: [hf article](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_grpo_trl), [blogpost](https://www.philschmid.de/deepseek-r1), [blogpost](https://normaluhr.github.io/2025/02/07/grpo/)

A major advantage of using GRPO is that, in the general case, we don't need a reasoning dataset. We can train a model to reason by providing a large number of tasks and constraining the format, as shown in the demo below.

This implementation is highly inspired by [GRPO demo](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb).


## Question: What steps should we take to learn the model to reason?

In [None]:
import os

# Needed to train on GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

# API provider keys are needed to calculate the 'oracle' metric and for dataset generation.
# In my case, I use Nebius AI Studio, but you can change it to any other provider that supports the OpenAI API.
# os.environ['ORACLE_BASE_URL'] = "https://api.studio.nebius.ai/v1/"
# os.environ['ORACLE_API_KEY'] = ""

import re
import torch

from pydantic import BaseModel
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel
from trl import GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer

from src.metrics import ModelEvaluator, extract_xml_answer
from src.generate_dataset import generate_preprocessed_dataset

# To allow run metrics calculator in notebook
import nest_asyncio
nest_asyncio.apply()

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
SYSTEM_PROMPT_REASONING = "You are an expert in critical thinking. Analyze the following text, identify which logical fallacy it contains, and write only the name of this logical fallacy."

DATASET_NAME = "tasksource/logical-fallacy"

dataset = load_dataset(
  DATASET_NAME,
  revision="main",
)

We will perform Supervised Fine-Tuning (SFT) first for two main reasons:
1. Format. Our model is too small and we have limited data, making it difficult to teach the model the required output format without SFT.
2. Basic Knowledge. As you remember, the base model answers correctly in only 1.5% of cases. That's too low; we would have to wait too long to generate enough correct results for the next stage.

However, generally speaking, we could take an instruction-tuned version and train it directly, as was done in the [GRPO demo](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb).

In [None]:
class StructuredReasoningResponse(BaseModel):
    answer: str
    reasoning: str | None = None


# The function that creates the user_prompt to big LLM
def row_preprocessor(data_row: dict[str, str]):
  return data_row['source_article']


# The function that geneartes data row for our small model, based on the response
def row_postprocessor(
  response: StructuredReasoningResponse,
  system_prompt: str,
  data_row: dict[str, str],
):
  model_answer = response.answer.lower()
  if not model_answer == data_row['logical_fallacies'].lower():
    return None

  assistant_response = f"<reasoning>\n{response.reasoning}\n</reasoning>\n<answer>\n{model_answer}\n</answer>\n"
  return {
    "prompt": [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": data_row['source_article']},
    ],
    "completion": [
        {"role": "assistant", "content": assistant_response}
    ]
  }

# generate_preprocessed_dataset(
#   original_dataset=dataset['train'],
#   response_model=StructuredReasoningResponse,
#   row_preprocessor=row_preprocessor,
#   row_postprocessor=row_postprocessor,
#   oracle_system_prompt=SYSTEM_PROMPT_REASONING,
#   dataset_system_prompt=SYSTEM_PROMPT_REASONING,
#   result_dataset_name="sft_resoning.jsonl",
# )

**Note**: For the sake of simplicity, we have not used oversampling and high-temperature generation in this case.

In [8]:
sft_resoning_dataset = load_dataset('json', data_files="datasets/sft_resoning.jsonl")

print("Example of Reasoning dataset:")
for key, value in sft_resoning_dataset['train'][0].items():
    print(f"{key}: {value}")

print(f"\nLength of the dataset: {len(sft_resoning_dataset['train'])}")

Example of Reasoning dataset:
prompt: [{'role': 'system', 'content': 'You are an expert in critical thinking. Analyze the following text, identify which logical fallacy it contains, and write only the name of this logical fallacy.'}, {'role': 'user', 'content': '"In his class president election video, he called his student opponent \'a brown-nosing, suck up who only wanted to get on the teacher\'s good side,\' which got him disqualified" IS an example of THIS fallacy.'}]
completion: [{'role': 'assistant', 'content': '<reasoning>\nWe need to identify logical fallacy in the statement: "In his class president election video, he called his student opponent \'a brown-nosing, suck up who only wanted to get on the teacher\'s good side,\' which got him disqualified" IS an example of THIS fallacy.\n\nThe statement is an example of an ad hominem (specifically, abusive ad hominem) because attacking the opponent\'s character rather than the argument. So answer: "ad hominem". Possibly "ad hominem a

To do it the correct way, we generated responses for every prompt using a larger model and selected only those where its result was the same as the ground truth from the dataset. Unfortunately, this gave us only 301 examples, which is too few.

In [9]:
args = SFTConfig(
    output_dir="./checkpoints/sft_reasoning",
    report_to="none",
    per_device_train_batch_size=8,
    num_train_epochs=1,
    completion_only_loss=True,
)

trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    args=args,
    train_dataset=sft_resoning_dataset['train'],
)

trainer.train()

{'loss': 1.2994, 'grad_norm': 9.554824829101562, 'learning_rate': 1.5263157894736846e-05, 'entropy': 1.7163094401359558, 'num_tokens': 17568.0, 'mean_token_accuracy': 0.701541143655777, 'epoch': 0.2631578947368421}
{'loss': 0.7909, 'grad_norm': 8.181174278259277, 'learning_rate': 1e-05, 'entropy': 1.4930537700653077, 'num_tokens': 34895.0, 'mean_token_accuracy': 0.7951819419860839, 'epoch': 0.5263157894736842}
{'loss': 0.6889, 'grad_norm': 7.62333869934082, 'learning_rate': 4.736842105263158e-06, 'entropy': 1.342898416519165, 'num_tokens': 52341.0, 'mean_token_accuracy': 0.8160412609577179, 'epoch': 0.7894736842105263}
{'train_runtime': 14.6233, 'train_samples_per_second': 20.584, 'train_steps_per_second': 2.599, 'train_loss': 0.8865696380012914, 'entropy': 1.332326889038086, 'num_tokens': 65786.0, 'mean_token_accuracy': 0.8017690926790237, 'epoch': 1.0}


TrainOutput(global_step=38, training_loss=0.8865696380012914, metrics={'train_runtime': 14.6233, 'train_samples_per_second': 20.584, 'train_steps_per_second': 2.599, 'train_loss': 0.8865696380012914, 'entropy': 1.332326889038086, 'num_tokens': 65786.0, 'mean_token_accuracy': 0.8017690926790237, 'epoch': 1.0})

In [3]:
def get_logical_fallacy(model: PreTrainedModel, user_prompt: str, system_prompt: str):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    # Get the tokeniser of the model
    tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=False)

    # Generate the input prompt based on the roles
    input_ids = tokenizer.apply_chat_template(
        messages,
        # Add the start of the assistant's replica at the end
        add_generation_prompt=True,
        return_tensors="pt",  # pytorch format of output tensors
    )

    prompt_length = input_ids.shape[1]

    generated_ids = model.generate(
        input_ids.to(model.device),
        max_new_tokens=512,
        do_sample=False,
    )

    assistant_ids = generated_ids[0][prompt_length:]
    response = tokenizer.decode(assistant_ids, skip_special_tokens=True)
    return response


sft_reasoning_model = AutoModelForCausalLM.from_pretrained(
    "./checkpoints/sft_reasoning/checkpoint-38",
    dtype=torch.bfloat16,
    device_map="auto",
)

data_id = 0
user_prompt = dataset['test'][data_id]['source_article']
answer = dataset['test'][data_id]['logical_fallacies']
response = get_logical_fallacy(sft_reasoning_model, user_prompt, SYSTEM_PROMPT_REASONING)
print(f"Prompt: {user_prompt}\nModel response: {response}\nCorrect answer: {answer}")

Prompt: People who drive big cars probably hate the environment.
Model response: <reasoning>
We need to identify logical fallacy. The statement: "People who drive big cars probably hate the environment." This is a false cause? It says because they drive big cars, they hate the environment. That's a circular reasoning (begging the question) or appeal to emotion? Actually it's a straw man? It's attacking the person rather than argument. So answer: "False cause". Probably "appeal to emotion" but more precisely "false cause".
</reasoning>
<answer>
false cause
</answer>

Correct answer: fallacy of extension


In [16]:
evaluator = ModelEvaluator(sft_reasoning_model, system_prompt=SYSTEM_PROMPT_REASONING)
metrics = evaluator.eval(
    DATASET_NAME,
    metrics=["em"],
    batch_size=16,
    parse_output=True,
    oracle_model_name="deepseek-ai/DeepSeek-R1-0528",
)

for key, value in metrics.items():
    print(f"Metric: {key}, Score: {value:.4}")

Model eval: 100%|██████████| 32/32 [13:17<00:00, 24.91s/it]

Metric: em, Score: 0.1213





Although there were only 300 examples, we increased the exact match metric to 12%, which is much better than I expected.

In [None]:
# Reward for correctness
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r.lower() == a.lower() else 0.0 for r, a in zip(extracted_responses, answer)]

# Reward for the correct format of the model's output
def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

In [5]:
grpo_dataset = dataset["train"].map(
    lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT_REASONING},
            {'role': 'user', 'content': x['source_article']},
        ],
        'answer': x['logical_fallacies'],
    },
    remove_columns=["config", "source_article", "logical_fallacies"],
)

Map: 100%|██████████| 2680/2680 [00:00<00:00, 38150.87 examples/s]


In [None]:
training_args = GRPOConfig(
    output_dir="./checkpoints/grpo/",
    # The initial learning rate, which determines the step size for model weight updates.
    learning_rate=5e-6,
    # Applies regularization to prevent overfitting by penalizing large weights.
    weight_decay = 0.1,
    # The proportion of total training steps used for a linear learning rate warmup from 0.
    warmup_ratio = 0.1,
    # The type of learning rate scheduler to use for adjusting the learning rate during training.
    lr_scheduler_type='cosine',
    # The frequency (in steps) at which to log training metrics.
    logging_steps=1,
    # Enables training in bf16 (bfloat16) mixed-precision, which can improve performance.
    bf16=True,
    # The number of training samples to process per device in a single forward pass.
    per_device_train_batch_size=16,
    # Number of update steps to accumulate gradients over before performing a backward pass.
    # Real batch size = per_device_train_batch_size * gradient_accumulation_steps.
    gradient_accumulation_steps=1,
    # The number of candidate responses to generate for each prompt during training.
    num_generations=16,
    # Maximum token length for input prompts; longer prompts will be truncated.
    max_prompt_length=256,
    # Maximum token length for generated completions.
    max_completion_length=786,
    # The total number of times the model will iterate over the entire dataset.
    num_train_epochs=1,
    # The frequency (in steps) at which to save a model checkpoint.
    save_steps=100,
    # The maximum norm for gradient clipping, used to prevent exploding gradients and stabilize training.
    max_grad_norm=0.1,
    # The integration to report logs and results to (in this case, Weights & Biases). See the report bellow.
    report_to="wandb",
)
        
tokenizer = AutoTokenizer.from_pretrained(sft_reasoning_model.name_or_path)
tokenizer.pad_token = tokenizer.eos_token

trainer = GRPOTrainer(
    model=sft_reasoning_model,
    processing_class=tokenizer,
    reward_funcs=[
        strict_format_reward_func,
        correctness_reward_func
    ],
    args=training_args,
    train_dataset=grpo_dataset,
)

trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33msergeyskv[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


-------------------- Question:
Time is up on Gore ’ s “ point of no return ” and Hansen ’ s “ critical tipping point. ” 
Answer:
fallacy of logic 
Extracted:
circular reasoning
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 0.0, 'num_tokens': 3364.0, 'completions/mean_length': 141.25, 'completions/min_length': 88.0, 'completions/max_length': 264.0, 'completions/clipped_ratio': 0.0, 'completions/mean_terminated_length': 141.25, 'completions/min_terminated_length': 88.0, 'completions/max_terminated_length': 264.0, 'rewards/strict_format_reward_func/mean': 0.5, 'rewards/strict_format_reward_func/std': 0.0, 'rewards/correctness_reward_func/mean': 0.0, 'rewards/correctness_reward_func/std': 0.0, 'reward': 0.5, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'entropy': 0.9693878889083862, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.00037313432835820896}
-------------------- Questio

TrainOutput(global_step=2680, training_loss=0.005519580252738123, metrics={'train_runtime': 11853.4047, 'train_samples_per_second': 0.226, 'train_steps_per_second': 0.226, 'total_flos': 0.0, 'train_loss': 0.005519580252738123})

You can see the report with train graphics [here](https://api.wandb.ai/links/sergeyskv/3bexijx1)

In [9]:
grpo_model = AutoModelForCausalLM.from_pretrained(
    "./checkpoints/grpo/checkpoint-2680/",
    dtype=torch.bfloat16,
    device_map="auto",
)

evaluator = ModelEvaluator(grpo_model, system_prompt=SYSTEM_PROMPT_REASONING)
metrics = evaluator.eval(
    DATASET_NAME,
    metrics=["em"],
    batch_size=16,
    parse_output=True,
    oracle_model_name="deepseek-ai/DeepSeek-R1-0528",
)

for key, value in metrics.items():
    print(f"Metric: {key}, Score: {value:.4}")

Model eval: 100%|██████████| 32/32 [18:03<00:00, 33.87s/it]

Metric: em, Score: 0.182





In [30]:
data_id = 20
user_prompt = dataset['test'][data_id]['source_article']
answer = dataset['test'][data_id]['logical_fallacies']
response = get_logical_fallacy(grpo_model, user_prompt, SYSTEM_PROMPT_REASONING)
print(f"Prompt: {user_prompt}\nModel response: {response}\nCorrect answer: {answer}")

Prompt: Repeating the same thing as if you’re proving something
Model response: <reasoning>
We need to identify logical fallacy. The statement: "Repeating the same thing as if you're proving something". This is a circular reasoning? It's saying that repeating the same thing is proof, but then they claim it's not proof because they repeat it. That's circular reasoning (begging the question). So answer: Circular reasoning.
</reasoning>
<answer>
circular reasoning
</answer>

Correct answer: circular reasoning
