In [None]:
from tabulate import tabulate

EXAMPLE = "../example.txt"
INPUT = "../input.txt"

In [None]:
def get_map(input_file_name):
    map = []
    with open(input_file_name, 'r') as f:
        for line in f:
            map.append([int(c) for c in line.replace("\n", "").strip()])
    return map

In [None]:
map = get_map(EXAMPLE)
height = len(map)
width = len(map[0])
print(tabulate(map))

In [None]:
def find_trail_heads(map):
    trail_heads = set()
    height = len(map)
    width = len(map[0])
    for i in range(height):
        for j in range(width):
            if map[i][j] == 0:
                trail_heads.add((i, j))
    return trail_heads

In [None]:
trail_heads = find_trail_heads(map)
print(trail_heads)

In [None]:
def find_next_positions(map, current_row, current_col):
    height = len(map)
    width = len(map[0])
    next_positions = []
    directions = [(1, 0), (-1, 0), (0, 1), (0, -1)]
    for (i, j) in directions:
        row, col = current_row + i, current_col + j
        if 0 <= row < height and 0 <= col < width and map[row][col] == map[current_row][current_col] + 1:
            next_positions.append((row, col))
    return next_positions

In [None]:
def find_final_positions(map, start_row, start_col, unique = True):
    final_positions = []
    current_trails = [(start_row, start_col)]
    while current_trails:
        next_trails = []
        for (current_row, current_col) in current_trails:
            next_positions = find_next_positions(map, current_row, current_col)
            for next_row, next_col in next_positions:
                if map[next_row][next_col] == 9:
                    final_positions.append((next_row, next_col))
                else:
                    next_trails.append((next_row, next_col))
        current_trails = next_trails
    if unique:
        return set(final_positions)
    else:
        return final_positions
        

In [None]:
for (row, col) in trail_heads:
    final_positions = find_final_positions(map, row, col)
    print(final_positions)
    print(len(final_positions))

In [None]:
def part_1(input_file_name):
    map = get_map(input_file_name)
    trail_heads = find_trail_heads(map)
    result = 0
    for (row, col) in trail_heads:
        result += len(find_final_positions(map, row, col))
    print(result)

In [None]:
part_1(EXAMPLE)

In [None]:
part_1(INPUT)

In [None]:
def part_2(input_file_name):
    map = get_map(input_file_name)
    trail_heads = find_trail_heads(map)
    result = 0
    for (row, col) in trail_heads:
        result += len(find_final_positions(map, row, col, unique = False))
    print(result)

In [None]:
part_2(EXAMPLE)

In [None]:
part_2(INPUT)