Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
124 lines (122 sloc) 4.16 KB
#' @export
ThompsonSamplingPolicy <- R6::R6Class(
portable = FALSE,
class = FALSE,
inherit = Policy,
public = list(
alpha = 1,
beta = 1,
class_name = "ThompsonSamplingPolicy",
initialize = function(alpha = 1, beta = 1) {
self$alpha <- alpha
self$beta <- beta
},
set_parameters = function(context_params) {
self$theta_to_arms <- list('succes' = self$alpha, 'failure' = self$beta, "mu" = 0)
},
get_action = function(t, context) {
for (arm in 1:context$k) {
self$theta$mu[[arm]] <- stats::rbeta(1, self$theta$succes[[arm]], self$theta$failure[[arm]])
}
action$choice <- which_max_list(self$theta$mu)
action
},
set_reward = function(t, context, action, reward) {
arm <- action$choice
reward <- reward$reward
inc(self$theta$succes[[arm]]) <- reward
inc(self$theta$failure[[arm]]) <- 1 - reward
self$theta
}
)
)
#' Policy: Thompson Sampling
#'
#' \code{ThompsonSamplingPolicy} works by maintaining a prior on the the mean rewards of its arms.
#' In this, it follows a beta-binomial model with parameters alpha and beta, sampling values
#' for each arm from its prior and picking the arm with the highest value.
#' When an arm is pulled and a Bernoulli reward is observed, it modifies the prior based on the reward.
#' This procedure is repeated for the next arm pull.
#'
#' @name ThompsonSamplingPolicy
#'
#' @section Usage:
#' \preformatted{
#' policy <- ThompsonSamplingPolicy(alpha = 1, beta = 1)
#' }
#'
#' @section Arguments:
#'
#' \describe{
#' \item{\code{alpha}}{
#' integer, a natural number N>0 - first parameter of the Beta distribution
#' }
#' \item{\code{beta}}{
#' integer, a natural number N>0 - second parameter of the Beta distribution
#' }
#' }
#'
#' @section Methods:
#'
#' \describe{
#' \item{\code{new(alpha = 1, beta = 1)}}{ Generates a new \code{ThompsonSamplingPolicy} object.
#' Arguments are defined in the Argument section above.}
#' }
#'
#' \describe{
#' \item{\code{set_parameters()}}{each policy needs to assign the parameters it wants to keep track of
#' to list \code{self$theta_to_arms} that has to be defined in \code{set_parameters()}'s body.
#' The parameters defined here can later be accessed by arm index in the following way:
#' \code{theta[[index_of_arm]]$parameter_name}
#' }
#' }
#'
#' \describe{
#' \item{\code{get_action(context)}}{
#' here, a policy decides which arm to choose, based on the current values
#' of its parameters and, potentially, the current context.
#' }
#' }
#'
#' \describe{
#' \item{\code{set_reward(reward, context)}}{
#' in \code{set_reward(reward, context)}, a policy updates its parameter values
#' based on the reward received, and, potentially, the current context.
#' }
#' }
#'
#' @references
#'
#' Thompson, W. R. (1933). On the likelihood that one unknown probability exceeds another in view of
#' the evidence of two samples. Biometrika, 25(3/4), 285-294.
#'
#' Chapelle, O., & Li, L. (2011). An empirical evaluation of thompson sampling. In Advances in neural
#' information processing systems (pp. 2249-2257).
#'
#' Agrawal, S., & Goyal, N. (2013, February). Thompson sampling for contextual bandits with linear payoffs.
#' In International Conference on Machine Learning (pp. 127-135).b
#'
#' @seealso
#'
#' Core contextual classes: \code{\link{Bandit}}, \code{\link{Policy}}, \code{\link{Simulator}},
#' \code{\link{Agent}}, \code{\link{History}}, \code{\link{Plot}}
#'
#' Bandit subclass examples: \code{\link{BasicBernoulliBandit}}, \code{\link{ContextualLogitBandit}},
#' \code{\link{OfflineReplayEvaluatorBandit}}
#'
#' Policy subclass examples: \code{\link{EpsilonGreedyPolicy}}, \code{\link{ContextualLinTSPolicy}}
#'
#' @examples
#'
#' horizon <- 100L
#' simulations <- 100L
#' weights <- c(0.9, 0.1, 0.1)
#'
#' policy <- ThompsonSamplingPolicy$new(alpha = 1, beta = 1)
#' bandit <- BasicBernoulliBandit$new(weights = weights)
#' agent <- Agent$new(policy, bandit)
#'
#' history <- Simulator$new(agent, horizon, simulations, do_parallel = FALSE)$run()
#'
#' plot(history, type = "cumulative")
NULL
You can’t perform that action at this time.