In [161]:
from scipy.stats import bernoulli
from scipy.stats import beta
from scipy.stats import uniform 
import math

In [162]:
class Arm:
    
    def __init__(self, q):
        self.quality     = q
        self.num_pulled  = 0
        self.num_rewards = 0
    
    def pull(self):
        reward = bernoulli.rvs(size=1,p=self.quality)[0]
        self.num_pulled +=1
        self.num_rewards += reward
        return reward

class Envi:
    
    def __init__(self):
        self.arms = []
        self.best_arm = None
        self.cumul_regret = 0
        self.cumul_regret_at_round = []
    
    def addArm(self, newArm):
        self.arms.append(newArm)
    
    def getQual(self, arm):
        return arm.quality
    
    def getBestArm(self):
        return self.best_arm
    
    def pullNewArm(self):
        qual = uniform.rvs(size=1)[0]
        newArm = Arm(qual)
        self.addArm(newArm)
        self.best_arm = max(self.arms, key=self.getQual)
        newArm.pull()
    
    def pull(self, arm):
        reward = arm.pull()
        self.cumul_regret += (self.best_arm.quality - arm.quality)
        self.cumul_regret_at_round.append(self.cumul_regret)
        return reward

In [163]:
## Testing, remove this later
e = Envi()

for i in range(10):
    e.pullNewArm()

for i in e.arms:
    print(i.quality)

print("best is", e.getBestArm().quality)

for i in range(100):
    e.pull(e.arms[0])
    
print("cumul regret", e.cumul_regret)
print("num pulls", list(map(lambda x: x.num_pulled, e.arms)))
print("num rewards", list(map(lambda x: x.num_rewards, e.arms)))

0.7266345879730363
0.614413920556071
0.32880586842831605
0.927600736099589
0.8911732224800346
0.98446574834337
0.33598460349064885
0.8620005810507441
0.7828362695829844
0.6042886789183496
best is 0.98446574834337
cumul regret 25.783116037033345
num pulls [101, 1, 1, 1, 1, 1, 1, 1, 1, 1]
num rewards [72, 1, 0, 1, 1, 1, 1, 0, 1, 0]


## Random Agent

In [164]:
class Agent:
    #pulls random arm with p = 0.5 or explores new arm with p = 0.5
    def __init__(self, envr, rounds):
        self.envr = envr
        self.total_rounds = rounds
        self.round = 0
        self.arm_pulled = []
        self.rewards = []
        self.total_reward = 0

class RandomAgent(Agent):
    
    def playRound(self):
        self.round += 1
        available_arms = self.envr.arms
        #print(available_arms)
        ran = uniform.rvs(size=1)[0]
        if len(self.envr.arms) == 0 or ran > 0.5:
            self.envr.pullNewArm()
        else:
            ran2 = uniform.rvs(size=1)[0]
            index = int(ran2*(len(self.envr.arms)))
            #print("index", index)
            self.envr.pull(self.envr.arms[index])
    
    def playFull(self):
        for i in range(self.total_rounds):
            self.playRound()

In [166]:
e = Envi()
rounds = 1000
ranAg = RandomAgent(e, rounds)
ranAg.playFull()

print("cumul regret", e.cumul_regret)
print("num pulls", list(map(lambda x: x.num_pulled, e.arms)))
print("num rewards", list(map(lambda x: x.num_rewards, e.arms)))
print("cumul_regret_at_round", e.cumul_regret_at_round)

cumul regret 245.96665142968345
num pulls [10, 3, 7, 6, 5, 8, 5, 3, 4, 4, 5, 4, 5, 8, 7, 4, 2, 3, 5, 3, 4, 6, 8, 4, 3, 4, 2, 6, 3, 9, 2, 3, 3, 6, 2, 4, 3, 4, 8, 4, 4, 2, 6, 1, 3, 1, 4, 2, 4, 1, 3, 3, 3, 2, 2, 7, 3, 5, 4, 3, 6, 3, 2, 4, 1, 3, 3, 5, 3, 2, 4, 3, 3, 2, 2, 2, 3, 5, 2, 1, 4, 4, 5, 2, 4, 2, 2, 3, 1, 3, 4, 1, 5, 1, 3, 2, 3, 1, 2, 1, 2, 1, 2, 3, 2, 4, 4, 4, 4, 2, 2, 3, 1, 5, 4, 3, 3, 2, 4, 2, 3, 2, 2, 3, 2, 2, 1, 2, 3, 1, 2, 6, 4, 1, 3, 5, 2, 1, 2, 3, 4, 3, 2, 2, 1, 1, 4, 4, 3, 2, 3, 3, 1, 2, 2, 1, 2, 3, 1, 2, 3, 1, 3, 2, 2, 1, 1, 3, 3, 2, 2, 2, 2, 2, 1, 2, 2, 2, 3, 3, 1, 3, 1, 1, 2, 2, 2, 3, 3, 3, 2, 3, 3, 1, 2, 1, 7, 1, 4, 1, 2, 1, 2, 2, 4, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 4, 2, 1, 2, 1, 1, 2, 1, 1, 1, 2, 4, 1, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 1, 1, 2, 1, 2, 1, 1, 2, 2, 1, 3, 1, 1, 1, 2, 1, 1, 3, 3, 2, 1, 1, 2, 1, 2, 2, 1, 1, 2, 1, 4, 2, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 2, 1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 3, 3, 3, 2, 3, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 3, 2, 1,