# 测试addition和judgement

## 创建测试数据

In [1]:
from main_utils import *
mydevice = 'cuda'

主要使用了3种测试prompt
- 为了检测model的计算能力，所有input以`T`开头，即我们只要求输出正确的计算结果，但同时我们会让模型多输出一位（判断结果是否正确），形如: `T1+2=`
- 为了检测model在extra number型错误上的判断能力，按照训练时的方法制作negative instances。只要求模型输出一位，prompt形如：`1+2=3`和`2+3=58`
- 为了检测model在其它错误上的判断能力，制作negative instances时以均匀分布对正确结果的其中一位采样，在该位上加1～9的随机数。只要求模型输出一位

In [5]:
import os
import random
import re

class create_test_data:
    def __init__(self, non_overlap_data_path, num_test_samples=100) -> None:
        if not os.path.exists(non_overlap_data_path):
            raise ValueError("There is no nonoverlap data file")
        self.non_overlap_data_path = non_overlap_data_path
        self.num_test_samples = num_test_samples
        self.samples = None
        # Open non-overlap data path to get a bunch of test samples
        with open(self.non_overlap_data_path, 'r') as f:
            lines = f.readlines()
            random.shuffle(lines)
            self.samples = lines[:self.num_test_samples]
            
        
    def create_prompt(self):
        """
        To create prompt data file like: 'Ta+b=' .
        The output should be correct answer and judgement
        """
        with open(f'prompt.txt', 'w') as f2:
            for line in self.samples:
                prompt = line.split('=')[0]+'=\n'
                f2.write(prompt)
                    
    def create_add_noise_judge_prompt(self):
        """
        To create prompt data file like: 'a+b=c' and 'a+b=d', where 'd' means wrong answer
        The output should be judgement
        """
        with open('add_noise_judge_prompt.txt', 'w') as f3:
            for line in self.samples:
                # 取出表达式部分
                prompt = line.split('T')[1]
                # prompt = prompt.split('=')[0].strip()
                new_prompt = self.modify_result(prompt, random.randint(1, 9), 'noise_add')
                f3.write(new_prompt + '\n')
                
    def create_extra_num_judge_prompt(self):
        """
        To create prompt data file like: 'a+b=c' and 'a+b=cd', where 'd' means extra number
        The output should be judgement
        """
        with open('extra_num_judge_prompt.txt', 'w') as f4:
            for line in self.samples:
                # 取出表达式部分
                prompt = line.split('T')[1]
                # prompt = prompt.split('=')[0].strip()
                new_prompt = self.modify_result(prompt, random.randint(1, 9), 'extra_num')
                f4.write(new_prompt + '\n')

    
    def modify_result(self, expression, addend, mode):
        # 使用正则表达式提取表达式中的数字
        match = re.match(r'(\d+)\+(\d+)=(\d+)', expression)
    
        if match:
            # 提取数字并计算新的结果
            num1 = int(match.group(1))
            num2 = int(match.group(2))
            result = int(match.group(3))
            num_digit = len(match.group(3).strip())
            if mode == 'extra_num':
                if random.uniform(0,1)>0.5 and random.uniform(0,1)<0.75:
                    extra = random.randint(1, 9)
                    new_expression = f"F{num1}+{num2}={result}{extra}"
                elif random.uniform(0,1)>0.75:
                    extra = random.randint(1, 9)
                    new_expression = f"F{num1}+{num2}={extra}{result}"
                else:
                    new_expression = f"T{num1}+{num2}={result}"
            elif mode == 'noise_add':
                if random.uniform(0,1)>0.5:
                    # 决定添加错误的位置
                    wrong_loc = random.randint(0, num_digit)
        
                    new_result = result + addend * (10**wrong_loc)

                    # 构建新的表达式
                    new_expression = f"F{num1}+{num2}={new_result}"
                else:
                    new_expression = f"T{num1}+{num2}={result}"
            
            else:
                return "Invalid modify pattern"
        
            return new_expression
        else:
            return "Invalid expression format"

In [3]:
non_overlap_data_path = './data/get_data_with_label/train_3digit_bilabeled10000_nonoverlap.txt'
num_examples = 10000
creator = create_test_data(non_overlap_data_path, num_examples)

In [4]:
creator.create_prompt()
creator.create_add_noise_judge_prompt()
creator.create_extra_num_judge_prompt()

## 加载模型

In [5]:
from model import GPTConfig, GPT
import torch

# init from a model saved in a specific directory
ckpt_path = 'bilabel_ckpt_acc.pt'
checkpoint = torch.load(ckpt_path, map_location=mydevice)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)

number of parameters: 10.63M


<All keys matched successfully>

