diff --git a/reconstruction/ai/src/main/java/org/jlab/service/ai/DCDenoiseEngine.java b/reconstruction/ai/src/main/java/org/jlab/service/ai/DCDenoiseEngine.java index 8187762ba..45e84159b 100644 --- a/reconstruction/ai/src/main/java/org/jlab/service/ai/DCDenoiseEngine.java +++ b/reconstruction/ai/src/main/java/org/jlab/service/ai/DCDenoiseEngine.java @@ -11,13 +11,16 @@ import ai.djl.training.util.ProgressBar; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; +import ai.djl.translate.Batchifier; import ai.djl.inference.Predictor; import ai.djl.repository.zoo.ModelNotFoundException; -import ai.djl.translate.Batchifier; import ai.djl.translate.TranslateException; + import java.io.IOException; import java.util.concurrent.BlockingQueue; import java.util.concurrent.ArrayBlockingQueue; +import java.util.logging.Level; +import java.util.logging.Logger; import org.jlab.clas.reco.ReconstructionEngine; import org.jlab.io.base.DataBank; @@ -27,28 +30,36 @@ public class DCDenoiseEngine extends ReconstructionEngine { final static String[] BANK_NAMES = {"DC::tot","DC::tdc"}; - final static String CONF_THRESHOLD = "threshold"; + final static String CONF_MODEL_FILE = "modelFile"; + final static String CONF_THRESHOLD = "threshold"; final static String CONF_THREADS = "threads"; + final static int LAYERS = 36; final static int WIRES = 112; + final static int SECTORS= 6; - float threshold = 0.025f; - Criteria criteria; - ZooModel model; + String modelFile = "cnn_autoenc_sector1_2b_48f_4x6k.pt"; + float threshold = 0.03f; + Criteria criteria; + ZooModel model; PredictorPool predictors; - + + // -------- Predictor Pool -------- public static class PredictorPool { - final BlockingQueue pool; - public PredictorPool(int size, ZooModel model) { + final BlockingQueue> pool; + public PredictorPool(int size, ZooModel model) { pool = new ArrayBlockingQueue<>(size); - for (int i=0; i take() throws InterruptedException { return pool.take(); } + public void put(Predictor p) throws InterruptedException { if (p!=null) pool.put(p); } + public void shutdownAll() { for (Predictor p: pool) { try { p.close(); } catch (Exception ignored) {} } } } public DCDenoiseEngine() { @@ -62,173 +73,138 @@ public boolean init() { System.setProperty("ai.djl.pytorch.graph_optimizer", "false"); if (getEngineConfigString(CONF_THRESHOLD) != null) threshold = Float.parseFloat(getEngineConfigString(CONF_THRESHOLD)); + if (getEngineConfigString(CONF_MODEL_FILE) != null) + modelFile = getEngineConfigString(CONF_MODEL_FILE); + try { + String modelPath = ClasUtilsFile.getResourceDir("CLAS12DIR", "etc/data/nnet/dn/" + modelFile); + criteria = Criteria.builder() - .setTypes(float[][].class, float[][].class) - .optModelPath(Paths.get(ClasUtilsFile.getResourceDir("CLAS12DIR","etc/data/nnet/dn/cnn_autoenc_sector1_nBlocks2.pt"))) + .setTypes(float[][][].class, float[][][].class) + .optModelPath(Paths.get(modelPath)) .optEngine("PyTorch") - .optTranslator(DCDenoiseEngine.getTranslator()) + .optTranslator(DCDenoiseEngine.getBatchTranslator()) .optProgress(new ProgressBar()) .build(); + model = criteria.loadModel(); + int threads = Integer.parseInt(getEngineConfigString(CONF_THREADS,"64")); predictors = new PredictorPool(threads, model); return true; } catch (NullPointerException | MalformedModelException | IOException | ModelNotFoundException ex) { - System.getLogger(DCDenoiseEngine.class.getName()).log(System.Logger.Level.ERROR, (String) null, ex); + Logger.getLogger(DCDenoiseEngine.class.getName()).log(Level.SEVERE, null, ex); return false; } } - public static void main(String args[]){ - DCDenoiseEngine dn = new DCDenoiseEngine(); - dn.init(); - for (int i=0; i<10000; i++) { - dn.processFakeEvent(); - } - } - @Override public boolean processDataEvent(DataEvent event) { + for (String bankName : BANK_NAMES) { + if (!event.hasBank(bankName)) continue; - //if (true) return processFakeEvent(); - - for (int i=0; i predictor = predictors.get(); - for (int sector=0; sector<6; sector++) { - float[][] input = DCDenoiseEngine.read(bank, sector+1); - float[][] output = predictor.predict(input); - //System.out.println("IN:");show(input); - //System.out.println("OUT:");show(output); - update(bank, threshold, output, sector); + DataBank bank = event.getBank(bankName); + try { + // Build batch for 6 sectors + float[][][] batchInput = new float[SECTORS][LAYERS][WIRES]; + boolean anySectorPresent = false; + int rows = bank.rows(); + for (int r=0; r SECTORS) continue; + int layer = bank.getByte(1,r); + int wire = bank.getShort(2,r); + byte order = bank.getByte(3,r); + if ((order==0)||(order==10)) { + batchInput[sector-1][layer-1][wire-1]=1.0f; + anySectorPresent = true; } - predictors.put(predictor); - event.removeBank(BANK_NAMES[i]); - event.appendBank(bank); } - catch (TranslateException | InterruptedException e) { - throw new RuntimeException(e); - } - break; - } - } - return true; - } - boolean processFakeEvent() { - try { - Predictor predictor = model.newPredictor(); - float[][] input = getAlmostStraightSlightlyBendingTrack(); - float[][] output = predictor.predict(input); - //System.out.println("IN:");show(input); - //System.out.println("OUT:");show(output); - } - catch (TranslateException e) { - throw new RuntimeException(e); - } - return true; - } - - /** - * Reject sub-threshold hits by modifying the bank's order variable. - * WARNING: This is not a full implementation of OrderType enum and - * all its names, but for now a copy of the subset in C++ DC denoising, see: - * https://code.jlab.org/hallb/clas12/coatjava/denoising/-/blob/main/denoising/code/drift.cc?ref_type=heads#L162-198 - */ - static void update(DataBank b, float threshold, float[][] data, int sector) { - //System.out.println("IN:");b.show(); - for (int row=0; row predictor = predictors.take(); + float[][][] batchOutput; + try { + batchOutput = predictor.predict(batchInput); + } finally { + predictors.put(predictor); + } - /** - * Print all hits for one sector. - */ - static void show(float[][] data) { - System.out.println("Shape: [" + data.length + "," + data[0].length + "]"); - for (int i = 0; i < LAYERS; i++) { - for (int j = 0; j < WIRES; j++) - System.out.printf("%.3f ", data[i][j]); - System.out.println(); - } - } + for (int sectorIdx=0; sectorIdx getTranslator() { - return new Translator() { + // -------- Translator for batch -------- + public static Translator getBatchTranslator() { + return new Translator() { @Override - public NDList processInput(TranslatorContext ctx, float[][] input) throws Exception { + public NDList processInput(TranslatorContext ctx, float[][][] input) { + int batch = input.length; + int height = input[0].length; + int width = input[0][0].length; + float[] flat = new float[batch*height*width]; + int pos=0; + for (int b=0; b [1,1,36,112] - x = x.expandDims(0).expandDims(0); + NDArray x = manager.create(flat, new Shape(batch,1,height,width)); return new NDList(x); } + @Override - public float[][] processOutput(TranslatorContext ctx, NDList list) throws Exception { + public float[][][] processOutput(TranslatorContext ctx, NDList list) { NDArray result = list.get(0); - // Remove batch and channel dims -> [36,112] - result = result.squeeze(); - // Convert to 1D float array - float[] flat = result.toFloatArray(); - // Reshape manually into 2D array long[] shape = result.getShape().getShape(); - int height = (int) shape[0]; - int width = (int) shape[1]; - float[][] output2d = new float[height][width]; - for (int i = 0; i < height; i++) { - System.arraycopy(flat, i * width, output2d[i], 0, width); - } - return output2d; + int batch = (int)shape[0]; + int height, width; + if (shape.length==4 && shape[1]==1) { + height=(int)shape[2]; width=(int)shape[3]; + result = result.squeeze(1); + } else if (shape.length==3) { + height=(int)shape[1]; width=(int)shape[2]; + } else throw new IllegalStateException("Unexpected output shape: "+java.util.Arrays.toString(shape)); + float[] flat = result.toFloatArray(); + float[][][] out = new float[batch][height][width]; + int pos=0; + for (int b=0;b=data.length) continue; + if (wire<0 || wire>=data[0].length) continue; + if (data[layer][wire]