# AoC 2021

## Helpers

In [181]:
#### IMPORTS

import re
import abc
import operator
import numpy as np
from collections import Counter, defaultdict, namedtuple, deque, abc
from itertools   import (permutations, combinations, chain, cycle, product, islice, 
                         takewhile, zip_longest, starmap, count as count_from)
from functools   import lru_cache, reduce
from heapq import (heappush, heappop, nlargest, nsmallest)

from pprint import pprint as p, pformat as pf
import toolz.curried as t
from tqdm.notebook import tqdm as tq
from dataclasses import dataclass, field

#### CONSTANTS

alphabet = 'abcdefghijklmnopqrstuvwxyz'
ALPHABET = alphabet.upper()
infinity = float('inf')

#### SIMPLE UTILITY FUNCTIONS

cat = ''.join

def ints(start, end, step=1):
    "The integers from start to end, inclusive: range(start, end+1)"
    return range(start, end + 1, step)

def first(iterable, default=None): 
    "The first item in an iterable, or default if it is empty."
    return next(iter(iterable), default)

def head(iterable, n=5):
    "The first n items in an iterable"
    return tuple(islice(iterable, n))

def tail(iterable, n=1):
    "Skip n items in an iterable"
    return islice(iterable, n, None)

def first_true(iterable, pred=None, default=None):
    """Returns the first true value in the iterable.
    If no true value is found, returns *default*
    If *pred* is not None, returns the first item
    for which pred(item) is true."""
    # first_true([a,b,c], default=x) --> a or b or c or x
    # first_true([a,b], fn, x) --> a if fn(a) else b if fn(b) else x
    return next(filter(pred, iterable), default)

def nth(iterable, n, default=None):
    "Returns the nth item of iterable, or a default value"
    return next(islice(iterable, n, None), default)

def upto(iterable, maxval):
    "From a monotonically increasing iterable, generate all the values <= maxval."
    # Why <= maxval rather than < maxval? In part because that's how Ruby's upto does it.
    return takewhile(lambda x: x <= maxval, iterable)

identity = lambda x: x

def quantify(iterable, pred=bool):
    "Count how many times the predicate is true of an item in iterable."
    return sum(map(pred, iterable))

def multimap(items):
    "Given (key, val) pairs, return {key: [val, ....], ...}."
    result = defaultdict(list)
    for (key, val) in items:
        result[key].append(val)
    return result

def overlapping(iterable, n):
    """Generate all (overlapping) n-element subsequences of iterable.
    overlapping('ABCDEFG', 3) --> ABC BCD CDE DEF EFG"""
    if isinstance(iterable, abc.Sequence):
        yield from (iterable[i:i+n] for i in range(len(iterable) + 1 - n))
    else:
        result = deque(maxlen=n)
        for x in iterable:
            result.append(x)
            if len(result) == n:
                yield tuple(result)
                
def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    return overlapping(iterable, 2)

def mapt(fn, *args): 
    "Do a map, and make the results into a tuple."
    return tuple(map(fn, *args))

def map2d(fn, grid):
    "Apply fn to every element in a 2-dimensional grid."
    return tuple(mapt(fn, row) for row in grid)

def flatmap(fn, *args):
    "Do a map and a one-level flatten"
    return tuple(chain.from_iterable(map(fn, *args)))

def repeat(n, fn, arg, *args, **kwds):
    "Repeat arg = fn(arg) n times, return arg."
    return nth(repeatedly(fn, arg, *args, **kwds), n)

def repeatedly(fn, arg, *args, **kwds):
    "Yield arg, fn(arg), fn(fn(arg)), ..."
    yield arg
    while True:
        arg = fn(arg, *args, **kwds)
        yield arg
        
def repeatedly1(fn, arg, *args, **kwds):
    "Yield fn(arg), fn(fn(arg)), ..."
    return tail(repeatedly(fn, arg, *args, **kwds))

def compose(f, g): 
    "The function that computes f(g(x))."
    return lambda x: f(g(x))

#### FILE INPUT AND PARSING

def Input(day, line_parser=str.strip, test=False, file_template='data/2021/{}.txt'):
    "For this day's input file, return a tuple of each line parsed by `line_parser`."
    return mapt(line_parser, open(file_template.format(
        f'{day}test' if test else day
    )))

def Groups(day, group_parser=str.split, test=False):
    entire = Input(day, t.identity, test)
    groups = ''.join(entire).strip('\n').split('\n\n')
    return mapt(group_parser, groups)

@t.curry
def Tokens(line, sep=','):
    "Splits line into delimited tokens"
    return line.strip().split(sep)

def integers(text): 
    "A tuple of all integers in a string (ignore other characters)."
    return mapt(int, re.findall(r'-?\d+', text))

def digits(number):
    "Tuple of digits in number"
    return mapt(int, str(number))

################ 2-D points implemented using (x, y) tuples

def coords(rows):
    for y, row in enumerate(rows):
        for x, val in enumerate(row):
            yield (x, y), val

def X(point): return point[0]
def Y(point): return point[1]
def Z(point): return point[2]

origin = (0, 0)
HEADINGS = UP, LEFT, DOWN, RIGHT = (0, -1), (-1, 0), (0, 1), (1, 0)

def turn_right(heading): return HEADINGS[HEADINGS.index(heading) - 1]
def turn_around(heading):return HEADINGS[HEADINGS.index(heading) - 2]
def turn_left(heading):  return HEADINGS[HEADINGS.index(heading) - 3]

def add(A, B): 
    "Element-wise addition of two n-dimensional vectors."
    return mapt(sum, zip(A, B))

def sub(A, B): 
    "Element-wise subtraction of two n-dimensional vectors."
    return tuple(a - b for a, b in zip(A, B))

def neighbors4(point): 
    "The four neighboring squares."
    x, y = point
    return (          (x, y-1),
            (x-1, y),           (x+1, y), 
                      (x, y+1))

