In [1]:
from aocd.models import Puzzle

puzzle = Puzzle(year=2024, day=16)

def parses(data):
    lines = data.strip().split('\n')
    walls = set()
    start = None
    end = None
    for i, row in enumerate(lines):
        for j, v in enumerate(row):
            z = (i, j)
            if v == 'S':
                assert start is None
                start = z
            if v == 'E':
                assert end is None
                end = z
            if v == '#':
                walls.add(z)
    return walls, start, end

# import re
# def parses(data):
#     return [[int(i) for i in re.findall("-?\d+", line)] 
#              for line in data.strip().split('\n')]

data = parses(puzzle.input_data)

In [2]:
sample = parses("""###############
#.......#....E#
#.#.###.#.###.#
#.....#.#...#.#
#.###.#####.#.#
#.#.#.......#.#
#.#.#####.###.#
#...........#.#
###.#.#####.#.#
#...#.....#.#.#
#.#.#.###.#.#.#
#.....#...#.#.#
#.###.#.#.#.#.#
#S..#.....#...#
###############""")

In [3]:
sample2 = parses("""#################
#...#...#...#..E#
#.#.#.#.#.#.#.#.#
#.#.#.#...#...#.#
#.#.#.#.###.#.#.#
#...#.#.#.....#.#
#.#.#.#.#.#####.#
#.#...#.#.#.....#
#.#.#####.#.###.#
#.#.#.......#...#
#.#.###.#####.###
#.#.#...#.....#.#
#.#.#.#####.###.#
#.#.#.........#.#
#.#.#.#########.#
#S#.............#
#################""")

In [4]:
from heapq import heappush, heappop

def solve_a(data):
    walls, start, end = data
    heap = [(0, start, (0,1))]
    visited = {}

    while heap:
        cost, pos, dir = heappop(heap)
        if (pos, dir) in visited:
            continue
            
        if pos == end:
            return cost
        
        visited[pos, dir] = cost
        
        nextpos = (pos[0]+dir[0], pos[1]+dir[1])
        if nextpos not in walls and (nextpos, dir) not in visited:
            heappush(heap, (cost+1, nextpos, dir))
            
        dx, dy = dir
        for nextdir in [(dy, -dx), (-dy, dx)]:
            if (pos, nextdir) not in visited:
                heappush(heap, (cost+1000, pos, nextdir))

In [452]:
from heapq import heappush, heappop

def solve_b_eq(data):
    walls, start, end = data
    heap_start = [(0, start, (0,1))]
    heap_end = [(0, end, d ) for d in [(1,0),(-1,0),(0,1),(0,-1)]]
    cost_start, cost_end = {}, {}

    # run twice, from start and from end
    for heap, cost_from, target in [(heap_start, cost_start, end), (heap_end, cost_end, start)]:
        while heap:
            cost, pos, dir = heappop(heap)
            if (pos, dir) in cost_from:
                continue

            cost_from[pos, dir] = cost
            
            if pos == target:
                continue
#             print(locals())
            nextpos = (pos[0]+dir[0], pos[1]+dir[1])
            if nextpos not in walls and (nextpos, dir) not in cost_from:
                heappush(heap, (cost+1, nextpos, dir))

            dx, dy = dir
            for nextdir in [(dy, -dx), (-dy, dx)]:
                if (pos, nextdir) not in cost_from:
                    heappush(heap, (cost+1000, pos, nextdir))
    
    best_cost = min([cost_start[end, d] for d in [(1,0),(-1,0),(0,1),(0,-1)] if (end,d) in cost_start])
    best_nodes = set()
    for node, (dx,dy) in cost_start:
        if (node, (-dx,-dy)) in cost_end:
            total_cost = cost_start[node,(dx,dy)]+cost_end[node,(-dx,-dy)]
            if best_cost == total_cost:
                best_nodes.add(node)
    return len(best_nodes)

In [467]:
cost_start, cost_end, end = solve_b_eq(data)

518

In [456]:
cost_start[(13,2),(-1,0)]

1001

In [458]:
# cost_end

In [460]:
cost_end[(13,2),(1,0)]

