# Day 18


In [None]:
def print_map(area_map):
    min_coords = min(area_map.keys())
    max_coords = max(area_map.keys())
    for y in range(min_coords[1], max_coords[1]+1):
        for x in range(min_coords[0], max_coords[0]+1):
            print(area_map[(x,y)].get("char", ""), end="")
        print("")
    
def parse_map(value, debug=False):
    area_map = dict()
    value = value.strip().splitlines()
    for y, line in enumerate(value):
        for x, char in enumerate(line):
            if debug:
                print(char, end="")
            pos = (x,y)
            tile = dict(pos=pos, char=char)
            if char == "#":
                tile["type"] = "wall"
            elif char in "@1234":
                tile["type"] = "me"
            elif char == ".":
                tile["type"] = "corridor"
            elif char.islower():
                tile["type"] = "key"
                tile["door"] = char.upper()
            elif char.isupper():
                tile["type"] = "door"
                tile["key"] = char.lower()
            area_map[(x,y)] = tile
        if debug:
            print("")
    return area_map
        
def find_item(area_map, item):
    for p in area_map.values():
        if p['char'] == item:
            return p

def set_dist_if_not_wall(area_map, point, dist, parent):
    item = area_map.get(point, {})
    if item.get("type") != "wall" and item.get("type") != "door":
        item.setdefault("dist", dist)
        item.setdefault("path", parent.get("path","") + item["char"])

def find_accessible_key(area_map, item):
    from copy import deepcopy
    import re

    area_map = deepcopy(area_map)
    area_map[item["pos"]] = item

    path_keys = re.sub(r'[^a-z@]', "", item.get("path", ""))
    keys = set(path_keys) | set(path_keys.upper()) | set("@")

    for v in area_map.values():
        if v is not item and "dist" in v:
            del v["dist"]
            del v["path"]
            
            if "keys" in v:
                del v["keys"]

        if v["char"] in keys:
            v["char"] = "."
            v["type"] = "corridor"
    
    i = item.setdefault("dist", 0)
    while True:
        found = 0
        for p in area_map.values():
            if p.get("dist") == i:
                if p.get("type") != "key":
                    found += 1
                    x,y = p["pos"]
                    set_dist_if_not_wall(area_map, (x+1,y), i+1, p)
                    set_dist_if_not_wall(area_map, (x-1,y), i+1, p)
                    set_dist_if_not_wall(area_map, (x,y+1), i+1, p)
                    set_dist_if_not_wall(area_map, (x,y-1), i+1, p)
        
        if found == 0:
            break
            
        i += 1
        
    keys = []
    for p in area_map.values():
        if p["type"] == "key" and p.get("path") is not None:
            keys.append(p)
    
    return keys




In [None]:
def set_distances(area_map, distances, point, distance, path=""):
    item = area_map.get(point)
    path = path + item["char"]
    x,y = point
    if item is not None and item["type"] != "wall":
        d = distances.get(point, {"dist": 1e10})
        m = min(distance, d["dist"])
        if m == distance:
            distances[point] = dict(dist=m, 
                                    path=path[1:],
                                    rpath=path[-2::-1],
                                    doors={d for d in path if d.isupper()}, 
                                    keys={d.upper() for d in path if d.islower()}
                                   )
            set_distances(area_map, distances, (x+1,y), m+1, path)
            set_distances(area_map, distances, (x-1,y), m+1, path)
            set_distances(area_map, distances, (x,y+1), m+1, path)
            set_distances(area_map, distances, (x,y-1), m+1, path)
        
def find_paths_from_key(area_map, key):
    item = find_item(area_map, key)
    distances = dict()
    set_distances(area_map, distances, item["pos"], 0)
    return distances

