In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline

# 基于状态的多臂赌博机

In [None]:
class ContextBandit:
    def __init__(self):
        # 测试
        self.bandits = np.array([[0.8,0,2.3,-5],[0.1,-5,3.2,2.2],[-5,3.5,5,5]])
        self.bandit_nums = self.bandits.shape[0]
        self.action_nums = self.bandits.shape[1]
    
    def getBandit(self):
        self.state = np.random.randint(0, self.bandit_nums)
        return self.state
    
    def pullArm(self, action):
        bandit = self.bandits[self.state, action]
        ret = np.random.randn(1)
        if ret > bandit:
            reward = 1
        else:
            reward = -1
        return reward

# 基于神经网络的Agent建模

In [None]:
class Agent:
    def __init__(self, lr, state_size, action_size):
        # 根据训练函数获得推荐动作
        self.state = tf.placeholder(dtype=tf.int32, shape=[1])
        state_oh = tf.one_hot(self.state, state_size)
        net = tf.layers.dense(state_oh, action_size, activation=tf.sigmoid, kernel_initializer=tf.ones_initializer())
        self.out = tf.reshape(net, shape=[-1])
        self.chosen_action = tf.argmax(self.out, 0)
        
        # 进行训练优化
        self.reward_holder = tf.placeholder(dtype=tf.float32, shape=[1])
        self.action_holder = tf.placeholder(dtype=tf.int32, shape=[1])
        self.response_weight = tf.slice(self.out, self.action_holder, size=[1])
        self.loss = -(tf.log(self.response_weight)*self.reward_holder)
        optimize = tf.train.GradientDescentOptimizer(learning_rate=lr)
        self.update = optimize.minimize(self.loss)

# 训练学习

In [None]:
tf.reset_default_graph()
cBandit = ContextBandit()
agent = Agent(lr=0.001, state_size=cBandit.bandit_nums, action_size=cBandit.action_nums)
total_episodes = 10000
total_reward = np.zeros([cBandit.bandit_nums, cBandit.action_nums])
e = 0.1

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    for i in range(total_episodes):
        s = cBandit.getBandit()
        
        if np.random.randn(1)<e:
            action = np.random.randint(0, cBandit.action_nums)
        else:
            action = sess.run(agent.chosen_action, feed_dict={agent.state:[s]})
        
        reward = cBandit.pullArm(action)
        sess.run(agent.update, feed_dict={agent.action_holder:[action], agent.reward_holder:[reward], agent.state:[s]})
        total_reward[s, action] += reward

In [None]:
for s in range(cBandit.bandit_nums):
    print (np.argmax(total_reward[s]) == np.argmin(cBandit.bandits[s]))