In [51]:
from collections import deque, defaultdict

import numpy as np

In [52]:
directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]

In [53]:
def get_region_plant_start(grid, plant, start):

    assert grid[start] == plant

    n_rows, n_cols = grid.shape

    visited = set()
    n_connexions = 0
    queue = deque([start])

    while queue:
        x_current, y_current = queue.popleft()
        if (x_current, y_current) in visited:
            continue
        visited.add((x_current, y_current))

        for dx, dy in directions:
            x_next, y_next = x_current + dx, y_current + dy
            if 0 <= x_next < n_rows and 0 <= y_next < n_cols and (x_next, y_next) not in visited:
                if grid[x_next, y_next] == plant:
                    queue.append((x_next, y_next))
                    n_connexions += 1

    return visited, n_connexions

In [68]:
def get_all_regions_one_plant(grid, plant):
    plant_pos = set([tuple(pos.tolist()) for pos in np.argwhere(grid == plant)])
    plant_regions = []

    while plant_pos:
        start = plant_pos.pop()
        region, n_connexions = get_region_plant_start(grid, plant, start)
        plant_regions.append((str(plant), n_connexions, region))
        plant_pos = plant_pos.difference(region)

    return plant_regions

In [69]:
def get_all_regions(grid):
    regions = []
    for plant in np.unique(grid):
        regions.extend(get_all_regions_one_plant(grid, plant))
    return regions

In [85]:
def compute_total_cost(all_regions):
    total_cost = 0
    for region in all_regions:
        n_plots = len(region[2])
        n_connexions = region[1]
        area = n_plots
        perimeter = 4 * n_plots - 2 * n_connexions
        total_cost += area * perimeter
        # print(region[0], area, perimeter)
    return total_cost

In [89]:
def main(file):
    with open(file) as file_in:
        grid_str = file_in.read()
    grid = np.array([list(row) for row in grid_str.splitlines()])

    regions = get_all_regions(grid)
    total_cost = compute_total_cost(regions)
    return total_cost

In [91]:
main('input.txt')

1431440