In [6]:
# evaluation
config={
    'start': 'FILE:./data/addition_bilabel/prompt_3digit_10000.txt',
    'device': mydevice,
}

In [7]:
encode, decode = get_encode_decode('./data/addition_bilabel/meta.pkl')

Loading meta from ./data/addition_bilabel/meta.pkl...


测试模型

In [8]:
x = '213+199=412'
ids = encode(x)
input = (torch.tensor(ids, dtype=torch.long, device=mydevice)[None, ...])
model.to(device=mydevice)
output = model.generate(input, max_new_tokens=1)
decode(output[0].tolist())

'213+199=412T'

## 批量测试

从txt文件中读取prompt进行测试，主要分为两种测试：

- 正确性测试：只关注计算结果是否正确
- 判断测试：只关注判断结果是否正确

读取prompt.txt，测试正确率与判断正确率

In [9]:
from contextlib import nullcontext
ctx = nullcontext()
config={
    'start': 'FILE:./prompt.txt',
    'device': mydevice,
    'temperature': 0.8
}
eval_addition_batch(config, model, ctx, encode, decode, judge=True)

evaluating addition from: FILE:./prompt.txt


100%|██████████| 10000/10000 [00:02<00:00, 4554.07it/s]
100%|██████████| 81/81 [00:19<00:00,  4.12it/s]

Judgement accuracy of 10000 examples: 9271/10000 (92.71000000000001%)
accuracy of 10000 examples: 9267/10000 (92.67%)
{'carry0': 90.9090909090909, 'carry1': 89.44240022643646, 'carry2': 95.0613676212741, 'carry3': 97.15950473415877}





(92.71000000000001,
 92.67,
 {'carry0': 90.9090909090909,
  'carry1': 89.44240022643646,
  'carry2': 95.0613676212741,
  'carry3': 97.15950473415877})

读取extra_num_judge_prompt.txt，测试判断能力

In [10]:
ctx = nullcontext()
config={
    'start': 'FILE:./extra_num_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode)

evaluating addition from: FILE:./extra_num_judge_prompt.txt


100%|██████████| 10000/10000 [00:02<00:00, 4552.71it/s]
100%|██████████| 82/82 [00:07<00:00, 11.50it/s]

Judgement accuracy of 10000 examples: 9231/10000 (92.31%)
No judging probability of 10000 examples: 650/10000 (6.5%)
{'carry0': 89.11483253588517, 'carry1': 91.14067364845741, 'carry2': 93.36645236703683, 'carry3': 96.5768390386016}





(92.31,
 6.5,
 {'carry0': 89.11483253588517,
  'carry1': 91.14067364845741,
  'carry2': 93.36645236703683,
  'carry3': 96.5768390386016})

读取add_noise_judge_prompt.txt，测试模型判断能力

In [11]:
ctx = nullcontext()
config={
    'start': 'FILE:./add_noise_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode)

evaluating addition from: FILE:./add_noise_judge_prompt.txt


100%|██████████| 10000/10000 [00:02<00:00, 4224.60it/s]
100%|██████████| 82/82 [00:06<00:00, 12.37it/s]

Judgement accuracy of 10000 examples: 6094/10000 (60.940000000000005%)
No judging probability of 10000 examples: 1182/10000 (11.82%)
{'carry0': 53.88755980861244, 'carry1': 59.835833569204645, 'carry2': 62.85797779076563, 'carry3': 67.5892206846322}





(60.940000000000005,
 11.82,
 {'carry0': 53.88755980861244,
  'carry1': 59.835833569204645,
  'carry2': 62.85797779076563,
  'carry3': 67.5892206846322})

由此可见，
- 不需要进行预训练+微调，模型已经有判断正误的能力
- 训练集中negative instances的设置方式会对模型的判断能力产生较大影响

## 同时使用两类negative instances训练

首先创建测试数据

In [12]:
non_overlap_data_path = './data/get_data_with_label/train_3digit_bilabeled10000_nonoverlap_new.txt'
num_examples = 10000
new_creator = create_test_data(non_overlap_data_path, num_examples)

In [13]:
new_creator.create_prompt()
new_creator.create_extra_num_judge_prompt()
new_creator.create_add_noise_judge_prompt()

加载模型

In [14]:
from model import GPTConfig, GPT
import torch

# init from a model saved in a specific directory
ckpt_path = 'new_bilabel_ckpt_acc.pt'
checkpoint = torch.load(ckpt_path, map_location=mydevice)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.to(mydevice)

number of parameters: 10.63M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(15, 384)
    (wpe): Embedding(256, 384)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=384, out_features=1152, bias=False)
          (c_proj): Linear(in_features=384, out_features=384, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=384, out_features=1536, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1536, out_features=384, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=384, out_features=15, bias=False)
)

