In [1]:
import re
from aocd import get_data, submit 
from dataclasses import dataclass

In [109]:
import numpy as np
data = get_data(year=2022, day=19)
# data = """Blueprint 1: Each ore robot costs 4 ore. Each clay robot costs 2 ore. Each obsidian robot costs 3 ore and 14 clay. Each geode robot costs 2 ore and 7 obsidian.
# Blueprint 2: Each ore robot costs 2 ore. Each clay robot costs 3 ore. Each obsidian robot costs 3 ore and 8 clay. Each geode robot costs 3 ore and 12 obsidian."""
data = data.split("\n")
data = [list(map(int, re.findall(r"( [\d]* )", line))) for line in data]
data = [(
        (d[0], 0, 0, 0),
        (d[1], 0, 0, 0),
        (d[2], d[3], 0, 0),
        (d[4], 0, d[5], 0)) for d in data]

bp = data[0]
bp


((4, 0, 0, 0), (4, 0, 0, 0), (4, 14, 0, 0), (3, 0, 16, 0))

In [77]:
# 10x speed improvement over numpy approach 
def buy_options(bp, m):
    return ((m[0] - bp[0][0] >= 0),
            (m[0] - bp[1][0] >= 0),
            (m[0] - bp[2][0] >= 0) and (m[1] - bp[2][1] >= 0),
            (m[0] - bp[3][0] >= 0) and (m[2] - bp[3][2] >= 0))

# 30x speed improvement over np.all(x >= 0)
all_nonnegative = lambda x: (x[0] >= 0) and (x[1] >= 0) and (x[2] >= 0) and (x[3] >= 0)

In [78]:
produce = lambda b, m: (m[0]+b[0], m[1]+b[1], m[2]+b[2], m[3]+b[3])
add_bot = lambda bots, idx: tuple((b+1 if idx==i else b) for i, b in enumerate(bots))

In [79]:
buy_bot = lambda bp, m, idx: (m[0]-bp[idx][0], m[1]-bp[idx][1], m[2]-bp[idx][2], m[3]-bp[idx][3])

In [80]:
buy_options(bp, (5,5,5,0))

(True, True, False, False)

In [92]:
is_better_money = lambda a, b: False if a == b else ((a[0] >= b[0]) and (a[1] >= b[1]) 
                                                        and (a[2] >= b[2]) and (a[3] >= b[3]))

In [82]:
is_better_money((8,4,0,0),(8,4,0,0))

False

In [96]:
is_dumb_state_heuristic = lambda b: (b[0] > 10) or (b[1] > 10)

In [106]:
is_not_productive_heuristic = lambda i, m: (i >= 21) and (m[3] == 0)

In [115]:
def find_max_geodes(bp):
    states = set([
        ((1, 0, 0, 0), (0, 0, 0, 0), (0, 0, 0, 0))
    ])
    new_best_money_per_bots = {(1,0,0,0): (-1,-1,-1,-1)}
    for i in range(24):
        best_money_per_bots = new_best_money_per_bots.copy()
        new_states = set()
        for state in states: 
            bots, money, prev_buy_options = state
            if is_better_money(best_money_per_bots[bots], money):
                continue
#             if is_dumb_state_heuristic(bots):
#                 continue
#             if is_not_productive_heuristic(i, money):
#                 continue 
            options = buy_options(bp, money)
            money = produce(bots, money)
            new_money = [money]
            new_bots = [bots]
            new_prev_buy_options = [options]
            if not options == prev_buy_options:
                for idx, option in enumerate(options):
                    if option: 
                        bots_option = add_bot(bots, idx)
                        money_option = buy_bot(bp, money, idx)
                        if bots_option in new_best_money_per_bots:
                            if is_better_money(money_option, new_best_money_per_bots[bots_option]):
                                new_best_money_per_bots[bots_option] = money_option
                        else:
                            new_best_money_per_bots[bots_option] = money_option
                        new_money.append(money_option)
                        new_bots.append(bots_option)
                        new_prev_buy_options.append((-1, -1, -1, -1))
            new_states.update([(b,m, opt) for b, m, opt in zip(new_bots, new_money, new_prev_buy_options)])
        states = new_states
        print(f"{i}: {len(states)} state{'s' if len(states) > 1 else ''}")

    geodes = [s[1][3] for s in states]
    return states, max(geodes)

