In [1]:
%load_ext autoreload
%autoreload 2

In [5]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from hypergrid_env import HyperGrid

### Sanity check

In [161]:
ndim, H = (2, 8)
R0, R1, R2 = 1e-2, 0.5, 2
env = HyperGrid(ndim, H, R0, R1, R2)

In [162]:
env.n_states

81

### DAG class

In [163]:
class DAG:
    def __init__(self):
        self.n_states = 0
        self.n_edges = 0
        self.states = []
        self.state_dict = {}
        self.parents = []
        self.children = []
        self.edges = []
        
    def add_states(self, state_list):
        # state_list is a list of (hashable?) python objects
        for state in state_list:
            self.states.append(state)
            self.state_dict[state] = self.n_states
            self.n_states += 1
            self.parents.append([])
            self.children.append([])
        
    def add_final_state(self):
        self.states.append('sf')
        self.state_dict['sf'] = self.n_states
        self.n_states += 1
        self.parents.append([])
        
    def add_edges(self, edge_list):
        # edge_list is a list of (i, j) tuples, i and j representing the state indices
        for i, j in edge_list:
            self.parents[j].append(i)
            self.children[i].append(j)
        self.edges.extend(edge_list)
            
    def get_parents(self, state=None, state_idx=None):
        if state_idx is None:
            state_idx = G.state_dict[state]
        parent_indices = self.parents[state_idx]
        parents = [self.states[index] for index in parent_indices]
        return parent_indices, parents
    
    def get_children(self, state=None, state_idx=None):
        if state_idx is None:
            state_idx = G.state_dict[state]
        children_indices = self.children[state_idx]
        children = [self.states[index] for index in children_indices]
        return children_indices, children

In [164]:
G = DAG()
state_list = [tuple(arr) for arr in list(env.grid.reshape(-1, ndim).numpy())]
G.add_states(state_list)
G.add_final_state()

In [165]:
G.n_states

82

In [166]:
# edge_list = [(i, j) for i in range(G.n_states - 1) for j in range(G.n_states - 1)
#              if np.sum(np.array(G.states[j]) - np.array(G.states[i])) == 1
#              and np.sum(np.abs(np.array(G.states[j]) - np.array(G.states[i]))) == 1 ]
# edge_list += [(i, G.n_states - 1) for i in range(G.n_states - 1)]
for state in G.states:
    idx_state = G.state_dict[state]
    if idx_state == G.n_states - 1:
        continue
    edge_list = []
    for dim in range(ndim):
        if state[dim] < H:
            child_state = tuple((j if i != dim else j + 1 for i, j in enumerate(state)))
            idx_child = G.state_dict[child_state]
            edge_list.append((idx_state, idx_child))
    edge_list.append((idx_state, G.n_states - 1))
    G.add_edges(edge_list)

In [167]:
G.edges[:10]

[(0, 1),
 (0, 9),
 (0, 81),
 (1, 2),
 (1, 10),
 (1, 81),
 (2, 3),
 (2, 11),
 (2, 81),
 (3, 4)]

### Algorithm 3

In [168]:
# Define P_B
def uniform_P_B(s_idx, s_prime_idx):
    parents_idx, _ = G.get_parents(state_idx=s_prime_idx)
    n_parents = len(parents_idx)
    return 1. / n_parents

In [169]:
Y = {G.n_states - 1}  # dict of nodes we will not revisit. This is the idx of s_f
U = []  # queue
F = {edge_idx: None for edge_idx in G.edges}
for state_idx in range(G.n_states - 1):
    F[(state_idx, G.n_states - 1)] = env.reward(np.array(G.states[state_idx]))
    children_idx, _ = G.get_children(state_idx=state_idx)
    if len(children_idx) == 1:
        assert children_idx[0] == G.n_states - 1  # The index of s_f
        U.append((state_idx, env.reward(np.array(G.states[state_idx]))))
V = {state_idx: env.reward(np.array(G.states[state_idx]))
     for state_idx in range(G.n_states - 1)}

In [170]:
while len(U) > 0:
    s_prime, t = U.pop(0)
    Y.add(s_prime)
    parents_idx, _ = G.get_parents(state_idx=s_prime)
    for s in parents_idx:
        F[(s, s_prime)] = t * uniform_P_B(s, s_prime)
        if s not in V:
            V[s] = F[(s, s_prime)]
        else:
            V[s] += F[(s, s_prime)]
        s_children_idx, _ = G.get_children(state_idx=s)
        if all([child_idx in Y for child_idx in s_children_idx]):
            U.append((s, V[s]))

In [171]:
initial_edge_flows = {key: val for key, val in F.items() if key[0] == 0}
print(initial_edge_flows)
total_flow = sum(initial_edge_flows.values())
print(total_flow)
print(V[0])
manual_total_reward = 0
for state_idx in range(G.n_states - 1):
    manual_total_reward += env.reward(np.array(G.states[state_idx]))
print(manual_total_reward)

{(0, 1): 8.149999999999999, (0, 9): 8.15, (0, 81): 0.51}
16.81
16.81
16.809999999999988
