In [None]:
from tabulate import tabulate

EXAMPLE = "../example.txt"
INPUT = "../input.txt"

In [None]:
def get_plant_map(input_file_name):
    with open(input_file_name, 'r') as f:
        map = []
        for line in f:
            map.append([c for c in line.replace("\n", "").strip()])
    return map

In [None]:
map = get_plant_map(EXAMPLE)
height = len(map)
width = len(map[0])
print(tabulate(map))

In [None]:
class Region:
    _id: int
    plant: str
    area: int 
    perimeter: int
    sides: int

    def __init__(self, plant):
        self.plant = plant
        self.area = 1
        self.perimeter = 4
        self.sides = 4
        self._id = id(self)

    def cost(self, criteria):
        if criteria == "perimeter":
            return self.area * self.perimeter
        if criteria == "sides":
            return self.area * self.sides
        return 0
    
    def __hash__(self):
        return hash(self._id)

    def __eq__(self, other):
        return isinstance(other, Region) and self._id == other._id


class RegionReference:
    region: Region

    def __init__(self, region):
        self.region = region


class RegionMap:
    plant_map: list[list[str]]
    height: int
    width: int
    region_map: list[list[RegionReference | None]]
    regions: set[Region]
    region_refs: set[RegionReference]

    def __init__(self, plant_map):
        self.plant_map = plant_map
        self.height = len(self.plant_map)
        self.width = len(self.plant_map[0])
        self.region_map = [[None for _ in range(self.width)] for _ in range(self.height)]
        self.regions = set()
        self.region_refs = set()
        self.build_region_map()

    def merge_regions(
        self, top_region_ref: RegionReference, left_region_ref: RegionReference
    ):
        if top_region_ref.region == left_region_ref.region:
            # The two region refs actually point to the same region, so top and left plants
            # are already part of the same region, no merge necessary
            return
        # Merge the two regions into the top region
        top_region_ref.region.area += left_region_ref.region.area
        top_region_ref.region.perimeter += left_region_ref.region.perimeter
        top_region_ref.region.sides += left_region_ref.region.sides
        # Remove the left region since it's been absorbed by the top one
        region_to_remove = left_region_ref.region
        for ref in self.region_refs:
            # Update all referencees to the left region to point them to the top one
            if ref.region == region_to_remove:
                ref.region = top_region_ref.region
        self.regions.remove(region_to_remove)

    def add_to_region(self, region_ref: RegionReference, area=0, perimeter=0, sides=0):
        region_ref.region.area += area
        region_ref.region.perimeter += perimeter
        region_ref.region.sides += sides

    def build_region_map(self):
        for i in range(self.height):
            for j in range(self.width):
                # Go through all positions from top left to bottom right and build the regions
                p = self.plant_map[i][j]
                # For each position, check the one on top and the one to the left (already processed)
                top_plant = None
                top_region_ref = None
                left_plant = None
                left_region_ref = None
                if i - 1 >= 0:
                    top_region_ref = self.region_map[i - 1][j]
                    top_plant = self.plant_map[i - 1][j]
                if j - 1 >= 0:
                    left_region_ref = self.region_map[i][j - 1]
                    left_plant = self.plant_map[i][j - 1]
                if top_plant == p and top_region_ref:
                    if left_plant == p and left_region_ref:
                        # Current plant borders two regions of the same plant (top and left)
                        # Merge the two regions and add the current plant to result
                        self.merge_regions(top_region_ref, left_region_ref)
                        sides = 0
                        # Check right neighbor of top plant to know how the number of sides will change
                        if j + 1 >= self.width or self.plant_map[i-1][j+1] != p:
                            sides -= 2
                        # In this case, the perimeter doesn't change
                        self.add_to_region(top_region_ref, area=1, sides=sides)
                    else:
                        # Current plant only borders a region of the same plant on top
                        # Check left and right neighbors of top plant to know how the number
                        # of sides will change
                        sides = 0
                        if j + 1 < self.width and self.plant_map[i-1][j+1] == p:
                            sides += 2
                        if j - 1 >= 0 and self.plant_map[i-1][j-1] == p:
                            sides += 2
                        self.add_to_region(top_region_ref, area=1, perimeter=2, sides=sides)
                    self.region_map[i][j] = top_region_ref
                elif left_plant == p and left_region_ref:
                    # Current plant only borders a region of the same plant to the left
                    sides = 0
                    # Check top neighbor of left plant to know how the number of sides will change
                    if i - 1 >= 0 and self.plant_map[i-1][j-1] == p:
                        sides += 2
                    self.add_to_region(left_region_ref, area=1, perimeter=2, sides=sides)
                    self.region_map[i][j] = left_region_ref
                else:
                    # The current plant has no identical neighbors, create a new region
                    new_region = Region(plant=p)
                    self.regions.add(new_region)
                    new_region_ref = RegionReference(new_region)
                    self.region_refs.add(new_region_ref)
                    self.region_map[i][j] = new_region_ref

    def cost(self, criteria):
        total = 0
        for region in self.regions:
            total += region.cost(criteria)
        return total


In [None]:
def part_1(input_file_name):
    plant_map = get_plant_map(input_file_name)
    region_map = RegionMap(plant_map)
    print(region_map.cost("perimeter"))

In [None]:
part_1(EXAMPLE)

In [None]:
part_1(INPUT)

In [None]:
def part_2(input_file_name):
    plant_map = get_plant_map(input_file_name)
    region_map = RegionMap(plant_map)
    print(region_map.cost("sides"))

In [None]:
part_2(EXAMPLE)

In [None]:
part_2(INPUT)