In [294]:
from aocd.models import Puzzle

puzzle = Puzzle(year=2022, day=16)

def parses(input):
    pattern = r'Valve (\w+) has flow rate=(\d+); tunnels? leads? to valves? (.*)'
    data = [re.findall(pattern, line)[0] for line in input.strip().split('\n')]
    return [(v, int(r), t.split(', ')) for v, r, t in data]
        
data = parses(puzzle.input_data)

In [295]:
sample = parses("""Valve AA has flow rate=0; tunnels lead to valves DD, II, BB
Valve BB has flow rate=13; tunnels lead to valves CC, AA
Valve CC has flow rate=2; tunnels lead to valves DD, BB
Valve DD has flow rate=20; tunnels lead to valves CC, AA, EE
Valve EE has flow rate=3; tunnels lead to valves FF, DD
Valve FF has flow rate=0; tunnels lead to valves EE, GG
Valve GG has flow rate=0; tunnels lead to valves FF, HH
Valve HH has flow rate=22; tunnel leads to valve GG
Valve II has flow rate=0; tunnels lead to valves AA, JJ
Valve JJ has flow rate=21; tunnel leads to valve II""")

In [296]:
def floyd_warshall(graph):
    # all-to-all shortest distances graph
    D = {(i,j): 1 if i in graph[j] else float('inf') 
         for i, j in itertools.product(graph, repeat=2)}
    for k, i, j in itertools.product(graph, repeat=3):
        D[i,j] = min(D[i,j], D[i,k]+D[k,j])
    return D

In [297]:
def preprocess(data):
    rates = {}
    graph = {}
    for valve, rate, neighbors in data:
        rates[valve] = rate
        graph[valve] = neighbors
    distances = floyd_warshall(graph)
    functioning_valves = [v for v, r in rates.items() if r > 0]
    return functioning_valves, distances, rates

In [308]:
import functools

def solve_a(data):
    functioning_valves, distances, rates = preprocess(data)
    
    @functools.lru_cache(maxsize=None)
    def max_pressure(remaining, location, state):
        # Stop option
        options = [remaining * rates[location]]
        # Open valve
        for i, valve in enumerate(functioning_valves):
            travel = distances[location,valve] + 1
            if not bool((state >> i) & 0x1) and travel < remaining:
                new_state = state | (1 << i)
                pressure = max_pressure(remaining-travel, valve, new_state)
                options.append(remaining * rates[location] + pressure)
        return max(options)
    
    return max_pressure(30, 'AA', 0)

In [309]:
def solve_b(data):
    functioning_valves, distances, rates = preprocess(data)
    
    @functools.lru_cache(maxsize=None)
    def max_pressure(remaining, location, state, elephant=False):
        # Stop option
        if elephant:
            options = [remaining * rates[location]]
        else:
            elephant_pressure = max_pressure(26, 'AA', state, elephant=True)
            options = [remaining * rates[location] + elephant_pressure]
        # Open valve
        for i, valve in enumerate(functioning_valves):
            travel = distances[location,valve] + 1
            if not bool((state >> i) & 0x1) and travel < remaining:
                new_state = state | (1 << i)
                pressure = max_pressure(remaining-travel, valve, new_state, elephant=elephant)
                options.append(remaining * rates[location] + pressure)
        return max(options)
    
    return max_pressure(26, 'AA', 0)

In [307]:
%%time
solve_b(sample)

CPU times: user 9.75 ms, sys: 902 µs, total: 10.7 ms
Wall time: 10.1 ms


1707

In [303]:
%%time
solve_a(data)

CPU times: user 426 ms, sys: 7.18 ms, total: 433 ms
Wall time: 435 ms


1474

In [304]:
solve_a(sample)

1651

In [None]:


def solve_enumerate(data, part):
    functioning_valves, distances, rates = preprocess(data)

    @functools.lru_cache(maxsize=None)
    def generate_paths(remaining, location, state):
        # Stop option
        paths = [[(remaining, location)]]
        # Open valve
        for i, valve in enumerate(functioning_valves):
            travel = distances[location,valve] + 1
            if not bool((state >> i) & 0x1) and travel < remaining:
                new_state = state | (1 << i)
                for subpath in generate_paths(remaining-travel, valve, new_state):
                    paths.append([(remaining, location)] + subpath)
        return paths
    
    if part == 'a':
        return max(
            sum(remaining * rates[valve] for remaining, valve in path)
            for path in generate_paths(30, 'AA', 0)
        )

