Skip to content

Commit

Permalink
Implement ThompsonSampling
Browse files Browse the repository at this point in the history
  • Loading branch information
XPerianer committed Nov 28, 2023
1 parent a01e9ff commit e518417
Show file tree
Hide file tree
Showing 25 changed files with 541 additions and 71 deletions.
78 changes: 70 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,92 @@ and the Flutter guide for
[developing packages and plugins](https://flutter.dev/developing-packages).
-->

[![Dart CI](https://github.com/XPerianer/Bandart/actions/workflows/test-package.yml/badge.svg)](https://github.com/XPerianer/Bandart/actions/workflows/test-package.yml)

Bandart (combination of "Bandits" and "Dart") is a library for bandit algorithms in Dart. It provides Bayesian Models for statistical analysis for the data.

## Features
- Bandit algorithms:
- Fixed Schedule
- Thompson Sampling
- Bayesian Models:
- Beta Model: Approximating success probability using Beta Distributions
- Gaussian Model: Approximating mean and variance using conjugate priors

TODO: List what your package can do. Maybe include images, gifs, or videos.

## Getting started

TODO: List prerequisites and provide or point to information on how to
start using the package.
Install the package using ```dart pub add bandart```.
Now you can use
```dart
import 'package:bandart/bandart.dart';
```

## Usage

TODO: Include short and useful examples for package users. Add longer examples
to `/example` folder.
### Create A Fixed Schedule

```dart
import 'package:bandart/bandart.dart';
int numberOfInterventions = 3;
var policy = FixedPolicy(numberofInterventions);
for (int i = 0; i < 10; i++) {
print(policy.choseAction({"decisionPoint": i}))
}
// Prints 0, 1, 2, 0, 1, 2, 0, 1, 2, 0
```

### Analyze data using Normal Model
```dart
const like = 'sample';
import 'package:bandart/bandart.dart';
var history = DataFrame(
{'intervention': [0, 0, 1, 1], 'outcome': [1.0, 1.0, 2.0, 2.0]})
// Mean and L are priors for the normal model
var gaussianModel = GaussianModel(
numberOfInterventions: 2, mean: 1.0, l: 1.0, random: Random(0))
gaussianModel.history = history
// Update the samples
gaussianModel.sample()
// Look at the results
print(gaussianModel.maxProbabilities())
// Prints around [0.2, 0.8]
```

### Create an adaptive schedule using Thompson Sampling

```dart
import 'package:bandart/bandart.dart';
var history = DataFrame(
{'intervention': [0, 0, 1, 1], 'outcome': [1.0, 1.0, 2.0, 2.0]})
// Mean and L are priors for the normal model
var gaussianModel = GaussianModel(
numberOfInterventions: 2, mean: 1.0, l: 1.0, random: Random(0))
var policy = ThompsonSampling(numberOfInterventions: 2);
// Prints either 0 or 1 (randomized), but will more often pick 1 as this is the intervention with the better history
print(policy.choseAction({}, history));
```

## Additional information

### Development goals
### Development

#### Setup
Test can be run with ```dart test```.
Some tests need mock code generated by mockito.
To update the mock code, run ```dart run build_runner build```.

#### Goals
- Performance: Since Dart is most often used for mobile development, the goal of the library is to support calculations fast enough to run on smartphones.
- Seedability: When seeded with the same seed, the library should always return the same result.
- Well tested
Expand All @@ -43,4 +106,3 @@ const like = 'sample';
TODO: Tell users more about the package: where to find more information, how to
contribute to the package, how to file issues, what response they can expect
from the package authors, and more.
# Bandart
13 changes: 13 additions & 0 deletions lib/bandart.dart
Original file line number Diff line number Diff line change
@@ -1 +1,14 @@
library bandart;

export 'src/dataframe.dart' show DataFrame;

export 'src/models/model.dart' show Model;
export 'src/models/sampling_model.dart' show SamplingModel;
export 'src/models/beta.dart' show BetaModel;
export 'src/models/gaussian.dart' show GaussianModel;

export 'src/policies/policy.dart' show Policy;
export 'src/policies/fixed_policy.dart' show FixedPolicy;
export 'src/policies/thompson_sampling.dart' show ThompsonSampling;

export 'src/helpers.dart' show weightedRandom;
27 changes: 0 additions & 27 deletions lib/helpers.dart

This file was deleted.

10 changes: 0 additions & 10 deletions lib/models/model.dart

This file was deleted.

7 changes: 0 additions & 7 deletions lib/policies/policy.dart

This file was deleted.

1 change: 0 additions & 1 deletion lib/policies/thompson_sampling.dart

This file was deleted.

1 change: 0 additions & 1 deletion lib/series.dart

This file was deleted.

7 changes: 7 additions & 0 deletions lib/dataframe.dart → lib/src/dataframe.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/// A DataFrame is a data structure that stores data in a tabular format
class DataFrame {
Map data;

Expand All @@ -9,6 +10,12 @@ class DataFrame {
return data[key];
}

/// Returns a map of the values in the column [column] grouped by the values in the column [key]
/// Example:
/// ```dart
/// var df = DataFrame({'a': [0, 0, 1, 1], 'b': [1, 2, 3, 4]})
/// df.groupBy('a', 'b') // {0: [1, 2], 1: [3, 4]}
/// ```
Map<num, List<num>> groupBy(String key, String column) {
Map<num, List<num>> values = {};
var keyValues = data[key];
Expand Down
File renamed without changes.
48 changes: 48 additions & 0 deletions lib/src/helpers.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
library helpers;

import 'dart:math';

import 'package:collection/collection.dart';

int argMax<T extends num>(List<T> list) {
return list
.asMap()
.entries
.reduce(
(MapEntry<int, T> l, MapEntry<int, T> r) => l.value > r.value ? l : r)
.key;
}

double mean(List<num> a, [double defaultValue = 0]) {
if (a.isEmpty) {
return defaultValue;
}
return a.average;
}

double variance(List<num> a, [double defaultValue = 0]) {
if (a.isEmpty) {
return defaultValue;
}
double m = mean(a);
return a.map((e) => (e - m) * (e - m)).average;
}

/// Returns a number between [0, weights.length) with probability proportional to the weights
int weightedRandom(List<double> weights, Random random) {
assert(weights.isNotEmpty);

List<num> cumulativeWeights = weights;
for (int i = 1; i < weights.length; i++) {
cumulativeWeights[i] += cumulativeWeights[i - 1];
}

// Instead of normalizing, we just use the sum (last cumulative weight) as a factor
double r = random.nextDouble() * cumulativeWeights.last;
for (int i = 0; i < weights.length; i++) {
if (r <= cumulativeWeights[i]) {
return i;
}
}
return weights.length - 1;
}
9 changes: 7 additions & 2 deletions lib/models/beta.dart → lib/src/models/beta.dart
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import 'package:bandart/models/sampling_model.dart';
import 'package:bandart/bandart.dart';

import 'package:data/data.dart';

import 'package:bandart/helpers.dart' as helpers;
import 'package:bandart/src/helpers.dart' as helpers;

/// BetaModel is a SamplingModel that models the data under each intervention after a Beta distribution.
///
/// Each Beta Variable models the probability of an intervention beeing a success.
/// A success is defined as an outcome that is greater than the average outcome.
/// For example if the outcomes for an intervention are [1, 2, 3], then the average outcome is 2, and the last two outcomes are successes.
class BetaModel extends SamplingModel {
final double _a, _b;

Expand Down
16 changes: 12 additions & 4 deletions lib/models/gaussian.dart → lib/src/models/gaussian.dart
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
import 'package:bandart/helpers.dart' as helpers;
import 'package:bandart/models/sampling_model.dart';
import 'package:bandart/src/helpers.dart' as helpers;
import 'package:bandart/src/models/sampling_model.dart';

import 'package:data/data.dart';

import 'dart:math';

/// GaussianModel is a SamplingModel that models the data under each intervention after a Gaussian distribution with unknown mean and variance
///
/// The model is based on the Normal-Inverse-Gamma distribution.
/// See https://en.wikipedia.org/wiki/Normal-inverse-gamma_distribution for details.
/// The formulas used to update the priors can be found under https://en.wikipedia.org/wiki/Conjugate_prior
/// under "Normal with unkown mean and variance"
class GaussianModel extends SamplingModel {
final double _mean, _l, _alpha, _beta;

/// Initalizes a GaussianModel
///
/// mean, l, alpha, beta are the priors for the Normal-Inverse-Gamma distribution.
GaussianModel(
{required numberOfInterventions,
required mean,
Expand Down Expand Up @@ -102,8 +112,6 @@ class GaussianModel extends SamplingModel {

@override
void sample([Map? context]) {
// Calculate posterior parameters
// See https://en.wikipedia.org/wiki/Conjugate_prior and then Normal with unkown mean and variance
List<double> mean = [], l = [], beta = [], alpha = [];

for (int intervention = 0;
Expand Down
12 changes: 12 additions & 0 deletions lib/src/models/model.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import 'package:bandart/bandart.dart';

/// Abstract class with capabilities to model data analysis in a bandit setting
abstract class Model {
/// The history is the sequence of interventions and outcomes the model can use to learn which interventions have which effects
DataFrame? history;

/// The number of interventions are the interventions the model will analyze between [0, numberOfInterventions)
int numberOfInterventions;

Model(this.numberOfInterventions);
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import 'dart:math';

import 'package:bandart/models/model.dart';
import 'package:bandart/helpers.dart' as helpers;
import 'package:bandart/src/models/model.dart';
import 'package:bandart/src/helpers.dart' as helpers;

/// A SamplingModel is an abstract representation of a model that uses sampling from a posterior distribution to analyze data.
///
/// Usually, after setting the history, one calls model.sample(), and then can use different calls to summarize the posterior distribution.
/// Example:
/// ```dart
/// var model = model
/// model.history = history
/// model.sample()
/// print(model.maxProbabilities())
/// ```
abstract class SamplingModel extends Model {
// array of floats
/// Random number generator that can be used for seeding
Random random;

/// sampleSize is the number of samples drawn on `sample()`
int sampleSize;
List<List<double>> _samples = [];

Expand All @@ -15,11 +27,16 @@ abstract class SamplingModel extends Model {
this.sampleSize = 5000})
: super(numberOfInterventions);

/// Sample from the posterior distribution.
/// This will be overwritten by calling sample().
/// Getter can be used to access the samples in order to calculate custom metrics that are not available by default.
get samples => _samples;
set samples(samples) => _samples = samples;

void sample([Map? context]);

/// Calculates the probability that intervention is the best by taking an argMax over the samples.
/// Output is an array of length numberOfInterventions, where each entry is the probability that the corresponding intervention is the best.
List<double> getSampleProbabilities({bool max = true}) {
if (samples.isEmpty) {
return List.filled(numberOfInterventions, 1 / numberOfInterventions);
Expand All @@ -39,14 +56,18 @@ abstract class SamplingModel extends Model {
return winningCounts.map((e) => e / sampleSize).toList();
}

/// Convenience Function. Calculates `getSampleProbabilities` with max:true
List<double> maxProbabilities() {
return getSampleProbabilities(max: true);
}

/// Convenience Function. Calculates `getSampleProbabilities` with max:false
List<double> minProbabilities() {
return getSampleProbabilities(max: false);
}

/// Calculates the mean of the samples for each intervention.
/// Output is a List of size numberOfInterventions, where each entry is the mean of the samples for the corresponding intervention.
List<double> interventionMeans() {
List<double> meanInterventionEffect = [];
for (int intervention = 0;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import 'package:bandart/dataframe.dart';
import 'package:bandart/exceptions.dart';
import 'package:bandart/policies/policy.dart';
import 'package:bandart/src/exceptions.dart';
import 'package:bandart/bandart.dart';

/// FixedPolicy implements a fixed schedule that repeats the sequence of increasing interventions.
///
/// Example: If numberOfInterventions is 3, then the sequence of interventions will be 0, 1, 2, 0, 1, 2, ...
class FixedPolicy implements Policy {
final int numberOfInterventions;

Expand Down
8 changes: 8 additions & 0 deletions lib/src/policies/policy.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import 'package:bandart/bandart.dart';

/// Policy is an abstract class that implements the logic for choosing an intervention in a bandit setting.
abstract class Policy {
Policy();

int choseAction(Map context, DataFrame history);
}

0 comments on commit e518417

Please sign in to comment.