In [1]:
from main_utils import *
mydevice = 'cuda'

In [16]:
def eval_judge_batch(config, model, ctx, encode, decode, operator='+', data_format='plain', reverse_c=False, num_digit=3, max_new_tokens=1):
    model.eval()
    start = config['start']
    device = config['device']
    test_data_file = start[5:]
    test_batch_size = config['test_batch_size'] if 'test_batch_size' in config.keys() else 128
    # 设置max_new_tokens为1，因为只需要输出判断结果
    max_new_tokens = max_new_tokens
    
    temperature = config['temperature'] if 'temperature' in config.keys() else 0.8
    top_k = config['top_k'] if 'top_k' in config.keys() else 200
    
    print(f'evaluating addition from: {start}')
    
    lines = []
    if start.startswith('FILE:'):
        with open(start[5:], 'r', encoding='utf-8') as f:
            # 除去每一行后面的空白字符，保存为列表，列表的每一个元素是一个算式，如“2+2=”
            old_lines = [line.rstrip() for line in f] 
            if data_format=='reverse' and reverse_c:
                for line in old_lines:
                    a, b = line.split('=')
                    b = b[:-1]
                    lines.append(a + f'={str(b)[::-1]}?')
    
    else:
        raise NotImplementedError("This method is not implemented yet!")
    
    pred_correct = 0
    no_judge = 0
    
    TP = 0 # True Positive
    FP = 0 # False Positive
    TN = 0 # True Negative
    FN = 0 # False Positive
    
    #总行数，也是总算式个数
    total = len(lines)
    
    carry_dictionary={f'carry{i}_correct':0 for i in range(num_digit+1)}
    #注意区别，corrtec和total
    carry_dictionary.update({f'carry{i}_total':0 for i in range(num_digit+1)})
    prompt_dict = {}
    
    for line_idx in tqdm(range(total)):
        #line_idx是所取出算式的index，取出对应行line
        line = lines[line_idx]
        line = line.strip('\n')
        line_start = line[0]
        line = line[1:]
        if data_format=='reverse':
            line = '$'+line+'$'
        # 对line这个string做编码
        start_ids = encode(line)
        # 将编码转换为张量，并额外加一个维度，从len(start_ids)变为(1,len(start_ids))
        x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
        # 在character level tokenization时，这个len_x其实就是len(start_ids)。。。
        len_x = len(x[0])
        a,b,c,op = get_abc(line)
        a_d, b_d, num_carry = get_num_digits(a), get_num_digits(b), numCarryOps(a,b)
        prompt_length = len(start_ids)
        # NOTE: prompt_length != len(line) if we're not using character level tokenization
        input_tuple = (x, len(line), line_start, a, b, c, a_d, b_d, num_carry)
        if prompt_length in prompt_dict.keys():
            prompt_dict[prompt_length].append(input_tuple)
        else:
            prompt_dict[prompt_length] = [input_tuple]
        # prompt是一个字典，键值是所有可能出现的prompt_length
        # 这样划分是为了保证每一个batch中的len_x相等
        
    # construct batches of prompts now
    batch_list = []
    for prompt_length in prompt_dict.keys():
        input_tuple_list = prompt_dict[prompt_length]
        for batch_idx in range(math.ceil(len(input_tuple_list)/test_batch_size)):
            #每个sequence（或算式）对应一个tuple，每test_batch_size个tuple划分为同一个batch，对应这一个list，
            # 也就是每个list就是一个batch，所有batch组成一个更大的batch_list
            batch_list.append(input_tuple_list[batch_idx*test_batch_size:(batch_idx+1)*test_batch_size])
            
    for batch_idx in tqdm(range(len(batch_list))):
        batch = batch_list[batch_idx]
        # 单取出所有x
        x_list = [input_tuple[0] for input_tuple in batch]
        # x.size=(batch_size, )
        x = torch.cat(x_list, dim=0)
        # run generation
        with torch.no_grad():
            with ctx:
                y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
                outcome_list = [decode(y_i.tolist()) for y_i in y]
                # 下面逐个分析这个batch中的model的预测结果
                for i, outcome in enumerate(outcome_list):
                     # 取出对应的tuple
                    _, len_x, line_start, a, b, c, a_d, b_d, num_carry = batch[i]
                    Pred = outcome[-1]
                    if line_start == 'T':
                        if Pred == 'T':
                            pred_correct += 1
                            carry_dictionary[f'carry{num_carry}_correct']+=1
                            TP += 1
                        elif Pred == 'F':
                            FN += 1
                            print('wrong outputs(x): ', outcome)
                            
                    elif line_start == 'F':
                        if Pred == 'F':
                            pred_correct += 1
                            carry_dictionary[f'carry{num_carry}_correct']+=1
                            TN += 1
                        elif Pred == 'T':
                            FP += 1
                            print('wrong outputs(x): ', outcome)
                            
                    else:
                        no_judge += 1
                        
                    carry_dictionary[f'carry{num_carry}_total']+=1
    
    pred_accuracy = pred_correct/total*100
    no_judging_probability = no_judge/total*100
    
    accuracy_dictionary = {f'carry{i}': carry_dictionary[f'carry{i}_correct']/carry_dictionary[f'carry{i}_total']*100 \
        if carry_dictionary[f'carry{i}_total']!=0 else np.nan for i in range(num_digit+1)}
    
    print(f"Judgement accuracy of {total} examples: {pred_correct}/{total} ({pred_accuracy}%)")
    print(f"No judging probability of {total} examples: {no_judge}/{total} ({no_judging_probability}%)")
    print(f'True Positive Examples: {TP}/{total}')
    print(f'False Positive Examples: {FP}/{total}')
    print(f'True Negative Examples: {TN}/{total}')
    print(f'False Negative Examples: {FN}/{total}')
    print(accuracy_dictionary)
    
    model.train()
    
    return pred_accuracy, no_judging_probability, accuracy_dictionary