In [15]:
encode, decode = get_encode_decode('./data/addition_bilabel/meta.pkl')

Loading meta from ./data/addition_bilabel/meta.pkl...


测试prompt

In [16]:
from contextlib import nullcontext
ctx = nullcontext()
config={
    'start': 'FILE:./prompt.txt',
    'device': mydevice,
    'temperature': 0.8
}
eval_addition_batch(config, model, ctx, encode, decode, judge=True)

evaluating addition from: FILE:./prompt.txt


100%|██████████| 10000/10000 [00:02<00:00, 4252.95it/s]
100%|██████████| 81/81 [00:20<00:00,  3.94it/s]

Judgement accuracy of 10000 examples: 8378/10000 (83.78%)
accuracy of 10000 examples: 9389/10000 (93.89%)
{'carry0': 93.84615384615384, 'carry1': 92.77440706012135, 'carry2': 94.23476968796433, 'carry3': 96.13343442001516}





(83.78,
 93.89,
 {'carry0': 93.84615384615384,
  'carry1': 92.77440706012135,
  'carry2': 94.23476968796433,
  'carry3': 96.13343442001516})

测试extra number

In [17]:
ctx = nullcontext()
config={
    'start': 'FILE:./extra_num_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode)

evaluating addition from: FILE:./extra_num_judge_prompt.txt


100%|██████████| 10000/10000 [00:02<00:00, 4511.77it/s]
100%|██████████| 84/84 [00:06<00:00, 12.28it/s]

Judgement accuracy of 10000 examples: 0/10000 (0.0%)
No judging probability of 10000 examples: 10000/10000 (100.0%)
{'carry0': 0.0, 'carry1': 0.0, 'carry2': 0.0, 'carry3': 0.0}





(0.0, 100.0, {'carry0': 0.0, 'carry1': 0.0, 'carry2': 0.0, 'carry3': 0.0})

测试add noise

In [18]:
ctx = nullcontext()
config={
    'start': 'FILE:./add_noise_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode)

evaluating addition from: FILE:./add_noise_judge_prompt.txt


100%|██████████| 10000/10000 [00:02<00:00, 4279.48it/s]
100%|██████████| 82/82 [00:07<00:00, 11.18it/s]

Judgement accuracy of 10000 examples: 0/10000 (0.0%)
No judging probability of 10000 examples: 10000/10000 (100.0%)
{'carry0': 0.0, 'carry1': 0.0, 'carry2': 0.0, 'carry3': 0.0}





(0.0, 100.0, {'carry0': 0.0, 'carry1': 0.0, 'carry2': 0.0, 'carry3': 0.0})

模型更本没有做判断，这是因为同时存在extra number和add noise两种错误

## 同时使用两类negative instances训练, 0.8p和0.2n, space

首先创建测试数据
发现0.7p+0.3n表现过差，但是在改为0.8p+0.2n后得到了目前为止最好的表现

In [18]:
import os
import random
import re

class create_test_data:
    def __init__(self, non_overlap_data_path, num_test_samples=100) -> None:
        if not os.path.exists(non_overlap_data_path):
            raise ValueError("There is no nonoverlap data file")
        self.non_overlap_data_path = non_overlap_data_path
        self.num_test_samples = num_test_samples
        self.samples = None
        # Open non-overlap data path to get a bunch of test samples
        with open(self.non_overlap_data_path, 'r') as f:
            lines = f.readlines()
            random.shuffle(lines)
            self.samples = lines[:self.num_test_samples]
            
        
    def create_prompt(self):
        """
        To create prompt data file like: 'Ta+b=' .
        The output should be correct answer and judgement
        """
        with open(f'prompt.txt', 'w') as f2:
            for line in self.samples:
                prompt = line.split('=')[0]+'=\n'
                f2.write(prompt)
                    
    def create_add_noise_judge_prompt(self):
        """
        To create prompt data file like: 'a+b=c' and 'a+b=d', where 'd' means wrong answer
        The output should be judgement
        """
        with open('add_noise_judge_prompt.txt', 'w') as f3:
            for line in self.samples:
                # 取出表达式部分
                prompt = line.split('T')[1].strip()
                # prompt = prompt.split('=')[0].strip()
                new_prompt = self.modify_result(prompt, random.randint(1, 9), 'noise_add')
                f3.write(new_prompt + '\n')
                
    def create_extra_num_judge_prompt(self):
        """
        To create prompt data file like: 'a+b=c' and 'a+b=cd', where 'd' means extra number
        The output should be judgement
        """
        with open('extra_num_judge_prompt.txt', 'w') as f4:
            for line in self.samples:
                # 取出表达式部分
                prompt = line.split('T')[1].strip()
                # prompt = prompt.split('=')[0].strip()
                new_prompt = self.modify_result(prompt, random.randint(1, 9), 'extra_num')
                f4.write(new_prompt + '\n')

    
    def modify_result(self, expression, addend, mode):
        # 使用正则表达式提取表达式中的数字
        match = re.match(r'(\d+)\+(\d+)=(\d+)', expression)
    
        if match:
            # 提取数字并计算新的结果
            num1 = int(match.group(1))
            num2 = int(match.group(2))
            result = int(match.group(3))
            num_digit = len(match.group(3).strip())
            if mode == 'extra_num':
                if random.uniform(0,1)>0.5 and random.uniform(0,1)<0.75:
                    extra = random.randint(1, 9)
                    new_expression = f"F{num1}+{num2}={result}{extra}"
                elif random.uniform(0,1)>0.75:
                    extra = random.randint(1, 9)
                    new_expression = f"F{num1}+{num2}={extra}{result}"
                else:
                    new_expression = f"T{num1}+{num2}={result}"
            elif mode == 'noise_add':
                if random.uniform(0,1)>0.5:
                    # 决定添加错误的位置
                    wrong_loc = random.randint(0, num_digit)
        
                    new_result = result + addend * (10**wrong_loc)

                    # 构建新的表达式
                    new_expression = f"F{num1}+{num2}={new_result}"
                else:
                    new_expression = f"T{num1}+{num2}={result}"
            
            else:
                return "Invalid modify pattern"
        
            return new_expression
        else:
            return "Invalid expression format"

In [19]:
non_overlap_data_path = './data/get_data_with_label/train_3digit_bilabeled10000_nonoverlap_new_sp.txt'
num_examples = 10000
new_creator = create_test_data(non_overlap_data_path, num_examples)

In [20]:
new_creator.create_prompt()
new_creator.create_extra_num_judge_prompt()
new_creator.create_add_noise_judge_prompt()

加载模型

In [21]:
from model import GPTConfig, GPT
import torch

# init from a model saved in a specific directory
ckpt_path = 'bilabel_ckpt_acc_sp.pt'
checkpoint = torch.load(ckpt_path, map_location=mydevice)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.to(mydevice)

number of parameters: 10.63M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(16, 384)
    (wpe): Embedding(256, 384)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=384, out_features=1152, bias=False)
          (c_proj): Linear(in_features=384, out_features=384, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=384, out_features=1536, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1536, out_features=384, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=384, out_features=16, bias=False)
)

