In [None]:
import os

from predibase import (
    Predibase,
    GRPOConfig,
    RewardFunctionsConfig,
    RewardFunctionsRuntimeConfig,
    SFTConfig,
    SamplingParamsConfig,
)
from datasets import load_dataset
from dotenv import load_dotenv

In [None]:
load_dotenv("../.env")
pb = Predibase(api_token=os.environ["PREDIBASE_API_KEY"])

In [None]:
# Load dataset from HuggingFace
dataset = load_dataset("predibase/wordle-grpo", split="train")
dataset = dataset.to_pandas()

# Upload dataset in Predibase
try:
    dataset = pb.datasets.from_pandas_dataframe(
        dataset,
        name="wordle_grpo_data"
    )
except Exception:
    dataset = pb.datasets.get("wordle_grpo_data")

In [None]:
# Import reward functions
from reward_functions import (
    guess_value,
    output_format_check,
    uses_previous_feedback,
)

In [None]:
# Create GRPO training run in Predibase by specifying the config, 
# dataset, repository and reward functions
pb.finetuning.jobs.create(
    config=GRPOConfig(
        base_model="qwen2-5-7b-instruct",
        reward_fns=RewardFunctionsConfig(
            runtime=RewardFunctionsRuntimeConfig(
                packages=["pandas"]
            ),
            functions={
                "output_format_check": output_format_check,
                "uses_previous_feedback": uses_previous_feedback,
                "guess_value": guess_value,
            }
        ),
        sampling_params=SamplingParamsConfig(max_tokens=4096),
        num_generations=16
    ),
    dataset=dataset,
    repo="wordle",
    description="Wordle GRPO"
)