In [1]:
using Distributed

In [2]:
addprocs(4, exeflags="--project=$(Base.active_project())")

4-element Vector{Int64}:
 2
 3
 4
 5

In [3]:
@everywhere begin

    using PyPlot
    using Random
    using Setfield
    using Statistics
    using JLD2
    using FileIO
    using Measurements: value
    using CUDA
    using LinearAlgebra
    using ComponentArrays

    using Revise

    using CMBLensing
    using PtsrcLens
    using MPMEstimate
    
end

In [4]:
CMBLensing.assign_GPU_workers()

┌ Info: Processes:
│  (myid = 1, host = cgpu18, device = CuDevice(0): Tesla V100-SXM2-16GB a80f8d10))
│  (myid = 2, host = cgpu18, device = CuDevice(1): Tesla V100-SXM2-16GB 75d9330e))
│  (myid = 3, host = cgpu18, device = CuDevice(2): Tesla V100-SXM2-16GB 7fcea50c))
│  (myid = 4, host = cgpu18, device = CuDevice(3): Tesla V100-SXM2-16GB 904d9fef))
│  (myid = 5, host = cgpu18, device = CuDevice(0): Tesla V100-SXM2-16GB a80f8d10))
└ @ CMBLensing /global/u1/m/marius/work/ptsrclens/dev/CMBLensing/src/util_parallel.jl:104


# Load

In [5]:
@load "data/sehgal_maps_h5/cutouts.jld2"

5-element Vector{Symbol}:
 :ϕs
 :κs
 :gs_ir
 :gs_radio
 :Ms_radio

In [6]:
@unpack (fg_noise_radio, fg_noise_ir) = get_foreground_noise(;Ms_radio, gs_radio, gs_ir);

# Configuration

In [16]:
ℓedges        = [2; 100:50:500; round.(Int, 10 .^ range(log10(502), log10(5000), length=30))];
polfrac_scale = 1
freq          = 148       # (90, 148)
survey        = :deep    # (:deep, :wide)
fluxcut       = 5       # (2, 5, 10, Inf)
sim           = 1;       # 1...40

# Run

In [17]:
noise_kwargs = noises[survey,freq];
fg_noise = fg_noise_radio[survey,freq,fluxcut];
M = cu(Ms_radio[survey,freq,fluxcut][sim])
nbinsϕ = length(ℓedges)-1

Cℓ = get_fiducial_Cℓ(ϕs);

@unpack ds,proj,f,ϕ = load_sim_dataset(;
    Cℓ = Cℓ,
    θpix = 2,
    storage = CuArray,    
    Nside = 300,
    pol = :P,
    bandpass_mask = LowPass(5000),
    noise_kwargs...
);

@unpack B = ds

ds.Cϕ = Cℓ_to_Cov(:I, proj, (Cℓ.total.ϕϕ, ℓedges, :Aϕ));
ds.G = 1


Cℓg = noiseCℓs(μKarcminT=polfrac_scale*value(fg_noise)/√2, beamFWHM=0, ℓknee=0)
Cg = Cℓ_to_Cov(:P, proj, Cℓg.EE, Cℓg.BB);

In [18]:
θ₀ = ComponentArray(Aϕ=ones(Float32,nbinsϕ),)

ComponentVector{Float32}(Aϕ = Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])

In [19]:
ds_nofg = resimulate(ds, ϕ=cu(ϕs[sim]), seed=sim).ds;

In [20]:
ds_withfg = @set(ds_nofg.d += polfrac_scale*B*M*cu(gs_radio[freq][sim]));

In [21]:
MAP_joint_kwargs = (progress=false, αtol=1e-4, nsteps=20, nburnin_update_hessian=Inf);

In [22]:
rng = MersenneTwister(1)

MersenneTwister(1)

In [23]:
regularize(θ,σθ) = similar(θ) .= sum(@. θ / σθ^2) / sum(@. 1 / σθ^2)

regularize (generic function with 1 method)

In [None]:
θmpm, σθ, history    = mpm(ds_nofg  , θ₀; MAP_joint_kwargs, regularize, nsteps=3, α=1, nsims=10, rng, progress=true, map=pmap);

In [None]:
θmpm′, σθ′, history′ = mpm(ds_withfg, θ₀; MAP_joint_kwargs, regularize, nsteps=3, α=1, nsims=10, rng, progress=true, map=pmap);

In [None]:
for (i,h) in enumerate(history′[2:end])
    plot(h.θunreg, label=i)
end
# ylim(-1,0.1)
legend()

In [None]:
for (i,(h,h′)) in enumerate(zip(history[2:end],history′[2:end]))
    σθ = 1 ./ std(collect(h.g_sims))
    plot((h.θunreg-h′.θunreg), label=i)
end
# ylim(-1,0.1)
legend()

# Loop

In [24]:
for sim = 1:40

    try
        
        M = cu(Ms_radio[survey,freq,fluxcut][sim])

        ds_nofg = resimulate(ds, ϕ=cu(ϕs[sim]), seed=sim).ds;
        ds_withfg = @set(ds_nofg.d += polfrac_scale*B*M*cu(gs_radio[freq][sim]));

        rng = MersenneTwister(sim)

        θmpm,  σθ,  history  = mpm(ds_nofg  , θ₀; MAP_joint_kwargs, regularize, nsteps=3, α=1, nsims=10, rng, progress=true, map=pmap);
        θmpm′, σθ′, history′ = mpm(ds_withfg, θ₀; MAP_joint_kwargs, regularize, nsteps=3, α=1, nsims=10, rng, progress=true, map=pmap);

        save("data/mpm_steps/freq$(freq)_$(survey)_fluxcut$(fluxcut)/sim$(sim).jld2", "history_nofg", history, "history_withfg", history′)
    catch err
        if (err isa InterruptException) || (err isa RemoteException && err.captured isa CapturedException && err.captured.ex isa InterruptException)
            rethrow(err)
        else
            @warn err
        end
    end

end

[32mMPM: 100%|██████████████████████████████████████████████| Time: 0:01:34[39m
[32mMPM: 100%|██████████████████████████████████████████████| Time: 0:01:30[39m
[32mMPM: 100%|██████████████████████████████████████████████| Time: 0:01:28[39m
[32mMPM: 100%|██████████████████████████████████████████████| Time: 0:01:28[39m
[32mMPM: 100%|██████████████████████████████████████████████| Time: 0:01:29[39m
[32mMPM: 100%|██████████████████████████████████████████████| Time: 0:01:27[39m
[32mMPM: 100%|██████████████████████████████████████████████| Time: 0:01:29[39m
[32mMPM: 100%|██████████████████████████████████████████████| Time: 0:01:28[39m
[32mMPM: 100%|██████████████████████████████████████████████| Time: 0:01:29[39m
[32mMPM: 100%|██████████████████████████████████████████████| Time: 0:01:29[39m
[32mMPM: 100%|██████████████████████████████████████████████| Time: 0:01:28[39m
[32mMPM: 100%|██████████████████████████████████████████████| Time: 0:01:27[39m
[32mMPM: 100%|█