In [22]:
encode, decode = get_encode_decode('./data/addition_bilabel_sp/meta.pkl')

Loading meta from ./data/addition_bilabel_sp/meta.pkl...


In [10]:
from contextlib import nullcontext
ctx = nullcontext()
config={
    'start': 'FILE:./prompt.txt',
    'device': mydevice,
    'temperature': 0.8
}
eval_addition_batch(config, model, ctx, encode, decode, judge=True)

evaluating addition from: FILE:./prompt.txt


100%|██████████| 10000/10000 [00:02<00:00, 4454.05it/s]
100%|██████████| 81/81 [00:21<00:00,  3.85it/s]

Judgement accuracy of 10000 examples: 9454/10000 (94.54%)
accuracy of 10000 examples: 9457/10000 (94.57%)
{'carry0': 93.62354383813611, 'carry1': 92.82136894824707, 'carry2': 95.77960140679953, 'carry3': 97.28539985326485}





(94.54,
 94.57,
 {'carry0': 93.62354383813611,
  'carry1': 92.82136894824707,
  'carry2': 95.77960140679953,
  'carry3': 97.28539985326485})

测试extra number的judgement

In [23]:
from contextlib import nullcontext
ctx = nullcontext()
config={
    'start': 'FILE:./extra_num_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode, max_new_tokens=2)

evaluating addition from: FILE:./extra_num_judge_prompt.txt


  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [00:02<00:00, 3971.68it/s]
100%|██████████| 83/83 [00:11<00:00,  7.22it/s]

Judgement accuracy of 10000 examples: 9386/10000 (93.86%)
No judging probability of 10000 examples: 243/10000 (2.4299999999999997%)
{'carry0': 93.17617866004963, 'carry1': 94.05286343612335, 'carry2': 93.48079161816065, 'carry3': 95.15151515151516}





(93.86,
 2.4299999999999997,
 {'carry0': 93.17617866004963,
  'carry1': 94.05286343612335,
  'carry2': 93.48079161816065,
  'carry3': 95.15151515151516})

测试add digit noise的judgement

In [24]:
ctx = nullcontext()
config={
    'start': 'FILE:./add_noise_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode, max_new_tokens=2)

evaluating addition from: FILE:./add_noise_judge_prompt.txt


100%|██████████| 10000/10000 [00:02<00:00, 3948.66it/s]
100%|██████████| 82/82 [00:11<00:00,  7.06it/s]

Judgement accuracy of 10000 examples: 6498/10000 (64.98%)
No judging probability of 10000 examples: 432/10000 (4.32%)
{'carry0': 59.42928039702233, 'carry1': 63.298458149779734, 'carry2': 67.05471478463329, 'carry3': 70.98484848484848}





(64.98,
 4.32,
 {'carry0': 59.42928039702233,
  'carry1': 63.298458149779734,
  'carry2': 67.05471478463329,
  'carry3': 70.98484848484848})

## 使用新数据
生成10000个训练样本，其中40%的样本除了正确的结果与label还包括两类负样本

In [16]:
import os
import random
import re

class create_test_data:
    def __init__(self, non_overlap_data_path, num_test_samples=100) -> None:
        if not os.path.exists(non_overlap_data_path):
            raise ValueError("There is no nonoverlap data file")
        self.non_overlap_data_path = non_overlap_data_path
        self.num_test_samples = num_test_samples
        self.samples = None
        # Open non-overlap data path to get a bunch of test samples
        with open(self.non_overlap_data_path, 'r') as f:
            lines = f.readlines()
            random.shuffle(lines)
            self.samples = lines[:self.num_test_samples]
            
        
    def create_prompt(self):
        """
        To create prompt data file like: 'Ta+b=' .
        The output should be correct answer and judgement
        """
        with open(f'prompt.txt', 'w') as f2:
            for line in self.samples:
                prompt = line.split('=')[0]+'=\n'
                f2.write(prompt)
                    
    def create_add_noise_judge_prompt(self):
        """
        To create prompt data file like: 'a+b=c' and 'a+b=d', where 'd' means wrong answer
        The output should be judgement
        """
        with open('add_noise_judge_prompt.txt', 'w') as f3:
            for line in self.samples:
                # 取出表达式部分
                prompt = line.split('T')[1].strip()
                # prompt = prompt.split('=')[0].strip()
                new_prompt = self.modify_result(prompt, random.randint(1, 9), 'noise_add')
                f3.write(new_prompt + '?\n')
                
    def create_extra_num_judge_prompt(self):
        """
        To create prompt data file like: 'a+b=c' and 'a+b=cd', where 'd' means extra number
        The output should be judgement
        """
        with open('extra_num_judge_prompt.txt', 'w') as f4:
            for line in self.samples:
                # 取出表达式部分
                prompt = line.split('T')[1].strip()
                # prompt = prompt.split('=')[0].strip()
                new_prompt = self.modify_result(prompt, random.randint(1, 9), 'extra_num')
                f4.write(new_prompt + '?\n')

    
    def modify_result(self, expression, addend, mode):
        # 使用正则表达式提取表达式中的数字
        match = re.match(r'(\d+)\+(\d+)=(\d+)', expression)
    
        if match:
            # 提取数字并计算新的结果
            num1 = int(match.group(1))
            num2 = int(match.group(2))
            result = int(match.group(3))
            num_digit = len(match.group(3).strip())
            if mode == 'extra_num':
                if random.uniform(0,1)>0.5 and random.uniform(0,1)<0.75:
                    extra = random.randint(1, 9)
                    new_expression = f"F{num1}+{num2}={result}{extra}"
                elif random.uniform(0,1)>0.75:
                    extra = random.randint(1, 9)
                    new_expression = f"F{num1}+{num2}={extra}{result}"
                else:
                    new_expression = f"T{num1}+{num2}={result}"
            elif mode == 'noise_add':
                if random.uniform(0,1)>0.5:
                    # 决定添加错误的位置
                    wrong_loc = random.randint(0, num_digit)
        
                    new_result = result + addend * (10**wrong_loc)

                    # 构建新的表达式
                    new_expression = f"F{num1}+{num2}={new_result}"
                else:
                    new_expression = f"T{num1}+{num2}={result}"
            
            else:
                return "Invalid modify pattern"
        
            return new_expression
        else:
            return "Invalid expression format"

In [17]:
non_overlap_data_path = './data/get_data_with_label/train_3digit_bilabeled10000_nonoverlap.txt'
num_examples = 10000
new_creator = create_test_data(non_overlap_data_path, num_examples)
new_creator.create_prompt()
new_creator.create_extra_num_judge_prompt()
new_creator.create_add_noise_judge_prompt()

加载模型

In [18]:
from model import GPTConfig, GPT
import torch

# init from a model saved in a specific directory
ckpt_path = 'ckpt_acc_bilabel(0.6p).pt'
checkpoint = torch.load(ckpt_path, map_location=mydevice)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.to(mydevice)

number of parameters: 10.63M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(16, 384)
    (wpe): Embedding(256, 384)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=384, out_features=1152, bias=False)
          (c_proj): Linear(in_features=384, out_features=384, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=384, out_features=1536, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1536, out_features=384, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=384, out_features=16, bias=False)
)

