In [None]:
using EDM4hep
using EDM4hep.RootIO
using LorentzVectorHEP
using JSON
using ONNXRunTime
using PhysicalConstants
using StructArrays
using JetReconstruction
using JetTaggingFCC

# Simple Jet Flavour Tagging Example

This notebook demonstrates how to:
1. Load EDM4hep event data
2. Reconstruct jets using JetReconstruction
3. Extract features for flavour tagging
4. Run ONNX neural network inference
5. Get flavour probabilities for each jet

In [None]:
# Paths to model files
model_dir = "data/wc_pt_7classes_12_04_2023"
onnx_path = joinpath(model_dir, "fccee_flavtagging_edm4hep_wc_v1.onnx")
json_path = joinpath(model_dir, "fccee_flavtagging_edm4hep_wc_v1.json")

# Check if model files exist
if !isfile(onnx_path)
    error("ONNX model not found at: $onnx_path")
end
if !isfile(json_path)
    error("JSON config not found at: $json_path")
end

println("Loading flavour tagging model...")
model, config = JetTaggingFCC.setup_onnx_runtime(onnx_path, json_path)

println("\nThe model predicts these flavour classes:")
for class_name in config["output_names"]
    println("  - $class_name")
end

In [None]:
# Path to ROOT file with EDM4hep data
edm4hep_path = "data/events_080263084.root"
if !isfile(edm4hep_path)
    error("EDM4hep data file not found at: $edm4hep_path")
end

println("\nLoading EDM4hep data...")
reader = RootIO.Reader(edm4hep_path)
events = RootIO.get(reader, "events")
println("Loaded $(length(events)) events")

# Process a specific event (event #12 as in the script)
event_id = 12
println("\nProcessing event #$event_id")
evt = events[event_id]

# Get reconstructed particles and tracks
recps = RootIO.get(reader, evt, "ReconstructedParticles")
tracks = RootIO.get(reader, evt, "EFlowTrack_1")

# Get MC particles and links for vertex information
mcps = RootIO.get(reader, evt, "Particle")
MCRecoLinks = RootIO.get(reader, evt, "MCRecoAssociations")

# Extract MC vertices for each reconstructed particle
mc_vertices = Vector{LorentzVector{Float32}}(undef, length(recps))
reco_to_mc = Dict(link.rec_idx.index => link.sim_idx.index for link in MCRecoLinks)
for (rec_idx, mc_idx) in reco_to_mc
    if rec_idx < length(recps) && mc_idx < length(mcps)
        mc_vertices[rec_idx + 1] = LorentzVector(Float32(mcps[mc_idx + 1].vertex.x),
                                                 Float32(mcps[mc_idx + 1].vertex.y),
                                                 Float32(mcps[mc_idx + 1].vertex.z),
                                                 Float32(mcps[mc_idx + 1].time))
    end
end

# Fill any missing vertices with (0,0,0,0)
for i in 1:length(recps)
    if !isassigned(mc_vertices, i)
        mc_vertices[i] = LorentzVector(0.0f0, 0.0f0, 0.0f0, 0.0f0)
    end
end

# Get needed collections for feature extraction
bz = RootIO.get(reader, evt, "magFieldBz", register = false)[1]
trackdata = RootIO.get(reader, evt, "EFlowTrack")
trackerhits = RootIO.get(reader, evt, "TrackerHits")
gammadata = RootIO.get(reader, evt, "EFlowPhoton")
nhdata = RootIO.get(reader, evt, "EFlowNeutralHadron")
calohits = RootIO.get(reader, evt, "CalorimeterHits")
dNdx = RootIO.get(reader, evt, "EFlowTrack_2")
track_L = RootIO.get(reader, evt, "EFlowTrack_L", register = false)

println("  - $(length(recps)) reconstructed particles")
println("  - $(length(tracks)) tracks")
println("  - Magnetic field Bz = $bz T")

# Print the primary vertex that will be used
primary_vertex = LorentzVector(0.0f0, 0.0f0, 0.0f0, 0.0f0)
for vertex in mc_vertices
    if vertex.x != 0.0 || vertex.y != 0.0 || vertex.z != 0.0
        primary_vertex = vertex
        break
    end
end
println("  - Primary vertex: ($(round(primary_vertex.x, digits=3)), $(round(primary_vertex.y, digits=3)), $(round(primary_vertex.z, digits=3))) mm")

