In [162]:
#### IMPORTS

import re
import abc
import operator
import numpy as np
from collections import Counter, defaultdict, namedtuple, deque, abc
from itertools   import (permutations, combinations, chain, cycle, product, islice, 
                         takewhile, zip_longest, starmap, count as count_from)
from functools   import lru_cache, reduce
from heapq import (heappush, heappop, nlargest, nsmallest)

from pprint import pprint as p, pformat as pf
import toolz.curried as t
from tqdm import tqdm_notebook as tq
from dataclasses import dataclass

#### CONSTANTS

alphabet = 'abcdefghijklmnopqrstuvwxyz'
ALPHABET = alphabet.upper()
infinity = float('inf')

#### SIMPLE UTILITY FUNCTIONS

cat = ''.join

def ints(start, end, step=1):
    "The integers from start to end, inclusive: range(start, end+1)"
    return range(start, end + 1, step)

def first(iterable, default=None): 
    "The first item in an iterable, or default if it is empty."
    return next(iter(iterable), default)

def head(iterable, n=5):
    "The first n items in an iterable"
    return tuple(islice(iterable, n))

def tail(iterable, n=1):
    "Skip n items in an iterable"
    return islice(iterable, n, None)

def first_true(iterable, pred=None, default=None):
    """Returns the first true value in the iterable.
    If no true value is found, returns *default*
    If *pred* is not None, returns the first item
    for which pred(item) is true."""
    # first_true([a,b,c], default=x) --> a or b or c or x
    # first_true([a,b], fn, x) --> a if fn(a) else b if fn(b) else x
    return next(filter(pred, iterable), default)

def nth(iterable, n, default=None):
    "Returns the nth item of iterable, or a default value"
    return next(islice(iterable, n, None), default)

def upto(iterable, maxval):
    "From a monotonically increasing iterable, generate all the values <= maxval."
    # Why <= maxval rather than < maxval? In part because that's how Ruby's upto does it.
    return takewhile(lambda x: x <= maxval, iterable)

identity = lambda x: x

def quantify(iterable, pred=bool):
    "Count how many times the predicate is true of an item in iterable."
    return sum(map(pred, iterable))

def multimap(items):
    "Given (key, val) pairs, return {key: [val, ....], ...}."
    result = defaultdict(list)
    for (key, val) in items:
        result[key].append(val)
    return result

def overlapping(iterable, n):
    """Generate all (overlapping) n-element subsequences of iterable.
    overlapping('ABCDEFG', 3) --> ABC BCD CDE DEF EFG"""
    if isinstance(iterable, abc.Sequence):
        yield from (iterable[i:i+n] for i in range(len(iterable) + 1 - n))
    else:
        result = deque(maxlen=n)
        for x in iterable:
            result.append(x)
            if len(result) == n:
                yield tuple(result)
                
def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    return overlapping(iterable, 2)

def mapt(fn, *args): 
    "Do a map, and make the results into a tuple."
    return tuple(map(fn, *args))

def map2d(fn, grid):
    "Apply fn to every element in a 2-dimensional grid."
    return tuple(mapt(fn, row) for row in grid)

def flatmap(fn, *args):
    "Do a map and a one-level flatten"
    return tuple(chain.from_iterable(map(fn, *args)))

def repeat(n, fn, arg, *args, **kwds):
    "Repeat arg = fn(arg) n times, return arg."
    return nth(repeatedly(fn, arg, *args, **kwds), n)

def repeatedly(fn, arg, *args, **kwds):
    "Yield arg, fn(arg), fn(fn(arg)), ..."
    yield arg
    while True:
        arg = fn(arg, *args, **kwds)
        yield arg
        
def repeatedly1(fn, arg, *args, **kwds):
    "Yield fn(arg), fn(fn(arg)), ..."
    return tail(repeatedly(fn, arg, *args, **kwds))

def compose(f, g): 
    "The function that computes f(g(x))."
    return lambda x: f(g(x))

