In [None]:
using Revise
using Pkg
TAMBO_PATH = "/Users/jlazar/research/TAMBO-MC/Tambo/"
Pkg.activate(TAMBO_PATH)

using Tambo
using CSV
using JLD2
using Plots
using PyCall
using Glob
using StaticArrays
using Distributions
include("../paperstyle.jl")

In [None]:
ak = pyimport("awkward")

In [None]:
const sim = jldopen("/Users/jlazar/Downloads//WhitePaper_300k.jld2")
const config = SimulationConfig(
    ;
    geo_spline_path="../resources/tambo_spline.jld2",
    filter(((k, v),) -> k!=:geo_spline_path, sim["config"])...
)
const geo = Tambo.Geometry(config)
const plane = Tambo.Plane(whitepaper_normal_vec, whitepaper_coord, geo)

const zmin = -1100units.m
const zmax = 1100units.m
const ycorsika = SVector{3}([0.89192975455881607, 0.18563051261662877, -0.41231374670066206])
const xcorsika = SVector{3}([0, -0.91184756344828699, -0.41052895273466672])
const zcorsika = whitepaper_normal_vec.proj
const xyzcorsika = inv([
  xcorsika.x xcorsika.y xcorsika.z;
  ycorsika.x ycorsika.y ycorsika.z;
  zcorsika.x zcorsika.y zcorsika.z;
 ])

In [None]:
ℓ = 2000units.m
Δs = 100units.m


detection_modules = Tambo.make_trianglearray(-2000units.m, 3000units.m, -ℓ/2, ℓ/2, Δs, ϕ=whitepaper_normal_vec.ϕ)
mask = zmin .< Tambo.plane_z.(getfield.(detection_modules, :x), getfield.(detection_modules, :y), Ref(plane)) .< zmax;

detection_modules = detection_modules[mask]

xs = LinRange(-5, 5, 201) .* units.km
ys = LinRange(-5, 5, 200) .* units.km


In [None]:
function Base.getindex(events::Vector{Tambo.CorsikaEvent}, s::String)
    return getfield.(events, Symbol(s))
end

In [None]:
function loadcorsika(files::Vector{String})
  println
  events = Tambo.CorsikaEvent[]
  for file in files
    events = [events; loadcorsika(file)]
  end
  return events
end

In [None]:
function loadcorsika(file::String)
  events = Tambo.CorsikaEvent[]
  x = nothing
  try
    x = ak.from_parquet(file)
  catch
    return events
  end
  if length(x)==0
    return events
  end
  xs = x["x"].to_numpy() .* units.m
  ys = x["y"].to_numpy() .* units.m
  zs = x["z"].to_numpy() .* units.m
  new_poss = []
  for (x, y, z) in zip(xs, ys, zs)

    push!(new_poss, xyzcorsika * [x,y,z])
  end
  xs = getindex.(new_poss, 1)
  ys = getindex.(new_poss, 2)
  zs = getindex.(new_poss, 3)
  ts = x["time"].to_numpy() .* units.second
  ws = x["weight"].to_numpy() .* 1.0
  ids = x["pdg"].to_numpy()
  es = x["kinetic_energy"].to_numpy() .* units.GeV
  for tup in zip(ids, es, xs, ys, zs, ts, ws)
    push!(events, Tambo.CorsikaEvent(tup...))
  end
  return events
end

## Event displays

In [None]:
DATA_BASEDIR = "/Users/jlazar/research/TAMBO-MC/resources/test_data"
FILES = [
  glob("*", "$(DATA_BASEDIR)/10543/"),
  glob("*", "$(DATA_BASEDIR)/1005/"),
  glob("*", "$(DATA_BASEDIR)/10161/"),
  glob("*", "$(DATA_BASEDIR)/10013/"),
  glob("*", "$(DATA_BASEDIR)/10172/"),
];

idx = 1
run_number = parse(Int, split(first(FILES[idx]), "/")[end-1])
events = loadcorsika(FILES[idx]);

In [None]:
lerp(x, y, λ) = x .- (x .- y) .* λ

In [None]:
base_plt = plot(
    size=(500, 500),
    xlimits=(first(xs), last(xs))./units.km,
    ylimits=(first(ys), last(ys))./units.km,
    xlabel=L"x~\left[\mathrm{km}\right]",
    ylabel=L"y~\left[\mathrm{km}\right]",
    bottommargin=2mm
)

println(sum(mask))

