In [1]:
from collections import defaultdict
import numpy as np
import sys
from enum import Enum
class ACTION(Enum):
    low=0 
    high=1
class Winner:
    def __init__(self):
        self.energy_max = 10
        self.energy_add = 2
        self.energy_prob = 0.2
        self.P_H = 0.55
        self.P_L = 0.45
        self.state = 0
        self.done = 0
        self.R = 0
        self.highcost = 1
        self.lowcost = 0
        self.greedy = 0.1
        self.energy=self.energy_max
        self.discount=0.9
        self.alpha=0.01
        self.lam = 0.9
        return
    def reset_game(self):
        self.state=0
        self.done=0
        self.energy=self.energy_max
        return

        
        
    def env_step(self,state,energy,action):
        done=0
        reward=0
        sample = np.random.uniform(0, 1)
        smp = np.random.uniform(0,1)

        if action == ACTION.high:
            if energy >= self.highcost :
                suc_prob = self.P_H
                energy = energy - self.highcost
            else:
                suc_prob = self.P_L
                reward = reward + self.lowcost
        else: 
            suc_prob = self.P_L
            reward = reward + self.lowcost
        if sample < suc_prob:
            state = state + 1
        else:
            state = state -1
        if state >= 3 :
            reward = reward + 1000
            done = 1
        if state <= (-3):
            done = 1
        if smp < self.energy_prob:
            energy = energy + self.energy_add
        if energy > self.energy_max:
            energy = self.energy_max
        return reward,state,energy,done
    def ba_policy(self,state,B):
        
        sample = np.random.uniform(0, 1)
        
        loweffort = self.fea_vector(state,B,ACTION.low)
        
        higheffort = self.fea_vector(state,B,ACTION.high)
        
        Q_low=np.dot(self.w,loweffort)
        
        Q_high=np.dot(self.w,higheffort)
        
        if sample > self.greedy: 
        
            act_buf = np.argmax([Q_low,Q_high])
            if act_buf == 1:
                act = ACTION.high
            else:
                act = ACTION.low
            #print(act)
        else:
            
            smp = np.random.uniform(0,1)
            
            if smp > 0.5:
                act = ACTION.high
            else:
                act = ACTION.low
        #print(act)
        return act

    def SARSAlambda(self,num_episodes):
        
        self.w=self.init_weight()
        
        for i_episode in range(1, num_episodes+1):
            if i_episode % 1000 == 0:
                print("\rEpisode {}/{}.".format(i_episode, num_episodes), end="")
                sys.stdout.flush()
            B = self.energy_max
            done = 0

            state = 0
            q_old = 0
            z= 0
            action = self.ba_policy(state,B)
            #print(action)
            x = self.fea_vector(state,B,action)
            while(1):
                
                reward,next_state,B,done = self.env_step(state,B,action)
                #print('flag=',flag,next)
                next_action = self.ba_policy(next_state,B)
                
                next_x = self.fea_vector(next_state,B,next_action)
                
                q = np.dot(self.w,x)
                
                next_q = np.dot(self.w,next_x)
                
                delta = reward + self.discount*next_q - q
                
                z = self.discount*self.lam*z + (1 - self.alpha*self.discount*self.lam*np.dot(z,x))*x
                
                self.w = self.w + self.alpha*(delta+ q - q_old)*z - self.alpha * (q - q_old)*x
                
                q_old = next_q
                
                x = next_x
                
                action = next_action
                
                state = next_state
                
                if done == 1:
                    break

        return self.w
    def init_weight(self):
        w = np.zeros((154))
        return w
    def fea_vector(self,state,B,action):
        x = np.zeros((154))
        #print(state)
        x[action.value*76+(state+3)*11+B] = 1
        return x
        


In [2]:
winner=Winner()

Q=winner.SARSAlambda(500000)

Q = defaultdict(lambda:np.zeros(2))


Episode 500000/500000.

In [3]:
B_MAX=10
for state in range (7):
    for b in range(B_MAX+1):
        Q[state-3,b][0]=np.dot(winner.w,winner.fea_vector(state - 3,b,ACTION.low))
        Q[state-3,b][1]=np.dot(winner.w,winner.fea_vector(state - 3,b,ACTION.high))

print(Q)

defaultdict(<function <lambda> at 0x7f902eef4bf8>, {(-3, 0): array([0., 0.]), (-3, 1): array([0., 0.]), (-3, 2): array([0., 0.]), (-3, 3): array([0., 0.]), (-3, 4): array([0., 0.]), (-3, 5): array([0., 0.]), (-3, 6): array([0., 0.]), (-3, 7): array([0., 0.]), (-3, 8): array([0., 0.]), (-3, 9): array([0., 0.]), (-3, 10): array([0., 0.]), (-2, 0): array([1.76810086, 0.0021099 ]), (-2, 1): array([0.90068246, 0.        ]), (-2, 2): array([28.53389773,  2.37244803]), (-2, 3): array([57.28657471, 18.67765224]), (-2, 4): array([47.6822711 , 58.64128081]), (-2, 5): array([57.01914762, 43.12527567]), (-2, 6): array([62.24174462, 73.17138695]), (-2, 7): array([23.572231  , 71.67051977]), (-2, 8): array([75.63968915, 88.5933672 ]), (-2, 9): array([69.34510218, 62.59466434]), (-2, 10): array([66.13227584, 93.07725228]), (-1, 0): array([1.21144705e+01, 1.95959294e-03]), (-1, 1): array([0.61270715, 7.32699072]), (-1, 2): array([85.08331213,  7.73483186]), (-1, 3): array([131.90754813,  79.98059446])

In [4]:
import pandas as pd
actions={'low','high'}
q_table = pd.DataFrame(data=[[Q[state,B][act.value] for act in ACTION] for state,B in Q.keys()],
                       index=Q.keys(), columns=actions)
print(q_table)

              low        high
-3 0     0.000000    0.000000
   1     0.000000    0.000000
   2     0.000000    0.000000
   3     0.000000    0.000000
   4     0.000000    0.000000
   5     0.000000    0.000000
   6     0.000000    0.000000
   7     0.000000    0.000000
   8     0.000000    0.000000
   9     0.000000    0.000000
   10    0.000000    0.000000
-2 0     1.768101    0.002110
   1     0.900682    0.000000
   2    28.533898    2.372448
   3    57.286575   18.677652
   4    47.682271   58.641281
   5    57.019148   43.125276
   6    62.241745   73.171387
   7    23.572231   71.670520
   8    75.639689   88.593367
   9    69.345102   62.594664
   10   66.132276   93.077252
-1 0    12.114471    0.001960
   1     0.612707    7.326991
   2    85.083312    7.734832
   3   131.907548   79.980594
   4   123.632847  106.555194
   5   133.023479  149.263906
   6   151.434701  124.959619
   7   147.628572  182.723094
...           ...         ...
 1 3   366.049912  267.215187
   4   406

In [5]:
winner.w

array([  0.        ,  13.6526239 ,  63.2013982 , 180.49335898,
       352.47266825, 659.90186272,   0.        ,  24.00761432,
        60.857246  ,  76.00268915, 105.23112663, 122.58073925,
       133.70445774, 142.94243696, 151.12219688, 166.02622473,
       168.61767823, 118.62950216,  62.49636359])