In [31]:
from aocd.models import Puzzle

puzzle = Puzzle(year=2021, day=21)

def parses(text):
    return [i[0] for i in parse.findall('position: {:d}', text)]

data = parses(puzzle.input_data)

In [32]:
sample = parses("""Player 1 starting position: 4
Player 2 starting position: 8""")

In [137]:
def cycle(pos):
    return 1 + (pos-1) % 10

In [152]:
def solve_a(data):
    pos1, pos2 = data
    i = 1 
    sc1, sc2 = 0, 0
    for i in itertools.count()
        pos1 = cycle(pos1 +  i + i + 1 + i + 2)
        sc1 += pos1
        i += 3
        if sc1 >= 1000:
            break
        pos2 = cycle(pos2 +  i + i + 1 + i + 2)
        sc2 += pos2
        i += 3
        if sc2 >= 1000:
            break
    return (i-1) * min(sc1,sc2)

In [153]:
solve_a(sample)

739785

In [157]:
def solve_a(data):
    pos1, pos2 = data
    i = 1 
    sc1, sc2 = 0, 0
    for i in itertools.count(1,6):
        pos1 = cycle(pos1 + i + i + 1 + i + 2)
        sc1 += pos1
        if sc1 >= 1000:
            return (i+2)*sc2
        i += 3
        pos2 = cycle(pos2 + i + i + 1 + i + 2)
        sc2 += pos2
        if sc2 >= 1000:
            return (i+2)*sc1

In [159]:
solve_a(data)

506466

In [165]:
def solve_b(data):
    """Each player is represented as dict[pos, dict[score, #ways]]
    We update each player indenpendenly and compute products
    of wins/loses to account for all possible universes
    """
    pos1, pos2 = data
    mem1, mem2 = {pos1: {0:1}}, {pos2: {0:1}}
    total1, total2 = 0, 0
    # Collapse 27 combinations into the 7 distinct moves (3..9) with counts
    steps = Counter(map(sum, itertools.product((1,2,3), repeat=3)))
    
    def turn(mem):
        # Runs turn for a single player
        wins = 0
        new_mem = defaultdict(lambda: defaultdict(int))
        for pos, scores in mem.items():
            for score, ways in scores.items():
                positions = {cycle(pos+step): n*ways for step, n in steps.items()}
                for new_pos, new_ways in positions.items():
                    new_score = score + new_pos
                    if new_score >= 21:
                        wins += new_ways
                    else:
                        new_mem[new_pos][new_score] += new_ways
        return new_mem, wins
    
    def count_ways(mem):
        return sum(ways for pos, scores in mem.items() for score, ways in scores.items())

    while len(mem1) != 0 and len(mem2) != 0:
        mem1, wins1 = turn(mem1)
        total1 += wins1 * count_ways(mem2)
        mem2, wins2 = turn(mem2)
        total2 += wins2 * count_ways(mem1)

    return max(total1, total2)
        
    

In [211]:
def solve_b(data, wincondition=21):
    """Each player is represented as dict[(pos, score), #ways]]
    We update each player indenpendenly and compute products
    of wins/loses to account for all possible universes
    """
    pos1, pos2 = data
    mem1, mem2 = {(pos1,0):1}, {(pos2,0):1}
    total1, total2 = 0, 0
    # Collapse 27 combinations into the 7 distinct moves (3..9) with counts
    steps = Counter(map(sum, itertools.product((1,2,3), repeat=3)))
    
    def turn(mem):
        # Runs turn for a single player
        wins = 0
        new_mem = defaultdict(int)
        for (pos, score), ways in mem.items():
            positions = {cycle(pos+step): n*ways for step, n in steps.items()}
            for new_pos, new_ways in positions.items():
                new_score = score + new_pos
                if new_score >= wincondition:
                    wins += new_ways
                else:
                    new_mem[new_pos,new_score] += new_ways
        return new_mem, wins

    while len(mem1) != 0 and len(mem2) != 0:
        mem1, wins1 = turn(mem1)
        total1 += wins1 * sum(mem2.values())
        mem2, wins2 = turn(mem2)
        total2 += wins2 * sum(mem1.values())

    return max(total1, total2)

In [213]:
solve_b(sample,100)

55038535590428753856514661082323914715870927758485665656548544838675

In [216]:
# Solves Part3 as well 
# https://ol.reddit.com/r/adventofcode/comments/rlgfyi/2021_day_21_part_3_playing_the_full_game/

In [217]:
solve_b(sample,1000)

5421241233526473492719111827530073322043700685324887972643877167221582706632062890226633882867215122817094478226757831672569240814542964877955305731916648253248335858512057533670130530856541774540595494034042339999669679917631875732825766116784092779284811114399087382450717587331260789093785894799877563606482291952570043001217814381339302494119777384417342867283876524391498980886698692797031928912775846300310773961458685480293345030624886674682243872741738613990867919596644394338348577333930203953377127120984552032785093074890068515454918026316742238636555050830959764192402139432795516421795229983126452003483234264162966755183792554646664002598051928477710447862345835

In [171]:
solve_b(data)

632979211251440

In [172]:
%%timeit
solve_b(sample)

2.81 ms ± 114 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [173]:
%%timeit
solve_b(data)

2.83 ms ± 46.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [123]:
sample

[4, 8]

In [124]:
data

[8, 7]

In [133]:
max(solve_b(data))

632979211251440

In [134]:
max(solve_b(sample))

444356092776315

In [192]:
# %%timeit

import functools

dice = [
	i + j + k
	for i in (1, 2, 3)
	for j in (1, 2, 3)
	for k in (1, 2, 3)
]

@functools.lru_cache(maxsize=None)
def search(state):
	score, pos = state[0]
	state = [state[1], None]
	wins = [0, 0]
	for roll in dice:
		newpos = (pos + roll - 1) % 10 + 1
		newscore = score + newpos
		state[1] = (newscore, newpos)
		if newscore >= 21:
			wins[0] += 1
		else:
			myself, other = search(tuple(state))
			wins = [wins[0] + other, wins[1] + myself]
	return wins

search(((0,4),(0,8)))

[444356092776315, 341960390180808]

In [203]:
444356092776315+341960390180808

786316482957123

In [207]:
%%timeit

@functools.lru_cache(maxsize=None)
def count_wins(p1, p2, sc1, sc2):
    # Given that A is at position p1 with score s1, and B is at position p2 with score s2, and A is to move,
    # return (# of universes where player A wins, # of universes where player B wins)
    wins1, wins2 = 0, 0
    for step in map(sum, itertools.product((1,2,3),repeat=3)):
        new_p1 = (p1+step-1) % 10 + 1
        new_sc1 = sc1 + new_p1
        if new_sc1 >= 21:
            wins1 += 1
        else:
            w2, w1 = count_wins(p2, new_p1, sc2, new_sc1)
            wins1 += w1
            wins2 += w2
    return wins1, wins2


count_wins(*data,0,0)

228 ms ± 47.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


(444356092776315, 341960390180808)