In [1]:
from collections import deque, defaultdict, Counter
from heapq import heapify, heappush, heappop
import numpy as np
from copy import deepcopy
import math
import time
from functools import cache, reduce, cmp_to_key
import graphviz
from itertools import product
import matplotlib.pyplot as plt
from bisect import bisect_left, bisect_right
import json

In [2]:
valves = set()
flow = {}
adj = defaultdict(set)

with open("./data/day16.txt") as f:
    while line := f.readline():
        line = line.rstrip().replace(',', '').split(' ')
        valve = line[1]
        valves.add(valve)
        flow[valve] = int(line[4].replace(';','').split('=')[1])
        for next_valve in line[9:]:
            valves.add(next_valve)
            adj[valve].add(next_valve)
valves = sorted(list(valves))
valve_to_ind = {valve:i for i, valve in enumerate(valves)}
len(valves), valves[0], flow[valves[0]], adj[valves[0]]

(55, 'AA', 0, {'DI', 'NB', 'UV', 'VS', 'XO'})

# Part 1

First idea: DFS with keeping track of the state of valves to allow coming back to a valve if the state has changed, i.e. a new valve has been opened. This is just super slow though. Works, but prohibitively slow on real input.

In [3]:
def dfs(ind, state, visited, time_left):
    if time_left <= 1:
        return 0
    res = 0
    valve_open = (state & (1 << ind)) != 0
    if not valve_open and flow[valves[ind]] != 0:
        new_state = state | (1 << ind)
        res = (time_left-1)*flow[valves[ind]] + dfs(ind, new_state, {(ind, new_state)}, time_left-1)
    for a in adj[valves[ind]]:
        a_ind = valve_to_ind[a]
        if (a_ind, state) not in visited:
            visited.add((a_ind, state))
            res = max(res, dfs(a_ind, state, visited, time_left-1))
            visited.remove((a_ind, state))
    return res

Second idea: dynamic programming (implemented top-down, i.e. recursion with memoization).

In [4]:
@cache
def dp(ind, state, time_left):
    if time_left <= 1:
        return 0
    res = 0
    valve_open = (state & (1 << ind)) != 0
    if not valve_open and flow[valves[ind]] != 0:
        new_state = state | (1 << ind)
        res = (time_left-1)*flow[valves[ind]] + dp(ind, new_state, time_left-1)
    for a in adj[valves[ind]]:
        a_ind = valve_to_ind[a]
        res = max(res, dp(a_ind, state, time_left-1))
    return res

In [5]:
dp(0, 0, 30)

1617

# Part 2

### Idea 1:
We get an extra index for our dynamic programming. We need to loop over all combinations of moves at `ind1` and `ind2`.

In [6]:
@cache
def dp2(ind1, ind2, state, time_left): # remember to always make sure ind1 <= ind2 to make space smaller
    if time_left <= 1:
        return 0
    res = 0
    valve_open1 = (state & (1 << ind1)) != 0
    valve_open2 = (state & (1 << ind2)) != 0
    new_state1 = state | (1 << ind1)
    new_state2 = state | (1 << ind2)
    if not valve_open1 and flow[valves[ind1]] != 0: # open valve at ind1
        if ind1 != ind2 and not valve_open2 and flow[valves[ind2]] != 0: # open both valves
            res = max(res, (time_left-1)*flow[valves[ind1]] + (time_left-1)*flow[valves[ind2]] + dp2(ind1, ind2, new_state1 | new_state2, time_left-1))
        for a2 in adj[valves[ind2]]: # open only valve at ind1
            a_ind2 = valve_to_ind[a2]
            res = max(res, (time_left-1)*flow[valves[ind1]] + dp2(min(ind1, a_ind2), max(ind1, a_ind2), new_state1, time_left-1))
    for a1 in adj[valves[ind1]]: # dont open valve at ind1
        a_ind1 = valve_to_ind[a1]
        if ind1 != ind2 and not valve_open2 and flow[valves[ind2]] != 0: # open valve at ind2
            res = max(res, (time_left-1)*flow[valves[ind2]] + dp2(min(a_ind1, ind2), max(a_ind1, ind2), new_state2, time_left-1))
        for a2 in adj[valves[ind2]]: # dont open any valves
            a_ind2 = valve_to_ind[a2]
            res = max(res, dp2(min(a_ind1, a_ind2), max(a_ind1, a_ind2), state, time_left-1))
    return res

