A magic square is a square array of distinct natural numbers, where each row, each column, and both long diagonals sum to the same “magic number.”

A prime magic square is a magic square consisting of only prime numbers.

Is it possible to construct a 4-by-4 prime magic square with a magic number of 2026? If so, give an example; if not, why not?

In [47]:
import numpy as np
from itertools import combinations
from tqdm import tqdm

In [21]:
def prime_number_generator(n):
    """Generate prime numbers up to n."""
    for num in range(2, n + 1):
        is_prime = True
        for i in range(2, int(num**0.5) + 1):
            if num % i == 0:
                is_prime = False
                break
        if is_prime:
            yield num


def factorial(n: int) -> int:
    if n in (0, 1):
        return 1
    return n * factorial(n - 1)

In [98]:
primes = set(prime_number_generator(2026)) - {2} # 2 cannot belong to the solution, for parity reasons
nb_primes_under_2026 = len(primes)
print(f"{nb_primes_under_2026=}")

nb_comb_four_distinct_primes = int(
    factorial(nb_primes_under_2026) / (factorial(4) * factorial(nb_primes_under_2026 - 4))
)
print(f"{nb_comb_four_distinct_primes=:_}")

nb_primes_under_2026=305
nb_comb_four_distinct_primes=353_518_180


In [29]:
for _ in range(nb_comb_four_distinct_primes):
    ...

The above , ie doing nothing for 358m times, already takes 7.2 seconds to run, so i'd better prune!

In [94]:
combs=[]
for comb in tqdm(combinations(primes, 4), total=nb_comb_four_distinct_primes):
    if sum(comb) == 2026:
        combs.append(comb)
        
print(len(combs))

100%|██████████| 353518180/353518180 [00:54<00:00, 6529920.36it/s]

90175





In [95]:
print(
    f"Number of combinations of four distinct primes that sum to 2026: {len(combs)} ({round(len(combs)/nb_comb_four_distinct_primes, 4)}% of all combinations of four primes below 2026)"
)

Number of combinations of four distinct primes that sum to 2026: 90175 (0.0003% of all combinations of four primes below 2026)


If a magic square exists, then all of its numbers belong to either 3 (diagonal elements) or 2 (other elements) combinations summing up to 2026

Let's discard all prime numbers in `primes` that do not satisfy this condition, if any

In [102]:
def passes(nb:int)->bool:
    k=0
    for comb in combs:
        if nb in comb:
            k+=1
            if k ==2: return True
    return False

f"Out of {len(primes)} primes under 2026 and != 2, {sum(map(passes, primes))} belong to at least two combinations of four primes summing up to 2026"

'Out of 305 primes under 2026 and != 2, 303 belong to at least two combinations of four primes summing up to 2026'

Below some validity tests

In [159]:
# there should be at least 4 disjoint combinations from the 90175 combinations
comb_=set(combs[0])
disjoint_combs=[comb_]
n=1
assert len(comb_)==4
for comb in combs:
    if all(not set(comb) & other for other in disjoint_combs):
        n += 1
        disjoint_combs.append(set(comb))
        if n ==4:
            print(disjoint_combs)
            break

[{3, 5, 2011, 7}, {11, 13, 971, 1031}, {1033, 953, 17, 23}, {937, 19, 31, 1039}]


The aboev test passes, now trying to nfit 4 rows AND one column

In [164]:
comb_=set(combs[0])
disjoint_combs=[comb_]
n=1
assert len(comb_)==4
for comb in combs:
    if all(not set(comb) & other for other in disjoint_combs):
        n += 1
        disjoint_combs.append(set(comb))
        if n ==4:
            print(disjoint_combs)
            break

[{3, 5, 2011, 7}, {11, 13, 971, 1031}, {1033, 953, 17, 23}, {937, 19, 31, 1039}]


In [171]:
ad,bd,cd,dd=list(disjoint_combs)
t = [set(comb) for comb in combs]
for a in ad:
    for b in bd:
        for c in cd:
            for d in dd:
                if {a,b,c,d} in t:
                    print(a,b,c,d)
    
    

3 971 1033 19