KeyError: ((13, 2), (-1, 0))

In [462]:
len(best_nodes)

45

In [6]:
assert solve_a(sample) == 7036
assert solve_a(sample2) == 11048
assert solve_a(data) == 85480

In [7]:
def solve_b(data):
    walls, start, end = data
    start_dir = (0, 1)
    heap = [(0, start, start_dir, None, None)]
    visited = {}
    preds = {}

    ends, end_cost = [], None

    # Dijkstra's algorithm but we keep track of all paths to the end
    # by storing for each state all the previous states that lead to it
    # with the same cost

    # Here the state is defined by the position and the direction

    while heap:
        cost, pos, dir, prevpos, prevdir = heappop(heap)

        if (pos, dir) not in visited:
            preds[pos, dir] = [(prevpos, prevdir)]
        elif visited[pos, dir] == cost:
            preds[pos, dir].append((prevpos, prevdir))

        if (pos, dir) in visited:
            continue

        visited[pos, dir] = cost

        if pos == end:
            if end_cost is None:
                end_cost = cost
            if cost == end_cost:
                ends.append((pos, dir))
            continue

        nextpos = (pos[0] + dir[0], pos[1] + dir[1])
        if nextpos not in walls and (nextpos, dir) not in visited:
            heappush(heap, (cost + 1, nextpos, dir, pos, dir))

        dx, dy = dir
        for nextdir in [(dy, -dx), (-dy, dx)]:
            if (pos, nextdir) not in visited:
                heappush(heap, (cost + 1000, pos, nextdir, pos, dir))

    preds[start, start_dir] = []

    # We do backwards DFS from the end to the start using the predecessors map
    stack = ends
    unique = set([(end, None)])
    while stack:
        node = stack.pop()
        for pred in preds[node]:
            if pred not in unique:
                unique.add(pred)
                stack.append(pred)
    # We count the number of unique positions, ignoring the direction
    unique_pos = set([n for n, _ in unique])
    return len(unique_pos)

In [8]:
assert solve_b(sample) == 45
assert solve_b(sample2) == 64
assert solve_b(data) == 518

In [9]:
sample_evil = parses("""###########################
#######################..E#
######################..#.#
#####################..##.#
####################..###.#
###################..##...#
##################..###.###
#################..####...#
################..#######.#
###############..##.......#
##############..###.#######
#############..####.......#
############..###########.#
###########..##...........#
##########..###.###########
#########..####...........#
########..###############.#
#######..##...............#
######..###.###############
#####..####...............#
####..###################.#
###..##...................#
##..###.###################
#..####...................#
#.#######################.#
#S........................#
###########################""")

In [11]:
assert solve_a(sample_evil) == 21148
assert solve_b(sample_evil) == 149

In [13]:
sample_evil2 = parses("""########################################################
#.........#.........#.........#.........#.........#...E#
#.........#.........#.........#.........#.........#....#
#....#....#....#....#....#....#....#....#....#....#....#
#....#....#....#....#....#....#....#....#....#....#....#
#....#....#....#....#....#....#....#....#....#....#....#
#....#....#....#....#....#....#....#....#....#....#....#
#....#.........#.........#.........#.........#.........#
#S...#.........#.........#.........#.........#.........#
########################################################""")

In [14]:
assert solve_a(sample_evil2) == 21110
assert solve_b(sample_evil2) == 264

# Networkx  for Partb

In [None]:
import networkx as nx
from aocd.models import Puzzle

