# Fine tuning to reason

In this notebook, we'll be:

* Loading in a small pre-trained LLM.
* Fine tuning it to solve reasoning problems and "think step-by-step"

In [None]:
# this cell should take 2-3 minutes
# !pip install bits
!pip install bitsandbytes==0.45.4 boto3==1.37.0 datasets==3.4.1 torch==2.6.0 trl==0.16.0

In [None]:
# this cell should take ~15 seconds
import boto3
import getpass
import io
import os
from os import path
import re
import shutil
import typing

import datasets
import huggingface_hub
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import peft
from peft import TaskType
import torch
from torch import cuda
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import trl

In [None]:
AWS_ACCESS_KEY_ID = getpass.getpass("Enter your AWS_ACCESS_KEY_ID: ")
AWS_SECRET_ACCESS_KEY = getpass.getpass("Enter your AWS_SECRET_ACCESS_KEY: ")

os.environ["AWS_ACCESS_KEY_ID"] = AWS_ACCESS_KEY_ID
os.environ["AWS_SECRET_ACCESS_KEY"] = AWS_SECRET_ACCESS_KEY

S3_BUCKET_NAME = "data-science-talks"
S3_CLIENT = boto3.client("s3")

In [None]:
# authentication required to read in the Llama model
huggingface_hub.login()

In [None]:
if not path.exists('data'):
    os.mkdir("data")

## Helper functions

In [None]:
def upload_file_to_s3(
    file_path: str, object_key: str, bucket_name: str = S3_BUCKET_NAME
) -> None:
    """
    Upload any local file to S3 at the specified object key.
    """
    S3_CLIENT.upload_file(Filename=file_path, Bucket=bucket_name, Key=object_key)
    print(f"Uploaded file {file_path} to s3://{bucket_name}/{object_key}")


class S3CheckpointCallback(transformers.TrainerCallback):
    def __init__(
        self, steps_interval=10, output_dir="outputs", s3_prefix="my_lora_checkpoints"
    ):
        self.steps_interval = steps_interval
        self.output_dir = output_dir
        self.s3_prefix = s3_prefix

    def on_step_end(self, args, state, control, **kwargs):
        """
        Every steps_interval steps, we request the trainer to perform a checkpoint save.
        This triggers the on_save event below.
        """
        if state.global_step > 0 and (state.global_step % self.steps_interval == 0):
            control.should_save = True  # signals the trainer to save now

    def on_save(self, args, state, control, **kwargs):
        """
        Called once the trainer has actually saved the full checkpoint to
        `outputs/checkpoint-<step>`.

        We zip that folder and upload to S3.
        """
        checkpoint_dir = path.join(self.output_dir, f"checkpoint-{state.global_step}")
        if not path.exists(checkpoint_dir):
            print(
                "S3CheckpointCallback: no checkpoint directory found at {checkpoint_dir}, "
                "skipping."
            )
            return

        # Zip
        zip_base = f"checkpoint_{state.global_step}"
        zip_filename = f"{zip_base}.zip"
        shutil.make_archive(base_name=zip_base, format="zip", root_dir=checkpoint_dir)
        print(f"S3CheckpointCallback: Created {zip_filename} from {checkpoint_dir}")

        # Upload
        s3_key = f"{self.s3_prefix}/{zip_filename}"
        upload_file_to_s3(zip_filename, s3_key)

        os.remove(zip_filename)
        print(
            f"S3CheckpointCallback: checkpoint {state.global_step} zipped+uploaded to "
            "s3://{S3_BUCKET_NAME}/{s3_key}"
        )

## Training

In [None]:
max_seq_length = 256
lora_rank = 64

MODEL_NAME = "meta-llama/meta-Llama-3.1-8B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, max_length=max_seq_length, torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# Apply LoRA with PEFT
lora_config = peft.LoraConfig(
    r=lora_rank,
    lora_alpha=lora_rank,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_dropout=0.05,
    task_type=TaskType.CAUSAL_LM,
)
model = peft.get_peft_model(model, lora_config)

