# Problem Statement

Triangle, square, pentagonal, hexagonal, heptagonal, and octagonal numbers are all figurate (polygonal) numbers and are generated by the following formulae:

Triangle	 	P3,n=n(n+1)/2	 	1, 3, 6, 10, 15, ...

Square	 	P4,n=n2	 	1, 4, 9, 16, 25, ...

Pentagonal	 	P5,n=n(3n−1)/2	 	1, 5, 12, 22, 35, ...

Hexagonal	 	P6,n=n(2n−1)	 	1, 6, 15, 28, 45, ...

Heptagonal	 	P7,n=n(5n−3)/2	 	1, 7, 18, 34, 55, ...

Octagonal	 	P8,n=n(3n−2)	 	1, 8, 21, 40, 65, ...

The ordered set of three 4-digit numbers: 8128, 2882, 8281, has three interesting properties.

The set is cyclic, in that the last two digits of each number is the first two digits of the next number (including the last number with the first).
Each polygonal type: triangle (P3,127=8128), square (P4,91=8281), and pentagonal (P5,44=2882), is represented by a different number in the set.
This is the only set of 4-digit numbers with this property.
Find the sum of the only ordered set of six cyclic 4-digit numbers for which each polygonal type: triangle, square, pentagonal, hexagonal, heptagonal, and octagonal, is represented by a different number in the set.

## Graph Theory?

First some function definitions, we'll need to generate numbers to test with:

In [9]:
def triangle(n):
    return int(n * (n + 1) / 2)

def square(n):
    return n * n

def pentagonal(n):
    return int(n * (3 * n - 1) / 2)

def hexagonal(n):
    return n * (2 * n - 1)

def heptagonal(n):
    return int(n * (5 * n - 3) / 2)

def octagonal(n):
    return n * (3 * n - 2)

funks = [
    ("triangle", triangle),
    ("square", square),
    ("pentagonal", pentagonal),
    ("hexagonal", hexagonal),
    ("heptagonal", heptagonal),
    ("octagonal", octagonal)
]

# Quick check all our functions
for name, func in funks:
    print(name, end=" ")
    for i in range(1, 5):
        print(f"{i}: {func(i)}", end=" ")
    print()
    

triangle 1: 1 2: 3 3: 6 4: 10 
square 1: 1 2: 4 3: 9 4: 16 
pentagonal 1: 1 2: 5 3: 12 4: 22 
hexagonal 1: 1 2: 6 3: 15 4: 28 
heptagonal 1: 1 2: 7 3: 18 4: 34 
octagonal 1: 1 2: 8 3: 21 4: 40 


Then for each type of number we construct a set of all four digit numbers of that type.

In [69]:
sets = {}
for name, func in funks:
    sets[name] = set()
    i = 1
    val = func(i)
    while val <= 9999:
        if val >= 1000:
            sets[name].add(val)
        i += 1
        val = func(i)
    
    


This is a lot of connections to deal with, so we can make it easier by pruning impossible values. Any number that has a 0 in the tens place won't be valid because it would point to a 3 digit number. Also we should remove numbers that exist in multiple sets **I think?**

In [70]:
import collections

for name, s in sets.items():
    r = []
    for n in s:
        if str(n)[2] == "0":
            r.append(n)
    for n in r:
        s.remove(n)

# Remove duplicates...?
# removals = collections.defaultdict(set)
# for from_name in names:
#     for to_name in names:
#         if from_name == to_name:
#             continue
#         for a in sets[from_name]:
#             for b in sets[to_name]:
#                 if a == b:
#                     removals[from_name].add(a)
#                     removals[to_name].add(a)
            
# for name, s in removals.items():
#     for val in s:
#         sets[name].remove(val)
    
        



Then we construct a giant graph of all the numbers and which other numbers they point to. We ignore connections within the same group.

In [71]:
from pprint import pprint

names = list(sets.keys())
edges = {}

def is_cyclic(a, b):
    return str(a)[2:] == str(b)[:2]

for from_name in names:
    edges[from_name] = {}
    for to_name in names:
        if from_name == to_name:
            continue
        for a in sets[from_name]:
            if a not in edges[from_name]:
                edges[from_name][a] = {}
            for b in sets[to_name]:
                if not is_cyclic(a, b):
                    continue
                if to_name not in edges[from_name][a]:
                    edges[from_name][a][to_name] = []
                edges[from_name][a][to_name].append(b)
pprint(edges)
                
        

{'heptagonal': {1071: {'hexagonal': [7140], 'triangle': [7140]},
                1177: {'pentagonal': [7740],
                       'square': [7744],
                       'triangle': [7750]},
                1288: {'pentagonal': [8855], 'square': [8836]},
                1525: {'hexagonal': [2556], 'triangle': [2556]},
                1651: {'hexagonal': [5151],
                       'pentagonal': [5192],
                       'square': [5184],
                       'triangle': [5151]},
                1782: {'square': [8281], 'triangle': [8256]},
                1918: {'hexagonal': [1891],
                       'octagonal': [1825],
                       'pentagonal': [1820],
                       'square': [1849],
                       'triangle': [1830, 1891]},
                2059: {'hexagonal': [5995],
                       'octagonal': [5985],
                       'pentagonal': [5922],
                       'square': [5929],
                       'triangle': [5995]}

This is a massive amount of edges, and naturally someone has written a efficient algorithm to do this, but I'mma just brute force it. We can do this by using a backtracking recursive algorithm. Since we know that one node from each type must be included, we can simply go through each triangle node, and assume that it is in the final loop. Then for each edge we make the same assumption until we reach a contradiction.

In [75]:
names = list(sets.keys())

def find_chain(node, node_type, types_considered=[], chain=[]):
    if not types_considered:
        types_considered = [node_type]
    if not chain:
        chain=[(node, node_type)]
        
    if len(chain) == 6:
        return chain
    
    max_chain = chain
    for n in names:
        if n in types_considered:
            continue
        if n not in edges[node_type][node]:
            continue
        for val in edges[node_type][node][n]:
            c = find_chain(val, n, types_considered=types_considered + [n], chain=chain + [(val, n)])
            # Make sure to check for chains that loop back to the start.
            if len(c) >= len(max_chain) and is_cyclic(c[-1][0], c[0][0]):
                max_chain = c
        
    return max_chain

max_chain = []
for node in edges['heptagonal'].keys():
    chain = find_chain(node, 'heptagonal')
    if len(chain) > len(max_chain):
        max_chain = chain
        
print(max_chain)
print(sum([item[0] for item in max_chain]))

[(2512, 'heptagonal'), (1281, 'octagonal'), (8128, 'hexagonal'), (2882, 'pentagonal'), (8256, 'triangle'), (5625, 'square')]
28684
