In [1]:
from dotenv import load_dotenv, find_dotenv
from llama_r1_zero.llama import Llama
from llama_r1_zero.grpo import GRPOLoss
from llama_r1_zero.prompts import SYSTEM_PROMPT
from llama_r1_zero.rewards import accuracy_reward, format_reward, complexity_reward, similarity_reward
import os
import json
import random

load_dotenv(find_dotenv())

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
llama = Llama.build(
    ckpt_dir=os.environ.get("MODEL_PATH"),
    max_batch_size=4,
    max_seq_len=1024
)


grpo_loss = GRPOLoss(
    tokenizer=llama.tokenizer,
    system_prompt=SYSTEM_PROMPT,
    reward_funcs=[accuracy_reward, format_reward, complexity_reward, similarity_reward],
    num_generations=2,
    max_new_tokens=768
)

In [3]:
with open('data/gsm8k_train.jsonl', 'r') as f:
    math_qns = [json.loads(line) for line in f]

In [4]:
qns = random.sample(math_qns, k=2)

loss = grpo_loss.compute_loss(
    model=llama,
    prompts=[qn['question'] for qn in qns],
    ground_truths=[qn['answer'] for qn in qns]
)

print(f'metrics: {grpo_loss.metrics}')
print(f'loss: {loss}')

metrics: {'reward': [0.9187671542167664], 'reward_std': [0.7020941376686096], 'kl': [0.0]}
loss: 1.4901161193847656e-07
