diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/model/model.pt b/etc/nnet/ALERT/model_AHDC/model_AHDC.pt
similarity index 100%
rename from reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/model/model.pt
rename to etc/nnet/ALERT/model_AHDC/model_AHDC.pt
diff --git a/parent/pom.xml b/parent/pom.xml
index 9de1ce0e22..f75ed4666a 100644
--- a/parent/pom.xml
+++ b/parent/pom.xml
@@ -18,6 +18,19 @@
+
+
+
+
+ ai.djl
+ bom
+ 0.30.0
+ pom
+ import
+
+
+
+
@@ -52,12 +65,40 @@
0.30.0
compile
+
ai.djl.pytorch
pytorch-model-zoo
0.30.0
+
+ ai.djl.pytorch
+ pytorch-engine
+ 0.30.0
+ runtime
+
+
+
+ ai.djl.pytorch
+ pytorch-native-cpu
+ linux-x86_64
+ runtime
+ 2.4.0
+
+
+
+ ai.djl.pytorch
+ pytorch-jni
+ 2.4.0-0.30.0
+ runtime
+
+
+
+ ai.djl
+ api
+
+
diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/Model.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/Model.java
new file mode 100644
index 0000000000..7c9f26c755
--- /dev/null
+++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/Model.java
@@ -0,0 +1,58 @@
+package org.jlab.rec.ahdc.AI;
+
+import ai.djl.MalformedModelException;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.types.Shape;
+import ai.djl.repository.zoo.Criteria;
+import ai.djl.repository.zoo.ModelNotFoundException;
+import ai.djl.repository.zoo.ZooModel;
+import ai.djl.training.util.ProgressBar;
+import ai.djl.translate.Translator;
+import ai.djl.translate.TranslatorContext;
+import org.jlab.utils.CLASResources;
+
+import java.io.IOException;
+import java.nio.file.Paths;
+
+public class Model {
+ private ZooModel model;
+
+ public Model() {
+ Translator my_translator = new Translator() {
+ @Override
+ public Float processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
+ return ndList.get(0).getFloat();
+ }
+
+ @Override
+ public NDList processInput(TranslatorContext translatorContext, float[] floats) throws Exception {
+ NDManager manager = NDManager.newBaseManager();
+ NDArray samples = manager.zeros(new Shape(floats.length));
+ samples.set(floats);
+ return new NDList(samples);
+ }
+ };
+
+ String path = CLASResources.getResourcePath("etc/nnet/ALERT/model_AHDC/");
+ Criteria my_model = Criteria.builder().setTypes(float[].class, Float.class)
+ .optModelPath(Paths.get("etc/nnet/ALERT/model_AHDC/"))
+ .optEngine("PyTorch")
+ .optTranslator(my_translator)
+ .optProgress(new ProgressBar())
+ .build();
+
+
+ try {
+ model = my_model.loadModel();
+ } catch (IOException | ModelNotFoundException | MalformedModelException e) {
+ throw new RuntimeException(e);
+ }
+
+ }
+
+ public ZooModel getModel() {
+ return model;
+ }
+}
diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java
index 17cb3f365f..ad88fc1f88 100644
--- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java
+++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java
@@ -32,14 +32,21 @@ public ArrayList find_preclusters_for_AI(List AHDC_hits) {
ArrayList s5l1 = fill(AHDC_hits, 5, 1);
// Sort hits of each layers by phi:
- s1l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
- s2l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
- s2l2.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
- s3l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
- s3l2.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
- s4l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
- s4l2.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
- s5l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
+ Comparator comparator = new Comparator<>() {
+ @Override
+ public int compare(Hit a1, Hit a2) {
+ return Double.compare(a1.getPhi(), a2.getPhi());
+ }
+ };
+
+ s1l1.sort(comparator);
+ s2l1.sort(comparator);
+ s2l2.sort(comparator);
+ s3l1.sort(comparator);
+ s3l2.sort(comparator);
+ s4l1.sort(comparator);
+ s4l2.sort(comparator);
+ s5l1.sort(comparator);
ArrayList> all_super_layer = new ArrayList<>(Arrays.asList(s1l1, s2l1, s2l2, s3l1, s3l2, s4l1, s4l2, s5l1));
@@ -55,12 +62,19 @@ public ArrayList find_preclusters_for_AI(List AHDC_hits) {
ArrayList temp = new ArrayList<>();
temp.add(hit);
hit.setUse(true);
+ int expected_wire_plus = hit.getWireId() + 1;
+ int expected_wire_minus = hit.getWireId() - 1;
+ if (hit.getWireId() == 1)
+ expected_wire_minus = hit.getNbOfWires();
+ if (hit.getWireId() == hit.getNbOfWires() )
+ expected_wire_plus = 1;
+
boolean has_next = true;
while (has_next) {
has_next = false;
for (Hit hit1 : p) {
- if (hit1.is_NoUsed() && (hit1.getWireId() == temp.get(temp.size() - 1).getWireId() + 1 || hit1.getWireId() == temp.get(temp.size() - 1).getWireId() - 1)) {
+ if (hit1.is_NoUsed() && (hit1.getWireId() == expected_wire_minus || hit1.getWireId() == expected_wire_plus)) {
temp.add(hit1);
hit1.setUse(true);
has_next = true;
diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java
index 10d2351ab6..8a968ba512 100644
--- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java
+++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java
@@ -48,7 +48,7 @@ public Track(ArrayList hitslist) {
double p = 150.0;//MeV/c
//take first hit.
Hit hit = hitslist.get(0);
- double phi = Math.atan2(hit.getX(), hit.getY());
+ double phi = Math.atan2(hit.getY(), hit.getX());
//hitslist.
this.px0 = p*Math.sin(phi);
this.py0 = p*Math.cos(phi);
diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java b/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java
index ed27d104e9..7eb6570075 100644
--- a/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java
+++ b/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java
@@ -1,29 +1,12 @@
package org.jlab.rec.service;
-import ai.djl.MalformedModelException;
-import ai.djl.ndarray.NDArray;
-import ai.djl.ndarray.NDList;
-import ai.djl.ndarray.NDManager;
-import ai.djl.ndarray.types.Shape;
-import ai.djl.repository.zoo.Criteria;
-import ai.djl.repository.zoo.ModelNotFoundException;
-import ai.djl.repository.zoo.ZooModel;
-import ai.djl.training.util.ProgressBar;
-import ai.djl.translate.TranslateException;
-import ai.djl.translate.Translator;
-import ai.djl.translate.TranslatorContext;
import org.jlab.clas.reco.ReconstructionEngine;
import org.jlab.clas.tracking.kalmanfilter.Material;
import org.jlab.io.base.DataBank;
import org.jlab.io.base.DataEvent;
import org.jlab.io.hipo.HipoDataSource;
import org.jlab.io.hipo.HipoDataSync;
-import org.jlab.jnp.hipo4.data.SchemaFactory;
-import org.jlab.rec.ahdc.AI.AIPrediction;
-import org.jlab.rec.ahdc.AI.PreClustering;
-import org.jlab.rec.ahdc.AI.PreclusterSuperlayer;
-import org.jlab.rec.ahdc.AI.TrackConstruction;
-import org.jlab.rec.ahdc.AI.TrackPrediction;
+import org.jlab.rec.ahdc.AI.*;
import org.jlab.rec.ahdc.Banks.RecoBankWriter;
import org.jlab.rec.ahdc.Cluster.Cluster;
import org.jlab.rec.ahdc.Cluster.ClusterFinder;
@@ -40,8 +23,6 @@
import org.jlab.rec.ahdc.Track.Track;
import java.io.File;
-import java.io.IOException;
-import java.nio.file.Paths;
import java.util.*;
public class AHDCEngine extends ReconstructionEngine {
@@ -50,7 +31,7 @@ public class AHDCEngine extends ReconstructionEngine {
private boolean use_AI_for_trackfinding;
private String findingMethod;
private HashMap materialMap;
- private ZooModel model;
+ private Model model;
public AHDCEngine() {
super("ALERT", "ouillon", "1.0.1");
@@ -66,35 +47,7 @@ public boolean init() {
materialMap = MaterialMap.generateMaterials();
}
- Translator my_translator = new Translator() {
- @Override
- public Float processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
- return ndList.get(0).getFloat();
- }
-
- @Override
- public NDList processInput(TranslatorContext translatorContext, float[] floats) throws Exception {
- NDManager manager = NDManager.newBaseManager();
- NDArray samples = manager.zeros(new Shape(floats.length));
- samples.set(floats);
- return new NDList(samples);
- }
- };
-
- Criteria my_model = Criteria.builder().setTypes(float[].class, Float.class)
- .optModelPath(Paths.get(System.getenv("CLAS12DIR") + "/../reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/model/"))
- .optEngine("PyTorch")
- .optTranslator(my_translator)
- .optProgress(new ProgressBar())
- .build();
-
-
- try {
- model = my_model.loadModel();
- } catch (IOException | ModelNotFoundException | MalformedModelException e) {
- throw new RuntimeException(e);
- }
-
+ model = new Model();
return true;
}
@@ -180,8 +133,8 @@ public int compare(Hit a1, Hit a2) {
try {
AIPrediction aiPrediction = new AIPrediction();
- predictions = aiPrediction.prediction(tracks, model);
- } catch (ModelNotFoundException | MalformedModelException | IOException | TranslateException e) {
+ predictions = aiPrediction.prediction(tracks, model.getModel());
+ } catch (Exception e) {
throw new RuntimeException(e);
}