Skip to content

BNN model #25

@penelopeysm

Description

@penelopeysm

There's one in test/mcmc/hmc.jl, there's also one in the docs (see if it's the same)

    @testset "multivariate support" begin
        # Define NN flow
        function nn(x, b1, w11, w12, w13, bo, wo)
            h = tanh.([w11 w12 w13]' * x .+ b1)
            return logistic(dot(wo, h) + bo)
        end

        # Generating training data
        N = 20
        M = N ÷ 4
        x1s = rand(M) * 5
        x2s = rand(M) * 5
        xt1s = Array([[x1s[i]; x2s[i]] for i in 1:M])
        append!(xt1s, Array([[x1s[i] - 6; x2s[i] - 6] for i in 1:M]))
        xt0s = Array([[x1s[i]; x2s[i] - 6] for i in 1:M])
        append!(xt0s, Array([[x1s[i] - 6; x2s[i]] for i in 1:M]))

        xs = [xt1s; xt0s]
        ts = [ones(M); ones(M); zeros(M); zeros(M)]

        # Define model

        alpha = 0.16                  # regularizatin term
        var_prior = sqrt(1.0 / alpha) # variance of the Gaussian prior

        @model function bnn(ts)
            b1 ~ MvNormal(
                [0.0; 0.0; 0.0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior]
            )
            w11 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior])
            w12 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior])
            w13 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior])
            bo ~ Normal(0, var_prior)

            wo ~ MvNormal(
                [0.0; 0; 0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior]
            )
            for i in rand(1:N, 10)
                y = nn(xs[i], b1, w11, w12, w13, bo, wo)
                ts[i] ~ Bernoulli(y)
            end
            return b1, w11, w12, w13, bo, wo
        end

        # Sampling
        chain = sample(StableRNG(seed), bnn(ts), HMC(0.1, 5; adtype=adbackend), 10)
    end

Metadata

Metadata

Assignees

No one assigned

    Labels

    this-repoSomething to do with just this repo

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions