# Setup

In [None]:
from dotenv import load_dotenv

_ = load_dotenv()

In [None]:
from aocd import submit
from aocd.models import Puzzle

In [None]:
puzzle = Puzzle(year=2023, day=5)

In [None]:
example_input, example_soln_a, example_soln_b = (
    puzzle.examples[0].input_data,
    *puzzle.examples[0].answers,
)
input = puzzle.input_data

# Part A

In [None]:
def solution_a(input: str):
    input = [line for line in input.split("\n") if len(line) > 0]
    mappings = {}
    for line in input:
        if "seeds: " in line:
            seeds = list(map(int, line.split("seeds: ")[1].split(" ")))
        elif " map:" in line:
            source, _, destination = line.split(" map:")[0].split("-")
            mappings[(source, destination)] = []
        elif line[0].isnumeric():
            destination_start, source_start, map_range = list(map(int, line.split(" ")))
            mappings[(source, destination)] += [
                {
                    "source_start": source_start,
                    "destination_start": destination_start,
                    "map_range": map_range,
                }
            ]

    seed_locations = []
    for seed in seeds:
        seed_location = seed
        for (source, destination), map_ranges in mappings.items():
            for map_range in map_ranges:
                if (
                    map_range["source_start"] <= seed_location
                    and map_range["source_start"] + map_range["map_range"]
                    >= seed_location
                ):
                    seed_location = (
                        seed_location
                        - map_range["source_start"]
                        + map_range["destination_start"]
                    )
                    break
        seed_locations += [seed_location]

    return min(seed_locations)

In [None]:
print("Part A example solution:", solution_a(input=example_input))
print("Part A example answer:", example_soln_a)

In [None]:
solution_a_output = solution_a(input=input)
print("Part A solution:", solution_a_output, "\n" + "-" * 60)
submit(solution_a_output, day=5, year=2023, part="a")

# Part B

In [None]:
def solution_b(input: str):
    import itertools

    input = [line for line in input.split("\n") if len(line) > 0]

    def invert_mapping(mappings_list: list):
        def inverse(y: int) -> list[int]:
            xs = [y]
            for source_start, destination_start, width in mappings_list:
                if y >= source_start and y <= source_start + width:
                    xs = [x for x in xs if x != y]
                if y >= destination_start and y <= destination_start + width:
                    if y - destination_start + source_start not in xs:
                        xs.append(y - destination_start + source_start)
            return xs

        return inverse

    mappings = {}
    for line in input:
        if "seeds: " in line:
            line = list(map(int, line.split("seeds: ")[1].split(" ")))
            seeds = [(line[2 * i], line[2 * i + 1]) for i in range(int(len(line) / 2))]
        elif " map:" in line:
            source, _, destination = line.split(" map:")[0].split("-")
            mappings[(source, destination)] = []
        elif line[0].isnumeric():
            destination_start, source_start, map_range = list(map(int, line.split(" ")))
            mappings[(source, destination)] += [
                [
                    source_start,
                    destination_start,
                    map_range,
                ]
            ]

    jump_points = []
    for _, mappings_list in list(mappings.items())[::-1]:
        jump_points = jump_points + list(
            itertools.chain(
                *[
                    invert_mapping(mappings_list)(jump_point)
                    for jump_point in jump_points
                ]
            )
        )
        jump_points = jump_points + list(
            itertools.chain(
                *[
                    [mapping_list[0], mapping_list[0] + mapping_list[2]]
                    for mapping_list in mappings_list
                ]
            )
        )

    jump_points = list(set(jump_points))
    seed_jump_points = []
    for jump_point in jump_points:
        for seed_range in seeds:
            if jump_point >= seed_range[0] and jump_point <= sum(seed_range):
                seed_jump_points.append(jump_point)
            seed_jump_points += [seed_range[0], sum(seed_range)]

    seed_jump_points = list(set(seed_jump_points))

    seed_locations = []
    seeds = seed_jump_points
    for seed in seeds:
        seed_location = seed
        for (source, destination), map_ranges in mappings.items():
            for map_range in map_ranges:
                if (
                    map_range[0] <= seed_location
                    and map_range[0] + map_range[2] >= seed_location
                ):
                    seed_location = seed_location - map_range[0] + map_range[1]
                    break
        seed_locations += [seed_location]

    return min(seed_locations)

In [None]:
print("Part B example solution:", solution_b(input=example_input))
print("Part B example answer:", example_soln_b)

In [None]:
solution_b_output = solution_b(input=input)
print("Part B solution:", solution_b_output, "\n" + "-" * 60)
submit(solution_b_output, day=5, year=2023, part="b")