def parses(data):
    lines = data.strip().split('\n')
    walls = set()
    start = None
    end = None
    G = nx.DiGraph()
    
    # First pass: identify special positions
    for i, row in enumerate(lines):
        for j, v in enumerate(row):
            if v == '#':
                walls.add((i,j))
            elif v == 'S':
                assert start is None
                start = i + 1j*j
            elif v == 'E':
                assert end is None
                end = i + 1j*j
    
    # Second pass: build graph
    for i, row in enumerate(lines):
        for j, v in enumerate(row):
            if (i,j) not in walls:  # Only process non-wall positions
                z = i + 1j*j
                for orientation in (1, -1, 1j, -1j):
                    pos_orient = (z, orientation)
                    G.add_node(pos_orient)
                    
                    # Add turning edges
                    G.add_edge((z, orientation), (z, orientation*1j), weight=1000)
                    G.add_edge((z, orientation), (z, orientation*-1j), weight=1000)
                    
                    # Add forward movement edges
                    next_pos = z + orientation
                    if 0 <= next_pos.real < len(lines) and 0 <= next_pos.imag < len(lines[0]):
                        if (int(next_pos.real), int(next_pos.imag)) not in walls:
                            G.add_edge(pos_orient, (next_pos, orientation), weight=1)
    
    # Add final node and its connections
    G.add_node(end)
    for orientation in (1, -1, 1j, -1j):
        G.add_edge((end, orientation), end, weight=0)
    
    return G, start, end

# Test the implementation
sample = parses("""###############
#.......#....E#
#.#.###.#.###.#
#.....#.#...#.#
#.###.#####.#.#
#.#.#.......#.#
#.#.#####.###.#
#...........#.#
###.#.#####.#.#
#...#.....#.#.#
#.#.#.###.#.#.#
#.....#...#.#.#
#.###.#.#.#.#.#
#S..#.....#...#
###############""")

def solve_b(data):
    G, start, end = data

    # Find all shortest paths from start (with initial orientation right)
    paths = list(nx.all_shortest_paths(G, source=(start, 1j), target=end, weight='weight'))
    allps = set()
    for path in paths:
        allps |= set([i for i, _ in path[:-1]])
    return len(allps)

In [None]:
solve_b(sample)

In [None]:
data = parses(puzzle.input_data)
solve_b(data)

In [None]:
# Previous attempts

In [373]:
from aocd.models import Puzzle

puzzle = Puzzle(year=2024, day=16)

def parses(data):
    lines = data.strip().split('\n')
#     print(len(lines), len(lines[0]))
    walls = set()
    start = None
    end = None
    for i, row in enumerate(lines):
        for j, v in enumerate(row):
            z = i+1j*j
            if v == 'S':
                assert start is None
                start = z
            if v == 'E':
                assert end is None
                end = z
            if v == '#':
                walls.add(z)
    return walls, start, end

# import re
# def parses(data):
#     return [[int(i) for i in re.findall("-?\d+", line)] 
#              for line in data.strip().split('\n')]

data = parses(puzzle.input_data)

In [391]:
sample = parses("""###############
#.......#....E#
#.#.###.#.###.#
#.....#.#...#.#
#.###.#####.#.#
#.#.#.......#.#
#.#.#####.###.#
#...........#.#
###.#.#####.#.#
#...#.....#.#.#
#.#.#.###.#.#.#
#.....#...#.#.#
#.###.#.#.#.#.#
#S..#.....#...#
###############""")

In [None]:
def solve_a(data):
    walls, start, end = data
    start_dir = 1j
    
    heap = ()

In [269]:
def heuristic(a, b):
    return abs(a.real-b.real) + abs(a.imag-b.imag)

from heapq import heappush, heappop

def solve_a(data):
    walls, start, end = data
    start_dir = 1j
    stack = []
    
    visited = set()
    i = 0
    heap = [(heuristic(start, end), 0, None, start, start_dir)]
#     print(end)
    while heap:
#         if i % 1000:
#             print(len(heap), len(visited))
#         print(heap)
#         print(heap)
        _, cost, _, x, v = heappop(heap)
#         render(walls, x, v)
        if x == end:
            return cost
        x2 = x + v
        if (x2, v) not in visited and x2 not in walls:
        
            visited.add((x2, v))
            i += 1
            heappush(heap,(
                cost+1+heuristic(x2,end), cost+1, i, x2, v
            ))
        
        for rot in (1j, -1j):
            v2 = v*rot
            if (x, v2) not in visited:
                visited.add((x,v2))
                i += 1
                heappush(heap, (
                    cost+1000+heuristic(x, end), cost+1000, i, x, v2
                ))

In [270]:
sample = parses("""#####
#  E#
# # #
#S  #
#####""")

5 5


In [267]:
def heuristic(a, b):
    return abs(a.real-b.real) + abs(a.imag-b.imag)

