Context: The Pythagorean theorem
  
If $a$, $b$, $c$ are the sides of a right triangle with longest side $c$ then
$a^2 + b^2 = c^2$.

This theorem is attributed to the Greek mathematician Pythagoras, who was born around 570 BC,
but it was also known by mathematicians in Mesopotamia, China, and India.

The equation has infinitely many solutions in positive integers. Euclid discovered how to generate
all integer solutions. (Euclid's Elements, c. 300 BC)

In [9]:
from math import gcd

def generate_primitive_pythagorean_triples(N):
    for m in range(2, N):
        for n in range(1 + (m % 2), m, 2):
            if gcd(m, n) == 1:
                a = (m * m) - (n * n)
                b = 2 * m * n
                c = (m * m) + (n * n)
                yield (a, b, c)

for triple in generate_primitive_pythagorean_triples(10):
    print(triple)

(3, 4, 5)
(5, 12, 13)
(15, 8, 17)
(7, 24, 25)
(21, 20, 29)
(9, 40, 41)
(35, 12, 37)
(11, 60, 61)
(45, 28, 53)
(33, 56, 65)
(13, 84, 85)
(63, 16, 65)
(55, 48, 73)
(39, 80, 89)
(15, 112, 113)
(77, 36, 85)
(65, 72, 97)
(17, 144, 145)


Fermat's conjecture (1637): The equation $a^n + b^n = c^n$ has no solutions in positive integers when $n \ge 3$.

Fermat claimed to have found a proof, but he never wrote it down. It was finally proved in 1994 by Andrew Wiles, a British mathematician.

<img src="euler-conj.png" width="1200">

In [13]:
def find_solutions():
    solutions = []
    for a in range(1, 200):
        for b in range(1, 200):
            for c in range(1, 200):
                for d in range(1, 200):
                    e = (a**5 + b**5 + c**5 + d**5)**(1/5)
                    if e.is_integer() and 1 <= e < 200:
                        solutions.append((a, b, c, d, int(e)))
    return solutions

In [47]:
# 1. Use integer arithmetic instead of floating point!

from sympy import integer_nthroot

def find_solutions():
    solutions = []
    for a in range(1, 200):
        for b in range(1, 200):
            for c in range(1, 200):
                for d in range(1, 200):
                    e = a**5 + b**5 + c**5 + d**5
                    root, is_exact = integer_nthroot(e, 5)
                    if is_exact and 1 <= root < 200:
                        solutions.append((a, b, c, d, root))
    return solutions

In [20]:
# 2. Take advantage of symmetry.

from sympy import integer_nthroot

def find_solutions():
    solutions = []
    for a in range(1, 200):
        for b in range(a, 200):
            for c in range(b, 200):
                for d in range(c, 200):
                    e = a**5 + b**5 + c**5 + d**5
                    root, is_exact = integer_nthroot(e, 5)
                    if is_exact and 1 <= root < 200:
                        solutions.append((a, b, c, d, root))
    return solutions

In [21]:
# 3. Use itertools.

from itertools import combinations_with_replacement
from sympy import integer_nthroot

def find_solutions():
    solutions = []
    for combination in combinations_with_replacement(range(1, 200), 4):
        a, b, c, d = combination
        e = a**5 + b**5 + c**5 + d**5
        root, is_exact = integer_nthroot(e, 5)
        if is_exact and 1 <= root < 200:
            solutions.append((a, b, c, d, root))
    return solutions

In [35]:
# 4. Pre-compute fifth powers

def find_solutions(N): 
    fifth_powers = {k**5 : k for k in range(1, N)}
    solutions = []
    for a, b, c, d in combinations_with_replacement(fifth_powers, 4):
        e = a + b + c + d
        if e in fifth_powers:
            solutions.append((fifth_powers[a], fifth_powers[b], fifth_powers[c],
                              fifth_powers[d], fifth_powers[e]))
    return solutions

In [38]:
%time find_solutions(200)

CPU times: user 9.15 s, sys: 15.8 ms, total: 9.16 s
Wall time: 9.16 s


[(27, 84, 110, 133, 144)]

In [45]:
# 5. Pre-compute values of a^5 + b^5
from math import gcd

def find_solutions(N):
    lookup = {a**5 + b**5 : (a, b) 
          for a in range(1, N)
          for b in range(a, N)}
    S = sorted(lookup)
    solutions = []
    for e in range(1, N):
        e5 = e ** 5
        halfway = e5 // 2
        for s in S:
            if s > halfway:
                break
            if (e5 - s) in lookup:
                a, b = lookup[s]
                c, d = lookup[e5 - s]
                if b < c and gcd(a, b, c, d) == 1:
                    solutions.append((a, b, c, d, e))
    return solutions    

In [46]:
%time find_solutions(200)

CPU times: user 143 ms, sys: 3.09 ms, total: 146 ms
Wall time: 145 ms


[(27, 84, 110, 133, 144)]

Lesson:

1. Avoid floating point
2. Take advantage of symmetry
3. Avoid repeating the same calculation