In [1]:
import math
import random
import os

# 创建带有标签的数据集/构造错误样本

Input data:
```
1+1=2
1+2=3
```

希望实现的是：
1. 创建mixed data set。即包括一个计算和与其对应的判断，正常来讲是一个正确的计算对应一个正确的判断；挑选其中x%的数据，构造错误样本，然后配备对应的计算和判断
2. 创建judge only data。只含有判断数据，同样是一个正确的判断对应，其中x%的数据构造对应的负样本（包括训练集与测试集）

In [2]:
# 设置随机种子
random.seed(0)

def randint_exclude(start, end, exclude):
    while True:
        num = random.randint(start, end)
        if num != exclude:
            return num
        
def replace_char_at_index(orig_string, index, new_char):
    new_string = orig_string[:index] + new_char + orig_string[index+1:]
    return new_string

class data_creator:
    def __init__(self, num_digits=3, mode='mix') -> None:
        random.seed(1234)
        self.num_digits = num_digits
        # self.total_num_examples = total_num_examples
        # self.num_test_samples = num_test_samples
        # neg_nums是若该式有负样本，负样本的个数
        # self.neg_nums = neg_nums
        self.mode = mode
        
        # self.training_data_path = f'train_{num_digits}digit_mixed_{total_num_examples}.txt'
        # self.non_overlap_path = f'train_{num_digits}digit_mixed_{total_num_examples}_nonoverlap.txt'
        
    
    def numCarryOps(self, a, b):
        a,b=int(a),int(b)
        def digitSum(n):
            return sum(map(int,str(n)))
        # assert(a >= 0); assert(b >= 0);
        return int((digitSum(a) + digitSum(b) - digitSum(a+b)) / 9)
    
    def create_from_addition_data(self, addition_data_path, output_path, wrong_number=1):
        with open(addition_data_path, 'r') as f:
            lines = f.readlines()
        with open(output_path, 'w') as f2:
            for line in lines:
                a_b = line.split('=')[0]
                a, b = a_b.split('+')
                a = int(a)
                b = int(b)
                c = a + b
                if self.mode == 'mix':
                    f2.write(f'T{a}+{b}={c}\n')
                f2.write(f'j{a}+{b}={c}~T\n')
                # 70% positive instances and 30% negative instances
                if random.uniform(0,1)>0.7:
                    # Length not match
                    flag = random.uniform(0,1)
                    if flag > 0.5:
                        extra = random.randint(0, 9)
                        # f.write(f'j({a}+{b}={c}{extra})~F\n')
                        f2.write(f'j{a}+{b}={c}{extra}~F\n')
                    else:
                        extra = random.randint(1, 9)
                        # if self.mode:
                        #     f.write(f'F{a}+{b}={extra}{c}\n')
                        # f.write(f'j({a}+{b}={extra}{c})~F\n')
                        f2.write(f'j{a}+{b}={extra}{c}~F\n')

                    # Wrong number 1
                    if wrong_number == 1:
                        wrong_loc = random.randint(0, self.num_digits)
                        addend = random.randint(1, 9)
                        new_result = c + addend * (10**wrong_loc)
                        f2.write(f'j{a}+{b}={new_result}~F\n')
                    elif wrong_number == 2:
                    # Wrong number 2
                        char_c = str(c)
                        n = len(char_c)
                        wrong_loc = random.randint(0, n-1)
                        changed_number = char_c[wrong_loc]
                        if wrong_loc == 0:
                            addend = randint_exclude(1, 9, int(changed_number))
                        else:
                            addend = randint_exclude(0, 9, int(changed_number))
                            
                        char_c = replace_char_at_index(char_c, wrong_loc, str(addend))
                        new_result = int(char_c)
                        f2.write(f'j{a}+{b}={new_result}~F\n')
                    else:
                        raise ValueError("Wrong number should be 1 or 2")
        print("All files are done!")
                    
        
    
    # def create_balanced_training_data_with_label(self):
    #     num_digit_2 = int(900*self.total_num_examples/10000)
    #     num_digit_list = [100, num_digit_2, self.total_num_examples - 100 - num_digit_2]
    #     print(num_digit_list)

    #     # create a list of number of carries - we target each number of carries to have the same number of examples
    #     target_num_carry_examples = math.ceil(self.total_num_examples / (self.num_digits+1))
    #     num_carry_list = [0 for i in range(self.num_digits+1)]
    #     # 创建字典统计错误样本的每个类型的个数
    #     num_error_list = {}
    #     for i in range(self.num_digits+1):
    #         num_error_list[i] = 0

    #     with open(self.training_data_path, 'w') as f:
    #         num_example = 0

    #         # generate all 1 digit examples
    #         # 1位加法全为正确
    #         for a in range(10):
    #             for b in range(10):
    #                 c = a + b
    #                 if self.mode == 'mix':
    #                     f.write(f'T{a}+{b}={c}\n')
    #                 f.write(f'j({a}+{b}={c})~T\n')
    #                 num_example += 1
    #                 num_carry = self.numCarryOps(a, b)
    #                 num_carry_list[num_carry] += 1

            # for num_digit in range(2, self.num_digits+1):
            #     num_digit_example = 0
            #     print(num_digit,  num_example, num_carry_list)
            #     while num_digit_example < num_digit_list[num_digit-1] and num_example < self.total_num_examples:
            #         # generate a random number between 0 and 10^(i+1) - 1
            #         a = random.randint(0, 10**(num_digit) - 1)
            #         b = random.randint(0, 10**(num_digit) - 1)
            #         c = a + b

            #         # count number of carries in c
            #         num_carry = self.numCarryOps(a, b)
            #         if num_carry_list[num_carry] < target_num_carry_examples:
            #             # 70% positive instances and 30% negative instances
            #             # if self.mode:
            #             #     f.write(f'T{a}+{b}={c}\n')
            #             f.write(f'j({a}+{b}={c})~T\n')
                        # if random.uniform(0,1)>0.7:
                        #     for j in range(self.neg_nums):
                        #         # Length not match
                        #         flag = random.uniform(0,1)
                        #         if flag > 0.5:
                        #             extra = random.randint(0, 9)
                        #             # if self.mode:
                        #             #     f.write(f'F{a}+{b}={c}{extra}\n')
                        #             f.write(f'j({a}+{b}={c}{extra})~F\n')
                        #         else:
                        #             extra = random.randint(1, 9)
                        #             # if self.mode:
                        #             #     f.write(f'F{a}+{b}={extra}{c}\n')
                        #             f.write(f'j({a}+{b}={extra}{c})~F\n')
                                    
                        #         # Wrong number 1
                        #         wrong_loc = random.randint(0, num_digit)
                        #         addend = random.randint(1, 9)
                        #         new_result = c + addend * (10**wrong_loc)
                                
                        #         # f.write(f'F{a}+{b}={new_result}\n')
                        #         f.write(f'j({a}+{b}={new_result})~F\n')
                                
                        #         # Wrong number2
                        #         char_c = str(c)
                        #         n = len(char_c)
                        #         wrong_loc = random.randint(0, n-1)
                        #         changed_number = char_c[wrong_loc]
                        #         if wrong_loc == 0:
        #                             addend = randint_exclude(1, 9, int(changed_number))
        #                         else:
        #                             addend = randint_exclude(0, 9, int(changed_number))
                                
        #                         char_c = replace_char_at_index(char_c, wrong_loc, str(addend))
        #                         new_result = int(char_c)
        #                         f.write(f'j({a}+{b}={new_result})~F\n')
                                
        #                         num_error_list[wrong_loc] += 1

        #                 num_carry_list[num_carry] += 1
        #                 num_digit_example += 1
        #                 num_example += 1
        #             else:
        #                 continue
        
        # print(num_carry_list)
        # print('addend error:', num_error_list)
    
    # def create_non_overlap_data(self):
    #     lines_to_remove = set()
    #     with open(self.training_data_path, 'r') as f:
    #         for line in f.readlines():
    #             lines_to_remove.add(line)

    #     print(len(lines_to_remove))

    #     with open(self.non_overlap_path, 'w') as f:
    #         for x in range(1000):
    #             for y in range(1000):
    #                 line_to_add = f'T{x}+{y}={x+y}\n'
    #                 if line_to_add in lines_to_remove:
    #                     lines_to_remove.remove(line_to_add)
    #                 else:
    #                     f.write(line_to_add)
    
    # def create_test_data(self):
    #     if not os.path.exists(self.non_overlap_path):
    #         raise ValueError("There is no nonoverlap data file")
    #     with open(self.non_overlap_path, 'r') as f:
    #         lines = f.readlines()
    #         random.shuffle(lines)
    #         with open(f'test_{self.num_digits}digit_mixed_{self.num_test_samples}.txt', 'w') as f2:
    #             for line in lines[:self.num_test_samples]:
    #                 prompt = line.split('=')[0]+'=\n'
    #                 f2.write(prompt)
                    
    # def create(self):
    #     self.create_balanced_training_data_with_label()
    #     self.create_non_overlap_data()
    #     self.create_test_data()
        
    #     print("All files are done!")