from heapq import heappush, heappop

def solve_b(data):
    walls, start, end = data
    start_dir = 1j
    stack = []
    
    visited = set()
    i = 0
    heap = [(heuristic(start, end), 0, None, start, start_dir, [(start, start_dir)])]
#     print(end)
    mincost = None
    best_cost = {(start, start_dir): 0}
    while heap:
#         if i % 1000:
#             print(len(heap), len(visited))
#         print(heap)
#         print(heap)
        _, cost, _, x, v, path = heappop(heap)
        if (x,v) not in best_cost:
            best_cost[x,v] = cost
#         render(walls, x, v)
        if x == end:
#             best_cost[x,v]
            best_path = path
            break

        x2 = x + v
        if (x2, v) not in visited and x2 not in walls:
        
            visited.add((x2, v))

            i += 1
            heappush(heap,(
                cost+1+heuristic(x2,end), cost+1, i, x2, v, path + [(x2, v)]
            ))
        
        for rot in (1j, -1j):
            v2 = v*rot
            if (x, v2) not in visited:
                visited.add((x,v2))
                i += 1
                heappush(heap, (
                    cost+1000+heuristic(x, end), cost+1000, i, x, v2, path + [(x, v)]
                ))
#     return best_path
    best_path = set(best_path)
    best_cost = {k: v for k, v in best_cost.items() if k in best_path} 
#     print(start, start_dir)
#     return best_cost

    visited = set()
    i = 0
    heap = [(heuristic(start, end), 0, None, start, start_dir, [(start, start_dir)])]
#     print(end)
    mincost = None
#     best_cost = {(start, start_dir): 0}
    print(best_path)
    allps = set()
    while heap:
#         if i % 1000:
#             print(len(heap), len(visited))
#         print(heap)
#         print(heap)
        _, cost, _, x, v, path = heappop(heap)
        if (x,v) in best_cost and best_cost[x,v] == cost:
            allps |= set(path)
#         if (x,v) not in best_cost:
#             best_cost[x,v] = cost
#         render(walls, x, v)
        if x == end:
            best_path = path
            break

        x2 = x + v
        if (x2, v) not in visited and x2 not in walls:
        
            visited.add((x2, v))

            i += 1
            heappush(heap,(
                cost+1+heuristic(x2,end), cost+1, i, x2, v, path + [(x2, v)]
            ))
        
        for rot in (1j, -1j):
            v2 = v*rot
            if (x, v2) not in visited:
                visited.add((x,v2))
                i += 1
                heappush(heap, (
                    cost+1000+heuristic(x, end), cost+1000, i, x, v2, path + [(x, v)]
                ))
    return len(set([x for x, _ in allps]))

In [268]:
sample = parses("""#####
#  E#
# # #
#S  #
#####""")

5 5


In [264]:
solve_b(sample)

5

In [228]:
solve_b(sample2)

49

In [221]:
solve_a(data)

85480

In [97]:
solve_a(sample2)

1
2
4
5
7
9
11
13
15
17
19
21
23
25
27
29
31
32
33
34
35
36
36
36
36
36
36
36
36
36
36
36
36
36
36
36
36
36
36
36
36
37
37
37
39
39
39
40
40
40
40
41
43
44
45
46
47
48
49
50
51
52
54
55
57
58
60
61
63
64
66
67
69
70
72
73
75
76
78
79
81
82
84
85
85
85
87
88
89
90
92
93
95
97
98
100
100
100
102
102
103
103
104
105
106
107
108
109
110
111
112
113
114
115
117
118
119
121
123
125
127
129
131
131
133
135
137
139
141
142
142
142
142
142
142
142
142
142
142
142
142
142
142
142
142
142
141
140
139
138
137
136
135
134
133
132
131
130
129
128
127
126
125
124
123
122
121
120
119
118
117
116
115
114
113
112
111
110
109
108
107
106
105
104
103
102
101
100
99
98
100
101
100
99
98
97
98
98
97
96
95
94
93
92
92
92
92
92
91
90
92
93
92
91
92
92
92
92
92
91
90
89
88
87
86
85
84
83
82
81
80
79
78
77
76
75
74
73
72
71
70
69
69
69
70
69
68
67
66
66
65
64
65
66
66
68
69
68
67
67
67
66
65
65
65
65
65
65
65
65
65
65
65
65
65
65
66
66
68
70
72
74
76
78
80
82
84
86
88
89
90
91
92
93
94
95
96
97
98
99
100
101
10