In [104]:
def f():
    "Find a set of primes that satisfy top row, leftmost column and endpoints of the topright-bottomleft diagonal"
    candidates = primes

    for corner in candidates:
        look_in = candidates - {corner}
        for i in range(3):
            for others in combinations(look_in, 3):
                if corner + sum(others) == 2026:
                    print(sorted({corner} | set(others)))
                    look_in -= set(others)
                    if i == 2:
                        return "possible"
                    break
            else:
                continue
    return "impossible"


f()

[3, 5, 7, 2011]
[3, 73, 919, 1031]
[3, 13, 977, 1033]


'possible'

In [161]:
def is_on_long_diagonal(row, col) -> int:
    "0 if not on any long diagonal, 1 on topleft-bottomright diagonal, 2 the other diagonal"
    if row == col:
        return 1
    if row + col == 3:
        return 2
    return 0


all_candidates = set(e for e in primes if passes(e))
nb_pos_seen = 0


def is_invalid(pos: np.ndarray) -> bool:
    for r in range(4):
        if np.count_nonzero(pos[r]) == 4 and pos[r].sum() != 2026:
            return True
    for c in range(4):
        if np.count_nonzero(pos[:, c]) == 4 and pos[:, c].sum() != 2026:
            return True
    if np.count_nonzero(np.diagonal(pos)) == 4 and np.trace(pos).sum() != 2026:
        return True
    if (
        np.count_nonzero(np.diagonal(np.fliplr(pos))) == 4
        and np.trace(np.fliplr(pos)).sum() != 2026
    ):
        return True
    return False


def recurse(pos: np.ndarray) -> None:
    global nb_pos_seen
    nb_pos_seen += 1
    
    if is_invalid(pos):
        return
    if np.count_nonzero(pos) == 16:
        print('solutoin!')
        for row in pos:
            print(*row)
        exit()
    if np.count_nonzero(pos) >= 13:
        for row in pos:
            print(*row)
        print("\n")
        return
    for r in range(4):
        for c in range(4):
            if pos[r, c] == 0:  # fill this cell
                row_sum, col_sum = pos[r, :].sum(), pos[:, c].sum()
                is_on_diag = is_on_long_diagonal(r, c)
                diag_sum = np.trace(pos) if is_on_diag == 1 else (np.trace(np.fliplr(pos)) if is_on_diag == 2 else 0)
                sorted_remaining_candidates = sorted(all_candidates - set(pos.flatten()))

                remaining_row = 4 - np.count_nonzero(pos[r])
                remaining_col = 4 - np.count_nonzero(pos[:, c])
                if is_on_diag:
                    if is_on_diag == 1:
                        remaining_diag = 4 - np.count_nonzero(np.diagonal(pos))
                    else:
                        remaining_diag = 4 - np.count_nonzero(np.diagonal(np.fliplr(pos)))

                # Row pruning
                if (row_sum + sum(sorted_remaining_candidates[:remaining_row - 1]) > 2026) or \
                (row_sum + sum(sorted_remaining_candidates[-(remaining_row - 1):]) < 2026):
                    return

                # Column pruning
                if (col_sum + sum(sorted_remaining_candidates[:remaining_col - 1]) > 2026) or \
                (col_sum + sum(sorted_remaining_candidates[-(remaining_col - 1):]) < 2026):
                    return

                # Diagonal pruning
                if is_on_diag:
                    if (diag_sum + sum(sorted_remaining_candidates[:remaining_diag - 1]) > 2026) or \
                    (diag_sum + sum(sorted_remaining_candidates[-(remaining_diag - 1):]) < 2026):
                        return
                ################
                last_one_on_line = np.count_nonzero(pos[r]) == 3
                if last_one_on_line:
                    if 2026 - row_sum not in all_candidates - set(pos.flatten()): return
                last_one_on_column = np.count_nonzero(pos[:, c]) == 3
                if last_one_on_column:
                    if 2026 - col_sum not in all_candidates - set(pos.flatten()): return
                if is_on_diag:
                    last_one_on_diag = (np.count_nonzero(np.diagonal(pos)) == 3) if is_on_diag == 1 else (np.count_nonzero(np.diagonal(np.fliplr(pos))) == 3)
                    if last_one_on_diag:
                        if 2026 - diag_sum not in all_candidates - set(pos.flatten()): return
                ################
                candidates = [
                    nb
                    for nb in all_candidates - set(pos.flatten())
                    if nb <= min(2026 - row_sum, 2026 - col_sum, 2026 - diag_sum)
                ]
                if not candidates:
                    return
                for cand in candidates:
                    new = pos.copy()
                    new[r, c] = cand
                    recurse(new)
                return


