In [2]:
from __future__ import __annotations__
import random
import os

def to_list(num: int) -> str:
    list_str = list(str(num))
    return list_str

def to_str(num: int) -> str:
    return " ".join(to_list(num))

def gen_examples(
    *,
    n_examples, 
    digits, 
    randomized_digits=False, 
    min_digit=1, 
    max_digit=8, 
    include_scrachpad,
):
    
    complete_str = []
    for _ in range(n_examples):
        if randomized_digits:
            digits = random.randrange(min_digit + 1, max_digit)

        first_number = random.randrange(int(10 ** digits), int(10 ** (digits + 1)))
        second_number = random.randrange(int(10 ** digits), int(10 ** (digits + 1)))
        input_sum = f'{to_str(first_number)} + {to_str(second_number)}'
        resultant_str = f'Input:\n{input_sum}\nTarget:\n'

        if include_scrachpad:
            scratch_pad = f'<scratch>\n{input_sum} , C: 0\n'
            carry = 0
            running_sum = ''
            
            initial = True

            for first_digit, second_digit in reversed(list(zip(
                to_list(first_number), to_list(second_number)
            ))):
                dig_sum = int(first_digit) + int(second_digit) + carry
                
                if not initial:
                    scratch_pad += f'{first_digit} + {second_digit} , {running_sum}C: {carry}\n'
                
                carry = int(dig_sum >= 10)
                running_sum = f'{dig_sum % 10} {running_sum}'
                initial = False
                
            scratch_pad += f', {running_sum}C: {carry}\n'
            scratch_pad += f'{carry} {running_sum}'.strip() + '\n'
            scratch_pad += '</scratch>\n'
            resultant_str += scratch_pad

        resultant_str += to_str(first_number + second_number)
        resultant_str += '\n\n'
        complete_str.append(resultant_str)

    return complete_str


def main(
    # n_samples = 10000
    # n_samples = 2000,
    n_samples = 10,

    context_examples = 50,
    examples_per_prompt = 1,
    # examples_per_prompt = 1   0

    val_start = 6,
    max_digits = 10,
    # include_scrachpad = True
    include_scrachpad = True,
    # fixed_examples = True
    fixed_examples = False,
    randomized_digits = False,
):
    if fixed_examples and randomized_digits:
        few_shot_str = gen_examples(
            n_samples         = examples_per_prompt - 1, 
            digits            = None, 
            randomized_digits = randomized_digits, 
            min_digit         = 1,
            max_digit         = val_start - 1,
            include_scrachpad = include_scrachpad,
        )

    for digits in range(max_digits):
        folder_name = 'val' if digits + 1 >= val_start else 'train'
        if include_scrachpad:
            folder_name += '_scratch'
            if fixed_examples and not randomized_digits:
                few_shot_str = gen_examples(
                    n_examples        = context_examples, 
                    digits            = digits, 
                    randomized_digits = randomized_digits,
                    include_scrachpad = include_scrachpad,
                )
        else:
            folder_name += '_direct'
        if not os.path.isdir(folder_name):
            os.mkdir(folder_name)
        complete_str = ''
        for _ in range(n_samples):
            n_gen_examples = 1 if fixed_examples else examples_per_prompt
            if fixed_examples:
                complete_str += "".join(random.sample(few_shot_str, examples_per_prompt - 1))
            complete_str += "".join(gen_examples(
                n_examples        = n_gen_examples, 
                digits            = digits,
                include_scrachpad = include_scrachpad,
            ))
            complete_str += '<|endoftext|>'

        print(complete_str)


main()


Input:
9 + 2
Target:
<scratch>
9 + 2 , C: 0
, 1 C: 1
1 1
</scratch>
1 1

<|endoftext|>Input:
6 + 8
Target:
<scratch>
6 + 8 , C: 0
, 4 C: 1
1 4
</scratch>
1 4

<|endoftext|>Input:
9 + 1
Target:
<scratch>
9 + 1 , C: 0
, 0 C: 1
1 0
</scratch>
1 0

<|endoftext|>Input:
3 + 6
Target:
<scratch>
3 + 6 , C: 0
, 9 C: 0
0 9
</scratch>
9

