In [1]:
from functools import lru_cache
import numpy as np
import copy
import itertools as it
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [2]:
@lru_cache(maxsize=None)
def factorial(n):
    return 1 if n == 0 else n * factorial(n - 1)

@lru_cache(maxsize=None)
def choose(n, k):
    if n < 0:
        return 1
    if k < 0 or k > n:
        return 0
    return factorial(n) / (factorial(n - k) * factorial(k))

@lru_cache(maxsize=None)
def layer_size(m, h):
    """
    returns the size of the h layer of the ingeter lattice in dimension m
    """
    return sum([choose(m, i) * choose(h - 1, i - 1) * 2 ** i for i in range(h + 1)])

@lru_cache(maxsize=None)
def layer_sum(m, n):
    return sum([layer_size(m, h) for h in range(n + 1)])

@lru_cache(maxsize=None)
def find_upper_bound(m):
    """
    finds the last time before the binary (exponential) growth overtakes the spacial (polynomial) expansion
    """
    n = 0
    while layer_sum(m, n) >= 2**(n+1) - 1:
        n += 1
    return n - 1

In [3]:
for m in range(2, 10):
    print(m, find_upper_bound(m))

2 4
3 9
4 13
5 18
6 22
7 27
8 31
9 36


In [4]:
@lru_cache(maxsize=None)
def gen_rot_matrices(m):
    """
    generate a list of the rotation (by pi/2 in each of the standard orthogonal directions) matrices in m dimensional space
    """
    mats = []
    for i in range(m - 1):
        M = np.identity(m, dtype=int)
        M[i][i] = 0          # cos(pi/2)
        M[i][i + 1] = -1     # -sin(pi/2)
        M[i + 1][i] = 1      # sin(pi/2)
        M[i + 1][i + 1] = 0  # cos(pi/2)
        mats.append(M)
    return mats

@lru_cache(maxsize=None)
def gen_group(m):
    """
    generate a list of the group elements generated by `gen_rot_matrices(m)`
    """
    mats = gen_rot_matrices(m)
    S = np.identity(m, dtype=int)
    S[0][0] = -1
    group = []
    for tup in it.product(range(4), repeat = m - 1):
        M = np.identity(m, dtype=int)
        for i, num in enumerate(tup):
            for _ in range(num):
                M = M @ mats[i]
        group.append(M)
        group.append(M @ S)
    return group

In [5]:
gen_rot_matrices(2)

[array([[ 0, -1],
        [ 1,  0]])]

In [6]:
gen_group(2)

[array([[1, 0],
        [0, 1]]),
 array([[-1,  0],
        [ 0,  1]]),
 array([[ 0, -1],
        [ 1,  0]]),
 array([[ 0, -1],
        [-1,  0]]),
 array([[-1,  0],
        [ 0, -1]]),
 array([[ 1,  0],
        [ 0, -1]]),
 array([[ 0,  1],
        [-1,  0]]),
 array([[0, 1],
        [1, 0]])]

In [7]:
gen_rot_matrices(3)

[array([[ 0, -1,  0],
        [ 1,  0,  0],
        [ 0,  0,  1]]),
 array([[ 1,  0,  0],
        [ 0,  0, -1],
        [ 0,  1,  0]])]

In [8]:
def check_np_contain(L, arr):
    for arr2 in L:
        if (arr == arr2).all():
            return True
    return False

In [9]:
EMPTY = -1

class State:
    def __init__(self, m, size = None):
        if size == None:
            size = find_upper_bound(m) * 2 + 1
        self.m = m # dimension
        if size % 2 == 0:
            size += 1
        self.size = size
        self.c = int((self.size - 1) / 2)
        self.grid = EMPTY * np.ones((size,) * m, dtype = int)
        self.leaves = []
        self.n = -1 # size of the tree

    def set_center(self):
        self.n = 0
        self.grid[tuple([self.c] * self.m)] = self.n
        self.leaves = [[self.c] * self.m]

    def copy(self):
        cpy = State(self.m, self.size)
        cpy.grid = copy.deepcopy(self.grid)
        cpy.n = self.n
        cpy.c = self.c
        return cpy

    def get_sprouts_lists(self):
        sprouts_lists = [[] for _ in self.leaves]
        for l, leaf in enumerate(self.leaves):
            for i in range(self.m):
                for sign in [1, -1]:
                    indx = leaf.copy()
                    if indx[i] + sign < self.size and indx[i] + sign >= 0:
                        indx[i] += sign
                        if self.grid[tuple(indx)] == EMPTY:
                            sprouts_lists[l].append(indx)
        return sprouts_lists

    def disp(self):
        if self.m == 2:
            for row in self.grid:
                print(' '.join(' ' if el == -1 else str(el) for el in row))
        elif self.m == 3:
            cods = [[], [], []]
            labels = []
            for indx in it.product(range(self.size), repeat = 3):
                if self.grid[indx] != -1:
                    for i in range(3):
                        cods[i].append(indx[i])
                    labels.append(self.grid[indx])
            fig = plt.figure()
            ax = fig.add_subplot(111, projection = '3d')
            ax.plot(*cods, ',')
            for x,y,z,i in zip(*cods,labels):
                ax.text(x,y,z,i)
            plt.show()
        else:
            print(self.grid)

    def get_children(self):
        children = []
        grids = []
        sprouts_lists = self.get_sprouts_lists()
        for sprouts in sprouts_lists:
            if len(sprouts) < 2:
                return children
        pairs_lists = [it.combinations(sprouts, 2) for sprouts in sprouts_lists]
        for arr in it.product(*pairs_lists):
            bad = False
            c = self.copy()
            c.n += 1
            for pair, leaf in zip(arr, self.leaves):
                for indx in pair:
                    if c.grid[tuple(indx)] != EMPTY:
                        bad = True
                        break
                    c.grid[tuple(indx)] = c.n
                    c.leaves.append(indx)
                if bad:
                    break
            else:
                # check for symmetric duplicates
                for R in gen_group(self.m):
                    if check_np_contain(grids, c.rotate(R)):
                        break
                else:
                    grids.append(c.grid)
                    children.append(c)
        return children
    
    def get_spread(self):
        s = 0
        for indx in it.product(range(self.size), repeat = self.m):
            l = sum([abs(self.c - el) for el in indx])
            s += l * (self.grid[tuple(indx)] + 1)
        return s
    
    def rotate(self, R):
        M = np.zeros(self.grid.shape, dtype=int)
        for indx in it.product(range(self.size), repeat = self.m):
            indx1 = np.array(indx) - (np.ones(self.m) * self.c)
            indx2 = (R @ indx1) + (np.ones(self.m) * self.c)
            indx3 = tuple(round(el) for el in indx2.data)
            M[indx3] = self.grid[indx]
        return M

In [10]:
def find_best_of_children(state):
    best = state.copy()
    stack = [state]
    group = gen_group(state.m)
    while len(stack) > 0:
        state = stack.pop()
        children = state.get_children()
        if len(children) == 0:
            if state.n > best.n:
                best = state
                if best.n == state.c:
                    return best
        else:
            children.sort(key=lambda x: -x.get_spread())
            stack.extend(children)
    return best

def find_best(m, size=None):
    init = State(m, size)
    init.set_center()
    return find_best_of_children(init)

In [11]:
find_best(1).disp()

[1 0 1]


In [12]:
find_best(2).disp()

                 
      4   4      
    4 3 4 3 4    
  4 3 2 1 2 3 4  
    4   0   4    
  4 3 2 1 2 3 4  
    4 3 4 3 4    
      4   4      
                 