bp = data[6]

states, g = find_max_geodes(bp)
g

0: 1 state
1: 1 state
2: 2 states
3: 4 states
4: 7 states
5: 12 states
6: 25 states
7: 42 states
8: 71 states
9: 134 states
10: 238 states
11: 400 states
12: 717 states
13: 1289 states
14: 2363 states
15: 5116 states
16: 10852 states
17: 22144 states
18: 46006 states
19: 94475 states
20: 184248 states
21: 362035 states
22: 697470 states
23: 1365950 states


2

In [117]:
geodes = []
for bp in data:
    print(bp)
    states, g = find_max_geodes(bp)
    print(g)
    geodes.append(g)
    
print(sum(geodes))

((4, 0, 0, 0), (4, 0, 0, 0), (4, 14, 0, 0), (3, 0, 16, 0))
0: 1 state
1: 1 state
2: 1 state
3: 1 state
4: 3 states
5: 3 states
6: 3 states
7: 5 states
8: 7 states
9: 11 states
10: 15 states
11: 29 states
12: 46 states
13: 85 states
14: 137 states
15: 236 states
16: 478 states
17: 990 states
18: 2113 states
19: 5110 states
20: 10641 states
21: 23131 states
22: 47255 states
23: 97234 states
0
((3, 0, 0, 0), (3, 0, 0, 0), (2, 19, 0, 0), (2, 0, 12, 0))
0: 1 state
1: 1 state
2: 1 state
3: 3 states
4: 3 states
5: 5 states
6: 7 states
7: 11 states
8: 21 states
9: 35 states
10: 64 states
11: 116 states
12: 194 states
13: 325 states
14: 635 states
15: 1246 states
16: 2471 states
17: 5364 states
18: 11842 states
19: 24249 states
20: 49290 states
21: 99133 states
22: 198202 states
23: 405199 states
1
((4, 0, 0, 0), (4, 0, 0, 0), (4, 9, 0, 0), (4, 0, 16, 0))
0: 1 state
1: 1 state
2: 1 state
3: 1 state
4: 3 states
5: 3 states
6: 3 states
7: 5 states
8: 7 states
9: 11 states
10: 15 states
11: 29 sta

21: 45050 states
22: 89883 states
23: 203120 states
2
((2, 0, 0, 0), (4, 0, 0, 0), (3, 19, 0, 0), (4, 0, 13, 0))
0: 1 state
1: 1 state
2: 2 states
3: 2 states
4: 5 states
5: 9 states
6: 18 states
7: 30 states
8: 55 states
9: 100 states
10: 182 states
11: 287 states
12: 471 states
13: 754 states
14: 1347 states
15: 2264 states
16: 4069 states
17: 8595 states
18: 17270 states
19: 36180 states
20: 73174 states
21: 147270 states
22: 293191 states
23: 609801 states
2
((3, 0, 0, 0), (3, 0, 0, 0), (3, 17, 0, 0), (4, 0, 8, 0))
0: 1 state
1: 1 state
2: 1 state
3: 3 states
4: 3 states
5: 5 states
6: 7 states
7: 11 states
8: 21 states
9: 35 states
10: 64 states
11: 116 states
12: 194 states
13: 342 states
14: 714 states
15: 1399 states
16: 3044 states
17: 6940 states
18: 14830 states
19: 30306 states
20: 63126 states
21: 133626 states
22: 307443 states
23: 727628 states
4
((3, 0, 0, 0), (3, 0, 0, 0), (3, 20, 0, 0), (2, 0, 12, 0))
0: 1 state
1: 1 state
2: 1 state
3: 3 states
4: 3 states
5: 5 state

In [123]:
np.sum(np.array(geodes) * (np.arange(len(geodes)) + 1))

1418