In [5]:
from typing import Dict

import numpy as np

np.random.seed(54)


class Strategy:
    def __init__(self, n_arms: int) -> None:
        self.n_arms = n_arms
        self.n_iters = 0
        self.arms_states = {
            "quantity": np.zeros(n_arms),
            "revenue": np.zeros(n_arms),
            "margin": np.zeros(n_arms),
        }
        self.arms_actions = np.zeros(n_arms)

    def flush(self) -> None:
        self.n_iters = 0
        for metric in self.arms_states.keys():
            self.arms_states[metric] = np.zeros(self.n_arms)
        self.arms_actions = np.zeros(self.n_arms)

    def update_reward(self, arm_id: int, metrics: Dict[str, float]) -> None:
        self.n_iters += 1
        for metric, value in metrics.items():
            self.arms_states[metric][arm_id] += value
        self.arms_actions[arm_id] += 1

    def choose_arm(self, target: str) -> int:
        raise NotImplementedError("This method should be implemented by subclasses")


class Thompson(Strategy):
    def choose_arm(self, target: str) -> int:
        target_metric = self.arms_states[target]
        return np.argmax(target_metric / (target_metric + self.arms_actions - target_metric)) 