#### FILE INPUT AND PARSING

def Input(day, line_parser=str.strip, test=False, file_template='data/2020/{}.txt'):
    "For this day's input file, return a tuple of each line parsed by `line_parser`."
    return mapt(line_parser, open(file_template.format(
        f'{day}test' if test else day
    )))

def Groups(day, group_parser=str.split, test=False):
    entire = Input(day, t.identity, test)
    groups = ''.join(entire).split('\n\n')
    return mapt(group_parser, groups)

@t.curry
def Tokens(line, sep=','):
    "Splits line into delimited tokens"
    return line.strip().split(sep)

def integers(text): 
    "A tuple of all integers in a string (ignore other characters)."
    return mapt(int, re.findall(r'-?\d+', text))

def digits(number):
    "Tuple of digits in number"
    return mapt(int, str(number))

################ 2-D points implemented using (x, y) tuples

def X(point): return point[0]
def Y(point): return point[1]
def Z(point): return point[2]

origin = (0, 0)
HEADINGS = UP, LEFT, DOWN, RIGHT = (0, -1), (-1, 0), (0, 1), (1, 0)

def turn_right(heading): return HEADINGS[HEADINGS.index(heading) - 1]
def turn_around(heading):return HEADINGS[HEADINGS.index(heading) - 2]
def turn_left(heading):  return HEADINGS[HEADINGS.index(heading) - 3]

def add(A, B): 
    "Element-wise addition of two n-dimensional vectors."
    return mapt(sum, zip(A, B))

def sub(A, B): 
    "Element-wise subtraction of two n-dimensional vectors."
    return tuple(a - b for a, b in zip(A, B))

def neighbors4(point): 
    "The four neighboring squares."
    x, y = point
    return (          (x, y-1),
            (x-1, y),           (x+1, y), 
                      (x, y+1))

def neighbors8(point): 
    "The eight neighboring squares."
    x, y = point 
    return ((x-1, y-1), (x, y-1), (x+1, y-1),
            (x-1, y),             (x+1, y),
            (x-1, y+1), (x, y+1), (x+1, y+1))

def cityblock_distance(P, Q=origin): 
    "Manhatten distance between two points."
    return sum(abs(p - q) for p, q in zip(P, Q))

def distance(P, Q=origin): 
    "Straight-line (hypotenuse) distance between two points."
    return sum((p - q) ** 2 for p, q in zip(P, Q)) ** 0.5

def king_distance(P, Q=origin):
    "Number of chess King moves between two points."
    return max(abs(p - q) for p, q in zip(P, Q))

################ Debugging 

def trace1(f):
    "Print a trace of the input and output of a function on one line."
    def traced_f(*args):
        result = f(*args)
        print('{}({}) = {}'.format(f.__name__, ', '.join(map(str, args)), result))
        return result
    return traced_f

def grep(pattern, iterable):
    "Print lines from iterable that match pattern."
    for line in iterable:
        if re.search(pattern, line):
            print(line)
            
class Struct:
    "A structure that can have any fields defined."
    def __init__(self, **entries): self.__dict__.update(entries)
    def __repr__(self): 
        fields = ['{}={}'.format(f, self.__dict__[f]) 
                  for f in sorted(self.__dict__)]
        return 'Struct({})'.format(', '.join(fields))

################ A* and Breadth-First Search (tracking states, not actions)

def always(value): return (lambda *args: value)

def Astar(start, moves_func, h_func, cost_func=always(1)):
    "Find a shortest sequence of states from start to a goal state (where h_func(s) == 0)."
    frontier  = [(h_func(start), start)] # A priority queue, ordered by path length, f = g + h
    previous  = {start: None}  # start state has no previous state; other states will
    path_cost = {start: 0}     # The cost of the best path to a state.
    Path      = lambda s: ([] if (s is None) else Path(previous[s]) + [s])
    while frontier:
        (f, s) = heappop(frontier)
        if h_func(s) == 0:
            return Path(s)
        for s2 in moves_func(s):
            g = path_cost[s] + cost_func(s, s2)
            if s2 not in path_cost or g < path_cost[s2]:
                heappush(frontier, (g + h_func(s2), s2))
                path_cost[s2] = g
                previous[s2] = s

