In [1]:
from datasets import load_dataset
ds = load_dataset('eth-dl-rewards/pref-data-math', split='eval')

In [4]:
ds.num_rows

13544

In [2]:
from transformers import AutoModel, AutoTokenizer
import torch

MODEL = 'eth-dl-rewards/internlm2-7b-reward-math-30k'
model = AutoModel.from_pretrained(MODEL, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map='cuda')

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)

In [4]:
def get_reward(prompt, solution):
  messages = [
    {"role": "user", "content": prompt},
    {"role": "assistant", "content": solution}
  ]
  encoded = tokenizer.apply_chat_template(messages, tokenize=False)
  with torch.no_grad():
    inputs = tokenizer(encoded, return_tensors='pt', truncation=True, max_length=3500).to('cuda')
    print(inputs['input_ids'].shape)
    outputs = model(**inputs)
    reward = outputs.logits[0][0].item()
    del inputs
    del outputs
    gc.collect()
    torch.cuda.empty_cache()

    return reward


In [None]:
from tqdm import tqdm
import gc
correct = 0
total = 0
gaps = []
model.eval()
for problem, accepted, rejected in tqdm(zip(ds['problem'], ds['accepted'], ds['rejected']), total=len(ds), desc='Evaluating'):
  reward_accepted = get_reward(problem, accepted)
  reward_rejected = get_reward(problem, rejected)
  gaps.append(reward_accepted - reward_rejected)
  if reward_accepted > reward_rejected:
    correct += 1
  total += 1

  if total % 100 == 0:
    print(f"Correct: {correct}/{total} ({correct/total*100:.2f}%)")
    gc.collect()
    torch.cuda.empty_cache()

Evaluating:   0%|          | 0/13544 [00:00<?, ?it/s]

torch.Size([1, 648])
torch.Size([1, 723])


Evaluating:   0%|          | 1/13544 [00:00<2:44:38,  1.37it/s]

torch.Size([1, 648])
torch.Size([1, 1395])


Evaluating:   0%|          | 2/13544 [00:01<2:34:35,  1.46it/s]

torch.Size([1, 648])
torch.Size([1, 687])


Evaluating:   0%|          | 3/13544 [00:01<2:14:10,  1.68it/s]

torch.Size([1, 648])
torch.Size([1, 675])


Evaluating:   0%|          | 4/13544 [00:02<2:03:56,  1.82it/s]

torch.Size([1, 797])
torch.Size([1, 1488])


Evaluating:   0%|          | 5/13544 [00:03<2:15:23,  1.67it/s]

torch.Size([1, 797])
torch.Size([1, 1570])


Evaluating:   0%|          | 6/13544 [00:03<2:27:07,  1.53it/s]

torch.Size([1, 797])
torch.Size([1, 1268])


Evaluating:   0%|          | 7/13544 [00:04<2:29:10,  1.51it/s]

torch.Size([1, 797])
torch.Size([1, 2271])


Evaluating:   0%|          | 8/13544 [00:05<2:52:29,  1.31it/s]

torch.Size([1, 534])
torch.Size([1, 511])


Evaluating:   0%|          | 9/13544 [00:05<2:29:35,  1.51it/s]

torch.Size([1, 534])
torch.Size([1, 439])


Evaluating:   0%|          | 10/13544 [00:06<2:13:38,  1.69it/s]

torch.Size([1, 534])


Evaluating:   0%|          | 11/13544 [00:06<2:01:49,  1.85it/s]

torch.Size([1, 403])
torch.Size([1, 494])
torch.Size([1, 511])


Evaluating:   0%|          | 12/13544 [00:07<1:53:39,  1.98it/s]

torch.Size([1, 246])
torch.Size([1, 342])


Evaluating:   0%|          | 13/13544 [00:07<1:43:01,  2.19it/s]

torch.Size([1, 256])
torch.Size([1, 342])


Evaluating:   0%|          | 14/13544 [00:07<1:35:29,  2.36it/s]

torch.Size([1, 347])
torch.Size([1, 342])


Evaluating:   0%|          | 15/13544 [00:08<1:31:33,  2.46it/s]

torch.Size([1, 247])
torch.Size([1, 342])


Evaluating:   0%|          | 16/13544 [00:08<1:27:50,  2.57it/s]

torch.Size([1, 273])
torch.Size([1, 201])


Evaluating:   0%|          | 17/13544 [00:08<1:24:52,  2.66it/s]

torch.Size([1, 273])
torch.Size([1, 344])


Evaluating:   0%|          | 18/13544 [00:09<1:24:12,  2.68it/s]

torch.Size([1, 245])
torch.Size([1, 201])


Evaluating:   0%|          | 19/13544 [00:09<1:22:41,  2.73it/s]

torch.Size([1, 245])
torch.Size([1, 344])


Evaluating:   0%|          | 20/13544 [00:10<1:22:36,  2.73it/s]

torch.Size([1, 686])
torch.Size([1, 405])


Evaluating:   0%|          | 21/13544 [00:10<1:28:13,  2.55it/s]

torch.Size([1, 686])


Evaluating:   0%|          | 22/13544 [00:10<1:29:58,  2.50it/s]

torch.Size([1, 357])
torch.Size([1, 686])


Evaluating:   0%|          | 23/13544 [00:11<1:32:12,  2.44it/s]

torch.Size([1, 407])
torch.Size([1, 686])
torch.Size([1, 885])


Evaluating:   0%|          | 24/13544 [00:11<1:39:07,  2.27it/s]

torch.Size([1, 774])
torch.Size([1, 1026])


Evaluating:   0%|          | 25/13544 [00:12<1:49:11,  2.06it/s]

torch.Size([1, 774])
torch.Size([1, 3500])


Evaluating:   0%|          | 26/13544 [00:13<2:50:35,  1.32it/s]

torch.Size([1, 774])
torch.Size([1, 3500])


Evaluating:   0%|          | 27/13544 [00:15<3:38:39,  1.03it/s]

torch.Size([1, 774])
torch.Size([1, 648])


Evaluating:   0%|          | 28/13544 [00:15<3:08:54,  1.19it/s]

torch.Size([1, 1051])
torch.Size([1, 1602])


Evaluating:   0%|          | 29/13544 [00:16<3:08:51,  1.19it/s]

torch.Size([1, 1051])
torch.Size([1, 3500])


Evaluating:   0%|          | 30/13544 [00:18<3:51:29,  1.03s/it]

torch.Size([1, 1051])
torch.Size([1, 3500])


Evaluating:   0%|          | 31/13544 [00:19<4:20:06,  1.15s/it]

torch.Size([1, 1051])
torch.Size([1, 2050])


Evaluating:   0%|          | 32/13544 [00:20<4:05:02,  1.09s/it]

torch.Size([1, 1228])
torch.Size([1, 1401])


Evaluating:   0%|          | 33/13544 [00:21<3:44:09,  1.00it/s]

torch.Size([1, 675])
torch.Size([1, 1401])


