# 这个notebook用来测试addition和judgement

## 创建测试数据

In [1]:
from main_utils import *
mydevice = 'mps'



主要使用了3种测试prompt
- 为了检测model的计算能力，所有input以`T`开头，即我们只要求输出正确的计算结果，但同时我们会让模型多输出一位（判断结果是否正确），形如: `T1+2=`
- 为了检测model在extra number型错误上的判断能力，按照训练时的方法制作negative instances。只要求模型输出一位，prompt形如：`1+2=3`和`2+3=58`
- 为了检测model在其它错误上的判断能力，制作negative instances时以均匀分布对正确结果的其中一位采样，在该位上加1～9的随机数。只要求模型输出一位

In [2]:
import os
import random
import re

class create_test_data:
    def __init__(self, non_overlap_data_path, num_test_samples=100) -> None:
        if not os.path.exists(non_overlap_data_path):
            raise ValueError("There is no nonoverlap data file")
        self.non_overlap_data_path = non_overlap_data_path
        self.num_test_samples = num_test_samples
        self.samples = None
        # Open non-overlap data path to get a bunch of test samples
        with open(self.non_overlap_data_path, 'r') as f:
            lines = f.readlines()
            random.shuffle(lines)
            self.samples = lines[:self.num_test_samples]
            
        
    def create_prompt(self):
        """
        To create prompt data file like: 'Ta+b=' .
        The output should be correct answer and judgement
        """
        with open(f'prompt.txt', 'w') as f2:
            for line in self.samples:
                prompt = line.split('=')[0]+'=\n'
                f2.write(prompt)
                    
    def create_add_noise_judge_prompt(self):
        """
        To create prompt data file like: 'a+b=c' and 'a+b=d', where 'd' means wrong answer
        The output should be judgement
        """
        with open('add_noise_judge_prompt.txt', 'w') as f3:
            for line in self.samples:
                # 取出表达式部分
                prompt = line.split('T')[1]
                # prompt = prompt.split('=')[0].strip()
                new_prompt = self.modify_result(prompt, random.randint(1, 9), 'noise_add')
                f3.write(new_prompt + '\n')
                
    def create_extra_num_judge_prompt(self):
        """
        To create prompt data file like: 'a+b=c' and 'a+b=cd', where 'd' means extra number
        The output should be judgement
        """
        with open('extra_num_judge_prompt.txt', 'w') as f4:
            for line in self.samples:
                # 取出表达式部分
                prompt = line.split('T')[1]
                # prompt = prompt.split('=')[0].strip()
                new_prompt = self.modify_result(prompt, random.randint(1, 9), 'extra_num')
                f4.write(new_prompt + '\n')

    
    def modify_result(self, expression, addend, mode):
        # 使用正则表达式提取表达式中的数字
        match = re.match(r'(\d+)\+(\d+)=(\d+)', expression)
    
        if match:
            # 提取数字并计算新的结果
            num1 = int(match.group(1))
            num2 = int(match.group(2))
            result = int(match.group(3))
            num_digit = len(match.group(3).strip())
            if mode == 'extra_num':
                if random.uniform(0,1)>0.5 and random.uniform(0,1)<0.75:
                    extra = random.randint(1, 9)
                    new_expression = f"F{num1}+{num2}={result}{extra}"
                elif random.uniform(0,1)>0.75:
                    extra = random.randint(1, 9)
                    new_expression = f"F{num1}+{num2}={extra}{result}"
                else:
                    new_expression = f"T{num1}+{num2}={result}"
            elif mode == 'noise_add':
                if random.uniform(0,1)>0.5:
                    # 决定添加错误的位置
                    wrong_loc = random.randint(0, num_digit)
        
                    new_result = result + addend * (10**wrong_loc)

                    # 构建新的表达式
                    new_expression = f"F{num1}+{num2}={new_result}"
                else:
                    new_expression = f"T{num1}+{num2}={result}"
            
            else:
                return "Invalid modify pattern"
        
            return new_expression
        else:
            return "Invalid expression format"

In [3]:
non_overlap_data_path = './data/get_data_with_label/train_3digit_bilabeled10000_nonoverlap.txt'
num_examples = 10000
creator = create_test_data(non_overlap_data_path, num_examples)

