In [None]:
using ReactiveMP, Distributions, Random, Rocket, GraphPPL, Flux, Zygote, ForwardDiff
import SpecialFunctions: loggamma

In [None]:
dataset = [[58.7226956168386, 45.13976354300799, 34.891234306106206, 22.049813810571656, 18.029311822124683, 42.714517401378004, 26.86775400403366, 32.107230568182345],
 [4.606339701921808, -0.43877921447069923, -0.8893574689730634],
 [6.478451882763512, 2.888032369033442, -7.49813190569346, -14.758169441040984, -14.478560265189357, -15.401117821926048, -18.672627310307654],
 [1.846619359098037, -2.68986972978257, 10.054981057055565, 23.763077651153214, -12.555004304724793],
 [6.040319868293942, 19.822216736962826, -8.83951499143022, 2.1115471241811163, -2.4026795198407163, 19.997114297241556, -9.985174142931942, -0.3973994620971628],
 [-8.68576178873407, -1.791407472467292, 9.954292431463115, -11.37327663982918, -2.78835504435457, -1.6674794730107032, 17.851944050272632, 5.547874888293355, -1.5037350724506187],
 [33.31815939613223, -13.432641838306846, 3.5993734745928982],
 [8.268877742477276, 35.26754519612162, 33.05402194461256]]

flatten_dataset = [score for school in dataset for score in school]

school_sizes = map(length, dataset);

In [None]:
@model [ default_factorisation = MeanField() ] function school_model(school_sizes)
    α ~ GammaShapeRate(0.1, 0.1)
    α_ ~ identity(α) where {meta = CVIApproximation(100, 100, ADAM(), 100, 20, FactorProduct((NormalMeanVariance(0, 1),)))}
    β ~ GammaShapeRate(0.1, 0.1)
    μ ~ NormalMeanVariance(0, 10)
    s ~ GammaShapeRate(0.1, 1.0)

    x = randomvar(length(school_sizes))
    w = randomvar(length(school_sizes))
    y = datavar(Float64, sum(school_sizes))

    n_count = 0
    for i in 1:8
        x[i] ~ NormalMeanPrecision(μ, s)
        w[i] ~ GammaShapeRate(α_, β)
        for n in 1:school_sizes[i]
            n_count += 1
            y[n_count] ~ NormalMeanPrecision(x[i], w[i])
        end
    end

    return α, β, μ, s, x, w, y
end

In [None]:
ReactiveMP.prod_analytical_rule(::Type{T}, ::Type{<:GenericLogPdfVectorisedProduct{T}}) where {T<:ContinuousUnivariateLogPdf} = ReactiveMP.ProdAnalyticalRuleAvailable()

In [None]:
function ReactiveMP.prod(::ProdAnalytical, left::T, right::GenericLogPdfVectorisedProduct{T}) where {T<:ContinuousUnivariateLogPdf}
    return push!(right, left)
end

In [None]:
ReactiveMP.getdata(m::Marginal{<:ProdFinal}) = m.data.dist

In [None]:
# import DomainSets
@rule GammaShapeRate(:α, Marginalisation) (q_out::Any, q_β::Any, ) = begin 
    return ContinuousUnivariateLogPdf(ReactiveMP.DomainSets.HalfLine(), (α)->α*mean(log, q_β) + (α-1)*mean(log, q_out) - loggamma(α))
end

@rule GammaShapeRate(:β, Marginalisation) (q_out::Any, q_α::Any, ) = begin
    return GammaShapeRate(mean(q_α)+1, mean(q_out))
end 

@rule GammaShapeRate(:out, Marginalisation) (q_α::Any, q_β::Any, ) = begin 
    return GammaShapeRate(mean(q_α), mean(q_β))
end

In [None]:
res = inference(
    model = Model(school_model, school_sizes),
    data = (y = flatten_dataset,),
    iterations = 1000,
    free_energy = false,
    initmarginals = (
        α = Gamma(0.01, 0.01), 
        α_ = Gamma(0.01, 0.01),
        β = Gamma(0.01, 0.01),
        s = Gamma(0.01, 0.01),
        μ = vague(GaussianMeanVariance),
        w = Gamma(0.01, 0.01),
        x = vague(GaussianMeanPrecision)),
    returnvars = (
        α = KeepEach(),
    ))

In [None]:
res.posteriors[:α][end] |> mean