# day 17

https://adventofcode.com/17/day/17

In [None]:
import logging
import logging.config
import os
import random

import yaml

In [None]:
with open('../logging.yaml') as fp:
    logging_config = yaml.load(fp, Loader=yaml.FullLoader)

logging.config.dictConfig(logging_config)

In [None]:
FNAME = os.path.join('data', 'day17.txt')

LOGGER = logging.getLogger('day17')

## part 1

### problem statement:

#### loading data

In [None]:
test_data = """2413432311323
3215453535623
3255245654254
3446585845452
4546657867536
1438598798454
4457876987766
3637877979653
4654967986887
4564679986453
1224686865563
2546548887735
4322674655533"""

In [None]:
def load_data(fname=FNAME):
    with open(fname) as fp:
        return fp.read().strip()

#### function def

In [None]:
from enum import Enum


class Direction(str, Enum):
    UP = "U"
    DOWN = "D"
    LEFT = "L"
    RIGHT = "R"


def dir_to_delta(direction: Direction) -> complex:
    match direction:
        case Direction.UP:
            return -1
        case Direction.DOWN:
            return 1
        case Direction.LEFT:
            return -1j
        case Direction.RIGHT:
            return 1j


def parse_data(data):
    return {i + j * 1j: int(char)
            for (i, line) in enumerate(data.strip().split('\n'))
            for (j, char) in enumerate(line.strip())}


m = parse_data(data=test_data)
start = 0 + 0j
end = max(m.keys(), key=abs)
end

In [None]:
import heapq
from dataclasses import dataclass


@dataclass
class DijkstraWalker:
    cost: int
    loc: complex
    direction: Direction = None
    
    @property
    def sortable_repr(self) -> tuple:
        return (self.cost, self.loc.real, self.loc.imag, self.direction)
    
    def __lt__(self, other: 'DijkstraWalker'):
        return self.sortable_repr < other.sortable_repr


def dir_to_turn_dirs(direction: Direction | None) -> list[Direction]:
    match direction:
        case Direction.UP | Direction.DOWN:
            return [Direction.LEFT, Direction.RIGHT]
        case Direction.LEFT | Direction.RIGHT:
            return [Direction.UP, Direction.DOWN]
        case None:
            return [Direction.LEFT, Direction.RIGHT, Direction.UP, Direction.DOWN]
        case _:
            raise ValueError()


def find_shortest_path_len(m: dict[complex, str], start: complex, end: complex,
                           min_steps: int = 1, max_steps: int = 3) -> int:
    map_max = max(m.keys(), key=abs)

    explore_heap = [DijkstraWalker(cost=0, loc=start, direction=None)]
    heapq.heapify(explore_heap)

    seen = set()

    while explore_heap:
        walker = heapq.heappop(explore_heap)
        
        if walker.loc == end:
            return walker.cost

        if (walker.loc, walker.direction) in seen:
            continue

        seen.add((walker.loc, walker.direction))

        # treat the current location as a place where we have decided to turn
        # for each of the directions we could have turned (given the previous
        # direction), take up to three steps and then add that new location to
        # the heapq
        for new_direction in dir_to_turn_dirs(walker.direction):
            new_loc = walker.loc
            new_cost = walker.cost
            delta = dir_to_delta(direction=new_direction)
            # take three steps and register a new heapq for turns at each step
            for i in range(max_steps):
                new_loc += delta
                try:
                    incurred_cost = m[new_loc]
                except KeyError:
                    # off the map, do nothing
                    break
                new_cost += incurred_cost
                if (i + 1) >= min_steps:
                    new_walker = DijkstraWalker(cost=new_cost, loc=new_loc, direction=new_direction)
                    heapq.heappush(explore_heap, new_walker)

In [None]:
def q_1(data: str, min_steps: int = 1, max_steps: int = 3):
    m = parse_data(data)
    end = max(m.keys(), key=abs)
    return find_shortest_path_len(m=m, start=0.0j, end=end, min_steps=min_steps, max_steps=max_steps)

#### tests

In [None]:
def test_q_1():
    LOGGER.setLevel(logging.DEBUG)
    assert q_1(test_data) == 102
    LOGGER.setLevel(logging.INFO)

In [None]:
test_q_1()

#### answer

In [None]:
q_1(load_data())

## part 2

### problem statement:

#### function def

In [None]:
def q_2(data):
    return q_1(data, min_steps=4, max_steps=10)

#### tests

In [None]:
test_data_2 = """111111111111
999999999991
999999999991
999999999991
999999999991"""

In [None]:
def test_q_2():
    LOGGER.setLevel(logging.DEBUG)
    assert q_2(test_data) == 94, q_2(test_data)
    assert q_2(test_data_2) == 71, q_2(test_data_2)
    LOGGER.setLevel(logging.INFO)

In [None]:
test_q_2()

#### answer

In [None]:
q_2(load_data())

fin