In [1]:
from itertools import product

In [2]:
def neighbours(point):
    delta = [-1, 0, 1]
    for d in product(delta, repeat=len(point)):
        if any(d):
            yield tuple(x + dx for x, dx in zip(point, d))

In [3]:
def active_neighbours(point, state):
    return sum(state.get(neighbour, 0) for neighbour in neighbours(point))

In [4]:
def _run(state, grid_boundaries):
    new_state = {}
    new_grid_boundaries = [(_min - 1, _max + 1) for _min, _max in grid_boundaries]
    for point in product(*(range(_min, _max) for _min, _max in new_grid_boundaries)):
        n_active_neighbours = active_neighbours(point, state)
        if (state.get(point, 0) and n_active_neighbours == 2) or n_active_neighbours == 3:
            new_state[point] = 1
    return new_state, new_grid_boundaries

In [5]:
def _boundaries(state, axis):
    projection = [point[axis] for point in state]
    return min(projection), max(projection) + 1

In [6]:
def run(init_state, nsteps):
    state = init_state
    grid_boundaries = [_boundaries(state, axis) for axis in range(len(next(iter(state))))]
    for _ in range(nsteps):
        state, grid_boundaries = _run(state, grid_boundaries)
    return state

# Part 1

In [7]:
with open('input', 'r') as ifile:
    state = {(i, j, 0): 1 for i, row in enumerate(ifile.read().splitlines())
             for j, coord_state in enumerate(row) if coord_state == '#'}

In [8]:
sum(run(state, 6).values())

223

# Part 2

In [9]:
sum(run({(i, j, 0, 0): 1 for (i, j, _) in state}, 6).values())

1884