Evaluating:   0%|          | 34/13544 [00:21<3:20:55,  1.12it/s]

torch.Size([1, 1228])
torch.Size([1, 675])


Evaluating:   0%|          | 35/13544 [00:22<3:02:23,  1.23it/s]

torch.Size([1, 675])
torch.Size([1, 675])


Evaluating:   0%|          | 36/13544 [00:23<2:40:32,  1.40it/s]

torch.Size([1, 794])
torch.Size([1, 454])


Evaluating:   0%|          | 37/13544 [00:23<2:23:48,  1.57it/s]

torch.Size([1, 794])
torch.Size([1, 1375])


Evaluating:   0%|          | 38/13544 [00:24<2:27:37,  1.52it/s]

torch.Size([1, 794])
torch.Size([1, 464])


Evaluating:   0%|          | 39/13544 [00:24<2:16:22,  1.65it/s]

torch.Size([1, 794])
torch.Size([1, 546])


Evaluating:   0%|          | 40/13544 [00:25<2:09:56,  1.73it/s]

torch.Size([1, 447])
torch.Size([1, 1636])


Evaluating:   0%|          | 41/13544 [00:25<2:18:10,  1.63it/s]

torch.Size([1, 447])
torch.Size([1, 478])


Evaluating:   0%|          | 42/13544 [00:26<2:05:29,  1.79it/s]

torch.Size([1, 447])
torch.Size([1, 637])


Evaluating:   0%|          | 43/13544 [00:26<1:57:52,  1.91it/s]

torch.Size([1, 336])
torch.Size([1, 1636])


Evaluating:   0%|          | 44/13544 [00:27<2:06:57,  1.77it/s]

torch.Size([1, 471])
torch.Size([1, 1922])


Evaluating:   0%|          | 45/13544 [00:28<2:20:54,  1.60it/s]

torch.Size([1, 366])
torch.Size([1, 1922])


Evaluating:   0%|          | 46/13544 [00:28<2:29:31,  1.50it/s]

torch.Size([1, 368])
torch.Size([1, 1922])


Evaluating:   0%|          | 47/13544 [00:29<2:35:29,  1.45it/s]

torch.Size([1, 414])
torch.Size([1, 1922])


Evaluating:   0%|          | 48/13544 [00:30<2:41:43,  1.39it/s]

torch.Size([1, 2008])
torch.Size([1, 763])


Evaluating:   0%|          | 49/13544 [00:31<2:50:24,  1.32it/s]

torch.Size([1, 2008])
torch.Size([1, 858])


Evaluating:   0%|          | 50/13544 [00:32<2:57:07,  1.27it/s]

torch.Size([1, 2008])
torch.Size([1, 900])


Evaluating:   0%|          | 51/13544 [00:33<3:02:26,  1.23it/s]

torch.Size([1, 2008])
torch.Size([1, 946])


Evaluating:   0%|          | 52/13544 [00:33<3:07:20,  1.20it/s]

torch.Size([1, 528])
torch.Size([1, 399])


Evaluating:   0%|          | 53/13544 [00:34<2:39:51,  1.41it/s]

torch.Size([1, 528])
torch.Size([1, 492])


Evaluating:   0%|          | 54/13544 [00:34<2:21:13,  1.59it/s]

torch.Size([1, 528])


Evaluating:   0%|          | 55/13544 [00:35<2:07:03,  1.77it/s]

torch.Size([1, 371])
torch.Size([1, 528])
torch.Size([1, 412])


Evaluating:   0%|          | 56/13544 [00:35<1:57:57,  1.91it/s]

torch.Size([1, 433])
torch.Size([1, 549])


Evaluating:   0%|          | 57/13544 [00:36<1:51:59,  2.01it/s]

torch.Size([1, 433])
torch.Size([1, 556])


Evaluating:   0%|          | 58/13544 [00:36<1:47:52,  2.08it/s]

torch.Size([1, 433])


Evaluating:   0%|          | 59/13544 [00:36<1:43:02,  2.18it/s]

torch.Size([1, 486])
torch.Size([1, 433])
torch.Size([1, 715])


Evaluating:   0%|          | 60/13544 [00:37<1:42:49,  2.19it/s]

torch.Size([1, 404])
torch.Size([1, 493])


Evaluating:   0%|          | 61/13544 [00:37<1:39:01,  2.27it/s]

torch.Size([1, 404])
torch.Size([1, 558])


Evaluating:   0%|          | 62/13544 [00:38<1:37:27,  2.31it/s]

torch.Size([1, 404])
torch.Size([1, 264])


Evaluating:   0%|          | 63/13544 [00:38<1:33:28,  2.40it/s]

torch.Size([1, 318])
torch.Size([1, 493])


Evaluating:   0%|          | 64/13544 [00:38<1:31:32,  2.45it/s]

torch.Size([1, 338])
torch.Size([1, 426])


Evaluating:   0%|          | 65/13544 [00:39<1:30:26,  2.48it/s]

torch.Size([1, 324])
torch.Size([1, 426])


Evaluating:   0%|          | 66/13544 [00:39<1:29:03,  2.52it/s]

torch.Size([1, 352])
torch.Size([1, 426])


Evaluating:   0%|          | 67/13544 [00:40<1:28:07,  2.55it/s]

torch.Size([1, 400])
torch.Size([1, 426])


Evaluating:   1%|          | 68/13544 [00:40<1:28:21,  2.54it/s]

torch.Size([1, 1106])
torch.Size([1, 627])


Evaluating:   1%|          | 69/13544 [00:41<1:41:15,  2.22it/s]

torch.Size([1, 680])
torch.Size([1, 627])


Evaluating:   1%|          | 70/13544 [00:41<1:42:38,  2.19it/s]

torch.Size([1, 586])
torch.Size([1, 627])


Evaluating:   1%|          | 71/13544 [00:42<1:43:14,  2.18it/s]

torch.Size([1, 1106])
torch.Size([1, 1396])


Evaluating:   1%|          | 72/13544 [00:42<2:03:04,  1.82it/s]

torch.Size([1, 383])
torch.Size([1, 242])


Evaluating:   1%|          | 73/13544 [00:43<1:50:42,  2.03it/s]

torch.Size([1, 383])
torch.Size([1, 464])


Evaluating:   1%|          | 74/13544 [00:43<1:43:37,  2.17it/s]

torch.Size([1, 383])
torch.Size([1, 438])


Evaluating:   1%|          | 75/13544 [00:43<1:39:09,  2.26it/s]

torch.Size([1, 383])
torch.Size([1, 293])


Evaluating:   1%|          | 76/13544 [00:44<1:35:25,  2.35it/s]

torch.Size([1, 626])
torch.Size([1, 989])


Evaluating:   1%|          | 77/13544 [00:44<1:44:43,  2.14it/s]

torch.Size([1, 562])
torch.Size([1, 989])


Evaluating:   1%|          | 78/13544 [00:45<1:50:08,  2.04it/s]

