In [None]:
using Pkg
Pkg.activate("..")
using TGLFNN
using Plots
using Measurements

filename = "sat2_em_ukstep_azf-1.bson"
fname = split(filename, ".")[1]
model_dir = "../models/$fname"
if !isdir(model_dir)
    mkpath(model_dir)
end

# Extract flux models and xnames
ensemble = TGLFNN.loadmodelonce(filename)
fluxmodels = [model.fluxmodel for model in ensemble.models]
xnames = ensemble.models[1].xnames
ynames = ensemble.models[1].ynames
xm = ensemble.models[1].xm
xσ = ensemble.models[1].xσ
ym = ensemble.models[1].ym
yσ = ensemble.models[1].yσ

# Populate xtest dictionary with mean values
xtest = Dict()
for (i, name) in enumerate(xnames)
    clean_name = replace(name, "_log10" => "")
    xtest[clean_name] = [xm[i]]
end

#=xtest["RLNS_1"] = 2.22
xtest["RLTS_1"] = 4.32
xtest["RLTS_2"] = 2.41
xtest["TAUS_2"] = 0.681
xtest["RMIN_LOC"] = 0.82
xtest["Q_LOC"] = 17.9
xtest["KAPPA_LOC"] = 2.48
xtest["DELTA_LOC"] = 0.297
xtest["BETAE"] = log10(0.00407)
xtest["ZEFF"] = 2.11=#

# Prepare the input vector for the model
x = zeros(Float32, length(xnames))
for (i, name) in enumerate(xnames)
    clean_name = replace(name, "_log10" => "")
    if haskey(xtest, clean_name)
        value = xtest[clean_name][1]  # Extract the scalar from the array
        x[i] = Float32(value)
    else
        error("Key '$clean_name' not found in xtest")
    end
end

# Identify relevant indices in xnames
dx1 = Int[]
for i in 1:length(xnames)
    if occursin("RL", xnames[i])
        push!(dx1, i)
    end
end

function delog10(x::Vector, xnames::Vector)
    x = copy(x)
    for ix in findall(map(name -> contains(name, "_log10"), xnames))
        x[ix] = 10.0 .^ (x[ix])
    end
    return x
end

function f(model, x1, xi) 
    model(hcat([delog10(vcat(x[1:xi-1], xx, x[xi+1:end]), xnames) for xx in x1]...); uncertain=true, warn_nn_train_bounds=false)
end

# Initialize a list to store each subplot
plots = []

for yi in 1:4
    if yi != 3
        continue
    end
    for xi in 1:length(xnames)
        y = ensemble(delog10(x, xnames); warn_nn_train_bounds=false, uncertain=true)

        ylbl = ynames[yi] == "OUT_Q_elec" ? "Q_e [gB]" : ynames[yi]

        xrange = LinRange(x[xi] - 3.3 * xσ[xi], x[xi] + 3.3 * xσ[xi], 10)
        
        tmp = f(ensemble, xrange, xi)[yi, :]

        p = plot(xrange, Measurements.value.(tmp), ribbon=Measurements.uncertainty.(tmp),
                 xlabel=xnames[xi], ylabel="", label="", linewidth=2, color=:red)
        scatter!([x[xi]], [y[yi]], color=:blue, label="")
        vline!([xm[xi]], color=:black, linestyle=:dash, label="")
        hline!([0], color=:black, linestyle=:solid, label="")

        push!(plots, p)
    end
end

# Arrange all plots in a tile layout
n_cols = 6
n_rows = cld(length(plots), n_cols)
plot_grid = plot(plots..., layout=(n_rows, n_cols), size=(1000, 800))

display(plot_grid)
savefig(plot_grid, "./plot_spot_check_$(split(filename,".")[1]).pdf")