In [21]:
import math
import random

def categorical_draw(probs):
  z = random.random()
  cum_prob = 0.0
  for i in range(len(probs)):
    prob = probs[i]
    cum_prob += prob
    if cum_prob > z:
      return i
  
  return len(probs) - 1

class Softmax:
  def __init__(self, temperature, counts, values):
    self.temperature = temperature
    self.counts = counts
    self.values = values
    return
  
  def initialize(self, n_arms):
    self.counts = [0 for col in range(n_arms)]
    self.values = [0.0 for col in range(n_arms)]
    return
  
  def select_arm(self):
    z = sum([math.exp(v / self.temperature) for v in self.values])
    probs = [math.exp(v / self.temperature) / z for v in self.values]
    return categorical_draw(probs)

  def update(self, chosen_arm, reward):
    self.counts[chosen_arm] = self.counts[chosen_arm] + 1
    n = self.counts[chosen_arm]
    
    value = self.values[chosen_arm]
    new_value = ((n - 1) / float(n)) * value + (1 / float(n)) * reward
    self.values[chosen_arm] = new_value
    return

In [22]:
execfile("core.py")
from random import randint
#print(randint(0,9))
import random

random.seed(1)
#means = [0.1, 0.1, 0.1, 0.1, 0.9]
means= random.sample(range(1, 4000), 3999)
n_arms = len(means)
random.shuffle(means)
arms = map(lambda (mu): BernoulliArm(mu), means)
print("Best arm is " + str(ind_max(means)))
f = open("algorithms/softmax/standard_softmax_results.tsv", "w")

for temperature in [0.1, 0.2, 0.3, 0.4, 0.5]:
  algo = Softmax(temperature, [], [])
  algo.initialize(n_arms)
  results = test_algorithm(algo, arms, 5000, 250)
  for i in range(len(results[0])):
      f.write(str(temperature) + "\t")
      f.write("\t".join([str(results[j][i]) for j in range(len(results))]) + "\n")

f.close()


Best arm is 3625


In [23]:
means

[3234,
 3612,
 589,
 2055,
 2799,
 181,
 644,
 1940,
 1896,
 3220,
 3445,
 51,
 1304,
 2265,
 1455,
 2172,
 3957,
 3425,
 256,
 2585,
 3019,
 2284,
 408,
 3310,
 2401,
 244,
 1833,
 994,
 340,
 1311,
 512,
 2954,
 3885,
 3496,
 2591,
 2992,
 1050,
 1769,
 3875,
 3890,
 2395,
 2876,
 1845,
 2702,
 28,
 3738,
 3654,
 1966,
 3900,
 3746,
 3687,
 755,
 3229,
 3296,
 1406,
 2752,
 2792,
 2249,
 2108,
 3368,
 1603,
 1167,
 393,
 3301,
 343,
 251,
 1499,
 3937,
 932,
 2866,
 3514,
 2017,
 2704,
 1673,
 1192,
 670,
 355,
 2370,
 3917,
 1964,
 1523,
 310,
 68,
 3266,
 92,
 1207,
 1842,
 3524,
 3648,
 3458,
 791,
 685,
 1098,
 1758,
 1714,
 2950,
 1909,
 1718,
 19,
 1933,
 1646,
 554,
 2685,
 3274,
 91,
 3819,
 1814,
 1803,
 1356,
 3652,
 294,
 1028,
 1829,
 1717,
 2149,
 2026,
 2966,
 3318,
 2388,
 2582,
 1421,
 1472,
 788,
 854,
 3226,
 3210,
 2620,
 3247,
 2947,
 672,
 886,
 1617,
 836,
 1328,
 2725,
 3576,
 2531,
 1860,
 1669,
 884,
 2004,
 1088,
 1300,
 2904,
 1099,
 1518,
 1383,
 3653,
 33

In [24]:
means.sort()

print means[10]

11


In [25]:
means.sort()

print means[1]

2


In [28]:
means.sort()

print means[3001]

3002
