In [3]:
%load_ext autoreload
%autoreload 2

import os
import torch as th
import numpy as np
from scripts.utils.utils import load_yaml, make_dir, save_json, load_json
from scripts.pruning_models.model import calculate_reward_transition_matrices_new, calculate_q_matrix_avpruning, calculate_traces

run = '1000000'
n_nodes = 6
n_steps = 8
n_actions = 2


selected_folder =  f'../data/{run}/selected'

test_file = os.path.join(selected_folder, 'test.json')
train_file = os.path.join(selected_folder, 'train.json')

test_networks = load_json(test_file)
train_networks = load_json(train_file)

In [8]:
import random
from itertools import groupby

rewards = [-100, -20, 20, 140]
n_rewards = len(rewards)
reward_id_map = {r: i for i,r in enumerate(rewards)}

def node_links(actions, **kwargs):
    return {
        source_node: sorted([
            {'targetId': l['targetId'], 'reward': l['reward'], 'rewardId': reward_id_map[l['reward']]} 
            for l in links
        ], key = lambda l: l['reward'])
        for source_node, links in groupby(actions, lambda l: l['sourceId'])
    }


class NetworkEnvironments:
    def __init__(self, networks, n_steps):
        self.networks = [
            {
                'node_links': node_links(**n),
                'starting_node': n['starting_node'],
                'max_reward': n['max_reward']
            }
            for n in networks
        ]
        self.n_steps = n_steps
        self.reset()

    def step(self, action):
        assert not self.done, 'Environment is done already.'
        selected_link = self.node_links[self.node][action]
        reward = selected_link['reward'] # - (self.max_reward / 8)
        self.node = selected_link['targetId']

        self.step_count += 1
        if self.step_count >= self.n_steps:
            self.done = True
            observation = None
        else:
            observation = self.observe(self.step_count, self.node_links[self.node])

        return observation, reward, self.done, {'max_reward': self.max_reward}

    @staticmethod
    def observe(step_count, node_links):
        return [
            (step_count, nl['rewardId'])
            for nl in node_links
        ]


    def reset(self):
        network = random.choice(self.networks)
        self.node_links = network['node_links']
        self.node = network['starting_node']
        self.max_reward = network['max_reward']
        self.step_count = 0
        self.done = False
        return self.observe(self.step_count, self.node_links[self.node])


In [9]:
env = NetworkEnvironments(train_networks, 8)

In [289]:
Q_mr = np.zeros((n_steps, n_rewards)) # m, r

def select_action(q, epsilon):
    p = np.heaviside(q-q.mean(), 0.5) # s,a
    p = p * (1-epsilon) + epsilon / len(q)
    return np.random.choice(len(p),p=p)

n_epochs = 200000
alpha = 0.01
epsilon = 1

all_reward = []
all_regret = []

for i in range(n_epochs):
    done = False
    obs = env.reset() # m, r
    epoch_reward = 0
    while not done:
        q = Q_mr[[o[0] for o in obs],[o[1] for o in obs]]
        action = select_action(q, epsilon)
        next_obs, reward, done, info = env.step(action)

        prev_value = q[action]
        if not done:
            next_max = np.max(Q_mr[[o[0] for o in next_obs],[o[1] for o in next_obs]])
        else:
            next_max = 0
        
        new = (1 - alpha) * prev_value + alpha * (reward + next_max)
        Q_mr[obs[action][0],obs[action][1]] = new
        obs = next_obs
        epoch_reward += reward
        # print(i, reward)
    
    all_reward.append(epoch_reward)
    all_regret.append(info['max_reward'] - epoch_reward)

    

In [290]:
Q_mr

array([[ 393.1020332 ,  402.74919728,  381.68414756,    0.        ],
       [ 265.28979613,  344.1867531 ,  393.67604398,  504.74582329],
       [ 218.05942111,  289.86174369,  322.58579478,  419.51160907],
       [ 157.30468366,  223.90987325,  278.66380693,  371.08431145],
       [ 106.30645761,  168.6013652 ,  208.48254967,  304.27480705],
       [  39.53544853,   98.39675002,  158.16578204,  249.22611848],
       [  -9.75845497,   46.68175443,   96.78841449,  183.83081673],
       [-100.        ,  -20.        ,   20.        ,  140.        ]])

In [286]:
all_reward[-20:]

[-440,
 -160,
 -200,
 560,
 0,
 -320,
 -280,
 -120,
 -80,
 120,
 200,
 240,
 40,
 0,
 40,
 -40,
 280,
 200,
 40,
 320]