-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
25 changed files
with
541 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,14 @@ | ||
library bandart; | ||
|
||
export 'src/dataframe.dart' show DataFrame; | ||
|
||
export 'src/models/model.dart' show Model; | ||
export 'src/models/sampling_model.dart' show SamplingModel; | ||
export 'src/models/beta.dart' show BetaModel; | ||
export 'src/models/gaussian.dart' show GaussianModel; | ||
|
||
export 'src/policies/policy.dart' show Policy; | ||
export 'src/policies/fixed_policy.dart' show FixedPolicy; | ||
export 'src/policies/thompson_sampling.dart' show ThompsonSampling; | ||
|
||
export 'src/helpers.dart' show weightedRandom; |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
library helpers; | ||
|
||
import 'dart:math'; | ||
|
||
import 'package:collection/collection.dart'; | ||
|
||
int argMax<T extends num>(List<T> list) { | ||
return list | ||
.asMap() | ||
.entries | ||
.reduce( | ||
(MapEntry<int, T> l, MapEntry<int, T> r) => l.value > r.value ? l : r) | ||
.key; | ||
} | ||
|
||
double mean(List<num> a, [double defaultValue = 0]) { | ||
if (a.isEmpty) { | ||
return defaultValue; | ||
} | ||
return a.average; | ||
} | ||
|
||
double variance(List<num> a, [double defaultValue = 0]) { | ||
if (a.isEmpty) { | ||
return defaultValue; | ||
} | ||
double m = mean(a); | ||
return a.map((e) => (e - m) * (e - m)).average; | ||
} | ||
|
||
/// Returns a number between [0, weights.length) with probability proportional to the weights | ||
int weightedRandom(List<double> weights, Random random) { | ||
assert(weights.isNotEmpty); | ||
|
||
List<num> cumulativeWeights = weights; | ||
for (int i = 1; i < weights.length; i++) { | ||
cumulativeWeights[i] += cumulativeWeights[i - 1]; | ||
} | ||
|
||
// Instead of normalizing, we just use the sum (last cumulative weight) as a factor | ||
double r = random.nextDouble() * cumulativeWeights.last; | ||
for (int i = 0; i < weights.length; i++) { | ||
if (r <= cumulativeWeights[i]) { | ||
return i; | ||
} | ||
} | ||
return weights.length - 1; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import 'package:bandart/bandart.dart'; | ||
|
||
/// Abstract class with capabilities to model data analysis in a bandit setting | ||
abstract class Model { | ||
/// The history is the sequence of interventions and outcomes the model can use to learn which interventions have which effects | ||
DataFrame? history; | ||
|
||
/// The number of interventions are the interventions the model will analyze between [0, numberOfInterventions) | ||
int numberOfInterventions; | ||
|
||
Model(this.numberOfInterventions); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 5 additions & 3 deletions
8
lib/policies/fixed_policy.dart → lib/src/policies/fixed_policy.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import 'package:bandart/bandart.dart'; | ||
|
||
/// Policy is an abstract class that implements the logic for choosing an intervention in a bandit setting. | ||
abstract class Policy { | ||
Policy(); | ||
|
||
int choseAction(Map context, DataFrame history); | ||
} |
Oops, something went wrong.