From MindYourDecisions

Find all integers `a, b >= 0` such that `sqrt(a) + sqrt(b) = sqrt(2023)`

Bonus:
Find all integers `a, b >= 0` such that `sqrt(a) + sqrt(b) = sqrt(n)`

This is from https://youtu.be/pxHd8tLI65Q

In [68]:
# Basic way:
import math

def calc_b(a):
  return 2023 + a - (2*math.sqrt(a*2023))

def solve_all_integers(n=2023):
  for a in range(0, math.floor(n/4)):
    b = calc_b(a)
    if b.is_integer():
      yield (int(a), int(b))
      yield (int(b), int(a))

for pair in sorted(list(solve_all_integers())):
  print(pair)

(0, 2023)
(7, 1792)
(28, 1575)
(63, 1372)
(112, 1183)
(175, 1008)
(252, 847)
(343, 700)
(448, 567)
(567, 448)
(700, 343)
(847, 252)
(1008, 175)
(1183, 112)
(1372, 63)
(1575, 28)
(1792, 7)
(2023, 0)


In [69]:
# Faster way:
import math
from sympy.ntheory import factorint

def calc_b(a):
  return 2023 + a - (2*math.sqrt(a*2023))

def solve_all_integers_faster(n=2023):
  square_constant = 1
  nonsquare_constant = 1
  for prime_factor, power in factorint(n).items():
    if power%2 == 0:
      square_constant *= int(prime_factor**(power/2))
    else:
      if power != 1:
        square_constant *= int(prime_factor**((power - 1)/2))
      nonsquare_constant *= prime_factor
  for x in range(0, math.floor(square_constant/2) + 1):
    y = square_constant - x
    a = nonsquare_constant * x * x
    b = nonsquare_constant * y * y
    yield (a, b)
    yield (b, a)

for pair in sorted(list(solve_all_integers_faster())):
  print(pair)


(0, 2023)
(7, 1792)
(28, 1575)
(63, 1372)
(112, 1183)
(175, 1008)
(252, 847)
(343, 700)
(448, 567)
(567, 448)
(700, 343)
(847, 252)
(1008, 175)
(1183, 112)
(1372, 63)
(1575, 28)
(1792, 7)
(2023, 0)


In [70]:
sorted(list(solve_all_integers_faster())) == sorted(list(solve_all_integers()))

True

In [80]:
list(enumerate(sorted(list(solve_all_integers_faster()))))

[(0, (0, 2023)),
 (1, (7, 1792)),
 (2, (28, 1575)),
 (3, (63, 1372)),
 (4, (112, 1183)),
 (5, (175, 1008)),
 (6, (252, 847)),
 (7, (343, 700)),
 (8, (448, 567)),
 (9, (567, 448)),
 (10, (700, 343)),
 (11, (847, 252)),
 (12, (1008, 175)),
 (13, (1183, 112)),
 (14, (1372, 63)),
 (15, (1575, 28)),
 (16, (1792, 7)),
 (17, (2023, 0))]

In [81]:
list(enumerate(sorted(list(solve_all_integers()))))

[(0, (0, 2023)),
 (1, (7, 1792)),
 (2, (28, 1575)),
 (3, (63, 1372)),
 (4, (112, 1183)),
 (5, (175, 1008)),
 (6, (252, 847)),
 (7, (343, 700)),
 (8, (448, 567)),
 (9, (567, 448)),
 (10, (700, 343)),
 (11, (847, 252)),
 (12, (1008, 175)),
 (13, (1183, 112)),
 (14, (1372, 63)),
 (15, (1575, 28)),
 (16, (1792, 7)),
 (17, (2023, 0))]

In [88]:
from time import process_time_ns
def benchmark(f):
  t0 = process_time_ns()
  list(f)
  t1 = process_time_ns()
  return t1 - t0
print("Basic:\t{} ns".format(benchmark(solve_all_integers())))
print("Faster:\t{} ns".format(benchmark(solve_all_integers_faster())))


Basic:	388847 ns
Faster:	72335 ns