torch.Size([1, 525])
torch.Size([1, 989])


Evaluating:   1%|          | 79/13544 [00:46<1:54:26,  1.96it/s]

torch.Size([1, 566])
torch.Size([1, 989])


Evaluating:   1%|          | 80/13544 [00:46<1:56:24,  1.93it/s]

torch.Size([1, 637])
torch.Size([1, 1357])


Evaluating:   1%|          | 81/13544 [00:47<2:05:24,  1.79it/s]

torch.Size([1, 637])
torch.Size([1, 852])


Evaluating:   1%|          | 82/13544 [00:47<2:03:01,  1.82it/s]

torch.Size([1, 637])
torch.Size([1, 767])


Evaluating:   1%|          | 83/13544 [00:48<1:58:56,  1.89it/s]

torch.Size([1, 637])
torch.Size([1, 867])


Evaluating:   1%|          | 84/13544 [00:48<1:58:08,  1.90it/s]

torch.Size([1, 709])
torch.Size([1, 2663])


Evaluating:   1%|          | 85/13544 [00:49<2:34:40,  1.45it/s]

torch.Size([1, 744])
torch.Size([1, 2663])


Evaluating:   1%|          | 86/13544 [00:50<3:00:39,  1.24it/s]

torch.Size([1, 542])
torch.Size([1, 2663])


Evaluating:   1%|          | 87/13544 [00:51<3:16:44,  1.14it/s]

torch.Size([1, 275])
torch.Size([1, 2663])


Evaluating:   1%|          | 88/13544 [00:52<3:25:28,  1.09it/s]

torch.Size([1, 221])
torch.Size([1, 297])


Evaluating:   1%|          | 89/13544 [00:53<2:48:36,  1.33it/s]

torch.Size([1, 299])
torch.Size([1, 297])


Evaluating:   1%|          | 90/13544 [00:53<2:22:54,  1.57it/s]

torch.Size([1, 336])
torch.Size([1, 297])


Evaluating:   1%|          | 91/13544 [00:54<2:05:56,  1.78it/s]

torch.Size([1, 336])
torch.Size([1, 297])


Evaluating:   1%|          | 92/13544 [00:54<1:53:45,  1.97it/s]

torch.Size([1, 215])
torch.Size([1, 204])


Evaluating:   1%|          | 93/13544 [00:54<1:43:07,  2.17it/s]

torch.Size([1, 215])
torch.Size([1, 408])


Evaluating:   1%|          | 94/13544 [00:55<1:38:47,  2.27it/s]

torch.Size([1, 215])
torch.Size([1, 505])


Evaluating:   1%|          | 95/13544 [00:55<1:35:26,  2.35it/s]

torch.Size([1, 274])
torch.Size([1, 204])


Evaluating:   1%|          | 96/13544 [00:55<1:31:32,  2.45it/s]

torch.Size([1, 212])
torch.Size([1, 222])


Evaluating:   1%|          | 97/13544 [00:56<1:27:12,  2.57it/s]

torch.Size([1, 212])
torch.Size([1, 219])


Evaluating:   1%|          | 98/13544 [00:56<1:26:06,  2.60it/s]

torch.Size([1, 212])
torch.Size([1, 257])


Evaluating:   1%|          | 99/13544 [00:57<1:28:34,  2.53it/s]

torch.Size([1, 212])
torch.Size([1, 219])
Correct: 89/100 (89.00%)


Evaluating:   1%|          | 100/13544 [00:57<1:37:36,  2.30it/s]

torch.Size([1, 223])
torch.Size([1, 200])


Evaluating:   1%|          | 101/13544 [00:57<1:33:06,  2.41it/s]

torch.Size([1, 223])
torch.Size([1, 320])


Evaluating:   1%|          | 102/13544 [00:58<1:33:37,  2.39it/s]

torch.Size([1, 262])
torch.Size([1, 200])


Evaluating:   1%|          | 103/13544 [00:58<1:31:18,  2.45it/s]

torch.Size([1, 390])
torch.Size([1, 200])


Evaluating:   1%|          | 104/13544 [00:59<1:29:09,  2.51it/s]

torch.Size([1, 1896])
torch.Size([1, 1083])


Evaluating:   1%|          | 105/13544 [01:00<2:03:57,  1.81it/s]

torch.Size([1, 1896])
torch.Size([1, 1156])


Evaluating:   1%|          | 106/13544 [01:00<2:28:13,  1.51it/s]

torch.Size([1, 1896])
torch.Size([1, 1331])


Evaluating:   1%|          | 107/13544 [01:01<2:45:04,  1.36it/s]

torch.Size([1, 1896])
torch.Size([1, 976])


Evaluating:   1%|          | 108/13544 [01:02<2:49:32,  1.32it/s]

torch.Size([1, 2421])
torch.Size([1, 3500])


Evaluating:   1%|          | 109/13544 [01:04<4:06:10,  1.10s/it]

torch.Size([1, 2421])
torch.Size([1, 3500])


Evaluating:   1%|          | 110/13544 [01:06<5:09:37,  1.38s/it]

torch.Size([1, 2421])
torch.Size([1, 2527])


Evaluating:   1%|          | 111/13544 [01:08<5:14:27,  1.40s/it]

torch.Size([1, 2421])
torch.Size([1, 722])


Evaluating:   1%|          | 112/13544 [01:09<4:42:29,  1.26s/it]

torch.Size([1, 783])
torch.Size([1, 1013])


Evaluating:   1%|          | 113/13544 [01:09<3:55:39,  1.05s/it]

torch.Size([1, 783])
torch.Size([1, 3500])


Evaluating:   1%|          | 114/13544 [01:10<4:18:13,  1.15s/it]

torch.Size([1, 783])
torch.Size([1, 633])


Evaluating:   1%|          | 115/13544 [01:11<3:34:15,  1.04it/s]

torch.Size([1, 783])
torch.Size([1, 1125])


Evaluating:   1%|          | 116/13544 [01:12<3:12:30,  1.16it/s]

torch.Size([1, 361])
torch.Size([1, 567])


Evaluating:   1%|          | 117/13544 [01:12<2:42:59,  1.37it/s]

torch.Size([1, 361])
torch.Size([1, 495])


Evaluating:   1%|          | 118/13544 [01:12<2:20:26,  1.59it/s]

torch.Size([1, 361])
torch.Size([1, 276])


Evaluating:   1%|          | 119/13544 [01:13<2:03:12,  1.82it/s]

torch.Size([1, 361])
torch.Size([1, 248])


Evaluating:   1%|          | 120/13544 [01:13<1:49:24,  2.04it/s]

torch.Size([1, 469])
torch.Size([1, 1956])


Evaluating:   1%|          | 121/13544 [01:14<2:07:04,  1.76it/s]

torch.Size([1, 469])
torch.Size([1, 1003])


Evaluating:   1%|          | 122/13544 [01:14<2:05:46,  1.78it/s]