In [19]:
encode, decode = get_encode_decode('./data/addition_bilabel/meta.pkl')

Loading meta from ./data/addition_bilabel/meta.pkl...


In [20]:
from contextlib import nullcontext
ctx = nullcontext()
config={
    'start': 'FILE:./prompt.txt',
    'device': mydevice,
    'temperature': 0.8
}
eval_addition_batch(config, model, ctx, encode, decode, judge=True)

evaluating addition from: FILE:./prompt.txt


  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [00:00<00:00, 12625.53it/s]
100%|██████████| 81/81 [00:01<00:00, 44.69it/s]

Judgement accuracy of 10000 examples: 9476/10000 (94.76%)
accuracy of 10000 examples: 9478/10000 (94.78%)
{'carry0': 95.66787003610109, 'carry1': 93.37062937062936, 'carry2': 95.57089444923785, 'carry3': 95.41213063763608}





(94.76,
 94.78,
 {'carry0': 95.66787003610109,
  'carry1': 93.37062937062936,
  'carry2': 95.57089444923785,
  'carry3': 95.41213063763608})

测试extra number

In [21]:
from contextlib import nullcontext
ctx = nullcontext()
config={
    'start': 'FILE:./extra_num_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode, max_new_tokens=1)

evaluating addition from: FILE:./extra_num_judge_prompt.txt


  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [00:00<00:00, 13763.47it/s]
100%|██████████| 83/83 [00:00<00:00, 152.41it/s]

Judgement accuracy of 10000 examples: 9371/10000 (93.71000000000001%)
No judging probability of 10000 examples: 0/10000 (0.0%)
{'carry0': 92.35860409145607, 'carry1': 93.2027972027972, 'carry2': 94.36295657175727, 'carry3': 95.10108864696734}





(93.71000000000001,
 0.0,
 {'carry0': 92.35860409145607,
  'carry1': 93.2027972027972,
  'carry2': 94.36295657175727,
  'carry3': 95.10108864696734})

测试add noise

In [22]:
ctx = nullcontext()
config={
    'start': 'FILE:./add_noise_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode, max_new_tokens=1)

evaluating addition from: FILE:./add_noise_judge_prompt.txt


100%|██████████| 10000/10000 [00:00<00:00, 11979.32it/s]
100%|██████████| 82/82 [00:00<00:00, 145.56it/s]

Judgement accuracy of 10000 examples: 6717/10000 (67.17%)
No judging probability of 10000 examples: 0/10000 (0.0%)
{'carry0': 62.15403128760529, 'carry1': 65.65034965034965, 'carry2': 68.99626114466494, 'carry3': 72.93934681181959}





(67.17,
 0.0,
 {'carry0': 62.15403128760529,
  'carry1': 65.65034965034965,
  'carry2': 68.99626114466494,
  'carry3': 72.93934681181959})

使用best judgement acc model测试

In [23]:
from model import GPTConfig, GPT
import torch

# init from a model saved in a specific directory
ckpt_path = 'ckpt_judge_acc_bilabel(0.6p).pt'
checkpoint = torch.load(ckpt_path, map_location=mydevice)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.to(mydevice)

number of parameters: 10.63M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(16, 384)
    (wpe): Embedding(256, 384)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=384, out_features=1152, bias=False)
          (c_proj): Linear(in_features=384, out_features=384, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=384, out_features=1536, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1536, out_features=384, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=384, out_features=16, bias=False)
)

