In [1]:
import numpy as np

In [2]:
class Bandit:
    def __init__(self, true_means):
        self.true_means = true_means

    def pull_arm(self, arm):
        return np.random.normal(self.true_means[arm], 1)

class DistributionModel:
    def __init__(self, num_arms):
        self.num_arms = num_arms
        self.mean_rewards = np.zeros(num_arms)
        self.variance_rewards = np.ones(num_arms)
        self.arm_counts = np.zeros(num_arms)  # Đếm số lần kéo từng arm

    # Update mean and variance of rewards for the selected arm
    def update_distribution(self, arm, reward):
        # Cập nhật số lần kéo của arm
        self.arm_counts[arm] += 1

        # Tính toán mean mới bằng công thức trung bình lũy thừa
        old_mean = self.mean_rewards[arm]
        new_mean = old_mean + (reward - old_mean) / self.arm_counts[arm]
        self.mean_rewards[arm] = new_mean

        # Tính toán variance mới
        if self.arm_counts[arm] > 1:
            self.variance_rewards[arm] = (
                (self.variance_rewards[arm] * (self.arm_counts[arm] - 1) +
                 (reward - new_mean) ** 2) / self.arm_counts[arm]
            )

In [3]:
# Define the true means of the bandit arms
true_means = [1.0, 2.0]

# Create a bandit environment with the true means
bandit = Bandit(true_means)

# Create a distribution model for the bandit
distribution_model = DistributionModel(num_arms=len(true_means))

# Pull arms and update distribution model
num_pulls = 1000
for _ in range(num_pulls):
    # Randomly select an arm to pull
    arm = np.random.randint(0, len(true_means))
    
    # Pull the selected arm and observe reward
    reward = bandit.pull_arm(arm)
    
    # Update distribution model
    distribution_model.update_distribution(arm, reward)

# Print the updated distribution model
print("Updated Distribution Model:")
print("Mean Rewards:", distribution_model.mean_rewards)
print("Variance of Rewards:", distribution_model.variance_rewards)

Updated Distribution Model:
Mean Rewards: [0.96658202 2.00365097]
Variance of Rewards: [0.97884204 0.96875667]
