In [1]:
# use full window width
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import numpy as np
os.chdir('..')
import virl
from matplotlib import pyplot as plt

In [21]:
# Let's import basic tools for defining the function and doing the gradient-based learning
import sklearn.pipeline
import sklearn.preprocessing
#from sklearn.preprocessing import PolynomialFeatures # you can try with polynomial basis if you want (It is difficult!)
from sklearn.linear_model import SGDRegressor # this defines the SGD function
from sklearn.kernel_approximation import RBFSampler # this is the RBF function transformation method

feature_transformer = sklearn.pipeline.FeatureUnion([
        ("rbf1", RBFSampler(gamma=5.0, n_components=100)),
        ("rbf2", RBFSampler(gamma=2.0, n_components=100)),
        ("rbf3", RBFSampler(gamma=1.0, n_components=100)),
        ("rbf4", RBFSampler(gamma=0.5, n_components=100))
        ])
feature_transformer.fit(state)

x_transformed = feature_transformer.transform(x)
x_train = np.array((-8.7,-6.5,-3.4,0.0,0.1,1.4,2.6,4.5)).reshape(-1, 1)
x_train_transformed = feature_transformer.transform(x_train)

y_train = np.array((3.4,4.4,1.1,-1.0,-1.5,4.4,6.6,-14.5))
model.fit(x_train_transformed,y_train)

func_approximator = SGDRegressor(learning_rate="constant",tol=1e-5,max_iter=1e4)

In [22]:


class LinearAprxAgent:
    def create_policy(self,func_approximator, epsilon):
        def policy_fn(state):
            """
            Input:
                state: a 2D array with the position and velocity
            Output:
                A,q_values: 
            """
            nA = self.num_of_actions
            A = np.ones(nA, dtype=float) * epsilon / nA
            q_values = self.func_approximator.predict(s=state)
           
            best_action = np.argmax(q_values)
            A[best_action] += (1.0 - epsilon)
            return A,q_values  # return the potentially stochastic policy (which is due to the exploration)

        return policy_fn # return a handle to the function so we can call it in the future
        
    def __init__(self, func_approximator,env):
        self.func_approximator = func_approximator
        self.num_of_actions = env.action_space.n
        self.env = env
        self.policy = self.create_policy(MyRandomFunctionApproximator,0.0)
        self.initial_epsilon = 0.5
        self.discount = 0.99 # gamma
        self.learning_rate = 0.25 # step size, alpha
        self.episodes = 2000
        self.print_out_every_x_episodes = int(self.episodes/50)
        
        # hyper parameters for epsilon
        self.initial_epsilon = 1 # initial
        self.decrease_factor = (1/self.episodes)/1.25 # epsilon
        self.decrease_factor = 0.00075
        
    
    def run_all_episodes(self):
        all_rewards = []
        all_q_table_exploits = []
        epislon = self.initial_epsilon # at the start only explore
        
        
        
        for episode in range(1, self.episodes + 1):
            rewards = self.run_episode(epislon)
            total_reward = np.sum(rewards)

            if episode % self.print_out_every_x_episodes == 0:
                print("Episode number: " + str(episode) + ". Total reward in episode: " + str(total_reward) + ". Episode executed with epsilon = " + str(epislon))
                print("Average total reward in last " + str(self.print_out_every_x_episodes) + " episodes: " + str(np.mean(all_rewards[-self.print_out_every_x_episodes:])))
                print("Average number of times we exploited q table in last " + str(self.print_out_every_x_episodes) + " episodes: " + str(np.mean(all_q_table_exploits[-self.print_out_every_x_episodes:])))
                print("-----")
            all_rewards.append(total_reward)
            epislon -= self.decrease_factor #hyperparameter
            
        return all_rewards
    
    def run_episode(self,epislon):
        rewards = []
        done = False
        
        state = self.env.reset()
        
        
      
        
        while not done:
            random_number = np.random.random()
            if random_number < epislon:
                #explore
                action = np.random.choice(self.num_of_actions)
            else:
                #exploit
                action,q_values = self.get_action(state)
                
            print(action)  
            print(state)
            new_state, reward, done, i = self.env.step(action=action)
            
            
            #update q table
            #self.update(new_state,action,reward)  
            rewards.append(reward)
            state = new_state
        return (rewards)
        
    def update(self):
        #update the linear function
        return
                                                       
        
    def get_action(self,state):
        #linear function get max action
        return self.policy(state)
    
    def get_action_text(self):
        return action_text
    
    def get_env(self):
        return env
    
    def get_chart_title(self):
        return "Action = " + action_text

In [23]:
actions = ["no intervention", "impose a full lockdown", "implement track & trace", "enforce social distancing and face masks"]
env = virl.Epidemic(stochastic=False, noisy=False)
agent = LinearAprxAgent(func_approximator,env)
states, rewards = agent.run_all_episodes()

