In [1]:
import sys
import string
import itertools
from collections import Counter, defaultdict
import re

from pathlib import Path
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math
import networkx as nx

In [2]:
%load_ext line_profiler

In [3]:
data = Path('../data/day_11.txt').read_text()

In [4]:
# start at 6:34

In [5]:
octopii = [[int(k) for k in row] for row in data.splitlines()]

In [6]:
def run_step(grid, nrows, ncols):
    grid = grid + 1
    flashed = set()
    while True:
        mask = grid > 159
        flashed = mask.nonzero()
        if flashed[0].shape[0] == 0:
            break
        grid[mask] = 0
        for x, y in zip(*flashed):
            grid[0 if x == 0 else x - 1: x + 2, 0 if y == 0 else y - 1: y + 2] += 1
            # grid[x][y] = 0
    mask = grid < 150
    flashed_count = mask.sum()
    grid = np.where(mask, 150, grid)
    return grid, flashed_count

def part_a(grid, steps):
    nrow, ncol = grid.shape
    grid = grid + 150 # random - just needs to be high enough (ideally nrow * ncol + max(grid))
    flashed = 0
    for iter_count in range(steps):
        grid, flash_count = run_step(grid, nrow, ncol)
        flashed += flash_count
    return flashed

print(part_a(np.array([[int(k) for k in row] for row in '''5483143223
2745854711
5264556173
6141336146
6357385478
4167524645
2176841721
6882881134
4846848554
5283751526'''.splitlines()]), 100))
print(part_a(np.array(octopii), 100))

1656
1723


In [7]:
%%timeit grid = np.array(octopii)
part_a(grid, 100)

8.73 ms ± 72.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
%lprun -f part_a -f run_step part_a(np.array(octopii, dtype=np.int16), 100)

Timer unit: 1e-06 s

Total time: 0.016474 s
File: <ipython-input-6-5ce7b632cc84>
Function: run_step at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def run_step(grid, nrows, ncols):
     2       100        199.0      2.0      1.2      grid = grid + 1
     3       100         53.0      0.5      0.3      flashed = set()
     4                                               while True:
     5       594       1112.0      1.9      6.8          mask = grid > 159
     6       594       1142.0      1.9      6.9          flashed = mask.nonzero()
     7       594        348.0      0.6      2.1          if flashed[0].shape[0] == 0:
     8       100         30.0      0.3      0.2              break
     9       494        730.0      1.5      4.4          grid[mask] = 0
    10      2217       2365.0      1.1     14.4          for x, y in zip(*flashed):
    11      1723       9175.0      5.3     55.7              grid[0 if x =

In [9]:
def part_b(grid):
    nrow, ncol = grid.shape
    grid = grid + 150
    step = 0
    while True:
        step += 1
        grid, _ = run_step(grid, nrow, ncol)
        if (grid == 150).all():
            break
    return step

print(part_b(np.array([[int(k) for k in row] for row in '''5483143223
2745854711
5264556173
6141336146
6357385478
4167524645
2176841721
6882881134
4846848554
5283751526'''.splitlines()])))
print(part_b(np.array(octopii)))


195
327


In [10]:
%%timeit grid = np.array(octopii)
part_b(grid) 

29.4 ms ± 403 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
%lprun -f part_b -f run_step part_b(np.array(octopii))

Timer unit: 1e-06 s

Total time: 0.047571 s
File: <ipython-input-6-5ce7b632cc84>
Function: run_step at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def run_step(grid, nrows, ncols):
     2       327        576.0      1.8      1.2      grid = grid + 1
     3       327        177.0      0.5      0.4      flashed = set()
     4                                               while True:
     5      1947       2907.0      1.5      6.1          mask = grid > 159
     6      1947       4264.0      2.2      9.0          flashed = mask.nonzero()
     7      1947       1196.0      0.6      2.5          if flashed[0].shape[0] == 0:
     8       327        116.0      0.4      0.2              break
     9      1620       2359.0      1.5      5.0          grid[mask] = 0
    10      7058       6663.0      0.9     14.0          for x, y in zip(*flashed):
    11      5438      25143.0      4.6     52.9              grid[0 if x =