In [6]:
import functools
from dataclasses import dataclass, field


def load_input() -> list[list[str]]:
    with open("../../data/day12-input.txt") as f:
        grid = []
        for line in f.readlines():
            line = line.strip()
            row = []
            for char in line:
                row.append(char)
            grid.append(row)
    return grid


@functools.lru_cache(maxsize=None)
def move_to(row, col, direction) -> tuple[int, int]:
    dirs = {
        'up': lambda i, j: (i - 1, j),
        'down': lambda i, j: (i + 1, j),
        'left': lambda i, j: (i, j - 1),
        'right': lambda i, j: (i, j + 1),
    }
    return dirs[direction](row, col)



In [7]:
@dataclass()
class GridSquareBorders:
    up: bool = False
    left: bool = False
    down: bool = False
    right: bool = False


@dataclass()
class GridSquare:
    row: int
    col: int
    value: str
    marked: bool = False
    borders: GridSquareBorders = field(default_factory=lambda: GridSquareBorders())

    def __hash__(self) -> int:
        return hash((self.row, self.col, self.value))


In [8]:
def create_grid():
    raw_grid = load_input()
    grid = []
    for i, row in enumerate(raw_grid):
        grid_row = []
        for j, col in enumerate(row):
            grid_row.append(GridSquare(i, j, col, False))
        grid.append(grid_row)
    return grid


def clear_grid_markers(grid: list[list[GridSquare]]) -> None:
    for row in grid:
        for square in row:
            square.marked = False


In [9]:
def collect_region(current_square: GridSquare, value_to_match: str, grid: list[list[GridSquare]]) -> None:
    collection = set()
    if current_square.value != value_to_match:
        return collection
    else:
        current_square.marked = True
        collection.add(current_square)
        for direction in ('up', 'down', 'left', 'right'):
            next_row, next_col = move_to(current_square.row, current_square.col, direction)
            if -1 < next_row < len(grid) and -1 < next_col < len(grid[0]):
                next_square = grid[next_row][next_col]
                if next_square.marked:
                    continue
                else:
                    collection = collection | collect_region(next_square, value_to_match, grid)
    return collection


def collect_regions(grid: list[list[GridSquare]]) -> list[set[GridSquare]]:
    regions = []
    for i, row in enumerate(grid):
        for j, square in enumerate(row):
            if square.marked:
                continue
            region = collect_region(square, square.value, grid)
            if region:
                regions.append(region)
    return regions


def count_borders(square: GridSquare, grid: list[list[GridSquare]]) -> int:
    borders = 0
    for direction in ('up', 'down', 'left', 'right'):
        next_row, next_col = move_to(square.row, square.col, direction)
        if -1 < next_row < len(grid) and -1 < next_col < len(grid[0]):
            next_square = grid[next_row][next_col]
            if next_square.value != square.value:
                borders += 1
        else:
            borders += 1
    return borders


def calculate_cost(region: set[GridSquare], grid: list[list[GridSquare]]) -> int:
    perimeter = 0
    area = 0
    for square in region:
        area += 1
        perimeter += count_borders(square, grid)
    return area * perimeter


def calculate_part1_cost():
    grid = create_grid()
    regions = collect_regions(grid)
    clear_grid_markers(regions)
    total_cost = 0
    for region in regions:
        total_cost += calculate_cost(region, grid)
    print(total_cost)


calculate_part1_cost()

1375476


## Part 2

In [10]:
def mark_borders(region: list[GridSquare], grid: list[list[GridSquare]]) -> list[GridSquare]:
    borders = []
    for square in region:
        for direction in ('up', 'down', 'left', 'right'):
            next_row, next_col = move_to(square.row, square.col, direction)
            if -1 < next_row < len(grid) and -1 < next_col < len(grid[0]):
                next_square = grid[next_row][next_col]
                if next_square.value != square.value:
                    setattr(square.borders, direction, True)
                    borders.append(square)
            else:
                setattr(square.borders, direction, True)
                borders.append(square)
    return borders


def calculate_part2_cost():
    grid = create_grid()
    regions = collect_regions(grid)
    total_cost = 0
    row_borders = ('up', 'down')
    col_borders = ('left', 'right')
    for region in regions:
        area = len(region)
        borders = mark_borders(region, grid)

        border_wall_at: dict[str, dict[int, set[int]]] = {
            # For each row, store the column number.
            'up': {},
            'down': {},
            # For each column, store the row number.
            'left': {},
            'right': {},
        }
        for border in borders:
            row = border.row
            col = border.col
            for direction in row_borders:
                if getattr(border.borders, direction):
                    if row in border_wall_at[direction]:
                        border_wall_at[direction][row].add(col)
                    else:
                        border_wall_at[direction][row] = {col}
            for direction in col_borders:
                if getattr(border.borders, direction):
                    if col in border_wall_at[direction]:
                        border_wall_at[direction][col].add(row)
                    else:
                        border_wall_at[direction][col] = {row}

        sides = 0
        for direction in border_wall_at:
            for pivot in border_wall_at[direction]:
                sorted_coordinates = sorted(border_wall_at[direction][pivot])
                contiguous = 1
                previous = sorted_coordinates[0]
                for value in sorted_coordinates[1:]:
                    if value != previous + 1:
                        contiguous += 1
                    previous = value
                sides += contiguous

        total_cost += sides * area

    print(total_cost)


calculate_part2_cost()

821372