In [332]:
%%time
solve_enumerate(data, 'a')

CPU times: user 2.61 s, sys: 89.1 ms, total: 2.7 s
Wall time: 2.74 s


1474

In [340]:
solve_a(sample)

1651

In [405]:
def solve_b(data):
    functioning_valves, distances, rates = preprocess(data)

    @functools.lru_cache(maxsize=None)
    def max_pressure_state(remaining, location, state):
        # Stop option
        options = {state: remaining * rates[location]}
        # Open valve
        for i, valve in enumerate(functioning_valves):
            travel = distances[location, valve] + 1
            if not bool((state >> i) & 0x1) and travel < remaining:
                new_state = state | (1 << i)
                for final, pressure in max_pressure_state(
                    remaining - travel, valve, new_state
                ).items():
                    options[final] = max(
                        options.get(final, 0), remaining * rates[location] + pressure
                    )
        return options

    # get the best split by final state
    per_state = max_pressure_state(26, "AA", 0)
    per_state = dict(sorted(per_state.items(), key=lambda x: -x[1]))
    # consider all disjoint combinations
    best = 0
    for mine, p1 in per_state.items():
        for elephant, p2 in per_state.items():
            if p1 + p2 < best:
                break
            if mine & elephant == 0:  # disjoint
                best = max(best, per_state[mine] + per_state[elephant])
    return best

In [406]:
%%time
solve_b(data)

CPU times: user 230 ms, sys: 8.12 ms, total: 238 ms
Wall time: 239 ms


2100

In [408]:
%%time
solve_b(sample)

CPU times: user 4.39 ms, sys: 189 µs, total: 4.58 ms
Wall time: 4.58 ms


1707

In [354]:
%%time
solve_b(data)

3428
CPU times: user 1.82 s, sys: 29.4 ms, total: 1.85 s
Wall time: 1.87 s


2100

In [350]:
best = 0
for a, b in itertools.product(ps, repeat=2):
    if a & b == 0:
        best = max(best, ps[a]+ps[b])
best

2100

In [None]:
generate_paths

In [None]:


def solve_enumerate(data, part):

    
    if part == 'a':
        return max(
            sum(remaining * rates[valve] for remaining, valve in path)
            for path in generate_paths(30, 'AA', 0)
        )

In [191]:
best_path

<function __main__.best_path(remaining)>

In [186]:
for s in generate_paths(20, 'AA', 0):
    print(s)

[(20, 'AA')]
[(20, 'AA'), (18, 'BB')]
[(20, 'AA'), (18, 'BB'), (16, 'CC')]
[(20, 'AA'), (18, 'BB'), (16, 'CC'), (14, 'DD')]
[(20, 'AA'), (18, 'BB'), (16, 'CC'), (14, 'DD'), (12, 'EE')]
[(20, 'AA'), (18, 'BB'), (16, 'CC'), (14, 'DD'), (12, 'EE'), (8, 'HH')]
[(20, 'AA'), (18, 'BB'), (16, 'CC'), (14, 'DD'), (12, 'EE'), (7, 'JJ')]
[(20, 'AA'), (18, 'BB'), (16, 'CC'), (14, 'DD'), (9, 'HH')]
[(20, 'AA'), (18, 'BB'), (16, 'CC'), (14, 'DD'), (9, 'HH'), (5, 'EE')]
[(20, 'AA'), (18, 'BB'), (16, 'CC'), (14, 'DD'), (9, 'HH'), (1, 'JJ')]
[(20, 'AA'), (18, 'BB'), (16, 'CC'), (14, 'DD'), (10, 'JJ')]
[(20, 'AA'), (18, 'BB'), (16, 'CC'), (14, 'DD'), (10, 'JJ'), (5, 'EE')]
[(20, 'AA'), (18, 'BB'), (16, 'CC'), (14, 'DD'), (10, 'JJ'), (5, 'EE'), (1, 'HH')]
[(20, 'AA'), (18, 'BB'), (16, 'CC'), (14, 'DD'), (10, 'JJ'), (2, 'HH')]
[(20, 'AA'), (18, 'BB'), (16, 'CC'), (13, 'EE')]
[(20, 'AA'), (18, 'BB'), (16, 'CC'), (13, 'EE'), (11, 'DD')]
[(20, 'AA'), (18, 'BB'), (16, 'CC'), (13, 'EE'), (11, 'DD'), (6, 'HH')]

