In [1]:
import math
import random
import os

In [2]:
class data_creator:
    def __init__(self, num_digits=3, total_num_examples=10000, num_test_samples=10000) -> None:
        self.num_digits = num_digits
        self.total_num_examples = total_num_examples
        self.num_test_samples = num_test_samples
        
        self.training_data_path = f'train_{num_digits}digit_bilabeled{total_num_examples}_new_sp.txt'
        self.non_overlap_path = f'train_{num_digits}digit_bilabeled{total_num_examples}_nonoverlap_new_sp.txt'
        
    
    def numCarryOps(self, a, b):
        a,b=int(a),int(b)
        def digitSum(n):
            return sum(map(int,str(n)))
        # assert(a >= 0); assert(b >= 0);
        return int((digitSum(a) + digitSum(b) - digitSum(a+b)) / 9)
    
    def create_balanced_training_data_with_label(self):
        num_digit_2 = int(900*self.total_num_examples/10000)
        num_digit_list = [100, num_digit_2, self.total_num_examples - 100 - num_digit_2]
        print(num_digit_list)

        # create a list of number of carries - we target each number of carries to have the same number of examples
        target_num_carry_examples = math.ceil(self.total_num_examples / (self.num_digits+1))
        num_carry_list = [0 for i in range(self.num_digits+1)]

        with open(self.training_data_path, 'w') as f:
            num_example = 0

            # generate all 1 digit examples
            # 1位加法全为正确
            for a in range(10):
                for b in range(10):
                    c = a + b
                    f.write(f'T {a}+{b}={c} T\n')
                    num_example += 1
                    num_carry = self.numCarryOps(a, b)
                    num_carry_list[num_carry] += 1

            for num_digit in range(2, self.num_digits+1):
                num_digit_example = 0
                print(num_digit,  num_example, num_carry_list)
                while num_digit_example < num_digit_list[num_digit-1] and num_example < self.total_num_examples:
                    # generate a random number between 0 and 10^(i+1) - 1
                    a = random.randint(0, 10**(num_digit) - 1)
                    b = random.randint(0, 10**(num_digit) - 1)
                    c = a + b

                    # count number of carries in c
                    num_carry = self.numCarryOps(a, b)
                    if num_carry_list[num_carry] < target_num_carry_examples:
                        # 70% positive instances and 30% negative instances
                        if random.uniform(0,1)>0.8:
                            # 其中一半负样本为extra number
                            if random.uniform(0,1)>0.5:
                                flag = random.uniform(0,1)
                                extra = random.randint(1, 9)
                                if flag > 0.5:
                                    f.write(f'F {a}+{b}={c}{extra} F\n')
                                else:
                                    f.write(f'F {a}+{b}={extra}{c} F\n')
                            # 另一半样本为add digit
                            else:
                                wrong_loc = random.randint(0, num_digit)
                                addend = random.randint(1, 9)
                                new_result = c + addend * (10**wrong_loc)
                                f.write(f'F {a}+{b}={new_result} F\n')
                        else: 
                            # write the example to file
                            f.write(f'T {a}+{b}={c} T\n')
                        num_carry_list[num_carry] += 1
                        num_digit_example += 1
                        num_example += 1
                    else:
                        continue
        
        print(num_carry_list)
    
    def create_non_overlap_data(self):
        lines_to_remove = set()
        with open(self.training_data_path, 'r') as f:
            for line in f.readlines():
                lines_to_remove.add(line)

        print(len(lines_to_remove))

        with open(self.non_overlap_path, 'w') as f:
            for x in range(1000):
                for y in range(1000):
                    line_to_add = f'T {x}+{y}={x+y} T\n'
                    if line_to_add in lines_to_remove:
                        lines_to_remove.remove(line_to_add)
                    else:
                        f.write(line_to_add)
    
    def create_test_data(self):
        if not os.path.exists(self.non_overlap_path):
            raise ValueError("There is no nonoverlap data file")
        with open(self.non_overlap_path, 'r') as f:
            lines = f.readlines()
            random.shuffle(lines)
            with open(f'prompt_{self.num_digits}digit_{self.num_test_samples}_new_sp.txt', 'w') as f2:
                for line in lines[:self.num_test_samples]:
                    prompt = line.split('=')[0]+'=\n'
                    f2.write(prompt)
                    
    def create(self):
        self.create_balanced_training_data_with_label()
        self.create_non_overlap_data()
        self.create_test_data()
        
        print("All files are done!")

In [3]:
num_digits = 3

creator = data_creator(num_digits=3)

In [4]:
creator.create()

[100, 900, 9000]
2 100 [55, 45, 0, 0]
3 1000 [336, 444, 220, 0]
[2500, 2500, 2500, 2500]
9938
All files are done!
