In [1]:
from main_utils import *
mydevice = 'cuda'

In [2]:
def eval_judge_batch(config, model, ctx, encode, decode, operator='+', data_format='plain', reverse_c=False, num_digit=3, max_new_tokens=1):
    model.eval()
    start = config['start']
    device = config['device']
    test_data_file = start[5:]
    test_batch_size = config['test_batch_size'] if 'test_batch_size' in config.keys() else 128
    # 设置max_new_tokens为1，因为只需要输出判断结果
    max_new_tokens = max_new_tokens
    
    temperature = config['temperature'] if 'temperature' in config.keys() else 0.8
    top_k = config['top_k'] if 'top_k' in config.keys() else 200
    
    print(f'evaluating addition from: {start}')
    
    lines = []
    if start.startswith('FILE:'):
        with open(start[5:], 'r', encoding='utf-8') as f:
            # 除去每一行后面的空白字符，保存为列表，列表的每一个元素是一个算式，如“2+2=”
            old_lines = [line.rstrip() for line in f] 
            if data_format=='reverse' and reverse_c:
                for line in old_lines:
                    a, b = line.split('=')
                    b = b[:-1]
                    lines.append(a + f'={str(b)[::-1]}?')
            else:
                lines = old_lines
    
    else:
        raise NotImplementedError("This method is not implemented yet!")
    
    pred_correct = 0
    no_judge = 0
    
    TP = 0 # True Positive
    FP = 0 # False Positive
    TN = 0 # True Negative
    FN = 0 # False Positive
    
    #总行数，也是总算式个数
    total = len(lines)
    
    carry_dictionary={f'carry{i}_correct':0 for i in range(num_digit+1)}
    #注意区别，corrtec和total
    carry_dictionary.update({f'carry{i}_total':0 for i in range(num_digit+1)})
    prompt_dict = {}
    
    for line_idx in tqdm(range(total)):
        #line_idx是所取出算式的index，取出对应行line
        line = lines[line_idx]
        line = line.strip('\n')
        line_start = line[0]
        line = line[1:]
        if data_format=='reverse':
            line = '$'+line+'$'
        # 对line这个string做编码
        start_ids = encode(line)
        # 将编码转换为张量，并额外加一个维度，从len(start_ids)变为(1,len(start_ids))
        x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
        # 在character level tokenization时，这个len_x其实就是len(start_ids)。。。
        len_x = len(x[0])
        a,b,c,op = get_abc(line)
        a_d, b_d, num_carry = get_num_digits(a), get_num_digits(b), numCarryOps(a,b)
        prompt_length = len(start_ids)
        # NOTE: prompt_length != len(line) if we're not using character level tokenization
        input_tuple = (x, len(line), line_start, a, b, c, a_d, b_d, num_carry)
        if prompt_length in prompt_dict.keys():
            prompt_dict[prompt_length].append(input_tuple)
        else:
            prompt_dict[prompt_length] = [input_tuple]
        # prompt是一个字典，键值是所有可能出现的prompt_length
        # 这样划分是为了保证每一个batch中的len_x相等
        
    # construct batches of prompts now
    batch_list = []
    for prompt_length in prompt_dict.keys():
        input_tuple_list = prompt_dict[prompt_length]
        for batch_idx in range(math.ceil(len(input_tuple_list)/test_batch_size)):
            #每个sequence（或算式）对应一个tuple，每test_batch_size个tuple划分为同一个batch，对应这一个list，
            # 也就是每个list就是一个batch，所有batch组成一个更大的batch_list
            batch_list.append(input_tuple_list[batch_idx*test_batch_size:(batch_idx+1)*test_batch_size])
            
    for batch_idx in tqdm(range(len(batch_list))):
        batch = batch_list[batch_idx]
        # 单取出所有x
        x_list = [input_tuple[0] for input_tuple in batch]
        # x.size=(batch_size, )
        x = torch.cat(x_list, dim=0)
        # run generation
        with torch.no_grad():
            with ctx:
                y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
                outcome_list = [decode(y_i.tolist()) for y_i in y]
                # 下面逐个分析这个batch中的model的预测结果
                for i, outcome in enumerate(outcome_list):
                     # 取出对应的tuple
                    _, len_x, line_start, a, b, c, a_d, b_d, num_carry = batch[i]
                    Pred = outcome[-1]
                    if line_start == 'T':
                        if Pred == 'T':
                            pred_correct += 1
                            carry_dictionary[f'carry{num_carry}_correct']+=1
                            TP += 1
                        elif Pred == 'F':
                            FN += 1
                            print('wrong outputs(x): ', outcome)
                            
                    elif line_start == 'F':
                        if Pred == 'F':
                            pred_correct += 1
                            carry_dictionary[f'carry{num_carry}_correct']+=1
                            TN += 1
                        elif Pred == 'T':
                            FP += 1
                            print('wrong outputs(x): ', outcome)
                            
                    else:
                        no_judge += 1
                        
                    carry_dictionary[f'carry{num_carry}_total']+=1
    
    pred_accuracy = pred_correct/total*100
    no_judging_probability = no_judge/total*100
    
    accuracy_dictionary = {f'carry{i}': carry_dictionary[f'carry{i}_correct']/carry_dictionary[f'carry{i}_total']*100 \
        if carry_dictionary[f'carry{i}_total']!=0 else np.nan for i in range(num_digit+1)}
    
    print(f"Judgement accuracy of {total} examples: {pred_correct}/{total} ({pred_accuracy}%)")
    print(f"No judging probability of {total} examples: {no_judge}/{total} ({no_judging_probability}%)")
    print(f'True Positive Examples: {TP}/{total}')
    print(f'False Positive Examples: {FP}/{total}')
    print(f'True Negative Examples: {TN}/{total}')
    print(f'False Negative Examples: {FN}/{total}')
    print(accuracy_dictionary)
    
    model.train()
    
    return pred_accuracy, no_judging_probability, accuracy_dictionary

