In [None]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
from rn.utils.utils import save_json, make_dir

In [None]:
import random
from itertools import groupby


rewards = [-100, -20, 20, 140]
n_rewards = len(rewards)
reward_id_map = {r: i for i,r in enumerate(rewards)}

def node_links(actions, **kwargs):
    return {
        source_node: sorted([
            {'targetId': l['targetId'], 'reward': l['reward'], 'rewardId': reward_id_map[l['reward']]} 
            for l in links
        ], key = lambda l: l['reward'])
        for source_node, links in groupby(actions, lambda l: l['sourceId'])
    }


class NetworkEnvironments:
    def __init__(self, networks, n_steps):
        self.networks = [
            {
                'node_links': node_links(**n),
                'starting_node': n['starting_node'],
                'max_reward': n['max_reward']
            }
            for n in networks
        ]
        self.n_steps = n_steps
        self.reset()

    def step(self, action):
        assert not self.done, 'Environment is done already.'
        selected_link = self.node_links[self.node][action]
        reward = selected_link['reward']
        self.node = selected_link['targetId']

        self.step_count += 1
        if self.step_count >= self.n_steps:
            self.done = True
            observation = None
        else:
            observation = self.observe(self.step_count, self.node_links[self.node])

        return observation, reward, self.done, {'max_reward': self.max_reward}

    @staticmethod
    def observe(step_count, node_links):
        return (step_count, *(nl['rewardId'] for nl in node_links))

    def reset(self):
        network = random.choice(self.networks)
        self.node_links = network['node_links']
        self.node = network['starting_node']
        self.max_reward = network['max_reward']
        self.step_count = 0
        self.done = False
        return self.observe(self.step_count, self.node_links[self.node])