## Setup

In [20]:
# use full window width
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import numpy as np
os.chdir('..')
import virl
from helper_methods import run, plot

## Agent Implementation

In [89]:
class QLearningAgent:

    def __init__(self, env,number_bins):
        self.num_of_actions = env.action_space.n
        #self.state_space =
        self.env = env
        self.number_bins = number_bins
        
        self.q_table = {}
           
        
        # hyper parameters
        self.discount = 0.99 #gamma
        self.learning_rate = 0.2 #step size, alpha
        self.episodes = 20
        self.decrease_factor = 0.1
        
        
    def continous_to_discrete(self,continous,highest=600000000,lowest=0):
        bins = np.linspace(lowest,highest,num=self.number_bins)
        discrete = np.digitize(continous,bins)
        return discrete
    
    def run_all_episodes(self):
        all_states = []
        all_rewards = []
        epislon = 1 #at the start only explore
        
        for episode in range(self.episodes):
            states,rewards = self.run_episode(epislon)
            all_states.append(states)
            all_rewards.append(rewards)
            epislon -= self.decrease_factor #hyperparameter
            
        return all_states,all_rewards
    
    def run_episode(self,epislon):
        states = []
        rewards = []
        done = False
        
        
        state = self.env.reset()
        state = self.continous_to_discrete(state)
        states.append(state)
        while not done:
            random_number = np.random.randint(0,1)
            if random_number < epislon:
                #explore
                action = np.random.choice(self.num_of_actions)
                if action not in [0,1,2,3]:
                    print(action)
            else:
                #exploit
                action = self.get_action(state)
                if action not in [0,1,2,3]:
                    print("exploit")
                    print(action)
            new_state, reward, done, i = self.env.step(action=action) # Q-learning
            new_state = self.continous_to_discrete(new_state)
            
            #update q table
            self.update_q_table(state,new_state,action,reward)
            
            states.append(state)
            rewards.append(reward)
            state = new_state
        return (states, rewards)
    
    def update_q_table(self,state,new_state,action,reward):
        #target
        #max of a' given the 
        max_a_prime = np.max(self.value_from_q(new_state))
        target = reward + (self.discount*max_a_prime)
        
        #compute difference
        difference = target - self.value_from_q(state)[action]
        
        #take a small step in the delta direction
        new_q = self.value_from_q(state)[action] + (self.learning_rate * difference)
        self.value_from_q(state)[action] = new_q
    
    def get_action(self,state):
        #exploit the q table
        action = np.argmax(self.q_table[tuple(state)])
#         if action not in [0,1,2,3]:
#             print("exploit")
#             print(action)
#             print(self.q_table[state,:])
#             print(self.q_table.shape)
        return action

    def value_from_q(self,state):
        return self.q_table.get(tuple(state),0)
    
    def get_action_text(self):
        # ["no intervention", "impose a full lockdown", "implement track & trace", "enforce social distancing and face masks"]
        return "Text here"
    
    def get_env(self):
        return env
    
    def get_chart_title(self):
        return "Title here"


## Analysis

In [88]:
env = virl.Epidemic(stochastic=False, noisy=False)

agent = QLearningAgent(env,60)
states, rewards = agent.run_all_episodes()

plot(agent, states, rewards)

---------
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 ...
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
[-0.00628583 -0.00628583 -0.00628583 -0.00628583]
1
---------
---------
[[ 0.          0.          0.          0.        ]
 [ 0.         -0.00628583  0.          0.        ]
 [ 0.          0.          0.          0.        ]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-4.71943968e-06 -4.71943968e-06 -4.71943968e-06 -4.71943968e-06]
0
---------
---------
[[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [-4.71943968e-06 -6.28582715e-03  0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
 ...
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]
[-2.31598535e-

---------
[[ 0.          0.          0.          0.        ]
 [-0.00148781 -0.01133908 -0.02226173 -0.02531551]
 [-0.00878735 -0.01267767 -0.02492682 -0.00772459]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.02025395 -0.02497852 -0.00926932 -0.00926932]
3
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00148781 -0.01133908 -0.02226173 -0.02497852]
 [-0.00878735 -0.01267767 -0.02492682 -0.00772459]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.04967557 -0.02269979 -0.02483186 -0.0048904 ]
2
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00148781 -0.01133908 -0.02269979 -0.02497852]
 [-0.00878735 -0.01267767 -0.02483186 -0.00772459]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

---------
[[ 0.          0.          0.          0.        ]
 [-0.0021553  -0.02220592 -0.01861971 -0.01410915]
 [-0.00878735 -0.01267767 -0.02483186 -0.01069266]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.00780532 -0.00211092 -0.00211092 -0.00191122]
0
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00211092 -0.02220592 -0.01861971 -0.01410915]
 [-0.00878735 -0.01267767 -0.02483186 -0.01069266]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.02875008 -0.0177008  -0.0177008  -0.00840297]
