In [6]:
# imports 
import random
import copy
import re
import os
import sys
import numpy as np
import wandb
from dotenv import load_dotenv

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

from DGXutils import GetLowestGPU

def set_random_seed(seed: int=42):
    """
    Set random seed for reproducibility across python, numpy, pytorch

    Args:
        seed (int): random seed value
    
    Returns:
        None
    """

    # set seed for python random module
    random.seed(seed)

    # set seed for numpy
    np.random.seed(seed)

    # set seed for pytorch
    torch.manual_seed(seed)

    # set seed for torch.cuda
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    
    # deterministic behavior in cuDNN (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# set the seed
set_random_seed(42)

# set wandb logging variables
load_dotenv()
os.environ["WANDB_API_KEY"] = os.getenv("WANDB_API_KEY")
os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT")


# Formatting + Answer Extraction

In [None]:
SYSTEM_PROMPT = """You are a helpful assistant. A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\
    The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""

def extract_answer_from_model_output(text):
    """
    Extracts value from the last <answer> tag in the text

    Args:
        text (str): model-generated text containing XML-style <answer> tags
    
    Returns:
        str or None: extracted answer from the last <answer> tag, or None if no <answer> tags are found
    """

    # split on <answer> and take everything after the last occurrence
    parts = text.split("<answer>")
    if len(parts) < 2:
        # No <answer> tag found
        return None
    last_part = parts [-1]

    # extract the content up to </answer>
    if "</answer>" not in last_part:
        return None
    answer = last_part.split("</answer>")[0].strip() 
    return None if answer == "..." else answer

def extract_answer_from_dataset(text):
    """
    Extracts answer from gsm8k dataset examples

    Args:
        text (str): dataset example text containing a question and answer
    
    Returns:
        str or None: extracted answer after '####' delimiter, or None if no answer is found
    """
    if '####' not in text:
        return None
    return text.split('####')[1].strip()


# Dataset Preparation

In [15]:
def build_prompt(messages):
    """
    Build a single prompt string from a list of messages.

    Args:
        messages (list): a list of message dictionaries, each with "role" and "content" keys
    
    Returns:
        str: a concatenated string of all message contents
    """
    return "\n".join([msg["content"].strip() for msg in messages])


def prepare_dataset(example):
    """
    prepare a gsm8k observation for training with string prompts

    Args:
        dataset (DatasetDict): a dataset containing examples with "question" and "text" keys
    
    Returns:
        list: a list of formatted examples, each containing a prompt string and an answer
    """

    # load data

    # loop through examples, format, add to new dataset
    prompt_str = build_prompt([
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": example["question"]}
    ])
    formatted_example = {
        "prompt": prompt_str,
        "answer": extract_answer_from_dataset(example["answer"])
    }
    return formatted_example

# build gsm8k dataset and preprocess
gsm8k = load_dataset("openai/gsm8k", "main")["train"]
data = gsm8k.map(prepare_dataset).remove_columns(["question"])

# Evaluation Functions

In [17]:
def extract_last_number(text):
    """
    Extracts the last number appearing in the text.

    Args:
        text (str): the text to extract a number from
    
    Returns:
        float or None: the last number in the text, or None if no number is found.
    """ 

    # remove $, % from text
    text = text.replace('$', '').replace('%', '')

    # regex to find an int, fraction, or decimal appearing at the end of the text
    pattern = r'(?:^|\s)(\s|=)\s*(-?\d*\.?\d+)\s*$'
    match = re.search(pattern, text)
    return float(match.group(2)) if match else None

def extract_single_number(text):
    """
    Extracts a single number from the text if exactly one number is present.
    
    Args:
        text (str): The text to extract a number from.
    
    Returns:
        float or None: the single number in the text, or none if zero or multiple
    """

    # regex to find a number in the text
    numbers = re.findall(r'-?\d*\.?\d+', text)

    # return the number if exactly one is found
    return float(numbers[0]) if len(numbers) == 1 else None

def evaluate_model(model, tokenizer, eval_examples, device):
    """
    Evaluates the model on a set of examples and prints the detailed results.

    Args:
        model: the language model to evaluate
        tokenizer: tokenizer for encoding inputs, decoding outputs
        eval_examples (list): list of evaluation examples, each containing a "prompt" and "answer"
        device: the device to run the model on

    Returns:
        float: accuracy percentage (correct predictions / total examples * 100)
    """
    
    # initialize variables
    model.eval()
    correct = 0
    total = len(eval_examples)
    print("\n" + "="*50)
    print(f"EVALUATING ON {total} EXAMPLES")
    print("="*50)

    for example in eval_examples:
        # get prompt, expected answer
        full_prompt = example["prompt"]
        expected = example["answer"]

        # tokenize and generate response
        inputs = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_new_tokens=512,
                temperature=0.7,
                num_return_sequences=1,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                forced_eos_token_id=tokenizer.eos_token_id,
                early_stopping=False
                )
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)

        try:
            # extact answer and check correctness
            predicted = extract_answer_from_model_output(response)

            # try different match methods
            if predicted == expected: # Exact math
                is_correct = True
            
            # single number matching
            else:
                pred_num = extract_single_number(str(predicted))
                exp_num = extract_single_number(str(expected))
                if pred_num is not None and exp_num is not None and pred_num == exp_num:
                    is_correct = True
                else:
                    # try last number matching
                    pred_num = extract_last_number(str(predicted))
                    exp_num = extract_last_number(str(expected))
                    is_correct = (pred_num is not None and exp_num is not None and pred_num == exp_num)
            
            # update counter if correct
            if is_correct:
                correct += 1
            
            # print results
            print("\nPrompt:")
            print(full_prompt)
            print("\nExpected Answer:")
            print(expected)
            print("\nExtracted Answer:")
            print(predicted)
            print("\nFull Generated Response:")
            print(response)
            print("\nCorrect:" "YES" if is_correct else "NO")
            print("-"*50)
            
        except Exception as e:
            print("\nFailed to parse model output for prompt:")
            print(full_prompt)
            print("Error:", e)
            print("-"*50)
    
    # calculate and print final accuracy
    accuracy = correct / total * 100
    print(f"\nAccuracy: {accuracy:.2f}% ({correct}/{total})")
    print("="*50)

    # put back in train mode
    model.train()
    return accuracy

# Reward Functions

In [None]:
def correctness_reward(prompts, completions, answer, **kwargs):
    """
    Assigns a reward based on the correctness of the model's answer.

    Args:
        prompts (list): list of input prompts
        completions (list): list of model-generated completions
        answer (list): list of expected answers
        **kwargs: additional keyword arguments
    
    Returns:
        list: a list of numerical rewards for each completion
    
    Rewards:
        2.0 points for an exact match
        1.5 points for numeric equivalense (values match but format differs)
        0.0 points for incorrect answers
    """

    # extract answers from model completions
    responses = [completion[0]['content'] for completion in completions]
    extracted = [extract_answer_from_model_output(r) for r in responses]
    rewards = []
    for r, a in zip(extracted, answer):
        # exact match
        if r == a: 
            rewards.append(2.0)
        # try numeric equivalence
        else:
            r_num = extract_single_number(str(r))
            a_num = extract_single_number(str(a))
            if r_num is not None and a_num is not None and r_num == a_num:
                rewards.append(1.5)
            else:
                rewards.append(0.0)
    
    # log completion lengths
    completion_lengths = [len(r) for r in responses]
    return rewards

def format_reward(completions, **kwargs):
    """
    Assigns a reward for adhering to the desired XML format.
    
    Args:
        completions (list): list of model completions, each containing content.
        **kwargs: additional keyword arguments
    
    Returns:
        list: a list of format compliance scores for each completion
    
    Rewards:
        0.2 points for each tag 
        0.8 points maximum score
    """

    # extract responses
    responses = [completion[0]['content'] for completion in completions]
    rewards = []
    format_scores = []

    # score responses
    for response in responses:
        score = 0.0
        if "<think>" in response:
            score += 0.2
        if "</think>" in response:
            score += 0.2
        if "<answer>" in response:
            score += 0.2
        if "</answer>" in response:
            score += 0.2
        rewards.append(score)
        format_scores.append(score)
    return rewards

def combined_reward(prompts, completions, answer):
    """
    Combines correctness and format rewards for each completion.

    Args:
        prompts (list[str]): list of prompt texts
        completions (list[list[dict]]): list of model completions
        answer (list[str]): list of expected answers
    
    Returns:
        list[float]: list of combined rewards
    
    Rewards:
        Correctness score range: 0.0 to 2.0
        Format score range: 0.0 to 0.8
        Total possible range: 0.0 to 2.8
    """

    # get individual rewards
    correctness_scores = correctness_reward(prompts=prompts, completions=completions, answer=answer)
    format_scores = format_reward(completions=completions)

    # combine rewards
    combined_rewards = [c + f for c, f in zip(correctness_scores, format_scores)]
    return combined_rewards