diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/model/model.pt b/etc/nnet/ALERT/model_AHDC/model_AHDC.pt similarity index 100% rename from reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/model/model.pt rename to etc/nnet/ALERT/model_AHDC/model_AHDC.pt diff --git a/parent/pom.xml b/parent/pom.xml index 9de1ce0e22..f75ed4666a 100644 --- a/parent/pom.xml +++ b/parent/pom.xml @@ -18,6 +18,19 @@ + + + + + ai.djl + bom + 0.30.0 + pom + import + + + + @@ -52,12 +65,40 @@ 0.30.0 compile + ai.djl.pytorch pytorch-model-zoo 0.30.0 + + ai.djl.pytorch + pytorch-engine + 0.30.0 + runtime + + + + ai.djl.pytorch + pytorch-native-cpu + linux-x86_64 + runtime + 2.4.0 + + + + ai.djl.pytorch + pytorch-jni + 2.4.0-0.30.0 + runtime + + + + ai.djl + api + + diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/Model.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/Model.java new file mode 100644 index 0000000000..7c9f26c755 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/Model.java @@ -0,0 +1,58 @@ +package org.jlab.rec.ahdc.AI; + +import ai.djl.MalformedModelException; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import org.jlab.utils.CLASResources; + +import java.io.IOException; +import java.nio.file.Paths; + +public class Model { + private ZooModel model; + + public Model() { + Translator my_translator = new Translator() { + @Override + public Float processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception { + return ndList.get(0).getFloat(); + } + + @Override + public NDList processInput(TranslatorContext translatorContext, float[] floats) throws Exception { + NDManager manager = NDManager.newBaseManager(); + NDArray samples = manager.zeros(new Shape(floats.length)); + samples.set(floats); + return new NDList(samples); + } + }; + + String path = CLASResources.getResourcePath("etc/nnet/ALERT/model_AHDC/"); + Criteria my_model = Criteria.builder().setTypes(float[].class, Float.class) + .optModelPath(Paths.get("etc/nnet/ALERT/model_AHDC/")) + .optEngine("PyTorch") + .optTranslator(my_translator) + .optProgress(new ProgressBar()) + .build(); + + + try { + model = my_model.loadModel(); + } catch (IOException | ModelNotFoundException | MalformedModelException e) { + throw new RuntimeException(e); + } + + } + + public ZooModel getModel() { + return model; + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java index 17cb3f365f..ad88fc1f88 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java @@ -32,14 +32,21 @@ public ArrayList find_preclusters_for_AI(List AHDC_hits) { ArrayList s5l1 = fill(AHDC_hits, 5, 1); // Sort hits of each layers by phi: - s1l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); - s2l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); - s2l2.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); - s3l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); - s3l2.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); - s4l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); - s4l2.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); - s5l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); + Comparator comparator = new Comparator<>() { + @Override + public int compare(Hit a1, Hit a2) { + return Double.compare(a1.getPhi(), a2.getPhi()); + } + }; + + s1l1.sort(comparator); + s2l1.sort(comparator); + s2l2.sort(comparator); + s3l1.sort(comparator); + s3l2.sort(comparator); + s4l1.sort(comparator); + s4l2.sort(comparator); + s5l1.sort(comparator); ArrayList> all_super_layer = new ArrayList<>(Arrays.asList(s1l1, s2l1, s2l2, s3l1, s3l2, s4l1, s4l2, s5l1)); @@ -55,12 +62,19 @@ public ArrayList find_preclusters_for_AI(List AHDC_hits) { ArrayList temp = new ArrayList<>(); temp.add(hit); hit.setUse(true); + int expected_wire_plus = hit.getWireId() + 1; + int expected_wire_minus = hit.getWireId() - 1; + if (hit.getWireId() == 1) + expected_wire_minus = hit.getNbOfWires(); + if (hit.getWireId() == hit.getNbOfWires() ) + expected_wire_plus = 1; + boolean has_next = true; while (has_next) { has_next = false; for (Hit hit1 : p) { - if (hit1.is_NoUsed() && (hit1.getWireId() == temp.get(temp.size() - 1).getWireId() + 1 || hit1.getWireId() == temp.get(temp.size() - 1).getWireId() - 1)) { + if (hit1.is_NoUsed() && (hit1.getWireId() == expected_wire_minus || hit1.getWireId() == expected_wire_plus)) { temp.add(hit1); hit1.setUse(true); has_next = true; diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java index 10d2351ab6..8a968ba512 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java @@ -48,7 +48,7 @@ public Track(ArrayList hitslist) { double p = 150.0;//MeV/c //take first hit. Hit hit = hitslist.get(0); - double phi = Math.atan2(hit.getX(), hit.getY()); + double phi = Math.atan2(hit.getY(), hit.getX()); //hitslist. this.px0 = p*Math.sin(phi); this.py0 = p*Math.cos(phi); diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java b/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java index ed27d104e9..7eb6570075 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java @@ -1,29 +1,12 @@ package org.jlab.rec.service; -import ai.djl.MalformedModelException; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.types.Shape; -import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ModelNotFoundException; -import ai.djl.repository.zoo.ZooModel; -import ai.djl.training.util.ProgressBar; -import ai.djl.translate.TranslateException; -import ai.djl.translate.Translator; -import ai.djl.translate.TranslatorContext; import org.jlab.clas.reco.ReconstructionEngine; import org.jlab.clas.tracking.kalmanfilter.Material; import org.jlab.io.base.DataBank; import org.jlab.io.base.DataEvent; import org.jlab.io.hipo.HipoDataSource; import org.jlab.io.hipo.HipoDataSync; -import org.jlab.jnp.hipo4.data.SchemaFactory; -import org.jlab.rec.ahdc.AI.AIPrediction; -import org.jlab.rec.ahdc.AI.PreClustering; -import org.jlab.rec.ahdc.AI.PreclusterSuperlayer; -import org.jlab.rec.ahdc.AI.TrackConstruction; -import org.jlab.rec.ahdc.AI.TrackPrediction; +import org.jlab.rec.ahdc.AI.*; import org.jlab.rec.ahdc.Banks.RecoBankWriter; import org.jlab.rec.ahdc.Cluster.Cluster; import org.jlab.rec.ahdc.Cluster.ClusterFinder; @@ -40,8 +23,6 @@ import org.jlab.rec.ahdc.Track.Track; import java.io.File; -import java.io.IOException; -import java.nio.file.Paths; import java.util.*; public class AHDCEngine extends ReconstructionEngine { @@ -50,7 +31,7 @@ public class AHDCEngine extends ReconstructionEngine { private boolean use_AI_for_trackfinding; private String findingMethod; private HashMap materialMap; - private ZooModel model; + private Model model; public AHDCEngine() { super("ALERT", "ouillon", "1.0.1"); @@ -66,35 +47,7 @@ public boolean init() { materialMap = MaterialMap.generateMaterials(); } - Translator my_translator = new Translator() { - @Override - public Float processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception { - return ndList.get(0).getFloat(); - } - - @Override - public NDList processInput(TranslatorContext translatorContext, float[] floats) throws Exception { - NDManager manager = NDManager.newBaseManager(); - NDArray samples = manager.zeros(new Shape(floats.length)); - samples.set(floats); - return new NDList(samples); - } - }; - - Criteria my_model = Criteria.builder().setTypes(float[].class, Float.class) - .optModelPath(Paths.get(System.getenv("CLAS12DIR") + "/../reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/model/")) - .optEngine("PyTorch") - .optTranslator(my_translator) - .optProgress(new ProgressBar()) - .build(); - - - try { - model = my_model.loadModel(); - } catch (IOException | ModelNotFoundException | MalformedModelException e) { - throw new RuntimeException(e); - } - + model = new Model(); return true; } @@ -180,8 +133,8 @@ public int compare(Hit a1, Hit a2) { try { AIPrediction aiPrediction = new AIPrediction(); - predictions = aiPrediction.prediction(tracks, model); - } catch (ModelNotFoundException | MalformedModelException | IOException | TranslateException e) { + predictions = aiPrediction.prediction(tracks, model.getModel()); + } catch (Exception e) { throw new RuntimeException(e); }