<|endoftext|>Input:
7 + 4
Target:
<scratch>
7 + 4 , C: 0
, 1 C: 1
1 1
</scratch>
1 1

<|endoftext|>Input:
6 + 1
Target:
<scratch>
6 + 1 , C: 0
, 7 C: 0
0 7
</scratch>
7

<|endoftext|>Input:
1 + 6
Target:
<scratch>
1 + 6 , C: 0
, 7 C: 0
0 7
</scratch>
7

<|endoftext|>Input:
8 + 2
Target:
<scratch>
8 + 2 , C: 0
, 0 C: 1
1 0
</scratch>
1 0

<|endoftext|>Input:
5 + 6
Target:
<scratch>
5 + 6 , C: 0
, 1 C: 1
1 1
</scratch>
1 1

<|endoftext|>Input:
9 + 8
Target:
<scratch>
9 + 8 , C: 0
, 7 C: 1
1 7
</scratch>
1 7

<|endoftext|>
Input:
8 9 + 1 6
Target:
<scratch>
8 9 + 1 6 , C: 0
8 + 1 , 5 C: 1
, 0 5 C: 1
1 0 5
</scratch>
1 0 5

<|endoftext|>Input:
4 7 + 8 9
Target:
<scra

In [None]:
def eval_output(args, output, answers, context, example_classes, accuracy, target_save, tokenizer, show=False, direct=False, endoftext="<|endoftext|>"):
    assert args.dataset_mode == "arithmetic", args.dataset_mode
    successful_examples = []
    enum_outputs = enumerate(output[1][0][:, :, 0])
    
    for (idx, o), target, cur_base_context, example_class in zip(enum_outputs, answers, context, example_classes):
        cur_output = tokenizer.decode(o)
        output_numbers = cur_output.split('\n')
        if example_class not in accuracy:
            accuracy[example_class] = {'accurate': 0, 'total': 0}
        accuracy[example_class]['total'] += 1
        if len(output_numbers) == 0:
            continue
        try:
            if args.dataset_mode == "cqa":
                output_numbers = output_numbers[0]
                if "<|endoftext|>" in output_numbers:
                    output_numbers = output_numbers.split("<|endoftext|>")[0]
                output_prediction = output_numbers[-3]
            elif args.dataset_mode == "gsm":
                output_prediction = ""
                for line_idx, line in enumerate(output_numbers):
                    if "####" in line:
                        output_numbers = "\n".join(output_numbers[:line_idx + 1])
                        if "<|endoftext|>" in output_numbers:
                            output_numbers = output_numbers.split("<|endoftext|>")[0]
                        output_prediction = output_numbers.split("####")[-1].strip()
                        break
            elif args.dataset_mode == "arithmetic":
                if len(output_numbers) == 0:
                    continue
                elif "<|endoftext|>" in output_numbers:
                    prediction_index = output_numbers.index("<|endoftext|>") - 1
                elif "</scratch>" in output_numbers:
                    prediction_index = output_numbers.index("</scratch>") + 1
                    if prediction_index == len(output_numbers):
                        continue
                else:
                    if direct and len(output_numbers) > 1:
                        prediction_index = 1
                    else:
                        prediction_index = 0
                output_prediction = output_numbers[prediction_index]

            if "<|endoftext|>" in output_prediction:
                output_prediction = output_prediction.split("<|endoftext|>")[0]

            correct = output_prediction.lower() == target.lower()
            if correct:
                accuracy[example_class]['accurate'] += 1
                with basic_open(target_save, 'a+') as new_train_f:
                    if args.dataset_mode == "cqa" or args.dataset_mode == "gsm":
                        new_example = cur_base_context + output_numbers + endoftext
                    elif args.dataset_mode == "arithmetic":
                        if args.few_shot_train:
                            raise NotImplementedError
                        joined_output = "\n".join(output_numbers[:prediction_index + 1])
                        if "<|endoftext|>" in joined_output:
                            joined_output = joined_output.split("<|endoftext|>")[0]
                        new_example = cur_base_context + joined_output + endoftext
                    if show:
                        print(new_example)
                    print(new_example, file=new_train_f, end="")
                successful_examples.append(idx)
        except IndexError:
            pass
    return successful_examples