def neighbors8(point): 
    "The eight neighboring squares."
    x, y = point 
    return ((x-1, y-1), (x, y-1), (x+1, y-1),
            (x-1, y),             (x+1, y),
            (x-1, y+1), (x, y+1), (x+1, y+1))

def cityblock_distance(P, Q=origin): 
    "Manhatten distance between two points."
    return sum(abs(p - q) for p, q in zip(P, Q))

def distance(P, Q=origin): 
    "Straight-line (hypotenuse) distance between two points."
    return sum((p - q) ** 2 for p, q in zip(P, Q)) ** 0.5

def king_distance(P, Q=origin):
    "Number of chess King moves between two points."
    return max(abs(p - q) for p, q in zip(P, Q))

################ Debugging 

def trace1(f):
    "Print a trace of the input and output of a function on one line."
    def traced_f(*args):
        result = f(*args)
        print('{}({}) = {}'.format(f.__name__, ', '.join(map(str, args)), result))
        return result
    return traced_f

def grep(pattern, iterable):
    "Print lines from iterable that match pattern."
    for line in iterable:
        if re.search(pattern, line):
            print(line)
            
class Struct:
    "A structure that can have any fields defined."
    def __init__(self, **entries): self.__dict__.update(entries)
    def __repr__(self): 
        fields = ['{}={}'.format(f, self.__dict__[f]) 
                  for f in sorted(self.__dict__)]
        return 'Struct({})'.format(', '.join(fields))

################ A* and Breadth-First Search (tracking states, not actions)

def always(value): return (lambda *args: value)

def Astar(start, moves_func, h_func, cost_func=always(1)):
    "Find a shortest sequence of states from start to a goal state (where h_func(s) == 0)."
    frontier  = [(h_func(start), start)] # A priority queue, ordered by path length, f = g + h
    previous  = {start: None}  # start state has no previous state; other states will
    path_cost = {start: 0}     # The cost of the best path to a state.
    Path      = lambda s: ([] if (s is None) else Path(previous[s]) + [s])
    while frontier:
        (f, s) = heappop(frontier)
        if h_func(s) == 0:
            return Path(s)
        for s2 in moves_func(s):
            g = path_cost[s] + cost_func(s, s2)
            if s2 not in path_cost or g < path_cost[s2]:
                heappush(frontier, (g + h_func(s2), s2))
                path_cost[s2] = g
                previous[s2] = s

def bfs(start, moves_func, goals):
    "Breadth-first search"
    goal_func = (goals if callable(goals) else lambda s: s in goals)
    return Astar(start, moves_func, lambda s: (0 if goal_func(s) else 1))

## Day 1

In [192]:
nums = Input(1, line_parser=int)
nums[:5]

(155, 157, 156, 172, 170)

In [193]:
increased = lambda pair: pair[0] < pair[1]
quantify(pairwise(nums), increased)

1713

In [194]:
sums = mapt(sum, overlapping(nums, 3))
quantify(pairwise(sums), increased)

1734

## Day 2

In [195]:
moves = Input(2, line_parser=str.split)
moves[:5]

(['forward', '1'], ['down', '3'], ['down', '2'], ['up', '1'], ['down', '7'])

In [196]:
def step(pos, move):
    x, y = pos
    d, length = move[0], int(move[1])
    return ((x + length, y) if d == 'forward' else
            (x, y + length) if d == 'down' else
            (x, y - length) if d == 'up' else
            (x, y))

In [197]:
x, y = reduce(step, moves, (0, 0))
x * y

1840243

In [198]:
def step2(pos, move):
    x, y, aim = pos
    d, length = move[0], int(move[1])
    return ((x + length, y + aim * length, aim) if d == 'forward' else
            (x, y, aim + length)                if d == 'down' else
            (x, y, aim - length)                if d == 'up' else
            (x, y))

In [199]:
x, y, aim = reduce(step2, moves, (0, 0, 0))
x * y

1727785422

## Day 3

In [200]:
# single_digits = lambda s: mapt(int, s.strip())
nums = Input(3)
nums[:5]

('101010000100',
 '100001010100',
 '111100000101',
 '010000000010',
 '001101100010')

In [201]:
# nums = '''
# 00100
# 11110
# 10110
# 10111
# 10101
# 01111
# 00111
# 11100
# 10000
# 11001
# 00010
# 01010
# '''.strip().split()

In [202]:
def nth_common(index, n=1):
    counts = Counter(num[index] for num in nums)
    (res, _) = counts.most_common(n)[-1]
    return res

In [203]:
width = len(nums[0])
top1 = cat(nth_common(pos) for pos in range(width))
top2 = cat(nth_common(pos, n=2) for pos in range(width))
int(top1, base=2) * int(top2, base=2)

3148794

In [204]:
def sift(nums, co2=False):
    flip = '01' if not co2 else '10'
    for index in range(0, len(nums[0])):
        counts = Counter(num[index] for num in nums)
        keep = flip[1 if counts['1'] >= counts['0'] else 0]
        nums = [num for num in nums if num[index] == keep]
        if len(nums) == 1:
            break
    return nums[0]

In [205]:
oxy = sift(nums, co2=False)
co2 = sift(nums, co2=True)
int(oxy, base=2) * int(co2, base=2)

2795310

## Day 4

In [206]:
[nums, *boards] = Groups(4, group_parser=integers, test=False)
len(nums)

100

In [207]:
class Bingo:
    _board: tuple
    _rows_cols: list
    
    def __init__(self, board):
        self._board = board
        rows = [board[i:i+5] for i in range(0, 5 * 5, 5)]
        cols = [board[i::5] for i in range(0, 5)]
        self._rows_cols = mapt(set, [*rows, *cols])
        
    def mark_num(self, num):
        for seq in self._rows_cols:
            seq.discard(num)
    
    def is_winner(self):
        return any(len(seq) == 0 for seq in self._rows_cols)
    
    def unmarked(self):
        return reduce(lambda s1, s2: s1 | s2, self._rows_cols)