recurse(np.zeros((4, 4), dtype=int))

KeyboardInterrupt: 

In [174]:
! uv pip install z3-solver

[2mUsing Python 3.13.2 environment at: /Users/clementlelievre/dev/rayon/apps/ai-backend/ai_orchestrator/.venv[0m
[2K[37m⠙[0m [2m                                                                              [0m[2mResolved [1m1 package[0m [2min 141ms[0m[0m
[2K[37m⠙[0m [2mPreparing packages...[0m (0/1)                                                   
[2K[1A[37m⠙[0m [2mPreparing packages...[0m (0/1)--------------[0m[0m     0 B/35.34 MiB           [1A
[2K[1A[37m⠙[0m [2mPreparing packages...[0m (0/1)--------------[0m[0m 16.00 KiB/35.34 MiB         [1A
[2K[1A[37m⠙[0m [2mPreparing packages...[0m (0/1)--------------[0m[0m 32.00 KiB/35.34 MiB         [1A
[2K[1A[37m⠙[0m [2mPreparing packages...[0m (0/1)--------------[0m[0m 48.00 KiB/35.34 MiB         [1A
[2K[1A[37m⠙[0m [2mPreparing packages...[0m (0/1)--------------[0m[0m 60.73 KiB/35.34 MiB         [1A
[2K[1A[37m⠙[0m [2mPreparing packages...[0m (0/1)--------------[0m[0m 76.

In [175]:
from z3 import *
import time

def is_prime(n):
    """Check if a number is prime."""
    if n < 2:
        return False
    if n == 2:
        return True
    if n % 2 == 0:
        return False
    for i in range(3, int(n**0.5) + 1, 2):
        if n % i == 0:
            return False
    return True

def get_primes_under(n):
    """Get all primes under n, excluding 2."""
    return [p for p in range(3, n + 1, 2) if is_prime(p)]

def solve_prime_magic_square():
    """Solve the 4x4 prime magic square with magic number 2026 using Z3."""
    
    print("Generating primes...")
    primes = get_primes_under(2026)
    print(f"Found {len(primes)} odd primes under 2026")
    
    # Create Z3 solver
    s = Solver()
    
    # Create 16 integer variables for the magic square
    square = [[Int(f"cell_{i}_{j}") for j in range(4)] for i in range(4)]
    
    print("Adding constraints...")
    
    # Constraint 1: Each cell must be one of the valid primes
    for i in range(4):
        for j in range(4):
            s.add(Or([square[i][j] == p for p in primes]))
    
    # Constraint 2: All cells must be distinct
    all_cells = [square[i][j] for i in range(4) for j in range(4)]
    s.add(Distinct(all_cells))
    
    # Constraint 3: All rows sum to 2026
    for i in range(4):
        s.add(Sum([square[i][j] for j in range(4)]) == 2026)
    
    # Constraint 4: All columns sum to 2026
    for j in range(4):
        s.add(Sum([square[i][j] for i in range(4)]) == 2026)
    
    # Constraint 5: Main diagonal (top-left to bottom-right) sums to 2026
    s.add(Sum([square[i][i] for i in range(4)]) == 2026)
    
    # Constraint 6: Anti-diagonal (top-right to bottom-left) sums to 2026
    s.add(Sum([square[i][3-i] for i in range(4)]) == 2026)
    
    print("Solving...")
    start_time = time.time()
    
    result = s.check()
    
    elapsed_time = time.time() - start_time
    
    if result == sat:
        print(f"\n✓ Solution found in {elapsed_time:.2f} seconds!\n")
        m = s.model()
        
        # Extract and display the solution
        solution = [[m.evaluate(square[i][j]).as_long() for j in range(4)] for i in range(4)]
        
        print("Prime Magic Square with magic number 2026:")
        print("=" * 40)
        for row in solution:
            print(" ".join(f"{num:4d}" for num in row))
        print("=" * 40)
        
        # Verify the solution
        print("\nVerification:")
        print("-" * 40)
        for i, row in enumerate(solution):
            print(f"Row {i}: {' + '.join(map(str, row))} = {sum(row)}")
        
        print()
        for j in range(4):
            col = [solution[i][j] for i in range(4)]
            print(f"Col {j}: {' + '.join(map(str, col))} = {sum(col)}")
        
        diag1 = [solution[i][i] for i in range(4)]
        print(f"\nDiagonal 1: {' + '.join(map(str, diag1))} = {sum(diag1)}")
        
        diag2 = [solution[i][3-i] for i in range(4)]
        print(f"Diagonal 2: {' + '.join(map(str, diag2))} = {sum(diag2)}")
        
        # Verify all are prime
        all_nums = [num for row in solution for num in row]
        print(f"\nAll numbers are prime: {all(is_prime(n) for n in all_nums)}")
        print(f"All numbers are distinct: {len(all_nums) == len(set(all_nums))}")
        
        return solution
    
    elif result == unsat:
        print(f"\n✗ No solution exists (proven in {elapsed_time:.2f} seconds)")
        print("The problem is mathematically impossible.")
        return None
    
    else:
        print(f"\n? Unknown result after {elapsed_time:.2f} seconds")
        print("Z3 could not determine satisfiability.")
        return None

if __name__ == "__main__":
    solve_prime_magic_square()

Generating primes...
Found 305 odd primes under 2026
Adding constraints...
Solving...

✓ Solution found in 2.47 seconds!

Prime Magic Square with magic number 2026:
 383  677  857  109
  11 1321  401  293
 139   23  281 1583
1493    5  487   41

Verification:
----------------------------------------
Row 0: 383 + 677 + 857 + 109 = 2026
Row 1: 11 + 1321 + 401 + 293 = 2026
Row 2: 139 + 23 + 281 + 1583 = 2026
Row 3: 1493 + 5 + 487 + 41 = 2026

Col 0: 383 + 11 + 139 + 1493 = 2026
Col 1: 677 + 1321 + 23 + 5 = 2026
Col 2: 857 + 401 + 281 + 487 = 2026
Col 3: 109 + 293 + 1583 + 41 = 2026

Diagonal 1: 383 + 1321 + 281 + 41 = 2026
Diagonal 2: 109 + 401 + 23 + 1493 = 2026

All numbers are prime: True
All numbers are distinct: True


In [None]:
sol=np.array([[383,  677,  857,  109],
  [11, 1321,  401,  293],
 [139 ,  23 , 281, 1583],
[1493  ,  5,  487 ,  41]])


assert all(nb in primes for nb in sol.flatten())
assert all(sum(row)==2026 for row in sol)
assert all(sol[:,i].sum() == 2026 for i in range(4))
assert np.trace(sol) == 2026
assert np.trace(np.fliplr(sol))==2026


In [200]:
from z3 import *
import time

def is_prime(n):
    """Check if a number is prime."""
    if n < 2:
        return False
    if n == 2:
        return True
    if n % 2 == 0:
        return False
    for i in range(3, int(n**0.5) + 1, 2):
        if n % i == 0:
            return False
    return True

def get_primes_under(n):
    """Get all primes under n, excluding 2."""
    return [p for p in range(3, n + 1, 2) if is_prime(p)]

def solve_all_prime_magic_squares():
    """Find ALL 4x4 prime magic squares with magic number 2026 using Z3."""
    
    print("Generating primes...")
    primes = get_primes_under(2026)
    print(f"Found {len(primes)} odd primes under 2026")
    
    # Create 16 integer variables for the magic square
    square = [[Int(f"cell_{i}_{j}") for j in range(4)] for i in range(4)]
    all_cells = [square[i][j] for i in range(4) for j in range(4)]
    
    # Create solver with base constraints
    s = Solver()
    
    print("Adding constraints...")
    
    # Constraint 1: Each cell must be one of the valid primes
    for i in range(4):
        for j in range(4):
            s.add(Or([square[i][j] == p for p in primes]))
    
    # Constraint 2: All cells must be distinct
    s.add(Distinct(all_cells))
    
    # Constraint 3: All rows sum to 2026
    for i in range(4):
        s.add(Sum([square[i][j] for j in range(4)]) == 2026)
    
    # Constraint 4: All columns sum to 2026
    for j in range(4):
        s.add(Sum([square[i][j] for i in range(4)]) == 2026)
    
    # Constraint 5: Main diagonal (top-left to bottom-right) sums to 2026
    s.add(Sum([square[i][i] for i in range(4)]) == 2026)
    
    # Constraint 6: Anti-diagonal (top-right to bottom-left) sums to 2026
    s.add(Sum([square[i][3-i] for i in range(4)]) == 2026)
    
    print("Finding all solutions...\n")
    start_time = time.time()
    
    solutions = []
    solution_count = 0
    M=0
    while True:
        result = s.check()
        
        if result == sat:
            solution_count += 1
            m = s.model()
            
            # Extract the solution
            solution = [[m.evaluate(square[i][j]).as_long() for j in range(4)] for i in range(4)]
            solutions.append(solution)
            m = max(max(row) for row in solution)
            if m > M:
                print(f"New max prime found in solutions: {m}")
                M=m
            
            print(f"Solution #{solution_count}:")
            print("=" * 40)
            for row in solution:
                print(" ".join(f"{num:4d}" for num in row))
            print("=" * 40)
            print()
            
            # Add constraint to exclude this solution
            # Create a constraint that at least one cell must be different
            block_constraint = Or([
                square[i][j] != solution[i][j]
                for i in range(4)
                for j in range(4)
            ])
            s.add(block_constraint)
            
        elif result == unsat:
            elapsed_time = time.time() - start_time
            print(f"\n{'='*40}")
            print(f"Found {solution_count} total solution(s) in {elapsed_time:.2f} seconds")
            print(f"{'='*40}\n")
            
            if solution_count > 0:
                # Verify one solution as example
                print("Verification of first solution:")
                print("-" * 40)
                sol = solutions[0]
                for i, row in enumerate(sol):
                    print(f"Row {i}: {' + '.join(map(str, row))} = {sum(row)}")
                
                print()
                for j in range(4):
                    col = [sol[i][j] for i in range(4)]
                    print(f"Col {j}: {' + '.join(map(str, col))} = {sum(col)}")
                
                diag1 = [sol[i][i] for i in range(4)]
                print(f"\nDiagonal 1: {' + '.join(map(str, diag1))} = {sum(diag1)}")
                
                diag2 = [sol[i][3-i] for i in range(4)]
                print(f"Diagonal 2: {' + '.join(map(str, diag2))} = {sum(diag2)}")
                
                all_nums = [num for row in sol for num in row]
                print(f"\nAll numbers are prime: {all(is_prime(n) for n in all_nums)}")
                print(f"All numbers are distinct: {len(all_nums) == len(set(all_nums))}")
            else:
                print("The problem is mathematically impossible.")
            
            return solutions
        
        else:
            print(f"\n? Unknown result after {time.time() - start_time:.2f} seconds")
            print("Z3 could not determine satisfiability.")
            return solutions

if __name__ == "__main__":
    all_solutions = solve_all_prime_magic_squares()

Generating primes...
Found 305 odd primes under 2026
Adding constraints...
Finding all solutions...

New max prime found in solutions: 1583
Solution #1:
 367  167   53 1439
1259  269  431   67
 251    7 1319  449
 149 1583  223   71

Solution #2:
 461  733  191  641
 719  281  929   97
 743  353  463  467
 103  659  443  821

Solution #3:
 443   31  953  599
 719  659  577   71
 757  743   47  479
 107  593  449  877

Solution #4:
 239  101 1109  577
 911  733  311   71
 157  419  563  887
 719  773   43  491

Solution #5:
 239  191 1109  487
 911  733  311   71
  67  419  563  977
 809  683   43  491

New max prime found in solutions: 1907
Solution #6:
  23   53 1193  757
  13 1907   47   59
 773    5   67 1181
1217   61  719   29

Solution #7:
  23   17 1229  757
  13 1907   47   59
 773    5   67 1181
1217   97  683   29

New max prime found in solutions: 1931
Solution #8:
  11  443   73 1499
1489   71    5  461
  23   19 1931   53
 503 1493   17   13

Solution #9:
  11  233  283 14