In [10]:
from math import log10, pi
import cmath
import time
import random

BASE = 10**4  # base หลักพันหลักหมื่นเพื่อให้ FFT/Karatsuba ทำงานได้ดี

# ---------- Utilities: แปลงเลข <-> digit list ----------

def _split_to_digits(x: int, base: int):
    """Splits an integer x into a list of digits in the given base (little-endian)."""
    if x == 0:
        return [0]
    digits = []
    abs_x = abs(x)
    while abs_x:
        digits.append(abs_x % base)
        abs_x //= base
    return digits


def _digits_to_int(digits, base):
    """Converts a list of digits (little-endian) back to an integer."""
    res = 0
    powb = 1
    for d in digits:
        res += d * powb
        powb *= base
    return res


def _normalize_inplace(arr, base):
    """
    Handles carries/borrows in-place for a list of digits thatอาจมีค่าเป็นลบ
    ใช้หลังจาก Karatsuba + FFT ที่มีค่า negative ชั่วคราวได้
    """
    carry = 0
    for i in range(len(arr)):
        total = arr[i] + carry
        arr[i] = total % base
        carry = total // base
        # ปรับกรณี % กับเลขลบของ Python
        if arr[i] < 0:
            arr[i] += base
            carry -= 1

    while carry:
        arr.append(carry % base)
        carry //= base

    # ลบ 0 ด้านหน้าที่ไม่จำเป็น (most significant)
    while len(arr) > 1 and arr[-1] == 0:
        arr.pop()


def _add_digits(a, b, base):
    """Adds two digit lists (little-endian)."""
    n = max(len(a), len(b))
    c = []
    carry = 0
    for i in range(n):
        av = a[i] if i < len(a) else 0
        bv = b[i] if i < len(b) else 0
        s = av + bv + carry
        c.append(s % base)
        carry = s // base
    if carry:
        c.append(carry)
    return c


def _sub_digits(a, b, base):
    """
    Subtracts two digit lists (a - b) ในกรณี a >= b และใช้ borrow แบบปกติ
    (ใช้สำหรับงานทั่วไป ไม่ใช่ intermediate ที่ต้องการค่าติดลบได้)
    """
    c = []
    borrow = 0
    for i in range(len(a)):
        av = a[i]
        bv = b[i] if i < len(b) else 0
        s = av - bv - borrow
        if s < 0:
            s += base
            borrow = 1
        else:
            borrow = 0
        c.append(s)

    while len(c) > 1 and c[-1] == 0:
        c.pop()
    return c

# ---------- Multiplication Algorithms ----------

def schoolbook_mul(a: int, b: int):
    """ใช้ multiplication ของ Python ตรง ๆ"""
    return a * b


def karatsuba_mul(a: int, b: int, base=10**4, cutoff=64):
    """Karatsuba multiplication ด้วย digit list base 'base'"""

    # จัดการ sign ก่อน
    sign = 1
    if a < 0:
        a = -a
        sign *= -1
    if b < 0:
        b = -b
        sign *= -1

    Ad = _split_to_digits(a, base)
    Bd = _split_to_digits(b, base)

    n = max(len(Ad), len(Bd))
    while len(Ad) < n:
        Ad.append(0)
    while len(Bd) < n:
        Bd.append(0)

    def kar(a_digits, b_digits):
        m = len(a_digits)
        if m == 0:
            return [0]

        # base case: ใช้ builtin ตรง ๆ เมื่อ digit น้อย
        if m <= cutoff:
            A = _digits_to_int(a_digits, base)
            B = _digits_to_int(b_digits, base)
            prod = A * B
            return _split_to_digits(prod, base)

        k = (m + 1) // 2

        a0 = a_digits[:k]
        a1 = a_digits[k:]
        b0 = b_digits[:k]
        b1 = b_digits[k:]

        # P1, P2
        z0 = kar(a0, b0)          # a0 * b0
        z2 = kar(a1, b1)          # a1 * b1

        # P3 = (a0+a1)*(b0+b1)
        sum_a = _add_digits(a0, a1, base)
        sum_b = _add_digits(b0, b1, base)
        z1 = kar(sum_a, sum_b)

        # z1 = P3 - P1 - P2
        # อนุญาตให้เป็นค่าติดลบชั่วคราวแล้วค่อย normalize
        z1_len = len(z1)
        for i, val in enumerate(z0):
            if i < z1_len:
                z1[i] -= val
            else:
                z1.append(-val)
        z1_len = len(z1)
        for i, val in enumerate(z2):
            if i < z1_len:
                z1[i] -= val
            else:
                z1.append(-val)

        # เอา z0, z1, z2 มาวางตำแหน่งตาม B^0, B^k, B^(2k)
        max_len = max(len(z0), len(z1) + k, len(z2) + 2 * k) + 1
        res = [0] * max_len

        for i, v in enumerate(z0):
            res[i] += v
        for i, v in enumerate(z1):
            res[i + k] += v
        for i, v in enumerate(z2):
            res[i + 2 * k] += v

        _normalize_inplace(res, base)
        return res

    res_digits = kar(Ad, Bd)
    res = _digits_to_int(res_digits, base)
    return sign * res