plot(agent, states, rewards)

3
[5.9996e+08 2.0000e+04 0.0000e+00 2.0000e+04]
0
[5.99915319e+08 3.83607448e+04 1.18101714e+04 3.45102420e+04]
2
[5.99700182e+08 1.71105015e+05 4.67607716e+04 8.19526301e+04]
0
[5.99458239e+08 2.27659325e+05 1.18887471e+05 1.95214630e+05]
3
[5.98185769e+08 1.01200511e+06 3.15017749e+05 4.87208289e+05]
1
[5.95943432e+08 1.92712268e+06 8.40179745e+05 1.28926525e+06]
0
[5.95889001e+08 7.82575181e+05 1.17360792e+06 2.15481565e+06]
0
[5.91597297e+08 3.41394822e+06 1.66884908e+06 3.31990574e+06]
3
[5.73566726e+08 1.43362610e+07 4.53538836e+06 7.56162474e+06]
3
[5.45150213e+08 2.48018375e+07 1.15611349e+07 1.84868146e+07]
3
[5.01285934e+08 3.90387642e+07 2.22216878e+07 3.74536137e+07]
1
[4.43090645e+08 5.36616086e+07 3.66809229e+07 6.65668239e+07]
3
[4.41970300e+08 2.15504272e+07 4.29307947e+07 9.35484782e+07]
1
[4.14720341e+08 2.63204564e+07 4.33456796e+07 1.15613523e+08]
3
[4.14206119e+08 1.05488334e+07 4.07481479e+07 1.34496900e+08]
3
[4.01886610e+08 1.21999814e+07 3.63939445e+07 1.495194

2
[5.99412876e+08 4.82170388e+03 7.22030709e+04 5.10099651e+05]
0
[5.99406061e+08 6.41308409e+03 5.84199108e+04 5.29105848e+05]
2
[5.99370161e+08 2.85545011e+04 5.16620423e+04 5.49622320e+05]
0
[5.99329812e+08 3.79742638e+04 5.39095397e+04 5.78304321e+05]
2
[5.99117354e+08 1.68985853e+05 7.90338514e+04 6.34626101e+05]
0
[5.98878785e+08 2.24574140e+05 1.42885816e+05 7.53754911e+05]
0
[5.97626517e+08 9.96010218e+05 3.30326245e+05 1.04714617e+06]
2
[5.92134386e+08 4.36749331e+06 1.22369832e+06 2.27442235e+06]
0
[5.86125247e+08 5.69075209e+06 3.03653964e+06 5.14746141e+06]
0
[5.57218393e+08 2.29843717e+07 7.64667310e+06 1.21505622e+07]
1
[4.64192496e+08 7.41118230e+07 2.49337063e+07 3.67619747e+07]
3
[4.62571249e+08 2.98080940e+07 3.93183388e+07 6.83023182e+07]
0
[4.22739880e+08 3.78380346e+07 4.46942840e+07 9.47278013e+07]
1
[3.39049367e+08 7.09333225e+07 5.74127114e+07 1.32604599e+08]
1
[3.37920168e+08 2.82729401e+07 6.36199827e+07 1.70186909e+08]
2
[3.37471154e+08 1.12685569e+07 5.70053

3
[5.89038900e+08 4.77870550e+06 2.22202867e+06 3.96036575e+06]
1
[5.78851276e+08 8.80136240e+06 4.49023932e+06 7.85712250e+06]
2
[5.78609978e+08 3.56965270e+06 5.86603974e+06 1.19543299e+07]
0
[5.73871603e+08 4.53039031e+06 6.23791856e+06 1.53600876e+07]
2
[5.51788070e+08 1.76024586e+07 8.94624091e+06 2.16632303e+07]
3
[5.30561418e+08 2.08115622e+07 1.49542586e+07 3.36727614e+07]
2
[4.95085740e+08 3.18420867e+07 2.25355785e+07 5.05365948e+07]
0
[4.63131592e+08 3.31578902e+07 3.10821129e+07 7.26284054e+07]
1
[3.74877889e+08 7.29244479e+07 4.59623432e+07 1.06235320e+08]
3
[3.73592990e+08 2.91418844e+07 5.52802688e+07 1.41984857e+08]
3
[3.45385870e+08 2.96204825e+07 5.52257392e+07 1.69767908e+08]
1
[3.19794859e+08 2.80203178e+07 5.49272800e+07 1.97257543e+08]
2
[3.19373954e+08 1.11532373e+07 5.01799919e+07 2.19292817e+08]
0
[3.13111911e+08 8.31896413e+06 4.29863916e+07 2.35582733e+08]
0
[3.00821982e+08 1.14133255e+07 3.74934401e+07 2.50271253e+08]
0
[2.85175368e+08 1.47946065e+07 3.45874

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()