In [3]:
from model import GPTConfig, GPT
import torch

# init from a model saved in a specific directory
ckpt_path = 'out-check-bilabel-last-dig/ckpt_10000_acc.pt'
checkpoint = torch.load(ckpt_path, map_location=mydevice)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.to(mydevice)
encode, decode = get_encode_decode('meta_all_ascii_chars.pkl')

number of parameters: 10.66M
Loading meta from meta_all_ascii_chars.pkl...


In [4]:
from contextlib import nullcontext
ctx = nullcontext()
config={
    'start': 'FILE:./test_data/easy_judge_prompt_noise_add_last_dig.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode, max_new_tokens=1, data_format='plain', reverse_c=False)

evaluating addition from: FILE:./test_data/easy_judge_prompt_noise_add_last_dig.txt


  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [00:00<00:00, 17730.71it/s]
 19%|█▊        | 15/81 [00:00<00:01, 50.41it/s]

wrong outputs(x):  87+877=965?T
wrong outputs(x):  65+130=196?T
wrong outputs(x):  645+32=678?T
wrong outputs(x):  90+130=221?T
wrong outputs(x):  881+45=927?T
wrong outputs(x):  792+65=858?T
wrong outputs(x):  58+856=915?T
wrong outputs(x):  88+451=540?T
wrong outputs(x):  185+88=274?T
wrong outputs(x):  58+894=953?T
wrong outputs(x):  59+735=795?T
wrong outputs(x):  453+94=548?T
wrong outputs(x):  319+68=388?T
wrong outputs(x):  915+33=949?T
wrong outputs(x):  173+26=200?T
wrong outputs(x):  854+57=912?T
wrong outputs(x):  16+353=370?T
wrong outputs(x):  261+63=325?T
wrong outputs(x):  18+980=999?T
wrong outputs(x):  99+457=557?T
wrong outputs(x):  30+220=251?T
wrong outputs(x):  816+87=904?T
wrong outputs(x):  86+799=885?F
wrong outputs(x):  204+22=227?T
wrong outputs(x):  813+24=838?T
wrong outputs(x):  34+615=650?T
wrong outputs(x):  199+33=233?T
wrong outputs(x):  98+883=982?T
wrong outputs(x):  16+334=351?T
wrong outputs(x):  401+11=412?F
wrong outputs(x):  814+20=835?T
wrong ou

 57%|█████▋    | 46/81 [00:00<00:00, 105.28it/s]

wrong outputs(x):  403+239=643?T
wrong outputs(x):  150+786=937?T
wrong outputs(x):  375+319=695?T
wrong outputs(x):  162+613=776?T
wrong outputs(x):  189+234=424?T
wrong outputs(x):  761+102=864?T
wrong outputs(x):  237+695=933?T
wrong outputs(x):  328+635=964?T
wrong outputs(x):  413+350=764?T
wrong outputs(x):  246+420=667?T
wrong outputs(x):  547+340=888?T
wrong outputs(x):  482+267=750?T
wrong outputs(x):  410+333=743?F
wrong outputs(x):  82+920=1003?T
wrong outputs(x):  215+132=348?T
wrong outputs(x):  310+685=996?T
wrong outputs(x):  390+389=780?T
wrong outputs(x):  178+636=815?T
wrong outputs(x):  470+195=666?T
wrong outputs(x):  702+109=812?T
wrong outputs(x):  138+163=302?T
wrong outputs(x):  159+624=784?T
wrong outputs(x):  204+701=906?T
wrong outputs(x):  513+146=660?T
wrong outputs(x):  141+551=693?T
wrong outputs(x):  463+101=565?T
wrong outputs(x):  830+119=950?T
wrong outputs(x):  303+348=652?T
wrong outputs(x):  425+327=752?F
wrong outputs(x):  288+112=401?T
wrong outp

100%|██████████| 81/81 [00:00<00:00, 98.92it/s] 

wrong outputs(x):  165+891=1056?F
wrong outputs(x):  119+932=1051?F
wrong outputs(x):  633+930=1564?T
wrong outputs(x):  858+946=1805?T
wrong outputs(x):  446+921=1368?T
wrong outputs(x):  657+725=1383?T
wrong outputs(x):  550+533=1084?T
wrong outputs(x):  238+790=1029?T
wrong outputs(x):  774+350=1125?T
wrong outputs(x):  995+323=1318?F
wrong outputs(x):  914+282=1197?T
wrong outputs(x):  453+846=1300?T
wrong outputs(x):  674+660=1335?T
wrong outputs(x):  972+396=1368?F
wrong outputs(x):  602+731=1334?T
wrong outputs(x):  103+988=1092?T
wrong outputs(x):  939+132=1072?T
wrong outputs(x):  665+616=1281?F
wrong outputs(x):  985+684=1670?T
wrong outputs(x):  914+933=1848?T
wrong outputs(x):  944+114=1059?T
wrong outputs(x):  162+957=1120?T
wrong outputs(x):  253+897=1151?T
wrong outputs(x):  868+169=1037?F
wrong outputs(x):  870+179=1050?T
wrong outputs(x):  805+203=1009?T
wrong outputs(x):  755+829=1585?T
wrong outputs(x):  898+716=1615?T
wrong outputs(x):  976+483=1460?T
wrong outputs(




(50.38,
 0.0,
 {'carry0': 48.49710982658959,
  'carry1': 50.60342408083076,
  'carry2': 51.51425178147269,
  'carry3': 49.36519790888723})