{('AA', 'AA'): 2,
 ('AA', 'BB'): 1,
 ('AA', 'CC'): 2,
 ('AA', 'DD'): 1,
 ('AA', 'EE'): 2,
 ('AA', 'FF'): 3,
 ('AA', 'GG'): 4,
 ('AA', 'HH'): 5,
 ('AA', 'II'): 1,
 ('AA', 'JJ'): 2,
 ('BB', 'AA'): 1,
 ('BB', 'BB'): 2,
 ('BB', 'CC'): 1,
 ('BB', 'DD'): 2,
 ('BB', 'EE'): 3,
 ('BB', 'FF'): 4,
 ('BB', 'GG'): 5,
 ('BB', 'HH'): 6,
 ('BB', 'II'): 2,
 ('BB', 'JJ'): 3,
 ('CC', 'AA'): 2,
 ('CC', 'BB'): 1,
 ('CC', 'CC'): 2,
 ('CC', 'DD'): 1,
 ('CC', 'EE'): 2,
 ('CC', 'FF'): 3,
 ('CC', 'GG'): 4,
 ('CC', 'HH'): 5,
 ('CC', 'II'): 3,
 ('CC', 'JJ'): 4,
 ('DD', 'AA'): 1,
 ('DD', 'BB'): 2,
 ('DD', 'CC'): 1,
 ('DD', 'DD'): 2,
 ('DD', 'EE'): 1,
 ('DD', 'FF'): 2,
 ('DD', 'GG'): 3,
 ('DD', 'HH'): 4,
 ('DD', 'II'): 2,
 ('DD', 'JJ'): 3,
 ('EE', 'AA'): 2,
 ('EE', 'BB'): 3,
 ('EE', 'CC'): 2,
 ('EE', 'DD'): 1,
 ('EE', 'EE'): 2,
 ('EE', 'FF'): 1,
 ('EE', 'GG'): 2,
 ('EE', 'HH'): 3,
 ('EE', 'II'): 3,
 ('EE', 'JJ'): 4,
 ('FF', 'AA'): 3,
 ('FF', 'BB'): 4,
 ('FF', 'CC'): 3,
 ('FF', 'DD'): 2,
 ('FF', 'EE'): 1,
 ('FF', 'F

In [134]:
# import z3
# graph, rates = make_graph(sample)

# valves = list(rates)

# location = {v: [z3.Int(f'loc_{v}_{t}') for t in range(30)] for v in valves}
# openv = {v: [z3.Int(f'openv_{v}_{t}') for t in range(30)] for v in valves}

# opt = z3.Optimize()

# T = 30

# for t in range(T):
#     for v in valves:
#         for var in location, openv:
#             opt.add(var[v][t] >= 0)
#             opt.add(var[v][t] <= 1)

# opt.add(location['AA'][0] == 1)
# # Only one location per timestep
# for t in range(T):
#     opt.add( z3.Sum([location[v][t] for v in valves]) == 1)
    
# for t in range(T):
#     for v in valves:
#         opt.add( location[v][t] >= openv[v][t])
#         if t < T-1:
#             opt.add( location[v][t+1] >= openv[v][t])
            
# for v in valves:
#     opt.add( z3.Sum([openv[v][t] for t in range(T)]) <= 1)
    
# for t in range(T-1):
#     for v in valves:
#         neighbors = [v]+list(graph[v])
#         opt.add( z3.Sum([location[w][t] for w in neighbors]) >= location[v][t+1] )
            
# pressure = z3.Sum([
#     z3.Sum([ openv[v][t] * (T-1-t) * rates[v] for v in valves ])
#     for t in range(T)
# ])

# c1 = opt.maximize(pressure)
# opt.check()

In [135]:
# graph, rates = make_graph(sample)

# valves = list(rates)

# location = {v: [z3.Bool(f'loc_{v}_{t}') for t in range(30)] for v in valves}
# openv = {v: [z3.Bool(f'openv_{v}_{t}') for t in range(30)] for v in valves}

# opt = z3.Optimize()

# T = 30

# opt.add(location['AA'][0] == True)
# # Only one location per timestep
# for t in range(T):
#     opt.add( z3.PbEq([(location[v][t],1) for v in valves], 1))
    
# for t in range(T):
#     for v in valves:
#         opt.add( z3.Implies(openv[v][t], location[v][t]))
#         if t < T-1:
#             opt.add( z3.Implies(openv[v][t], location[v][t+1]))
            
# for v in valves:
#     opt.add( z3.PbLe([(openv[v][t],1) for t in range(T)], 1))
    
# for t in range(T-1):
#     for v in valves:
#         neighbors = [v]+list(graph[v])
#         opt.add( z3.Implies(location[v][t+1], z3.Or([location[w][t] for w in neighbors])))
            
# pressure = z3.Sum([
#     z3.Sum([ z3.If(openv[v][t], (T-1-t) * rates[v],0) for v in valves ])
#     for t in range(T)
# ])

# c1 = opt.maximize(pressure)
# opt.check()
# c1.value()

In [136]:
# from pyrsistent import pmap
# def solve_a(data):
#     graph, rates = make_graph(sample)
#     valves = list(rates)
#     nonzero_valves = [v != 0 for k, v in rates.items()]
#     stack = [(0, 'AA', 0, pmap({v: False for v in valves}) )]
#     visited = set()
#     visited.add(stack[0])
#     max_pressure = 0
#     T = 30
#     while stack:
#         print(max_pressure, stack[-1])
#         time, location, pressure, opened = stack.pop()
#         if time == T or sum(v for k, v in opened.items()) == nonzero_valves:
#             max_pressure = max(max_pressure, pressure)
#             continue
#         if not opened[location] and rates[location] > 0:
#             release = (T-time-1)*rates[location]
#             next_ = (time+1, location, pressure+release, opened.set(location, True))
#             if next_ not in visited:
#                 stack.append(next_)
#                 visited.add(next_)
#         for neighbor in graph[location]:
#             next_ = (time+1, neighbor, pressure, opened)
#             if next_ not in visited:
#                 stack.append(next_)
#                 visited.add(next_)
#     return max_pressure
        
            

In [137]:
make_graph(sample)

({'AA': ['DD', 'II', 'BB'],
  'BB': ['CC', 'AA'],
  'CC': ['DD', 'BB'],
  'DD': ['CC', 'AA', 'EE'],
  'EE': ['FF', 'DD'],
  'FF': ['EE', 'GG'],
  'GG': ['FF', 'HH'],
  'HH': ['GG'],
  'II': ['AA', 'JJ'],
  'JJ': ['II']},
 {'AA': 0,
  'BB': 13,
  'CC': 2,
  'DD': 20,
  'EE': 3,
  'FF': 0,
  'GG': 0,
  'HH': 22,
  'II': 0,
  'JJ': 21})

In [184]:
def solve_a(data, n=30):
    graph, rates = make_graph(data)
    G = nx.Graph(graph)
        D = nx.floyd_warshall(G)
    functioning_valves = [k for k, v in rates.items() if v > 0]

    cache = {}
    
    def test_open(state, v):
        i = functioning_valves.index(v)
        return bool((state >> i) & 0x1)
    
    # opened is a bit encoded var 
    def dp(remaining, location, opened):
        if (remaining, location, opened) not in cache:
            if remaining == 0:
                pressure = 0
            else:
                open_valves = set([v for v in functioning_valves if test_open(opened, v)])
                pressure_t = sum(rates[v] for v in open_valves)
                pressure = 0
                for v in functioning_valves:
                    if v not in open_valves:
                        d = int(D[location][v]) + 1
                        i = functioning_valves.index(v)
                        new_opened = opened | (1 << i)
                        if remaining >= d:
                            # go to v
                            pressure = max(pressure, pressure_t*d + dp(remaining-d, v, new_opened))
                # Stay put until the end
                pressure = max(pressure, pressure_t*remaining)
            
            cache[remaining, location, opened] = pressure
        return cache[remaining, location, opened]
        
    return dp(n, 'AA', 0)

IndentationError: unexpected indent (<ipython-input-184-b1a969a8d30e>, line 4)

In [123]:
def solve_b(data, n=26):
    graph, rates = make_graph(data)
    G = nx.Graph(graph)
    D = nx.floyd_warshall(G)
    functioning_valves = [k for k, v in rates.items() if v > 0]

    cache = {}
    
    def test_open(state, v):
        i = functioning_valves.index(v)
        return bool((state >> i) & 0x1)
    
    def open_valve(state, v):
        if v == 'AA': return state
        i = functioning_valves.index(v)
        return state | (1 << i)
    
    # opened is a bit encoded var 
    def dp(remaining, loc1, moving1, loc2, moving2, opened):
        if (remaining, loc1, moving1, loc2, moving2, opened) not in cache:
            if remaining == 0:
                pressure = 0
            else:
                open_valves = set([v for v in functioning_valves if test_open(opened, v)])
                pressure_t = sum(rates[v] for v in open_valves)
                pressure = 0
                closed_valves = [v for v in functioning_valves if v not in open_valves]
                
                # move 1
                if moving1 == 0:
                    opened = open_valve(opened, loc1)
                    locs1 = [(v, int(D[loc1][v]) + 1) for v in closed_valves+[loc1]]
                else:
                    locs1 = [(loc1, moving1-1)]
                    
                if moving2 == 0:
                    opened = open_valve(opened, loc2)
                    locs2 = [(v, int(D[loc2][v]) + 1) for v in closed_valves+[loc2]]
                else:
                    locs2 = [(loc2, moving2-1)]
                    
                for (loc1, moving1), (loc2, moving2) in itertools.product(locs1, locs2):
                    if loc1 != loc2:
                        pressure = max(pressure, pressure_t + dp(remaining-1, loc1, moving1, loc2, moving2, opened) )
            
            cache[remaining, loc1, moving1, loc2, moving2, opened] = pressure
        return cache[remaining, loc1, moving1, loc2, moving2, opened]
        
    dp(n, 'AA', 0, 'AA', 0, 0)
    return cache

In [130]:
solve_b(sample, 5)

{(0, 'BB', 0, 'CC', 1, 3): 0,
 (0, 'BB', 0, 'DD', 2, 3): 0,
 (0, 'BB', 0, 'EE', 3, 3): 0,
 (0, 'BB', 0, 'HH', 6, 3): 0,
 (0, 'BB', 0, 'JJ', 5, 3): 0,
 (1, 'BB', 0, 'CC', 1, 3): 13,
 (0, 'DD', 2, 'CC', 1, 3): 0,
 (0, 'DD', 2, 'EE', 3, 3): 0,
 (0, 'DD', 2, 'HH', 6, 3): 0,
 (0, 'DD', 2, 'JJ', 5, 3): 0,
 (1, 'DD', 2, 'CC', 1, 3): 13,
 (0, 'EE', 3, 'CC', 1, 3): 0,
 (0, 'EE', 3, 'DD', 2, 3): 0,
 (0, 'EE', 3, 'HH', 6, 3): 0,
 (0, 'EE', 3, 'JJ', 5, 3): 0,
 (1, 'EE', 3, 'CC', 1, 3): 13,
 (0, 'HH', 6, 'CC', 1, 3): 0,
 (0, 'HH', 6, 'DD', 2, 3): 0,
 (0, 'HH', 6, 'EE', 3, 3): 0,
 (0, 'HH', 6, 'JJ', 5, 3): 0,
 (1, 'HH', 6, 'CC', 1, 3): 13,
 (0, 'JJ', 3, 'CC', 1, 3): 0,
 (0, 'JJ', 3, 'DD', 2, 3): 0,
 (0, 'JJ', 3, 'EE', 3, 3): 0,
 (0, 'JJ', 3, 'HH', 6, 3): 0,
 (1, 'JJ', 3, 'CC', 1, 3): 13,
 (2, 'BB', 1, 'CC', 0, 1): 13,
 (3, 'BB', 0, 'CC', 1, 0): 13,
 (4, 'BB', 1, 'CC', 2, 0): 13,
 (0, 'BB', 0, 'CC', 1, 5): 0,
 (1, 'BB', 0, 'CC', 1, 5): 33,
 (0, 'BB', 0, 'DD', 0, 5): 0,
 (1, 'BB', 0, 'DD', 0, 5): 33,


In [102]:
%%time
solve_a(sample)

CPU times: user 3.8 ms, sys: 110 µs, total: 3.91 ms
Wall time: 4.04 ms


1651

In [103]:
%%time
solve_a(data)

CPU times: user 954 ms, sys: 29.1 ms, total: 983 ms
Wall time: 1 s


1474

In [74]:
import networkx as nx

In [75]:

# from pyrsistent import pmap

# from collections import Counter

# def solve_a(data):
#     graph, rates = make_graph(data)
#     G = nx.Graph(graph)
#     D = nx.floyd_warshall(G)
#     functioning_valves = [k for k, v in rates.items() if v > 0]
        
#     stack = [(0, 'AA', 0, pmap({v: False for v in functioning_valves}) )]

#     max_pressure = 0
#     T = 30
    
#     while stack:
#         time, location, pressure, opened = stack.pop()
#         max_pressure = max(max_pressure, pressure)
#         if sum(v for k, v in opened.items()) == len(functioning_valves):
#             continue
        
#         for next_location in functioning_valves:
#             if not opened[next_location]:
#                 # opening valve
#                 next_time = time + int(D[location][next_location]) + 1
#                 if next_time < T:
#                     next_opened = opened.set(next_location, True)
#                     next_pressure = pressure + (T-next_time) * rates[next_location]
#                     stack.append((next_time, next_location, next_pressure, next_opened))

#     return max_pressure


In [8]:
from collections import Counter

In [9]:
def solve_a(data):
    graph, rates = make_graph(data)
    G = nx.Graph(graph)
    D = nx.floyd_warshall(G)
    functioning_valves = [k for k, v in rates.items() if v > 0]
        
    stack = [(0, 'AA', 0, pmap({v: False for v in functioning_valves}) )]

    max_pressure = 0
    T = 30
    
    while stack:
        time, location, pressure, opened = stack.pop()
        max_pressure = max(max_pressure, pressure)
        if sum(v for k, v in opened.items()) == len(functioning_valves):
            continue
        
        for next_location in functioning_valves:
            if not opened[next_location]:
                # opening valve
                next_time = time + int(D[location][next_location]) + 1
                if next_time < T:
                    next_opened = opened.set(next_location, True)
                    next_pressure = pressure + (T-next_time) * rates[next_location]
                    stack.append((next_time, next_location, next_pressure, next_opened))

    return max_pressure

In [10]:
%%time
solve_a(sample)

CPU times: user 29.5 ms, sys: 3.08 ms, total: 32.5 ms
Wall time: 32.2 ms


1651

In [11]:
%%time
solve_a(data)

CPU times: user 3.95 s, sys: 83.1 ms, total: 4.04 s
Wall time: 4.11 s


1474

In [76]:

# %%time
# solve_a(sample)

# %%time
# solve_a(data)

In [77]:
def solve_a(data):
    graph, rates = make_graph(data)
    G = nx.Graph(graph)
    D = nx.floyd_warshall(G)
    functioning_valves = [k for k, v in rates.items() if v > 0]
    order = functioning_valves + ['AA']
    distances = [[D[v][w] for w in order] for v in order]
    graph = {order.index(v): [order.index(w) for w in neigh if w in order] for v, neigh in graph.items() if v in order}
    rates = [rates[v] for v in order]
    
    def test_open(state, v):
        return bool((state >> v) & 0x1)
        
    def open_valve(state, v):
        return state | (1 << v)


    functioning_valves = [i for i, v in enumerate(rates) if v > 0]
        
    stack = [(0, order.index('AA'), 0, 0)]

    max_pressure = 0
    T = 30
    
    while stack:
        time, location, pressure, opened = stack.pop()
        max_pressure = max(max_pressure, pressure)
        if bin(opened).count('1') == len(functioning_valves):
            continue
        
        for next_location in functioning_valves:
            if not test_open(opened, next_location):
                # opening valve
                next_time = time + int(distances[location][next_location]) + 1
                if next_time < T:
                    next_opened = open_valve(opened, next_location)
                    next_pressure = pressure + (T-next_time) * rates[next_location]
                    stack.append((next_time, next_location, next_pressure, next_opened))

    return max_pressure

In [78]:
%%time
solve_a(sample)

CPU times: user 5.71 ms, sys: 194 µs, total: 5.91 ms
Wall time: 6.01 ms


1651

In [79]:
%%time
solve_a(data)

CPU times: user 1.25 s, sys: 29 ms, total: 1.28 s
Wall time: 1.31 s


1474

In [120]:
from collections import namedtuple

In [None]:
def solve_a(data):
    graph, rates = make_graph(data)
    G = nx.Graph(graph)
    D = nx.floyd_warshall(G)
    functioning_valves = [k for k, v in rates.items() if v > 0]
    order = functioning_valves + ['AA']
    distances = [[D[v][w] for w in order] for v in order]
    graph = {order.index(v): [order.index(w) for w in neigh if w in order] for v, neigh in graph.items() if v in order}
    rates = [rates[v] for v in order]
    
    def test_open(state, v):
        return bool((state >> v) & 0x1)
        
    def open_valve(state, v):
        return state | (1 << v)


    functioning_valves = [i for i, v in enumerate(rates) if v > 0]
        
    stack = [(0, order.index('AA'), 0, 0)]

    max_pressure = 0
    T = 30
    
    while stack:
        time, location, pressure, opened = stack.pop()
        max_pressure = max(max_pressure, pressure)
        if bin(opened).count('1') == len(functioning_valves):
            continue
        
        for next_location in functioning_valves:
            if not test_open(opened, next_location):
                # opening valve
                next_time = time + int(distances[location][next_location]) + 1
                if next_time < T:
                    next_opened = open_valve(opened, next_location)
                    next_pressure = pressure + (T-next_time) * rates[next_location]
                    stack.append((next_time, next_location, next_pressure, next_opened))

    return max_pressure

In [171]:
from collections import deque

In [224]:
def solve_b(data):
    graph, rates = make_graph(data)
    G = nx.Graph(graph)
    D = nx.floyd_warshall(G)
    functioning_valves = [k for k, v in rates.items() if v > 0]
    order = functioning_valves + ['AA']
    distances = [[D[v][w] for w in order] for v in order]
    graph = {order.index(v): [order.index(w) for w in neigh if w in order] for v, neigh in graph.items() if v in order}
    rates = [rates[v] for v in order]
    
    def test_open(state, v):
        return bool((state >> v) & 0x1)
        
    def open_valve(state, v):
        return state | (1 << v)

    functioning_valves = [i for i in range(len(functioning_valves))]
                
    stack = [(0, order.index('AA'), 0, order.index('AA'), 0, 0)]
    stack = stack
    
    max_pressure = 0
    T = 26
    
    while stack:

        time1, location1, time2, location2, pressure, opened = stack.pop()
        max_pressure = max(max_pressure, pressure)
        if bin(opened).count('1') == len(functioning_valves):
            continue
        
        for next_location in functioning_valves:
            if not test_open(opened, next_location):
                # opening valve
                next_time = time1 + int(distances[location1][next_location]) + 1
                if next_time < T:
                    next_opened = open_valve(opened, next_location)
                    next_pressure = pressure + (T-next_time) * rates[next_location]
                    stack.append((next_time, next_location, time2, location2, next_pressure, next_opened))
                    
        for next_location in functioning_valves:
            if not test_open(opened, next_location):
                # opening valve
                next_time = time2 + int(distances[location2][next_location]) + 1
                if next_time < T:
                    next_opened = open_valve(opened, next_location)
                    next_pressure = pressure + (T-next_time) * rates[next_location]
                    stack.append((time1, location1, next_time, next_location, next_pressure, next_opened))

    return max_pressure

In [259]:
def solve_b(data):
    graph, rates = make_graph(data)
    G = nx.Graph(graph)
    D = nx.floyd_warshall(G)
    functioning_valves = [k for k, v in rates.items() if v > 0]
    order = functioning_valves + ['AA']
    distances = [[D[v][w] for w in order] for v in order]
    graph = {order.index(v): [order.index(w) for w in neigh if w in order] for v, neigh in graph.items() if v in order}
    rates = [rates[v] for v in order]
    
    def test_open(state, v):
        return bool((state >> v) & 0x1)
        
    def open_valve(state, v):
        return state | (1 << v)

    functioning_valves = [i for i in range(len(functioning_valves))]
                
    stack = [(0, order.index('AA'), 0, order.index('AA'), 0, 0)]
    stack = stack
    
    visited = set(stack)
    
    max_pressure = 0
    T = 26
        
    def hypothetical_pressure(valve, time, location):
        t = 1 + int(distances[location][valve]) + time
        if t >= T:
            return 0
        return rates[valve] * (T-t)
    
    def remaining_upper_bound(t1, l1, t2, l2, opened):
        remaining = 0
        for valve in functioning_valves:
            if not test_open(opened, valve):
                remaining += max(
                    hypothetical_pressure(valve,t1,l1), 
                    hypothetical_pressure(valve,t2,l2)
                )
        return remaining
    
    i = 0
    while stack:
        i += 1
        if i % 100_000 == 0:
            print(i, max_pressure, len(stack))

        time1, location1, time2, location2, pressure, opened = stack.pop()
        max_pressure = max(max_pressure, pressure)
        if bin(opened).count('1') == len(functioning_valves):
            continue
        
        for next_location in functioning_valves:
            if not test_open(opened, next_location):
                # opening valve
                next_time = time1 + int(distances[location1][next_location]) + 1
                if next_time < T:
                    next_opened = open_valve(opened, next_location)
                    next_pressure = pressure + (T-next_time) * rates[next_location]
                    
                    new = (next_time, next_location, time2, location2, next_pressure, next_opened)
                    if new not in visited:
                        if next_pressure + remaining_upper_bound(next_time, next_location, time2, location2, next_opened) > max_pressure:
                            visited.add(new)
                            stack.append(new)
                            
                    
        for next_location in functioning_valves:
            if not test_open(opened, next_location):
                # opening valve
                next_time = time2 + int(distances[location2][next_location]) + 1
                if next_time < T:
                    next_opened = open_valve(opened, next_location)
                    next_pressure = pressure + (T-next_time) * rates[next_location]
                    new = (time1, location1, next_time, next_location, next_pressure, next_opened)
                    if new not in visited:
                        if next_pressure + remaining_upper_bound(time1, location1, next_time, next_location, next_opened) > max_pressure:
                            visited.add(new)
                            stack.append(new)
                            

    return max_pressure

In [260]:
solve_b(sample)

1707

In [262]:
%%time
solve_b(data)

100000 1890 85
200000 1890 72
300000 1890 69
400000 2021 74
500000 2076 67
600000 2100 67
700000 2100 62
800000 2100 39
900000 2100 54
1000000 2100 45
1100000 2100 92
1200000 2100 58
1300000 2100 38
1400000 2100 99


KeyboardInterrupt: 

In [242]:
2**15

32768

In [None]:
def solve_b(data):
    graph, rates = make_graph(data)
    G = nx.Graph(graph)
    D = nx.floyd_warshall(G)
    functioning_valves = [k for k, v in rates.items() if v > 0]
    order = functioning_valves + ['AA']
    distances = [[D[v][w] for w in order] for v in order]
    graph = {order.index(v): [order.index(w) for w in neigh if w in order] for v, neigh in graph.items() if v in order}
    rates = [rates[v] for v in order]
    
    def test_open(state, v):
        return bool((state >> v) & 0x1)
        
    def open_valve(state, v):
        return state | (1 << v)

    functioning_valves = [i for i in range(len(functioning_valves))]
    
    T = 26
    dp_pressure = [defaultdict(int) for t in range(T+1)]
    dp_pressure[0] = {(order.index('AA'), order.index('AA'), 0)}
    
    for t in range(T):
        for (loc1, loc2, opened), pressure in dp_pressure.items():
            