# Day 9

## Imports and data loading

In [65]:
from math import prod

from utils import get_input, load_data

day = 9


In [2]:
get_input(day)


Data saved


In [2]:
data = load_data(day, list_type="line", number=False)
test_data = [
    "2199943210",
    "3987894921",
    "9856789892",
    "8767896789",
    "9899965678",
]
test_answer_1 = 15
test_answer_2 = 1134


## Part one

In [66]:
class Grid:
    def __init__(self, input_rows):
        self.neighbours = [(-1, 0), (1, 0), (0, -1), (0, 1)]
        self.rows = [[int(i) for i in row] for row in input_rows]
        self.grid = {}
        for y, row in enumerate(self.rows):
            for x, point in enumerate(row):
                self.grid[(x, y)] = point
        self.low_points = self._find_low_points()

    def __repr__(self):
        rows = ["".join([str(r) for r in row]) for row in self.rows]
        return "\n".join(rows)

    def point(self, x, y):
        return self.grid[(x, y)]

    def get_neighbours(self, x, y):
        return [
            self.grid[(x + neighbour_x, y + neighbour_y)]
            for neighbour_x, neighbour_y in self.neighbours
            if (x + neighbour_x, y + neighbour_y) in self.grid
        ]

    def is_low_point(self, x, y):
        point = self.rows[y][x]
        neighbours = self.get_neighbours(x, y)
        return all(point < n for n in neighbours)

    def add_low_points(self):
        """Note that this adds 1 to each low point to match the question text."""
        return sum([(point + 1) for point in self.low_points.values()])

    def _find_low_points(self):
        points = {}
        for y, row in enumerate(self.rows):
            for x, point in enumerate(row):
                if self.is_low_point(x, y):
                    points[(x, y)] = point
        return points

    def explore_basin(self, x, y, so_far=None):
        """Get set of points in basin starting from a given low point."""
        if not so_far:
            so_far = {(x, y)}
        neighbours = {
            (x + neighbour_x, y + neighbour_y)
            for neighbour_x, neighbour_y in self.neighbours
            if self.grid.get((x + neighbour_x, y + neighbour_y), 9) < 9
            and (x + neighbour_x, y + neighbour_y) not in so_far
        }
        if neighbours.issubset(so_far):
            return so_far
        else:
            so_far.update(neighbours)
            for a, b in neighbours:
                so_far.update(self.explore_basin(a, b, so_far))
        return so_far

    def measure_basin(self, x, y):
        return len(self.explore_basin(x, y))

    def multiply_biggest_basins(self):
        basin_sizes = [self.measure_basin(x, y) for x, y in self.low_points]
        return prod(sorted(basin_sizes, reverse=True)[0:3])


In [67]:
test_grid = Grid(test_data)
assert test_grid.add_low_points() == test_answer_1
grid = Grid(data)
grid.add_low_points()


506

## Part two

In [68]:
assert test_grid.multiply_biggest_basins() == test_answer_2
grid.multiply_biggest_basins()


931200