2
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00211092 -0.02220592 -0.0177008  -0.01410915]
 [-0.00878735 -0.01267767 -0.02483186 -0.01069266]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

---------
[[ 0.          0.          0.          0.        ]
 [-0.00425862 -0.04069674 -0.01609834 -0.02272645]
 [-0.01009746 -0.0312425  -0.03095102 -0.01614757]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.01107008 -0.02334791 -0.02531019 -0.00516675]
3
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00425862 -0.04069674 -0.01609834 -0.02334791]
 [-0.01009746 -0.0312425  -0.03095102 -0.01614757]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.01337887 -0.02320114 -0.01744087 -0.00865621]
3
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00425862 -0.04069674 -0.01609834 -0.02320114]
 [-0.01009746 -0.0312425  -0.03095102 -0.01744087]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

---------
[[ 0.          0.          0.          0.        ]
 [-0.00938642 -0.03732199 -0.01639094 -0.01491888]
 [-0.01037581 -0.03596408 -0.02807032 -0.01809602]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.02189282 -0.03640496 -0.03640496 -0.03640496]
1
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00938642 -0.03640496 -0.01639094 -0.01491888]
 [-0.01037581 -0.03596408 -0.02807032 -0.01809602]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.00664267 -0.01372436 -0.01372436 -0.01372436]
3
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00938642 -0.03640496 -0.01639094 -0.01372436]
 [-0.01037581 -0.03596408 -0.02807032 -0.01809602]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

---------
[[ 0.          0.          0.          0.        ]
 [-0.00730338 -0.0407257  -0.01936461 -0.02598248]
 [-0.01547899 -0.03596408 -0.0387018  -0.01893084]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.01215406 -0.02343752 -0.02343752 -0.00689255]
3
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00730338 -0.0407257  -0.01936461 -0.02343752]
 [-0.01547899 -0.03596408 -0.0387018  -0.01893084]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.01212607 -0.02115283 -0.02115283 -0.00791686]
3
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00730338 -0.0407257  -0.01936461 -0.02115283]
 [-0.01547899 -0.03596408 -0.0387018  -0.01893084]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

---------
[[ 0.          0.          0.          0.        ]
 [-0.00490821 -0.03508195 -0.01425189 -0.01655832]
 [-0.01547899 -0.03596408 -0.0387018  -0.01893084]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.01353747 -0.01418878 -0.01418878 -0.01418878]
2
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00490821 -0.03508195 -0.01418878 -0.01655832]
 [-0.01547899 -0.03596408 -0.0387018  -0.01893084]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.00181812 -0.00430188 -0.00430188 -0.00430188]
0
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00430188 -0.03508195 -0.01418878 -0.01655832]
 [-0.01547899 -0.03596408 -0.0387018  -0.01893084]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

---------
[[ 0.          0.          0.          0.        ]
 [-0.00537447 -0.03833942 -0.02102377 -0.01684242]
 [-0.02082103 -0.03596408 -0.03701266 -0.02834846]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.01396079 -0.01994983 -0.01994983 -0.01994983]
2
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00537447 -0.03833942 -0.01994983 -0.01684242]
 [-0.02082103 -0.03596408 -0.03701266 -0.02834846]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.00357671 -0.00502068 -0.00502068 -0.00502068]
0
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00502068 -0.03833942 -0.01994983 -0.01684242]
 [-0.02082103 -0.03596408 -0.03701266 -0.02834846]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

---------
[[ 0.          0.          0.          0.        ]
 [-0.00432266 -0.03549378 -0.01748413 -0.01353209]
 [-0.02082103 -0.03596408 -0.03701266 -0.02834846]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.01566839 -0.01763087 -0.01763087 -0.01763087]
2
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00432266 -0.03549378 -0.01763087 -0.01353209]
 [-0.02082103 -0.03596408 -0.03701266 -0.02834846]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.00474047 -0.01884517 -0.01884517 -0.01884517]
2
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00432266 -0.03549378 -0.01884517 -0.01353209]
 [-0.02082103 -0.03596408 -0.03701266 -0.02834846]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

---------
[[ 0.          0.          0.          0.        ]
 [-0.00505661 -0.03593568 -0.02293451 -0.01213562]
 [-0.02082103 -0.04081807 -0.03032169 -0.02558162]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.01788923 -0.04276935 -0.04276935 -0.04276935]
1
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00505661 -0.03593568 -0.02293451 -0.01213562]
 [-0.02082103 -0.04276935 -0.03032169 -0.02558162]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.01459222 -0.02418669 -0.03009643 -0.03071345]
2
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00505661 -0.03593568 -0.02418669 -0.01213562]
 [-0.02082103 -0.04276935 -0.03009643 -0.02558162]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