11048

In [124]:
solve_b(sample2)

- [(15+1j), (14+1j), (13+1j), (12+1j), (11+1j), (10+1j), (9+1j), (8+1j), (7+1j), (6+1j), (5+1j), (5+2j), (5+3j), (6+3j), (7+3j), (8+3j), (9+3j), (10+3j), (11+3j), (12+3j), (13+3j), (14+3j), (15+3j), (15+4j), (15+5j), (14+5j), (13+5j), (12+5j), (11+5j), (11+6j), (11+7j), (10+7j), (9+7j), (9+8j), (9+9j), (9+10j), (9+11j), (8+11j), (7+11j), (7+12j), (7+13j), (7+14j), (7+15j), (6+15j), (5+15j), (4+15j), (3+15j), (2+15j), (1+15j)]
- [(15+1j), (14+1j), (13+1j), (12+1j), (11+1j), (10+1j), (9+1j), (8+1j), (7+1j), (6+1j), (5+1j), (5+2j), (5+3j), (6+3j), (7+3j), (8+3j), (9+3j), (10+3j), (11+3j), (12+3j), (13+3j), (14+3j), (15+3j), (15+4j), (15+5j), (14+5j), (13+5j), (12+5j), (11+5j), (11+6j), (11+7j), (10+7j), (9+7j), (9+8j), (9+9j), (8+9j), (7+9j), (6+9j), (5+9j), (5+10j), (5+11j), (5+12j), (5+13j), (4+13j), (3+13j), (2+13j), (1+13j), (1+14j), (1+15j)]


62

In [None]:
# def solve_b(data):
#     walls, start, end = data
    
#     pos = {(start, 1j): 0}
#     visited = start
#     while True:

In [134]:
from toolz import valmap

In [325]:
def heuristic(a, b):
    return abs(a.real-b.real) + abs(a.imag-b.imag)

from heapq import heappush, heappop

def solve_b_1(data):
    walls, start, end = data
    start_dir = 1j
    stack = []
    
    visited = set()
    i = 0
    heap = [(heuristic(start, end), 0, None, start, start_dir)]
    
    best_cost = {}
    arrived_from = defaultdict(lambda: defaultdict(lambda: float('inf')))
#     print(end)
    while heap:
#         if i % 1000:
#             print(len(heap), len(visited))
#         print(heap)
#         print(heap)
        _, cost, _, x, v = heappop(heap)
#         render(walls, x, v)
        if x == end:
#             return dict(valmap(dict, arrived_from)), end
            new = {}
            for src, d in arrived_from.items():
                d2 = {k:v for k, v in d.items() if v != float('inf')}
                if len(d2) > 0:
                    new[src] = d2
            return new, end
            
            
        x2 = x + v
        if x2 not in walls:
            arrived_from[x2,v][x,v] = min(arrived_from[x2][x], cost+1)
        if (x2, v) not in visited and x2 not in walls:
            visited.add((x2, v))
            i += 1
            heappush(heap,(
                cost+1+heuristic(x2,end), cost+1, i, x2, v
            ))
        
        for rot in (1j, -1j):
            v2 = v*rot
            arrived_from[x,v2][x,v] = min(arrived_from[x2][x], cost+1)
            if (x, v2) not in visited:
                visited.add((x,v2))
                i += 1
                heappush(heap, (
                    cost+1000+heuristic(x, end), cost+1000, i, x, v2
                ))
                

In [326]:
def solve_b_2(data):
    arrived, end = data
    arrived2 = {}
    for dst in arrived:
        for src, val in arrived[dst].items():
            arrived2.setdefault(dst[0], {})
            arrived2[dst[0]][src[0]] = val % 1000
    arrived = arrived2

    visited = set()
    stack = [end]
    while stack:
        dst = stack.pop()
        if dst in arrived:
            m = min(arrived[dst].values())
            for src in [k for k, v in arrived[dst].items() if v == m]:
                if src not in visited:
                    visited.add(src)
                    stack.append(src)
    return len(visited)+1

