Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions parent/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@
</repository>
</repositories>


<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>0.30.0</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>

<dependencies>
<!-- https://mvnrepository.com/artifact/junit/junit -->
<dependency>
Expand Down Expand Up @@ -52,12 +65,40 @@
<version>0.30.0</version>
<scope>compile</scope>
</dependency>

<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>0.30.0</version>
</dependency>

<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.30.0</version>
<scope>runtime</scope>
</dependency>

<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<classifier>linux-x86_64</classifier>
<scope>runtime</scope>
<version>2.4.0</version>
</dependency>

<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.4.0-0.30.0</version>
<scope>runtime</scope>
</dependency>

<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
</dependency>

</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package org.jlab.rec.ahdc.AI;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import org.jlab.utils.CLASResources;

import java.io.IOException;
import java.nio.file.Paths;

public class Model {
private ZooModel<float[], Float> model;

public Model() {
Translator<float[], Float> my_translator = new Translator<float[], Float>() {
@Override
public Float processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
return ndList.get(0).getFloat();
}

@Override
public NDList processInput(TranslatorContext translatorContext, float[] floats) throws Exception {
NDManager manager = NDManager.newBaseManager();
NDArray samples = manager.zeros(new Shape(floats.length));
samples.set(floats);
return new NDList(samples);
}
};

String path = CLASResources.getResourcePath("etc/nnet/ALERT/model_AHDC/");
Criteria<float[], Float> my_model = Criteria.builder().setTypes(float[].class, Float.class)
.optModelPath(Paths.get("etc/nnet/ALERT/model_AHDC/"))
.optEngine("PyTorch")
.optTranslator(my_translator)
.optProgress(new ProgressBar())
.build();


try {
model = my_model.loadModel();
} catch (IOException | ModelNotFoundException | MalformedModelException e) {
throw new RuntimeException(e);
}

}

public ZooModel<float[], Float> getModel() {
return model;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,21 @@ public ArrayList<PreCluster> find_preclusters_for_AI(List<Hit> AHDC_hits) {
ArrayList<Hit> s5l1 = fill(AHDC_hits, 5, 1);

// Sort hits of each layers by phi:
s1l1.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
s2l1.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
s2l2.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
s3l1.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
s3l2.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
s4l1.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
s4l2.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
s5l1.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
Comparator<Hit> comparator = new Comparator<>() {
@Override
public int compare(Hit a1, Hit a2) {
return Double.compare(a1.getPhi(), a2.getPhi());
}
};

s1l1.sort(comparator);
s2l1.sort(comparator);
s2l2.sort(comparator);
s3l1.sort(comparator);
s3l2.sort(comparator);
s4l1.sort(comparator);
s4l2.sort(comparator);
s5l1.sort(comparator);

ArrayList<ArrayList<Hit>> all_super_layer = new ArrayList<>(Arrays.asList(s1l1, s2l1, s2l2, s3l1, s3l2, s4l1, s4l2, s5l1));

Expand All @@ -55,12 +62,19 @@ public ArrayList<PreCluster> find_preclusters_for_AI(List<Hit> AHDC_hits) {
ArrayList<Hit> temp = new ArrayList<>();
temp.add(hit);
hit.setUse(true);
int expected_wire_plus = hit.getWireId() + 1;
int expected_wire_minus = hit.getWireId() - 1;
if (hit.getWireId() == 1)
expected_wire_minus = hit.getNbOfWires();
if (hit.getWireId() == hit.getNbOfWires() )
expected_wire_plus = 1;


boolean has_next = true;
while (has_next) {
has_next = false;
for (Hit hit1 : p) {
if (hit1.is_NoUsed() && (hit1.getWireId() == temp.get(temp.size() - 1).getWireId() + 1 || hit1.getWireId() == temp.get(temp.size() - 1).getWireId() - 1)) {
if (hit1.is_NoUsed() && (hit1.getWireId() == expected_wire_minus || hit1.getWireId() == expected_wire_plus)) {
temp.add(hit1);
hit1.setUse(true);
has_next = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public Track(ArrayList<Hit> hitslist) {
double p = 150.0;//MeV/c
//take first hit.
Hit hit = hitslist.get(0);
double phi = Math.atan2(hit.getX(), hit.getY());
double phi = Math.atan2(hit.getY(), hit.getX());
//hitslist.
this.px0 = p*Math.sin(phi);
this.py0 = p*Math.cos(phi);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,12 @@
package org.jlab.rec.service;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import org.jlab.clas.reco.ReconstructionEngine;
import org.jlab.clas.tracking.kalmanfilter.Material;
import org.jlab.io.base.DataBank;
import org.jlab.io.base.DataEvent;
import org.jlab.io.hipo.HipoDataSource;
import org.jlab.io.hipo.HipoDataSync;
import org.jlab.jnp.hipo4.data.SchemaFactory;
import org.jlab.rec.ahdc.AI.AIPrediction;
import org.jlab.rec.ahdc.AI.PreClustering;
import org.jlab.rec.ahdc.AI.PreclusterSuperlayer;
import org.jlab.rec.ahdc.AI.TrackConstruction;
import org.jlab.rec.ahdc.AI.TrackPrediction;
import org.jlab.rec.ahdc.AI.*;
import org.jlab.rec.ahdc.Banks.RecoBankWriter;
import org.jlab.rec.ahdc.Cluster.Cluster;
import org.jlab.rec.ahdc.Cluster.ClusterFinder;
Expand All @@ -40,8 +23,6 @@
import org.jlab.rec.ahdc.Track.Track;

import java.io.File;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.*;

public class AHDCEngine extends ReconstructionEngine {
Expand All @@ -50,7 +31,7 @@ public class AHDCEngine extends ReconstructionEngine {
private boolean use_AI_for_trackfinding;
private String findingMethod;
private HashMap<String, Material> materialMap;
private ZooModel<float[], Float> model;
private Model model;

public AHDCEngine() {
super("ALERT", "ouillon", "1.0.1");
Expand All @@ -66,35 +47,7 @@ public boolean init() {
materialMap = MaterialMap.generateMaterials();
}

Translator<float[], Float> my_translator = new Translator<float[], Float>() {
@Override
public Float processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
return ndList.get(0).getFloat();
}

@Override
public NDList processInput(TranslatorContext translatorContext, float[] floats) throws Exception {
NDManager manager = NDManager.newBaseManager();
NDArray samples = manager.zeros(new Shape(floats.length));
samples.set(floats);
return new NDList(samples);
}
};

Criteria<float[], Float> my_model = Criteria.builder().setTypes(float[].class, Float.class)
.optModelPath(Paths.get(System.getenv("CLAS12DIR") + "/../reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/model/"))
.optEngine("PyTorch")
.optTranslator(my_translator)
.optProgress(new ProgressBar())
.build();


try {
model = my_model.loadModel();
} catch (IOException | ModelNotFoundException | MalformedModelException e) {
throw new RuntimeException(e);
}

model = new Model();

return true;
}
Expand Down Expand Up @@ -180,8 +133,8 @@ public int compare(Hit a1, Hit a2) {

try {
AIPrediction aiPrediction = new AIPrediction();
predictions = aiPrediction.prediction(tracks, model);
} catch (ModelNotFoundException | MalformedModelException | IOException | TranslateException e) {
predictions = aiPrediction.prediction(tracks, model.getModel());
} catch (Exception e) {
throw new RuntimeException(e);
}

Expand Down
Loading