In [1]:
import numpy as np

In [2]:
def parse_input(input):
    grid = []
    with open(input) as file_in:
        grid_str = file_in.read().splitlines()
        for row in grid_str:
            grid.append(list(row))

    return np.array(grid)

def count_adjacent_rolls(grid, x, y, directions):
    n_row, n_col = grid.shape

    n_adjacent_rolls = 0
    for dx, dy in directions:
        x_new, y_new = x+dx, y+dy
        if 0 <= x_new < n_row and 0 <= y_new < n_col and grid[x+dx, y+dy] == "@":
            n_adjacent_rolls += 1

    return n_adjacent_rolls


def get_accessible_rolls(grid):
    n_row, n_col = grid.shape
    directions = [(x,y) for x in (-1, 0, 1) for y in (-1, 0, 1) if (x,y) != (0,0)]

    accessible_rolls = []
    for i in range(n_row):
        for j in range(n_col):
            if grid[i,j] == "@":
                if count_adjacent_rolls(grid, i, j, directions) < 4:
                    accessible_rolls.append((i, j))

    return accessible_rolls

In [3]:
def main(input, part):
    grid = parse_input(input)

    if part == 1:
        return len(get_accessible_rolls(grid))

    if part == 2:
        total_n_removed_rolls = 0
        n_removable_rolls = None
        while n_removable_rolls != 0:
            removable_rolls = get_accessible_rolls(grid)
            n_removable_rolls = len(removable_rolls)
            for x, y in removable_rolls:
                grid[x, y] = "."
            total_n_removed_rolls += n_removable_rolls
        return total_n_removed_rolls

In [14]:
assert main("example.txt", part=1) == 13
main("input.txt", part=1)

1428

In [15]:
assert main("example.txt", part=2) == 43
main("input.txt", part=2)

8936