# ---------- FFT ----------

def fft(a, invert):
    """Iterative Cooley–Tukey FFT (in-place)."""
    n = len(a)
    j = 0
    # bit-reversal
    for i in range(1, n):
        bit = n >> 1
        while j & bit:
            j ^= bit
            bit >>= 1
        j |= bit
        if i < j:
            a[i], a[j] = a[j], a[i]

    length = 2
    while length <= n:
        ang = 2 * pi / length * (-1 if invert else 1)
        wlen = cmath.rect(1, ang)  # cos + i sin

        for i in range(0, n, length):
            w = 1 + 0j
            half = length >> 1
            for j in range(i, i + half):
                u = a[j]
                v = a[j + half] * w
                a[j] = u + v
                a[j + half] = u - v
                w *= wlen
        length <<= 1

    if invert:
        for i in range(n):
            a[i] /= n


def fft_mul(a: int, b: int, base=BASE):
    """FFT-based multiplication บน digit list base 'base'."""
    if a == 0 or b == 0:
        return 0

    # จัดการ sign
    sign = 1
    if a < 0:
        a = -a
        sign *= -1
    if b < 0:
        b = -b
        sign *= -1

    A = _split_to_digits(a, base)
    B = _split_to_digits(b, base)

    m = len(A) + len(B) - 1  # จำนวน coefficient จริง
    n = 1
    while n < m:
        n <<= 1

    # เตรียม complex array สำหรับ FFT
    fa = [complex(d, 0) for d in A] + [0j] * (n - len(A))
    fb = [complex(d, 0) for d in B] + [0j] * (n - len(B))

    # forward FFT
    fft(fa, invert=False)
    fft(fb, invert=False)

    # point-wise multiply
    for i in range(n):
        fa[i] *= fb[i]

    # inverse FFT
    fft(fa, invert=True)

    # ดึงค่าจริงมาทำเป็น integer coefficients แล้ว normalize carry
    result = [0] * (m + 2)  # +2 กัน carry ท้าย ๆ
    for i in range(m):
        result[i] = int(round(fa[i].real))

    _normalize_inplace(result, base)
    res = _digits_to_int(result, base)
    return sign * res

# ---------- Dispatcher ----------