torch.Size([1, 469])
torch.Size([1, 866])


Evaluating:   1%|          | 123/13544 [01:15<2:02:10,  1.83it/s]

torch.Size([1, 469])
torch.Size([1, 947])


Evaluating:   1%|          | 124/13544 [01:15<1:58:57,  1.88it/s]

torch.Size([1, 210])
torch.Size([1, 229])


Evaluating:   1%|          | 125/13544 [01:16<1:46:51,  2.09it/s]

torch.Size([1, 248])
torch.Size([1, 229])


Evaluating:   1%|          | 126/13544 [01:16<1:37:56,  2.28it/s]

torch.Size([1, 264])
torch.Size([1, 229])


Evaluating:   1%|          | 127/13544 [01:17<1:32:34,  2.42it/s]

torch.Size([1, 240])
torch.Size([1, 229])


Evaluating:   1%|          | 128/13544 [01:17<1:26:43,  2.58it/s]

torch.Size([1, 1707])
torch.Size([1, 3500])


Evaluating:   1%|          | 129/13544 [01:18<2:41:25,  1.38it/s]

torch.Size([1, 1346])
torch.Size([1, 3500])


Evaluating:   1%|          | 130/13544 [01:20<3:33:47,  1.05it/s]

torch.Size([1, 1707])
torch.Size([1, 577])


Evaluating:   1%|          | 131/13544 [01:21<3:18:52,  1.12it/s]

torch.Size([1, 1346])
torch.Size([1, 577])


Evaluating:   1%|          | 132/13544 [01:21<2:59:50,  1.24it/s]

torch.Size([1, 724])
torch.Size([1, 1533])


Evaluating:   1%|          | 133/13544 [01:22<2:50:35,  1.31it/s]

torch.Size([1, 724])
torch.Size([1, 2073])


Evaluating:   1%|          | 134/13544 [01:23<2:57:54,  1.26it/s]

torch.Size([1, 724])
torch.Size([1, 2546])


Evaluating:   1%|          | 135/13544 [01:24<3:14:07,  1.15it/s]

torch.Size([1, 724])
torch.Size([1, 586])


Evaluating:   1%|          | 136/13544 [01:24<2:50:22,  1.31it/s]

torch.Size([1, 606])
torch.Size([1, 475])


Evaluating:   1%|          | 137/13544 [01:25<2:30:00,  1.49it/s]

torch.Size([1, 606])
torch.Size([1, 501])


Evaluating:   1%|          | 138/13544 [01:25<2:15:26,  1.65it/s]

torch.Size([1, 606])
torch.Size([1, 439])


Evaluating:   1%|          | 139/13544 [01:26<2:04:39,  1.79it/s]

torch.Size([1, 606])
torch.Size([1, 449])


Evaluating:   1%|          | 140/13544 [01:26<2:00:05,  1.86it/s]

torch.Size([1, 229])
torch.Size([1, 201])


Evaluating:   1%|          | 141/13544 [01:27<1:50:45,  2.02it/s]

torch.Size([1, 239])
torch.Size([1, 201])


Evaluating:   1%|          | 142/13544 [01:27<1:42:43,  2.17it/s]

torch.Size([1, 261])
torch.Size([1, 201])


Evaluating:   1%|          | 143/13544 [01:27<1:37:27,  2.29it/s]

torch.Size([1, 275])
torch.Size([1, 201])


Evaluating:   1%|          | 144/13544 [01:28<1:32:10,  2.42it/s]

torch.Size([1, 745])
torch.Size([1, 668])


Evaluating:   1%|          | 145/13544 [01:28<1:36:43,  2.31it/s]

torch.Size([1, 745])
torch.Size([1, 686])


Evaluating:   1%|          | 146/13544 [01:29<1:41:19,  2.20it/s]

torch.Size([1, 745])
torch.Size([1, 544])


Evaluating:   1%|          | 147/13544 [01:29<1:42:29,  2.18it/s]

torch.Size([1, 745])
torch.Size([1, 923])


Evaluating:   1%|          | 148/13544 [01:30<1:46:27,  2.10it/s]

torch.Size([1, 160])
torch.Size([1, 368])


Evaluating:   1%|          | 149/13544 [01:30<1:38:00,  2.28it/s]

torch.Size([1, 209])
torch.Size([1, 368])


Evaluating:   1%|          | 150/13544 [01:30<1:32:20,  2.42it/s]

torch.Size([1, 194])
torch.Size([1, 368])


Evaluating:   1%|          | 151/13544 [01:31<1:27:44,  2.54it/s]

torch.Size([1, 222])
torch.Size([1, 368])


Evaluating:   1%|          | 152/13544 [01:31<1:27:13,  2.56it/s]

torch.Size([1, 530])
torch.Size([1, 714])


Evaluating:   1%|          | 153/13544 [01:31<1:30:50,  2.46it/s]

torch.Size([1, 471])
torch.Size([1, 714])


Evaluating:   1%|          | 154/13544 [01:32<1:33:38,  2.38it/s]

torch.Size([1, 494])
torch.Size([1, 714])


Evaluating:   1%|          | 155/13544 [01:32<1:34:38,  2.36it/s]

torch.Size([1, 532])
torch.Size([1, 714])


Evaluating:   1%|          | 156/13544 [01:33<1:37:44,  2.28it/s]

torch.Size([1, 288])
torch.Size([1, 334])


Evaluating:   1%|          | 157/13544 [01:33<1:33:52,  2.38it/s]

torch.Size([1, 288])
torch.Size([1, 296])


Evaluating:   1%|          | 158/13544 [01:34<1:32:26,  2.41it/s]

torch.Size([1, 288])
torch.Size([1, 259])


Evaluating:   1%|          | 159/13544 [01:34<1:29:49,  2.48it/s]

torch.Size([1, 288])
torch.Size([1, 283])


Evaluating:   1%|          | 160/13544 [01:34<1:28:45,  2.51it/s]

torch.Size([1, 401])
torch.Size([1, 373])


Evaluating:   1%|          | 161/13544 [01:35<1:31:33,  2.44it/s]

torch.Size([1, 437])
torch.Size([1, 373])


Evaluating:   1%|          | 162/13544 [01:35<1:33:10,  2.39it/s]

torch.Size([1, 401])


Evaluating:   1%|          | 163/13544 [01:36<1:31:34,  2.44it/s]

torch.Size([1, 361])
torch.Size([1, 437])


Evaluating:   1%|          | 164/13544 [01:36<1:32:02,  2.42it/s]

torch.Size([1, 361])
torch.Size([1, 293])


Evaluating:   1%|          | 165/13544 [01:36<1:28:56,  2.51it/s]

torch.Size([1, 196])
torch.Size([1, 253])


Evaluating:   1%|          | 166/13544 [01:37<1:25:24,  2.61it/s]

torch.Size([1, 196])
torch.Size([1, 293])


Evaluating:   1%|          | 167/13544 [01:37<1:23:46,  2.66it/s]

