diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/Model.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/Model.java index a558763d78..3f196db93c 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/Model.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/Model.java @@ -34,6 +34,9 @@ public NDList processInput(TranslatorContext translatorContext, float[] floats) return new NDList(samples); } }; + System.setProperty("ai.djl.pytorch.num_interop_threads", "1"); + System.setProperty("ai.djl.pytorch.num_threads", "1"); + System.setProperty("ai.djl.pytorch.graph_optimizer", "false"); String path = CLASResources.getResourcePath("etc/nnet/ALERT/model_AHDC/"); Criteria my_model = Criteria.builder().setTypes(float[].class, Float.class) diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreclusterSuperlayer.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreclusterSuperlayer.java index ecab32728c..f1773e73e9 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreclusterSuperlayer.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreclusterSuperlayer.java @@ -39,6 +39,10 @@ public double getY() { return y; } + public int getSuperlayer() { + return this.preclusters.get(0).get_Super_layer(); + } + public String toString() { return "PreCluster{" + "X: " + this.x + " Y: " + this.y + " phi: " + Math.atan2(this.y, this.x) + "}\n"; diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/TrackConstruction.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/TrackConstruction.java index dce5414001..d88ee4a661 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/TrackConstruction.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/TrackConstruction.java @@ -1,5 +1,6 @@ package org.jlab.rec.ahdc.AI; +import org.apache.commons.lang3.mutable.MutableBoolean; import org.jlab.rec.ahdc.Hit.Hit; import java.io.File; @@ -7,9 +8,31 @@ import java.io.IOException; import java.util.*; +/** + * The TrackConstruction class is responsible for constructing all possible track + * candidates from a set of superpreclusters. + */ public class TrackConstruction { + private int max_number_of_track_candidates = 10000; + private double max_angle = Math.toRadians(60); + + /** + * Default constructor. + */ public TrackConstruction() {} + /** + * Computes the modulo operation, which returns the remainder of the division + * of one number by another. This method handles floating-point edge cases + * to ensure accurate results within the expected range. + * + * @param x The dividend. + * @param y The divisor. If y is 0, the method returns x. + * @return The result of x modulo y. The result is in the range: + * - [0..y) if y > 0 + * - (y..0] if y < 0 + * Special cases are handled to avoid floating-point inaccuracies. + */ private double mod(double x, double y) { if (0. == y) return x; @@ -33,74 +56,135 @@ private double mod(double x, double y) { return m; } + + /** + * Wraps an angle to the range [0, 2π). + * + * @param angle The angle to wrap. + * @return The angle wrapped to the range [0, 2π). + */ private double warp_zero_two_pi(double angle) { return mod(angle, 2. * Math.PI); } + /** + * Checks if an angle is within a specified range. + * + * @param angle The angle to check. + * @param lower The lower bound of the range. + * @param upper The upper bound of the range. + * @return {@code true} if the angle is within the range, {@code false} otherwise. + */ private boolean angle_in_range(double angle, double lower, double upper) { return warp_zero_two_pi(angle - lower) <= warp_zero_two_pi(upper - lower); } + /** + * Computes the Cartesian product of two lists of integers, ensuring the number of track candidates + * does not exceed the maximum allowed limit. + * + * @param v1 The first list of integer combinations. + * @param v2 The second list of integers to combine with the first list. + * @param too_much_track_candidates A mutable boolean that is set to {@code true} if the number of track candidates exceeds the maximum limit. + * @param number_of_track_candidates The current count of track candidates. + * @return A list of all possible combinations of integers from {@code v1} and {@code v2}. + */ + private ArrayList> cartesian_product(ArrayList> v1, ArrayList v2, MutableBoolean too_much_track_candidates, int number_of_track_candidates) { + ArrayList> result = new ArrayList<>(); + for (ArrayList i : v1) { + if (too_much_track_candidates.booleanValue()) break; + for (int j : v2) { + if (too_much_track_candidates.booleanValue()) break; + ArrayList n = new ArrayList<>(i); + n.add(j); + result.add(n); + + if (number_of_track_candidates + result.size() >= max_number_of_track_candidates) { + too_much_track_candidates.setValue(true); + break; + } + } + - public ArrayList> get_all_possible_track(ArrayList preclusterSuperlayers) { - - // Get seeds to start the track finding algorithm - ArrayList seeds = new ArrayList<>(); - for (PreclusterSuperlayer precluster : preclusterSuperlayers) { - if (precluster.getPreclusters().get(0).get_hits_list().get(0).getSuperLayerId() == 1) seeds.add(precluster); } - seeds.sort(new Comparator() { - @Override - public int compare(PreclusterSuperlayer a1, PreclusterSuperlayer a2) { - return Double.compare(Math.atan2(a1.getY(), a1.getX()), Math.atan2(a2.getY(), a2.getX())); + return result; + } + + public boolean get_all_possible_track(ArrayList preclusterSuperlayers, ArrayList> all_track_candidates) { + + /* + Identify all superpreclusters located in the first superlayer. + These superpreclusters serve as seeds for constructing track candidates. + A track candidate always starts from a seed. + */ + ArrayList seed_index = new ArrayList<>(); + for (int i = 0; i < preclusterSuperlayers.size(); i++) { + if (!preclusterSuperlayers.get(i).getPreclusters().isEmpty() && + preclusterSuperlayers.get(i).getSuperlayer() == 1) { + seed_index.add(i); } - }); - // System.out.println("seeds: " + seeds); + } - // Get all possible tracks ---------------------------------------------------------------- - double max_angle = Math.toRadians(60); - ArrayList> all_combinations = new ArrayList<>(); - for (PreclusterSuperlayer seed : seeds) { - double phi_seed = warp_zero_two_pi(Math.atan2(seed.getY(), seed.getX())); + boolean sucess = true; + int number_of_track_candidates = 0; - ArrayList track = new ArrayList<>(); - for (PreclusterSuperlayer p : preclusterSuperlayers) { - double phi_p = warp_zero_two_pi(Math.atan2(p.getY(), p.getX())); - if (angle_in_range(phi_p, phi_seed - max_angle, phi_seed + max_angle)) track.add(p); - } - // System.out.println("track: " + track.size()); - - ArrayList> combinations = new ArrayList<>(List.of(new ArrayList<>(List.of(seed)))); - // System.out.println("combinations: " + combinations); - - for (int i = 1; i < 5; ++i) { - ArrayList> new_combinations = new ArrayList<>(); - for (ArrayList combination : combinations) { - - for (PreclusterSuperlayer precluster : track) { - if (precluster.getPreclusters().get(0).get_hits_list().get(0).getSuperLayerId() == seed.getPreclusters().get(0).get_hits_list().get(0).getSuperLayerId() + i) { - // System.out.printf("Good Precluster x: %.2f, y: %.2f, r: %.2f%n", precluster.getX(), precluster.getY(), Math.hypot(precluster.getX(), precluster.getY())); - // System.out.println("combination: " + combination); - - ArrayList new_combination = new ArrayList<>(combination); - new_combination.add(precluster); - // System.out.println("new_combination: " + new_combination); - new_combinations.add(new_combination); - } - } - for (ArrayList c : new_combinations) { - // System.out.println("c.size: " + c.size() + ", c: " + c); - } + // Loop over all seeds to construct track candidates + for (int s : seed_index) { + // Check if the number of track candidates exceeds the maximum limit if so, stop the loop + if (!sucess) break; + // Find all superpreclusters that have a phi angle within phi angle of the seed +/- 60 degrees + // The goal is to reduce the number of superpreclusters to loop over + double phi_seed = warp_zero_two_pi(Math.atan2(preclusterSuperlayers.get(s).getY(), preclusterSuperlayers.get(s).getX())); // phi angle of the seed + ArrayList all_superpreclusters = new ArrayList<>(); // all superpreclusters that are within phi angle of the seed + for (int i = 0; i < preclusterSuperlayers.size(); ++i) { + double phi_p = warp_zero_two_pi(Math.atan2(preclusterSuperlayers.get(i).getY(), preclusterSuperlayers.get(i).getX())); + if (angle_in_range(phi_p, phi_seed - max_angle, phi_seed + max_angle)) { + all_superpreclusters.add(i); } - combinations = new_combinations; - if (combinations.size() > 10000) break; } - for (ArrayList combination : combinations) { - if (combination.size() == 5) { - all_combinations.add(combination); + + + // Sort the superpreclusters by superlayer to have a simpler loops after + ArrayList superpreclusters_s1 = new ArrayList<>(List.of(s)); + ArrayList superpreclusters_s3 = new ArrayList<>(); + ArrayList superpreclusters_s4 = new ArrayList<>(); + ArrayList superpreclusters_s2 = new ArrayList<>(); + ArrayList superpreclusters_s5 = new ArrayList<>(); + + for (int i = 0; i < all_superpreclusters.size(); i++) { + if (preclusterSuperlayers.get(all_superpreclusters.get(i)).getPreclusters().get(0).get_Super_layer() == 2) + superpreclusters_s2.add(all_superpreclusters.get(i)); + else if (preclusterSuperlayers.get(all_superpreclusters.get(i)).getPreclusters().get(0).get_Super_layer() == 3) + superpreclusters_s3.add(all_superpreclusters.get(i)); + else if (preclusterSuperlayers.get(all_superpreclusters.get(i)).getPreclusters().get(0).get_Super_layer() == 4) + superpreclusters_s4.add(all_superpreclusters.get(i)); + else if (preclusterSuperlayers.get(all_superpreclusters.get(i)).getPreclusters().get(0).get_Super_layer() == 5) + superpreclusters_s5.add(all_superpreclusters.get(i)); + } + + MutableBoolean too_much_track_candidates = new MutableBoolean(); // Need to be a mutable boolean to be able to change it in the cartesian_product method + too_much_track_candidates.setFalse(); + + // Find all possible combinations of superpreclusters on different superlayers + ArrayList> combinations_s1_s2 = cartesian_product(new ArrayList<>(List.of(superpreclusters_s1)), superpreclusters_s2, too_much_track_candidates, number_of_track_candidates); + ArrayList> combinations_s1_s2_s3 = cartesian_product(combinations_s1_s2, superpreclusters_s3, too_much_track_candidates, number_of_track_candidates); + ArrayList> combinations_s1_s2_s3_s4 = cartesian_product(combinations_s1_s2_s3, superpreclusters_s4, too_much_track_candidates, number_of_track_candidates); + ArrayList> combinations_s1_s2_s3_s4_s5 = cartesian_product(combinations_s1_s2_s3_s4, superpreclusters_s5, too_much_track_candidates, number_of_track_candidates); + + // Keep track of the number of track candidates + number_of_track_candidates += combinations_s1_s2_s3_s4_s5.size(); + if (too_much_track_candidates.booleanValue()) sucess = false; // If the number of track candidates exceeds the maximum limit, set success to false + + // Add all track candidates to the list of all track candidates + // And switch back from index to superprecluster + for (ArrayList combination : combinations_s1_s2_s3_s4_s5) { + ArrayList track_candidate = new ArrayList<>(); + for (int index : combination) { + track_candidate.add(preclusterSuperlayers.get(index)); } + all_track_candidates.add(track_candidate); } } - return all_combinations; + return sucess; } } diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Mode.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Mode.java new file mode 100644 index 0000000000..1278ddae4c --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Mode.java @@ -0,0 +1,5 @@ +package org.jlab.rec.ahdc; + +public enum Mode { + AI_Track_Finding, CV_Track_Finding; +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java b/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java index 5dab3df8f9..9c118a09a2 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java @@ -21,6 +21,7 @@ import org.jlab.rec.ahdc.PreCluster.PreCluster; import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; import org.jlab.rec.ahdc.Track.Track; +import org.jlab.rec.ahdc.Mode; import java.io.File; import java.util.*; @@ -28,11 +29,12 @@ public class AHDCEngine extends ReconstructionEngine { private boolean simulation; - private boolean use_AI_for_trackfinding; private String findingMethod; private HashMap materialMap; private Model model; + private Mode mode = Mode.CV_Track_Finding; + public AHDCEngine() { super("ALERT", "ouillon", "1.0.1"); } @@ -41,13 +43,23 @@ public AHDCEngine() { public boolean init() { simulation = false; findingMethod = "distance"; - use_AI_for_trackfinding = true; if (materialMap == null) { materialMap = MaterialMap.generateMaterials(); } - model = new Model(); + if(this.getEngineConfigString("Mode")!=null) { + if (Objects.equals(this.getEngineConfigString("Mode"), Mode.AI_Track_Finding.name())) + mode = Mode.AI_Track_Finding; + + if (Objects.equals(this.getEngineConfigString("Mode"), Mode.CV_Track_Finding.name())) + mode = Mode.CV_Track_Finding; + + } + + if (mode == Mode.AI_Track_Finding) { + model = new Model(); + } return true; } @@ -89,8 +101,6 @@ public boolean processDataEvent(DataEvent event) { AHDC_PreClusters = preclusterfinder.get_AHDCPreClusters(); //System.out.println("AHDC_PreClusters size " + AHDC_PreClusters.size()); - - // III) Create Cluster ClusterFinder clusterfinder = new ClusterFinder(); clusterfinder.findCluster(AHDC_PreClusters); @@ -101,7 +111,10 @@ public boolean processDataEvent(DataEvent event) { ArrayList AHDC_Tracks = new ArrayList<>(); ArrayList predictions = new ArrayList<>(); - if (use_AI_for_trackfinding == false) { + // If there is too much hits, we rely on to the conventional track finding + if (AHDC_Hits.size() > 300) mode = Mode.CV_Track_Finding; + + if (mode == Mode.CV_Track_Finding) { if (findingMethod.equals("distance")) { // IV) a) Distance method //System.out.println("using distance"); @@ -116,7 +129,7 @@ public boolean processDataEvent(DataEvent event) { AHDC_Tracks = houghtransform.get_AHDCTracks(); } } - else { + if (mode == Mode.AI_Track_Finding) { // AI --------------------------------------------------------------------------------- AHDC_Hits.sort(new Comparator() { @Override @@ -128,8 +141,13 @@ public int compare(Hit a1, Hit a2) { ArrayList preClustersAI = preClustering.find_preclusters_for_AI(AHDC_Hits); ArrayList preclusterSuperlayers = preClustering.merge_preclusters(preClustersAI); TrackConstruction trackConstruction = new TrackConstruction(); - ArrayList> tracks = trackConstruction.get_all_possible_track(preclusterSuperlayers); + ArrayList> tracks = new ArrayList<>(); + boolean sucess = trackConstruction.get_all_possible_track(preclusterSuperlayers, tracks); + if (!sucess) { + System.err.println("Too much tracks candidates, exit"); + return false; + } try { AIPrediction aiPrediction = new AIPrediction(); @@ -139,7 +157,7 @@ public int compare(Hit a1, Hit a2) { } for (TrackPrediction t : predictions) { - if (t.getPrediction() > 0.5) + if (t.getPrediction() > 0.2) AHDC_Tracks.add(new Track(t.getClusters())); } } @@ -148,7 +166,7 @@ public int compare(Hit a1, Hit a2) { //Temporary track method ONLY for MC with no background; //AHDC_Tracks.add(new Track(AHDC_Hits)); - + // V) Global fit for (Track track : AHDC_Tracks) { int nbOfPoints = track.get_Clusters().size(); @@ -170,7 +188,7 @@ public int compare(Hit a1, Hit a2) { // VI) Kalman Filter // System.out.println("AHDC_Tracks = " + AHDC_Tracks); KalmanFilter kalmanFitter = new KalmanFilter(AHDC_Tracks, event); - + // VII) Write bank RecoBankWriter writer = new RecoBankWriter(); @@ -192,6 +210,7 @@ public int compare(Hit a1, Hit a2) { DataBank recoMCBank = writer.fillAHDCMCTrackBank(event); event.appendBank(recoMCBank); } + } return true; @@ -202,9 +221,9 @@ public static void main(String[] args) { double starttime = System.nanoTime(); int nEvent = 0; - int maxEvent = 1000; + int maxEvent = 10; int myEvent = 3; - String inputFile = "alert_out_update.hipo"; + String inputFile = "merged_10.hipo"; String outputFile = "output.hipo"; if (new File(outputFile).delete()) System.out.println("output.hipo is delete.");