# Algorithm 2

This is a Python implementation of algorithm 2 from the article "Unsigned integer multiplication with a constant fraction" by Michael Schmidt.

TODO: Add a link to the article.

## Requirements

- Python 3.10 or later
- Jupyter Notebook

## Usage

Start by running all cells in the notebook. This will setup everything up and run a small example.

The skip the next cell with the huge chunk of code. You will find further instructions below.


In [52]:
# pyright: strict
from dataclasses import dataclass
import math


@dataclass(frozen=True, kw_only=True)
class Problem:
    u: int
    t: int
    d: int
    r_d: int

    def __post_init__(self) -> None:
        assert self.u >= 0
        assert self.t >= 0
        assert self.d > 0
        assert 0 <= self.r_d < self.d

    @staticmethod
    def floor(u: int, t: int, d: int) -> "Problem":
        return Problem(u=u, t=t, d=d, r_d=0)

    @staticmethod
    def round(u: int, t: int, d: int) -> "Problem":
        return Problem(u=u, t=t, d=d, r_d=d // 2)

    @staticmethod
    def ceil(u: int, t: int, d: int) -> "Problem":
        return Problem(u=u, t=t, d=d, r_d=d - 1)

    def simplified(self) -> "Problem":
        if self.at(self.u) == 0:
            # the problem is a constant 0 function
            return Problem(u=self.u, t=0, d=1, r_d=0)

        gcd = math.gcd(self.t, self.d)
        return Problem(
            u=self.u,
            t=self.t // gcd,
            d=self.d // gcd,
            r_d=self.r_d // gcd,
        )

    def at(self, x: int) -> int:
        """
        Returns the value of the function at x.
        """
        return (x * self.t + self.r_d) // self.d


@dataclass(frozen=True, kw_only=True)
class Solution:
    f: int
    a: int
    s: int

    def original(self) -> "Solution":
        """
        Returns the smallest solution from which this solution was derived.
        """
        solution = self
        while solution.s > 0 and solution.f % 2 == 0:
            solution = Solution(f=solution.f // 2, a=solution.a // 2, s=solution.s - 1)
        return solution


@dataclass(frozen=True)
class Range:
    """
    A range of integers from min to max, inclusive.
    """

    min: int
    max: int

    def __post_init__(self) -> None:
        assert self.min <= self.max


@dataclass(frozen=True, kw_only=True)
class SolutionRange:
    f: int
    A: Range
    s: int

    def pick_any(self) -> Solution:
        f = self.f
        A = self.A
        s = self.s

        if A.min == 0:
            a = 0
        elif A.min <= f <= A.max:
            a = f
        elif A.max == 2**s - 1:
            a = A.max
        else:
            a = A.min

        return Solution(f=f, a=a, s=s)

    def original(self) -> "SolutionRange":
        """
        Returns the smallest solution from which this solution was derived.
        """
        solution = self
        while solution.s > 0 and solution.f % 2 == 0:
            solution = SolutionRange(
                f=solution.f // 2,
                A=Range(solution.A.min // 2, solution.A.max // 2),
                s=solution.s - 1,
            )
        return solution


U8: int = 2**8 - 1
U16: int = 2**16 - 1
U32: int = 2**32 - 1
U64: int = 2**64 - 1


def div_ceil(a: int, b: int) -> int:
    return -(a // -b)


def primitive_solution(p: Problem) -> Solution:
    s = math.ceil(math.log2((p.u + 1) * p.d))
    f = ((p.t << s) + p.d - 1) // p.d
    a = ((p.r_d << s) + p.d - 1) // p.d
    return Solution(f=f, a=a, s=s)


def algorithm_2_minimal(
    p: Problem, force_minimal: bool = False, log: bool = True
) -> tuple[SolutionRange, bool]:
    """
    Returns a solution and whether it is *guaranteed* to be minimal.

    Even if False is returned, the solution might be minimal in some cases.
    """

    p = p.simplified()

    if p.t == 0 or p.d == 1:
        # trivial solution
        return SolutionRange(f=p.t, A=Range(0, 0), s=0), True

    if p.t >= p.d:
        # reduce the problem to a smaller one
        k = p.t // p.d
        solution, is_minimal = algorithm_2_minimal(
            Problem(u=p.u, t=p.t % p.d, d=p.d, r_d=p.r_d),
            force_minimal=force_minimal,
            log=log,
        )
        return (
            SolutionRange(f=solution.f + (k << solution.s), A=solution.A, s=solution.s),
            is_minimal,
        )

    # use the un-derived primitive solution as the starting point
    solution = primitive_solution(p).original()

    if solution.s == 0:
        # trivially minimal
        return SolutionRange(f=solution.f, A=Range(0, 0), s=0), True

    assert solution.f % 2 == 1

    # chicken out if u<2d
    if p.u < 2 * p.d and not force_minimal:
        # the solution is not guaranteed to be minimal
        return (
            SolutionRange(f=solution.f, A=Range(solution.a, solution.a), s=solution.s),
            False,
        )

    # refine the primitive solution
    X = get_input_set(p)
    if log:
        print(f"|X| = {len(X)}")

    solution = algorithm_1a(p, f=solution.f, s=solution.s, X=X)
    assert solution is not None

    while True:
        assert solution.f % 2 == 1
        original_s = solution.s
        for f in [solution.f - 1, solution.f + 1]:
            neighbor = algorithm_1a(p, f=f, s=solution.s, X=X)
            if neighbor is not None:
                solution = neighbor.original()
                break

        if solution.s == 0 or solution.s == original_s:
            # found the minimal solution range
            return solution, True


def algorithm_1f(p: Problem, a: int, s: int, X: set[int]) -> Range | None:
    v = (p.u * p.t + p.r_d) // p.d
    f_min = div_ceil((v << s) - a, p.u)
    f_max = (((v + 1) << s) - a) // p.u

    for x in X:
        if x == 0:
            continue

        y = (x * p.t + p.r_d) // p.d

        f_min = max(f_min, div_ceil((y << s) - a, x))
        f_max = min(f_max, (((y + 1) << s) - a - 1) // x)

        if f_min > f_max:
            return None

    return Range(f_min, f_max)


def algorithm_1a(p: Problem, f: int, s: int, X: set[int]) -> SolutionRange | None:
    a_min = 0
    a_max = (1 << s) - 1

    for x in X:
        y = (x * p.t + p.r_d) // p.d

        a_min = max(a_min, (y << s) - x * f)
        a_max = min(a_max, ((y + 1) << s) - x * f - 1)

        if a_min > a_max:
            return None

    return SolutionRange(f=f, A=Range(a_min, a_max), s=s)


def get_input_set(p: Problem) -> set[int]:
    def add_range(X: set[int], start: int, stop: int) -> set[int]:
        X.add(start)
        X.add(stop)
        jump = p.d // p.t - 1
        if jump > 0:
            last = p.at(start)

            # skip to the first x value for which p.at(x) > last
            x = ((last + 1) * p.d - p.r_d) // p.t - 1
            while x <= stop and p.at(x) <= last:
                x += 1

            if x > stop:
                return X

            assert p.at(x) == last + 1, f"failed for {start=} {stop=} {last=} {x=}"
            X.add(x - 1)
            X.add(x)
            last = p.at(x)

            while True:
                x += jump
                current = p.at(x)
                while current == last:
                    x += 1
                    current = p.at(x)

                if x > stop:
                    break

                X.add(x - 1)
                X.add(x)
                last = current
        else:
            # add the range start..stop
            X.update(range(start + 1, stop))
        return X

    if 0 < p.t < p.d and math.gcd(p.t, p.d) == 1 and p.d * 2 <= p.u:
        # conjecture 14
        t_inv = pow(p.t, -1, p.d)
        x_bc = (-t_inv * (p.r_d + 1)) % p.d
        x_ad = (-t_inv * p.r_d) % p.d
        x1 = x_bc
        x2 = x_ad
        x3 = p.u // p.d * p.d + x_bc
        x4 = p.u // p.d * p.d + x_ad
        if x3 > p.u:
            x3 -= p.d
        if x4 > p.u:
            x4 -= p.d
        return {x1, x2, x3, x4}
    elif p.d * 2 <= p.u:
        # only consider the first and last d values
        X = {0}
        add_range(X, 0, p.d - 1)
        add_range(X, p.u - p.d + 1, p.u)
        return X
    else:
        # consider all x values
        return add_range(set(), 0, p.u)


def find_a_zero(p: Problem) -> Solution | None:
    p = p.simplified()

    if p.t == 0 or p.d == 1:
        # trivial solution
        return Solution(f=p.t, a=0, s=0)

    if p.t >= p.d:
        # reduce the problem to a smaller one
        k = p.t // p.d
        solution = find_a_zero(Problem(u=p.u, t=p.t % p.d, d=p.d, r_d=p.r_d))
        if solution is None:
            return None
        return Solution(
            f=solution.f + (k << solution.s),
            a=0,
            s=solution.s,
        )

    X = get_input_set(p)

    # +10 to be safe
    s = math.ceil(math.log2(p.u)) + math.ceil(math.log2(p.d)) + 10
    F = algorithm_1f(p, a=0, s=s, X=X)
    if F is None:
        return None
    assert F.min > 0

    f = pick_most_even(F)
    while f % 2 == 0:
        f //= 2
        s -= 1

    return Solution(f=f, a=0, s=s)


def pick_most_even(range: Range) -> int:
    """
    Returns the most even number in the range.
    """
    a = range.min
    b = range.max
    if a == 0:
        return 0

    scale = 0
    while a < b:
        scale += 1
        a = (a + 1) // 2
        b = b // 2

    return a << scale


def verify(p: Problem, s: Solution | SolutionRange, log: bool = True) -> bool | None:
    """
    Verifies that the solution (range) is correct.
    """
    if min(p.u, p.d) > 1_000_000:
        print(f" ⚠️ {s}")
        print(" ⚠️ Unable to verify, because the input space is too large!")
        return None

    if isinstance(s, SolutionRange):
        s_min = Solution(f=s.f, a=s.A.min, s=s.s)
        s_max = Solution(f=s.f, a=s.A.max, s=s.s)
        correct = verify(p, s_min, log=False) and verify(p, s_max, log=False)
        assert correct is not None
    else:
        incorrect_x = None
        if p.d * 2 < p.u:
            # only check the first and last d values
            for x in range(0, p.d + 1):
                if p.at(x) != (x * s.f + s.a) >> s.s:
                    incorrect_x = x
                    break
            if incorrect_x is None:
                for x in range(p.u - p.d, p.u + 1):
                    if p.at(x) != (x * s.f + s.a) >> s.s:
                        incorrect_x = x
                        break
        else:
            # check all x values
            for x in range(p.u + 1):
                if p.at(x) != (x * s.f + s.a) >> s.s:
                    incorrect_x = x
                    break

        correct = incorrect_x is None
        if incorrect_x is not None:
            expected = p.at(incorrect_x)
            actual = (incorrect_x * s.f + s.a) >> s.s
            print(
                f" ❌ INCORRECT: Expected {expected} but got {actual} for x={incorrect_x}"
            )

    if log:
        if not correct:
            print(f" ❌ {s}")
        else:
            print(f" ✅ {s}")
    return correct


def is_minimal(p: Problem, s: Solution | SolutionRange) -> bool:
    """
    If a solution with a smaller s exists, return False.
    """

    if s.s == 0:
        return True

    X = get_input_set(p)

    if algorithm_1a(p, s.f + 1, s.s, X=X) is not None:
        return False
    if s.f > 0:
        if algorithm_1a(p, s.f - 1, s.s, X=X) is not None:
            return False

    return True


def is_minimal_zero(p: Problem, s: Solution) -> bool:
    """
    If a solution with a smaller s exists, return False.
    """
    assert s.a == 0

    if s.s == 0:
        return True
    if s.f % 2 == 0:
        return False

    X = get_input_set(p)

    range = algorithm_1a(p, s.f + 1, s.s, X=X)
    if range is not None and range.A.min == 0:
        return False
    range = algorithm_1a(p, s.f - 1, s.s, X=X)
    if range is not None and range.A.min == 0:
        return False

    return True


def solve(p: Problem) -> None:
    print(f"Solving {p}")
    p = p.simplified()
    print("    Primitive solution:")
    print(f"    {primitive_solution(p).original()}")
    print()
    print("    Minimal solution range:")
    expected_input_set_size = min(p.u, 2 * p.d, 4 * (p.t % p.d))
    minimal, is_minimal = algorithm_2_minimal(
        p, force_minimal=expected_input_set_size < 10_000, log=False
    )
    verify(p, minimal)
    if not is_minimal:
        print(" ⚠️ The solution range is NOT guaranteed to be minimal!")

    print()
    zero = find_a_zero(p)
    if zero is None:
        print("    No solutions with a=0 exist")
    else:
        print("    Smallest solution with a=0:")
        verify(p, zero)
    print()

## Examples

The below code cell contains a small example, which defines a problem and finds the minimal solution range as well as the smallest a=0 solution.

Calculated solutions are verified and printed to the console.

There's a playground below, so I recommend not modifying the examples.


In [53]:
# Converting an 8-bit unsigned integer to a 5-bit unsigned integer
problem = Problem.round(u=U8, t=31, d=255)
solve(problem)

# Converting an 5-bit unsigned integer to a 8-bit unsigned integer
problem = Problem.round(u=31, t=255, d=31)
solve(problem)

Solving Problem(u=255, t=31, d=255, r_d=127)
    Primitive solution:
    Solution(f=249, a=1020, s=11)

    Minimal solution range:
 ✅ SolutionRange(f=249, A=Range(min=1014, max=1026), s=11)

    No solutions with a=0 exist

Solving Problem(u=31, t=255, d=31, r_d=15)
    Primitive solution:
    Solution(f=1053, a=62, s=7)

    Minimal solution range:
 ✅ SolutionRange(f=527, A=Range(min=23, max=23), s=6)

    No solutions with a=0 exist



In [54]:
# Diving by 17 with different rounding modes
d = 17
solve(Problem.floor(u=U32, t=1, d=d))
solve(Problem.round(u=U32, t=1, d=d))
solve(Problem.ceil(u=U32, t=1, d=d))

Solving Problem(u=4294967295, t=1, d=17, r_d=0)
    Primitive solution:
    Solution(f=4042322161, a=0, s=36)

    Minimal solution range:
 ✅ SolutionRange(f=252645135, A=Range(min=252645135, max=252645135), s=32)

    Smallest solution with a=0:
 ✅ Solution(f=4042322161, a=0, s=36)

Solving Problem(u=4294967295, t=1, d=17, r_d=8)
    Primitive solution:
    Solution(f=4042322161, a=32338577288, s=36)

    Minimal solution range:
 ✅ SolutionRange(f=252645135, A=Range(min=2273806215, max=2273806215), s=32)

    No solutions with a=0 exist

Solving Problem(u=4294967295, t=1, d=17, r_d=16)
    Primitive solution:
    Solution(f=4042322161, a=64677154575, s=36)

    Minimal solution range:
 ✅ SolutionRange(f=252645135, A=Range(min=4294967295, max=4294967295), s=32)

    No solutions with a=0 exist



In [55]:
# °F to °C requires multiplying by 5/9
problem = Problem.round(u=U32, t=5, d=9)
solve(problem)

Solving Problem(u=4294967295, t=5, d=9, r_d=4)
    Primitive solution:
    Solution(f=9544371769, a=7635497415, s=34)

    Minimal solution range:
 ✅ SolutionRange(f=9544371769, A=Range(min=7635497415, max=9067153180), s=34)

    No solutions with a=0 exist



## Playground

### Defining the problem

The rounding function is specified by using either `Problem.floor(u, t, d)`, `Problem.round(u, t, d)`, or `Problem.ceil(u, t, d)`.

`u` is the input range. While you can use any number, the constants `U8`, `U16`, `U32`, and `U64` are provided for convenience.

`t` and `d` define the fraction `t/d`.


In [59]:
# Playground
problem = Problem.floor(u=U32, t=2551819, d=477878880)
solve(problem)

Solving Problem(u=4294967295, t=2551819, d=477878880, r_d=0)
    Primitive solution:
    Solution(f=12312940052777975, a=0, s=61)

    Minimal solution range:
 ⚠️ SolutionRange(f=6156470026388987, A=Range(min=732438272, max=2477622224), s=60)
 ⚠️ Unable to verify, because the input space is too large!

    Smallest solution with a=0:
 ⚠️ Solution(f=12312940052777975, a=0, s=61)
 ⚠️ Unable to verify, because the input space is too large!



In [57]:
# TODO: remove this before release
MAX = 32

for t in range(1, MAX):
    print(f"t = {t}")
    for d in range(1, MAX):
        if math.gcd(t, d) != 1:
            continue

        for u in range(1, MAX):
            for r_d in range(d):
                problem = Problem.floor(u=U8, t=t, d=d)
                try:
                    solution, _ = algorithm_2_minimal(
                        problem, force_minimal=True, log=False
                    )
                    assert verify(problem, solution, log=False)
                    assert is_minimal(problem, solution)
                    zero = find_a_zero(problem)
                    if zero is not None:
                        assert verify(problem, zero, log=False)
                        assert is_minimal_zero(problem, zero)
                except:
                    print(f"❌ failed for {problem}")
                    solution, _ = algorithm_2_minimal(problem, force_minimal=True)
                    verify(problem, solution)
                    raise

t = 1
t = 2
t = 3
t = 4
t = 5
t = 6
t = 7
t = 8
t = 9
t = 10
t = 11
t = 12
t = 13
t = 14
t = 15
t = 16
t = 17
t = 18
t = 19
t = 20
t = 21
t = 22
t = 23
t = 24
t = 25
t = 26
t = 27
t = 28
t = 29
t = 30
t = 31
