# Day 11

## Imports and data loading

In [None]:
from utils import get_input, load_data

day = 11


In [None]:
get_input(day)


In [None]:
data = load_data(day, list_type="line", number=False)
test_data = [
    "5483143223",
    "2745854711",
    "5264556173",
    "6141336146",
    "6357385478",
    "4167524645",
    "2176841721",
    "6882881134",
    "4846848554",
    "5283751526",
]
test_answer_1 = 1656
test_answer_2 = 195


## Part one

In [None]:
# Let's try this with numpy
import numpy as np

# Remember numpy matrix has y coordinate first
class Cave:
    def __init__(self, data):
        self.matrix = np.array([[int(i) for i in row] for row in data])
        self.flash_matrix = np.zeros((10, 10), dtype=bool)
        self.flash_count = 0
        self.time = 0
        self.all_flash = None

    def __repr__(self):
        rows = ["".join([str(r).rjust(3, " ") for r in row]) for row in self.matrix]
        return "\n".join(rows)

    def step(self):
        self.matrix += 1
        self.increase()
        self.flash()
        self.matrix[self.matrix > 9] = 0
        self.time += 1
        if np.array_equal(self.matrix, np.zeros((10, 10))):
            self.all_flash = self.time

    def flash(self):
        self.flash_count += np.count_nonzero(self.matrix > 9)

    def increase_neighbours(self, y, x):
        self.matrix[max(y - 1, 0) : min(y + 2, 10), max(x - 1, 0) : min(x + 2, 10)] += 1
        self.matrix[y, x] -= 1
        self.flash_matrix[y, x] = True

    def increase(self):
        # Initialise as a matrix of 11s, as the actual grid should never look like this
        while (self.matrix > 9).sum() - self.flash_matrix[self.matrix > 9].sum() > 0:
            iter = np.nditer(self.matrix, flags=["multi_index"], op_flags=["readwrite"])
            for point in iter:
                if (
                    point > 9
                    and not self.flash_matrix[iter.multi_index[0], iter.multi_index[1]]
                ):
                    self.increase_neighbours(iter.multi_index[0], iter.multi_index[1])
        self.flash_matrix = np.zeros((10, 10), dtype=bool)


In [None]:
test_cave = Cave(test_data)
cave = Cave(data)
for i in range(100):
    test_cave.step()
    cave.step()
assert test_cave.flash_count == test_answer_1

cave.flash_count


## Part two

In [None]:
test_cave = Cave(test_data)
cave = Cave(data)

while not test_cave.all_flash:
    test_cave.step()

assert test_cave.all_flash == test_answer_2

while not cave.all_flash:
    cave.step()

print(cave.all_flash)
