使用addtion-only训练的model，在test数据集上做推理，将结果保存为`answer_1st.txt`

先加载prompt

In [2]:
from main_utils import *
from model import GPT, GPTConfig

device = 'cuda'
meta_path = 'meta_all_ascii_chars.pkl'
encode, decode = get_encode_decode(meta_path)

input_data_path = 'data/bal/test_3digit_10000.txt'

test_data_list = get_data_list(input_data_path, operator='+', judge=False, test=True)
test_data_str = generate_data_str(test_data_list, operator='+', format='eval_format', train=False, shuffle=True, judge=False, label_exp=False)
      
lines = test_data_str.split('\n')[:-1]
total = len(lines)
inputs = []

for line_idx in tqdm(range(total)):
    #line_idx是所取出算式的index，取出对应行line
    line = lines[line_idx]
    line.strip('\n')
    # 对line这个string做编码
    start_ids = encode(line)
    # 将编码转换为张量，并额外加一个维度，从len(start_ids)变为(1,len(start_ids))
    x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
    inputs.append(x)


ckpt_path = 'test_out/out-check-add-only-eval/ckpt_10000_acc.pt'

checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)

state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.to(device)

results = []
max_new_tokens = 3 + 2


for i, x in tqdm(enumerate(inputs)):
    model.eval()
    with torch.no_grad():
        y = model.generate(x, max_new_tokens, temperature=0.8)
        output = [decode(y_i.tolist()) for y_i in y]
        outcome = output[0].split('\n')[0]
        results.append(output)
        
import re

real_results = []
nums = [2000, 4000, 6000, 8000, 10000]
for i, line in enumerate(results):
    # 匹配如e(a+b):c的子串，提取出a,b,c的值
    match = re.search(r"e(?P<a>\d+)\+(?P<b>\d+):(?P<c>\d+)", line[0])
    # 如果匹配成功，提取出a,b,c的值
    if match:
        a = int(match.group("a"))
        b = int(match.group("b"))
        c_hat = int(match.group("c"))
        c = a + b
        print(f"a={a}, b={b}, c_hat={c_hat}, c={c}")
        # 如果c_hat等于c，说明模型计算正确, 按照正常思路构造两种错误
        if c_hat != c:
            # real_results.append(f'F{a}+{b}={c_hat}\n')
            real_results.append(f'j{a}+{b}={c_hat}~F\n')
 
        real_results.append(f'T{a}+{b}={c}\n')
        real_results.append(f'j{a}+{b}={c}~T\n')
        if random.uniform(0,1)>0.7:
            # 其中一半负样本为extra number
            flag = random.uniform(0,1)
            extra = random.randint(1, 9)
            if flag > 0.5:
                # real_results.append(f'F{a}+{b}={c}{extra}\n')
                real_results.append(f'j{a}+{b}={c}{extra}~F\n')
            else:
                # real_results.append(f'F{a}+{b}={extra}{c}\n')
                real_results.append(f'j{a}+{b}={extra}{c}~F\n')
            # 另一半样本为W1
            wrong_loc = random.randint(0, 3)
            addend = random.randint(1, 9)
            new_result = c + addend * (10**wrong_loc)
            # real_results.append(f'F{a}+{b}={new_result}\n')
            real_results.append(f'j{a}+{b}={new_result}~F\n')
            if i in nums:
                # 写入文件
                output_path = f'answer_{i}_W1.txt'
                with open(output_path, 'w') as f:
                    f.writelines(real_results)
                    print(f"写入{output_path}成功！")
    else:
        print("未找到匹配的子串")

Loading meta from meta_all_ascii_chars.pkl...


  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [00:05<00:00, 1675.95it/s]


number of parameters: 10.66M


10000it [03:23, 49.19it/s]


a=767, b=493, c_hat=1260, c=1260
a=319, b=504, c_hat=823, c=823
a=49, b=147, c_hat=196, c=196
a=414, b=417, c_hat=831, c=831
a=372, b=963, c_hat=1335, c=1335
a=3, b=515, c_hat=518, c=518
a=326, b=645, c_hat=971, c=971
a=353, b=688, c_hat=1041, c=1041
a=84, b=28, c_hat=112, c=112
a=636, b=669, c_hat=1305, c=1305
a=106, b=992, c_hat=1098, c=1098
a=412, b=322, c_hat=734, c=734
a=623, b=107, c_hat=730, c=730
a=899, b=828, c_hat=1727, c=1727
a=889, b=953, c_hat=1842, c=1842
a=89, b=358, c_hat=447, c=447
a=212, b=187, c_hat=399, c=399
a=667, b=171, c_hat=838, c=838
a=209, b=401, c_hat=610, c=610
a=45, b=75, c_hat=120, c=120
a=305, b=554, c_hat=859, c=859
a=166, b=252, c_hat=418, c=418
a=692, b=582, c_hat=1274, c=1274
a=835, b=985, c_hat=1820, c=1820
a=633, b=126, c_hat=759, c=759
a=820, b=362, c_hat=1182, c=1182
a=772, b=116, c_hat=888, c=888
a=649, b=498, c_hat=1147, c=1147
a=180, b=476, c_hat=656, c=656
a=846, b=840, c_hat=1686, c=1686
a=845, b=986, c_hat=1831, c=1831
a=775, b=524, c_hat=1