contour!(
    base_plt,
    xs ./ units.km,
    ys ./ units.km,
    @. (geo(xs', ys) + geo.tambo_offset.z) / units.km;
    fill=true,
    color=palette(:lapaz),
    clims=(1.5, 5),
    colorbar_title=L"\mathrm{Altitude}~\left[\mathrm{km}\right]"
)

scatter!(
    base_plt,
    getfield.(detection_modules, :x) ./ units.km,
    getfield.(detection_modules, :y) ./ units.km,
    alpha=0.5,
    markersize=3,
    color="black",
    markerstrokewidth=0
)
evts = filter((x,)-> x.time < 0.000025units.second, events)

cg = cgrad(:roma, rev=false)

thin = 50

ltmin = minimum(log.(10, evts["time"]/units.second))
ltmax = maximum(log.(10, evts["time"]/units.second))
cs = get.(Ref(cg), (log.(10, evts["time"]/units.second) .- ltmin) / (ltmax - ltmin))
scatter!(
    base_plt,
    prompt_events[begin:thin:end]["x"] / units.km,
    prompt_events[begin:thin:end]["y"] / units.km,
    color=cs[begin:thin:end]
)

# f(λ) = lerp(
#     [injected_event.final_state.position.x, injected_event.final_state.position.y],
#     [propped_event.propped_state.position.x, propped_event.propped_state.position.y],
#     λ
# )

# path = f.(0:0.01:1)

# plot!(
#     base_plt,
#     getindex.(path, 1) / units.km,
#     getindex.(path, 2) / units.km,
#     color=get.(Ref(cgrad(:heat, rev=true)), 0:0.01:1),
# )

# scatter!(
#     base_plt,
#     [path[1][1]] / units.km,
#     [path[1][2]] / units.km,
#     color=get(cgrad(:heat, rev=true), 0),
#     label="Interaction vertex"
# )

# scatter!(
#     base_plt,
#     [path[end][1]] / units.km,
#     [path[end][2]] / units.km,
#     color=get(cgrad(:heat, rev=true), 1),
#     label="Decay vertex"
# )
# savefig(base_plt, "../figures/corsika_contour.pdf")
display(base_plt)

In [None]:
hits = Tambo.find_near_hits(events, detection_modules);

In [None]:
d = Dict{Int, Int}()
od = Dict{Int, Any}()
for hit in hits
    if ~(hit.mod.idx in keys(d))
        d[hit.mod.idx] = 0
        od[hit.mod.idx] = []
    end
    weight = hit.event.weight
    if weight!=1
        weight = rand(Poisson(weight))
    end
    d[hit.mod.idx] += weight
    push!(od[hit.mod.idx], hit.event.time)
end
    

In [None]:
plz = []
for (k, v) in od
    push!(plz, (k, median(v)))
end

In [None]:
trigger_mask = fill(false, size(detection_modules))
sizes = fill(1.0, size(detection_modules))
cs = fill(get(cgrad(:lightrainbow),0.0), size(detection_modules))

med_times = [x[2] for x in plz]
tmin = minimum(med_times)
tmax = maximum(med_times)
Δt = tmax - tmin

for (idx, detmod) in enumerate(detection_modules)
    if ~(detmod.idx in keys(d))
        continue
    end
    if d[detmod.idx] >=3
        trigger_mask[idx] = true
#         sizes[idx] = maximum((3, d[detmod.idx]))
        sizes[idx] = maximum((3, 3*sqrt(d[detmod.idx])))
        cs[idx] = get(cgrad(:lightrainbow, rev=true), (median(od[detmod.idx]) - tmin) / Δt)
    end
end

In [None]:
filter(((k,v),)-> v>=3, d)

In [None]:
base_plt = plot(
  size=(500, 500),
  xlimits=(first(xs), last(xs))./units.km,
  ylimits=(first(ys), last(ys))./units.km
)


contour!(
    base_plt,
    xs ./ units.km,
    ys ./ units.km,
    @. (geo(xs', ys) + geo.tambo_offset.z) / units.km;
    fill=true,
    color=palette(:lapaz),
    clims=(1.5, 5)
)

scatter!(
    base_plt,
    getfield.(detection_modules[.~trigger_mask], :x) ./ units.km,
    getfield.(detection_modules[.~trigger_mask], :y) ./ units.km,
    alpha=0.5,
    markersize=3,
    color=:black,
    markerstrokewidth=0
)

scatter!(
    base_plt,
    getfield.(detection_modules[trigger_mask], :x) ./ units.km,
    getfield.(detection_modules[trigger_mask], :y) ./ units.km,
    alpha=0.5,
#     markersize=3,
    markersize=sizes[trigger_mask],
    markercolor=:yellow,
#     markercolor=cs[trigger_mask],
    markerstrokewidth=0
)

f(λ) = lerp(
    [injected_event.final_state.position.x, injected_event.final_state.position.y],
    [propped_event.propped_state.position.x, propped_event.propped_state.position.y],
    λ
)

path = f.(0:0.01:1)

plot!(
    base_plt,
    getindex.(path, 1) / units.km,
    getindex.(path, 2) / units.km,
    color=get.(Ref(cgrad(:heat, rev=true)), 0:0.01:1),
)

scatter!(
    base_plt,
    [path[1][1]] / units.km,
    [path[1][2]] / units.km,
    color=get(cgrad(:heat, rev=true), 0),
    label="Interaction vertex"
)

scatter!(
    base_plt,
    [path[end][1]] / units.km,
    [path[end][2]] / units.km,
    color=get(cgrad(:heat, rev=true), 1),
    label="Decay vertex"
)

savefig(base_plt, "../figures/triggered_modules_contour.pdf")

display(base_plt)

In [None]:
plot(ecuts / units.GeV, ns ./ maximum(ns), xscale=:log10, yscale=:log10)

## Array size

In [None]:
Δs = 100.0 * units.m
files = glob("test_events_*_$(Int(Δs / units.m)).0.npy", "/Users/jlazar/Downloads/array_size/")

In [None]:
plt = plot(
    xlabel=L"N_{\mathrm{mod}}",
    ylabel=L"\Gamma_{\mathrm{evnt}}~\left[\mathrm{yr}^{-1}\right]",
    bottommargin=1mm,
    leftmargin=1mm
)
Δses = [50, 100.0, 150] .* units.m

for Δs in Δses
    println()
    println(Δs)
    files = glob("test_events_*_$(Int(Δs / units.m)).0.npy", "/Users/jlazar/Downloads/array_size/")
    res = []

    for file in files
        ℓ = parse(Float64, split(file, "_")[end-1]) * units.m

        event_numbers = Int.(np.load(file)[1,:])

        detection_modules = Tambo.make_trianglearray(-2000units.m, 3000units.m, -ℓ/2, ℓ/2, Δs, ϕ=whitepaper_normal_vec.ϕ)
        mask = zmin .< Tambo.plane_z.(getfield.(detection_modules, :x), getfield.(detection_modules, :y), Ref(plane)) .< zmax;
        detection_modules = detection_modules[mask]

        nmodules = sum(mask)
        println((ℓ/units.m, nmodules))

        hese_γ = 2.37
        hese_norm = 6.37e-18 / 3 / units.GeV / units.cm^2 / units.second * (100units.TeV)^hese_γ
        hese_pl = Tambo.PowerLaw(hese_γ, 100units.GeV, 1e9units.GeV, hese_norm)

        injector = Tambo.Injector(config)
        events = sim["injected_events"][event_numbers]
        fluxes = hese_pl.(getfield.(getfield.(events, :initial_state), :energy))
        weights = Tambo.oneweight.(events, Ref(injector.xs), Ref(injector.xs), Ref(injector.powerlaw), Ref(injector.anglesampler), Ref(injector.injectionvolume), Ref(geo)) / 1e5

        nevents = sum(fluxes .* weights) * 10^7.5 * units.second
        push!(res, (nmodules, nevents))
    end
    res = sort(res)
    plot!(plt, [r[1] for r in res], [r[2] for r in res], label="$(Int(Δs / units.m)) m")
end
savefig(plt, "../figures/event_rate_vs_nmod.pdf")
display(plt)


In [None]:
fluxes = hese_pl.(getfield.(getfield.(events, :initial_state), :energy))
weights = Tambo.oneweight.(events, Ref(injector.xs), Ref(injector.xs), Ref(injector.powerlaw), Ref(injector.anglesampler), Ref(injector.injectionvolume), Ref(geo)) / 3e5

nevents = sum(fluxes .* weights) * 10^7.5 * units.second

In [None]:
evts = filter((x,)-> x.time < 0.000025units.second, events)
cg = cgrad(:roma, rev=false)
thin = 50

for ecut in [1, 30, 100, 300] .* units.GeV

    base_plt = plot(
        size=(500, 500),
        xlimits=(first(xs), last(xs))./units.km,
        ylimits=(first(ys), last(ys))./units.km,
        xlabel=L"x~\left[\mathrm{km}\right]",
        ylabel=L"y~\left[\mathrm{km}\right]",
        bottommargin=2mm
    )

    evts = filter((e,)->e.kinetic_energy > ecut, evts)
    contour!(
        base_plt,
        xs ./ units.km,
        ys ./ units.km,
        @. (geo(xs', ys) + geo.tambo_offset.z) / units.km;
        fill=true,
        color=palette(:lapaz),
        clims=(1.5, 5),
        colorbar_title=L"\mathrm{Altitude}~\left[\mathrm{km}\right]"
    )

    scatter!(
        base_plt,
        getfield.(detection_modules, :x) ./ units.km,
        getfield.(detection_modules, :y) ./ units.km,
        alpha=0.5,
        markersize=3,
        color="black",
        markerstrokewidth=0
    )


    ltmin = minimum(log.(10, evts["time"]/units.second))
    ltmax = maximum(log.(10, evts["time"]/units.second))
    cs = get.(Ref(cg), (log.(10, evts["time"]/units.second) .- ltmin) / (ltmax - ltmin))
    scatter!(
        base_plt,
        evts[begin:thin:end]["x"] / units.km,
        evts[begin:thin:end]["y"] / units.km,
        color=cs[begin:thin:end]
    )

    display(base_plt)
end

In [None]:

println(sum(mask))

for ecut in [1,3,10,30,100]*units.GeV
    base_plt = plot(size=(500, 500), xlimits=(first(xs), last(xs))./units.km, ylimits=(first(ys), last(ys))./units.km)

    cut_events = filter((e,)-> e.kinetic_energy>ecut, events)
    println(length(cut_events))
    contour!(
        base_plt,
        xs ./ units.km,
        ys ./ units.km,
        @. (geo(xs', ys) + geo.tambo_offset.z) / units.km;
        fill=true,
        color=palette(:lapaz),
        clims=(1.5, 5)
    )

    scatter!(
        base_plt,
        getfield.(detection_modules, :x) ./ units.km,
        getfield.(detection_modules, :y) ./ units.km,
        alpha=0.5,
        markersize=3,
        color="black",
        markerstrokewidth=0
    )
    prompt_events = filter((x,)-> x.time < 0.000025units.second, cut_events)

    cg = cgrad(:lightrainbow, rev=true)

    thin = 50

    ltmin = minimum(log.(10, prompt_events["time"]/units.second))
    ltmax = maximum(log.(10, prompt_events["time"]/units.second))
    cs = get.(Ref(cg), (log.(10, prompt_events["time"]/units.second) .- ltmin) / (ltmax - ltmin))
    scatter!(
        base_plt,
        prompt_events[begin:thin:end]["x"] / units.km,
        prompt_events[begin:thin:end]["y"] / units.km,
        color=cs[begin:thin:end],
    #     zcolor=(prompt_events[begin:thin:end]["z"] .+geo.tambo_offset.z) / units.km,
    #     color=palette(:lapaz),
    #     clims=(1.5, 5)
    )

    f(λ) = lerp(
        [injected_event.final_state.position.x, injected_event.final_state.position.y],
        [propped_event.propped_state.position.x, propped_event.propped_state.position.y],
        λ
    )

    path = f.(0:0.01:1)

    plot!(
        base_plt,
        getindex.(path, 1) / units.km,
        getindex.(path, 2) / units.km,
        color=get.(Ref(cgrad(:heat, rev=true)), 0:0.01:1),
    )

    scatter!(
        base_plt,
        [path[1][1]] / units.km,
        [path[1][2]] / units.km,
        color=get(cgrad(:heat, rev=true), 0),
        label="Interaction vertex"
    )

    scatter!(
        base_plt,
        [path[end][1]] / units.km,
        [path[end][2]] / units.km,
        color=get(cgrad(:heat, rev=true), 1),
        label="Decay vertex"
    )


    display(base_plt)
end

In [None]:
ns = []
for ecut in 10 .^ LinRange(0, 3, 31) * units.GeV
    cut_events = filter((e,)-> e.kinetic_energy>ecut, events)
    push!(ns, length(cut_events))
end

In [None]:
scatter(10 .^ LinRange(0, 3, 31), ns ./ maximum(ns), xscale=:log10, yscale=:log10)

In [None]:
sum(.~Tambo.inside.(sim["proposal_events"][1:99875]["propped_state"]["position"], Ref(geo)))