In [None]:
using Plots, StatsPlots
using Random, Distributions
using Zygote
using Distributed

In [None]:
const nmax_workers = Sys.CPU_THREADS - 1

if nworkers() < nmax_workers
    addprocs(nmax_workers - nworkers())
end

In [None]:
n = 100
μ = 0
σ = 1
x = μ .+ σ * randn(n)
θ = [μ, log(σ)]
θ_init = [mean(x), log(std(x))]
;

In [None]:
@everywhere begin
    using Random, Distributions, Zygote
    
    function score(θ, x)
        return gradient(θ -> logpdf(Normal(θ[1], exp(θ[2])), x), θ)[1]
    end

    function martingale(θ, T, seed)
        Random.seed!(seed)
        θ_trace = zeros(2, T)
        θ_trace[:,1] = θ
        for m in 2:T
            x = θ[1] + exp(θ[2]) * randn()
            θ = θ + 1/(m+1) * score(θ, x)
            θ_trace[:,m] = θ
        end
        return θ_trace
    end
end

In [None]:
T = 100
n_chains = 3000
s_chains = 1:n_chains
θ_chains = [θ_init for _ in 1:n_chains]
T_chains = repeat([T], n_chains)
;

In [None]:
θ_trace = pmap(
    (θ, T, seed) -> martingale(θ, T, seed),
    θ_chains,
    T_chains,
    s_chains
)
θ_samples = hcat([chain[:,end] for chain in θ_trace]...)
;

In [None]:
plt = plot(size=(1000,400), layout=(1,2), bottom_margin=5Plots.mm, left_margin=5Plots.mm)
for i in 1:10:n_chains
    plot!(θ_trace[i][1,:], subplot=1, label=nothing, color=:black, alpha=0.5)
    plot!(θ_trace[i][2,:], subplot=2, label=nothing, color=:black, alpha=0.5)
end
xlabel!("Iteration", subplot=1)
xlabel!("Iteration", subplot=2)
ylabel!("θ[1]", subplot=1)
ylabel!("θ[2]", subplot=2)
title!("μ", subplot=1)
title!("logσ", subplot=2)
display(plt)

In [None]:
plt = plot(size=(1000,400), layout=(1,2), bottom_margin=5Plots.mm, left_margin=5Plots.mm)
stephist!(θ_samples[1,:], normalize=:pdf, subplot=1, label="Martingale hist", color=:black, linewidth=2)
stephist!(θ_samples[2,:], normalize=:pdf, subplot=2, label="Martingale hist", color=:black, linewidth=2)
vline!([mean(θ_samples[1,:])], subplot=1, label="Martingale", color=:green, linewidth=2)
vline!([mean(θ_samples[2,:])], subplot=2, label="Martingale", color=:green, linewidth=2)
vline!([θ_init[1]], subplot=1, label="Sample", color=:blue, linewidth=2)
vline!([θ_init[2]], subplot=2, label="Sample", color=:blue, linewidth=2)
vline!([θ[1]], subplot=1, label="True", color=:red, linewidth=2)
vline!([θ[2]], subplot=2, label="True", color=:red, linewidth=2)
xlabel!("θ[1]", subplot=1)
xlabel!("θ[2]", subplot=2)
ylabel!("Density", subplot=1)
ylabel!("Density", subplot=2)
title!("μ", subplot=1)
title!("logσ", subplot=2)
display(plt)