In [None]:
using Pkg
Pkg.activate("..")
#Pkg.instantiate()
#Pkg.resolve()
#Pkg.update()
using TGLFNN
using ONNXNaiveNASflux
using Flux

In [None]:
# Export Julia trained model to ONNX format
filename = "sat2_em_d3d+mastu+nstx_azf-1.bson"
fname = split(filename, ".")[1]
model_dir = "../models/$fname"
if !isdir(model_dir)
    mkpath(model_dir)
end

# Extract flux models and xnames
ensemble = TGLFNN.loadmodelonce(filename)
fluxmodels = [model.fluxmodel for model in ensemble.models]
fluxmodels_f32 = [f32(model) for model in fluxmodels]

xnames = ensemble.models[1].xnames
ynames = ensemble.models[1].ynames
xm = Float32.(ensemble.models[1].xm)
xσ = Float32.(ensemble.models[1].xσ)
ym = Float32.(ensemble.models[1].ym)
yσ = Float32.(ensemble.models[1].yσ)

# Save each model in ONNX format
for i in 1:length(fluxmodels_f32)
    onnx_path = "$model_dir/$(fname)_model_$i.onnx"
    ONNXNaiveNASflux.save(onnx_path, fluxmodels_f32[i])
end

# Save xnames to a single text file
xnames_path = "$model_dir/xnames.txt"
open(xnames_path, "w") do f
    for name in xnames
        write(f, "$name\n")
    end
end

# Save ynames to a single text file
ynames_path = "$model_dir/ynames.txt"
open(ynames_path, "w") do f
    for name in ynames
        write(f, "$name\n")
    end
end

# Save xm (normalization means for inputs) to a file
xm_path = "$model_dir/xm.txt"
open(xm_path, "w") do f
    for value in xm
        write(f, "$value\n")
    end
end

# Save xσ (normalization standard deviations for inputs) to a file
xσ_path = "$model_dir/xsigma.txt"
open(xσ_path, "w") do f
    for value in xσ
        write(f, "$value\n")
    end
end

# Save ym (normalization means for outputs) to a file
ym_path = "$model_dir/ym.txt"
open(ym_path, "w") do f
    for value in ym
        write(f, "$value\n")
    end
end

# Save yσ (normalization standard deviations for outputs) to a file
yσ_path = "$model_dir/ysigma.txt"
open(yσ_path, "w") do f
    for value in yσ
        write(f, "$value\n")
    end
end

In [None]:
# Import ONNX format model to Julia
using ONNXNaiveNASflux
using TGLFNN

# Define the directory where ONNX models and normalization parameters are saved
filename = "sat3_em_mastu_azf-1.bson"
fname = split(filename, ".")[1]
model_dir = "../models/$fname"

# Load xnames and ynames, mean and stdev from text files
xnames = open("$model_dir/xnames.txt", "r") do f
    readlines(f)
end

ynames = open("$model_dir/ynames.txt", "r") do f
    readlines(f)
end

xm = open("$model_dir/xm.txt", "r") do f
    parse.(Float32, readlines(f))
end

xσ = open("$model_dir/xsigma.txt", "r") do f
    parse.(Float32, readlines(f))
end

ym = open("$model_dir/ym.txt", "r") do f
    parse.(Float32, readlines(f))
end

yσ = open("$model_dir/ysigma.txt", "r") do f
    parse.(Float32, readlines(f))
end

# Calculate xbounds and ybounds as mean ± 3.3 sigma
xbounds = [xm .- 3.3 .* xσ  xm .+ 3.3 .* xσ]
ybounds = [ym .- 3.3 .* yσ  ym .+ 3.3 .* yσ]

# Load the ONNX models from saved files
loaded_fluxmodels = []
for i in 1:20
    onnx_path = "$model_dir/$(fname)_model_$i.onnx"
    onnx_model = ONNXNaiveNASflux.load(onnx_path)
    push!(loaded_fluxmodels, onnx_model)
end

# Will want to create an ensemble using the loaded ONNX models and the loaded normalization parameters [not yet implemented]
# ensemble_reloaded = TGLFNN.flux_to_tglfnn(loaded_fluxmodels, "sat2_em_ukstep_azf-1", xnames, ynames, xm, xσ, ym, yσ, xbounds, ybounds)

# Will want to save the loaded onnx ensemble model [not yet implemented]
# TGLFNN.savemodel(ensemble_reloaded, "$model_dir/sat2_em_ukstep_azf-1")

In [None]:
# Test run comparing Julia and ONNX models
using Pkg
Pkg.activate("..")
#Pkg.instantiate()
#Pkg.resolve()
#Pkg.update()
using Test
using ONNXNaiveNASflux
using TGLFNN

# Define the directory where ONNX models and normalization parameters are saved
filename = "sat1geo_em_nstx_azf-1.bson"
fname = split(filename, ".")[1]
model_dir = "../models/$fname"

xtest = Dict(
    "AS_2" => [0.6614277],
    "AS_3" => [0.056261905],
    "BETAE" => [0.00019011849],
    "DEBYE" => [0.055129163],
    "DELTA_LOC" => [0.042043645],
    "DRMAJDX_LOC" => [-0.22123617],
    "DZMAJDX_LOC" => [-0.48096865],
    "KAPPA_LOC" => [2.1243148],
    "P_PRIME_LOC" => [-0.00024872847],
    "Q_LOC" => [2.8812475],
    "Q_PRIME_LOC" => [-23.442652],
    "RLNS_1" => [2.32089],
    "RLNS_2" => [-1.6702774],
    "RLNS_3" => [3.135723],
    "RLTS_1" => [1.168423],
    "RLTS_2" => [2.404535],
    "RLTS_3" => [2.404535],
    "RMAJ_LOC" => [3.7501109],
    "RMIN_LOC" => [0.94476163],
    "S_DELTA_LOC" => [-0.25010625],
    "S_KAPPA_LOC" => [0.047604818],
    "S_ZETA_LOC" => [0.01721785],
    "TAUS_2" => [1.0998588],
    "TAUS_3" => [1.0998588],
    "VEXB_SHEAR" => [-0.055943374],
    "VPAR_1" => [0.23399605],
    "VPAR_SHEAR_1" => [0.6398101],
    "XNUE" => [0.6738171],
    "ZEFF" => [2.6878572],
    "ZETA_LOC" => [-0.14186314],
    "ZMAJ_LOC" => [-0.49191883],
    "ZS_3" => [6.0]
)

# Load Julia model
ensemble = TGLFNN.loadmodelonce(filename)

# Load ONNX model
loaded_fluxmodels = []
for i in 1:20
    onnx_path = "$model_dir/$(fname)_model_$i.onnx"
    onnx_model = ONNXNaiveNASflux.load(onnx_path)
    push!(loaded_fluxmodels, onnx_model)
end

# Prepare the input vector
x = zeros(Float32, length(ensemble.xnames))

for (i, name) in enumerate(ensemble.xnames)
    clean_name = replace(name, "_log10" => "")
    if haskey(xtest, clean_name)
        value = xtest[clean_name][1]  # Extract the scalar from the array
        x[i] = Float32(value)
    else
        error("Key '$clean_name' not found in xtest")
    end
end

# Test models individually
for (i, model) in enumerate(ensemble.models)
    y = model.fluxmodel(x)
    y_onnx = loaded_fluxmodels[i](x)
    println(@test y ≈ y_onnx)
end