# Lab CMBLenses

## Model

In [1]:
using CMBLensing, CMBLensingInferenceTestProblem
using CUDA, LaTeXStrings, MCMCDiagnosticTools, Plots, ProgressMeter, Random, Zygote
using LinearAlgebra, Statistics
ENV["LINES"] = 10;

In [2]:
using Revise
using MicroCanonicalHMC

In [None]:
prob = load_cmb_lensing_problem(storage = Array,
                                T = Float32,
                                Nside = 64);

In [None]:
target = CMBLensingTarget(prob);

In [None]:
d = length(prob.Ωstart[:])
to_vec, from_vec = CMBLensingInferenceTestProblem.to_from_vec(prob.Ωstart)

## Sample

In [None]:
#spl=MCHMC(10.0, 10.0*sqrt(50); sigma=one(LenseBasis(diag(target.Λmass))))
spl = MCHMC(500, 0.0001;  adaptive=true, init_eps=10.0, init_L=10.0*sqrt(50), sigma=one(LenseBasis(diag(target.Λmass))))

In [None]:
samples = Sample(spl, target, 1000; dialog=false)

In [None]:
samples_redux = []
for sample in samples
    if all(isfinite.(sample))
            append!(samples_redux, [sample])
    end
end    

In [None]:
_samples = zeros(length(samples), length(samples[1]), 1)
_samples[:, :, 1] = mapreduce(permutedims, vcat, samples)
_samples = permutedims(_samples, (1,3,2))
ess, rhat = MCMCDiagnosticTools.ess_rhat(_samples)

In [None]:
_ess = median(ess)

In [None]:
_rhat = median(rhat)

## Plotting

In [None]:
Plots.default(fmt=:png, dpi=120, size=(500,300), legendfontsize=10)

In [None]:
rs = [sample[1:d][end-1] for sample in samples_redux]
Aϕs = [sample[1:d][end] for sample in samples_redux];

In [None]:
p_rs = [sample[d:2d][end-1] for sample in samples_redux]
p_Aϕs = [sample[d:2d][end] for sample in samples_redux];

In [None]:
plot(exp.(rs), label=L"r", xlabel="step")
plot!(exp.(Aϕs), label=L"A_\phi")


In [None]:
plot(exp.(p_rs), label=L"\Pi(r)", xlabel="step")
plot!(exp.(p_Aϕs), label=L"\Pi(A_\phi)")


In [None]:
Energy = [sample[end-1] for sample in samples_redux];
std(Energy)^2/d

In [None]:
plot(Energy/d, label=L"Energy/d", xlabel="step")

In [None]:
plot(histogram(exp.(rs), xlabel=L"r", label=nothing, lw=1),
    histogram(exp.(Aϕs), xlabel=L"A_\phi", label=nothing, lw=1))

In [None]:
to_vec, from_vec = CMBLensingInferenceTestProblem.to_from_vec(prob.Ωstart)
last_sample = from_vec(samples[end][1:end-2])

In [None]:
ps = map([(:ϕ°,:I,L"L",L"\phi^\circ"), (:f°,:E,"L\ell",L"E^\circ"), (:f°,:B,L"\ell",L"B^\circ")]) do (k1, k2, xlabel, title)
    plot(get_Cℓ(prob.Ωtrue[k1][k2]); label="true", xlabel, title)
    plot!(get_Cℓ(prob.Ωstart[k1][k2]); label="start", xlabel, title)
    plot!(get_Cℓ(last_sample[k1][k2]); label="last sample", xlabel, title)
end
plot(ps..., layout=(1,3), xscale=:log10, yscale=:log10, size=(1000,300), legend=:bottomleft)

In [None]:
samps = [∇²*from_vec(sample[1:end-2]).ϕ°/2 for sample in samples[100:end]]
anim = @animate for i in 1:length(samps)
    motionblur = [0.5, 1, 0.5]
    frame = sum(circshift(samps, i)[1:length(motionblur)] .* motionblur) / sum(motionblur)
    plot(1 * frame, clims=(-0.5, 0.5), c=:thermal, title="κ samples")
end;

In [None]:
mp4(anim, "kappa_samples.mp4", fps = 25)