Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.translate.Batchifier;
import ai.djl.inference.Predictor;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;

import java.io.IOException;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.logging.Level;
import java.util.logging.Logger;

import org.jlab.clas.reco.ReconstructionEngine;
import org.jlab.io.base.DataBank;
Expand All @@ -27,28 +30,36 @@
public class DCDenoiseEngine extends ReconstructionEngine {

final static String[] BANK_NAMES = {"DC::tot","DC::tdc"};
final static String CONF_THRESHOLD = "threshold";
final static String CONF_MODEL_FILE = "modelFile";
final static String CONF_THRESHOLD = "threshold";
final static String CONF_THREADS = "threads";

final static int LAYERS = 36;
final static int WIRES = 112;
final static int SECTORS= 6;

float threshold = 0.025f;
Criteria<float[][],float[][]> criteria;
ZooModel<float[][], float[][]> model;
String modelFile = "cnn_autoenc_sector1_2b_48f_4x6k.pt";
float threshold = 0.03f;
Criteria<float[][][], float[][][]> criteria;
ZooModel<float[][][], float[][][]> model;
PredictorPool predictors;


// -------- Predictor Pool --------
public static class PredictorPool {
final BlockingQueue<Predictor> pool;
public PredictorPool(int size, ZooModel model) {
final BlockingQueue<Predictor<float[][][], float[][][]>> pool;
public PredictorPool(int size, ZooModel<float[][][], float[][][]> model) {
pool = new ArrayBlockingQueue<>(size);
for (int i=0; i<size; i++) pool.add(model.newPredictor());
}
public Predictor get() throws InterruptedException {
return pool.poll();
}
public void put(Predictor p) {
if (p != null) pool.offer(p);
for (int i=0; i<size; i++) {
try {
pool.add(model.newPredictor());
} catch (Exception e) {
Logger.getLogger(PredictorPool.class.getName()).log(Level.WARNING, "Failed to create predictor", e);
}
}
}
public Predictor<float[][][], float[][][]> take() throws InterruptedException { return pool.take(); }
public void put(Predictor<float[][][], float[][][]> p) throws InterruptedException { if (p!=null) pool.put(p); }
public void shutdownAll() { for (Predictor p: pool) { try { p.close(); } catch (Exception ignored) {} } }
}

public DCDenoiseEngine() {
Expand All @@ -62,173 +73,138 @@ public boolean init() {
System.setProperty("ai.djl.pytorch.graph_optimizer", "false");
if (getEngineConfigString(CONF_THRESHOLD) != null)
threshold = Float.parseFloat(getEngineConfigString(CONF_THRESHOLD));
if (getEngineConfigString(CONF_MODEL_FILE) != null)
modelFile = getEngineConfigString(CONF_MODEL_FILE);

try {
String modelPath = ClasUtilsFile.getResourceDir("CLAS12DIR", "etc/data/nnet/dn/" + modelFile);

criteria = Criteria.builder()
.setTypes(float[][].class, float[][].class)
.optModelPath(Paths.get(ClasUtilsFile.getResourceDir("CLAS12DIR","etc/data/nnet/dn/cnn_autoenc_sector1_nBlocks2.pt")))
.setTypes(float[][][].class, float[][][].class)
.optModelPath(Paths.get(modelPath))
.optEngine("PyTorch")
.optTranslator(DCDenoiseEngine.getTranslator())
.optTranslator(DCDenoiseEngine.getBatchTranslator())
.optProgress(new ProgressBar())
.build();

model = criteria.loadModel();

int threads = Integer.parseInt(getEngineConfigString(CONF_THREADS,"64"));
predictors = new PredictorPool(threads, model);
return true;
} catch (NullPointerException | MalformedModelException | IOException | ModelNotFoundException ex) {
System.getLogger(DCDenoiseEngine.class.getName()).log(System.Logger.Level.ERROR, (String) null, ex);
Logger.getLogger(DCDenoiseEngine.class.getName()).log(Level.SEVERE, null, ex);
return false;
}
}

public static void main(String args[]){
DCDenoiseEngine dn = new DCDenoiseEngine();
dn.init();
for (int i=0; i<10000; i++) {
dn.processFakeEvent();
}
}

@Override
public boolean processDataEvent(DataEvent event) {
for (String bankName : BANK_NAMES) {
if (!event.hasBank(bankName)) continue;

//if (true) return processFakeEvent();

for (int i=0; i<BANK_NAMES.length; i++){
if (event.hasBank(BANK_NAMES[i])) {
DataBank bank = event.getBank(BANK_NAMES[i]);
try {
// WARNING: Predictor is *not* thread safe.
Predictor<float[][], float[][]> predictor = predictors.get();
for (int sector=0; sector<6; sector++) {
float[][] input = DCDenoiseEngine.read(bank, sector+1);
float[][] output = predictor.predict(input);
//System.out.println("IN:");show(input);
//System.out.println("OUT:");show(output);
update(bank, threshold, output, sector);
DataBank bank = event.getBank(bankName);
try {
// Build batch for 6 sectors
float[][][] batchInput = new float[SECTORS][LAYERS][WIRES];
boolean anySectorPresent = false;
int rows = bank.rows();
for (int r=0; r<rows; r++) {
int sector = bank.getByte(0,r); // 1..6
if (sector < 1 || sector > SECTORS) continue;
int layer = bank.getByte(1,r);
int wire = bank.getShort(2,r);
byte order = bank.getByte(3,r);
if ((order==0)||(order==10)) {
batchInput[sector-1][layer-1][wire-1]=1.0f;
anySectorPresent = true;
}
predictors.put(predictor);
event.removeBank(BANK_NAMES[i]);
event.appendBank(bank);
}
catch (TranslateException | InterruptedException e) {
throw new RuntimeException(e);
}
break;
}
}
return true;
}

boolean processFakeEvent() {
try {
Predictor<float[][], float[][]> predictor = model.newPredictor();
float[][] input = getAlmostStraightSlightlyBendingTrack();
float[][] output = predictor.predict(input);
//System.out.println("IN:");show(input);
//System.out.println("OUT:");show(output);
}
catch (TranslateException e) {
throw new RuntimeException(e);
}
return true;
}

/**
* Reject sub-threshold hits by modifying the bank's order variable.
* WARNING: This is not a full implementation of OrderType enum and
* all its names, but for now a copy of the subset in C++ DC denoising, see:
* https://code.jlab.org/hallb/clas12/coatjava/denoising/-/blob/main/denoising/code/drift.cc?ref_type=heads#L162-198
*/
static void update(DataBank b, float threshold, float[][] data, int sector) {
//System.out.println("IN:");b.show();
for (int row=0; row<b.rows(); row++) {
if (b.getByte(0,row)-1 != sector) continue;
if (data[b.getByte(1,row)-1][b.getShort(2,row)-1] < threshold) {
if(b.getByte(3,row) == 0) b.setByte(3, row, (byte)(60));
if(b.getByte(3,row) == 10) b.setByte(3, row, (byte)(90));
}
}
//System.out.println("OUT:");b.show();
}
if (!anySectorPresent) continue;

/**
* Get one-sector data with weights set to 0/1.
*/
static float[][] read(DataBank bank, int sector) {
float[][] data = new float[LAYERS][WIRES];
for (int i=0; i<bank.rows(); ++i) {
if (bank.getByte(0,i) == sector) {
byte o = bank.getByte(3,i);
if (0==o || 10==o)
// got a hit, set weight to one:
data[bank.getByte(1,i)-1][bank.getShort(2,i)-1] = 1.0f;
}
}
return data;
}
Predictor<float[][][], float[][][]> predictor = predictors.take();
float[][][] batchOutput;
try {
batchOutput = predictor.predict(batchInput);
} finally {
predictors.put(predictor);
}

/**
* Print all hits for one sector.
*/
static void show(float[][] data) {
System.out.println("Shape: [" + data.length + "," + data[0].length + "]");
for (int i = 0; i < LAYERS; i++) {
for (int j = 0; j < WIRES; j++)
System.out.printf("%.3f ", data[i][j]);
System.out.println();
}
}
for (int sectorIdx=0; sectorIdx<SECTORS; sectorIdx++) {
update(bank, threshold, batchOutput[sectorIdx], sectorIdx);
}

/**
* @return a dummy sector/track
*/
static float[][] getAlmostStraightSlightlyBendingTrack() {
float[][] data = new float[LAYERS][WIRES];
for (int y = 0; y < LAYERS; y++) {
int x = 50 + (y / 10);
data[y][x] = 1.0f;
event.removeBank(bankName);
event.appendBank(bank);
} catch (TranslateException | InterruptedException e) {
throw new RuntimeException(e);
}
break;
}
return data;
return true;
}

public static Translator<float[][],float[][]> getTranslator() {
return new Translator<float[][],float[][]>() {
// -------- Translator for batch --------
public static Translator<float[][][], float[][][]> getBatchTranslator() {
return new Translator<float[][][], float[][][]>() {
@Override
public NDList processInput(TranslatorContext ctx, float[][] input) throws Exception {
public NDList processInput(TranslatorContext ctx, float[][][] input) {
int batch = input.length;
int height = input[0].length;
int width = input[0][0].length;
float[] flat = new float[batch*height*width];
int pos=0;
for (int b=0; b<batch; b++)
for (int h=0; h<height; h++) {
System.arraycopy(input[b][h],0,flat,pos,width);
pos+=width;
}
NDManager manager = ctx.getNDManager();
int height = input.length;
int width = input[0].length;
float[] flat = new float[height * width];
for (int i = 0; i < height; i++) {
System.arraycopy(input[i], 0, flat, i * width, width);
}
NDArray x = manager.create(flat, new Shape(height, width));
// Add batch and channel dims -> [1,1,36,112]
x = x.expandDims(0).expandDims(0);
NDArray x = manager.create(flat, new Shape(batch,1,height,width));
return new NDList(x);
}

@Override
public float[][] processOutput(TranslatorContext ctx, NDList list) throws Exception {
public float[][][] processOutput(TranslatorContext ctx, NDList list) {
NDArray result = list.get(0);
// Remove batch and channel dims -> [36,112]
result = result.squeeze();
// Convert to 1D float array
float[] flat = result.toFloatArray();
// Reshape manually into 2D array
long[] shape = result.getShape().getShape();
int height = (int) shape[0];
int width = (int) shape[1];
float[][] output2d = new float[height][width];
for (int i = 0; i < height; i++) {
System.arraycopy(flat, i * width, output2d[i], 0, width);
}
return output2d;
int batch = (int)shape[0];
int height, width;
if (shape.length==4 && shape[1]==1) {
height=(int)shape[2]; width=(int)shape[3];
result = result.squeeze(1);
} else if (shape.length==3) {
height=(int)shape[1]; width=(int)shape[2];
} else throw new IllegalStateException("Unexpected output shape: "+java.util.Arrays.toString(shape));
float[] flat = result.toFloatArray();
float[][][] out = new float[batch][height][width];
int pos=0;
for (int b=0;b<batch;b++)
for (int h=0;h<height;h++) {
System.arraycopy(flat,pos,out[b][h],0,width);
pos+=width;
}
return out;
}

@Override
public Batchifier getBatchifier() {
return null; // no batching
}
public Batchifier getBatchifier() { return null; }
};
}

// -------- Update single sector in bank --------
static void update(DataBank b, float threshold, float[][] data, int sectorIdx) {
for (int row=0; row<b.rows(); row++) {
if (b.getByte(0,row)-1 != sectorIdx) continue;
int layer=b.getByte(1,row)-1;
int wire=b.getShort(2,row)-1;
if (layer<0 || layer>=data.length) continue;
if (wire<0 || wire>=data[0].length) continue;
if (data[layer][wire]<threshold) {
if(b.getByte(3,row)==0) b.setByte(3,row,(byte)60);
if(b.getByte(3,row)==10) b.setByte(3,row,(byte)90);
}
}
}
}