def find_all_key_paths(area_map):
    # Find all keys
    keys = set()
    for k in area_map.values():
        if k["type"] in ["me","key"]:
            keys.add(k["char"])
            
    temp_keys = set(keys)
    all_paths = dict()
    while len(temp_keys)>0:
        for start in sorted(temp_keys):
            distances = find_paths_from_key(area_map, start)
            
            temp_keys.remove(start) 
            
            for end in sorted(temp_keys):
                pos = find_item(area_map, end)["pos"]
                d = distances.get(pos)
                if d is None:
                    print("{} not reachable from {}".format(end, start))
                else:
                    all_paths[start+end] = d

                    d = dict(d)
                    rpath = d["rpath"]
                    d["rpath"] = d["path"]
                    d["path"] = rpath
                    all_paths[end+start] = d

#     print(all_paths)
    return all_paths

    

    

In [None]:
def dijsktra(area_map, start, approximate=False, prune_size=200000, prune_pct=0.1, solution_limit=100, debug_step=1e4):
    import heapq, math, time, collections
    from IPython.display import clear_output


    vertices = find_all_key_paths(area_map)
    print("Calculated {} distances".format(len(vertices)))
    
    # Build all the lower case letters for path calculations (the keys are two letters - so this is a trick) 
    start = set(start)
    keys = set()
    for k in vertices.keys():
        keys |= set(k)  
    keys = keys - start
    
    dist = dict()
    q = []
    
    heapq.heappush(q, (0, start, "", frozenset()))
    
    seen = dict()
    
    found = []
    lengths = [math.inf]
    iterations = 0 
    start_time = time.time()
    max_steps_seen = []

    while q:
        iterations += 1 
        que_item = heapq.heappop(q)
        d, u_set, p, seen_keys = que_item
        key_length = len(seen_keys)
        
        if (key_length>len(max_steps_seen)):
            max_steps_seen.append(dict(iterations=iterations, time=time.time()-start_time, steps=len(seen_keys)))
        
        if iterations % debug_step == 0:
            clear_output(wait = True)
            print(iterations, len(q), len(seen_keys), len(found), min(lengths))

            if len(q) == 0:
                print("Emtpy q")
            else:
                # Sort queue
                sorted_queue = dict()
                for q_item in q:
                    sorted_queue.setdefault(len(q_item[3]), []).append(q_item)
                for k in range(len(keys)+1):
                    v = len(sorted_queue.get(k,[]))
                    print("{:2d} - {:6d} times ({:3.0f}%)".format(k, v, 100*v/len(q)))

                for step in max_steps_seen:
                    print(step)
                print("\n".join(found[-10:]))

                if prune_size is not None and len(q)>=prune_size:
                    print("Reducing q", len(q), end="")
                    q = []
                    for q1 in sorted_queue.values():
                        q_len = len(q1)
                        if q_len == 0:
                            continue

                        q += heapq.nsmallest(math.ceil(q_len*prune_pct), q1)
                    print(" to ", len(q))
        
        if len(keys) == len(seen_keys):
            lengths.append(d)
            found.append("FOUND iter={} d={} p={} keys={}".format(iterations, d, p, sorted(seen_keys)))
            if len(lengths) > solution_limit:
                return lengths, found
        
        for u in u_set:
            for key in keys-set(u):
                path = vertices.get(u+key)
                if path is not None:
                    if len(path["doors"] - seen_keys) > 0:
                        continue

                    seen_u = seen.get((key, seen_keys))
                    dist = d+path["dist"]

                    if approximate:
                        if seen_u is not None:
                            continue
                    elif seen_u is not None and seen_u<dist:
                        continue

                    seen[(key, seen_keys)] = dist
                    heapq.heappush(q, (dist, u_set-set(u)|set(key), p+key, frozenset(seen_keys | path["keys"])))


    return lengths, found

In [None]:
sample1 = """
#########
#b.A.@.a#
#########
"""
area_map = parse_map(sample1, debug=True)
assert len(area_map) == 27, len(area_map)
assert area_map[(3,1)]["char"] == "A"

item = find_item(area_map, "@")
assert item is not None
assert item["pos"] == (5,1)

print(area_map.keys())

d = dijsktra(area_map, '@', approximate=False)
print(min(d[0]))
print(d[1])