torch.Size([1, 306])
torch.Size([1, 253])


Evaluating:   1%|          | 168/13544 [01:38<1:23:20,  2.68it/s]

torch.Size([1, 306])
torch.Size([1, 646])
torch.Size([1, 467])


Evaluating:   1%|          | 169/13544 [01:38<1:28:32,  2.52it/s]

torch.Size([1, 591])
torch.Size([1, 467])


Evaluating:   1%|▏         | 170/13544 [01:38<1:30:12,  2.47it/s]

torch.Size([1, 463])


Evaluating:   1%|▏         | 171/13544 [01:39<1:29:59,  2.48it/s]

torch.Size([1, 467])
torch.Size([1, 600])
torch.Size([1, 467])


Evaluating:   1%|▏         | 172/13544 [01:39<1:31:11,  2.44it/s]

torch.Size([1, 283])
torch.Size([1, 904])


Evaluating:   1%|▏         | 173/13544 [01:40<1:35:13,  2.34it/s]

torch.Size([1, 283])
torch.Size([1, 488])


Evaluating:   1%|▏         | 174/13544 [01:40<1:34:57,  2.35it/s]

torch.Size([1, 283])
torch.Size([1, 522])


Evaluating:   1%|▏         | 175/13544 [01:40<1:32:36,  2.41it/s]

torch.Size([1, 283])
torch.Size([1, 268])


Evaluating:   1%|▏         | 176/13544 [01:41<1:28:36,  2.51it/s]

torch.Size([1, 505])
torch.Size([1, 437])


Evaluating:   1%|▏         | 177/13544 [01:41<1:28:09,  2.53it/s]

torch.Size([1, 505])
torch.Size([1, 1340])


Evaluating:   1%|▏         | 178/13544 [01:42<1:42:25,  2.18it/s]

torch.Size([1, 505])
torch.Size([1, 1105])


Evaluating:   1%|▏         | 179/13544 [01:42<1:48:02,  2.06it/s]

torch.Size([1, 505])
torch.Size([1, 1087])


Evaluating:   1%|▏         | 180/13544 [01:43<1:50:36,  2.01it/s]

torch.Size([1, 277])
torch.Size([1, 207])


Evaluating:   1%|▏         | 181/13544 [01:43<1:40:15,  2.22it/s]

torch.Size([1, 277])
torch.Size([1, 305])


Evaluating:   1%|▏         | 182/13544 [01:44<1:35:30,  2.33it/s]

torch.Size([1, 277])
torch.Size([1, 279])


Evaluating:   1%|▏         | 183/13544 [01:44<1:33:18,  2.39it/s]

torch.Size([1, 277])
torch.Size([1, 326])


Evaluating:   1%|▏         | 184/13544 [01:44<1:31:21,  2.44it/s]

torch.Size([1, 650])
torch.Size([1, 897])


Evaluating:   1%|▏         | 185/13544 [01:45<1:41:33,  2.19it/s]

torch.Size([1, 791])
torch.Size([1, 897])


Evaluating:   1%|▏         | 186/13544 [01:46<1:48:49,  2.05it/s]

torch.Size([1, 688])
torch.Size([1, 897])


Evaluating:   1%|▏         | 187/13544 [01:46<1:53:44,  1.96it/s]

torch.Size([1, 668])
torch.Size([1, 897])


Evaluating:   1%|▏         | 188/13544 [01:47<1:55:41,  1.92it/s]

torch.Size([1, 267])
torch.Size([1, 251])


Evaluating:   1%|▏         | 189/13544 [01:47<1:45:04,  2.12it/s]

torch.Size([1, 243])
torch.Size([1, 251])


Evaluating:   1%|▏         | 190/13544 [01:47<1:37:31,  2.28it/s]

torch.Size([1, 242])
torch.Size([1, 251])


Evaluating:   1%|▏         | 191/13544 [01:48<1:32:01,  2.42it/s]

torch.Size([1, 271])
torch.Size([1, 251])


Evaluating:   1%|▏         | 192/13544 [01:48<1:28:20,  2.52it/s]

torch.Size([1, 427])
torch.Size([1, 143])


Evaluating:   1%|▏         | 193/13544 [01:48<1:25:08,  2.61it/s]

torch.Size([1, 219])
torch.Size([1, 143])


Evaluating:   1%|▏         | 194/13544 [01:49<1:20:34,  2.76it/s]

torch.Size([1, 427])


Evaluating:   1%|▏         | 195/13544 [01:49<1:20:55,  2.75it/s]

torch.Size([1, 240])
torch.Size([1, 427])


Evaluating:   1%|▏         | 196/13544 [01:49<1:22:27,  2.70it/s]

torch.Size([1, 221])
torch.Size([1, 292])
torch.Size([1, 420])


Evaluating:   1%|▏         | 197/13544 [01:50<1:23:45,  2.66it/s]

torch.Size([1, 292])
torch.Size([1, 331])


Evaluating:   1%|▏         | 198/13544 [01:50<1:23:59,  2.65it/s]

torch.Size([1, 292])
torch.Size([1, 283])


Evaluating:   1%|▏         | 199/13544 [01:51<1:24:23,  2.64it/s]

torch.Size([1, 347])
torch.Size([1, 420])


Evaluating:   1%|▏         | 200/13544 [01:51<1:33:30,  2.38it/s]

Correct: 179/200 (89.50%)
torch.Size([1, 468])


Evaluating:   1%|▏         | 201/13544 [01:52<1:31:27,  2.43it/s]

torch.Size([1, 380])
torch.Size([1, 468])


Evaluating:   1%|▏         | 202/13544 [01:52<1:29:11,  2.49it/s]

torch.Size([1, 375])
torch.Size([1, 468])
torch.Size([1, 527])


Evaluating:   1%|▏         | 203/13544 [01:52<1:29:15,  2.49it/s]

torch.Size([1, 389])
torch.Size([1, 380])


Evaluating:   2%|▏         | 204/13544 [01:53<1:27:00,  2.56it/s]

torch.Size([1, 173])
torch.Size([1, 168])


Evaluating:   2%|▏         | 205/13544 [01:53<1:21:07,  2.74it/s]

torch.Size([1, 173])
torch.Size([1, 161])


Evaluating:   2%|▏         | 206/13544 [01:53<1:17:01,  2.89it/s]

torch.Size([1, 173])
torch.Size([1, 187])


Evaluating:   2%|▏         | 207/13544 [01:54<1:16:02,  2.92it/s]

torch.Size([1, 173])
torch.Size([1, 165])


Evaluating:   2%|▏         | 208/13544 [01:54<1:15:32,  2.94it/s]

torch.Size([1, 295])
torch.Size([1, 573])


Evaluating:   2%|▏         | 209/13544 [01:54<1:21:49,  2.72it/s]

torch.Size([1, 327])
torch.Size([1, 573])


Evaluating:   2%|▏         | 210/13544 [01:55<1:26:20,  2.57it/s]

