In [1]:
from reasoning_gym import get_score_answer_fn, create_dataset
from prompts import *
from transformers import AutoModelForCausalLM, AutoTokenizer
from rich import print

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 100
batch_size = 2
n_rollouts = 3
buffer_size = 6
max_new_tokens = 100
model_name = "HuggingFaceTB/SmolLM-135M-Instruct"

In [3]:
env_name = "propositional_logic"
dataset = create_dataset(env_name,seed=42, size=1)

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [5]:
entry = dataset[0]

In [6]:
question = entry["question"]
answer = entry['metadata']['example_answer']

In [7]:
validation_object = entry["metadata"]["source_dataset"]
score_fn = get_score_answer_fn(validation_object)

In [8]:
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": question}, # Obtained from reasoning-gym
]

In [9]:
templated_string = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        # return_tensors="pt",
        add_generation_prompt=True,
    )

In [10]:
inputs = tokenizer(
        [templated_string],
        return_tensors="pt",
        padding_side="left",
        max_length=512,
        padding='max_length',
        # # truncation=True,
    )

In [11]:
print(inputs["input_ids"].shape)  # torch.Size([1, 75])

In [12]:
generated_response = model.generate(
        input_ids=inputs['input_ids'],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=max_new_tokens, # The max number of tokens to generate
        do_sample=True,                # Probabilistic sampling
        top_p=0.95,                    # Nucleus sampling
        num_return_sequences=n_rollouts,        # Number of sequences per question
        temperature=1,                 # Increase randomness
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
    )

In [13]:
from utils import *
decoded_resp = tokenizer.batch_decode(generated_response[:, inputs["input_ids"].shape[1]:]) # Remove the prompt part
extract_answers = [extract_answer(decoded_resp[i]) for i in range(n_rollouts)]
rewards = calculate_total_reward(decoded_resp, np.repeat(entry, n_rollouts))
# extract_answers

In [14]:
rewards = np.reshape(rewards, (1, n_rollouts))
advantages = (rewards - np.mean(rewards, axis=1, keepdims=True)) / (
    np.std(rewards, axis=1, keepdims=True) + 1e-8
)
# advantages = advantages.reshape(-1, 1)
advantages

array([[0., 0., 0.]])

In [15]:
inputs["attention_mask"].shape

torch.Size([1, 512])

In [16]:
log_probs = calculate_logits(model, generated_response, inputs["attention_mask"].repeat(n_rollouts, 1))

In [17]:
padded_tokens = (generated_response!=tokenizer.eos_token_id).int()
response_start_idx = padded_tokens.argmax(axis=-1)
response_end_idx = padded_tokens.shape[1] - torch.flip(padded_tokens, dims=[1]).argmax(dim=1)
response_mask = torch.zeros_like(padded_tokens)
for i in range(response_mask.shape[0]):
    response_mask[i, inputs['input_ids'].shape[-1]:response_end_idx[i]] = 1
    experience = [
        {'input_sequence': generated_response[
                i, response_start_idx[i]:response_end_idx[i]
            ],
         'response_mask': response_mask[
                i, response_start_idx[i]:response_end_idx[i]
            ],
         'log_probs': log_probs[
                i, response_start_idx[i]:response_end_idx[i]
            ],
         'advantage': advantages[i]
        }
        for i in range(1)
    ]

In [18]:
print(advantages.shape, log_probs.shape, rewards.shape)


In [19]:
advantages = torch.tensor(advantages, dtype = torch.float32)

In [20]:
(advantages@log_probs).shape

torch.Size([1, 612])

In [21]:
new_log_probs = calculate_logits(model, generated_response, inputs["attention_mask"].repeat(n_rollouts, 1))

In [22]:
def compute_policy_loss(new_log_probs, old_log_probs, advantages, response_mask):
    ratioes = new_log_probs/old_log_probs
    relavant_ratioes = ratioes*response_mask
    clipped_term = advantages@torch.clip(relavant_ratioes, 0.8, 1.2)
    unclipped_term = advantages@relavant_ratioes
    grpo_term = torch.min(unclipped_term, clipped_term)
    loss = torch.sum(grpo_term)/grpo_term.shape[1]
    return loss

In [23]:
response_mask

tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]], dtype=torch.int32)

In [24]:

compute_policy_loss(new_log_probs, log_probs, advantages, response_mask)

tensor(0., grad_fn=<DivBackward0>)