In [None]:
SYSTEM_PROMPT = """
### EXAMPLE ###
Q: 3+2
<reasoning>
3 plus 2 is 5
</reasoning>
<answer>
5
</answer>

Now follow the same format EXACTLY for each question:

<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

def _extract_hash_answer(text: str) -> str:
    return text.split("####")[1].strip()

# from 
# https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb
def _extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()
    

def _parse_number(raw: str) -> typing.Optional[float]:
    raw = raw.strip().strip("$").replace(",", "")
    # remove trailing period, etc.
    raw = re.sub(r"[.!]+$", "", raw)
    try:
        return float(raw)
    except ValueError:
        return None



# from 
# https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb
def _get_gsm8k_questions(split = "train") -> datasets.Dataset:
    data = datasets.load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(
        lambda x: {
            'prompt': [
                {'role': 'system', 'content': SYSTEM_PROMPT},
                {'role': 'user', 'content': x['question']},
            ],
            'answer': _extract_hash_answer(x['answer']),
        }
    )
    return data # type: ignore

dataset = _get_gsm8k_questions()



def _correctness_reward_func_length_penalty(
    completions: list[list[dict[str, str]]], answer: list[str], **kwargs
) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    predicted_nums = [_parse_number(_extract_xml_answer(r)) for r in responses]
    gold_nums = [_parse_number(a) for a in answer]
    rewards = []
    for r, p, g in zip(responses, predicted_nums, gold_nums):
        if p is not None and g is not None and abs(p - g) < 1e-9:  # Answer is correct
            try:
                reasoning_text = r.split("<reasoning>")[1].split("</reasoning>")[0].strip()
                reasoning_length = len(reasoning_text)
                if reasoning_length < 45:
                    return 0
                length_factor = min(1.0, reasoning_length / 100.0)  # Scale up to 100 characters
                reward = 5.0 * length_factor
            except IndexError:
                reward = 0.0  # Incorrect format prevents reward
        else:
            reward = 0.0  # Incorrect answer
        rewards.append(reward)
    return rewards

# from 
# https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb
def _int_reward_func(completions: list[list[dict[str, str]]], **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [_extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def _format_reward_func(completions: list[list[dict[str, str]]], **kwargs) -> list[float]:
    pattern = r"^[\s]*<reasoning>[\s\S]+?</reasoning>\s*<answer>[\s\S]+?</answer>[\s]*$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [1 if match else 0.0 for match in matches]

In [None]:
class LogRewardsGRPOTrainer(trl.GRPOTrainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # Compute the loss as usual
        loss = super().compute_loss(model, inputs, return_outputs, num_items_in_batch)

        # Log rewards and sub-rewards from self._metrics
        mode = "train"  # Since we're in training
        # Align with logging_steps
        if (
            self.state.global_step
            and self.state.global_step % self.args.logging_steps == 0
        ):  
            logs = {}
            # Overall reward and reward_std
            if self._metrics[mode]["reward"]:
                logs["reward"] = self._metrics[mode]["reward"][-1]  # Most recent value
            if self._metrics[mode]["reward_std"]:
                logs["reward_std"] = self._metrics[mode]["reward_std"][-1]
            # Sub-rewards
            for i, reward_func in enumerate(self.reward_funcs):
                reward_func_name = reward_func.__name__
                metric_key = f"rewards/{reward_func_name}"
                if self._metrics[mode][metric_key]:
                    logs[f"subreward_{reward_func_name}"] = self._metrics[mode][
                        metric_key
                    ][-1]
            # Additional metrics like completion_length and kl (if applicable)
            if self._metrics[mode]["completion_length"]:
                logs["completion_length"] = self._metrics[mode]["completion_length"][-1]
            if self._metrics[mode].get("kl"):
                logs["kl"] = self._metrics[mode]["kl"][-1]
            if logs:
                self.log(logs)

            print(logs)
        return loss


In [None]:
def _resume_from_latest_s3_checkpoint(
    s3_prefix="test_lora_checkpoints",
    local_unzip_dir="resume_checkpoint",
    bucket_name: str = S3_BUCKET_NAME,
):
    """
    Finds the highest-numbered checkpoint_XX.zip in s3://BUCKET/<s3_prefix>,
    downloads & unzips it into 'resume_checkpoint/checkpoint-<step>'. Returns that path.
    """

    # 1) List objects in the S3 prefix
    response = S3_CLIENT.list_objects_v2(Bucket=bucket_name, Prefix=s3_prefix)
    if "Contents" not in response:
        print("No objects found under that S3 prefix!")
        return None

    # 2) Find the best (highest) checkpoint_X.zip
    pattern = re.compile(r"checkpoint_(\d+)\.zip$")
    max_step = None
    best_key = None
    for obj in response["Contents"]:
        key = obj["Key"]  # e.g. "test_lora_checkpoints/checkpoint_10.zip"
        match = pattern.search(key)
        if match:
            step = int(match.group(1))
            if max_step is None or step > max_step:
                max_step = step
                best_key = key

    if best_key is None:
        print("No checkpoint_XX.zip found on S3.")
        return None

    print(f"Found latest checkpoint => step={max_step}, key={best_key}")

    # 3) Download that single zip
    local_zip = "latest_checkpoint.zip"
    print(f"Downloading s3://{bucket_name}/{best_key} to {local_zip}")
    S3_CLIENT.download_file(Bucket=bucket_name, Key=best_key, Filename=local_zip)

    # 4) Unzip into local_unzip_dir
    if path.exists(local_unzip_dir):
        shutil.rmtree(local_unzip_dir)
    os.makedirs(local_unzip_dir, exist_ok=True)

    local_ckpt_dir = path.join(local_unzip_dir, f"checkpoint-{max_step}")
    print(f"Unzipping {local_zip} to {local_ckpt_dir}...")
    shutil.unpack_archive(local_zip, local_ckpt_dir)
    os.remove(local_zip)

    print(f"Checkpoint placed in {local_ckpt_dir}")
    return local_ckpt_dir


In [None]:
S3_PREFIX_DIR = 'gsm8k_lora_reward_fixes_3_lora_64_checkpoints'

# 2) Download & unzip the last checkpoint
local_ckpt_dir = _resume_from_latest_s3_checkpoint(
    s3_prefix=S3_PREFIX_DIR,  # same folder used previously
    local_unzip_dir="resume_checkpoint"
)

In [None]:
are_we_at_the_start_of_the_training_run = False

training_args = trl.GRPOConfig(
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_torch",
    logging_steps=25,
    bf16=cuda.is_bf16_supported(),
    fp16=not cuda.is_bf16_supported(),
    per_device_train_batch_size=6,
    gradient_accumulation_steps=4,
    num_generations=6,
    max_prompt_length=256,
    max_completion_length=256,
    max_steps=2_500,
    save_strategy="steps",
    save_steps=-1,
    max_grad_norm=0.1,
    report_to="none",
    output_dir="outputs",
    seed=250217,
    gradient_checkpointing=True,
)

# Add custom callback to do the custom saving & S3 upload every 10 steps
checkpoint_callback = S3CheckpointCallback(
    steps_interval=25, output_dir="outputs", s3_prefix=S3_PREFIX_DIR
)

trainer = LogRewardsGRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        _correctness_reward_func_length_penalty,
        _format_reward_func,
        _int_reward_func,
    ],
    args=training_args,
    train_dataset=dataset,
)

trainer.add_callback(checkpoint_callback)

if are_we_at_the_start_of_the_training_run:
    trainer.train(resume_from_checkpoint=local_ckpt_dir)
else:
    trainer.train()


Due to the `LogRewardsGRPOTrainer`, this will print out lines like the following (these are the actual lines from step 400 of the training run):

```
{'reward': 3.1666667461395264, 'reward_std': 2.58198881149292, 'subreward_correctness_reward_func_length_penalty': 1.6666666269302368, 'subreward_format_reward_func': 1.0, 'subreward_int_reward_func': 0.5, 'completion_length': 157.1666717529297, 'kl': 0.006420954596251249}
{'reward': 3.1666667461395264, 'reward_std': 2.58198881149292, 'subreward_correctness_reward_func_length_penalty': 1.6666666269302368, 'subreward_format_reward_func': 1.0, 'subreward_int_reward_func': 0.5, 'completion_length': 144.83334350585938, 'kl': 0.008445960469543934}
{'reward': 3.1666667461395264, 'reward_std': 2.58198881149292, 'subreward_correctness_reward_func_length_penalty': 1.6666666269302368, 'subreward_format_reward_func': 1.0, 'subreward_int_reward_func': 0.5, 'completion_length': 132.5, 'kl': 0.006176165770739317}
{'reward': 6.5, 'reward_std': 0.0, 'subreward_correctness_reward_func_length_penalty': 5.0, 'subreward_format_reward_func': 1.0, 'subreward_int_reward_func': 0.5, 'completion_length': 134.33334350585938, 'kl': 0.005382485222071409}
```

## Plot of training loss

In [None]:
response = S3_CLIENT.get_object(
    Bucket=S3_BUCKET_NAME, Key=path.join(S3_PREFIX_DIR, 'training_loss_plot.png')
)
image_data = response['Body'].read()

img = mpimg.imread(io.BytesIO(image_data), format='png')
plt.figure(figsize=(10, 6))
plt.imshow(img)
plt.axis('off')  # Hide axes
plt.show()