torch.Size([1, 311])
torch.Size([1, 573])


Evaluating:   2%|▏         | 211/13544 [01:55<1:29:39,  2.48it/s]

torch.Size([1, 322])
torch.Size([1, 573])


Evaluating:   2%|▏         | 212/13544 [01:56<1:30:58,  2.44it/s]

torch.Size([1, 898])
torch.Size([1, 2973])


Evaluating:   2%|▏         | 213/13544 [01:57<2:27:43,  1.50it/s]

torch.Size([1, 898])
torch.Size([1, 1002])


Evaluating:   2%|▏         | 214/13544 [01:58<2:24:07,  1.54it/s]

torch.Size([1, 898])
torch.Size([1, 995])


Evaluating:   2%|▏         | 215/13544 [01:58<2:20:20,  1.58it/s]

torch.Size([1, 898])
torch.Size([1, 1263])


Evaluating:   2%|▏         | 216/13544 [01:59<2:22:33,  1.56it/s]

torch.Size([1, 241])
torch.Size([1, 360])


Evaluating:   2%|▏         | 217/13544 [01:59<2:04:03,  1.79it/s]

torch.Size([1, 367])
torch.Size([1, 360])


Evaluating:   2%|▏         | 218/13544 [02:00<1:52:12,  1.98it/s]

torch.Size([1, 334])
torch.Size([1, 360])


Evaluating:   2%|▏         | 219/13544 [02:00<1:42:37,  2.16it/s]

torch.Size([1, 191])
torch.Size([1, 360])


Evaluating:   2%|▏         | 220/13544 [02:00<1:35:10,  2.33it/s]

torch.Size([1, 151])
torch.Size([1, 422])


Evaluating:   2%|▏         | 221/13544 [02:01<1:31:07,  2.44it/s]

torch.Size([1, 151])
torch.Size([1, 232])


Evaluating:   2%|▏         | 222/13544 [02:01<1:26:48,  2.56it/s]

torch.Size([1, 260])
torch.Size([1, 422])


Evaluating:   2%|▏         | 223/13544 [02:01<1:25:23,  2.60it/s]

torch.Size([1, 193])
torch.Size([1, 422])


Evaluating:   2%|▏         | 224/13544 [02:02<1:23:56,  2.64it/s]

torch.Size([1, 285])
torch.Size([1, 262])


Evaluating:   2%|▏         | 225/13544 [02:02<1:23:32,  2.66it/s]

torch.Size([1, 285])
torch.Size([1, 321])


Evaluating:   2%|▏         | 226/13544 [02:02<1:22:09,  2.70it/s]

torch.Size([1, 285])
torch.Size([1, 298])


Evaluating:   2%|▏         | 227/13544 [02:03<1:21:10,  2.73it/s]

torch.Size([1, 285])
torch.Size([1, 231])


Evaluating:   2%|▏         | 228/13544 [02:03<1:20:00,  2.77it/s]

torch.Size([1, 907])
torch.Size([1, 827])


Evaluating:   2%|▏         | 229/13544 [02:04<1:33:36,  2.37it/s]

torch.Size([1, 1082])
torch.Size([1, 827])


Evaluating:   2%|▏         | 230/13544 [02:04<1:46:19,  2.09it/s]

torch.Size([1, 907])
torch.Size([1, 900])


Evaluating:   2%|▏         | 231/13544 [02:05<1:53:28,  1.96it/s]

torch.Size([1, 907])
torch.Size([1, 934])


Evaluating:   2%|▏         | 232/13544 [02:06<1:58:58,  1.86it/s]

torch.Size([1, 985])
torch.Size([1, 1568])


Evaluating:   2%|▏         | 233/13544 [02:06<2:15:09,  1.64it/s]

torch.Size([1, 779])
torch.Size([1, 1568])


Evaluating:   2%|▏         | 234/13544 [02:07<2:25:58,  1.52it/s]

torch.Size([1, 985])
torch.Size([1, 1280])


Evaluating:   2%|▏         | 235/13544 [02:08<2:31:36,  1.46it/s]

torch.Size([1, 779])
torch.Size([1, 1280])


Evaluating:   2%|▏         | 236/13544 [02:09<2:31:11,  1.47it/s]

torch.Size([1, 550])
torch.Size([1, 800])


Evaluating:   2%|▏         | 237/13544 [02:09<2:18:57,  1.60it/s]

torch.Size([1, 550])


Evaluating:   2%|▏         | 238/13544 [02:09<2:04:57,  1.77it/s]

torch.Size([1, 469])
torch.Size([1, 550])
torch.Size([1, 1181])


Evaluating:   2%|▏         | 239/13544 [02:10<2:08:31,  1.73it/s]

torch.Size([1, 550])


Evaluating:   2%|▏         | 240/13544 [02:10<1:57:46,  1.88it/s]

torch.Size([1, 422])
torch.Size([1, 253])


Evaluating:   2%|▏         | 241/13544 [02:11<1:47:02,  2.07it/s]

torch.Size([1, 312])
torch.Size([1, 253])


Evaluating:   2%|▏         | 242/13544 [02:11<1:38:25,  2.25it/s]

torch.Size([1, 227])
torch.Size([1, 253])
torch.Size([1, 331])


Evaluating:   2%|▏         | 243/13544 [02:12<1:34:03,  2.36it/s]

torch.Size([1, 253])
torch.Size([1, 263])


Evaluating:   2%|▏         | 244/13544 [02:12<1:30:10,  2.46it/s]

torch.Size([1, 602])
torch.Size([1, 593])


Evaluating:   2%|▏         | 245/13544 [02:12<1:33:25,  2.37it/s]

torch.Size([1, 602])


Evaluating:   2%|▏         | 246/13544 [02:13<1:33:45,  2.36it/s]

torch.Size([1, 483])
torch.Size([1, 602])


Evaluating:   2%|▏         | 247/13544 [02:13<1:34:55,  2.33it/s]

torch.Size([1, 435])
torch.Size([1, 602])
torch.Size([1, 629])


Evaluating:   2%|▏         | 248/13544 [02:14<1:36:50,  2.29it/s]

torch.Size([1, 257])
torch.Size([1, 427])


Evaluating:   2%|▏         | 249/13544 [02:14<1:34:02,  2.36it/s]

torch.Size([1, 257])
torch.Size([1, 358])


Evaluating:   2%|▏         | 250/13544 [02:14<1:31:34,  2.42it/s]

torch.Size([1, 257])
torch.Size([1, 291])


Evaluating:   2%|▏         | 251/13544 [02:15<1:29:56,  2.46it/s]

torch.Size([1, 257])
torch.Size([1, 322])


Evaluating:   2%|▏         | 252/13544 [02:15<1:28:31,  2.50it/s]

torch.Size([1, 823])
torch.Size([1, 622])


Evaluating:   2%|▏         | 253/13544 [02:16<1:35:53,  2.31it/s]

torch.Size([1, 1021])
torch.Size([1, 622])