In [208]:
def play(nums, bingos):
    for num in nums:
        for b in bingos:
            b.mark_num(num)
            if b.is_winner():
                return num, b

In [209]:
bingos = [Bingo(board) for board in boards]
final, board = play(nums, bingos)

In [210]:
sum(board.unmarked()) * final

14093

In [211]:
def play_to_lose(nums, bingos):
    for num in nums:
        for b in bingos:
            b.mark_num(num)
            
        bingos = [b for b in bingos if not b.is_winner()]
        if len(bingos) == 1:
            return play(nums, bingos)

In [212]:
bingos = [Bingo(board) for board in boards]
final, board = play_to_lose(nums, bingos)

sum(board.unmarked()) * final

17388

## Day 5

In [213]:
lines = Input(5, line_parser=integers, test=False)
lines[:6]

((405, 945, 780, 945),
 (253, 100, 954, 801),
 (518, 300, 870, 300),
 (775, 848, 20, 848),
 (586, 671, 469, 671),
 (598, 20, 900, 20))

In [214]:
import itertools

def sign(n):
    return (-1 if n < 0 else
            0 if n == 0 else
            1)

def between(n1, n2):
    step = sign(n2 - n1)
    return (range(n1, n2 + step, step) if step != 0 else
            itertools.repeat(n1))

def points(line):
    x1, y1, x2, y2 = line
    return tuple(zip(between(x1, x2), between(y1, y2)))

In [215]:
def axis_aligned(line):
    x1, y1, x2, y2 = line
    return (x1 - x2 == 0) or (y1 - y2 == 0)

axis_aligned(lines[0])

True

In [216]:
straights = filter(axis_aligned, lines)
vents = Counter(point
                for line in straights
                for point in points(line))
quantify(vents.values(), lambda n: n > 1)

7438

In [217]:
vents = Counter(point
                for line in lines
                for point in points(line))
quantify(vents.values(), lambda n: n > 1)

21406

## Day 6

In [218]:
(initial, ) = Input(6, line_parser=integers, test=False)
fish = Counter(initial)
fish

Counter({1: 99, 2: 57, 5: 53, 4: 48, 3: 43})

In [219]:
def step(fish):
    tick = Counter({
        timer - 1: count for timer, count in fish.items() })
    tick[8] += tick[-1]
    tick[6] += tick[-1]
    del tick[-1]
    return tick

In [220]:
end = repeat(80, step, fish)
sum(end.values())

356190

In [221]:
end = repeat(256, step, fish)
sum(end.values())

1617359101538

## Day 7

In [222]:
(crabs, ) = Input(7, line_parser=integers, test=False)
crabs = np.array(crabs)

In [223]:
opt = int(np.median(crabs))
fuel = sum(np.abs(crabs - opt))
fuel

364898

In [224]:
def cost(diff):
    return diff * (diff + 1) / 2

def total(pos):
    diffs = np.abs(crabs - pos)
    return sum(mapt(cost, diffs))

In [225]:
final = min(range(min(crabs), max(crabs)), key=total)
final

500

In [226]:
int(total(final))

104149091

## Day 8

In [227]:
def sort_str(s):
    return cat(sorted(s))

def parse_note(line):
    patterns, output = line.strip().split(' | ')
    signals = set(mapt(sort_str, patterns.split()))
    return signals, output.split()

notes = Input(8, line_parser=parse_note, test=False)
len(notes)

200

In [228]:
def unique_outputs(note):
    _, out = note
    return quantify(out, lambda s: len(s) in [2, 3, 4, 7])
sum(mapt(unique_outputs, notes))

288

In [229]:
segments = 'abcdefg'
digits = {'cf': 1, 'acf': 7, 'bcdf': 4, 'acdeg': 2, 'acdfg': 3, 'abdfg': 5,
          'abcefg': 0, 'abdefg': 6, 'abcdfg': 9, 'abcdefg': 8}

In [230]:
def decode(signals, config):
    wiring = str.maketrans(cat(config), segments)
    output = [digits.get(sort_str(sig.translate(wiring))) for sig in signals]
    return output

@t.curry
def valid(signals, config):
    outputs = decode(signals, config)
    return set(outputs) == {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}

In [231]:
test = set(mapt(sort_str, 'acedgfb cdfbe gcdfa fbcad dab cefabd cdfgeb eafb cagedb ab'.split()))
config = first_true(permutations(segments, r=7), valid(test))
config

('d', 'e', 'a', 'f', 'g', 'b', 'c')

In [232]:
def solve(note):
    signals, output = note
    config = first_true(permutations(segments, r=7), valid(signals))
    return int(cat(mapt(str, decode(output, config))))

sum(mapt(solve, notes))

KeyboardInterrupt: 

## Day 9

In [None]:
def heights(line):
    return mapt(int, line.strip())

heightmap = Input(9, line_parser=heights, test=False)

In [None]:
h, w = len(heightmap), len(heightmap[0])
def at(pos, default=9):
    x, y = pos
    return (default if not (0 <= x < w and 0 <= y < h) else
            heightmap[y][x])

def is_low(pos):
    this = at(pos)
    return all(this < at(adj) for adj in neighbors4(pos))

def risk(pos):
    return at(pos) + 1

In [None]:
valleys = filter(is_low, product(range(w), range(h)))
sum(mapt(risk, valleys))

In [None]:
def concat(lsts):
    return sum(lsts, start=[])

def basins():
    seen = set()
    stop = 9
    def basin1(pos):
        if pos in seen:
            return []
        if at(pos) == stop:
            return []
        seen.add(pos)

        adj = mapt(basin1, neighbors4(pos))
        return concat([[pos], *adj])
    
    res = []
    for pos in product(range(w), range(h)):
        if pos not in seen and at(pos) != stop:
            res.append(basin1(pos))
    return res

In [None]:
from math import prod
prod(sorted(mapt(len, basins()))[-3:])

