In [3]:
import argparse
import json
import random
from functools import partial
from pathlib import Path
from typing import List
from tqdm import tqdm

ones = {
    0: "",
    1: "one",
    2: "two",
    3: "three",
    4: "four",
    5: "five",
    6: "six",
    7: "seven",
    8: "eight",
    9: "nine",
    10: "ten",
    11: "eleven",
    12: "twelve",
    13: "thirteen",
    14: "fourteen",
    15: "fifteen",
    16: "sixteen",
    17: "seventeen",
    18: "eighteen",
    19: "nineteen",
}
tens = {2: "twenty", 3: "thirty", 4: "forty", 5: "fifty", 6: "sixty", 7: "seventy", 8: "eighty", 9: "ninety"}

magnitudes = [
    "ones",
    "tens",
    "hundreds",
    "thousands",
    "ten-thousands",
    "hundred-thousands",
    "millions",
    "ten-millions",
    "hundred-millions",
    "billions",
]


def say_number(i):
    """
    Convert an integer in to it's word representation.

    say_number(i: integer) -> string
    """
    if i < 0:
        return _join("negative", _say_number_pos(-i))
    if i == 0:
        return "zero"
    return _say_number_pos(i)


def _say_number_pos(i):
    if i < 20:
        return ones[i]
    if i < 100:
        return _join(tens[i // 10], ones[i % 10])

    return _divide(i, 100, "hundred")


def _divide(dividend, divisor, magnitude):
    return _join(
        _say_number_pos(dividend // divisor),
        magnitude,
        _say_number_pos(dividend % divisor),
    )


def _join(*args):
    return " ".join(filter(bool, args))


def _say_magnitude(i: int):
    assert i < len(magnitudes), f"magnitude not supported: {i}, max is {len(magnitudes)}"
    return magnitudes[i]

def generate_number(i):
    assert i > 0
    return random.randint(10 ** (i - 1), 10 ** i - 1)


def digits(x: int, size: int = 0, sep=0) -> List[int]:
    n = x
    digits = []
    while n > 0:
        digits.append(n % 10)
        n //= 10

    while len(digits) < size:
        digits.append(sep)
       

    return digits


def sweep(start_a: int, end_a: int, start_b: int, end_b: int):
    for x in range(start_a, end_a):
        for y in range(start_b, end_b):
            yield x, y


def sample(digits_x: int, digits_y: int, size: int):
    for _ in range(size):
        yield generate_number(digits_x), generate_number(digits_y)



# TODO: fix the algorithm, there is more recursion that needed

def add(x:list, y:list) -> list:
    """
    Adds two lists of digits representing integers.
    E.g., [1,2,3] + [4,5,6] = [5,7,9]
    """
    max_len = max(len(x), len(y))
    # Pad the shorter list with leading zeros
    x = [0] * (max_len - len(x)) + x
    y = [0] * (max_len - len(y)) + y

    result = []
    carry = 0

    for i in range(max_len - 1, -1, -1):
        total = x[i] + y[i] + carry
        result_digit = total % 10
        carry = total // 10
        result.insert(0, result_digit)

    if carry > 0:
        result.insert(0, carry)

    return result


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--task', type=str, default='addition', help='task to generate scratchpads for'
    )
    parser.add_argument(
        '--num-digit', type=int, default=4, help='maximum number of digits')
    parser.add_argument(
        "--number-prompts", type=int, default=0, help="Number of prompts, defaults to 0, meaning all combinations"
    )
    parser.add_argument(
        "--use-alternative-function", action='store_true',
        help="Compute an alternative function to multiplication, defined to be computationally similar to "
             "multiplication"
    )
    parser.add_argument("--output-path", type=str, default=None, help="output path")
    parser.add_argument(
        "--no-scratchpad", action='store_true',
        help="If set, generates only question prompts without scratchpads"
    )
    parser.add_argument(
        "--one-digit", action='store_true',
        help="If set, generates scratchpad without partial products logic"
    )
    args = parser.parse_args()

    print(f"one-digit: {args.one_digit}")

    output_path = Path(args.output_path)
    output_path.mkdir(parents=True, exist_ok=True)

    digits = list(range(1, args.num_digit + 1))
    for k in digits:
        start_k, end_k = 10 ** (k - 1), 10 ** k

        if args.one_digit:
            y_digits = [1]
        else:
            y_digits = digits[:k]
        for p in y_digits:
            outputs = []
            start_p, end_p = 10 ** (p - 1), 10 ** p
            total = (end_p - start_p) * (end_k - start_k)

            if 0 < args.number_prompts < total:
                num_prompts = args.number_prompts
                generator_fn = partial(sample, digits_x=k, digits_y=p, size=num_prompts)
            else:
                num_prompts = total
                generator_fn = partial(sweep, start_a=start_k, end_a=end_k, start_b=start_p, end_b=end_p)

            for i, (x, y) in tqdm(enumerate(generator_fn()), total=num_prompts, desc=f"{k}_by_{p}"):

                print(f"x: {x}, y: {y}")
                
                if args.task=="multiplication":
                    if args.one_digit==True:
                        prompt, question, expected_answer, partials = generate_prompt_multiplication_one_digit(x, y)
                    else:
                        prompt, question, expected_answer = generate_prompt_multiplication(x, y)

                outputs.append(
                    {
                        "question": question,
                        "prompt": prompt,
                        "expected_answer": expected_answer,
                        "partials": partials if args.one_digit else None,
                    }
                )

            alt_str = 'add_' if args.use_alternative_function else ''
            output_finale = (
                f"{alt_str}scratchpad_{k}_by_{p}_{args.number_prompts if args.number_prompts > 0 else ''}prompts.json"
            )
            result_path = output_path / output_finale
            with open(result_path, "w") as w:
                json.dump(outputs, w, indent=2)

In [4]:
def generate_prompt_multiplication_one_digit(x: int, y: int, size: int = 0, include_always_carry: bool = False, zfill: bool = False):

    digits_x = digits(x, size=size) if zfill else digits(x, size=size, sep=' ')
    digits_y = digits(y)
    
    question = f"What is the product between {str(x).zfill(size) if zfill else x:>4} and {y}?"
    steps = [f"Multiply {str(x).zfill(size) if zfill else x:>4} by {y} step by step."]

    dy = digits_y[0]
    carry_over = 0
    step_count = 0
    results = []
    partials = []

    # Iterate over xâ€™s digits from least to most significant
    for i, dx in enumerate(digits_x):
        # step_count += 1       # NOTE: removed to avoid more number
        if dx == ' ':
            dx = 0
        next_step = dx * dy + carry_over
        old_carry = carry_over

        if (next_step >= 10 and i < len(digits_x) - 1) or include_always_carry:
            residual = next_step % 10
            carry_over = next_step // 10
        else:
            residual = next_step
            carry_over = 0


        if old_carry > 0 or include_always_carry:
            step = (
                f"Multiply {dy} by the digit in the {_say_magnitude(i)} place of {str(x).zfill(size) if zfill else x:>4}, which is {dx}. "
                f"Add the carryover {old_carry}: ({dx}x{dy})+{old_carry}={str(next_step).zfill(2) if zfill else next_step:>2}. "
                f"Write down {residual}{' and carry over ' + str(carry_over) if (carry_over > 0 or include_always_carry) else ''}."
            )
        else:
            step = (
                f"Multiply {dy} by the digit in the {_say_magnitude(i)} place of {x}, which is {dx}: "
                f"{dx} x {dy} = {next_step}. "
                f"Write down {residual}{' and carry over ' + str(carry_over) if carry_over > 0 else ''}."
            )

        steps.append(step)
        results.insert(0, residual)
        partials.append((dx, dy, carry_over))

    product = x * y
    steps.append(f"The final product is {str(product).zfill(size) if zfill else product:>5}.")

    final_prompt = "\n".join(steps)     # NOTE: remove \n to avoid new lines
    return final_prompt, question, product, partials

x = 12
y = 9
size = 4
include_always_carry = True
zfill = True

text = generate_prompt_multiplication_one_digit(x, y, size, include_always_carry, zfill)[0]
print(text)


Multiply 0012 by 9 step by step.
Multiply 9 by the digit in the ones place of 0012, which is 2. Add the carryover 0: (2x9)+0=18. Write down 8 and carry over 1.
Multiply 9 by the digit in the tens place of 0012, which is 1. Add the carryover 1: (1x9)+1=10. Write down 0 and carry over 1.
Multiply 9 by the digit in the hundreds place of 0012, which is 0. Add the carryover 1: (0x9)+1=01. Write down 1 and carry over 0.
Multiply 9 by the digit in the thousands place of 0012, which is 0. Add the carryover 0: (0x9)+0=00. Write down 0 and carry over 0.
The final product is  0108.


In [5]:
x = 99

print(f"{x:>4d}")
print(f"{str(x).zfill(4)}")

  99
0099


In [9]:
from itertools import product

def compute_carry(x_digits, k):
    """Compute the carries for a 4-digit number x_digits and multiplier k."""
    carries = []
    carry = 0
    for d in x_digits[::-1]:  # start from units
        prod = d * k + carry
        carry = prod // 10
        carries.append(carry)
    return carries[::-1]

def find_carry_changers_full_number(x, k):
    """
    Find combinations of x (modifying one digit) and k that change the carry
    in each of the 4 intermediate steps.
    
    Returns a dict where each step contains a list of tuples:
      (new full number as int, k used, new carry)
    """
    if isinstance(x, int):
        x = f"{x:04d}"
    x_digits = [int(c) for c in x]
    orig_carry = compute_carry(x_digits, k)

    results = {i: [] for i in range(4)}

    # Modify each digit (0..9)
    for idx in range(4):
        for new_digit in range(10):
            if new_digit == x_digits[idx]:
                continue
            x_new_digits = x_digits.copy()
            x_new_digits[idx] = new_digit
            new_carry = compute_carry(x_new_digits, k)
            for step in range(4):
                if new_carry[step] != orig_carry[step]:
                    x_new_number = int(''.join(map(str, x_new_digits)))
                    results[step].append({
                        'new_number': x_new_number,
                        'k': k,
                        'new_carry': new_carry[step]
                    })

    # Modify k (0..9)
    for new_k in range(10):
        if new_k == k:
            continue
        new_carry = compute_carry(x_digits, new_k)
        for step in range(4):
            if new_carry[step] != orig_carry[step]:
                x_number = int(''.join(map(str, x_digits)))
                results[step].append({
                    'new_number': x_number,
                    'k': new_k,
                    'new_carry': new_carry[step]
                })

    return results

# --- Example ---
x = 12   # can also be '0012'
k = 9
changes = find_carry_changers_full_number(x, k)

for step, lst in changes.items():
    print(f"\nStep {step} (digit {['d','c','b','a'][step]}):")
    for change in lst:
        print(f"  New number: {change['new_number']}, k: {change['k']}, new carry: {change['new_carry']}")



Step 0 (digit d):
  New number: 2012, k: 9, new carry: 1
  New number: 3012, k: 9, new carry: 2
  New number: 4012, k: 9, new carry: 3
  New number: 5012, k: 9, new carry: 4
  New number: 6012, k: 9, new carry: 5
  New number: 7012, k: 9, new carry: 6
  New number: 8012, k: 9, new carry: 7
  New number: 9012, k: 9, new carry: 8

Step 1 (digit c):
  New number: 112, k: 9, new carry: 1
  New number: 212, k: 9, new carry: 1
  New number: 312, k: 9, new carry: 2
  New number: 412, k: 9, new carry: 3
  New number: 512, k: 9, new carry: 4
  New number: 612, k: 9, new carry: 5
  New number: 712, k: 9, new carry: 6
  New number: 812, k: 9, new carry: 7
  New number: 912, k: 9, new carry: 8

Step 2 (digit b):
  New number: 2, k: 9, new carry: 0
  New number: 32, k: 9, new carry: 2
  New number: 42, k: 9, new carry: 3
  New number: 52, k: 9, new carry: 4
  New number: 62, k: 9, new carry: 5
  New number: 72, k: 9, new carry: 6
  New number: 82, k: 9, new carry: 7
  New number: 92, k: 9, new car

In [18]:
compute_carry([0, 0, 9, 3], 9)

[0, 0, 8, 2]

In [8]:
changes

{0: [{'changed_pos': 0, 'new_digit': 2, 'k': 9, 'new_carry': 1},
  {'changed_pos': 0, 'new_digit': 3, 'k': 9, 'new_carry': 2},
  {'changed_pos': 0, 'new_digit': 4, 'k': 9, 'new_carry': 3},
  {'changed_pos': 0, 'new_digit': 5, 'k': 9, 'new_carry': 4},
  {'changed_pos': 0, 'new_digit': 6, 'k': 9, 'new_carry': 5},
  {'changed_pos': 0, 'new_digit': 7, 'k': 9, 'new_carry': 6},
  {'changed_pos': 0, 'new_digit': 8, 'k': 9, 'new_carry': 7},
  {'changed_pos': 0, 'new_digit': 9, 'k': 9, 'new_carry': 8}],
 1: [{'changed_pos': 1, 'new_digit': 1, 'k': 9, 'new_carry': 1},
  {'changed_pos': 1, 'new_digit': 2, 'k': 9, 'new_carry': 1},
  {'changed_pos': 1, 'new_digit': 3, 'k': 9, 'new_carry': 2},
  {'changed_pos': 1, 'new_digit': 4, 'k': 9, 'new_carry': 3},
  {'changed_pos': 1, 'new_digit': 5, 'k': 9, 'new_carry': 4},
  {'changed_pos': 1, 'new_digit': 6, 'k': 9, 'new_carry': 5},
  {'changed_pos': 1, 'new_digit': 7, 'k': 9, 'new_carry': 6},
  {'changed_pos': 1, 'new_digit': 8, 'k': 9, 'new_carry': 7},
 

In [43]:
def find_carry_changers_no_propagation(x, k):
    """
    Find combinations of x (modifying one digit) and k that change the carry
    at a given step without causing propagation to more significant steps.
    
    Returns a dict: each step has a list of dicts with:
      - 'new_number' : the full 4-digit number after change
      - 'k' : multiplier used
      - 'new_carry' : carry at that step
    """
    if isinstance(x, int):
        x = f"{x:04d}"
    x_digits = [int(c) for c in x]
    orig_carry = compute_carry(x_digits, k)

    results = {i: [] for i in range(4)}

    # Modify each digit
    for idx in range(4):
        for new_digit in range(10):
            if new_digit == x_digits[idx]:
                continue
            x_new_digits = x_digits.copy()
            x_new_digits[idx] = new_digit
            new_carry = compute_carry(x_new_digits, k)
            # Only keep if carry at this step flips AND all higher steps unchanged
            for step in range(4):
                if new_carry[step] != orig_carry[step] and new_carry[step+1:] == orig_carry[step+1:] and new_carry[:step] == orig_carry[:step]:
                    x_new_number = int(''.join(map(str, x_new_digits)))
                    results[step].append({
                        'new_number': x_new_number,
                        'k': k,
                        'new_carry': new_carry[step]
                    })

    # Modify k
    for new_k in range(10):
        if new_k == k:
            continue
        new_carry = compute_carry(x_digits, new_k)
        for step in range(4):
            if new_carry[step] != orig_carry[step] and new_carry[step+1:] == orig_carry[step+1:] and new_carry[:step] == orig_carry[:step]:
                x_number = int(''.join(map(str, x_digits)))
                results[step].append({
                    'new_number': x_number,
                    'k': new_k,
                    'new_carry': new_carry[step]
                })

    return results

# --- Example ---
x = 123
k = 9
changes = find_carry_changers_no_propagation(x, k)

for step, lst in changes.items():
    print(f"\nStep {step} (digit {['d','c','b','a'][step]}):")
    for change in lst:
        print(f"  New number: {change['new_number']}, k: {change['k']}, new carry: {change['new_carry']}")



Step 0 (digit d):
  New number: 1123, k: 9, new carry: 1
  New number: 2123, k: 9, new carry: 1
  New number: 3123, k: 9, new carry: 2
  New number: 4123, k: 9, new carry: 3
  New number: 5123, k: 9, new carry: 4
  New number: 6123, k: 9, new carry: 5
  New number: 7123, k: 9, new carry: 6
  New number: 8123, k: 9, new carry: 7
  New number: 9123, k: 9, new carry: 8

Step 1 (digit c):
  New number: 23, k: 9, new carry: 0
  New number: 223, k: 9, new carry: 2
  New number: 323, k: 9, new carry: 2
  New number: 423, k: 9, new carry: 3
  New number: 523, k: 9, new carry: 4
  New number: 623, k: 9, new carry: 5
  New number: 723, k: 9, new carry: 6
  New number: 823, k: 9, new carry: 7
  New number: 923, k: 9, new carry: 8

Step 2 (digit b):
  New number: 113, k: 9, new carry: 1
  New number: 143, k: 9, new carry: 3
  New number: 153, k: 9, new carry: 4
  New number: 163, k: 9, new carry: 5
  New number: 173, k: 9, new carry: 6
  New number: 183, k: 9, new carry: 7
  New number: 193, k: 9

Multiply 0123 by 9 step by step.
Multiply 9 by the digit in the ones place of 0123, which is 3. Add the carryover 0: (3x9)+0=27. Write down 7 and carry over 2.
Multiply 9 by the digit in the tens place of 0123, which is 2. Add the carryover 2: (2x9)+2=20. Write down 0 and carry over 2.
Multiply 9 by the digit in the hundreds place of 0123, which is 1. Add the carryover 2: (1x9)+2=11. Write down 1 and carry over 1.
Multiply 9 by the digit in the thousands place of 0123, which is 0. Add the carryover 1: (0x9)+1=01. Write down 1 and carry over 0.
The final product is  1107.


In [100]:
step = 3
sample = -1
new_number = changes[step][sample]['new_number']
new_k = changes[step][sample]['k']

text, *_ = generate_prompt_multiplication_one_digit(x, k, size=4, include_always_carry=True, zfill=True)
print(text)
print(compute_carry(digits(x, size=4), k), "\n")
print(step)

text_modified, *_= generate_prompt_multiplication_one_digit(new_number, new_k, size=4, include_always_carry=True, zfill=True)
print(compute_carry(digits(new_number, size=4), new_k), "\n")
print(text_modified)

Multiply 0123 by 9 step by step.
Multiply 9 by the digit in the ones place of 0123, which is 3. Add the carryover 0: (3x9)+0=27. Write down 7 and carry over 2.
Multiply 9 by the digit in the tens place of 0123, which is 2. Add the carryover 2: (2x9)+2=20. Write down 0 and carry over 2.
Multiply 9 by the digit in the hundreds place of 0123, which is 1. Add the carryover 2: (1x9)+2=11. Write down 1 and carry over 1.
Multiply 9 by the digit in the thousands place of 0123, which is 0. Add the carryover 1: (0x9)+1=01. Write down 1 and carry over 0.
The final product is  1107.
[2, 1, 0, 0] 

3
[8, 1, 0, 0] 

Multiply 0129 by 9 step by step.
Multiply 9 by the digit in the ones place of 0129, which is 9. Add the carryover 0: (9x9)+0=81. Write down 1 and carry over 8.
Multiply 9 by the digit in the tens place of 0129, which is 2. Add the carryover 8: (2x9)+8=26. Write down 6 and carry over 2.
Multiply 9 by the digit in the hundreds place of 0129, which is 1. Add the carryover 2: (1x9)+2=11. Wri

In [105]:
from itertools import product

def compute_carry(x_digits, k):
    """Compute the carries for a 4-digit number x_digits and multiplier k."""
    carries = []
    carry = 0
    for d in x_digits[::-1]:  # start from units
        prod = d * k + carry
        carry = prod // 10
        carries.append(carry)
    return carries[::-1]

def find_carry_changers_no_propagation(x, k):
    """
    Find combinations of x (modifying one digit) and k that change the carry
    at a given step without causing propagation to more significant steps.
    
    Returns a dict: each step has a list of dicts with:
      - 'new_number' : the full 4-digit number after change
      - 'k' : multiplier used
      - 'new_carry' : carry at that step
    """
    if isinstance(x, int):
        x = f"{x:04d}"
    x_digits = [int(c) for c in x]
    orig_carry = compute_carry(x_digits, k)

    results = {i: [] for i in range(4)}

    # Modify each digit
    for idx in range(4):
        for new_digit in range(10):
            if new_digit == x_digits[idx]:
                continue
            x_new_digits = x_digits.copy()
            x_new_digits[idx] = new_digit
            new_carry = compute_carry(x_new_digits, k)
            # Only keep if carry at this step flips AND all other steps unchanged
            for step in range(4):
                if new_carry[step] != orig_carry[step] and new_carry[step+1:] == orig_carry[step+1:] and new_carry[:step] == orig_carry[:step]:
                    x_new_number = int(''.join(map(str, x_new_digits)))
                    results[step].append({
                        'new_number': x_new_number,
                        'k': k,
                        'new_carry': new_carry[step]
                    })

    # Modify k
    for new_k in range(10):
        if new_k == k:
            continue
        new_carry = compute_carry(x_digits, new_k)
        for step in range(4):
            if new_carry[step] != orig_carry[step] and new_carry[step+1:] == orig_carry[step+1:] and new_carry[:step] == orig_carry[:step]:
                x_number = int(''.join(map(str, x_digits)))
                results[step].append({
                    'new_number': x_number,
                    'k': new_k,
                    'new_carry': new_carry[step]
                })

    return results

def get_all_combinations_per_step(changes):
    """
    Given the changes dict from find_carry_changers_no_propagation,
    return a dict where each step maps to all combinations of (new_number, k, new_carry).
    """
    results = {}
    for step, change_list in changes.items():
        results[step] = []
        for change in change_list:
            results[step].append((change['new_number'], change['k']))
    return results


# --- Example ---
x = 123
k = 9
changes = find_carry_changers_no_propagation(x, k)
combinations_per_step = get_all_combinations_per_step(changes)
combinations_per_step

{0: [(1123, 9),
  (2123, 9),
  (3123, 9),
  (4123, 9),
  (5123, 9),
  (6123, 9),
  (7123, 9),
  (8123, 9),
  (9123, 9)],
 1: [(23, 9),
  (223, 9),
  (323, 9),
  (423, 9),
  (523, 9),
  (623, 9),
  (723, 9),
  (823, 9),
  (923, 9)],
 2: [(113, 9), (143, 9), (153, 9), (163, 9), (173, 9), (183, 9), (193, 9)],
 3: [(124, 9), (125, 9), (126, 9), (127, 9), (128, 9), (129, 9)]}