In [None]:
# Reconstruct jets
println("\nReconstructing jets...")
cs = jet_reconstruct(recps; p = 1.0, R = 2.0, algorithm = JetAlgorithm.EEKt)

# Get 2 exclusive jets
jets = exclusive_jets(cs; njets = 2, T = EEJet)
println("Found $(length(jets)) jets")

# Print jet properties
for (i, jet) in enumerate(jets)
    println("\nJet $i:")
    println("  - Energy: $(round(jet.E, digits=2)) GeV")
    println("  - Pt: $(round(JetReconstruction.pt(jet), digits=2)) GeV")
    println("  - Eta: $(round(JetReconstruction.eta(jet), digits=3))")
    println("  - Phi: $(round(JetReconstruction.phi(jet), digits=3))")
    println("  - Mass: $(round(JetReconstruction.mass(jet), digits=2)) GeV")
end

In [None]:
# Get jet constituents
println("\nExtracting jet constituents...")
constituent_indices = [constituent_indexes(jet, cs) for jet in jets]

jet_constituents = JetTaggingFCC.build_constituents_cluster(recps, constituent_indices)

for (i, constituents) in enumerate(jet_constituents)
    println("  - Jet $i has $(length(constituents)) constituents")
end

In [None]:
# Extract features for flavour tagging
println("\nExtracting features for flavour tagging...")
feature_data = JetTaggingFCC.extract_features(jets,
                                              jet_constituents,
                                              tracks,
                                              bz,
                                              track_L,
                                              config,
                                              trackdata,
                                              trackerhits,
                                              gammadata,
                                              nhdata,
                                              calohits,
                                              dNdx,
                                              mc_vertices)

# Prepare input tensors
println("Preparing input tensors...")
input_tensors = JetTaggingFCC.prepare_input_tensor(jet_constituents,
                                                   jets,
                                                   config,
                                                   feature_data)

# Run inference
println("Running neural network inference...")
weights = JetTaggingFCC.get_weights(0,  # Thread slot
                                    feature_data,
                                    jets,
                                    jet_constituents,
                                    config,
                                    model)

In [None]:
# Extract and display results
println("\n" * "="^60)
println("FLAVOUR TAGGING RESULTS")
println("="^60)

for (jet_idx, jet) in enumerate(jets)
    println("\nJet $jet_idx (E=$(round(jet.E, digits=1)) GeV, pT=$(round(JetReconstruction.pt(jet), digits=1)) GeV):")
    println("-"^40)
    
    # Collect scores for this jet
    scores = Float32[]
    labels = String[]
    
    for (i, score_name) in enumerate(config["output_names"])
        score = JetTaggingFCC.get_weight(weights, i - 1)[jet_idx]
        push!(scores, score)
        push!(labels, score_name)
    end
    
    # Sort by probability (descending)
    sorted_indices = sortperm(scores, rev = true)
    
    # Display scores
    for idx in sorted_indices
        label = labels[idx]
        score = scores[idx]
        
        # Handle NaN or invalid scores
        if isnan(score) || isinf(score)
            flavor_map = Dict("recojet_isG" => "Gluon   ",
                              "recojet_isQ" => "Light q ",
                              "recojet_isS" => "Strange ",
                              "recojet_isC" => "Charm   ",
                              "recojet_isB" => "Bottom  ")
            formatted_label = get(flavor_map, label, label)
            println("  $formatted_label: [Invalid score]")
            continue
        end
        
        bar_length = Int(round(score * 30))
        bar = "█"^bar_length
        percentage = round(score * 100, digits = 1)
        
        # Format label
        flavor_map = Dict("recojet_isG" => "Gluon   ",
                          "recojet_isQ" => "Light q ",
                          "recojet_isS" => "Strange ",
                          "recojet_isC" => "Charm   ",
                          "recojet_isB" => "Bottom  ")
        
        formatted_label = get(flavor_map, label, label)
        println("  $formatted_label: $bar $(percentage)%")
    end
    
    # Identify most likely flavour
    max_idx = argmax(scores)
    max_label = labels[max_idx]
    max_score = scores[max_idx]
    
    flavour_name = Dict("recojet_isG" => "gluon",
                        "recojet_isQ" => "light quark",
                        "recojet_isS" => "strange",
                        "recojet_isC" => "charm",
                        "recojet_isB" => "bottom")[max_label]
    
    println("\n  → Most likely: $(flavour_name) ($(round(max_score * 100, digits=1))% confidence)")
end

println("\n" * "="^60)
println("Processing complete!")