def multiply(a: int, b: int, method='auto', **kwargs):
    """
    เลือกวิธีคูณตาม method:
      - 'school'     : ใช้ builtin *
      - 'karatsuba'  : ใช้ karatsuba_mul
      - 'fft'        : ใช้ fft_mul
      - 'auto'       : เลือกตามจำนวน digit ของเลขใน base 10
    """
    if a == 0 or b == 0:
        return 0

    if method == 'school':
        return schoolbook_mul(a, b)

    if method == 'karatsuba':
        return karatsuba_mul(
            a, b,
            base=kwargs.get('base', BASE),
            cutoff=kwargs.get('cutoff', 64)
        )

    if method == 'fft':
        return fft_mul(
            a, b,
            base=kwargs.get('base', BASE)
        )

    # --- auto selection ---
    abs_a = abs(a)
    abs_b = abs(b)
    digits_a = max(1, int(log10(abs_a)) + 1) if abs_a != 0 else 1
    digits_b = max(1, int(log10(abs_b)) + 1) if abs_b != 0 else 1
    maxd = max(digits_a, digits_b)

    if maxd < 50:
        # ตัวเล็ก ๆ ใช้ builtin เร็วสุด
        return schoolbook_mul(a, b)
    elif maxd < 3000:
        # กลาง ๆ ใช้ Karatsuba
        return karatsuba_mul(a, b, base=kwargs.get('base', 1000))
    else:
        # ใหญ่มาก ใช้ FFT
        return fft_mul(a, b, base=kwargs.get('base', 1000))

# ---------- Benchmark ----------

def benchmark_example():
    """Runs a light benchmark."""
    A_DIGITS = 2000
    B_DIGITS = 2000

    a = int(''.join(str(random.randint(0, 9)) for _ in range(A_DIGITS)))
    b = int(''.join(str(random.randint(0, 9)) for _ in range(B_DIGITS)))

    print("--- Benchmark ---")
    print(f"Digits: {len(str(a))} (A), {len(str(b))} (B)")

    expected_res = a * b
    print(f"Built-in result length: {len(str(expected_res))}")
    print("-" * 20)

    for method in ('school', 'karatsuba', 'fft', 'auto'):
        t0 = time.time()
        if method == 'fft':
            res = multiply(a, b, method=method, base=10**4)
        else:
            res = multiply(a, b, method=method, base=1000)

        dt = time.time() - t0
        correct = (res == expected_res)
        print(f"{method:10s} -> Time: {dt:.3f}s, Result digits: {len(str(res))}, Correct: {correct}")

    print("-" * 20)


if __name__ == "__main__":
    print("--- Test Cases ---")
    tests = [
        (1234, 5678),
        (10**50 + 12345, 10**40 + 54321),
        (-12345678901234567890, 98765432109876543210),
        (0, 123456),
        (1, 10**1000 + 1),
    ]

    for a, b in tests:
        builtin_res = a * b
        kar_res = karatsuba_mul(a, b, base=BASE)
        fft_res = fft_mul(a, b, base=BASE)
        auto_res = multiply(a, b, method='auto', base=BASE)

        print(f"a = {a}")
        print(f"b = {b}")
        print(f"builtin:  {builtin_res}")
        print(f"Karatsuba: {kar_res} ({kar_res == builtin_res})")
        print(f"FFT:       {fft_res} ({fft_res == builtin_res})")
        print(f"Auto:      {auto_res} ({auto_res == builtin_res})")
        print("-----")

    # ถ้าอยากลอง benchmark ให้ uncomment บรรทัดนี้
    # benchmark_example()


--- Test Cases ---
a = 1234
b = 5678
builtin:  7006652
Karatsuba: 7006652 (True)
FFT:       7006652 (True)
Auto:      7006652 (True)
-----
a = 100000000000000000000000000000000000000000000012345
b = 10000000000000000000000000000000000054321
builtin:  1000000000000000000000000000000000005432100000123450000000000000000000000000000000670592745
Karatsuba: 1000000000000000000000000000000000005432100000123450000000000000000000000000000000670592745 (True)
FFT:       1000000000000000000000000000000000005432100000123450000000000000000000000000000000670592745 (True)
Auto:      1000000000000000000000000000000000005432100000123450000000000000000000000000000000670592745 (True)
-----
a = -12345678901234567890
b = 98765432109876543210
builtin:  -1219326311370217952237463801111263526900
Karatsuba: -1219326311370217952237463801111263526900 (True)
FFT:       -1219326311370217952237463801111263526900 (True)
Auto:      -1219326311370217952237463801111263526900 (True)
-----
a = 0
b = 123456
builtin:  0
Kar