Skip to content

Commit

Permalink
Gamma(shape, scale) should be Number #204
Browse files Browse the repository at this point in the history
  • Loading branch information
walterxie committed May 31, 2022
1 parent e3fc97e commit 49b43d3
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions lphy/src/main/java/lphy/core/distributions/Gamma.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,34 @@

import static lphy.core.distributions.DistributionConstants.scaleParamName;
import static lphy.core.distributions.DistributionConstants.shapeParamName;
import static lphy.graphicalModel.ValueUtils.doubleValue;

/**
* Gamma distribution
*/
public class Gamma implements GenerativeDistribution1D<Double> {

private Value<Double> shape;
private Value<Double> scale;
private Value<Number> shape;
private Value<Number> scale;

GammaDistribution gammaDistribution;

public Gamma(@ParameterInfo(name = shapeParamName, description = "the shape of the distribution.") Value<Double> shape,
@ParameterInfo(name = scaleParamName, description = "the scale of the distribution.") Value<Double> scale) {
public Gamma(@ParameterInfo(name = shapeParamName, description = "the shape of the distribution.") Value<Number> shape,
@ParameterInfo(name = scaleParamName, description = "the scale of the distribution.") Value<Number> scale) {

this.shape = shape;
if (shape == null) throw new IllegalArgumentException("The shape value can't be null!");
this.scale = scale;
if (scale == null) throw new IllegalArgumentException("The scale value can't be null!");

constructGammaDistribution();
constructDistribution();
}

@GeneratorInfo(name = "Gamma", verbClause = "has", narrativeName = "gamma distribution prior",
category = GeneratorCategory.PROB_DIST, examples = {"covidDPG.lphy"},
description = "The gamma probability distribution.")
public RandomVariable<Double> sample() {
constructGammaDistribution();
// constructDistribution() only required in constructor and setParam
double x = gammaDistribution.sample();
return new RandomVariable<>("x", x, this);
}
Expand All @@ -57,15 +58,16 @@ public void setParam(String paramName, Value value) {
else if (paramName.equals(scaleParamName)) scale = value;
else throw new RuntimeException("Unrecognised parameter name: " + paramName);

constructGammaDistribution();
constructDistribution();
}

private void constructGammaDistribution() {
@Override
public void constructDistribution() {
// in case the shape is type integer
double sh = ((Number) shape.value()).doubleValue();
double sh = doubleValue(shape);

// in case the scale is type integer
double sc = ((Number) scale.value()).doubleValue();
double sc = doubleValue(scale);

gammaDistribution = new GammaDistribution(Utils.getRandom(), sh, sc);
}
Expand All @@ -74,11 +76,11 @@ public String toString() {
return getName();
}

public Value<Double> getScale() {
public Value<Number> getScale() {
return scale;
}

public Value<Double> getShape() {
public Value<Number> getShape() {
return shape;
}

Expand Down

0 comments on commit 49b43d3

Please sign in to comment.