def bfs(start, moves_func, goals):
    "Breadth-first search"
    goal_func = (goals if callable(goals) else lambda s: s in goals)
    return Astar(start, moves_func, lambda s: (0 if goal_func(s) else 1))

## Day 1

In [163]:
nums = set(Input(1, line_parser=int))
len(nums)

200

In [164]:
def has_2020(n):
    return (2020 - n) in nums
n = first_true(nums, pred=has_2020)
n * (2020 - n)

969024

In [165]:
def has_2020_triplet(pair):
    return (2020 - sum(pair)) in nums
n1, n2 = first_true(combinations(nums, 2), pred=has_2020_triplet)
n1 * n2 * (2020 - n1 - n2)

230057040

## Day 2

In [166]:
TOKENS = re.compile(r'(\d+)-(\d+) (\w): (\w+)')
def toks(line):
    [(least, most, char, pw)] = TOKENS.findall(line)
    return int(least), int(most), char, pw

In [167]:
specs = Input(2, line_parser=toks)
specs[:2]

((6, 10, 's', 'snkscgszxsssscss'), (6, 7, 'b', 'bbbbbxkb'))

In [168]:
def is_valid_pw(spec):
    least, most, char, pw = spec
    return least <= pw.count(char) <= most

In [169]:
quantify(specs, is_valid_pw)

414

In [170]:
from operator import xor

def is_valid_pw2(spec):
    pos1, pos2, char, pw = spec
    return xor(pw[pos1 - 1] == char, pw[pos2 - 1] == char)

In [171]:
quantify(specs, is_valid_pw2)

413

## Day 3

In [172]:
mountain = Input(3)
width = len(mountain[0])
mountain[0:2]

('....#..#.......#........#....#.', '..##.#.#.#...................#.')

In [173]:
def at(irow):
    index, row = irow
    x = index * 3 % len(row)
    return row[x]

quantify(mapt(at, enumerate(mountain)), lambda c: c == '#')

257

In [174]:
def hops(slope):
    x, y = 0, 0
    while True:
        yield (x, y)
        x += X(slope)
        y += Y(slope)
        
def coords(slope):
    return tuple(takewhile(lambda p: X(p) < len(mountain), hops(slope)))

def at(coord):
    x, y = coord
    return mountain[x][y % width]

def count_trees(slope):
    return quantify(map(at, coords(slope)), lambda c: c == '#')

count_trees((1, 3))

257

In [175]:
trees = [
    count_trees((1, 1)),
    count_trees((1, 3)),
    count_trees((1, 5)),
    count_trees((1, 7)),
    count_trees((2, 1)),
]

reduce(operator.mul, trees)
# trees

1744787392

## Day 4

In [176]:
def fields(pp):
    return dict([line.split(':') for line in pp])

fields([ 'ecl:gry', 'pid:123' ])

{'ecl': 'gry', 'pid': '123'}

In [177]:
inp = Input(4, line_parser=t.identity)
ppl = ''.join(inp).split('\n\n')

parse = t.compose(fields, str.split)
ppl = mapt(parse, ppl)
ppl[0]

{'iyr': '2015',
 'cid': '189',
 'ecl': 'oth',
 'byr': '1947',
 'hcl': '#6c4ab1',
 'eyr': '2026',
 'hgt': '174cm',
 'pid': '526744288'}

In [178]:
exp = ['byr', 'iyr', 'eyr', 'hgt', 'hcl', 'ecl', 'pid']
def has_fields(pp):
    return all(field in pp for field in exp)

quantify(ppl, has_fields)

264