Evaluating:   2%|▏         | 254/13544 [02:16<1:44:41,  2.12it/s]

torch.Size([1, 813])
torch.Size([1, 622])


Evaluating:   2%|▏         | 255/13544 [02:17<1:47:32,  2.06it/s]

torch.Size([1, 761])
torch.Size([1, 622])


Evaluating:   2%|▏         | 256/13544 [02:17<1:47:19,  2.06it/s]

torch.Size([1, 345])
torch.Size([1, 354])


Evaluating:   2%|▏         | 257/13544 [02:18<1:40:41,  2.20it/s]

torch.Size([1, 345])
torch.Size([1, 310])


Evaluating:   2%|▏         | 258/13544 [02:18<1:35:03,  2.33it/s]

torch.Size([1, 345])
torch.Size([1, 631])


Evaluating:   2%|▏         | 259/13544 [02:19<1:34:53,  2.33it/s]

torch.Size([1, 334])
torch.Size([1, 354])


Evaluating:   2%|▏         | 260/13544 [02:19<1:32:01,  2.41it/s]

torch.Size([1, 1015])
torch.Size([1, 1120])


Evaluating:   2%|▏         | 261/13544 [02:20<1:45:21,  2.10it/s]

torch.Size([1, 1122])
torch.Size([1, 1120])


Evaluating:   2%|▏         | 262/13544 [02:20<1:57:57,  1.88it/s]

torch.Size([1, 1015])
torch.Size([1, 739])


Evaluating:   2%|▏         | 263/13544 [02:21<1:57:50,  1.88it/s]

torch.Size([1, 1122])
torch.Size([1, 739])


Evaluating:   2%|▏         | 264/13544 [02:21<2:00:30,  1.84it/s]

torch.Size([1, 1026])
torch.Size([1, 801])


Evaluating:   2%|▏         | 265/13544 [02:22<2:02:03,  1.81it/s]

torch.Size([1, 1026])
torch.Size([1, 678])


Evaluating:   2%|▏         | 266/13544 [02:22<2:03:55,  1.79it/s]

torch.Size([1, 1026])
torch.Size([1, 583])


Evaluating:   2%|▏         | 267/13544 [02:23<2:02:27,  1.81it/s]

torch.Size([1, 1026])
torch.Size([1, 549])


Evaluating:   2%|▏         | 268/13544 [02:23<2:00:17,  1.84it/s]

torch.Size([1, 1297])
torch.Size([1, 402])


Evaluating:   2%|▏         | 269/13544 [02:24<2:03:38,  1.79it/s]

torch.Size([1, 1297])
torch.Size([1, 674])


Evaluating:   2%|▏         | 270/13544 [02:25<2:18:23,  1.60it/s]

torch.Size([1, 1297])
torch.Size([1, 959])


Evaluating:   2%|▏         | 271/13544 [02:26<2:25:05,  1.52it/s]

torch.Size([1, 563])


Evaluating:   2%|▏         | 272/13544 [02:26<2:10:06,  1.70it/s]

torch.Size([1, 402])
torch.Size([1, 191])
torch.Size([1, 458])


Evaluating:   2%|▏         | 273/13544 [02:26<1:56:37,  1.90it/s]

torch.Size([1, 191])


Evaluating:   2%|▏         | 274/13544 [02:27<1:46:39,  2.07it/s]

torch.Size([1, 148])
torch.Size([1, 220])
torch.Size([1, 458])


Evaluating:   2%|▏         | 275/13544 [02:27<1:40:19,  2.20it/s]

torch.Size([1, 220])
torch.Size([1, 148])


Evaluating:   2%|▏         | 276/13544 [02:27<1:30:51,  2.43it/s]

torch.Size([1, 505])
torch.Size([1, 176])


Evaluating:   2%|▏         | 277/13544 [02:28<1:27:40,  2.52it/s]

torch.Size([1, 505])
torch.Size([1, 290])


Evaluating:   2%|▏         | 278/13544 [02:28<1:26:38,  2.55it/s]

torch.Size([1, 187])
torch.Size([1, 176])


Evaluating:   2%|▏         | 279/13544 [02:29<1:22:14,  2.69it/s]

torch.Size([1, 237])
torch.Size([1, 176])


Evaluating:   2%|▏         | 280/13544 [02:29<1:19:18,  2.79it/s]

torch.Size([1, 1254])
torch.Size([1, 1196])


Evaluating:   2%|▏         | 281/13544 [02:30<1:46:35,  2.07it/s]

torch.Size([1, 649])
torch.Size([1, 1196])


Evaluating:   2%|▏         | 282/13544 [02:30<1:55:39,  1.91it/s]

torch.Size([1, 675])
torch.Size([1, 1196])


Evaluating:   2%|▏         | 283/13544 [02:31<2:01:04,  1.83it/s]

torch.Size([1, 1254])
torch.Size([1, 639])


Evaluating:   2%|▏         | 284/13544 [02:31<2:05:49,  1.76it/s]

torch.Size([1, 730])
torch.Size([1, 462])


Evaluating:   2%|▏         | 285/13544 [02:32<1:59:15,  1.85it/s]

torch.Size([1, 730])
torch.Size([1, 872])


Evaluating:   2%|▏         | 286/13544 [02:33<1:59:45,  1.85it/s]

torch.Size([1, 730])


Evaluating:   2%|▏         | 287/13544 [02:33<1:54:06,  1.94it/s]

torch.Size([1, 422])
torch.Size([1, 730])
torch.Size([1, 801])


Evaluating:   2%|▏         | 288/13544 [02:33<1:54:08,  1.94it/s]

torch.Size([1, 322])
torch.Size([1, 374])


Evaluating:   2%|▏         | 289/13544 [02:34<1:44:47,  2.11it/s]

torch.Size([1, 399])


Evaluating:   2%|▏         | 290/13544 [02:34<1:40:09,  2.21it/s]

torch.Size([1, 374])
torch.Size([1, 322])


Evaluating:   2%|▏         | 291/13544 [02:35<1:36:04,  2.30it/s]

torch.Size([1, 319])
torch.Size([1, 322])
torch.Size([1, 412])


Evaluating:   2%|▏         | 292/13544 [02:35<1:34:50,  2.33it/s]

torch.Size([1, 1105])
torch.Size([1, 1165])


Evaluating:   2%|▏         | 293/13544 [02:36<1:56:49,  1.89it/s]

torch.Size([1, 1105])
torch.Size([1, 1169])


Evaluating:   2%|▏         | 294/13544 [02:37<2:11:45,  1.68it/s]

torch.Size([1, 1105])
torch.Size([1, 500])


Evaluating:   2%|▏         | 295/13544 [02:37<2:11:50,  1.67it/s]

torch.Size([1, 1105])
torch.Size([1, 1106])


Evaluating:   2%|▏         | 296/13544 [02:38<2:20:26,  1.57it/s]

torch.Size([1, 302])
torch.Size([1, 747])