In [None]:
sample2 = """
########################
#f.D.E.e.C.b.A.@.a.B.c.#
######################.#
#d.....................#
########################
"""
area_map = parse_map(sample2, debug=True)
d = dijsktra(area_map, '@')
print(min(d[0]))
print(d[1])

In [None]:
sample3 = """
########################
#...............b.C.D.f#
#.######################
#.....@.a.B.c.d.A.e.F.g#
########################
"""
area_map = parse_map(sample3, debug=True)
d = dijsktra(area_map, '@')
print(min(d[0]))
print(d[1])

In [None]:
sample4 = """
#################
#i.G..c...e..H.p#
########.########
#j.A..b...f..D.o#
########@########
#k.E..a...g..B.n#
########.########
#l.F..d...h..C.m#
#################
"""
area_map = parse_map(sample4, debug=True)
# search_path(area_map)
d = dijsktra(area_map, '@', approximate=False)
print(min(d[0]))
print(d[1])

In [None]:
sample5 = """
########################
#@..............ac.GI.b#
###d#e#f################
###A#B#C################
###g#h#i################
########################
"""
area_map = parse_map(sample5, debug=True)
# search_path(area_map)
d = dijsktra(area_map, '@', approximate=True)
print(min(d[0]))
print(d[1])

In [None]:
with open("18-input.txt", "rt") as FILE:
    data = FILE.read().strip()
    
area_map = parse_map(data, debug=True)
d = dijsktra(area_map, '@', approximate=True)
print(min(d[0]))
print(d[1])

# Part 2

In [None]:
def fix_starting_points(area_map):
    item = find_item(area_map, '@')
    x, y = item["pos"]
    for p in [(x,y),(x-1,y),(x+1,y),(x,y-1),(x,y+1)]:
        area_map[p] = dict(pos=p, char='#', type='wall')
    
    for p_ix, p in enumerate([(x-1,y-1),(x+1,y-1),(x-1,y+1),(x+1,y+1)]):
        area_map[p] = dict(pos=p, char=str(p_ix+1), type='me')
        


In [None]:
sample2_1 = """
#######
#a.#Cd#
##...##
##.@.##
##...##
#cB#Ab#
#######
"""
area_map = parse_map(sample2_1, debug=False)
fix_starting_points(area_map)
print_map(area_map)
d = dijsktra(area_map, set("1234"), approximate=False)
print(min(d[0]))
print(d[1])

In [None]:
sample2_2 = """
###############
#d.ABC.#.....a#
######1#2######
###############
######4#3######
#b.....#.....c#
###############
"""
area_map = parse_map(sample2_2, debug=False)
# fix_starting_points(area_map)
print_map(area_map)
d = dijsktra(area_map, set("1234"), approximate=False)
print(min(d[0]))
print(d[1])

In [None]:
sample2_3 = """
#############
#DcBa.#.GhKl#
#.###1#2#I###
#e#d#####j#k#
###C#4#3###J#
#fEbA.#.FgHi#
#############
"""
area_map = parse_map(sample2_3, debug=False)
# fix_starting_points(area_map)
print_map(area_map)
d = dijsktra(area_map, set("1234"), approximate=False)
print(min(d[0]))
print(d[1])

In [None]:
import math
sample2_4 = """
#############
#g#f.D#..h#l#
#F###e#E###.#
#dCba1#2BcIJ#
#############
#nK.L4#3G...#
#M###N#H###.#
#o#m..#i#jk.#
#############
"""
area_map = parse_map(sample2_4, debug=False)
# fix_starting_points(area_map)
print_map(area_map)
d = dijsktra(area_map, set("1234"), approximate=True, prune_size=None, prune_pct=0.1, solution_limit=10)
print(min(d[0]))
print("\n".join(d[1]))

In [None]:
area_map = parse_map(data, debug=False)
fix_starting_points(area_map)
print_map(area_map)
d = dijsktra(area_map, set("1234"), approximate=True, prune_size=None, prune_pct=0.1, solution_limit=1)
print(min(d[0]))
# print(d[1])