In [6]:
# imports 
import random
import copy
import re
import os
import sys
import numpy as np
import wandb
from dotenv import load_dotenv

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

from DGXutils import GetLowestGPU

def set_random_seed(seed: int=42):
    """
    Set random seed for reproducibility across python, numpy, pytorch

    Args:
        seed (int): random seed value
    
    Returns:
        None
    """

    # set seed for python random module
    random.seed(seed)

    # set seed for numpy
    np.random.seed(seed)

    # set seed for pytorch
    torch.manual_seed(seed)

    # set seed for torch.cuda
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    
    # deterministic behavior in cuDNN (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# set the seed
set_random_seed(42)

# set wandb logging variables
load_dotenv()
os.environ["WANDB_API_KEY"] = os.getenv("WANDB_API_KEY")
os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT")


# Formatting + Answer Extraction

In [None]:
SYSTEM_PROMPT = """You are a helpful assistant. A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\
    The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""

def extract_answer_from_model_output(text):
    """
    Extracts value from the last <answer> tag in the text

    Args:
        text (str): model-generated text containing XML-style <answer> tags
    
    Returns:
        str or None: extracted answer from the last <answer> tag, or None if no <answer> tags are found
    """

    # split on <answer> and take everything after the last occurrence
    parts = text.split("<answer>")
    if len(parts) < 2:
        # No <answer> tag found
        return None
    last_part = parts [-1]

    # extract the content up to </answer>
    if "</answer>" not in last_part:
        return None
    answer = last_part.split("</answer>")[0].strip() 
    return None if answer == "..." else answer

def extract_answer_from_dataset(text):
    """
    Extracts answer from gsm8k dataset examples

    Args:
        text (str): dataset example text containing a question and answer
    
    Returns:
        str or None: extracted answer after '####' delimiter, or None if no answer is found
    """
    if '####' not in text:
        return None
    return text.split('####')[1].strip()


# Dataset Preparation

In [15]:
def build_prompt(messages):
    """
    Build a single prompt string from a list of messages.

    Args:
        messages (list): a list of message dictionaries, each with "role" and "content" keys
    
    Returns:
        str: a concatenated string of all message contents
    """
    return "\n".join([msg["content"].strip() for msg in messages])


def prepare_dataset(example):
    """
    prepare a gsm8k observation for training with string prompts

    Args:
        dataset (DatasetDict): a dataset containing examples with "question" and "text" keys
    
    Returns:
        list: a list of formatted examples, each containing a prompt string and an answer
    """

    # load data

    # loop through examples, format, add to new dataset
    prompt_str = build_prompt([
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": example["question"]}
    ])
    formatted_example = {
        "prompt": prompt_str,
        "answer": extract_answer_from_dataset(example["answer"])
    }
    return formatted_example

# build gsm8k dataset and preprocess
gsm8k = load_dataset("openai/gsm8k", "main")["train"]
data = gsm8k.map(prepare_dataset).remove_columns(["question"])

# Evaluation Functions

In [None]:
def extract_last_number(text):
    """
    Extracts the last number appearing in the text.

    Args:
        text (str): the text to extract a number from
    
    Returns:
        float or None: the last number in the text, or None if no number is found.
    """ 

    # remove $, % from text
    text = text.replace('$', '').replace('%', '')

    # regex to find an int, fraction, or decimal appearing at the end of the text
    pattern = r'(?:^|\s)(\s|=)\s*(-?\d*\.?\d+)\s*$'
    match = re.search(pattern, text)
    return float(match.group(2)) if match else None

def extract_single_number(text):
    """
    Extracts a single number from the text if exactly one number is present.
    
    Args:
        text (str): The text to extract a number from.
    
    Returns:
        float or None: the single number in the text, or none if zero or multiple
    """

    # regex to find a number in the text
    numbers = re.findall(r'-?\d*\.?\d+', text)

    # return the number if exactly one is found
    return float(numbers[0]) if len(numbers) == 1 else None

def evaluate_model(model, tokenizer, eval_examples, device):
    """
    Evaluates the model on a set of examples and prints the detailed results.

    Args:
        model: the language model to evaluate
        tokenizer: tokenizer for encoding inputs, decoding outputs
        eval_examples (list): list of evaluation examples, each containing a "prompt" and "answer"
        device: the device to run the model on

    Returns:
        float: accuracy percentage (correct predictions / total examples * 100)
    """