In [3]:
num_digits = 3

creator = data_creator(num_digits=3, mode='mix')
for error_type in [1, 2]:
    for num in [2000, 4000, 6000, 8000, 10000]:
        addition_data_path = f'../bal/train_3digit_{num}.txt'
        output_path = f'../mixed/train_3digit_mixed_W{error_type}_{num}.txt'
        creator.create_from_addition_data(addition_data_path, output_path, error_type)

All files are done!
All files are done!
All files are done!
All files are done!
All files are done!
All files are done!
All files are done!
All files are done!
All files are done!
All files are done!


In [4]:
num_digits = 3

creator = data_creator(num_digits=3, mode='mix')
for error_type in [1, 2]:
    for num in [10000]:
        addition_data_path = f'../bal/test_3digit_{num}.txt'
        output_path = f'../mixed/test_3digit_mixed_W{error_type}_{num}.txt'
        creator.create_from_addition_data(addition_data_path, output_path, error_type)

All files are done!
All files are done!


In [5]:
num_digits = 3

creator = data_creator(num_digits=3, mode='judge')
for error_type in [1, 2]:
    for num in [2000, 4000, 6000, 8000, 10000]:
        addition_data_path = f'../bal/train_3digit_{num}.txt'
        output_path = f'../bal/train_3digit_jugde_W{error_type}_{num}.txt'
        creator.create_from_addition_data(addition_data_path, output_path, error_type)

All files are done!
All files are done!
All files are done!
All files are done!
All files are done!
All files are done!
All files are done!
All files are done!
All files are done!
All files are done!


In [6]:
num_digits = 3

creator = data_creator(num_digits=3, mode='judge')
for error_type in [1, 2]:
    for num in [10000]:
        addition_data_path = f'../bal/test_3digit_{num}.txt'
        output_path = f'../bal/test_3digit_judge_W{error_type}_{num}.txt'
        creator.create_from_addition_data(addition_data_path, output_path, error_type)

All files are done!
All files are done!