In [11]:
import torch
from main_utils import *
from model import GPT, GPTConfig
import re
import random

def randint_exclude(start, end, exclude):
    while True:
        num = random.randint(start, end)
        if num != exclude:
            return num
        
def replace_char_at_index(orig_string, index, new_char):
    new_string = orig_string[:index] + new_char + orig_string[index+1:]
    return new_string

def generate_results(input_data_path, ckpt_path, device='cuda', max_new_tokens=5, mode=1, temperature=0.8):
    meta_path = 'meta_all_ascii_chars.pkl'
    encode, decode = get_encode_decode(meta_path)

    test_data_list = get_data_list(input_data_path, operator='+', judge=False, test=True)
    test_data_str = generate_data_str(test_data_list, operator='+', format='eval_format', train=False, shuffle=True, judge=False, label_exp=False)

    lines = test_data_str.split('\n')[:-1]
    total = len(lines)
    inputs = []

    for line_idx in tqdm(range(total)):
        line = lines[line_idx]
        line.strip('\n')
        start_ids = encode(line)
        x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
        inputs.append(x)

    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)

    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    model.to(device)

    results = []

    for i, x in tqdm(enumerate(inputs)):
        model.eval()
        with torch.no_grad():
            y = model.generate(x, max_new_tokens, temperature=temperature)
            output = [decode(y_i.tolist()) for y_i in y]
            outcome = output[0].split('\n')[0]
            results.append(output)

    real_results = []
    nums = [2000, 4000, 6000, 8000, 10000]
    for i, line in enumerate(results):
        match = re.search(r"e(?P<a>\d+)\+(?P<b>\d+):(?P<c>\d+)", line[0])
        if match:
            a = int(match.group("a"))
            b = int(match.group("b"))
            c_hat = int(match.group("c"))
            c = a + b
            if c_hat != c:
                real_results.append(f'j{a}+{b}={c_hat}~F\n')

            real_results.append(f'T{a}+{b}={c}\n')
            real_results.append(f'j{a}+{b}={c}~T\n')
            if random.uniform(0,1)>0.7:
                flag = random.uniform(0,1)
                extra = random.randint(1, 9)
                if flag > 0.5:
                    real_results.append(f'j{a}+{b}={c}{extra}~F\n')
                else:
                    real_results.append(f'j{a}+{b}={extra}{c}~F\n')
                if mode == 1:
                    wrong_loc = random.randint(0, 3)
                    addend = random.randint(1, 9)
                    new_result = c + addend * (10**wrong_loc)
                    real_results.append(f'j{a}+{b}={new_result}~F\n')
                else:
                    # Wrong number 2
                    char_c = str(c)
                    n = len(char_c)
                    wrong_loc = random.randint(0, n-1)
                    changed_number = char_c[wrong_loc]
                    if wrong_loc == 0:
                        addend = randint_exclude(1, 9, int(changed_number))
                    else:
                        addend = randint_exclude(0, 9, int(changed_number))
                            
                    char_c = replace_char_at_index(char_c, wrong_loc, str(addend))
                    new_result = int(char_c)
                    real_results.append(f'j{a}+{b}={new_result}~F\n')
                wrong_loc = random.randint(0, 3)
                addend = random.randint(1, 9)
                new_result = c + addend * (10**wrong_loc)
                real_results.append(f'j{a}+{b}={new_result}~F\n')
        else:
            print("未找到匹配的子串")
        # 若到达指定的i值，则写入文件
        if i in nums:
            output_path = f'answer_{i}_W_{mode}.txt'
            with open(output_path, 'w') as f:
                f.writelines(real_results)
                print(f"写入{output_path}成功！")
                
    output_path = f'answer_10000_W_{mode}.txt'
    with open(output_path, 'w') as f:
        f.writelines(real_results)
        print(f"写入{output_path}成功！")

    return real_results

In [12]:
ckpt_path = 'test_out/out-check-add-only-eval/ckpt_10000_acc.pt'
input_data_path = 'data/bal/test_3digit_10000.txt'
for W in [1, 2]:
    generate_results(input_data_path, ckpt_path, device='cuda', max_new_tokens=5, mode=W, temperature=0.8)

Loading meta from meta_all_ascii_chars.pkl...


  8%|▊         | 828/10000 [00:00<00:01, 8278.69it/s]

100%|██████████| 10000/10000 [00:00<00:00, 26976.38it/s]


number of parameters: 10.66M


10000it [03:17, 50.55it/s]


写入answer_2000_W_1.txt成功！
未找到匹配的子串
写入answer_4000_W_1.txt成功！
写入answer_6000_W_1.txt成功！
写入answer_8000_W_1.txt成功！
写入answer_10000_W_1.txt成功！
Loading meta from meta_all_ascii_chars.pkl...


100%|██████████| 10000/10000 [00:00<00:00, 11897.31it/s]


number of parameters: 10.66M


10000it [03:28, 48.00it/s]


写入answer_2000_W_2.txt成功！
写入answer_4000_W_2.txt成功！
写入answer_6000_W_2.txt成功！
写入answer_8000_W_2.txt成功！
写入answer_10000_W_2.txt成功！


: 