diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNConstants.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNConstants.java new file mode 100644 index 0000000000..502d3476c6 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNConstants.java @@ -0,0 +1,42 @@ +package org.jlab.rec.ahdc.AI; + +/** Normalization and graph-construction constants for the GNN track finder. + * Mirrors track-finding/gnn/config.py — keep in sync with the training config. + */ +final class GNNConstants { + private GNNConstants() {} + + static final int NODE_FEAT_DIM = 10; + static final int EDGE_FEAT_DIM = 9; + + // Model architecture parameters (control the minimum graph size at inference). + // GravNet progressive-k reaches 2*k, topk uses k+1 → N_nodes >= 2*k + 2. + // The exported model clamps topk(k+1) to N internally (see + // track-finding/export_torchscript.py::_knn_indices), so any graph with + // >=3 nodes runs without crashing. Smaller graphs can't form any edge + // with the MAX_LAYER_GAP rule anyway, so we skip them here. + static final int MIN_NODES = 3; + + // Graph construction + static final int MAX_LAYER_GAP = 2; + static final double MAX_EDGE_DISTANCE = 35.0; // mm + static final double MAX_EDGE_DIST_SQ = MAX_EDGE_DISTANCE * MAX_EDGE_DISTANCE; + + // Feature normalization + static final double MAX_R = 100.0; // mm + static final double DOCA_STD = 10.0; // mm + static final double Z_HALF_LENGTH = 200.0; // mm + static final double STEREO_ANGLE_MAX = 0.03; // rad + static final double STEREO_SCALE = 1.0 / STEREO_ANGLE_MAX; + + // ATOF abs_layer convention from Python's build_graph + static final int ATOF_BAR_ABS_LAYER = 10; // component == 10 + static final int ATOF_WEDGE_ABS_LAYER = 11; // all other components + + // Track extraction: connected components at a single score threshold, matching + // gnn/evaluate.py (extract_tracks(..., method="cc", threshold=0.1)). Drop tracks + // with fewer than MIN_TRACK_NODES total nodes — same filter evaluate.py applies + // after the method call. + static final double TRACK_SCORE_THRESHOLD = 0.1; + static final int MIN_TRACK_NODES = 3; +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNGraphBuilder.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNGraphBuilder.java new file mode 100644 index 0000000000..fb83898dc3 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNGraphBuilder.java @@ -0,0 +1,224 @@ +package org.jlab.rec.ahdc.AI; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.jlab.geom.prim.Line3D; +import org.jlab.geom.prim.Point3D; +import org.jlab.geom.prim.Vector3D; +import org.jlab.io.base.DataBank; +import org.jlab.rec.ahdc.Hit.Hit; + +/** Builds the graph tensors expected by the exported GNN edge scorer. + * Ports track-finding/gnn/dataset.py::build_graph — must stay byte-compatible + * with the training-time feature layout and normalization. + */ +final class GNNGraphBuilder { + + /** Container for the tensors + node provenance that the caller needs. */ + static final class GraphInput { + final float[][] nodeFeatures; // shape [N, 10] + final long[][] edgeIndex; // shape [2, E] + final float[][] edgeAttr; // shape [E, 9] + /** nodeToSource[i] is the backing Hit for AHDC nodes, or null for ATOF nodes. */ + final Hit[] nodeToSource; + + GraphInput(float[][] nodeFeatures, long[][] edgeIndex, float[][] edgeAttr, Hit[] nodeToSource) { + this.nodeFeatures = nodeFeatures; + this.edgeIndex = edgeIndex; + this.edgeAttr = edgeAttr; + this.nodeToSource = nodeToSource; + } + } + + private GNNGraphBuilder() {} + + /** Build a graph from AHDC hits (required) plus the ATOF::hits bank (optional). */ + static GraphInput build(List ahdcHits, DataBank atofHitsBank) { + int nAhdc = ahdcHits == null ? 0 : ahdcHits.size(); + + // Node state buffers (grow as we append AHDC then ATOF nodes). + List nodeBuf = new ArrayList<>(); // per-node raw floats (see NodeField indexes) + List nodeLine = new ArrayList<>(); // wire line for AHDC; null for ATOF + List nodeHit = new ArrayList<>(); // backing Hit for AHDC; null for ATOF + + // --- AHDC nodes ------------------------------------------------------------- + for (int i = 0; i < nAhdc; i++) { + Hit h = ahdcHits.get(i); + Line3D line = h.getLine(); + if (line == null) continue; // missing geometry → skip (shouldn't happen after setWirePosition) + + Point3D mid = line.midpoint(); + Vector3D dir = line.toVector(); + double len = Math.max(dir.mag(), 1e-12); + double ux = dir.x() / len, uy = dir.y() / len, uz = dir.z() / len; + double stereo = Math.atan2(Math.sqrt(ux*ux + uy*uy), uz); + + int absLayer = (h.getSuperLayerId() - 1) * 2 + (h.getLayerId() - 1); + nodeBuf.add(new double[]{ + absLayer, // 0: abs_layer + h.getPhi(), // 1: phi + h.getRadius(), // 2: r + stereo, // 3: stereo_angle + mid.x(), // 4: x_mid + mid.y(), // 5: y_mid + mid.z(), // 6: z_mid + ux, // 7: ux + uy, // 8: uy + uz, // 9: uz + h.getX(), // 10: x (raw, for edge distance mask) + h.getY(), // 11: y (raw, for edge distance mask) + 0.0, // 12: det_type = 0 (AHDC) + }); + nodeLine.add(line); + nodeHit.add(h); + } + + // --- ATOF nodes ------------------------------------------------------------- + // Deduplicate by (sector, layer, component) — inference-time variant of the + // Python dedup which also keys on track id (only needed at training time). + if (atofHitsBank != null) { + Set seen = new HashSet<>(); + int rows = atofHitsBank.rows(); + for (int r = 0; r < rows; r++) { + int sector = atofHitsBank.getInt("sector", r); + int layer = atofHitsBank.getInt("layer", r); + int component = atofHitsBank.getInt("component", r); + long key = (((long)sector * 1000L) + layer) * 1000L + component; + if (!seen.add(key)) continue; + + double x = atofHitsBank.getFloat("x", r); + double y = atofHitsBank.getFloat("y", r); + double radius = Math.hypot(x, y); + double phi = Math.atan2(y, x); + int absLayer = (component == 10) ? GNNConstants.ATOF_BAR_ABS_LAYER + : GNNConstants.ATOF_WEDGE_ABS_LAYER; + + nodeBuf.add(new double[]{ + absLayer, phi, radius, + 0.0, // stereo + x, y, 0.0, // mid + 0.0, 0.0, 1.0, // (ux, uy, uz) + x, y, // raw x, y (for edge mask) + 1.0, // det_type = 1 (ATOF) + }); + nodeLine.add(null); + nodeHit.add(null); + } + } + + int n = nodeBuf.size(); + if (n < 2) { + return new GraphInput(new float[0][GNNConstants.NODE_FEAT_DIM], + new long[][]{new long[0], new long[0]}, + new float[0][GNNConstants.EDGE_FEAT_DIM], + new Hit[0]); + } + + // --- Node feature tensor [N, 10] -------------------------------------------- + float[][] nodeFeatures = new float[n][GNNConstants.NODE_FEAT_DIM]; + for (int i = 0; i < n; i++) { + double[] v = nodeBuf.get(i); + nodeFeatures[i][0] = (float)(v[0] / 11.0); + nodeFeatures[i][1] = (float)(v[1] / Math.PI); + nodeFeatures[i][2] = (float)(v[2] / GNNConstants.DOCA_STD); + nodeFeatures[i][3] = (float)(v[3] / GNNConstants.STEREO_ANGLE_MAX); + nodeFeatures[i][4] = (float)(v[4] / GNNConstants.MAX_R); + nodeFeatures[i][5] = (float)(v[5] / GNNConstants.MAX_R); + nodeFeatures[i][6] = (float)(v[6] / GNNConstants.Z_HALF_LENGTH); + nodeFeatures[i][7] = (float)(v[7] * GNNConstants.STEREO_SCALE); + nodeFeatures[i][8] = (float)(v[8] * GNNConstants.STEREO_SCALE); + nodeFeatures[i][9] = (float)(v[9]); + } + + // --- Edge construction (directed, layer_gap in [1, MAX_LAYER_GAP]) ----------- + // Mirrors Python's np.where(mask) on a non-symmetric mask. + int[] absLayer = new int[n]; + double[] xRaw = new double[n]; + double[] yRaw = new double[n]; + double[] rRaw = new double[n]; + double[] phiRaw = new double[n]; + double[] stereoRaw = new double[n]; + double[] detTypeRaw = new double[n]; + for (int i = 0; i < n; i++) { + double[] v = nodeBuf.get(i); + absLayer[i] = (int) v[0]; + phiRaw[i] = v[1]; + rRaw[i] = v[2]; + stereoRaw[i] = v[3]; + xRaw[i] = v[10]; + yRaw[i] = v[11]; + detTypeRaw[i] = v[12]; + } + + List edgePairs = new ArrayList<>(); + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + if (i == j) continue; + int gap = absLayer[j] - absLayer[i]; + if (gap < 1 || gap > GNNConstants.MAX_LAYER_GAP) continue; + double dx = xRaw[i] - xRaw[j]; + double dy = yRaw[i] - yRaw[j]; + if (dx*dx + dy*dy > GNNConstants.MAX_EDGE_DIST_SQ) continue; + edgePairs.add(new long[]{i, j}); + } + } + + int e = edgePairs.size(); + long[][] edgeIndex = new long[2][e]; + float[][] edgeAttr = new float[e][GNNConstants.EDGE_FEAT_DIM]; + + for (int k = 0; k < e; k++) { + long[] p = edgePairs.get(k); + int s = (int) p[0]; + int d = (int) p[1]; + edgeIndex[0][k] = s; + edgeIndex[1][k] = d; + + // dphi wrapped into [-pi, pi] + double dphi = phiRaw[s] - phiRaw[d]; + dphi = ((dphi + Math.PI) % (2.0 * Math.PI) + 2.0 * Math.PI) % (2.0 * Math.PI) - Math.PI; + double dlayer = (double)(absLayer[d] - absLayer[s]) / GNNConstants.MAX_LAYER_GAP; + + double doca, z1, z2; + Line3D ls = nodeLine.get(s); + Line3D ld = nodeLine.get(d); + if (ls != null && ld != null) { + doca = ls.distance(ld).length(); + // Python: z1 = cp_d.z, where cp_d is the point on line_s closest to line_d's midpoint + // z2 = cp_s.z, where cp_s is the point on line_d closest to line_s's midpoint + z1 = clampZ(ls.distance(ld.midpoint()).origin().z()); + z2 = clampZ(ld.distance(ls.midpoint()).origin().z()); + } else { + double ex = xRaw[s] - xRaw[d]; + double ey = yRaw[s] - yRaw[d]; + doca = Math.hypot(ex, ey); + z1 = 0.0; + z2 = 0.0; + } + + double edgeDetType = 0.5 * (detTypeRaw[s] + detTypeRaw[d]); + + edgeAttr[k][0] = (float)(dphi / Math.PI); + edgeAttr[k][1] = (float) dlayer; + edgeAttr[k][2] = (float)(doca / GNNConstants.MAX_R); + edgeAttr[k][3] = (float)(z1 / GNNConstants.Z_HALF_LENGTH); + edgeAttr[k][4] = (float)(z2 / GNNConstants.Z_HALF_LENGTH); + edgeAttr[k][5] = (float)(rRaw[s] / GNNConstants.DOCA_STD); + edgeAttr[k][6] = (float)(rRaw[d] / GNNConstants.DOCA_STD); + edgeAttr[k][7] = (float)((stereoRaw[s] - stereoRaw[d]) / (2.0 * GNNConstants.STEREO_ANGLE_MAX)); + edgeAttr[k][8] = (float) edgeDetType; + } + + Hit[] nodeToHit = nodeHit.toArray(new Hit[0]); + return new GraphInput(nodeFeatures, edgeIndex, edgeAttr, nodeToHit); + } + + private static double clampZ(double z) { + if (z < -GNNConstants.Z_HALF_LENGTH) return -GNNConstants.Z_HALF_LENGTH; + if (z > GNNConstants.Z_HALF_LENGTH) return GNNConstants.Z_HALF_LENGTH; + return z; + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNPrediction.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNPrediction.java new file mode 100644 index 0000000000..ddea3f0631 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNPrediction.java @@ -0,0 +1,117 @@ +package org.jlab.rec.ahdc.AI; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.logging.Logger; + +import org.jlab.io.base.DataBank; +import org.jlab.rec.ahdc.Cluster.Cluster; +import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.PreCluster.PreCluster; +import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; +import org.jlab.rec.ahdc.Track.Track; + +/** Orchestrates GNN-based track finding: builds the graph, runs the exported + * edge scorer, extracts tracks via connected components on edge scores + * thresholded at 0.1, and converts each node-set back into a {@link Track} + * carrying per-superlayer Clusters so the downstream helix fit / Kalman + * stages can consume it. + */ +public final class GNNPrediction { + + private static final Logger LOGGER = Logger.getLogger(GNNPrediction.class.getName()); + + public ArrayList prediction(List ahdcHits, + DataBank atofHitsBank, + ModelTrackFindingGNN model) { + ArrayList out = new ArrayList<>(); + if (ahdcHits == null || ahdcHits.isEmpty() || model == null) return out; + + GNNGraphBuilder.GraphInput g = GNNGraphBuilder.build(ahdcHits, atofHitsBank); + int nNodes = g.nodeToSource.length; + int nEdges = g.edgeIndex[0].length; + if (nNodes < GNNConstants.MIN_NODES || nEdges == 0) { + return out; // model cannot run on graphs this small + } + + float[] edgeScores; + try { + edgeScores = model.predictEdgeScores(g.nodeFeatures, g.edgeIndex, g.edgeAttr); + } catch (Exception ex) { + LOGGER.warning(() -> "GNN inference failed: " + ex); + return out; + } + + // Connected components at TRACK_SCORE_THRESHOLD, filtered to + // components of size >= MIN_TRACK_NODES — mirrors gnn/evaluate.py. + List trackNodeSets = SeedExtendTrackExtractor.extract(edgeScores, g.edgeIndex, nNodes); + + for (int[] nodes : trackNodeSets) { + // Collect just the AHDC Hits in this track — ATOF nodes were graph + // context only, they don't belong in AHDC::track or AHDC::hits. + ArrayList trackHits = new ArrayList<>(nodes.length); + for (int n : nodes) { + Hit h = g.nodeToSource[n]; + if (h != null) trackHits.add(h); + } + if (trackHits.isEmpty()) continue; + + ArrayList clusters = buildSuperlayerClusters(trackHits); + if (clusters.size() < 3) continue; // matches the downstream >=3 filter + + out.add(new Track(clusters)); + } + + return out; + } + + /** One {@link Cluster} per superlayer built from two {@link PreCluster}s (one + * per layer within the superlayer). Using real PreClusters — instead of the + * 3-arg {@code Cluster(x,y,z)} constructor — keeps + * {@code Track.generateHitList()} and {@code DocaClusterRefiner}'s stereo + * pairing working for GNN-discovered tracks just like they do for MLP tracks. + */ + private static ArrayList buildSuperlayerClusters(List hits) { + // Feed the track's hits through the same preclustering the MLP path uses. + // findPreclusters mutates its input (it calls setUse(true) on consumed + // hits), so pass a copy and ensure each hit starts unmarked. + ArrayList hitsForPre = new ArrayList<>(hits.size()); + for (Hit h : hits) { h.setUse(false); hitsForPre.add(h); } + PreClusterFinder pcf = new PreClusterFinder(); + pcf.findPreclusters(hitsForPre); + ArrayList preclusters = pcf.get_AHDCPreClusters(); + + // Index by (superlayer, layer). If the GNN assigns two PreClusters of the + // same superlayer+layer to one track (rare — it would mean two disjoint + // wire runs on the same layer), keep the largest and drop the rest. + Map bySuperlayer = new HashMap<>(); + for (PreCluster pc : preclusters) { + int sl = pc.get_Super_layer(); + int layerIdx = pc.get_Layer() - 1; // layer is 1-based, slots are [0,1] + if (layerIdx < 0 || layerIdx > 1) continue; + PreCluster[] slot = bySuperlayer.computeIfAbsent(sl, k -> new PreCluster[2]); + PreCluster prev = slot[layerIdx]; + if (prev == null || pc.get_Num_wire() > prev.get_Num_wire()) slot[layerIdx] = pc; + } + + ArrayList clusters = new ArrayList<>(); + // Iterate superlayers in ascending order to keep downstream output stable. + // If both stereo layers have a PreCluster, pair them (full stereo cluster). + // If only one has hits, use the single-layer Cluster(PreCluster) ctor — + // DocaClusterRefiner handles PreClusters_list.size() != 2 with a + // degenerate DocaCluster fallback, so the helix fit still runs. + for (int sl = 1; sl <= 5; sl++) { + PreCluster[] slot = bySuperlayer.get(sl); + if (slot == null) continue; + if (slot[0] != null && slot[1] != null) { + clusters.add(new Cluster(slot[0], slot[1])); + } else { + PreCluster single = (slot[0] != null) ? slot[0] : slot[1]; + if (single != null) clusters.add(new Cluster(single)); + } + } + return clusters; + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/ModelTrackFindingGNN.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/ModelTrackFindingGNN.java new file mode 100644 index 0000000000..57e8a39b33 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/ModelTrackFindingGNN.java @@ -0,0 +1,96 @@ +package org.jlab.rec.ahdc.AI; + +import java.io.IOException; +import java.nio.file.Paths; + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.NoopTranslator; +import org.jlab.utils.CLASResources; + +/** DJL wrapper around the GravNet TorchScript model exported from + * track-finding/export_torchscript.py. Runs per-event edge scoring. + * + * Exported forward signature (see SingleGraphEdgeScorer): + * forward(x: float32[N, 10], edge_index: int64[2, E], edge_attr: float32[E, 9]) + * -> float32[E] (sigmoid edge scores in [0, 1]) + */ +public class ModelTrackFindingGNN { + + private final ZooModel model; + /** Reused across every call. DJL's Predictor is NOT thread-safe, but the + * ALERT reconstruction engine is single-threaded, so one instance is fine + * and avoids the allocation/libtorch-graph-prep cost per event that + * dominated predictEdgeScores on small graphs. + */ + private final Predictor predictor; + + public ModelTrackFindingGNN() { + // Let libtorch pick sensible defaults: GravNet's cdist + topk + gather + // chain benefits from the graph optimizer and intra-op parallelism the + // MLP copy-paste was pinning off. Keep num_interop_threads=1 — there is + // only one event in flight at a time. + System.setProperty("ai.djl.pytorch.num_interop_threads", "1"); + + String path = CLASResources.getResourcePath("etc/data/nnet/rg-l/model_AHDC_GNN/"); + Criteria criteria = Criteria.builder() + .setTypes(NDList.class, NDList.class) + .optModelPath(Paths.get(path)) + .optEngine("PyTorch") + .optTranslator(new NoopTranslator()) + .optProgress(new ProgressBar()) + .build(); + + try { + model = criteria.loadModel(); + } catch (IOException | ModelNotFoundException | MalformedModelException ex) { + throw new RuntimeException(ex); + } + predictor = model.newPredictor(new NoopTranslator()); + } + + /** Score every edge in the input graph. + * + * @param nodeFeatures shape [N, 10] — see GNNConstants.NODE_FEAT_DIM + * @param edgeIndex shape [2, E] — int64 source / destination node ids + * @param edgeAttr shape [E, 9] — see GNNConstants.EDGE_FEAT_DIM + * @return float[E] of sigmoid edge scores in [0, 1] + */ + public float[] predictEdgeScores(float[][] nodeFeatures, long[][] edgeIndex, float[][] edgeAttr) throws Exception { + if (nodeFeatures == null || nodeFeatures.length == 0) return new float[0]; + int n = nodeFeatures.length; + int e = edgeIndex[0].length; + if (e == 0) return new float[0]; + + try (NDManager manager = NDManager.newBaseManager()) { + // Flatten x into a contiguous float[] so DJL builds a [N, node_dim] tensor. + int nodeDim = nodeFeatures[0].length; + float[] xFlat = new float[n * nodeDim]; + for (int i = 0; i < n; i++) System.arraycopy(nodeFeatures[i], 0, xFlat, i * nodeDim, nodeDim); + NDArray x = manager.create(xFlat, new Shape(n, nodeDim)); + + // edge_index is int64[2, E]; flatten row-major. + long[] edgeIndexFlat = new long[2 * e]; + System.arraycopy(edgeIndex[0], 0, edgeIndexFlat, 0, e); + System.arraycopy(edgeIndex[1], 0, edgeIndexFlat, e, e); + NDArray edgeIndexNd = manager.create(edgeIndexFlat, new Shape(2, e)); + + int edgeDim = edgeAttr[0].length; + float[] edgeAttrFlat = new float[e * edgeDim]; + for (int i = 0; i < e; i++) System.arraycopy(edgeAttr[i], 0, edgeAttrFlat, i * edgeDim, edgeDim); + NDArray edgeAttrNd = manager.create(edgeAttrFlat, new Shape(e, edgeDim)); + + NDList output = predictor.predict(new NDList(x, edgeIndexNd, edgeAttrNd)); + NDArray scoresNd = output.get(0); + return scoresNd.toFloatArray(); + } + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/SeedExtendTrackExtractor.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/SeedExtendTrackExtractor.java new file mode 100644 index 0000000000..abe3a9d4f1 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/SeedExtendTrackExtractor.java @@ -0,0 +1,63 @@ +package org.jlab.rec.ahdc.AI; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Track extraction from per-edge scores via union-find connected components + * at a single threshold. Ports the {@code method="cc"} branch of + * {@code track-finding/gnn/inference.py::extract_tracks}, which is the + * extractor that gnn/evaluate.py uses. + */ +final class SeedExtendTrackExtractor { + + private SeedExtendTrackExtractor() {} + + /** @return list of node-index arrays, one per connected component of size + * ≥ {@link GNNConstants#MIN_TRACK_NODES}. */ + static List extract(float[] scores, long[][] edgeIndex, int nNodes) { + if (nNodes <= 0 || scores == null || edgeIndex == null || edgeIndex[0].length != scores.length) { + return new ArrayList<>(); + } + long[] src = edgeIndex[0]; + long[] dst = edgeIndex[1]; + + int[] parent = new int[nNodes]; + for (int i = 0; i < nNodes; i++) parent[i] = i; + for (int e = 0; e < scores.length; e++) { + if (scores[e] >= GNNConstants.TRACK_SCORE_THRESHOLD) { + union(parent, (int) src[e], (int) dst[e]); + } + } + + Map> groups = new HashMap<>(); + for (int i = 0; i < nNodes; i++) { + int r = find(parent, i); + groups.computeIfAbsent(r, k -> new ArrayList<>()).add(i); + } + + List out = new ArrayList<>(); + for (List members : groups.values()) { + if (members.size() < GNNConstants.MIN_TRACK_NODES) continue; + int[] arr = new int[members.size()]; + for (int i = 0; i < arr.length; i++) arr[i] = members.get(i); + out.add(arr); + } + return out; + } + + private static int find(int[] parent, int x) { + while (parent[x] != x) { + parent[x] = parent[parent[x]]; + x = parent[x]; + } + return x; + } + + private static void union(int[] parent, int a, int b) { + int ra = find(parent, a); + int rb = find(parent, b); + if (ra != rb) parent[ra] = rb; + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/Cluster.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/Cluster.java index f0df2e98da..d183294c2c 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/Cluster.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/Cluster.java @@ -86,6 +86,35 @@ public Cluster(double X, double Y, double Z) { this._Z = Z; } + /** Build a Cluster from a single PreCluster (one layer of a superlayer). + * Used by the GNN path when a track covers a superlayer on only one + * stereo layer — no stereo pair is available, so Z is taken from the + * average wire-midpoint z of the PreCluster's hits rather than from a + * stereo-angle computation. DocaClusterRefiner falls back to a degenerate + * DocaCluster when {@code get_PreClusters_list().size() != 2}, so + * downstream is unaffected. */ + public Cluster(PreCluster precluster) { + this._PreClusters_list = new ArrayList<>(); + _PreClusters_list.add(precluster); + this._Radius = precluster.get_Radius(); + this._Phi = precluster.get_Phi(); + this._X = precluster.get_X(); + this._Y = precluster.get_Y(); + this._Num_wire = (int) precluster.get_Num_wire(); + double r2 = this._X * this._X + this._Y * this._Y; + if (r2 > 0.0) { + this._U = this._X / r2; + this._V = this._Y / r2; + } + double zSum = 0.0; + int zCount = 0; + for (Hit h : precluster.get_hits_list()) { + Line3D line = h.getLine(); + if (line != null) { zSum += line.midpoint().z(); zCount++; } + } + this._Z = (zCount > 0) ? zSum / zCount : 0.0; + } + @Override public String toString() { return "Cluster{" + "_X=" + _X + ", _Y=" + _Y + ", _Z=" + _Z + '}'; diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/ModeTrackFinding.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/ModeTrackFinding.java index ec500c3ad9..72a260ba52 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/ModeTrackFinding.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/ModeTrackFinding.java @@ -1,7 +1,8 @@ package org.jlab.rec.ahdc; public enum ModeTrackFinding { - AI_Track_Finding, + MLP_Track_Finding, CV_Distance, CV_Hough, + GNN_Track_Finding, } diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/AITrackFinder.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/AITrackFinder.java new file mode 100644 index 0000000000..93f864a795 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/AITrackFinder.java @@ -0,0 +1,93 @@ +package org.jlab.rec.ahdc.TrackFinding; + +import org.jlab.rec.ahdc.AI.AIPrediction; +import org.jlab.rec.ahdc.AI.InterCluster; +import org.jlab.rec.ahdc.AI.ModelTrackFinding; +import org.jlab.rec.ahdc.AI.PreClustering; +import org.jlab.rec.ahdc.AI.TrackCandidatesGenerator; +import org.jlab.rec.ahdc.AI.TrackPrediction; +import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.PreCluster.PreCluster; +import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; +import org.jlab.rec.ahdc.Track.Track; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Set; +import java.util.logging.Logger; + +public class AITrackFinder implements TrackFinder { + + private static final Logger LOGGER = Logger.getLogger(AITrackFinder.class.getName()); + + private static final double TRACK_FINDING_AI_THRESHOLD = 0.2; + private static final int MAX_HITS_FOR_AI = 300; + + private final ModelTrackFinding model; + private final TrackFinder fallback; + + public AITrackFinder() { + this.model = new ModelTrackFinding(); + this.fallback = new DistanceTrackFinder(); + } + + @Override + public TrackFinderResult findTracks(ArrayList hits) { + // Safety: too many hits → fall back to conventional track finding for this event + if (hits.size() > MAX_HITS_FOR_AI) { + LOGGER.info("Too many AHDC_Hits in AHDC::adc, rely on conventional track finding for this event"); + return fallback.findTracks(hits); + } + + // Preclustering + PreClusterFinder pcf = new PreClusterFinder(); + pcf.findPreclusters(hits); + ArrayList preclusters = pcf.get_AHDCPreClusters(); + + // 1) Create inter-clusters from pre-clusters + PreClustering preClustering = new PreClustering(); + ArrayList inter_clusters = preClustering.mergePreclusters(preclusters); + + // 2) Create track candidates from inter-clusters + ArrayList> tracks_candidates = new ArrayList<>(); + TrackCandidatesGenerator trackCandidatesGenerator = new TrackCandidatesGenerator(); + boolean success = trackCandidatesGenerator.getAllPossibleTrack(inter_clusters, tracks_candidates); + + if (!success) { + LOGGER.severe("Too many track candidates find by the AI, exiting..."); + return TrackFinderResult.invalid(); + } + + // 3) Use AI model to evaluate track candidates + ArrayList predictions; + try { + AIPrediction aiPrediction = new AIPrediction(); + predictions = aiPrediction.prediction(tracks_candidates, model); + } catch (Exception e) { + throw new RuntimeException(e); + } + + // 4) Select good tracks via greedy non-overlap: sort predictions by score + // descending, accept the highest-scoring prediction, mark its PreClusters + // as claimed, and skip any later prediction that reuses a claimed PreCluster. + // The AI candidate generator routinely emits overlapping predictions (each + // PreCluster can feed several combinations), and because set_trackId mutates + // the shared Hit references in place, a naive "accept all above threshold" + // pass would let later tracks silently steal earlier tracks' hits and leave + // them orphaned in AHDC::hits. Greedy selection enforces one-hit-one-track. + predictions.sort((a, b) -> Float.compare(b.getPrediction(), a.getPrediction())); + Set claimedPreclusters = new HashSet<>(); + ArrayList tracks = new ArrayList<>(); + for (TrackPrediction t : predictions) { + if (t.getPrediction() <= TRACK_FINDING_AI_THRESHOLD) continue; + boolean overlaps = false; + for (PreCluster pc : t.getPreclusters()) { + if (claimedPreclusters.contains(pc)) { overlaps = true; break; } + } + if (overlaps) continue; + claimedPreclusters.addAll(t.getPreclusters()); + tracks.add(new Track(t.getClusters())); + } + return TrackFinderResult.ok(tracks); + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/DistanceTrackFinder.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/DistanceTrackFinder.java new file mode 100644 index 0000000000..bc4f0c0dc0 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/DistanceTrackFinder.java @@ -0,0 +1,31 @@ +package org.jlab.rec.ahdc.TrackFinding; + +import org.jlab.rec.ahdc.Cluster.Cluster; +import org.jlab.rec.ahdc.Cluster.ClusterFinder; +import org.jlab.rec.ahdc.Distance.Distance; +import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.PreCluster.PreCluster; +import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; +import org.jlab.rec.ahdc.Track.Track; + +import java.util.ArrayList; + +public class DistanceTrackFinder implements TrackFinder { + + @Override + public TrackFinderResult findTracks(ArrayList hits) { + PreClusterFinder pcf = new PreClusterFinder(); + pcf.findPreclusters(hits); + ArrayList preclusters = pcf.get_AHDCPreClusters(); + + ClusterFinder cf = new ClusterFinder(); + cf.findCluster(preclusters); + ArrayList clusters = cf.get_AHDCClusters(); + + Distance distance = new Distance(); + distance.find_track(clusters); + ArrayList tracks = distance.get_AHDCTracks(); + + return TrackFinderResult.ok(tracks); + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/GNNTrackFinder.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/GNNTrackFinder.java new file mode 100644 index 0000000000..cac719bf84 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/GNNTrackFinder.java @@ -0,0 +1,52 @@ +package org.jlab.rec.ahdc.TrackFinding; + +import org.jlab.io.base.DataBank; +import org.jlab.rec.ahdc.AI.GNNPrediction; +import org.jlab.rec.ahdc.AI.ModelTrackFindingGNN; +import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.Track.Track; + +import java.util.ArrayList; +import java.util.logging.Logger; + +/** GravNet-based track finder. Builds a per-event hit graph from the AHDC + * hits and (when present) the ATOF::hits bank, runs the exported edge + * scorer, extracts tracks via connected components on edges with score + * ≥ 0.1, and packages each surviving track as a {@link Track} backed by + * per-superlayer {@link org.jlab.rec.ahdc.Cluster.Cluster}s. + */ +public class GNNTrackFinder implements TrackFinder { + + private static final Logger LOGGER = Logger.getLogger(GNNTrackFinder.class.getName()); + + /** Above this hit count the graph builder + GNN inference is too slow to + * be useful, so the event is skipped (no GNN tracks produced). */ + private static final int MAX_HITS_FOR_GNN = 500; + + private final ModelTrackFindingGNN model; + private final GNNPrediction predictor; + + public GNNTrackFinder() { + this.model = new ModelTrackFindingGNN(); + this.predictor = new GNNPrediction(); + } + + /** Without an ATOF bank the GNN still runs on AHDC-only graphs. */ + @Override + public TrackFinderResult findTracks(ArrayList hits) { + return findTracks(hits, null); + } + + @Override + public TrackFinderResult findTracks(ArrayList ahdcHits, DataBank atofHitsBank) { + if (ahdcHits == null || ahdcHits.size() > MAX_HITS_FOR_GNN) { + if (ahdcHits != null) { + LOGGER.info("Too many AHDC_Hits in AHDC::hits, skipping GNN track finding for this event"); + } + return TrackFinderResult.ok(new ArrayList<>()); + } + + ArrayList tracks = predictor.prediction(ahdcHits, atofHitsBank, model); + return TrackFinderResult.ok(tracks); + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/HoughTrackFinder.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/HoughTrackFinder.java new file mode 100644 index 0000000000..21880f35b5 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/HoughTrackFinder.java @@ -0,0 +1,31 @@ +package org.jlab.rec.ahdc.TrackFinding; + +import org.jlab.rec.ahdc.Cluster.Cluster; +import org.jlab.rec.ahdc.Cluster.ClusterFinder; +import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.HoughTransform.HoughTransform; +import org.jlab.rec.ahdc.PreCluster.PreCluster; +import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; +import org.jlab.rec.ahdc.Track.Track; + +import java.util.ArrayList; + +public class HoughTrackFinder implements TrackFinder { + + @Override + public TrackFinderResult findTracks(ArrayList hits) { + PreClusterFinder pcf = new PreClusterFinder(); + pcf.findPreclusters(hits); + ArrayList preclusters = pcf.get_AHDCPreClusters(); + + ClusterFinder cf = new ClusterFinder(); + cf.findCluster(preclusters); + ArrayList clusters = cf.get_AHDCClusters(); + + HoughTransform hough = new HoughTransform(); + hough.find_tracks(clusters); + ArrayList tracks = hough.get_AHDCTracks(); + + return TrackFinderResult.ok(tracks); + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/TrackFinder.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/TrackFinder.java new file mode 100644 index 0000000000..e0d40f803d --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/TrackFinder.java @@ -0,0 +1,20 @@ +package org.jlab.rec.ahdc.TrackFinding; + +import org.jlab.io.base.DataBank; +import org.jlab.rec.ahdc.Hit.Hit; + +import java.util.ArrayList; + +public interface TrackFinder { + + /** AHDC-only track finding. Implementations that don't need ATOF context + * (MLP / Distance / Hough) only need to override this method. */ + TrackFinderResult findTracks(ArrayList hits); + + /** Track finding with ATOF context (e.g. GNN, which builds a joint + * AHDC + ATOF hit graph). The default delegates to the AHDC-only + * version, ignoring the ATOF bank. */ + default TrackFinderResult findTracks(ArrayList ahdcHits, DataBank atofHitsBank) { + return findTracks(ahdcHits); + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/TrackFinderResult.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/TrackFinderResult.java new file mode 100644 index 0000000000..30fdd9e0ae --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/TrackFinderResult.java @@ -0,0 +1,34 @@ +package org.jlab.rec.ahdc.TrackFinding; + +import org.jlab.rec.ahdc.Track.Track; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class TrackFinderResult { + + private final List tracks; + private final boolean valid; + + public TrackFinderResult(List tracks, boolean valid) { + this.tracks = tracks; + this.valid = valid; + } + + public static TrackFinderResult ok(List tracks) { + return new TrackFinderResult(tracks, true); + } + + public static TrackFinderResult invalid() { + return new TrackFinderResult(Collections.emptyList(), false); + } + + public List getTracks() { + return tracks; + } + + public boolean isValid() { + return valid; + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/service/ahdc/AHDCEngine.java b/reconstruction/alert/src/main/java/org/jlab/service/ahdc/AHDCEngine.java index 66b965825c..f5beeb1591 100644 --- a/reconstruction/alert/src/main/java/org/jlab/service/ahdc/AHDCEngine.java +++ b/reconstruction/alert/src/main/java/org/jlab/service/ahdc/AHDCEngine.java @@ -5,23 +5,13 @@ import org.jlab.io.base.DataEvent; import org.jlab.io.hipo.HipoDataSource; import org.jlab.io.hipo.HipoDataSync; -import org.jlab.rec.ahdc.AI.*; import org.jlab.rec.ahdc.Banks.RecoBankWriter; -import org.jlab.rec.ahdc.Cluster.Cluster; -import org.jlab.rec.ahdc.Cluster.ClusterFinder; -import org.jlab.rec.ahdc.DocaCluster.DocaClusterRefiner; -import org.jlab.rec.ahdc.DocaCluster.DocaCluster; -import org.jlab.rec.ahdc.Distance.Distance; -import org.jlab.rec.ahdc.HelixFit.HelixFitJava; import org.jlab.rec.ahdc.Hit.Hit; import org.jlab.rec.ahdc.Hit.HitReader; -import org.jlab.rec.ahdc.HoughTransform.HoughTransform; -import org.jlab.rec.ahdc.PreCluster.PreCluster; -import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; -import org.jlab.rec.ahdc.Track.Track; -import org.jlab.rec.ahdc.ModeTrackFinding; import java.io.File; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; import java.util.logging.Logger; import org.jlab.detector.calib.utils.DatabaseConstantProvider; @@ -32,22 +22,14 @@ /** AHDCEngine reconstruction service. * - * AHDC Reconstruction using only AHDC information. - * - * Reconstruction utilizing other detectors (i.e. ATOF) are - * implemented in ALERTEngine. + * Reads AHDC::adc, applies calibration, and writes calibrated AHDC::hits. * + * Track finding (preclustering, AI/CV track finders, DOCA refinement, + * helix fit) lives in ALERTEngine. */ public class AHDCEngine extends ReconstructionEngine { static final Logger LOGGER = Logger.getLogger(AHDCEngine.class.getName()); - private boolean simulation = false; - - private ModelTrackFinding modelTrackFinding; - private ModeTrackFinding modeTrackFinding = ModeTrackFinding.AI_Track_Finding; - static final double TRACK_FINDING_AI_THRESHOLD = 0.2; - static final int MAX_HITS_FOR_AI = 300; - private AlertDCDetector factory = null; private ModeAHDC ahdcExtractor = new ModeAHDC(); @@ -62,20 +44,11 @@ public class AHDCEngine extends ReconstructionEngine { public AHDCEngine() { super("ALERT", "ouillon", "1.0.1"); } - public boolean init(ModeTrackFinding m) { - modeTrackFinding = m; - return init(); - } - @Override public boolean init() { factory = (new AlertDCFactory()).createDetectorCLAS(new DatabaseConstantProvider()); - String modeConfig = this.getEngineConfigString("Mode"); - if (modeConfig != null) modeTrackFinding = ModeTrackFinding.valueOf(modeConfig); - if (modeTrackFinding == ModeTrackFinding.AI_Track_Finding) modelTrackFinding = new ModelTrackFinding(); - Map tableMap = new HashMap<>(); tableMap.put("/calibration/alert/ahdc/time_offsets", 3); tableMap.put("/calibration/alert/ahdc/time_to_distance_wire", 3); @@ -84,8 +57,8 @@ public boolean init() { tableMap.put("/calibration/alert/ahdc/time_over_threshold", 3); requireConstants(tableMap); - this.getConstantsManager().setVariation("default"); - this.registerOutputBank("AHDC::hits","AHDC::preclusters","AHDC::clusters","AHDC::track","AHDC::mc","AHDC::ai:prediction","AHDC::interclusters","AHDC::docaclusters"); + this.getConstantsManager().setVariation("default"); + this.registerOutputBank("AHDC::hits"); return true; } @@ -93,8 +66,6 @@ public boolean init() { @Override public boolean processDataEvent(DataEvent event) { - if(event.hasBank("MC::Particle")) simulation = true; - ahdcExtractor.update(30, null, event, "AHDC::wf", "AHDC::adc"); if (event.hasBank("RUN::config")) { @@ -115,155 +86,15 @@ public boolean processDataEvent(DataEvent event) { } if (event.hasBank("AHDC::adc")) { - // I) Read raw hits + boolean simulation = event.hasBank("MC::Particle"); HitReader hitReader = new HitReader(event, factory, simulation, ahdcRawHitCutsTable, ahdcTimeOffsetsTable, ahdcTimeToDistanceWireTable, ahdcTimeOverThresholdTable, ahdcAdcGainsTable); ArrayList AHDC_Hits = hitReader.get_AHDCHits(); - // II) Create PreClusters - PreClusterFinder preclusterfinder = new PreClusterFinder(); - preclusterfinder.findPreclusters(AHDC_Hits); - ArrayList AHDC_PreClusters = preclusterfinder.get_AHDCPreClusters(); - - - // III) Track Finding: Input = PreClusters, Output = Tracks - // During track finding we build Clusters and InterClusters. Each of these objects must be assigned a Track ID so we can: - // - identify which track they belong to, - // - write them properly into the output banks later, - // - and reuse them downstream in the ALERT Engine. - // - // If using AI-based track finding, tracks are identified using inter-clusters. - // Otherwise, the conventional methods (Hough Transform or distance) use clusters. - - // Safety check: if too many hits, rely on conventional track finding - ModeTrackFinding effectiveMode = modeTrackFinding; - if (AHDC_Hits.size() > MAX_HITS_FOR_AI) { - LOGGER.info("Too many AHDC_Hits in AHDC::adc, rely on conventional track finding for this event"); - effectiveMode = ModeTrackFinding.CV_Distance; - } - - ArrayList AHDC_Tracks = new ArrayList<>(); - - if (effectiveMode == ModeTrackFinding.AI_Track_Finding) { - // 1) Create inter-clusters from pre-clusters - PreClustering preClustering = new PreClustering(); - ArrayList inter_clusters = preClustering.mergePreclusters(AHDC_PreClusters); - - // 2) Create track candidates from inter-clusters - ArrayList> tracks_candidates = new ArrayList<>(); - TrackCandidatesGenerator trackCandidatesGenerator = new TrackCandidatesGenerator(); - boolean success = trackCandidatesGenerator.getAllPossibleTrack(inter_clusters, tracks_candidates); - - if (!success) { - LOGGER.severe("Too many track candidates find by the AI, exiting..."); - return false; - } - - // 3) Use AI model to evaluate track candidates - ArrayList predictions = new ArrayList<>(); - try { - AIPrediction aiPrediction = new AIPrediction(); - predictions = aiPrediction.prediction(tracks_candidates, modelTrackFinding); - } catch (Exception e) { - throw new RuntimeException(e); - } - - // 4) Select good tracks via greedy non-overlap: sort predictions by score - // descending, accept the highest-scoring prediction, mark its PreClusters - // as claimed, and skip any later prediction that reuses a claimed PreCluster. - // The AI candidate generator routinely emits overlapping predictions (each - // PreCluster can feed several combinations), and because set_trackId mutates - // the shared Hit references in place, a naive "accept all above threshold" - // pass would let later tracks silently steal earlier tracks' hits and leave - // them orphaned in AHDC::hits. Greedy selection enforces one-hit-one-track. - predictions.sort((a, b) -> Float.compare(b.getPrediction(), a.getPrediction())); - Set claimedPreclusters = new HashSet<>(); - for (TrackPrediction t : predictions) { - if (t.getPrediction() <= TRACK_FINDING_AI_THRESHOLD) continue; - boolean overlaps = false; - for (PreCluster pc : t.getPreclusters()) { - if (claimedPreclusters.contains(pc)) { overlaps = true; break; } - } - if (overlaps) continue; - claimedPreclusters.addAll(t.getPreclusters()); - AHDC_Tracks.add(new Track(t.getClusters())); - } - } - else { - // Conventional Track Finding: Hough Transform or Distance: use cluster informations to find tracks - // 1) Create clusters from pre-clusters - ClusterFinder clusterfinder = new ClusterFinder(); - clusterfinder.findCluster(AHDC_PreClusters); - ArrayList AHDC_Clusters = clusterfinder.get_AHDCClusters(); - - // 2) Find tracks using the selected conventional method - if (effectiveMode == ModeTrackFinding.CV_Distance) { - Distance distance = new Distance(); - distance.find_track(AHDC_Clusters); - AHDC_Tracks = distance.get_AHDCTracks(); - } - else if (effectiveMode == ModeTrackFinding.CV_Hough) { - HoughTransform houghtransform = new HoughTransform(); - houghtransform.find_tracks(AHDC_Clusters); - AHDC_Tracks = houghtransform.get_AHDCTracks(); - } - } - - - //Temporary track method ONLY for MC with no background; - //AHDC_Tracks.add(new Track(AHDC_Hits)); - - // V) Global fit - int trackid = 0; - ArrayList all_docaClusters = new ArrayList<>(); - AHDC_Tracks.removeIf(track -> track.get_Clusters().size() < 3); - for (Track track : AHDC_Tracks) { - trackid++; - track.set_trackId(trackid); - List originalClusters = track.get_Clusters(); - ArrayList docaClusters = DocaClusterRefiner.buildRefinedClusters(originalClusters); - all_docaClusters.addAll(docaClusters); - if (docaClusters == null || docaClusters.size() < 3 || originalClusters == null || originalClusters.size() < 3) { - // not enough points, skip helix fit - continue; - } - HelixFitJava h = new HelixFitJava(); - track.setPositionAndMomentum(h.helix_fit_with_doca_selection(docaClusters, 1)); - } - - // VII) Write bank RecoBankWriter writer = new RecoBankWriter(); - - DataBank recoHitsBank = writer.fillAHDCHitsBank(event, AHDC_Hits); - DataBank recoPreClusterBank = writer.fillPreClustersBank(event, AHDC_PreClusters); - ArrayList AHDC_Clusters = new ArrayList<>(); - for (Track track : AHDC_Tracks) { - AHDC_Clusters.addAll(track.get_Clusters()); - } - DataBank recoClusterBank = writer.fillClustersBank(event, AHDC_Clusters); - DataBank recoTracksBank = writer.fillAHDCTrackBank(event, AHDC_Tracks); - DataBank clustersDocaBank = writer.fillAHDCDocaClustersBank(event, all_docaClusters); - - ArrayList all_interclusters = new ArrayList<>(); - for (Track track : AHDC_Tracks) { - all_interclusters.addAll(track.getInterclusters()); - } - DataBank recoInterClusterBank = writer.fillInterClusterBank(event, all_interclusters); - - //event.removeBanks("AHDC::hits","AHDC::preclusters","AHDC::clusters","AHDC::track","AHDC::kftrack","AHDC::mc","AHDC::ai:prediction"); - event.appendBank(recoHitsBank); - event.appendBank(recoPreClusterBank); - event.appendBank(recoClusterBank); - event.appendBank(recoTracksBank); - event.appendBank(recoInterClusterBank); - event.appendBank(clustersDocaBank); - - if (simulation) { - DataBank recoMCBank = writer.fillAHDCMCTrackBank(event); - event.appendBank(recoMCBank); - } - + DataBank recoHitsBank = writer.fillAHDCHitsBank(event, AHDC_Hits); + if (recoHitsBank != null) event.appendBank(recoHitsBank); } return true; } @@ -274,7 +105,6 @@ public static void main(String[] args) { int nEvent = 0; int maxEvent = 10; - int myEvent = 3; String inputFile = "output1.hipo"; String outputFile = "output.hipo"; @@ -286,24 +116,17 @@ public static void main(String[] args) { HipoDataSource reader = new HipoDataSource(); - // en.init(); - en.init(ModeTrackFinding.AI_Track_Finding); + en.init(); reader.open(inputFile); - // SchemaFactory factory = reader.getReader().getSchemaFactory(); HipoDataSync writer = new HipoDataSync(); writer.open(outputFile); while (reader.hasEvent() && nEvent < maxEvent) { nEvent++; - // if (nEvent % 100 == 0) System.out.println("nEvent = " + nEvent); DataEvent event = reader.getNextEvent(); System.out.println("Event: " + nEvent); - // if (nEvent != myEvent) continue; - // System.out.println("*********** NEXT EVENT ************"); - // event.show(); - en.processDataEvent(event); writer.writeEvent(event); @@ -312,4 +135,4 @@ public static void main(String[] args) { System.out.println("finished " + (System.nanoTime() - starttime) * Math.pow(10, -9)); } -} \ No newline at end of file +} diff --git a/reconstruction/alert/src/main/java/org/jlab/service/alert/ALERTEngine.java b/reconstruction/alert/src/main/java/org/jlab/service/alert/ALERTEngine.java index a899e20c8a..6d63fcb10e 100644 --- a/reconstruction/alert/src/main/java/org/jlab/service/alert/ALERTEngine.java +++ b/reconstruction/alert/src/main/java/org/jlab/service/alert/ALERTEngine.java @@ -24,13 +24,28 @@ import org.jlab.rec.alert.banks.RecoBankWriter; import org.jlab.rec.alert.projections.TrackProjector; import org.jlab.rec.atof.hit.ATOFHit; +import org.jlab.rec.ahdc.AI.InterCluster; +import org.jlab.rec.ahdc.Cluster.Cluster; +import org.jlab.rec.ahdc.DocaCluster.DocaCluster; +import org.jlab.rec.ahdc.DocaCluster.DocaClusterRefiner; +import org.jlab.rec.ahdc.HelixFit.HelixFitJava; import org.jlab.rec.ahdc.KalmanFilter.KalmanFilter; +import org.jlab.rec.ahdc.ModeTrackFinding; +import org.jlab.rec.ahdc.PreCluster.PreCluster; +import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.TrackFinding.AITrackFinder; +import org.jlab.rec.ahdc.TrackFinding.DistanceTrackFinder; +import org.jlab.rec.ahdc.TrackFinding.GNNTrackFinder; +import org.jlab.rec.ahdc.TrackFinding.HoughTrackFinder; +import org.jlab.rec.ahdc.TrackFinding.TrackFinder; +import org.jlab.rec.ahdc.TrackFinding.TrackFinderResult; import org.jlab.geom.detector.alert.AHDC.AlertDCDetector; import org.jlab.geom.detector.alert.AHDC.AlertDCFactory; import org.jlab.rec.ahdc.Track.Track; import org.jlab.clas.pdg.PDGDatabase; import org.jlab.clas.pdg.PDGParticle; +import java.util.List; import java.util.logging.Logger; @@ -75,6 +90,11 @@ public class ALERTEngine extends ReconstructionEngine { private ModelTrackMatching modelTrackMatching; private ModelPrePID modelPrePID; + // AHDC track-finding strategy (driven by ALERT.Mode YAML key) + private TrackFinder trackFinder; + private final org.jlab.rec.ahdc.Banks.RecoBankWriter ahdcWriter + = new org.jlab.rec.ahdc.Banks.RecoBankWriter(); + // AHDC calibration table (refreshed on run change) private IndexedTable ahdcAdcGainsTable; @@ -115,10 +135,22 @@ public boolean init() { requireConstants(tableMap); this.getConstantsManager().setVariation("default"); - if(this.getEngineConfigString("Mode")!=null) { - //if (Objects.equals(this.getEngineConfigString("Mode"), Mode.AI_Track_Finding.name())) - // mode = Mode.AI_Track_Finding; + ModeTrackFinding mode = ModeTrackFinding.MLP_Track_Finding; + String modeConfig = this.getEngineConfigString("Mode"); + if (modeConfig != null) mode = ModeTrackFinding.valueOf(modeConfig); + switch (mode) { + case MLP_Track_Finding: trackFinder = new AITrackFinder(); break; + case CV_Distance: trackFinder = new DistanceTrackFinder(); break; + case CV_Hough: trackFinder = new HoughTrackFinder(); break; + case GNN_Track_Finding: trackFinder = new GNNTrackFinder(); break; } + + this.registerOutputBank( + "AHDC::preclusters", "AHDC::clusters", "AHDC::track", + "AHDC::interclusters", "AHDC::docaclusters", "AHDC::ai:prediction", + "AHDC::mc", "AHDC::kftrack", + "ALERT::projections", "ALERT::ai:projections", "ALERT::prePID"); + return true; } @@ -134,9 +166,7 @@ public boolean init() { @Override public boolean processDataEvent(DataEvent event) { - if (!event.hasBank("AHDC::adc")) - return false; - if (!event.hasBank("ATOF::tdc")) + if (!event.hasBank("AHDC::adc")) return false; if (!event.hasBank("RUN::config")) { @@ -154,7 +184,111 @@ public boolean processDataEvent(DataEvent event) { run.set(newRun); ahdcAdcGainsTable = this.getConstantsManager().getConstants(newRun, "/calibration/alert/ahdc/gains"); } - + + // =========================================================================== + // AHDC track-finding pipeline (preclustering, track finder, DOCA, helix fit) + // Originally lived in AHDCEngine; runs here so AHDCEngine is hits-only. + // Reads AHDC::hits produced by AHDCEngine, mutates Hit.trackId during finding, + // then rewrites AHDC::hits and writes the cluster/track/intercluster banks. + // =========================================================================== + boolean simulation = event.hasBank("MC::Particle"); + + if (event.hasBank("AHDC::hits")) { + + // I) Reconstruct Hit list from AHDC::hits bank + DataBank ahdcHitBank = event.getBank("AHDC::hits"); + ArrayList AHDC_Hits = new ArrayList<>(); + for (int row = 0; row < ahdcHitBank.rows(); row++) { + int id = ahdcHitBank.getShort("id", row); + int superlayer = ahdcHitBank.getByte("superlayer", row); + int layer = ahdcHitBank.getByte("layer", row); + int wire = ahdcHitBank.getInt("wire", row); + int adc = ahdcHitBank.getInt("adc", row); + double doca = ahdcHitBank.getDouble("doca", row); + double time = ahdcHitBank.getDouble("time", row); + double tot = ahdcHitBank.getDouble("timeOverThreshold", row); + Hit hit = new Hit(id, superlayer, layer, wire, doca, adc, time); + hit.setWirePosition(AHDC); + hit.setADC(adc); + hit.setToT(tot); + AHDC_Hits.add(hit); + } + + // II) Track Finding via the strategy selected in init() (ALERT.Mode YAML key). + // The implementation owns its own preclustering, cluster building, and any + // mode-specific safety fallbacks (e.g. AITrackFinder delegates to Distance + // when the hit count exceeds its MAX_HITS_FOR_AI threshold). The ATOF bank + // is passed for finders that build joint AHDC+ATOF graphs (GNN); the + // AHDC-only finders inherit the default and ignore it. + DataBank atofHitsBankForGNN = event.hasBank("ATOF::hits") ? event.getBank("ATOF::hits") : null; + TrackFinderResult trackResult = trackFinder.findTracks(AHDC_Hits, atofHitsBankForGNN); + if (!trackResult.isValid()) { + return false; + } + ArrayList AHDC_Tracks = new ArrayList<>(trackResult.getTracks()); + + // Preclusters are also written to AHDC::preclusters as a diagnostic bank; + // PreClusterFinder is idempotent on Hit.use, so re-running it here is safe. + PreClusterFinder preclusterfinder = new PreClusterFinder(); + preclusterfinder.findPreclusters(AHDC_Hits); + ArrayList AHDC_PreClusters = preclusterfinder.get_AHDCPreClusters(); + + // IV) Global fit: DOCA refinement + helix fit + int trackid = 0; + ArrayList all_docaClusters = new ArrayList<>(); + AHDC_Tracks.removeIf(track -> track.get_Clusters().size() < 3); + for (Track track : AHDC_Tracks) { + trackid++; + track.set_trackId(trackid); + List originalClusters = track.get_Clusters(); + ArrayList docaClusters = DocaClusterRefiner.buildRefinedClusters(originalClusters); + all_docaClusters.addAll(docaClusters); + if (docaClusters == null || docaClusters.size() < 3 || originalClusters == null || originalClusters.size() < 3) { + // not enough points, skip helix fit + continue; + } + HelixFitJava h = new HelixFitJava(); + track.setPositionAndMomentum(h.helix_fit_with_doca_selection(docaClusters, 1)); + } + + // V) Replace AHDC::hits (now with trackId) and write track-finding output banks + DataBank recoHitsBank = ahdcWriter.fillAHDCHitsBank(event, AHDC_Hits); + DataBank recoPreClusterBank = ahdcWriter.fillPreClustersBank(event, AHDC_PreClusters); + ArrayList AHDC_Clusters = new ArrayList<>(); + for (Track track : AHDC_Tracks) { + AHDC_Clusters.addAll(track.get_Clusters()); + } + DataBank recoClusterBank = ahdcWriter.fillClustersBank(event, AHDC_Clusters); + DataBank recoTracksBank = ahdcWriter.fillAHDCTrackBank(event, AHDC_Tracks); + DataBank clustersDocaBank = ahdcWriter.fillAHDCDocaClustersBank(event, all_docaClusters); + + ArrayList all_interclusters = new ArrayList<>(); + for (Track track : AHDC_Tracks) { + all_interclusters.addAll(track.getInterclusters()); + } + DataBank recoInterClusterBank = ahdcWriter.fillInterClusterBank(event, all_interclusters); + + event.removeBank("AHDC::hits"); + event.appendBank(recoHitsBank); + event.appendBank(recoPreClusterBank); + event.appendBank(recoClusterBank); + event.appendBank(recoTracksBank); + event.appendBank(recoInterClusterBank); + event.appendBank(clustersDocaBank); + + if (simulation) { + DataBank recoMCBank = ahdcWriter.fillAHDCMCTrackBank(event); + event.appendBank(recoMCBank); + } + } + // =========================================================================== + + // ATOF-dependent processing follows. Bail out for events without ATOF::tdc + // so the AHDC track-finding output above stands on its own (matches the + // pre-refactor flow where AHDCEngine ran independently of ATOF presence). + if (!event.hasBank("ATOF::tdc")) + return false; + //Do we need to read the event vx,vy,vz? //If not, this part can be moved in the initialization of the engine. double eventVx=0,eventVy=0,eventVz=0; //They should be in CM @@ -362,7 +496,7 @@ public boolean processDataEvent(DataEvent event) { // Initialise the position and the momentum using the information of the AHDC::track // position : mm // momentum : MeV - // Invariant: AHDC_hits is non-empty. AHDCEngine's AI_Track_Finding path uses greedy + // Invariant: AHDC_hits is non-empty. AHDCEngine's MLP_Track_Finding path uses greedy // non-overlap selection so each PreCluster (and thus each Hit) belongs to at most one // surviving track, so the set_trackId stamping is unambiguous and every AHDC::track // row has matching AHDC::hits rows. If this invariant ever flips, the get(0) inside diff --git a/reconstruction/alert/src/test/java/org/jlab/service/alert/AHDCTest.java b/reconstruction/alert/src/test/java/org/jlab/service/alert/AHDCTest.java index 0c372b919c..640ab7fea4 100644 --- a/reconstruction/alert/src/test/java/org/jlab/service/alert/AHDCTest.java +++ b/reconstruction/alert/src/test/java/org/jlab/service/alert/AHDCTest.java @@ -7,7 +7,6 @@ import org.jlab.detector.base.DetectorType; import org.jlab.analysis.physics.TestEvent; import org.jlab.service.ahdc.AHDCEngine; -import org.jlab.rec.ahdc.ModeTrackFinding; /** * @@ -23,12 +22,11 @@ public void run() { DataEvent event = TestEvent.get(DetectorType.AHDC); AHDCEngine engine = new AHDCEngine(); - engine.init(ModeTrackFinding.AI_Track_Finding); + engine.init(); engine.processDataEvent(event); event.show(); event.getBank("AHDC::hits").show(); - event.getBank("AHDC::clusters").show(); assertEquals(event.hasBank("FAKE::Bank"), false); assertEquals(event.hasBank("AHDC::wf"), true);