# Reinforcement Learning :

MultiArmed Bandit

In [1]:
import gym
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import logging

In [3]:
# Creating environment multiarmed bandit
class MultiArmedBandit:
    # 4 armed bandit
    def __init__(self):
        # Probabilities of payout for 4 arms.
        self.bandit = [0.2, 0.0, 0.1, -4.0]
        self.num_actions = 4
    
    # Pull an arm if random number > bandit[arm] number i.e if random number > prob of arm reward is 1 else -1.
    def pull(self, arm):
        return 1 if np.random.randn(1) > self.bandit[arm] else -1

In [5]:
# Agent class which will work with MultiArmedBandit.
class Agent:
    def __init__(self, actions = 4):
        self.num_actions = actions
        self.reward_in = tf.placeholder(tf.float32, [1], name='reward_in')
        self.action_in = tf.placeholder(tf.int32, [1], name='action_in')
        
        self.W = tf.get_variable('W', [self.num_actions])
        self.best_action = tf.argmax(self.W, axis=0)
        
        actions_weight = tf.slice(self.W, self.action_in, [1])
        policy_loss = -(tf.log(actions_weight) * self.reward_in)
        self.optimizer = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(policy_loss)
        
    def predict(self, sess):
        return sess.run(self.best_action)
    
    def random_or_predict(self, sess, epsilon):
        if np.random.rand(1) < epsilon:
            return np.random.randint(self.num_actions)
        else:
            return self.predict(sess)
        
    def train(self, sess, action, reward):
        sess.run(self.optimizer, {self.action_in:[action], self.reward_in:[reward]})

In [6]:
env = MultiArmedBandit()
agent = Agent()
num_episodes = 50000
EPSILON = 0.1

Instructions for updating:
Colocations handled automatically by placer.


In [10]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for _ in range(num_episodes):
        action = agent.random_or_predict(sess, EPSILON)
        reward = env.pull(action)
        agent.train(sess, action, reward)
        
    print(np.argmin(np.array(env.bandit)))
    print(agent.predict(sess))

3
1
