# Fine tuning to reason - baseline assessment

In this notebook, we'll be:

* Loading in a small pre-trained LLM.
* Generate outputs which enable us to assess its baseline performance on the GSM8K dataset

In [None]:
# this cell should take 2-3 minutes
!pip install bitsandbytes==0.45.4 boto3==1.37.0 datasets==3.4.1 torch==2.6.0

In [None]:
# this cell should take ~15 seconds
import boto3
from datetime import datetime
import getpass
import json
import os
from os import path
import random
import re
import string
import typing

import datasets
import huggingface_hub
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
AWS_ACCESS_KEY_ID = getpass.getpass("Enter your AWS_ACCESS_KEY_ID: ")
AWS_SECRET_ACCESS_KEY = getpass.getpass("Enter your AWS_SECRET_ACCESS_KEY: ")

os.environ["AWS_ACCESS_KEY_ID"] = AWS_ACCESS_KEY_ID
os.environ["AWS_SECRET_ACCESS_KEY"] = AWS_SECRET_ACCESS_KEY

S3_BUCKET_NAME = "data-science-talks"
S3_CLIENT = boto3.client("s3")

JSONable = typing.Union[dict, list]

In [None]:
if not path.exists('data'):
    os.mkdir("data")

In [None]:
# authentication required to read in the Llama model
huggingface_hub.login()

## Helper functions

In [None]:
def _random_three_alphanumeric():
    # 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
    chars = string.ascii_letters + string.digits
    return ''.join(random.choices(chars, k=3))


In [None]:
def upload_jsonable_object_to_s3(
    obj: JSONable, object_key: str, temp_filepath: str, bucket_name: str = S3_BUCKET_NAME
) -> None:
    print(f"Dumping object temporarily to {temp_filepath}")
    with open(temp_filepath, "w") as f:
        json.dump(obj, f, indent=2)

    s3_key_name = f"{object_key}.json"
    S3_CLIENT.upload_file(Filename=temp_filepath, Bucket=bucket_name, Key=s3_key_name)

    print(f"Uploaded object to S3 as {s3_key_name}")


def download_object_from_s3(object_key: str, bucket_name: str = S3_BUCKET_NAME) -> None:
    object_name = f"{object_key}.json"
    obj = S3_CLIENT.get_object(Bucket=bucket_name, Key=object_name)
    print(f"Reading {object_name} into Python from S3")
    data_bytes = obj["Body"].read()  # Read the raw bytes from S3
    return json.loads(data_bytes)  # Parse directly as JSON


Display one object from S3:

```python
import textwrap

def _print_wrapping_as_list_of_lines(data: dict) -> str:
    # Wrap 'question' as a list of lines
    if "question" in data and data["question"]:
        question_wrapped = textwrap.wrap(data["question"], width=80)
        data["question"] = question_wrapped

    # Wrap 'raw_text' as a list of lines
    for entry in data.get("model_outputs", []):
        if "raw_text" in entry and entry["raw_text"]:
            wrapped_lines = textwrap.wrap(entry["raw_text"], width=99)
            entry["raw_text"] = wrapped_lines

    # Return as pretty-printed JSON
    return json.dumps(data, indent=2)

data = download_object_from_s3("gsm8k_llama_7b_100_record_test_1")
s = _print_wrapping_as_list_of_lines(data)
print(s)
```

## Reading in the LLama model

In [None]:
max_seq_length = 256
lora_rank = 64

MODEL_NAME = "meta-llama/meta-Llama-3.1-8B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    max_length=max_seq_length,
    torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
data = datasets.load_dataset('openai/gsm8k', 'main')["train"]

In [None]:
def _extract_hash_answer(text: str) -> str:
    return text.split("####")[1].strip()

Script to get the lengths of the gold answers:

```python
import pandas as pd


gsm8k_data = datasets.load_dataset('openai/gsm8k', 'main')
test_set = gsm8k_data['test']
test_set_shuffled = test_set.shuffle(seed=250217)

subset = test_set_shuffled.select(range(100))

gold_lens = []
for i, sample in enumerate(subset):
    gold_raw = sample["answer"]
    gold_answer = _extract_hash_answer(gold_raw)

    # Tokenize gold_answer
    tok = tokenizer(gold_raw, return_tensors="pt", padding=False, truncation=False)

    # Count the tokens
    num_tokens = len(tok["input_ids"][0])

    gold_lens.append(
        {"index": i, "gold_answer": gold_answer, "token_length": num_tokens}
    )

# Convert to a Pandas DataFrame
df = pd.DataFrame(gold_lens)
print(df.head(30))
```

### Checking baseline performance

Goal: check how many, of the GSM8K test set, Llama 8B Instruct gets right.

