# Pursuit algorithm

In [211]:
#Init arms
class BernoulliArm():
    def __init__(self, p):
        self.p = p

    def draw(self):
        if np.random.random() > self.p:
            return 0
        else:
            return 1.0

In [212]:
import numpy as np

In [213]:
#Arms parameters
n = 4 # Number of arms
means = [0.45, 0.55, 0, 0] # Prendre en considération n
arms = []
for mean in means:
    arms.append(BernoulliArm(mean))

In [214]:
# Simulation
counts = np.zeros(n, dtype=int) #Number of counts that we use arm n
values = np.zeros(n) # average amount of reward we’ve gotten when playing each of the N arms
probabilities = np.ones(n)/n
times = 500 # Number of turns
chosen_arms = [0.0 for i in range(times)]
rewards = [0.0 for i in range(times)]
cumulative_rewards = 0.0

In [215]:
def weighted_choice(weights):
    totals = np.cumsum(weights)
    norm = totals[-1]
    throw = np.random.rand()*norm
    return np.searchsorted(totals, throw)

 
# Do the softmax algorithm
def select_arm(probabilities, values, t):
    beta = 0.05
    #calculate new probabilities
    if t == 0:
        return np.random.randint(len(values)) # First time
    else:
        for ind, probability in enumerate(probabilities, start=0):
            if ind == np.argmax(values):
                probabilities[ind] = probability + (beta * (1 - probability))
            else: 
                probabilities[ind] = probability + (beta * (0 - probability))
    print(probabilities)
    dist = weighted_choice(probabilities)
    return dist

In [216]:
for t in range(times):
    chosen_arm = select_arm(probabilities, values, t)  # Run pursuit algo
    reward = arms[chosen_arm].draw()
    
    #Logging purposes
    chosen_arms[t] = chosen_arm  # Logging
    rewards[t] = reward  # Logging
    
    counts[chosen_arm] += 1
    _n = counts[chosen_arm]
    
    value = values[chosen_arm]
    new_value = ((_n - 1) / float(_n)) * value + (1 / float(_n)) * reward  #empirical mean
    values[chosen_arm] = new_value
    
    cumulative_rewards = cumulative_rewards + reward

print(times) 
print(chosen_arms)
print(rewards)
print(cumulative_rewards)

[ 0.2875  0.2375  0.2375  0.2375]
[ 0.323125  0.225625  0.225625  0.225625]
[ 0.35696875  0.21434375  0.21434375  0.21434375]
[ 0.33912031  0.25362656  0.20362656  0.20362656]
[ 0.3221643   0.29094523  0.19344523  0.19344523]
[ 0.30605608  0.32639797  0.18377297  0.18377297]
[ 0.29075328  0.36007807  0.17458432  0.17458432]
[ 0.27621561  0.39207417  0.16585511  0.16585511]
[ 0.31240483  0.37247046  0.15756235  0.15756235]
[ 0.34678459  0.35384694  0.14968423  0.14968423]
[ 0.37944536  0.33615459  0.14220002  0.14220002]
[ 0.41047309  0.31934686  0.13509002  0.13509002]
[ 0.43994944  0.30337952  0.12833552  0.12833552]
[ 0.46795197  0.28821054  0.12191874  0.12191874]
[ 0.49455437  0.27380002  0.11582281  0.11582281]
[ 0.51982665  0.26011002  0.11003167  0.11003167]
[ 0.54383532  0.24710451  0.10453008  0.10453008]
[ 0.56664355  0.23474929  0.09930358  0.09930358]
[ 0.53831137  0.27301182  0.0943384   0.0943384 ]
[ 0.56139581  0.25936123  0.08962148  0.08962148]
[ 0.53332602  0.29639317