## Day 10

In [None]:
lines = Input(10, test=False)
lines[:2]

In [None]:
opens = '([{<'
closes = ')]}>'
match = dict(zip(opens, closes))

def corrupted(line):
    stack = []
    for char in line:
        if char in closes:
            start = stack.pop()
            if char != match[start]:
                return char
        else:
            stack.append(char)
    return False

score = {')': 3, ']': 57, '}': 1197, '>': 25137, False: 0}
sum(mapt(score.get, mapt(corrupted, lines)))

In [None]:
def complete(line):
    stack = []
    for char in line:
        if char in closes:
            start = stack.pop()
        else:
            stack.append(char)
    return [match[char] for char in reversed(stack)]

def add_end(score, char):
    return score * 5 + (closes.index(char) + 1)

incompletes = filter(lambda line: not corrupted(line), lines)
scores = [reduce(add_end, line, 0) for line in mapt(complete, incompletes)]

from statistics import median
median(scores)
# mapt(complete, incompletes)

## Day 11

In [None]:
def parse(line):
    return mapt(int, line.strip())
energies = Input(11, line_parser=parse, test=False)
w, h = len(energies[0]), len(energies)

In [None]:
cave = {(x, y): octo
        for y, row in enumerate(energies)
        for x, octo in enumerate(row)}

In [None]:
def flash(cave, pos):
    cave[pos] += 1
    if cave[pos] != 10:
        return
    for adj in neighbors8(pos):
        if adj in cave:
            flash(cave, adj)

def step(cave):
    remaining = deque(cave.keys())
    while remaining:
        pos = remaining.popleft()
        flash(cave, pos)
    return cave

def settle(cave):
    n = 0
    for pos, energy in cave.items():
        if energy >= 10:
            cave[pos] = 0
            n += 1
    return n

def simulate(cave, steps=100):
    cave = cave.copy()
    flashes = 0
    for _ in range(steps):
        step(cave)
        flashes += settle(cave)
    return flashes

In [None]:
simulate(cave, steps=100)

In [None]:
def synchonized(cave):
    cave = cave.copy()
    target = w * h
    steps = 0
    while True:
        step(cave)
        flashes = settle(cave)
        steps += 1
        
        if flashes == target:
            return steps

In [None]:
synchonized(cave)

## Day 12

In [None]:
def route(line):
    return line.strip().split('-')

routes = Input(12, line_parser=route, test=False)

In [None]:
def make_caves(routes=routes):
    res = defaultdict(set)
    for start, end in routes:
        res[start].add(end)
        res[end].add(start)
    return res

caves = make_caves()

In [None]:
def n_paths(now, seen=set(), path=[]):
    seen = seen.copy()
    if now == 'end':
        # print(','.join([*path, 'end']))
        return 1
    if now.islower():
        seen.add(now)
    
    path = [*path, now]
    return sum(n_paths(cave, seen, path)
               for cave in caves[now]
               if cave not in seen)

In [None]:
n_paths('start')

In [None]:
def n_paths2(now, seen=set(), backtracked=False, path=[]):
    if now == 'end':
        # print(','.join([*path, 'end']))
        return 1

    seen = seen.copy()

    if now.islower():
        seen.add(now)
    
    path = [*path, now]
    return sum(n_paths2(cave, seen, backtracked, path)
               if cave not in seen else
               n_paths2(cave, seen, True, path)
               if not backtracked else
               0
               for cave in caves[now]
               if cave != 'start')

In [None]:
n_paths2('start')

## Day 13

In [None]:
newline = lambda s: s.split('\n')
manual = Groups(13, group_parser=newline, test=False)

prefix = 'fold along '
dots = { mapt(int, s.split(',')) for s in manual[0] }
instrs = [s[len(prefix):].split('=') for s in manual[1]]

In [None]:
def fold(dot, ins):
    x, y = dot
    axis, line = ins[0], int(ins[1])
    if axis == 'x' and x > line:
        return (line - (x - line), y)
    elif axis == 'y' and y > line:
        return (x, line - (y - line))
    return dot

def fold_all(dots, ins):
    return { fold(dot, ins) for dot in dots }

In [None]:
len(fold_all(dots, instrs[0]))

In [None]:
def show_paper(dots):
    w = max(X(dot) for dot in dots) + 1
    h = max(Y(dot) for dot in dots) + 1
    for y in range(h):
        for x in range(w):
            if (x, y) in dots:
                print('#', end='')
            else:
                print('.', end='')
        print()

In [None]:
paper = reduce(fold_all, instrs, dots)
show_paper(paper)

## Day 14

In [None]:
newline = lambda s: s.split('\n')
[[template], rules] = Groups(14, group_parser=newline, test=False)
rules = { s[:2]: s[-1] for s in rules }

In [None]:
pair_rules = {
    f'{a}{c}': [f'{a}{b}', f'{b}{c}']
    for (a, c), b in rules.items()
}

In [None]:
def pairs(template):
    return Counter(pairwise(template))

In [None]:
def step(counts):
    res = Counter()
    for pair, count in counts.items():
        nexts = pair_rules[pair]
        res[nexts[0]] += count
        res[nexts[1]] += count
    return res