In [17]:
from model import GPTConfig, GPT
import torch

# init from a model saved in a specific directory
ckpt_path = 'out-check-bilabel-label$/ckpt_10000_acc.pt'
checkpoint = torch.load(ckpt_path, map_location=mydevice)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.to(mydevice)
encode, decode = get_encode_decode('meta_all_ascii_chars.pkl')

number of parameters: 10.66M
Loading meta from meta_all_ascii_chars.pkl...


In [19]:
from contextlib import nullcontext
ctx = nullcontext()
config={
    'start': 'FILE:./test_data/extra_num_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode, max_new_tokens=1, data_format='reverse', reverse_c=True)

evaluating addition from: FILE:./test_data/extra_num_judge_prompt.txt


  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [00:00<00:00, 13291.06it/s]
 30%|███       | 25/82 [00:00<00:00, 121.80it/s]

wrong outputs(x):  $836+149=589?$F
wrong outputs(x):  $332+507=938?$F
wrong outputs(x):  $395+39=1434?$T
wrong outputs(x):  $419+54=6374?$T
wrong outputs(x):  $641+47=8868?$T
wrong outputs(x):  $579+135=417?$F
wrong outputs(x):  $241+39=4082?$T
wrong outputs(x):  $920+24=6449?$T
wrong outputs(x):  $245+29=2472?$T
wrong outputs(x):  $516+252=867?$F
wrong outputs(x):  $312+156=864?$F
wrong outputs(x):  $338+181=915?$F
wrong outputs(x):  $50+207=4752?$T
wrong outputs(x):  $129+520=946?$F
wrong outputs(x):  $219+502=127?$F
wrong outputs(x):  $646+102=847?$F
wrong outputs(x):  $835+160=599?$F
wrong outputs(x):  $319+158=774?$F
wrong outputs(x):  $13+655=6866?$T
wrong outputs(x):  $49+570=9163?$T
wrong outputs(x):  $111+834=549?$F
wrong outputs(x):  $319+148=764?$F
wrong outputs(x):  $29+960=7989?$T
wrong outputs(x):  $549+160=907?$F
wrong outputs(x):  $124+449=375?$F
wrong outputs(x):  $170+68=8329?$T
wrong outputs(x):  $680+140=028?$F
wrong outputs(x):  $81+874=3559?$T
wrong outputs(x):  $

 46%|████▋     | 38/82 [00:00<00:00, 116.73it/s]

wrong outputs(x):  $962+3=7569?$T
wrong outputs(x):  $7+473=3084?$T
wrong outputs(x):  $218+7=7522?$T
wrong outputs(x):  $81+300=183?$F
wrong outputs(x):  $518+66=485?$F
wrong outputs(x):  $699+16=517?$F
wrong outputs(x):  $65+59=7421?$T
wrong outputs(x):  $30+439=964?$F
wrong outputs(x):  $622+68=096?$F
wrong outputs(x):  $752+40=297?$F
wrong outputs(x):  $283+69=253?$F
wrong outputs(x):  $875+86=169?$F
wrong outputs(x):  $501+96=795?$F
wrong outputs(x):  $241+73=413?$F
wrong outputs(x):  $9+467=3674?$T
wrong outputs(x):  $79+297=673?$F
wrong outputs(x):  $732+96=828?$F
wrong outputs(x):  $851+2=6358?$T
wrong outputs(x):  $1+871=1278?$T
wrong outputs(x):  $23+473=694?$F
wrong outputs(x):  $79+694=377?$F
wrong outputs(x):  $52+57=5901?$T
wrong outputs(x):  $263+36=992?$F
wrong outputs(x):  $430+2=7234?$T
wrong outputs(x):  $2+496=1894?$T
wrong outputs(x):  $64+190=452?$F
wrong outputs(x):  $813+3=3618?$T
wrong outputs(x):  $55+56=6111?$T
wrong outputs(x):  $890+29=919?$F
wrong outputs(

 77%|███████▋  | 63/82 [00:00<00:00, 117.85it/s]

wrong outputs(x):  $530+224=4579?$T
wrong outputs(x):  $800+996=6971?$F
wrong outputs(x):  $628+312=0495?$T
wrong outputs(x):  $237+224=2164?$T
wrong outputs(x):  $400+506=6097?$T
wrong outputs(x):  $963+172=5311?$F
wrong outputs(x):  $327+957=4821?$F
wrong outputs(x):  $989+372=1631?$F
wrong outputs(x):  $417+363=2087?$T
wrong outputs(x):  $702+751=3541?$F
wrong outputs(x):  $880+349=9221?$F
wrong outputs(x):  $297+309=7606?$T
wrong outputs(x):  $146+367=3154?$T
wrong outputs(x):  $590+185=2577?$T
wrong outputs(x):  $855+656=1151?$F
wrong outputs(x):  $245+272=1715?$T
wrong outputs(x):  $492+819=1131?$F
wrong outputs(x):  $466+826=2921?$F
wrong outputs(x):  $361+181=1245?$T
wrong outputs(x):  $896+155=1501?$F
wrong outputs(x):  $560+186=8647?$T
wrong outputs(x):  $203+583=4687?$T
wrong outputs(x):  $553+559=2111?$F
wrong outputs(x):  $713+699=2141?$F
wrong outputs(x):  $392+296=8862?$T
wrong outputs(x):  $357+454=1188?$T
wrong outputs(x):  $465+852=7131?$F
wrong outputs(x):  $376+969=

100%|██████████| 82/82 [00:00<00:00, 119.48it/s]

wrong outputs(x):  $784+915=9961?$F
wrong outputs(x):  $104+728=5238?$T
wrong outputs(x):  $273+184=7546?$T
wrong outputs(x):  $779+367=6411?$F
wrong outputs(x):  $217+450=7663?$T
wrong outputs(x):  $474+514=8894?$T
wrong outputs(x):  $510+325=5388?$T
wrong outputs(x):  $393+973=6631?$F
wrong outputs(x):  $385+225=0162?$T
wrong outputs(x):  $449+977=6241?$F
wrong outputs(x):  $597+230=8728?$T
wrong outputs(x):  $582+953=5351?$F
wrong outputs(x):  $588+806=4931?$F
wrong outputs(x):  $342+796=8311?$F
wrong outputs(x):  $967+303=0721?$F
wrong outputs(x):  $588+160=9847?$T
wrong outputs(x):  $211+855=6601?$F
wrong outputs(x):  $120+686=8608?$T
wrong outputs(x):  $840+589=9241?$F
wrong outputs(x):  $197+772=8969?$T
wrong outputs(x):  $113+402=6515?$T
wrong outputs(x):  $621+551=2711?$F
wrong outputs(x):  $889+587=6741?$F
wrong outputs(x):  $340+208=7845?$T
wrong outputs(x):  $662+528=0911?$F
wrong outputs(x):  $427+818=5421?$F
wrong outputs(x):  $595+183=4877?$T
wrong outputs(x):  $273+345=




(50.029999999999994,
 0.0,
 {'carry0': 49.644549763033176,
  'carry1': 48.88080581980974,
  'carry2': 50.13231402528668,
  'carry3': 53.32834704562453})

In [20]:
from contextlib import nullcontext
ctx = nullcontext()
config={
    'start': 'FILE:./test_data/add_noise_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode, max_new_tokens=1, data_format='reverse', reverse_c=True)

evaluating addition from: FILE:./test_data/add_noise_judge_prompt.txt


100%|██████████| 10000/10000 [00:00<00:00, 14078.62it/s]
 16%|█▌        | 13/81 [00:00<00:00, 120.88it/s]

wrong outputs(x):  $836+149=589?$F
wrong outputs(x):  $116+507=356?$T
wrong outputs(x):  $162+465=707?$T
wrong outputs(x):  $342+158=005?$F
wrong outputs(x):  $94+966=0601?$F
wrong outputs(x):  $204+727=139?$F
wrong outputs(x):  $332+507=938?$F
wrong outputs(x):  $641+47=8831?$T
wrong outputs(x):  $395+19=4146?$T
wrong outputs(x):  $241+39=0825?$T
wrong outputs(x):  $666+272=839?$F
wrong outputs(x):  $920+24=4481?$T
wrong outputs(x):  $516+252=867?$F
wrong outputs(x):  $338+181=915?$F
wrong outputs(x):  $32+621=3565?$T
wrong outputs(x):  $442+351=897?$T
wrong outputs(x):  $256+301=365?$T
wrong outputs(x):  $129+520=946?$F
wrong outputs(x):  $219+502=037?$T
wrong outputs(x):  $101+153=852?$T
wrong outputs(x):  $646+102=847?$F
wrong outputs(x):  $616+140=467?$T
wrong outputs(x):  $72+636=8001?$T
wrong outputs(x):  $415+233=866?$T
wrong outputs(x):  $13+655=8669?$T
wrong outputs(x):  $49+570=9165?$T
wrong outputs(x):  $29+960=9881?$T
wrong outputs(x):  $519+364=319?$T
wrong outputs(x):  $

 48%|████▊     | 39/81 [00:00<00:00, 120.74it/s]

wrong outputs(x):  $43+521=465?$F
wrong outputs(x):  $2+992=4901?$T
wrong outputs(x):  $1+992=3901?$T
wrong outputs(x):  $114+52=471?$T
wrong outputs(x):  $64+294=853?$F
wrong outputs(x):  $810+80=098?$F
wrong outputs(x):  $538+23=136?$T
wrong outputs(x):  $323+15=837?$T
wrong outputs(x):  $758+95=168?$T
wrong outputs(x):  $4+280=4824?$T
wrong outputs(x):  $88+881=969?$F
wrong outputs(x):  $545+77=226?$F
wrong outputs(x):  $159+86=545?$T
wrong outputs(x):  $55+318=373?$F
wrong outputs(x):  $732+67=928?$T
wrong outputs(x):  $41+334=504?$T
wrong outputs(x):  $798+80=388?$T
wrong outputs(x):  $15+405=724?$T
wrong outputs(x):  $37+771=908?$T
wrong outputs(x):  $71+263=433?$F
wrong outputs(x):  $449+42=194?$F
wrong outputs(x):  $55+326=183?$F
wrong outputs(x):  $59+409=864?$F
wrong outputs(x):  $70+163=332?$F
wrong outputs(x):  $38+794=238?$F
wrong outputs(x):  $81+251=243?$T
wrong outputs(x):  $833+9=2486?$T
wrong outputs(x):  $34+192=622?$F
wrong outputs(x):  $668+80=457?$T
wrong outputs(

 80%|████████  | 65/81 [00:00<00:00, 118.09it/s]

wrong outputs(x):  $893+690=3856?$T
wrong outputs(x):  $854+534=8541?$T
wrong outputs(x):  $277+503=0873?$T
wrong outputs(x):  $568+675=3402?$T
wrong outputs(x):  $697+505=2021?$F
wrong outputs(x):  $782+155=7399?$T
wrong outputs(x):  $434+601=7301?$T
wrong outputs(x):  $848+534=2832?$T
wrong outputs(x):  $959+168=7251?$T
wrong outputs(x):  $918+845=3652?$T
wrong outputs(x):  $306+517=3282?$T
wrong outputs(x):  $747+793=0471?$T
wrong outputs(x):  $350+666=4201?$T
wrong outputs(x):  $749+470=9121?$F
wrong outputs(x):  $157+701=8583?$T
wrong outputs(x):  $901+409=1131?$T
wrong outputs(x):  $400+758=8511?$F
wrong outputs(x):  $697+818=5551?$T
wrong outputs(x):  $774+997=1771?$F
wrong outputs(x):  $980+212=2911?$F
wrong outputs(x):  $963+231=4911?$F
wrong outputs(x):  $455+801=6521?$F
wrong outputs(x):  $778+987=5673?$T
wrong outputs(x):  $506+892=8931?$F
wrong outputs(x):  $424+596=0201?$F
wrong outputs(x):  $682+324=6901?$T
wrong outputs(x):  $315+295=0121?$T
wrong outputs(x):  $717+456=

100%|██████████| 81/81 [00:00<00:00, 120.00it/s]

wrong outputs(x):  $286+835=12113?$T
wrong outputs(x):  $533+842=57314?$T
wrong outputs(x):  $841+566=70416?$T
wrong outputs(x):  $905+478=38315?$T
wrong outputs(x):  $710+728=83413?$T
wrong outputs(x):  $557+659=61201?$T
wrong outputs(x):  $875+734=90601?$T
wrong outputs(x):  $190+991=18113?$T
wrong outputs(x):  $749+773=22513?$T
wrong outputs(x):  $497+981=87416?$T
wrong outputs(x):  $964+107=17014?$T
wrong outputs(x):  $961+297=85216?$T
wrong outputs(x):  $749+343=29016?$T
wrong outputs(x):  $355+778=33119?$T
wrong outputs(x):  $717+603=02319?$T
wrong outputs(x):  $887+653=04501?$T
wrong outputs(x):  $309+880=98114?$T
wrong outputs(x):  $509+576=58012?$T
wrong outputs(x):  $366+655=12001?$T
wrong outputs(x):  $940+644=48514?$T
wrong outputs(x):  $406+709=51113?$T
wrong outputs(x):  $751+556=70314?$T
wrong outputs(x):  $911+979=09812?$T
wrong outputs(x):  $449+797=64219?$T
wrong outputs(x):  $992+227=91218?$T
wrong outputs(x):  $909+295=40219?$T
wrong outputs(x):  $776+408=48112?$T
w




(50.3,
 0.0,
 {'carry0': 48.992890995260666,
  'carry1': 51.42697257974258,
  'carry2': 49.89708909144369,
  'carry3': 49.962602842184})