In [24]:
encode, decode = get_encode_decode('./data/addition_bilabel/meta.pkl')

Loading meta from ./data/addition_bilabel/meta.pkl...


In [25]:
from contextlib import nullcontext
ctx = nullcontext()
config={
    'start': 'FILE:./prompt.txt',
    'device': mydevice,
    'temperature': 0.8
}
eval_addition_batch(config, model, ctx, encode, decode, judge=True)

evaluating addition from: FILE:./prompt.txt


100%|██████████| 10000/10000 [00:00<00:00, 16885.27it/s]
100%|██████████| 81/81 [00:01<00:00, 52.43it/s]

Judgement accuracy of 10000 examples: 9499/10000 (94.99%)
accuracy of 10000 examples: 9500/10000 (95.0%)
{'carry0': 95.66787003610109, 'carry1': 93.84615384615384, 'carry2': 95.71469657750936, 'carry3': 95.41213063763608}





(94.99,
 95.0,
 {'carry0': 95.66787003610109,
  'carry1': 93.84615384615384,
  'carry2': 95.71469657750936,
  'carry3': 95.41213063763608})

In [26]:
from contextlib import nullcontext
ctx = nullcontext()
config={
    'start': 'FILE:./extra_num_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode, max_new_tokens=1)

evaluating addition from: FILE:./extra_num_judge_prompt.txt


100%|██████████| 10000/10000 [00:00<00:00, 16776.22it/s]
100%|██████████| 83/83 [00:00<00:00, 176.36it/s]

Judgement accuracy of 10000 examples: 9415/10000 (94.15%)
No judging probability of 10000 examples: 0/10000 (0.0%)
{'carry0': 93.14079422382672, 'carry1': 93.7062937062937, 'carry2': 94.16163359217717, 'carry3': 96.65629860031105}





(94.15,
 0.0,
 {'carry0': 93.14079422382672,
  'carry1': 93.7062937062937,
  'carry2': 94.16163359217717,
  'carry3': 96.65629860031105})

In [27]:
ctx = nullcontext()
config={
    'start': 'FILE:./add_noise_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode, max_new_tokens=1)

evaluating addition from: FILE:./add_noise_judge_prompt.txt


100%|██████████| 10000/10000 [00:00<00:00, 15096.86it/s]
100%|██████████| 82/82 [00:00<00:00, 157.35it/s]

Judgement accuracy of 10000 examples: 6644/10000 (66.44%)
No judging probability of 10000 examples: 0/10000 (0.0%)
{'carry0': 62.15403128760529, 'carry1': 64.47552447552447, 'carry2': 68.47857348288755, 'carry3': 71.9284603421462}





(66.44,
 0.0,
 {'carry0': 62.15403128760529,
  'carry1': 64.47552447552447,
  'carry2': 68.47857348288755,
  'carry3': 71.9284603421462})

## 新数据：20%positive + 80%negative

In [2]:
import os
import random
import re

class create_test_data:
    def __init__(self, non_overlap_data_path, num_test_samples=100) -> None:
        if not os.path.exists(non_overlap_data_path):
            raise ValueError("There is no nonoverlap data file")
        self.non_overlap_data_path = non_overlap_data_path
        self.num_test_samples = num_test_samples
        self.samples = None
        # Open non-overlap data path to get a bunch of test samples
        with open(self.non_overlap_data_path, 'r') as f:
            lines = f.readlines()
            random.shuffle(lines)
            self.samples = lines[:self.num_test_samples]
            
        
    def create_prompt(self):
        """
        To create prompt data file like: 'Ta+b=' .
        The output should be correct answer and judgement
        """
        with open(f'prompt.txt', 'w') as f2:
            for line in self.samples:
                prompt = line.split('=')[0]+'=\n'
                f2.write(prompt)
                    
    def create_add_noise_judge_prompt(self):
        """
        To create prompt data file like: 'a+b=c' and 'a+b=d', where 'd' means wrong answer
        The output should be judgement
        """
        with open('add_noise_judge_prompt.txt', 'w') as f3:
            for line in self.samples:
                # 取出表达式部分
                prompt = line.split('T')[1].strip()
                # prompt = prompt.split('=')[0].strip()
                new_prompt = self.modify_result(prompt, random.randint(1, 9), 'noise_add')
                f3.write(new_prompt + '?\n')
                
    def create_extra_num_judge_prompt(self):
        """
        To create prompt data file like: 'a+b=c' and 'a+b=cd', where 'd' means extra number
        The output should be judgement
        """
        with open('extra_num_judge_prompt.txt', 'w') as f4:
            for line in self.samples:
                # 取出表达式部分
                prompt = line.split('T')[1].strip()
                # prompt = prompt.split('=')[0].strip()
                new_prompt = self.modify_result(prompt, random.randint(1, 9), 'extra_num')
                f4.write(new_prompt + '?\n')

    
    def modify_result(self, expression, addend, mode):
        # 使用正则表达式提取表达式中的数字
        match = re.match(r'(\d+)\+(\d+)=(\d+)', expression)
    
        if match:
            # 提取数字并计算新的结果
            num1 = int(match.group(1))
            num2 = int(match.group(2))
            result = int(match.group(3))
            num_digit = len(match.group(3).strip())
            if mode == 'extra_num':
                if random.uniform(0,1)>0.5 and random.uniform(0,1)<0.75:
                    extra = random.randint(1, 9)
                    new_expression = f"F{num1}+{num2}={result}{extra}"
                elif random.uniform(0,1)>0.75:
                    extra = random.randint(1, 9)
                    new_expression = f"F{num1}+{num2}={extra}{result}"
                else:
                    new_expression = f"T{num1}+{num2}={result}"
            elif mode == 'noise_add':
                if random.uniform(0,1)>0.5:
                    # 决定添加错误的位置
                    wrong_loc = random.randint(0, num_digit)
        
                    new_result = result + addend * (10**wrong_loc)

                    # 构建新的表达式
                    new_expression = f"F{num1}+{num2}={new_result}"
                else:
                    new_expression = f"T{num1}+{num2}={result}"
            
            else:
                return "Invalid modify pattern"
        
            return new_expression
        else:
            return "Invalid expression format"

In [3]:
non_overlap_data_path = './data/get_data_with_label/train_3digit_bilabeled10000_nonoverlap.txt'
num_examples = 10000
new_creator = create_test_data(non_overlap_data_path, num_examples)
new_creator.create_prompt()
new_creator.create_extra_num_judge_prompt()
new_creator.create_add_noise_judge_prompt()

In [4]:
from model import GPTConfig, GPT
import torch

# init from a model saved in a specific directory
ckpt_path = 'ckpt_judge_acc_bilabel(0.2p).pt'
checkpoint = torch.load(ckpt_path, map_location=mydevice)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.to(mydevice)

number of parameters: 10.63M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(16, 384)
    (wpe): Embedding(256, 384)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=384, out_features=1152, bias=False)
          (c_proj): Linear(in_features=384, out_features=384, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=384, out_features=1536, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1536, out_features=384, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=384, out_features=16, bias=False)
)