In [None]:
def _is_valid_number(candidate: str) -> bool:
    """
    Check if 'candidate' matches one of the patterns:
    - plain integer (e.g. '39')
    - dollar + integer (e.g. '$39')
    - comma-separated integer (e.g. '43,500')
    """
    candidate = candidate.strip()
    # e.g. remove leading/trailing '$', but handle the case if it starts with $
    # We'll use a regex approach to handle commas or $ sign:
    pattern = r"^\$?\d{1,3}(,\d{3})*(\.\d+)?$"
    return bool(re.match(pattern, candidate))


# from 
# https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb
def _extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()


def generate_answers(
    model,
    tokenizer,
    system_prompt: str,
    dataset_subset,
    s3_prefix: str,
    num_attempts: int = 5,
    max_length_prompt: int = 256,
    max_new_tokens: int = 256,
    start_index: int = 0,
    end_index: typing.Optional[int] = None,
):
    """
    Generate multiple answers for each sample in a dataset subset using a specified model,
    and upload the results to S3.

    Args:
        model: The model to use for generation (e.g., baseline model or fine-tuned LoRA model).
        tokenizer: The tokenizer corresponding to the model.
        system_prompt: The system prompt to prepend to each question.
        dataset_subset: The dataset subset to process (e.g., 100 shuffled GSM8K test samples).
        s3_prefix: Prefix for S3 object keys (e.g., 'gsm8k_llama_7b_100_record_test_').
        num_attempts: Number of answer attempts per sample (default: 5).
        max_length_prompt: Maximum token length for the input prompt (default: 256).
        max_new_tokens: Maximum new tokens to generate (default: 256).
        start_index: Index to start processing from (default: 0, useful for resuming).
    """
    # Set model to evaluation mode
    model.eval()

    # Move model to CUDA
    model.to("cuda")

    # Handle index bounds
    dataset_len = len(dataset_subset)
    start_index = max(0, start_index)
    end_index = dataset_len if end_index is None else min(dataset_len, end_index)

    # Process each sample in the specified range
    for i in range(start_index, end_index):
        sample = dataset_subset[i]

        # Extract question and gold answer from the sample
        question_text = sample["question"]
        gold_raw = sample["answer"]
        gold_answer = _extract_hash_answer(gold_raw)

        # Construct the prompt based on use_chat_template
        messages = [
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': question_text},
        ]
        full_prompt_text = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        # Tokenize the prompt
        inputs = tokenizer(
            full_prompt_text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length_prompt,
        ).to("cuda")

        # Get the length of the prompt tokens
        prompt_len = inputs["input_ids"].shape[1]

        # Generate five answers for the sample
        sampled_answers = []
        for attempt in range(num_attempts):
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    temperature=0.8,
                    top_p=0.95,
                    **generation_kwargs,
                )
            # Extract newly generated tokens (after the prompt)
            gen_tokens = outputs[0, prompt_len:]
            completion_text = tokenizer.decode(gen_tokens, skip_special_tokens=True)

            # Extract and validate the predicted answer
            predicted_answer = _extract_xml_answer(completion_text)
            if _is_valid_number(predicted_answer):
                pred_clean = predicted_answer.replace("$", "").replace(",", "").strip()
                gold_clean = gold_answer.replace("$", "").replace(",", "").strip()
                correct = pred_clean == gold_clean
            else:
                correct = False

            # Store attempt details
            sampled_answers.append(
                {
                    "raw_text": completion_text,
                    "predicted_answer": predicted_answer,
                    "valid_number": _is_valid_number(predicted_answer),
                    "correct": correct,
                }
            )

        # Compile the record for this sample
        record = {
            "index": i,
            "question": question_text,
            "gold_answer": gold_answer,
            "model_outputs": sampled_answers,
        }

        now = datetime.now()

        timestamp_str = now.strftime("%b-%d-%Y_%I-%M-%S%p")

        filename = f"{_random_three_alphanumeric()}_{timestamp_str}.json"
        # Upload to S3 with a 1-based index in the key
        upload_jsonable_object_to_s3(
            record, f'{s3_prefix}{i + 1}', temp_filepath=path.join("data", filename)
        )


In [None]:
gsm8k_data = datasets.load_dataset('openai/gsm8k', 'main')
test_set = gsm8k_data['test']
test_set_shuffled = test_set.shuffle(seed=250217)

SYSTEM_PROMPT = """
### EXAMPLE ###
Q: 3+2
<reasoning>
3 plus 2 is 5
</reasoning>
<answer>
5
</answer>

Now follow the same format EXACTLY for each question:

<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

PREFIX_NEW = "gsm8k_llama_7b_100_record_test_new_prompt_3_"

generate_answers(
    model=model,
    tokenizer=tokenizer,
    system_prompt=SYSTEM_PROMPT,
    dataset_subset=test_set_shuffled,
    s3_prefix=PREFIX_NEW,
)