In [327]:
def solve_b(data):
    return solve_b_2(solve_b_1(data))
    

In [332]:
solve_b(data)

589

In [330]:
solve_b(sample)

45

In [331]:
solve_b(sample2)

64

In [None]:
def solve_b(data):
    walls, start, end = data
    start_dir = 1j
    
    states = {0: (start, start_dir), set([start])}
    visited = set()
    while states:
        new_states = {}
        for state in states.items()
        
        

In [333]:
import networkx

In [305]:
arrived, end = solve_b(sample)


In [307]:
arrived[end, -1]

{((2+13j), (-1+0j)): 7036}

44

In [296]:
arrived

{((13+2j), 1j): {((13+1j), 1j): 1,
  ((13+2j), (-1+0j)): 2,
  ((13+2j), (1-0j)): 2},
 ((13+1j), (-1+0j)): {((13+1j), 1j): 1, ((13+1j), (-0-1j)): 1},
 ((13+1j), (1-0j)): {((13+1j), 1j): 1,
  ((13+1j), (-0-1j)): 1,
  ((12+1j), (1-0j)): 2},
 ((13+3j), 1j): {((13+2j), 1j): 2,
  ((13+3j), (-1+0j)): 3,
  ((13+3j), (1-0j)): 3},
 ((13+2j), (-1+0j)): {((13+2j), 1j): 2, ((13+2j), (-0-1j)): 2},
 ((13+2j), (1-0j)): {((13+2j), 1j): 2, ((13+2j), (-0-1j)): 2},
 ((13+3j), (-1+0j)): {((13+3j), 1j): 3, ((13+3j), (-0-1j)): 3},
 ((13+3j), (1-0j)): {((13+3j), 1j): 3, ((13+3j), (-0-1j)): 3},
 ((12+1j), (-1+0j)): {((13+1j), (-1+0j)): 1,
  ((12+1j), (-0-1j)): 2,
  ((12+1j), 1j): 2},
 ((13+1j), (-0-1j)): {((13+1j), (-1+0j)): 1,
  ((13+1j), (1-0j)): 1,
  ((13+2j), (-0-1j)): 2},
 ((13+1j), 1j): {((13+1j), (-1+0j)): 1, ((13+1j), (1-0j)): 1},
 ((13+2j), (-0-1j)): {((13+2j), (-1+0j)): 2,
  ((13+2j), (1-0j)): 2,
  ((13+3j), (-0-1j)): 3},
 ((11+1j), (-1+0j)): {((12+1j), (-1+0j)): 2,
  ((11+1j), (-0-1j)): 3,
  ((11+1j

In [292]:
arrived2

{((13+2j), 1j): {((13+1j), 1j): 1,
  ((13+2j), (-1+0j)): 2,
  ((13+2j), (1-0j)): 2},
 ((13+1j), (-1+0j)): {((13+1j), 1j): 1, ((13+1j), (-0-1j)): 1},
 ((13+1j), (1-0j)): {((13+1j), 1j): 1,
  ((13+1j), (-0-1j)): 1,
  ((12+1j), (1-0j)): 2},
 ((13+3j), 1j): {((13+2j), 1j): 2,
  ((13+3j), (-1+0j)): 3,
  ((13+3j), (1-0j)): 3},
 ((13+2j), (-1+0j)): {((13+2j), 1j): 2, ((13+2j), (-0-1j)): 2},
 ((13+2j), (1-0j)): {((13+2j), 1j): 2, ((13+2j), (-0-1j)): 2},
 ((13+3j), (-1+0j)): {((13+3j), 1j): 3, ((13+3j), (-0-1j)): 3},
 ((13+3j), (1-0j)): {((13+3j), 1j): 3, ((13+3j), (-0-1j)): 3},
 ((12+1j), (-1+0j)): {((13+1j), (-1+0j)): 1,
  ((12+1j), (-0-1j)): 2,
  ((12+1j), 1j): 2},
 ((13+1j), (-0-1j)): {((13+1j), (-1+0j)): 1,
  ((13+1j), (1-0j)): 1,
  ((13+2j), (-0-1j)): 2},
 ((13+1j), 1j): {((13+1j), (-1+0j)): 1, ((13+1j), (1-0j)): 1},
 ((13+2j), (-0-1j)): {((13+2j), (-1+0j)): 2,
  ((13+2j), (1-0j)): 2,
  ((13+3j), (-0-1j)): 3},
 ((11+1j), (-1+0j)): {((12+1j), (-1+0j)): 2,
  ((11+1j), (-0-1j)): 3,
  ((11+1j

In [183]:
def heuristic(a, b):
    return abs(a.real-b.real) + abs(a.imag-b.imag)

from heapq import heappush, heappop

def solve_b(data):
    walls, start, end = data
    start_dir = 1j
    stack = []
    
    visited = set()
    i = 0
    heap = [(heuristic(start, end), 0, None, start, start_dir)]
    
    best_cost = {}
    arrived_from = defaultdict(lambda: defaultdict(lambda: float('inf')))
#     print(end)
    while heap:
#         if i % 1000:
#             print(len(heap), len(visited))
#         print(heap)
#         print(heap)
        _, cost, _, x, v = heappop(heap)
#         render(walls, x, v)
        if x == end:
#             return dict(valmap(dict, arrived_from)), end
            new = {}
            for src, d in arrived_from.items():
                d2 = {k:v for k, v in d.items() if v != float('inf')}
                if len(d2) > 0:
                    new[src] = d2
            return new, e
            
            
        x2 = x + v
        if x2 not in walls:
            arrived_from[x2,v][x,v] = min(arrived_from[x2][x], cost+1)
        if (x2, v) not in visited and x2 not in walls:
            visited.add((x2, v))
            i += 1
            heappush(heap,(
                cost+1+heuristic(x2,end), cost+1, i, x2, v
            ))
        
        for rot in (1j, -1j):
            v2 = v*rot
            arrived_from[x,v2][x,v] = min(arrived_from[x2][x], cost+1)
            if (x, v2) not in visited:
                visited.add((x,v2))
                i += 1
                heappush(heap, (
                    cost+1000+heuristic(x, end), cost+1000, i, x, v2
                ))
                

In [212]:
arrived, e = solve_b(sample)

In [213]:
arrived

{((13+2j), 1j): {((13+1j), 1j): 2001,
  ((13+2j), (-1+0j)): 1002,
  ((13+2j), (1-0j)): 1002},
 ((13+1j), (-1+0j)): {((13+1j), 1j): 2001, ((13+1j), (-0-1j)): 2001},
 ((13+1j), (1-0j)): {((13+1j), 1j): 2001,
  ((13+1j), (-0-1j)): 2001,
  ((12+1j), (1-0j)): 3002},
 ((13+3j), 1j): {((13+2j), 1j): 2,
  ((13+3j), (-1+0j)): 1003,
  ((13+3j), (1-0j)): 1003},
 ((13+2j), (-1+0j)): {((13+2j), 1j): 2, ((13+2j), (-0-1j)): 2002},
 ((13+2j), (1-0j)): {((13+2j), 1j): 2, ((13+2j), (-0-1j)): 2002},
 ((13+3j), (-1+0j)): {((13+3j), 1j): 3, ((13+3j), (-0-1j)): 2003},
 ((13+3j), (1-0j)): {((13+3j), 1j): 3, ((13+3j), (-0-1j)): 2003},
 ((12+1j), (-1+0j)): {((13+1j), (-1+0j)): 1001,
  ((12+1j), (-0-1j)): 2002,
  ((12+1j), 1j): 2002},
 ((13+1j), (-0-1j)): {((13+1j), (-1+0j)): 1001,
  ((13+1j), (1-0j)): 1001,
  ((13+2j), (-0-1j)): 2002},
 ((13+1j), 1j): {((13+1j), (-1+0j)): 1001, ((13+1j), (1-0j)): 1001},
 ((13+2j), (-0-1j)): {((13+2j), (-1+0j)): 1002,
  ((13+2j), (1-0j)): 1002,
  ((13+3j), (-0-1j)): 2003},
 ((1

In [214]:
visited = set()
stack = [(e,1),(e,-1),(e,1j),(e,-1)]
while stack:
    pos = stack.pop()
    if pos not in arrived:
        print(pos)
        continue
    m = min(arrived[pos].values())
    prev = {k: v for k, v in arrived[pos].items() if v%1000 == m%1000}
#     print(len(prev))
    for p in prev:
        if p not in visited:
            visited.add(p)
            stack.append(p)

((1+13j), 1j)
((1+13j), 1)


In [215]:
len(set([x for x, _ in visited]))

36

In [216]:
len(visited)

55

In [158]:
len(visited)

74

In [137]:
af

{(13+2j): {(13+1j): 1, (13+3j): 2003},
 (13+3j): {(13+2j): 2},
 (12+1j): {(13+1j): 1001, (11+1j): 3003},
 (11+1j): {(12+1j): 1002, (10+1j): 3004, (11+2j): 4004},
 (10+1j): {(11+1j): 1003, (9+1j): 3005},
 (9+1j): {(10+1j): 1004, (9+2j): 4006},
 (13+1j): {(13+2j): 2002, (12+1j): 3002},
 (11+2j): {(11+1j): 2003, (11+3j): 4005},
 (11+3j): {(11+2j): 2004, (10+3j): 3008, (11+4j): 4006},
 (9+2j): {(9+1j): 2005, (9+3j): 4007},
 (11+4j): {(11+3j): 2005, (11+5j): 4007},
 (9+3j): {(9+2j): 2006, (10+3j): 3006, (8+3j): 5008},
 (11+5j): {(11+4j): 2006, (10+5j): 5008, (12+5j): 5008},
 (10+3j): {(11+3j): 3005, (9+3j): 3007},
 (8+3j): {(9+3j): 3007, (7+3j): 5009},
 (10+5j): {(11+5j): 3007, (9+5j): 5009},
 (12+5j): {(11+5j): 3007, (13+5j): 5009},
 (7+3j): {(8+3j): 3008, (7+4j): 4012, (6+3j): 5010, (7+2j): 6010},
 (9+5j): {(10+5j): 3008, (8+5j): 5010, (9+6j): 6010},
 (6+3j): {(7+3j): 3009, (5+3j): 5011},
 (8+5j): {(9+5j): 3009, (7+5j): 5011},
 (5+3j): {(6+3j): 3010},
 (7+5j): {(8+5j): 3010, (7+4j): 4010,

In [None]:
    qqaa

In [92]:
solve_a(data)

KeyboardInterrupt: 

In [74]:
def render(walls, pos, ori):
    print('🔥', {1:'v', -1: '^', 1j: '>', -1j: '<'}[ori])
    N = int(max(i.real for i in walls))+1
    M = int(max(i.imag for i in walls))+1
    s = ''
    for i in range(N):
        for j in range(M):
            z = i+1j*j
            if z in walls:
                s += '#'
            elif z == pos:
                s += '\033[91m@\033[0m'
            else:
                s += ' '
        s += '\n'
    print(s)

In [40]:
render(sample[0], None)

(14, 14)

In [34]:
solve_a(data)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



85395

In [31]:
sample2 = parses("""#################
#...#...#...#..E#
#.#.#.#.#.#.#.#.#
#.#.#.#...#...#.#
#.#.#.#.###.#.#.#
#...#.#.#.....#.#
#.#.#.#.#.#####.#
#.#...#.#.#.....#
#.#.#####.#.###.#
#.#.#.......#...#
#.#.###.#####.###
#.#.#...#.....#.#
#.#.#.#####.###.#
#.#.#.........#.#
#.#.#.#########.#
#S#.............#
#################""")

In [32]:
solve_a(sample2)

In [None]:
solve_a(sample)

In [None]:
solve_a(data)

In [None]:
def solve_b(data):
    pass

In [None]:
solve_b(sample)

In [None]:
solve_b(data)