In [5]:
encode, decode = get_encode_decode('./data/addition_bilabel/meta.pkl')

Loading meta from ./data/addition_bilabel/meta.pkl...


In [6]:
from contextlib import nullcontext
ctx = nullcontext()
config={
    'start': 'FILE:./prompt.txt',
    'device': mydevice,
    'temperature': 0.8
}
eval_addition_batch(config, model, ctx, encode, decode, judge=True)

evaluating addition from: FILE:./prompt.txt


100%|██████████| 10000/10000 [00:03<00:00, 2800.92it/s]
100%|██████████| 81/81 [00:04<00:00, 17.82it/s]

Judgement accuracy of 10000 examples: 9434/10000 (94.34%)
accuracy of 10000 examples: 9434/10000 (94.34%)
{'carry0': 95.57522123893806, 'carry1': 92.58426966292134, 'carry2': 95.56204379562044, 'carry3': 94.31818181818183}





(94.34,
 94.34,
 {'carry0': 95.57522123893806,
  'carry1': 92.58426966292134,
  'carry2': 95.56204379562044,
  'carry3': 94.31818181818183})

In [8]:
from contextlib import nullcontext
ctx = nullcontext()
config={
    'start': 'FILE:./extra_num_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode, max_new_tokens=1)

evaluating addition from: FILE:./extra_num_judge_prompt.txt


100%|██████████| 10000/10000 [00:03<00:00, 3183.20it/s]
100%|██████████| 83/83 [00:01<00:00, 73.97it/s]

Judgement accuracy of 10000 examples: 8865/10000 (88.64999999999999%)
No judging probability of 10000 examples: 0/10000 (0.0%)
{'carry0': 90.4424778761062, 'carry1': 87.80898876404495, 'carry2': 87.97080291970802, 'carry3': 90.37878787878788}





(88.64999999999999,
 0.0,
 {'carry0': 90.4424778761062,
  'carry1': 87.80898876404495,
  'carry2': 87.97080291970802,
  'carry3': 90.37878787878788})

In [10]:
ctx = nullcontext()
config={
    'start': 'FILE:./add_noise_judge_prompt.txt',
    'device': mydevice,
}
eval_judge_batch(config, model, ctx, encode, decode, max_new_tokens=1)

evaluating addition from: FILE:./add_noise_judge_prompt.txt


100%|██████████| 10000/10000 [00:05<00:00, 1682.61it/s]
100%|██████████| 82/82 [01:26<00:00,  1.06s/it] 

Judgement accuracy of 10000 examples: 6695/10000 (66.95%)
No judging probability of 10000 examples: 0/10000 (0.0%)
{'carry0': 65.66371681415929, 'carry1': 66.01123595505618, 'carry2': 66.97810218978103, 'carry3': 71.06060606060606}





(66.95,
 0.0,
 {'carry0': 65.66371681415929,
  'carry1': 66.01123595505618,
  'carry2': 66.97810218978103,
  'carry3': 71.06060606060606})