Evaluating:   2%|▏         | 297/13544 [02:38<2:08:01,  1.72it/s]

torch.Size([1, 302])
torch.Size([1, 353])


Evaluating:   2%|▏         | 298/13544 [02:39<1:53:50,  1.94it/s]

torch.Size([1, 354])
torch.Size([1, 747])


Evaluating:   2%|▏         | 299/13544 [02:39<1:48:55,  2.03it/s]

torch.Size([1, 354])
torch.Size([1, 353])


Evaluating:   2%|▏         | 300/13544 [02:40<1:48:00,  2.04it/s]

Correct: 269/300 (89.67%)
torch.Size([1, 394])
torch.Size([1, 407])


Evaluating:   2%|▏         | 301/13544 [02:40<1:41:55,  2.17it/s]

torch.Size([1, 394])
torch.Size([1, 451])


Evaluating:   2%|▏         | 302/13544 [02:40<1:37:55,  2.25it/s]

torch.Size([1, 394])
torch.Size([1, 462])


Evaluating:   2%|▏         | 303/13544 [02:41<1:34:29,  2.34it/s]

torch.Size([1, 394])
torch.Size([1, 339])


Evaluating:   2%|▏         | 304/13544 [02:41<1:32:23,  2.39it/s]

torch.Size([1, 1332])
torch.Size([1, 3500])


Evaluating:   2%|▏         | 305/13544 [02:43<2:48:58,  1.31it/s]

torch.Size([1, 1611])
torch.Size([1, 3500])


Evaluating:   2%|▏         | 306/13544 [02:44<3:48:42,  1.04s/it]

torch.Size([1, 1332])
torch.Size([1, 3500])


Evaluating:   2%|▏         | 307/13544 [02:46<4:27:26,  1.21s/it]

torch.Size([1, 1611])
torch.Size([1, 3500])


Evaluating:   2%|▏         | 308/13544 [02:48<4:59:13,  1.36s/it]

torch.Size([1, 367])


Evaluating:   2%|▏         | 309/13544 [02:48<3:55:36,  1.07s/it]

torch.Size([1, 379])
torch.Size([1, 367])
torch.Size([1, 694])


Evaluating:   2%|▏         | 310/13544 [02:49<3:13:40,  1.14it/s]

torch.Size([1, 367])
torch.Size([1, 382])


Evaluating:   2%|▏         | 311/13544 [02:49<2:40:15,  1.38it/s]

torch.Size([1, 367])
torch.Size([1, 1714])


Evaluating:   2%|▏         | 312/13544 [02:50<2:38:32,  1.39it/s]

torch.Size([1, 240])
torch.Size([1, 227])


Evaluating:   2%|▏         | 313/13544 [02:50<2:14:09,  1.64it/s]

torch.Size([1, 173])
torch.Size([1, 227])


Evaluating:   2%|▏         | 314/13544 [02:50<1:54:53,  1.92it/s]

torch.Size([1, 240])
torch.Size([1, 190])


Evaluating:   2%|▏         | 315/13544 [02:51<1:42:17,  2.16it/s]

torch.Size([1, 173])
torch.Size([1, 190])


Evaluating:   2%|▏         | 316/13544 [02:51<1:32:48,  2.38it/s]

torch.Size([1, 315])
torch.Size([1, 357])


Evaluating:   2%|▏         | 317/13544 [02:51<1:29:24,  2.47it/s]

torch.Size([1, 319])
torch.Size([1, 357])


Evaluating:   2%|▏         | 318/13544 [02:52<1:27:01,  2.53it/s]

torch.Size([1, 355])
torch.Size([1, 357])


Evaluating:   2%|▏         | 319/13544 [02:52<1:25:56,  2.56it/s]

torch.Size([1, 361])
torch.Size([1, 357])


Evaluating:   2%|▏         | 320/13544 [02:53<1:26:13,  2.56it/s]

torch.Size([1, 204])
torch.Size([1, 189])


Evaluating:   2%|▏         | 321/13544 [02:53<1:24:40,  2.60it/s]

torch.Size([1, 218])
torch.Size([1, 189])


Evaluating:   2%|▏         | 322/13544 [02:53<1:21:39,  2.70it/s]

In [6]:
print(f"Correct: {correct}/{total} ({correct/total*100:.2f}%)")

Correct: 4062/4180 (97.18%)


In [2]:
from tqdm import tqdm
import numpy as np
import gc
import torch

def get_rewards_batch(batch):
    prompts, solutions = zip(*batch)
    messages = [
        [{"role": "user", "content": prompt}, {"role": "assistant", "content": solution}]
        for prompt, solution in zip(prompts, solutions)
    ]
    encoded = tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer(encoded, return_tensors='pt', truncation=True, max_length=4096, padding=True).to('cuda')
    with torch.no_grad():
        outputs = model(**inputs)
        rewards = outputs.logits[:, 0].tolist()
        del inputs, outputs
        return rewards

def precompute_rewards(ds, batch_size=10):
    rewards = []
    for i in tqdm(range(0, len(ds), batch_size), desc='Precomputing Rewards'):
        batch = list(zip(ds['problem'][i:i + batch_size], ds['accepted'][i:i + batch_size], ds['rejected'][i:i + batch_size]))
        accepted_batch = [(problem, accepted) for problem, accepted, _ in batch]
        rejected_batch = [(problem, rejected) for problem, _, rejected in batch]

        reward_accepted = get_rewards_batch(accepted_batch)
        reward_rejected = get_rewards_batch(rejected_batch)

        batch_rewards = list(zip(reward_accepted, reward_rejected))
        rewards.extend(batch_rewards)

        gc.collect()
        torch.cuda.empty_cache()

    return rewards

def evaluate_from_precomputed(rewards):
    correct = 0
    total = 0
    gaps = []

    for reward_accepted, reward_rejected in rewards:
        gap = reward_accepted - reward_rejected
        gaps.append(gap)
        if reward_accepted > reward_rejected:
            correct += 1
        total += 1

    return correct, total, gaps

# Example usage:
# ds should be a list of tuples in the form [(problem, accepted, rejected), ...]
batch_size = 8
precomputed_rewards = precompute_rewards(ds, batch_size=batch_size)
correct, total, gaps = evaluate_from_precomputed(precomputed_rewards)
print(f"Final Results: Correct: {correct}/{total} ({correct/total*100:.2f}%)")


NameError: name 'ds' is not defined

In [5]:
from transformers import AutoConfig, AutoModel
config = AutoConfig.from_pretrained("internlm/internlm2-7b")

config.json:   0%|          | 0.00/840 [00:00<?, ?B/s]

configuration_internlm2.py:   0%|          | 0.00/8.84k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/internlm/internlm2-7b:
- configuration_internlm2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


In [10]:
config.auto_map['AutoModel'] = 'internlm/internlm2-7b--modeling_internlm2.InternLM2ForSequenceClassification'

'internlm/internlm2-7b--modeling_internlm2.InternLM2ForCausalLM'