Starts eating up too much memory. Not really suprising since there are now 55 times more states to memoize than in part 1. For example, let's compare cache sizes after running the entire part 1 vs running part 2 with only 15 minutes to start.

In [7]:
dp2(0, 0, 0, 15)

691

In [8]:
dp.cache_info(), dp2.cache_info()

(CacheInfo(hits=510247, misses=484150, maxsize=None, currsize=484150),
 CacheInfo(hits=4699263, misses=1248667, maxsize=None, currsize=1248667))

### Idea 2:
Let's add a check to make sure we aren't computing `dp2(ind1, ind2, state, time_left)` if we have already computed it for more time left. This makes the caching of the function sketchy since we have side effects in the function which only happen the first time it is run but let's see what happens.

In [9]:
visited = defaultdict(int)

@cache
def dp3(ind1, ind2, state, time_left): # remember to always make sure ind1 <= ind2 to make space smaller
    if visited[(ind1, ind2, state)] >= time_left:
        return 0
    visited[(ind1, ind2, state)] = time_left
    if time_left <= 1:
        return 0
    res = 0
    valve_open1 = (state & (1 << ind1)) != 0
    valve_open2 = (state & (1 << ind2)) != 0
    new_state1 = state | (1 << ind1)
    new_state2 = state | (1 << ind2)
    if not valve_open1 and flow[valves[ind1]] != 0: # open valve at ind1
        if ind1 != ind2 and not valve_open2 and flow[valves[ind2]] != 0: # open both valves
            res = max(res, (time_left-1)*flow[valves[ind1]] + (time_left-1)*flow[valves[ind2]] + dp3(ind1, ind2, new_state1 | new_state2, time_left-1))
        for a2 in adj[valves[ind2]]: # open only valve at ind1
            a_ind2 = valve_to_ind[a2]
            res = max(res, (time_left-1)*flow[valves[ind1]] + dp3(min(ind1, a_ind2), max(ind1, a_ind2), new_state1, time_left-1))
    for a1 in adj[valves[ind1]]: # dont open valve at ind1
        a_ind1 = valve_to_ind[a1]
        if ind1 != ind2 and not valve_open2 and flow[valves[ind2]] != 0: # open valve at ind2
            res = max(res, (time_left-1)*flow[valves[ind2]] + dp3(min(a_ind1, ind2), max(a_ind1, ind2), new_state2, time_left-1))
        for a2 in adj[valves[ind2]]: # dont open any valves
            a_ind2 = valve_to_ind[a2]
            res = max(res, dp3(min(a_ind1, a_ind2), max(a_ind1, a_ind2), state, time_left-1))
    return res

In [10]:
dp3(0, 0, 0, 15)

691

In [11]:
dp3.cache_info()

CacheInfo(hits=3968529, misses=1207786, maxsize=None, currsize=1207786)

Nope. Still eats too much memory for my liking.

### Idea 3: Memory optimized DP.

If I write a DP solution that only stores `time_left-1` solutions while building up the solution for `time_left`, how big would the array be?

In [12]:
# Number of states = 2**(number of valves with non-zero flow):
c = 0
for valve in valves:
    c += flow[valve] != 0
state_count = 2**c

# Number of ind1, ind2 (ind1 <= ind2) pairs:
s = 0
for i in range(len(valves)):
    s += len(valves)-i

total_size = s*state_count
total_size

50462720

50 million integers to store while building an array of 50 million more integers? Nope.