Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cellpose v2 #10

Merged
merged 7 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
79 changes: 61 additions & 18 deletions src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java
Expand Up @@ -29,6 +29,7 @@
import org.locationtech.jts.geom.Envelope;
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.index.strtree.STRtree;
import org.locationtech.jts.simplify.VWSimplifier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.ext.biop.cmd.VirtualEnvironmentRunner;
Expand Down Expand Up @@ -92,6 +93,14 @@ public class Cellpose2D {
private final static Logger logger = LoggerFactory.getLogger(Cellpose2D.class);
public Double learningRate = null;
public Integer batchSize = null;
public double simplifyDistance = 0.0;
public double normPercentileMin = -1.0;
public double normPercentileMax = -1.0;

protected boolean useCellposeNormalization = true;
protected boolean useGlobalNorm = false;
protected int globalNormalizationScale = 8;

protected Integer channel1 = 0;
protected Integer channel2 = 0;
protected Double iouThreshold = 0.1;
Expand Down Expand Up @@ -181,6 +190,16 @@ private static PathObject cellToObject(PathObject cell, Function<ROI, PathObject
return parent;
}

private Geometry simplify(Geometry geom) {
if (simplifyDistance <= 0)
return geom;
try {
return VWSimplifier.simplify(geom, simplifyDistance);
} catch (Exception e) {
return geom;
}
}

/**
* Detect cells within one or more parent objects, firing update events upon completion.
*
Expand Down Expand Up @@ -259,16 +278,37 @@ public void detectObjects(ImageData<BufferedImage> imageData, Collection<? exten
File ori = tilefile.getFile();
File maskFile = new File(ori.getParent(), FilenameUtils.removeExtension(ori.getName()) + "_cp_masks.tif");
if (maskFile.exists()) {
try {
logger.info("Getting objects for {}", maskFile);
logger.info("Getting objects for {}", maskFile);

// thank you Pete for the ContourTracing Class
List<PathObject> detections = ContourTracing.labelsToDetections(maskFile.toPath(), tilefile.getTile());
// thank you Pete for the ContourTracing Class
List<PathObject> detections = null;
try {
detections = ContourTracing.labelsToDetections( maskFile.toPath(), tilefile.getTile());


// Clean Detections
detections = detections.parallelStream().map(det -> {
if (det.getROI().getGeometry().getNumGeometries() > 1) {
// Detemine largest one
Geometry geom = det.getROI().getGeometry();
double largestArea = geom.getGeometryN(0).getArea();
int idx = 0;
for (int i = 0; i < geom.getNumGeometries(); i++) {
if (geom.getGeometryN(i).getArea() > largestArea) idx = i;
}
ROI newROI = GeometryTools.geometryToROI(geom.getGeometryN(idx), det.getROI().getImagePlane());
return PathObjects.createDetectionObject(newROI, det.getPathClass(), det.getMeasurementList());
} else {
return det;
}
}).collect(Collectors.toList());

allDetections.addAll(detections);
} catch (IOException e) {
e.printStackTrace();
}
logger.info("Getting objects for {} Done", maskFile);

allDetections.addAll(detections);
}
});

Expand Down Expand Up @@ -346,6 +386,8 @@ public void detectObjects(ImageData<BufferedImage> imageData, Collection<? exten

private PathObject convertToObject(PathObject object, ImagePlane plane, double cellExpansion, Geometry mask) {
var geomNucleus = object.getROI().getGeometry();
geomNucleus = simplify(geomNucleus);

PathObject pathObject;
if (cellExpansion > 0) {
var geomCell = CellTools.estimateCellBoundary(geomNucleus, cellExpansion, cellConstrainScale);
Expand Down Expand Up @@ -522,9 +564,10 @@ private void runCellpose() throws IOException, InterruptedException {
}

if (!maskThreshold.isNaN()) {
if (!cellposeSetup.getVersion().equals(CellposeSetup.CellposeVersion.CELLPOSE))
if (cellposeSetup.getVersion().equals(CellposeSetup.CellposeVersion.CELLPOSE))
cellposeArguments.add("--cellprob_threshold");
else
cellposeArguments.add("--mask_threshold");
else cellposeArguments.add("--cellprob_threshold");


cellposeArguments.add("" + maskThreshold);
Expand All @@ -540,12 +583,14 @@ private void runCellpose() throws IOException, InterruptedException {

cellposeArguments.add("--no_npy");

if (!cellposeSetup.getVersion().equals(CellposeSetup.CellposeVersion.CELLPOSE_1))
if ( !cellposeSetup.getVersion().equals(CellposeSetup.CellposeVersion.CELLPOSE_1) ||
cellposeSetup.getVersion().equals(CellposeSetup.CellposeVersion.CELLPOSE_2) )
cellposeArguments.add("--resample");

if (useGPU) cellposeArguments.add("--use_gpu");

if (cellposeSetup.getVersion().equals(CellposeSetup.CellposeVersion.CELLPOSE_1))
if ( cellposeSetup.getVersion().equals(CellposeSetup.CellposeVersion.CELLPOSE_1) ||
cellposeSetup.getVersion().equals(CellposeSetup.CellposeVersion.CELLPOSE_2) )
cellposeArguments.add("--verbose");

veRunner.setArguments(cellposeArguments);
Expand Down Expand Up @@ -803,17 +848,15 @@ private void saveTrainingImages() throws IOException {
logger.error(ex.getMessage());
}
});

}


/**
* Checks the default folder where cellpose drops a trained model (../train/models/)
* and moves it to the defined modelDirectory using {@link #modelDirectory}
*
* @return the File of the moved model
* @throws IOException in case there was a problem moving the file
*/
/**
* Checks the default folder where cellpose drops a trained model (../train/models/)
* and moves it to the defined modelDirectory using {@link #modelDirectory}
*
* @return the File of the moved model
* @throws IOException in case there was a problem moving the file
*/
private File moveAndReturnModelFile() throws IOException {
File cellPoseModelFolder = new File(trainDirectory, "models");
// Find the first file in there
Expand Down
102 changes: 75 additions & 27 deletions src/main/java/qupath/ext/biop/cellpose/CellposeBuilder.java
Expand Up @@ -80,6 +80,12 @@ public class CellposeBuilder {

private transient boolean saveBuilder;
private transient String builderName;
private double simplifyDistance = 0.0;
private boolean useCellposeNormalization = true;
private boolean useGlobalNorm = false;
private int globalNormalizationScale = 8;
private double normPercentileMin = -1.0;
private double normPercentileMax = -1.0;

/**
* can create a cellpose builder from a serialized JSON version of this builder.
Expand Down Expand Up @@ -131,6 +137,31 @@ public CellposeBuilder pixelSize(double pixelSize) {
return this;
}

/**
* Apply percentile normalization to the input image channels.
* <p>
* Note that this can be used in combination with {@link #preprocess(ImageOp...)},
* in which case the order in which the operations are applied depends upon the order
* in which the methods of the builder are called.
* <p>
* Warning! This is applied on a per-tile basis. This can result in artifacts and false detections
* without background/constant regions.
* Consider using {@link #inputAdd(double...)} and {@link #inputScale(double...)} as alternative
* normalization strategies, if appropriate constants can be determined to apply globally.
*
* @param min minimum percentile
* @param max maximum percentile
* @return this builder
*/
public CellposeBuilder normalizePercentiles(double min, double max) {
this.normPercentileMin = min;
this.normPercentileMax = max;

//this.ops.add(ImageOps.Normalize.percentile(min, max));
return this;
}


/**
* Specify channels. Useful for detecting nuclei for one channel
* within a multi-channel image, or potentially for trained models that
Expand Down Expand Up @@ -187,27 +218,6 @@ public CellposeBuilder preprocess(ImageOp... ops) {
return this;
}

/**
* Apply percentile normalization to the input image channels.
* <p>
* Note that this can be used in combination with {@link #preprocess(ImageOp...)},
* in which case the order in which the operations are applied depends upon the order
* in which the methods of the builder are called.
* <p>
* Warning! This is applied on a per-tile basis. This can result in artifacts and false detections
* without background/constant regions.
* Consider using {@link #inputAdd(double...)} and {@link #inputScale(double...)} as alternative
* normalization strategies, if appropriate constants can be determined to apply globally.
*
* @param min minimum percentile
* @param max maximum percentile
* @return this builder
*/
public CellposeBuilder normalizePercentiles(double min, double max) {
this.ops.add(ImageOps.Normalize.percentile(min, max));
return this;
}

/**
* Add an offset as a preprocessing step.
* Usually the value will be negative. Along with {@link #inputScale(double...)} this can be used as an alternative (global) normalization.
Expand Down Expand Up @@ -267,12 +277,28 @@ public CellposeBuilder tileSize(int tileWidth, int tileHeight) {
return this;
}

/**
* Sets the channels to use by cellpose, in case there is an issue with the order or the number of exported channels
* @param channel1 the main channel
* @param channel2 the second channel (typically nuclei)
* @return this builder
*/
public CellposeBuilder useCellposeNormalization( boolean useCellposeNorm){
this.useCellposeNormalization = useCellposeNorm;
return this;
}

public CellposeBuilder useGlobalNormalization( boolean useGlobalNorm) {
this.useGlobalNorm = useGlobalNorm;
return this;
}

public CellposeBuilder globalNormalizationScale( int globalNormDownsampling ) {
this.globalNormalizationScale = globalNormDownsampling;
return this;
}


/**
* Sets the channels to use by cellpose, in case there is an issue with the order or the number of exported channels
* @param channel1 the main channel
* @param channel2 the second channel (typically nuclei)
* @return this builder
*/
public CellposeBuilder cellposeChannels(int channel1, int channel2) {
this.channel1 = channel1;
this.channel2 = channel2;
Expand Down Expand Up @@ -322,6 +348,11 @@ public CellposeBuilder invert() {
return this;
}

public CellposeBuilder simplify(double distance) {
this.simplifyDistance = distance;
return this;
}

/**
* Use Omnipose implementation: Adds --omni flag to command
*
Expand Down Expand Up @@ -623,7 +654,20 @@ public Cellpose2D build() {
if(diameter.isNaN()) cellpose.diameter = 0.0;
else cellpose.diameter = diameter;

cellpose.simplifyDistance = simplifyDistance;

cellpose.invert = isInvert;

if(cellpose.useCellposeNormalization) logger.info("Using Cellpose Normalization (per tile).");
cellpose.useCellposeNormalization = useCellposeNormalization;

if(cellpose.useGlobalNorm && cellpose.useCellposeNormalization) {
logger.warn("You cannot use global normalization and enable 'use cellpose normalization' at the same time!. Will default to cellpose normalization (per tile).");
} else {
logger.info("Using global normalization with a downsampling factor of {}", globalNormalizationScale);
cellpose.useGlobalNorm = useGlobalNorm;
cellpose.globalNormalizationScale = globalNormalizationScale;
}
cellpose.doCluster = doCluster;
cellpose.excludeEdges = excludeEdges;
cellpose.useOmnipose = useOmnipose;
Expand Down Expand Up @@ -657,6 +701,10 @@ public Cellpose2D build() {
cellpose.learningRate = learningRate;
cellpose.batchSize = batchSize;

cellpose.normPercentileMax = normPercentileMax;
cellpose.normPercentileMin = normPercentileMin;



// Overlap for segmentation of tiles. Should be large enough that any object must be "complete"
// in at least one tile for resolving overlaps
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/qupath/ext/biop/cellpose/CellposeSetup.java
Expand Up @@ -9,7 +9,9 @@ public class CellposeSetup {
public enum CellposeVersion {
CELLPOSE("Cellpose before v0.7.0"),
OMNIPOSE("Omnipose after v0.7.2"),
CELLPOSE_1("Cellpose Version 1.0");
CELLPOSE_1("Cellpose Version 1.0"),
CELLPOSE_2("Cellpose Version 2.0");


private final String description;

Expand Down
Expand Up @@ -75,7 +75,7 @@ private List<String> getActivationCommand() {
case CONDA:
switch (platform) {
case WINDOWS:
cmd.addAll(Arrays.asList("conda", "activate", environmentNameOrPath, "&", "python"));
cmd.addAll(Arrays.asList("CALL", "conda.bat", "activate", environmentNameOrPath, "&", "python"));
break;
case UNIX:
case OSX:
Expand Down