In [1]:
from datasets import load_dataset
ds = load_dataset('eth-dl-rewards/code_preference_data', split='eval')

In [2]:
from transformers import AutoModel, AutoTokenizer
import torch

MODEL = 'internlm/internlm2-7b-reward'
model = AutoModel.from_pretrained(MODEL, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map='cuda')

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)

In [4]:
def get_reward(prompt, solution):
  messages = [
    {"role": "user", "content": prompt},
    {"role": "assistant", "content": solution}
  ]
  encoded = tokenizer.apply_chat_template(messages, tokenize=False)
  inputs = tokenizer(encoded, return_tensors='pt', truncation=True, max_length=4096).to('cuda')
  with torch.no_grad():
    outputs = model(**inputs)
    reward = outputs.logits[0][0].item()
    del inputs, outputs
    return reward


In [5]:
from tqdm import tqdm
import numpy as np
import gc
correct = 0
total = 0
gaps = []
for problem, accepted, rejected in tqdm(zip(ds['problem'], ds['accepted'], ds['rejected']), total=len(ds), desc='Evaluating'):
  reward_accepted = get_reward(problem, accepted)
  reward_rejected = get_reward(problem, rejected)
  gaps.append(reward_accepted - reward_rejected)
  if reward_accepted > reward_rejected:
    correct += 1
  total += 1

  if total % 100 == 0:
    print(f"Correct: {correct}/{total} ({correct/total*100:.2f}%)")
    gc.collect()
    torch.cuda.empty_cache()

Evaluating:   2%|▏         | 101/4180 [00:15<12:12,  5.57it/s]

Correct: 94/100 (94.00%)


Evaluating:   5%|▍         | 200/4180 [00:29<12:46,  5.19it/s]

Correct: 194/200 (97.00%)


Evaluating:   7%|▋         | 300/4180 [00:53<16:22,  3.95it/s]

Correct: 289/300 (96.33%)


Evaluating:  10%|▉         | 400/4180 [01:14<14:19,  4.40it/s]

Correct: 385/400 (96.25%)


Evaluating:  12%|█▏        | 500/4180 [01:40<10:27,  5.86it/s]

Correct: 483/500 (96.60%)


Evaluating:  14%|█▍        | 601/4180 [01:59<11:54,  5.01it/s]

Correct: 582/600 (97.00%)


Evaluating:  17%|█▋        | 700/4180 [02:17<13:54,  4.17it/s]

Correct: 681/700 (97.29%)


Evaluating:  19%|█▉        | 800/4180 [02:35<09:20,  6.03it/s]

Correct: 777/800 (97.12%)


Evaluating:  22%|██▏       | 900/4180 [02:51<14:07,  3.87it/s]

Correct: 875/900 (97.22%)


Evaluating:  24%|██▍       | 1000/4180 [03:07<09:22,  5.65it/s]

Correct: 973/1000 (97.30%)


Evaluating:  26%|██▋       | 1100/4180 [03:29<07:42,  6.66it/s]

Correct: 1071/1100 (97.36%)


Evaluating:  29%|██▊       | 1201/4180 [03:52<11:30,  4.31it/s]

Correct: 1168/1200 (97.33%)


Evaluating:  31%|███       | 1300/4180 [04:14<14:29,  3.31it/s]

Correct: 1267/1300 (97.46%)


Evaluating:  34%|███▎      | 1401/4180 [04:35<09:11,  5.04it/s]

Correct: 1365/1400 (97.50%)


Evaluating:  36%|███▌      | 1500/4180 [04:53<07:49,  5.71it/s]

Correct: 1460/1500 (97.33%)


Evaluating:  38%|███▊      | 1600/4180 [05:13<09:41,  4.44it/s]

Correct: 1558/1600 (97.38%)


Evaluating:  41%|████      | 1702/4180 [05:31<05:12,  7.92it/s]

Correct: 1654/1700 (97.29%)


Evaluating:  43%|████▎     | 1800/4180 [05:51<18:53,  2.10it/s]

Correct: 1752/1800 (97.33%)


Evaluating:  45%|████▌     | 1900/4180 [06:09<09:32,  3.98it/s]

Correct: 1847/1900 (97.21%)


Evaluating:  48%|████▊     | 2002/4180 [06:27<04:48,  7.55it/s]

Correct: 1944/2000 (97.20%)


Evaluating:  50%|█████     | 2101/4180 [06:49<05:52,  5.89it/s]

Correct: 2040/2100 (97.14%)


Evaluating:  53%|█████▎    | 2201/4180 [07:04<05:03,  6.53it/s]

Correct: 2136/2200 (97.09%)


Evaluating:  55%|█████▌    | 2301/4180 [07:23<04:25,  7.09it/s]

Correct: 2236/2300 (97.22%)


Evaluating:  57%|█████▋    | 2400/4180 [07:39<03:48,  7.78it/s]

