In [1]:
import re
from functools import reduce
import itertools
from collections import Counter
import numpy as np


def get_input(n):
    with open('input_'+n+'.txt', 'r') as infile:
        return infile.read().strip()


puzzle = get_input('16')
assert puzzle != get_input('15')


sample_aoc = r""".|...\....
|.-.\.....
.....|-...
........|.
..........
.........\
..../.\\..
.-.-/..|..
.|....-|.\
..//.|...."""


def parse_input(puzzle):
    return puzzle.strip('\n').strip().split('\n')


def solve1(puzzle, viz=False, start=(0, 0, 1, 0), cache=None):
    layout = parse_input(puzzle)
    trace = layout.copy()

    if cache is None:
        cache = {}

    initial_path = [start]

    def draw_path(path):
        if len(path) == 0:
            print("no path")
            return set()
        out = layout.copy()
        for x, y, *_ in path:
            l = list(out[y])
            l[x] = '#'
            out[y] = ''.join(l)
        l = list(out[y])
        l[x] = 'X'
        out[y] = ''.join(l)
        print('\n'.join(out))

    def ray(path):
        seen = set()
        x, y, dx, dy = path[-1]
        state = (x, y, dx, dy)

        if state in cache:  # cache hit
            # print("cache hit")
            return cache[state]

        if x < 0 or y < 0 or x >= len(layout[0]) or y >= len(layout) or state in path[:-1]:
            for i in range(len(path)-1):
                seen = seen.union(set([(p[0], p[1]) for p in path[i+1:-1]]))
                if path[i] in cache.keys():
                    cache[path[i]] = cache[path[i]].union(seen)
                else:
                    cache[path[i]] = seen

            return seen

        try:
            pos = layout[y][x]
        except:
            print("out of bounds")
            draw_path(path[:-1])
            return set(path[:-1])
        if pos == '\\':
            dx, dy = dy, dx
        if pos == r'/':
            dx, dy = -dy, -dx

        if pos == '|' and dx != 0:
            dx, dy = dy, dx
            for dyi in [-1, 1]:
                next_state = (x, y+dyi, dx, dyi)
                ray(path+[next_state])

        elif pos == '-' and dy != 0:
            dx, dy = dy, dx
            for dxi in [-1, 1]:
                next_state = (x+dxi, y, dxi, dy)
                seen = seen.union(ray(path+[next_state]))

        # if pos == '.' or (pos == r'|' and (dx,dy) in [(0,1),(0,-1)]) or (pos == r'-' and (dx,dy) in [(1,0),(-1,0)]):
        else:
            next_state = (x+dx, y+dy, dx, dy)
            seen = seen.union(ray(path+[next_state]))
        return seen

    res = ray(initial_path)
    # print(res)
    if viz:
        draw_path(list(res))
    resnum = len(res)
    # print(resnum)
    return resnum


sample1 = """...--.
......"""

sample2 = r"""
.-\\
\./
"""
sample3 = r"""
.--\\
.\./
"""

sample4 = r"""
.-.\
.|./
.-..
"""

# assert solve1(sample1) == 6
# assert solve1(sample2) == 6
# assert solve1(sample3) == 7
# assert solve1(sample4) == 11

# assert solve1(sample_aoc, viz=False) == 46

In [2]:
def solve2(puzzle, viz=False):
    layout = parse_input(puzzle)
    cache = {}
    initial_states = list(itertools.chain(
        ((i, 0, 0, 1) for i in range(len(layout[0]))),  # top edge
        ((i, len(layout)-1, 0, -1)
         for i in range(len(layout[0]))),  # bottom edge
        ((0, i, 1, 0) for i in range(len(layout))),  # left edge
        ((len(layout), i, -1, 0) for i in range(len(layout))),  # right edge
    ))
    results = {}
    for ist in initial_states:
        print(ist, len(cache.keys()))
        resnum = solve1(puzzle, viz=False, start=ist, cache=cache)
        results[ist] = resnum

    print(results)

    ist_max, n = max(results.items(), key=lambda x: x[1])
    solve1(puzzle, viz=True, start=ist_max, cache=cache)
    return n

Currently in a caching problem. I want to remember the solution from a given point. But im only skipping forward on the path if i hit a known node. I'm essentially chosing a specific branch and ignoring all other branches. 

The cache should return the set of all downstream nodes in the path.

