# Lesson 8: Putting it all together - Training Wordle

<div style="background-color:#fff6ff; padding:13px; border-width:3px; border-color:#efe6ef; border-style:solid; border-radius:6px">
<p> 💻 &nbsp; <b>Access <code>requirements.txt</code>  file:</b> 1) click on the <em>"File"</em> option on the top menu of the notebook and then 2) click on <em>"Open"</em>.

<p> ⬇ &nbsp; <b>Download Notebooks:</b> 1) click on the <em>"File"</em> option on the top menu of the notebook and then 2) click on <em>"Download as"</em> and select <em>"Notebook (.ipynb)"</em>.</p>

<p> 📒 &nbsp; For more help, please see the <em>"Appendix – Tips, Help, and Download"</em> Lesson.</p>

</div>

Import dependencies and setup the Predibase deployment for training:

In [1]:
import os

from predibase import (
    Predibase,
    GRPOConfig,
    RewardFunctionsConfig,
    RewardFunctionsRuntimeConfig,
    SFTConfig,
    SamplingParamsConfig,
)
from datasets import load_dataset
from dotenv import load_dotenv

In [2]:
load_dotenv("../.env")
pb = Predibase(api_token=os.environ["PREDIBASE_API_KEY"])

The `PREDIBASE_API_TOKEN` long format you're using will be deprecated on April 15, 2024. Please upgrade your token by going to the Predibase UI and generating a new one.


Load the GRPO [wordle training dataset](https://huggingface.co/datasets/predibase/wordle-grpo) from Hugging Face:

In [3]:
# Load dataset from HuggingFace
dataset = load_dataset("predibase/wordle-grpo", split="train")
dataset = dataset.to_pandas()

# Upload dataset in Predibase
try:
    dataset = pb.datasets.from_pandas_dataframe(
        dataset,
        name="wordle_grpo_data"
    )
except Exception:
    dataset = pb.datasets.get("wordle_grpo_data")

README.md:   0%|          | 0.00/57.0 [00:00<?, ?B/s]

train.csv:   0%|          | 0.00/108k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/76 [00:00<?, ? examples/s]

  warn(


Create a training repo and load the Wordle reward functions:

In [7]:
# Uncomment the line below if running in your own environment - the repos is already setup for you here
# Create repository in Predibase
repo = pb.repos.create(name="wordle", exists_ok=True)

RuntimeError: Bad request. Response status code 403. Error: {'message': 'You have insufficient permissions to perform this action.'}

<div style="background-color:#fff6ff; padding:13px; border-width:3px; border-color:#efe6ef; border-style:solid; border-radius:6px">
<b>Note:</b> You can access the full code of the reward functions, stored in <code>reward_functions.py</code> by </b> 1) clicking on the <em>"File"</em> option on the top menu of the notebook and then 2) clicking on <em>"Open"</em>.

</div>

In [5]:
# Import reward functions
from reward_functions import (
    guess_value,
    output_format_check,
    uses_previous_feedback,
)

## Set up the training run

<div style="background-color:#fff6ff; padding:13px; border-width:3px; border-color:#efe6ef; border-style:solid; border-radius:6px">
<b>Note:</b> The following cell will not run on the learning platform. If you decide to run from your own computer, update the PREDIBASE_API_TOKEN environment variable with your own API key in the setup above. 

You can get free credits to try out Predibase at [this website](https://predibase.com/free-trial).

</div>

In [6]:
# Create GRPO training run in Predibase by specifying the config, 
# dataset, repository and reward functions
pb.finetuning.jobs.create(
    config=GRPOConfig(
        base_model="qwen2-5-7b-instruct",
        reward_fns=RewardFunctionsConfig(
            runtime=RewardFunctionsRuntimeConfig(
                packages=["pandas"]
            ),
            functions={
                "output_format_check": output_format_check,
                "uses_previous_feedback": uses_previous_feedback,
                "guess_value": guess_value,
            }
        ),
        sampling_params=SamplingParamsConfig(max_tokens=4096),
        num_generations=16
    ),
    dataset=dataset,
    repo="wordle",
    description="Wordle GRPO"
)

RuntimeError: Bad request. Response status code 403. Error: {'message': 'You have insufficient permissions to perform this action.'}

## Try out SFT and SFT+GRPO on Predibase

You can use the code below to setup an SFT training job in Predibase, and then use the resulting checkpoing as input for a GRPO run.

This example uses a following [Wordle SFT dataset](https://huggingface.co/datasets/predibase/wordle-sft) available on Hugging Face. 

### SFT training on Predibase

```python

dataset = load_dataset("predibase/wordle-sft", split="train")
dataset = dataset.to_pandas()

# Upload dataset to Predibase
dataset = pb.datasets.from_pandas_dataframe(dataset, name="wordle_sft_data")

# Create repository in Predibase
repo = pb.repos.create(name="wordle", exists_ok=True)

# Create SFT training run in Predibase by specifying the config, dataset, repository and reward functions
pb.finetuning.jobs.create(
    config=SFTConfig(
        base_model="qwen2-5-7b-instruct",
        epochs=10,
        rank=64,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
    ),
    dataset=dataset,
    repo="wordle",
    description="Wordle SFT, 10 epochs"
)
```

### SFT + GRPO training on Predibase

```python
# Use the same dataset as the GRPO training run
dataset = pb.datasets.get("wordle_grpo_data")

# Create GRPO training run in Predibase by specifying the config, dataset, repository and reward functions
pb.finetuning.jobs.create(
    config=GRPOConfig(
        base_model="qwen2-5-7b-instruct",
        reward_fns=RewardFunctionsConfig(
            runtime=RewardFunctionsRuntimeConfig(packages=["pandas"]),
            functions={
                "output_format_check": output_format_check,
                "uses_previous_feedback": uses_previous_feedback,
                "guess_value": guess_value,
            }
        ),
        epochs=3,
        enable_early_stopping=False,
        sampling_params=SamplingParamsConfig(max_tokens=4096),
        num_generations=8
    ),
    continue_from_version="wordle/1", # change "1" to the version number of the SFT training run in the repository
    dataset=dataset,
    repo="wordle",
    description="Wordle GRPO"
)
```