In [None]:
def count_pairs(counts):
    first, last = template[0], template[-1]
    res = Counter()
    for (a, b), count in counts.items():
        res[a] += count
        res[b] += count
        
    # Adjust for double counting
    res[first] += 1
    res[last] += 1
    return Counter({ k: v // 2 for k, v in res.items() })

In [None]:
final = repeat(10, step, pairs(template))
[most, *_, least] = count_pairs(final).most_common()
most[1] - least[1]

In [None]:
final = repeat(40, step, pairs(template))
[most, *_, least] = count_pairs(final).most_common()
most[1] - least[1]

## Day 15

In [None]:
risks = lambda s: mapt(int, s.strip())
cavern = Input(15, line_parser=risks, test=False)

In [None]:
w = len(cavern)
def moves(pos):
    return [(x, y)
            for (x, y) in neighbors4(pos)
            if 0 <= x < w and 0 <= y < w]

def h_func(pos):
    return cityblock_distance(pos, (w - 1, w - 1))

def cost(_, pos):
    x, y = pos
    return cavern[y][x]

def total_risk(path):
    return sum(starmap(cost, pairwise(path)))

In [None]:
path = Astar((0, 0), moves, h_func, cost)
total_risk(path)

In [None]:
orig_w = len(cavern)
w = orig_w * 5

def cost(_, pos):
    x, y = pos
    mult_x, orig_x = divmod(x, orig_w)
    mult_y, orig_y = divmod(y, orig_w)
    increment = mult_x + mult_y
    cost = cavern[orig_y][orig_x] + increment
    return cost if cost <= 9 else cost - 9

In [None]:
path = Astar((0, 0), moves, h_func, cost)
total_risk(path)

## Day 16

In [None]:
(raw, ) = Input(16, test=False)

In [None]:
def to_binary(raw):
    res = bin(int(raw, 16))[2:]
    length = len(raw) * 4
    return res.zfill(length)

In [None]:
def int2(binary):
    if isinstance(binary, str):
        return int(binary, 2)
    return int(cat(binary), 2)

In [None]:
def parts(s, indexes):
    if isinstance(indexes, int):
        indexes = [indexes]

    indexes = [0, *indexes, len(s)]
    res = []
    for start, end in pairwise(indexes):
        res.append(s[start:end])
    return res

In [None]:
@dataclass
class Lit:
    ver: int
    typ: int
    val: int

@dataclass
class Op:
    ver: int
    typ: int
    packets: list

In [None]:
O = '0'
I = '1'

In [None]:
def parse_lit(data):
    res = []
    while data:
        group, data = data[:5], data[5:]
        res.append(group[1:])
        if group.startswith(O):
            return int2(res), data

In [None]:
def parse_many(data, n=infinity):
    packets = []
    while data and n > 0:
        packet, data = parse(data)
        packets.append(packet)
        n -= 1
    return packets, data

In [None]:
def parse(msg):
    ver, typ, data = parts(msg, [3, 6])
    ver = int2(ver)
    typ = int2(typ)
    
    if typ == 4:
        val, rest = parse_lit(data)
        return Lit(ver, typ, val), rest
    
    length_type, data = parts(data, 1)
    if length_type == O:
        length, data = parts(data, 15)
        data, rest = parts(data, int2(length))
        packets, _ = parse_many(data)
        return Op(ver, typ, packets), rest
    else:
        n, data = parts(data, 11)
        packets, rest = parse_many(data, int2(n))
        return Op(ver, typ, packets), rest

In [None]:
def version_sum(tree):
    if isinstance(tree, Lit):
        return tree.ver
    return tree.ver + sum(mapt(version_sum, tree.packets))

In [None]:
tree, _ = parse(to_binary(raw))
version_sum(tree)

In [None]:
funcs = {
    0: sum,
    1: t.reduce(operator.mul),
    2: min,
    3: max,
    5: lambda p: int(p[0] > p[1]),
    6: lambda p: int(p[0] < p[1]),
    7: lambda p: int(p[0] == p[1]),
}

def eval_tree(tree):
    if isinstance(tree, Lit):
        return tree.val
    agg = funcs[tree.typ]
    return agg(mapt(eval_tree, tree.packets))

In [None]:
tree, _ = parse(to_binary(raw))
eval_tree(tree)

## Day 17

In [None]:
def command(line):
    xy = line.strip()[13:]
    x, y = xy.split(', ')
    return integers(x), integers(y)

(target, ) = Input(17, line_parser=command, test=False)
((xmin, xmax), (ymin, ymax)) = target
start = (0, 0)

In [None]:
def State(vel):
    return (start, vel)

def step(state):
    pos, vel = state
    return (add(pos, vel),
            (X(vel) - 1 if X(vel) > 0 else X(vel),
             Y(vel) - 1))

def stop(state):
    ((x, y), _) = state
    return ((x >= xmin and y <= ymax) or
            x >= xmax or y <= ymin)

def in_target(state):
    ((x, y), _) = state
    return xmin <= x <= xmax and ymin <= y <= ymax

def max_y(state):
    path = []
    while not stop(state):
        path.append(state)
        state = step(state)
    return (max(y for ((x, y), _) in path)
            if in_target(state)
            else 0)

In [None]:
starts = product(range(0, xmax), range(0, xmax))
final = max(map(State, starts), key=max_y)
max_y(final)

In [None]:
def lands(state):
    state = first_true(repeatedly(step, state), stop)
    return in_target(state)

In [None]:
starts = product(range(0, xmax + 1), range(ymin, xmax))
quantify(starts, compose(lands, State))

## Day 18

In [None]:
raws = Input(18, line_parser=eval, test=False)

In [None]:
@dataclass
class S:
    left: 'S | L'
    right: 'S | L'
    parent: 'S | None'
    
    def __init__(self, pair, parent=None):
        left, right = pair
        self.left = (S(left, self) if isinstance(left, list)
                     else L(left, self))
        self.right = (S(right, self) if isinstance(right, list)
                      else L(right, self))
        self.parent = parent 
        
    def __repr__(self):
        return f'[{repr(self.left)} , {repr(self.right)}]'
    
    def __add__(self, other):
        root = S([None, None])
        root.left = self
        root.right = other
        self.parent = root
        other.parent = root
        return root

    @property
    def is_root(self):
        return self.parent is None
    
@dataclass
class L:
    val: int
    parent: S

    def __repr__(self):
        return f'{self.val}'

null = L(-99, None)

In [None]:
is_L = lambda n: isinstance(n, L)
is_S = lambda n: isinstance(n, S)

In [None]:
def leftmost(s):
    while is_S(s):
        s = s.left
    return s

def rightmost(s):
    while is_S(s):
        s = s.right
    return s

In [None]:
def explode_one(s, depth=0):
    if is_L(s):
        return False
    
    if depth < 4:
        return (explode_one(s.left, depth + 1) or 
                explode_one(s.right, depth + 1))

    if is_S(s):
        explode(s)
        return True

    return False

In [None]:
def explode(s):
    left = left_wing(s)
    left.val += s.left.val

    right = right_wing(s)
    right.val += s.right.val
    
    parent = s.parent
    if parent.left is s:
        parent.left = L(0, parent)
    else:
        parent.right = L(0, parent)

def left_wing(s):
    parent = s.parent
    return (null if s.is_root else
            rightmost(parent.left) if parent.left is not s else
            left_wing(parent))
    
def right_wing(s):
    parent = s.parent
    return (null if s.is_root else
            leftmost(parent.right) if parent.right is not s else
            right_wing(parent))

In [None]:
from math import floor, ceil

def split(n, parent):
    left = floor(n / 2)
    right = ceil(n / 2)
    return S([left, right], parent)

def split_one(s):
    if is_S(s):
        return split_one(s.left) or split_one(s.right)
    
    if s.val < 10:
        return False
    
    parent = s.parent
    if parent.left is s:
        parent.left = split(s.val, parent)
    else:
        parent.right = split(s.val, parent)
    return True

In [None]:
def action(s):
    return explode_one(s) or split_one(s)

In [None]:
def simplify(s):
    editing = True
    while editing:
        editing = action(s)
    return s

In [None]:
def snail_add(s1, s2):
    res = s1 + s2
    return simplify(res)

def snail_sum(raws):
    return reduce(snail_add, map(S, raws))

In [None]:
def magnitude(s):
    if is_L(s):
        return s.val
    return 3 * magnitude(s.left) + 2 * magnitude(s.right)

In [None]:
magnitude(snail_sum(raws))

In [None]:
pairs = permutations(raws, 2)
max(map(compose(magnitude, snail_sum), pairs))

## Day 19

In [None]:
def scan(group):
    coords = group.split('\n')[1:]
    return np.array([mapt(int, c.split(',')) for c in coords])

report = Groups(19, group_parser=scan, test=False)

In [None]:
from scipy.spatial.transform import Rotation as R

I = np.identity(4, dtype=int)

def affine(points):
    return np.column_stack([points, np.ones(len(points), dtype=int)])

def apply(tr, points):
    res = affine(points) @ tr
    return res[:, :3]

def rotate(axis, deg):
    mat3 = (R.from_euler(axis, deg, degrees=True)
            .as_matrix().round().astype(int))
    res = I.copy()
    res[:3, :3] = mat3
    return res

def translate(vec3):
    res = I.copy()
    res[3, :3] = vec3
    return res

In [None]:
x, y, z = [rotate(ax, 90) for ax in axes]
transforms = [
    I,
    z,
    z @ z,
    z @ z @ z,
    y,
    y @ z,
    y @ z @ z,
    y @ z @ z @ z,
    y @ y,
    y @ y @ z,
    y @ y @ z @ z,
    y @ y @ z @ z @ z,
    y @ y @ y,
    y @ y @ y @ z,
    y @ y @ y @ z @ z,
    y @ y @ y @ z @ z @ z,
    x,
    x @ z,
    x @ z @ z,
    x @ z @ z @ z,
    x @ y @ y,
    x @ y @ y @ z,
    x @ y @ y @ z @ z,
    x @ y @ y @ z @ z @ z,
]

In [None]:
def matches(i, j):
    r1, r2 = report[i], report[j]
    for A in transforms:
        points = apply(A, r2)
        pairs = product(r1, points)
        diffs = starmap(operator.sub, pairs)
        counts = Counter(mapt(tuple, diffs))
        [(offset, n)] = counts.most_common(1)
        if n >= 12:
            return A @ translate(offset)
    return False

In [None]:
nodes = list(range(len(report)))

def beacon_map(start=0):
    frontier = [start]
    unvisited = set(nodes)
    transforms = { start: I }
    while frontier:
        current = frontier.pop()
        if current not in unvisited:
            continue
        unvisited.remove(current)
        for node in unvisited:
            A = matches(current, node)
            if A is not False:
                transforms[node] = A @ transforms[current] 
                frontier.append(node)
    return transforms

bmap = beacon_map()

In [None]:
def absolute(node):
    return mapt(tuple, apply(bmap[node], report[node]))

beacons = mapt(compose(set, absolute), nodes)
len(reduce(operator.or_, beacons))

In [None]:
def offset(transform):
    return transform[3, :3]

def dist(pair):
    p1, p2 = pair
    return np.abs(p1 - p2).sum()

pairs = combinations(mapt(offset, bmap.values()), 2)
max(mapt(dist, pairs))

## Day 20

In [None]:
[enhance], image = Groups(20, test=False)

def Image(px, fill):
    return defaultdict(always(fill), px)

Digit = '.#'.index

image = Image({(x, y): Digit(val)
               for y, row in enumerate(image)
               for x, val in enumerate(row)},
              0)

In [None]:
def square(center):
    x, y = center
    for dy in ints(-1, 1):
        for dx in ints(-1, 1):
            yield x + dx, y + dy

In [None]:
def output(digits):
    index = sum(n << i for i, n in enumerate(reversed(digits)))
    return Digit(enhance[index])

In [None]:
import itertools
def next_fill(image):
    digits = itertools.repeat(image.default_factory(), 9)
    return output(list(digits))

In [None]:
def step(image):
    minw, minh = min(image.keys())
    maxw, maxh = max(image.keys())
    Pixels = t.curry(mapt, image.__getitem__)
    transform = t.compose(output, Pixels, square)
    coords = product(ints(minw - 1, maxw + 1), ints(minh - 1, maxh + 1))
    fill = next_fill(image)
    return Image({ coord: transform(coord) for coord in coords }, fill)

In [None]:
def display(image):
    minw, minh = min(image.keys())
    maxw, maxh = max(image.keys())
    for y in ints(minh, maxh):
        for x in ints(minw, maxw):
            print('#' if image[(x, y)] else '.', end='')
        print()

In [None]:
result = repeat(2, step, image)
quantify(result.values())

In [None]:
result = repeat(50, step, image)
quantify(result.values())

## Day 21

In [302]:
start = (8, 4)

# test
# start = (4, 8)

In [303]:
def Player(pos, score=0):
    return pos, score

def move(player, n):
    pos, score = player
    wrapped = (pos + n) % 10
    next_pos = wrapped if wrapped else 10
    return Player(next_pos, score + next_pos)

In [304]:
Die = lambda: cycle(ints(1, 100))

@dataclass
class Game:
    players: list
    turn: int = 0
    n_rolls: int = 0
    die: Die = field(repr=False, default_factory=Die)

In [305]:
def turn(game):
    player = game.players[game.turn]
    n = sum(head(game.die, 3))
    player = move(player, n)
    game.players[game.turn] = player
    game.turn = 1 - game.turn
    game.n_rolls += 3
    return game

def won(game):
    (_, s1), (_, s2) = game.players
    return s1 >= 1000 or s2 >= 1000

def loser(game):
    (_, s1), (_, s2) = game.players
    return s1 if s1 < 1000 else s2

In [306]:
game = Game([Player(start[0]), Player(start[1])])
game = first_true(repeatedly(turn, game), won)
loser(game) * game.n_rolls

504972

In [310]:
dirac = [1, 2, 3]
distribution = Counter(mapt(sum, product(dirac, dirac, dirac)))

In [311]:
Game = namedtuple('Game', 'p1,p2,turn')

stop = 21

def mult(c, A):
    return tuple(c * x for x in A)

def next_game(game, roll):
    p1, p2, turn = game
    player = p1 if turn == 0 else p2
    return (Game(move(player, roll), p2, 1) if turn == 0 else
            Game(p1, move(player, roll), 0))

@lru_cache(maxsize=None)
def universes(game):
    (x1, s1), (x2, s2), turn = game
    if s1 >= stop:
        return (1, 0)
    if s2 >= stop:
        return (0, 1)
    
    unis = (mult(freq, universes(next_game(game, roll)))
            for roll, freq in distribution.items())
    
    return reduce(add, unis)

In [312]:
game = Game(Player(start[0]), Player(start[1]), 0)
w1, w2 = universes(game)
w1

446968027750017

## Day 22

In [494]:
@dataclass
class Range:
    a: int ; b: int
    
    @property
    def width(self):
        return self.b - self.a + 1
    
    def __and__(self, other):
        return Range(max(self.a, other.a), min(self.b, other.b))
    
    def __bool__(self):
        return self.width > 0
    
    def __repr__(self):
        return f'{self.a}..{self.b}'

In [495]:
@dataclass
class Cube:
    x: Range ; y: Range ; z: Range
    sign: int = 1
    
    @property
    def vol(self):
        return self.sign * self.x.width * self.y.width * self.z.width
    
    def __neg__(self):
        return Cube(self.x, self.y, self.z, -1 * self.sign)
    
    def __and__(self, other):
        return Cube(self.x & other.x, self.y & other.y, self.z & other.z,
                    other.sign)
    
    def __bool__(self):
        return bool(self.x and self.y and self.z)
    
    def __repr__(self):
        sign = '' if self.sign == 1 else '-'
        return f'{sign}({self.x}, {self.y}, {self.z})'

In [496]:
def step(line):
    [power, coords] = line.strip().split()
    [x1, x2, y1, y2, z1, z2] = integers(coords)
    return power, Cube(Range(x1, x2), Range(y1, y2), Range(z1, z2))

In [501]:
steps = Input(22, line_parser=step, test=False)

In [502]:
def execute(steps):
    cubes = []
    for power, cube in steps:
        cube = cube if power == 'on' else -cube

        intersects = filter(bool, (cube & other for other in reversed(cubes)))
        for intersect in intersects:
            cubes.append(-intersect)

        if power == 'on':
            cubes.append(cube)
    return sum(cube.vol for cube in cubes)

In [503]:
small = Cube(Range(-50, 50), Range(-50, 50), Range(-50, 50))
def within(step, space=small):
    return space & step[1]

execute(filter(within, steps))

582644

In [506]:
execute(tq(steps))

  0%|          | 0/420 [00:00<?, ?it/s]

1263804707062415

## Day 23

In [117]:
burrow = list(Input(23, test=False))
burrow[3] = f'##{burrow[3]}##'
burrow[4] = f'##{burrow[4]}##'
burrow

['#############',
 '#...........#',
 '###B#A#B#C###',
 '###D#A#D#C###',
 '#############']

In [119]:
spots = set(coord for coord, val in coords(burrow) if val != '#')

edges = {spot: { coord for coord in neighbors4(spot) if coord in spots }
         for spot in spots }

hallway = {(x, 1)
           for x, c in enumerate('#..x.x.x.x..#')
           if c == '.'}

rooms = {
    'A': ((3, 3), (3, 2)),
    'B': ((5, 3), (5, 2)),
    'C': ((7, 3), (7, 2)),
    'D': ((9, 3), (9, 2)),
}

all_rooms = sum(rooms.values(), start=tuple())

multipliers = {
    'A': 1, 'B': 10, 'C': 100, 'D': 1000,    
}

pods = set('ABCD')

# tuple[(coord, pod)]
State = tuple

In [120]:
state = State((coord, pod)
              for coord, pod in coords(burrow)
              if pod in pods)

In [121]:
def reachable(start, moves_func):
    frontier = [start]
    visited = set()
    while frontier:
        now = frontier.pop()
        if now in visited:
            continue

        visited.add(now)
        
        for s2 in moves_func(now):
            frontier.append(s2)
    return visited - {start}

In [122]:
def adjacent(pods):
    filled = set(pods.keys())
    
    def adj(coord):
        return edges[coord] - filled
    return adj

In [123]:
null = set()

@t.curry
def pod_can_go(state, coord):
    pods = dict(state)
    pod = pods[coord]
    room = rooms[pod]
    
    room_is_ready = all(pods[r] == pod for r in room if r in pods)
    
    settled = room_is_ready and coord in room
    waiting = coord in hallway and not room_is_ready
    if settled or waiting:
        return null
    
    reaches = reachable(coord, adjacent(pods))
    
    if coord in all_rooms:
        return reaches & hallway
    
    return reaches & { first(r for r in room if r not in pods) }

In [124]:
def moves(state):
    can_go = pod_can_go(state)
    state = deque(state)
    for _ in range(len(state)):
        (coord, pod) = first = state.popleft()
        for next_coord in can_go(coord):
            s2 = state.copy()
            s2.appendleft((next_coord, pod))
            yield tuple(s2)
        state.append(first)

In [125]:
def pods_left(state):
    return quantify(pod for coord, pod in state
                    if coord not in rooms[pod])

In [126]:
def energy(s1, s2):
    [c1, c2] = list(occupied(s1) ^ occupied(s2))
    pod = first(pod for coord, pod in s1 if coord == c1 or coord == c2)
    return multipliers[pod] * cityblock_distance(c1, c2)

In [127]:
%%time

path = Astar(state, moves, pods_left, energy)

CPU times: user 31.5 s, sys: 116 ms, total: 31.6 s
Wall time: 31.8 s


In [128]:
sum(energy(s1, s2) for s1, s2 in pairwise(path))

16506

**Part 2**

In [129]:
burrow = list(Input(23, test=False))
burrow[3] = f'##{burrow[3]}##'
burrow[4] = f'##{burrow[4]}##'
burrow.insert(3, '###D#C#B#A###')
burrow.insert(4, '###D#B#A#C###')
burrow

['#############',
 '#...........#',
 '###B#A#B#C###',
 '###D#C#B#A###',
 '###D#B#A#C###',
 '###D#A#D#C###',
 '#############']

In [130]:
state = State((coord, pod)
              for coord, pod in coords(burrow)
              if pod in pods)

In [131]:
spots = set(coord for coord, val in coords(burrow) if val != '#')

edges = {spot: { coord for coord in neighbors4(spot) if coord in spots }
         for spot in spots }

rooms = {
    'A': ((3, 5), (3, 4), (3, 3), (3, 2)),
    'B': ((5, 5), (5, 4), (5, 3), (5, 2)),
    'C': ((7, 5), (7, 4), (7, 3), (7, 2)),
    'D': ((9, 5), (9, 4), (9, 3), (9, 2)),
}

all_rooms = sum(rooms.values(), start=tuple())

In [132]:
%%time

path = Astar(state, moves, pods_left, energy)

CPU times: user 1min 16s, sys: 551 ms, total: 1min 17s
Wall time: 1min 18s


In [133]:
sum(energy(s1, s2) for s1, s2 in pairwise(path))

48304

## Day 24

In [87]:
parse = lambda str: str.strip().split()
instrs = Input(24, line_parser=parse, test=False)

In [92]:
commands = dict(
    inp='{0} = next(serial)',
    add='{0} = {0} + {1}',
    mul='{0} = {0} * {1}',
    div='{0} = {0} // {1}',
    mod='{0} = {0} % {1}',
    eql='{0} = int({0} == {1})',
)

In [127]:
block_len = 18
block_params = [4, 5, 15]

def comple(instrs):
    res = []
    for n in range(0, len(instrs), block_len):
        block = instrs[n:n + block_len]
        abc = [int(block[i][-1]) for i in block_params]
        res.append(abc)
    return res

In [143]:
Constraint = namedtuple('C', 'i, j, delta')
def constraints(params):
    res = []
    stack = []
    for i, (a, b, c) in enumerate(params):
        if a == 1:
            stack.append((i, c))
        else:
            old_i, c = stack.pop()
            res.append(Constraint(old_i, i, c + b))
    return res

In [170]:
valid = set(digits(123456789))

@t.curry
def fulfill(rules, tiebreak=max):
    res = [0] * 14
    for i, j, delta in rules:
        di, dj = tiebreak((di, dj)
                          for di in valid
                          if (dj := di + delta) in valid)
        res[i] = di
        res[j] = dj
    return res

In [173]:
process = t.compose_left(
    comple,
    constraints,
    fulfill,
    t.map(str),
    cat,
    print,
)
process(instrs)

99911993949684


In [174]:
process = t.compose_left(
    comple,
    constraints,
    fulfill(tiebreak=min),
    t.map(str),
    cat,
    print,
)
process(instrs)

62911941716111


## Day 25

In [208]:
sea = Input(25, test=False)

In [209]:
w, h = len(sea[0]), len(sea)

cukes = {coord: cuke for coord, cuke in coords(sea) if cuke in 'v>'}

In [210]:
def step(last):
    cukes, _ = last
    cukes = cukes.copy()
    
    east  = {(x, y) for (x, y), cuke in cukes.items()
             if cuke == '>' and ((x + 1) % w, y) not in cukes}
    
    for (x, y) in east:
        del cukes[(x, y)]
        cukes[((x + 1) % w, y)] = '>'
        
    south = {(x, y) for (x, y), cuke in cukes.items()
             if cuke == 'v' and (x, (y + 1) % h) not in cukes}
    
    for (x, y) in south:
        del cukes[(x, y)]
        cukes[(x, (y + 1) % h)] = 'v'
    
    return cukes, len(east) + len(south)

In [211]:
steps, _ = first_true(enumerate(repeatedly(step, (cukes, None))),
                      lambda state: state[1][1] == 0)
steps

549