In [4]:
creator.create_prompt()
creator.create_add_noise_judge_prompt()
creator.create_extra_num_judge_prompt()

## 加载模型

In [5]:
from model import GPTConfig, GPT
import torch

# init from a model saved in a specific directory
ckpt_path = 'bilabel_ckpt_acc.pt'
checkpoint = torch.load(ckpt_path, map_location=mydevice)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)

number of parameters: 10.63M


<All keys matched successfully>

In [6]:
# evaluation
config={
    'start': 'FILE:./data/addition_bilabel/prompt_3digit_10000.txt',
    'device': mydevice,
}

In [7]:
encode, decode = get_encode_decode('./data/addition_bilabel/meta.pkl')

Loading meta from ./data/addition_bilabel/meta.pkl...


测试模型

In [8]:
x = '213+199=412'
ids = encode(x)
input = (torch.tensor(ids, dtype=torch.long, device=mydevice)[None, ...])
model.to(device=mydevice)
output = model.generate(input, max_new_tokens=1)
decode(output[0].tolist())

'213+199=412T'

## 批量测试

从txt文件中读取prompt进行测试，主要分为两种测试：

- 正确性测试：只关注计算结果是否正确
- 判断测试：只关注判断结果是否正确

读取prompt.txt，测试正确率与判断正确率

In [9]:
from contextlib import nullcontext
ctx = nullcontext()
config={
    'start': 'FILE:./prompt.txt',
    'device': mydevice,
    'temperature': 0.8
}
eval_addition_batch(config, model, ctx, encode, decode, judge=True)

evaluating addition from: FILE:./prompt.txt


100%|██████████| 10000/10000 [00:02<00:00, 4372.09it/s]
100%|██████████| 80/80 [00:19<00:00,  4.00it/s]

Judgement accuracy of 10000 examples: 9332/10000 (93.32000000000001%)
accuracy of 10000 examples: 9326/10000 (93.26%)
{'carry0': 92.07800121876905, 'carry1': 90.0326797385621, 'carry2': 95.55223880597015, 'carry3': 97.83096484667165}





(93.32000000000001,
 93.26,
 {'carry0': 92.07800121876905,
  'carry1': 90.0326797385621,
  'carry2': 95.55223880597015,
  'carry3': 97.83096484667165})

读取extra_num_judge_prompt.txt，测试判断能力

In [10]:
ctx = nullcontext()
config={
    'start': 'FILE:./extra_num_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode)

evaluating addition from: FILE:./extra_num_judge_prompt.txt


100%|██████████| 10000/10000 [00:02<00:00, 4183.97it/s]
100%|██████████| 82/82 [00:07<00:00, 10.96it/s]

Judgement accuracy of 10000 examples: 9300/10000 (93.0%)
No judging probability of 10000 examples: 592/10000 (5.92%)
{'carry0': 89.4576477757465, 'carry1': 92.12962962962963, 'carry2': 94.2089552238806, 'carry3': 96.70905011219148}





(93.0,
 5.92,
 {'carry0': 89.4576477757465,
  'carry1': 92.12962962962963,
  'carry2': 94.2089552238806,
  'carry3': 96.70905011219148})

读取add_noise_judge_prompt.txt，测试模型判断能力

In [11]:
ctx = nullcontext()
config={
    'start': 'FILE:./add_noise_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode)

evaluating addition from: FILE:./add_noise_judge_prompt.txt


100%|██████████| 10000/10000 [00:02<00:00, 4330.17it/s]
100%|██████████| 83/83 [00:07<00:00, 11.74it/s]

Judgement accuracy of 10000 examples: 6074/10000 (60.74%)
No judging probability of 10000 examples: 1174/10000 (11.74%)
{'carry0': 54.96648385131018, 'carry1': 58.08823529411765, 'carry2': 63.67164179104478, 'carry3': 67.76364996260284}





(60.74,
 11.74,
 {'carry0': 54.96648385131018,
  'carry1': 58.08823529411765,
  'carry2': 63.67164179104478,
  'carry3': 67.76364996260284})