In [179]:
def hgt(h):
    n, unit = h[:-2], h[-2:]
    return (150 <= int(n) <= 193 if unit == 'cm' else
            59 <= int(n) <= 76 if unit == 'in' else
            False)

hcl_re = re.compile(r'^#[0-9a-f]{6}$')
pid_re = re.compile(r'^[0-9]{9}$')

validators = {
    'byr': lambda v: 1920 <= int(v) <= 2002,
    'iyr': lambda v: 2010 <= int(v) <= 2020,
    'eyr': lambda v: 2020 <= int(v) <= 2030,
    'hgt': hgt,
    'hcl': hcl_re.match,
    'ecl': lambda v: v in ['amb', 'blu', 'brn', 'gry', 'grn', 'hzl', 'oth'],
    'pid': pid_re.match,
    'cid': lambda v: True
}

In [180]:
def is_valid(pp):
    return has_fields(pp) and all(
        validators[f](val)
        for f, val in pp.items()
    )

In [181]:
quantify(ppl, is_valid)

224

## Day 5

In [182]:
tix = Input(5)
tix[:2]

('FFFFBFBLLR', 'BFBFFBBLLR')

In [183]:
from math import floor, ceil

def split(front, lower, upper, tik):
    if not tik: return lower
    
    first, rest = tik[0], tik[1:]
    med = lower + (upper - lower) / 2
    med1 = floor(med)
    med2 = ceil(med)
    return (split(front, lower, med1, rest) if first == front else
            split(front, med2, upper, rest))

row = t.curry(split, 'F', 0, 127)
col = t.curry(split, 'L', 0, 7)

In [184]:
def seat(tik):
    rows, cols = tik[:-3], tik[-3:]
    r, c = row(rows), col(cols)
    return r * 8 + c

In [185]:
ids = set(mapt(seat, tix))
max(ids)

835

In [186]:
first_true(ints(min(ids), max(ids)),
           lambda n: n not in ids)

649

## Day 6

In [198]:
qs = Groups(6)
qs[:2]

(['ymw', 'w', 'wm', 'vsw', 'wm'], ['vs', 'lqn', 'ti', 'uvl'])

In [199]:
def answered(g):
    ppl = mapt(set, g)
    return reduce(operator.or_, ppl)

answered(['abc', 'bcd'])

{'a', 'b', 'c', 'd'}

In [200]:
sum(mapt(t.compose(len, answered), qs))

6735

In [201]:
def answered(g):
    ppl = mapt(set, g)
    return reduce(operator.and_, ppl)

answered(['abc', 'bcd'])

{'b', 'c'}

In [202]:
sum(mapt(t.compose(len, answered), qs))

3221

## Day 7

In [217]:
Tokens('helo, world')

['helo', ' world']

In [246]:
def parse_children(rest):
    toks = Tokens(rest)
    children = [Tokens(c, ' ') for c in toks]
    return [(int(c[0]), f'{c[1]} {c[2]}')
            for c in children
            if c[0] != 'no']

def parse_bags(line):
    parent, rest = line.split(' bags contain ')
    children = mapt(parse_children, rest.split(', '))
    return parent, list(t.concat(children))

In [273]:
rules = Input(7, line_parser=parse_bags, test=True)
rules = dict(rules)
head(rules)

('shiny gold', 'dark red', 'dark orange', 'dark yellow', 'dark green')

In [270]:
gold_cache = defaultdict(lambda: None)

search = 'shiny gold'

def gets_gold(bag):
    if bag == search: return True
    if gold_cache[bag] != None:
        return gold_cache[bag]
    
    children = rules[bag]
    yes = any(gets_gold(c) for _, c in children)
    gold_cache[bag] = yes
    return yes

gets_gold('bright white')

False

In [271]:
quantify(rules.keys(), gets_gold) - 1

151

In [272]:
def inside(bag):
    children = rules[bag]
    return sum(count * inside(child) for count, child in children)

In [274]:
inside('shiny gold')

0