---------
[[ 0.          0.          0.          0.        ]
 [-0.01231383 -0.04030461 -0.02510574 -0.01267445]
 [-0.0344236  -0.04694898 -0.03009643 -0.03494493]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.01169608 -0.012582   -0.012582   -0.012582  ]
3
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.01231383 -0.04030461 -0.02510574 -0.012582  ]
 [-0.0344236  -0.04694898 -0.03009643 -0.03494493]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.01618864 -0.023437   -0.023437   -0.023437  ]
2
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.01231383 -0.04030461 -0.023437   -0.012582  ]
 [-0.0344236  -0.04694898 -0.03009643 -0.03494493]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

---------
[[ 0.          0.          0.          0.        ]
 [-0.01599282 -0.04112075 -0.02940579 -0.02871549]
 [-0.03985637 -0.04607725 -0.02766009 -0.03501242]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.00163225 -0.01359799 -0.01359799 -0.00163225]
0
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.01359799 -0.04112075 -0.02940579 -0.02871549]
 [-0.03985637 -0.04607725 -0.02766009 -0.03501242]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.01709083 -0.02656775 -0.02656775 -0.01565244]
2
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.01359799 -0.04112075 -0.02656775 -0.02871549]
 [-0.03985637 -0.04607725 -0.02766009 -0.03501242]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

---------
[[ 0.          0.          0.          0.        ]
 [-0.0095623  -0.03633487 -0.02767767 -0.01393067]
 [-0.03985637 -0.05723495 -0.0333192  -0.03501242]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.02431972 -0.03563886 -0.04136985 -0.02431972]
2
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.0095623  -0.03633487 -0.03563886 -0.01393067]
 [-0.03985637 -0.05723495 -0.0333192  -0.03501242]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.04357356 -0.02131353 -0.05330351 -0.01325347]
3
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.0095623  -0.03633487 -0.03563886 -0.02131353]
 [-0.03985637 -0.05723495 -0.0333192  -0.03501242]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

---------
[[ 0.          0.          0.          0.        ]
 [-0.00689814 -0.03557208 -0.0181158  -0.01451662]
 [-0.03231911 -0.05526463 -0.0333192  -0.03501242]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.01748113 -0.01814497 -0.01814497 -0.01814497]
2
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00689814 -0.03557208 -0.01814497 -0.01451662]
 [-0.03231911 -0.05526463 -0.0333192  -0.03501242]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.01301214 -0.01440493 -0.01440493 -0.01440493]
3
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00689814 -0.03557208 -0.01814497 -0.01440493]
 [-0.03231911 -0.05526463 -0.0333192  -0.03501242]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

---------
[[ 0.          0.          0.          0.        ]
 [-0.00700426 -0.03835424 -0.01852942 -0.02930048]
 [-0.01998614 -0.05831078 -0.03313015 -0.03129498]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.02368303 -0.03377144 -0.0642848  -0.01033106]
3
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00700426 -0.03835424 -0.01852942 -0.03377144]
 [-0.01998614 -0.05831078 -0.03313015 -0.03129498]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.05259419 -0.02399363 -0.06530653 -0.02164274]
2
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00700426 -0.03835424 -0.02399363 -0.03377144]
 [-0.01998614 -0.05831078 -0.03313015 -0.03129498]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

---------
[[ 0.          0.          0.          0.        ]
 [-0.00777265 -0.03784566 -0.02646327 -0.02478063]
 [-0.02279607 -0.05769544 -0.03313015 -0.03351474]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.03546679 -0.07039782 -0.06501277 -0.0577386 ]
2
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00777265 -0.03784566 -0.02646327 -0.02478063]
 [-0.02279607 -0.05769544 -0.03313015 -0.03351474]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.02089752 -0.05530577 -0.06864807 -0.03390054]
0
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.00777265 -0.03784566 -0.02646327 -0.02478063]
 [-0.02279607 -0.05769544 -0.03313015 -0.03351474]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

---------
[[ 0.          0.          0.          0.        ]
 [-0.01209574 -0.04818738 -0.03096229 -0.01914846]
 [-0.02649837 -0.05049975 -0.03957924 -0.038821  ]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.00626181 -0.01213226 -0.02365436 -0.01538327]
0
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.01213226 -0.04818738 -0.03096229 -0.01914846]
 [-0.02365436 -0.05049975 -0.03957924 -0.038821  ]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
[-0.02716617 -0.02914027 -0.03603382 -0.02078364]
2
---------
---------
[[ 0.          0.          0.          0.        ]
 [-0.01213226 -0.04818738 -0.02914027 -0.01914846]
 [-0.02365436 -0.05049975 -0.03603382 -0.038821  ]
 ...
 [ 0.          0.          0.          0.        ]
 [ 0.        

AssertionError: 

## Evaluation

Eval here