In [18]:
import numpy as np

class BernoulliBandit :
    def __init__(self, K) :
        self.probs = np.random.uniform(size=K)
        self.K = K
        self.best_idx = np.argmax(self.probs)
        self.best_prob = self.probs[self.best_idx]
        
    def step(self, k:int) :
        if np.random.rand() < self.probs[k] :
            return 1
        return 0
    def status(self) :
        print(self.probs, self.K, self.best_prob, self.best_idx)

bb = BernoulliBandit(10)
reward = bb.step(int(2))
print(reward)
bb.status()

1
[0.68023941 0.45007456 0.83983248 0.99371238 0.42765172 0.83419885
 0.53664422 0.19829556 0.14298862 0.36819451] 10 0.9937123793337669 3


In [20]:
class Solver :
    def __init__(self, bandit) :
        self.bandit = bandit
        self.counter = [0 for i in range(bandit.K)]
        self.regrets = [] #[0 for i in range(bandit.K)] 是懊悔记录，不是每个臂的懊悔，别跟reward混淆
    def policy(self) :
        # choose one bandit and return
        return np.random.randint(0, self.bandit.K)

    def update_regret(self, k) :
        regret = self.bandit.best_prob - self.bandit.probs[k]
        self.regrets.append(regret)
    
    def run(self, max_step) :
        for i in range(max_step) :
            k = self.policy()
            self.counter[k] = self.counter[k] + 1
            self.update_regret(k)
            
    def status(self) :
        print("counter:", self.counter)
        print("regrets:", self.regrets)
s = Solver(bb)
s.run(10)
s.status()

0.9937123793337669 0.19829555741213145
0.9937123793337669 0.19829555741213145
0.9937123793337669 0.19829555741213145
0.9937123793337669 0.19829555741213145
0.9937123793337669 0.4276517178052863
0.9937123793337669 0.14298861709600075
0.9937123793337669 0.36819451311107
0.9937123793337669 0.9937123793337669
0.9937123793337669 0.839832479882609
0.9937123793337669 0.36819451311107
counter: [0, 0, 1, 1, 1, 0, 0, 4, 1, 2]
regrets: [0.7954168219216354, 0.7954168219216354, 0.7954168219216354, 0.7954168219216354, 0.5660606615284806, 0.8507237622377661, 0.6255178662226969, 0.0, 0.15387989945115788, 0.6255178662226969]