Correct: 2331/2400 (97.12%)


Evaluating:  60%|█████▉    | 2502/4180 [07:57<03:56,  7.11it/s]

Correct: 2428/2500 (97.12%)


Evaluating:  62%|██████▏   | 2600/4180 [08:12<06:43,  3.91it/s]

Correct: 2525/2600 (97.12%)


Evaluating:  65%|██████▍   | 2701/4180 [08:37<07:16,  3.39it/s]

Correct: 2621/2700 (97.07%)


Evaluating:  67%|██████▋   | 2801/4180 [08:58<06:38,  3.46it/s]

Correct: 2720/2800 (97.14%)


Evaluating:  69%|██████▉   | 2900/4180 [09:18<06:52,  3.10it/s]

Correct: 2815/2900 (97.07%)


Evaluating:  72%|███████▏  | 3001/4180 [09:37<03:00,  6.53it/s]

Correct: 2908/3000 (96.93%)


Evaluating:  74%|███████▍  | 3101/4180 [09:54<02:28,  7.29it/s]

Correct: 3002/3100 (96.84%)


Evaluating:  77%|███████▋  | 3201/4180 [10:12<02:55,  5.57it/s]

Correct: 3098/3200 (96.81%)


Evaluating:  79%|███████▉  | 3301/4180 [10:36<02:12,  6.62it/s]

Correct: 3193/3300 (96.76%)


Evaluating:  81%|████████▏ | 3400/4180 [10:59<02:04,  6.28it/s]

Correct: 3290/3400 (96.76%)


Evaluating:  84%|████████▍ | 3501/4180 [11:17<02:05,  5.41it/s]

Correct: 3385/3500 (96.71%)


Evaluating:  86%|████████▌ | 3601/4180 [11:35<02:23,  4.03it/s]

Correct: 3481/3600 (96.69%)


Evaluating:  89%|████████▊ | 3700/4180 [11:51<01:31,  5.27it/s]

Correct: 3578/3700 (96.70%)


Evaluating:  91%|█████████ | 3800/4180 [12:11<01:08,  5.53it/s]

Correct: 3673/3800 (96.66%)


Evaluating:  93%|█████████▎| 3901/4180 [12:28<00:32,  8.50it/s]

Correct: 3768/3900 (96.62%)


Evaluating:  96%|█████████▌| 4000/4180 [12:49<00:44,  4.04it/s]

Correct: 3868/4000 (96.70%)


Evaluating:  98%|█████████▊| 4100/4180 [13:11<00:11,  7.26it/s]

Correct: 3966/4100 (96.73%)


Evaluating: 100%|██████████| 4180/4180 [13:24<00:00,  5.19it/s]


In [6]:
print(f"Correct: {correct}/{total} ({correct/total*100:.2f}%)")

Correct: 4062/4180 (97.18%)


In [2]:
from tqdm import tqdm
import numpy as np
import gc
import torch

def get_rewards_batch(batch):
    prompts, solutions = zip(*batch)
    messages = [
        [{"role": "user", "content": prompt}, {"role": "assistant", "content": solution}]
        for prompt, solution in zip(prompts, solutions)
    ]
    encoded = tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer(encoded, return_tensors='pt', truncation=True, max_length=4096, padding=True).to('cuda')
    with torch.no_grad():
        outputs = model(**inputs)
        rewards = outputs.logits[:, 0].tolist()
        del inputs, outputs
        return rewards

def precompute_rewards(ds, batch_size=10):
    rewards = []
    for i in tqdm(range(0, len(ds), batch_size), desc='Precomputing Rewards'):
        batch = list(zip(ds['problem'][i:i + batch_size], ds['accepted'][i:i + batch_size], ds['rejected'][i:i + batch_size]))
        accepted_batch = [(problem, accepted) for problem, accepted, _ in batch]
        rejected_batch = [(problem, rejected) for problem, _, rejected in batch]

        reward_accepted = get_rewards_batch(accepted_batch)
        reward_rejected = get_rewards_batch(rejected_batch)

        batch_rewards = list(zip(reward_accepted, reward_rejected))
        rewards.extend(batch_rewards)

        gc.collect()
        torch.cuda.empty_cache()

    return rewards

def evaluate_from_precomputed(rewards):
    correct = 0
    total = 0
    gaps = []

    for reward_accepted, reward_rejected in rewards:
        gap = reward_accepted - reward_rejected
        gaps.append(gap)
        if reward_accepted > reward_rejected:
            correct += 1
        total += 1

    return correct, total, gaps

# Example usage:
# ds should be a list of tuples in the form [(problem, accepted, rejected), ...]
batch_size = 8
precomputed_rewards = precompute_rewards(ds, batch_size=batch_size)
correct, total, gaps = evaluate_from_precomputed(precomputed_rewards)
print(f"Final Results: Correct: {correct}/{total} ({correct/total*100:.2f}%)")


NameError: name 'ds' is not defined