In [160]:
#### IMPORTS

import re
from collections import Counter, defaultdict, namedtuple, deque
from itertools   import chain, cycle, product, count as count_from
from functools   import lru_cache

from pprint import pprint as p
import toolz.curried as t
from tqdm import tqdm as tq

#### CONSTANTS

alphabet = 'abcdefghijklmnopqrstuvwxyz'
ALPHABET = alphabet.upper()
infinity = float('inf')

#### SIMPLE UTILITY FUNCTIONS

cat = ''.join

def ints(start, end, step=1):
    "The integers from start to end, inclusive: range(start, end+1)"
    return range(start, end + 1, step)

def first(iterable, default=None): 
    "The first item in an iterable, or default if it is empty."
    return next(iter(iterable), default)

def first_true(iterable, pred=None, default=None):
    """Returns the first true value in the iterable.
    If no true value is found, returns *default*
    If *pred* is not None, returns the first item
    for which pred(item) is true."""
    # first_true([a,b,c], default=x) --> a or b or c or x
    # first_true([a,b], fn, x) --> a if fn(a) else b if fn(b) else x
    return next(filter(pred, iterable), default)

def quantify(iterable, pred=bool):
    "Count how many times the predicate is true of an item in iterable."
    return sum(map(pred, iterable))

def multimap(items):
    "Given (key, val) pairs, return {key: [val, ....], ...}."
    result = defaultdict(list)
    for (key, val) in items:
        result[key].append(val)
    return result

def mapt(fn, *args): 
    "Do a map, and make the results into a tuple."
    return tuple(map(fn, *args))

def map2d(fn, grid):
    "Apply fn to every element in a 2-dimensional grid."
    return tuple(mapt(fn, row) for row in grid)

def repeat(n, fn, arg, *args, **kwds):
    "Repeat arg = fn(arg) n times, return arg."
    return nth(repeatedly(fn, arg, *args, **kwds), n)

def repeatedly(fn, arg, *args, **kwds):
    "Yield arg, fn(arg), fn(fn(arg)), ..."
    yield arg
    while True:
        arg = fn(arg, *args, **kwds)
        yield arg
        
def compose(f, g): 
    "The function that computes f(g(x))."
    return lambda x: f(g(x))

#### FILE INPUT AND PARSING

def Input(day, line_parser=str.strip, file_template='data/advent2018/{}.txt'):
    "For this day's input file, return a tuple of each line parsed by `line_parser`."
    return mapt(line_parser, open(file_template.format(day)))

def integers(text): 
    "A tuple of all integers in a string (ignore other characters)."
    return mapt(int, re.findall(r'-?\d+', text))

In [21]:
################ 2-D points implemented using (x, y) tuples

def X(point): return point[0]
def Y(point): return point[1]

origin = (0, 0)
HEADINGS = UP, LEFT, DOWN, RIGHT = (0, -1), (-1, 0), (0, 1), (1, 0)

def turn_right(heading): return HEADINGS[HEADINGS.index(heading) - 1]
def turn_around(heading):return HEADINGS[HEADINGS.index(heading) - 2]
def turn_left(heading):  return HEADINGS[HEADINGS.index(heading) - 3]

def add(A, B): 
    "Element-wise addition of two n-dimensional vectors."
    return mapt(sum, zip(A, B))

In [210]:
def head(iterable, n=5): return list(t.take(n, iterable))

def coords(two_d_arr):
    return [
        ((x, y), val)
        for y, line in enumerate(two_d_arr)
        for x, val in enumerate(line)
    ]

def duplicates(iterable):
    return { item for item, count in Counter(iterable).items() if count > 1 }

## Day 13

In [307]:
tracks = r'''
/->-\        
|   |  /----\
| /-+--+-\  |
| | |  | v  |
\-+-/  \-+--/
  \------/   
'''.strip('\n').split('\n')

tracks = r'''
/>-<\  
|   |  
| /<+-\
| | | v
\>+</ |
  |   ^
  \<->/
'''.strip('\n').split('\n')

tracks = Input(13, line_parser=t.identity)

In [308]:
Cart = namedtuple('Cart', 'id, pos, dir, turns')

cart_dirs = { 'v': DOWN, '^': UP, '<': LEFT, '>': RIGHT }

def junctions(tracks=tracks):
    return { pos: ch for pos, ch in coords(tracks) if ch in r'/\+' }

def carts(tracks=tracks):
    positions = [(pos, ch) for pos, ch in coords(tracks) if ch in cart_dirs]
    return deque(
        Cart(id, pos, cart_dirs[ch], cycle([turn_left, t.identity, turn_right]))
        for id, (pos, ch) in enumerate(positions)
    )

In [309]:
bends = {
    r'/': { LEFT: DOWN, DOWN: LEFT, RIGHT: UP, UP: RIGHT },
    '\\': { LEFT: UP, UP: LEFT, RIGHT: DOWN, DOWN: RIGHT },
}

def turn(cart, juncts=junctions()):
    _, pos, dir, turns = cart
    junction = juncts.get(add(pos, dir), False)
    
    if junction in bends: return bends[junction][dir]
    if junction == '+':   return next(turns)(dir)
    else:                 return dir

In [310]:
def move_cart(carts):
    if first(carts).id == min(cart.id for cart in carts):
        carts = deque(Cart(id, pos, dir, turns)
                      for id, (_, pos, dir, turns) in
                      enumerate(sorted(carts, key=lambda cart: (Y(cart.pos), X(cart.pos)))))
        
    cart = id, pos, dir, turns = carts.popleft()
    carts.append(Cart(id, add(pos, dir), turn(cart), turns))
    return carts

def collision_pos(carts):
    crashes = duplicates(cart.pos for cart in carts)
    return crashes if len(crashes) > 0 else False

In [311]:
first_true(collision_pos(c) for c in repeatedly(move_cart, carts()))

{(32, 8)}

In [312]:
def remove_collisions(carts):
    crashes = collision_pos(carts)
    return (deque(cart for cart in carts if cart.pos not in crashes)
            if crashes else carts)

In [313]:
first(c for c in repeatedly(compose(remove_collisions, move_cart), carts())
      if len(c) == 1 and first(c).id == 0)

deque([Cart(id=0, pos=(38, 38), dir=(0, 1), turns=<itertools.cycle object at 0x10ede5168>)])