In [47]:
def traverse(x0, y0, dx0, dy0, layout, previous_path=[], cache={}):
    x, y, dx, dy = x0, y0, dx0, dy0
    path = previous_path
    seen = set()
    ymax, xmax = len(layout)-1, len(layout[0])-1
    def get_pos(r): return (r[0], r[1])
    already_seen = False

    while True:
        if (state := (x, y, dx, dy)) in cache.keys():
            return cache[state]
        in_bounds = (0 <= x <= xmax) and (0 <= y <= ymax)

        if already_seen or not in_bounds:
            break

        # print(x, y, dx, dy)
        pos = layout[y][x]
        path.append((x, y, dx, dy))

        if pos == '\\':
            dx, dy = dy, dx
        elif pos == r'/':
            dx, dy = -dy, -dx

        elif pos == '|' and dx != 0:
            for dyi in [-1, 1]:
                seen = seen.union(traverse(x, y+dyi, 0, dyi, layout, path, cache))
            break
        elif pos == '-' and dy != 0:
            for dxi in [-1, 1]:
                seen = seen.union(traverse(x, y, dxi, 0, layout, path, cache))
            break
        x += dx
        y += dy
        already_seen = (x, y, dx, dy) in path

    res = set(get_pos(p) for p in path)
    cache[(x0, y0, dx0, dy0)] = res
    return res.union(seen)


def solve_better(puzzle, ray_func=traverse, viz=False, start=(0, 0, 1, 0), cache={}):
    layout = parse_input(puzzle)
    hits = ray_func(*start, layout, [], cache=cache)
    # print(sorted(hits))
    return len(hits)


layout = parse_input(sample3)
assert len(traverse(0, 0, 0, 1, layout)) == 2

solve_better(sample1)

6

In [48]:
print(solve_better(sample1))
assert solve_better(sample1) == 6
assert solve_better(sample2) == 6
assert solve_better(sample3) == 7
assert solve_better(sample4) == 11

6


AssertionError: 

In [44]:
def calc_energized(grid, start):
    # (row, col, movement row, movement col)
    queue = [start]
    seen = set()

    while queue:
        row, col, drow, dcol = queue.pop(0)
        row += drow
        col += dcol

        if row < 0 or row >= len(grid) or col < 0 or col >= len(grid[0]):
            continue

        new_pos = grid[row][col]

        if (
            new_pos == "."
            or (new_pos == "-" and dcol != 0)
            or (new_pos == "|" and drow != 0)
        ):
            queue.append((row, col, drow, dcol))
            seen.add((row, col, drow, dcol))

        elif new_pos == "\\":
            drow, dcol = dcol, drow
            if (row, col, drow, dcol) not in seen:
                queue.append((row, col, drow, dcol))
                seen.add((row, col, drow, dcol))

        elif new_pos == "/":
            drow, dcol = -dcol, -drow
            if (row, col, drow, dcol) not in seen:
                queue.append((row, col, drow, dcol))
                seen.add((row, col, drow, dcol))

        else:
            for dr, dc in [(1, 0), (-1, 0)] if new_pos == "|" else [(0, 1), (0, -1)]:
                if (row, col, dr, dc) not in seen:
                    queue.append((row, col, dr, dc))
                    seen.add((row, col, dr, dc))

    visited = {(row, col) for (row, col, _, _) in seen}

    return len(visited)


def solve2(puzzle, viz=False, ray_func=traverse):
    layout = parse_input(puzzle)
    cache = {}
    initial_states = list(itertools.chain(
        ((i, 0, 0, 1) for i in range(len(layout[0]))),  # top edge
        ((i, len(layout)-1, 0, -1) for i in range(len(layout[0]))),  # bottom edge
        ((0, i, 1, 0) for i in range(len(layout))),  # left edge
        ((len(layout), i, -1, 0) for i in range(len(layout))),  # right edge
    ))
    results = {}
    for i, ist in enumerate(initial_states):
        resnum = solve_better(puzzle, viz=False, start=ist, cache=cache, ray_func)
        # resnum = calc_energized(parse_input(puzzle), start=ist)
        results[ist] = resnum
        # print(i, len(cache.keys()))

    # print(results)

    ist_max, n = max(results.items(), key=lambda x: x[1])
    # solve1(puzzle, viz=True, start=ist_max, cache=cache)
    return n

In [45]:
solve2(sample_aoc)

51

In [46]:
solve2(puzzle)

7661

In [43]:
g = {complex(i, j): c for j, r in enumerate(open('input_16.txt'))
     for i, c in enumerate(r.strip())}

def fn(todo):
    done = set()
    while todo:
        pos, dir = todo.pop()
        while not (pos, dir) in done:
            done.add((pos, dir))
            pos += dir
            match g.get(pos):
                case '|': dir = 1j
                todo.append((pos, -dir))
                case '-': dir = -1
                todo.append((pos, -dir))
                case '/': dir = -complex(dir.imag, dir.real)
                case '\\': dir = complex(dir.imag, dir.real)
                case None: break

    return len(set(pos for pos, _ in done)) - 1


print(fn([(-1, 1)]))

print(max(map(fn, ([(pos-dir, dir)] for dir in (1, 1j, -1, -1j)
                   for pos in g if pos-dir not in g))))

7307
7635
