In [1]:
from aocd.models import Puzzle

puzzle = Puzzle(year=2021, day=23)

def parses(text):
    return text

data = parses(puzzle.input_data)

In [2]:
data.index("A")

47

In [3]:
sample = parses("""#############
#...........#
###B#C#B#D###
  #A#D#C#A#
  #########""")

In [4]:
data

'#############\n#...........#\n###D#C#D#B###\n  #B#A#A#C#\n  #########'

In [5]:
energies = {c:10**i for i, c in enumerate('ABCD')}

In [6]:
energies

{'A': 1, 'B': 10, 'C': 100, 'D': 1000}

In [7]:
def sign(x):
    if x > 0:
        return 1
    elif x < 0:
        return -1
    return 0

In [32]:
from functools import cached_property

In [86]:
from dataclasses import dataclass
from typing import Dict, Tuple
import copy

@dataclass
class State:
    positions: Dict[Tuple[int, int], str]
    final: Dict[str, Tuple[int, int]]
    energy: int = 0
    maxy: int = 3

    HALLWAY = set([(1, 1), (2, 1), (4, 1), (6, 1), (8, 1), (10, 1), (11, 1)])
    ENERGIES = {"A": 1, "B": 10, "C": 100, "D": 1000}

    def check_collisions(self, start, end):
        c = self.positions[start]
        x, y = start
        x2, y2 = end
        e = 0
        # Move into hallway (if needed)
        while y != 1:
            y -= 1
            if (x, y) in self.positions:
                return None
            e += self.ENERGIES[c]
        # Move along the hallway (if needed)
        while x != x2:
            x += sign(x2 - x)
            if (x, y) in self.positions:
                return None
            e += self.ENERGIES[c]
        # Move into the Room (if needed)
        while y != y2:
            y += sign(y2 - y)
            if (x, y) in self.positions:
                return None
            e += self.ENERGIES[c]
        return e

    def move(self, start, end):
        if end in self.positions:
            return None
        if energy := self.check_collisions(start, end):
            new_positions = self.positions.copy()
            new_final = self.final.copy()
            new_energy = self.energy + energy
            new_positions[end] = new_positions.pop(start)
            # clean if moved into final room
            c = self.positions[start]
            if end == self.final[c]:
                new_positions.pop(end)
                x, y = new_final[c]
                if y == 2:
                    new_final.pop(c)
                else:
                    new_final[c] = (x, y - 1)
            return State(new_positions, new_final, new_energy, self.maxy)

    def all_moves(self, pos):
        c = self.positions[pos]
        # If we can move into final position, it's optimal to do so
        if s := self.move(pos, self.final[c]):
            return [s]
        # We can only choose if we haven't moved into the hallway yet
        possible_moves = []
        if pos not in self.HALLWAY:
            for pos2 in self.HALLWAY:
                if pos2 not in self.positions and (s := self.move(pos, pos2)):
                    possible_moves.append(s)
        return possible_moves

    def children(self):
        return sum([self.all_moves(pos) for pos in self.positions], [])

    def check_initial(self):
        # Move from initial to final positions that already correct
        new_positions = self.positions.copy()
        new_final = self.final.copy()
        for k in range(self.maxy, 1, -1):
            for c, (x, y) in list(new_final.items()):
                if y == k and new_positions.get((x, y), None) == c:
                    new_positions.pop((x, y))
                    if k == 2:
                        new_final.pop(c)
                    else:
                        new_final[c] = (x, k - 1)
        return State(new_positions, new_final, self.energy, self.maxy)

    @staticmethod
    def fromstr(diagram):
        # Parse map into
        initial = {}
        for j, line in enumerate(diagram.strip().split("\n")):
            for i, v in enumerate(line):
                if v in "ABCD":
                    initial[i, j] = v
        maxy = max(j for _, j in initial)
        final = {c: (3 + 2 * i, maxy) for i, c in enumerate("ABCD")}
        return State(initial, final, maxy=maxy).check_initial()
    
    @cached_property
    def hash(self):
        return hash(tuple([(x, y, c) for (x, y), c in self.positions.items()]))

    def __hash__(self):
        return self.hash

    def __eq__(self, other):
        # Energy is not part of positions, otherwise we'd memoize with energy as well
        # final can be derived from positions
        return self.hash == other.hash
        
    def done(self):
        return len(self.final) == 0
        
    def remaining(self):
        return sum(y - 1 for _, y in self.final.values())
    
    def lower_bound_completion(self):
        final = self.final.copy()
        energy_lb = 0
        for (x,y), c in self.positions.items():
            x2, y2 = final[c]
            movement = abs(y-1) + abs(y2-1) + abs(x-x2)
            energy_lb += movement * self.ENERGIES[c]
            final[c] = (x2, y2-1)
        return energy_lb
    
    @cached_property
    def cost(self):
        return (self.energy+self.lower_bound_completion(), self.energy, self.remaining())

    def __lt__(self, other):
        return self.cost < other.cost

    def render(self):
        s = ""
        for j in range(self.maxy + 2):
            for i in range(13):
                if (i, j) in self.positions:
                    s += self.positions[i, j]
                elif 2 <= j <= self.maxy and i in [3, 5, 7, 9]:
                    c = "ABCD"[(i - 3) // 2]
                    if c not in self.final or j > self.final[c][1]:
                        s += c
                    else:
                        s += "."
                else:
                    s += "." if j == 1 and 0 < i < 12 else "#"
            s += "\n"
        print(s)

In [94]:
def solve_a(data):
    from heapq import heappop, heappush

    heap = [State.fromstr(data)]
    visited = defaultdict(lambda: float("inf"))
    while heap:
        state = heappop(heap)
        if state.done():
            return state.energy
        for child in state.children():
            if child.energy < visited[child]:
                visited[child] = child.energy
                heappush(heap, child)


def solve_b(data):
    extra_rows = "  #D#C#B#A#\n  #D#B#A#C#\n"
    data = data[:42] + extra_rows + data[42:]
    return solve_a(data)

In [95]:
start = time.time()
print(solve_a(sample))
print(time.time()-start)

12521
0.12559819221496582


In [96]:
start = time.time()
print(solve_a(data))
print(time.time()-start)

14371
0.17653417587280273


In [97]:
start = time.time()
print(solve_b(sample))
print(time.time()-start)

44169
34.69020128250122


In [98]:
start = time.time()
print(solve_b(data))
print(time.time()-start)

40941
1.6879541873931885


In [23]:
import time

In [466]:
start = time.time()
print(solve_a(sample))
print(time.time()-start)

12521
8.376643896102905


In [467]:
start = time.time()
solve_b(sample)
print(time.time()-start)

176.0352418422699


In [468]:
start = time.time()
solve_b(data)
print(time.time()-start)

178.9384560585022


In [28]:
simple = """#############
#...........#
###B#A#C#D###
  #A#B#C#D#
  #A#B#C#D#
  #A#B#C#D#
  #########"""

In [29]:
x = State.fromstr(simple)

In [30]:
x.lower_bound_completion()

44

In [355]:
print(solve_b(sample))

#############
#~~.~.~.~.~~#
###B#C#B#D###
  #D#C#B#A#
  #D#B#A#C#
  #A#D#C#A#
  #########


In [358]:
sample

'#############\n#~~.~.~.~.~~#\n###B#C#B#D###\n  #A#D#C#A#\n  #########'

In [356]:
data

'#############\n#...........#\n###D#C#D#B###\n  #B#A#A#C#\n  #########'

In [323]:
solve_a(data)

14371

In [342]:
len("""#############
#...........#
###A#B#C#D###""")

41

In [344]:
sample[42]

' '

In [316]:
solve_a(data)

14371

In [45]:
simple = """#############
#...........#
###A#B#C#D###
  #A#B#C#D#
  #########"""

In [46]:
State.fromstr(simple).lower_bound_completion

6

In [16]:
x = State.fromstr(sample)

TypeError: State() takes no arguments

In [None]:
x.

In [227]:
from heapq import heappop, heappush
heap = [State.fromstr(sample)]

In [228]:
state = heappop(heap)
if state.done():
    print(state.energy)
for child in state.children():
    heappush(heap, child)

In [223]:
state.render()

#############
#...........#
###B#C#B#D###
###A#D#C#A###
#############
#############



In [225]:
for s in heap:
    s.render()

#############
#B..........#
###.#C#B#D###
###A#D#C#A###
#############
#############

#############
#.B.........#
###.#C#B#D###
###A#D#C#A###
#############
#############

#############
#...B.......#
###.#C#B#D###
###A#D#C#A###
#############
#############

#############
#.....B.....#
###.#C#B#D###
###A#D#C#A###
#############
#############

#############
#.......B...#
###.#C#B#D###
###A#D#C#A###
#############
#############

#############
#.........B.#
###.#C#B#D###
###A#D#C#A###
#############
#############

#############
#..........B#
###.#C#B#D###
###A#D#C#A###
#############
#############

#############
#.B.........#
###B#C#.#D###
###A#D#C#A###
#############
#############

#############
#.....B.....#
###B#C#.#D###
###A#D#C#A###
#############
#############

#############
#.........B.#
###B#C#.#D###
###A#D#C#A###
#############
#############

#############
#.....C.....#
###B#.#B#D###
###A#D#C#A###
#############
#############

#############
#.......C...#
###B#.#B#D###
###A#D#C#A###
#########

In [15]:
starting

{(1, 1): '.',
 (2, 1): '.',
 (3, 1): '.',
 (4, 1): '.',
 (5, 1): '.',
 (6, 1): '.',
 (7, 1): '.',
 (8, 1): '.',
 (9, 1): '.',
 (10, 1): '.',
 (11, 1): '.',
 (3, 2): 'B',
 (5, 2): 'C',
 (7, 2): 'B',
 (9, 2): 'D',
 (3, 3): 'A',
 (5, 3): 'D',
 (7, 3): 'C',
 (9, 3): 'A'}

In [23]:
x = parses("""#############
#...........#
###A#B#C#D###
  #A#B#C#D#
  #########
""")

In [None]:
def solve_a(positions):