forked from deeplearning4j/deeplearning4j
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from KonduitAI/asto_ops_wrapper
[WIP] New ops wrapper
- Loading branch information
Showing
14 changed files
with
434 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
30 changes: 30 additions & 0 deletions
30
...nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
package org.nd4j.linalg.api.ops.custom; | ||
|
||
import org.nd4j.autodiff.samediff.SDVariable; | ||
import org.nd4j.autodiff.samediff.SameDiff; | ||
import org.nd4j.base.Preconditions; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||
|
||
public class AdjustContrast extends BaseAdjustContrast { | ||
|
||
public AdjustContrast() {super();} | ||
|
||
public AdjustContrast(INDArray in, double factor, INDArray out) { | ||
super(in, factor, out); | ||
} | ||
|
||
public AdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor) { | ||
super(sameDiff,new SDVariable[]{in,factor}); | ||
} | ||
|
||
@Override | ||
public String opName() { | ||
return "adjust_contrast"; | ||
} | ||
|
||
@Override | ||
public String tensorflowName() { | ||
return "AdjustContrast"; | ||
} | ||
} |
30 changes: 30 additions & 0 deletions
30
...4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
package org.nd4j.linalg.api.ops.custom; | ||
|
||
import org.nd4j.autodiff.samediff.SDVariable; | ||
import org.nd4j.autodiff.samediff.SameDiff; | ||
import org.nd4j.base.Preconditions; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||
|
||
public class AdjustContrastV2 extends BaseAdjustContrast { | ||
|
||
public AdjustContrastV2() {super();} | ||
|
||
public AdjustContrastV2(INDArray in, double factor, INDArray out) { | ||
super(in, factor, out); | ||
} | ||
|
||
public AdjustContrastV2(SameDiff sameDiff, SDVariable in, SDVariable factor) { | ||
super( sameDiff,new SDVariable[]{in,factor}); | ||
} | ||
|
||
@Override | ||
public String opName() { | ||
return "adjust_contrast_v2"; | ||
} | ||
|
||
@Override | ||
public String tensorflowName() { | ||
return "AdjustContrast"; | ||
} | ||
} |
25 changes: 25 additions & 0 deletions
25
...-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
package org.nd4j.linalg.api.ops.custom; | ||
|
||
import org.nd4j.autodiff.samediff.SDVariable; | ||
import org.nd4j.autodiff.samediff.SameDiff; | ||
import org.nd4j.base.Preconditions; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||
|
||
public abstract class BaseAdjustContrast extends DynamicCustomOp { | ||
public BaseAdjustContrast() { | ||
} | ||
|
||
public BaseAdjustContrast(INDArray in, double factor, INDArray out) { | ||
Preconditions.checkArgument(in.rank() >= 3, | ||
String.format("AdjustContrast: op expects rank of input array to be >= 3, but got %d instead", in.rank())); | ||
inputArguments.add(in); | ||
outputArguments.add(out); | ||
|
||
addTArgument(factor); | ||
} | ||
|
||
public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) { | ||
super("", sameDiff, vars); | ||
} | ||
} |
32 changes: 32 additions & 0 deletions
32
...ckends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package org.nd4j.linalg.api.ops.custom; | ||
|
||
import org.nd4j.autodiff.samediff.SDVariable; | ||
import org.nd4j.autodiff.samediff.SameDiff; | ||
import org.nd4j.linalg.api.buffer.DataType; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||
import org.nd4j.linalg.factory.Nd4j; | ||
|
||
public class BitCast extends DynamicCustomOp { | ||
public BitCast() {} | ||
|
||
public BitCast(INDArray in, int dataType, INDArray out) { | ||
inputArguments.add(in); | ||
outputArguments.add(out); | ||
iArguments.add(Long.valueOf(dataType)); | ||
} | ||
|
||
public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) { | ||
super("", sameDiff, new SDVariable[]{in, dataType}); | ||
} | ||
|
||
@Override | ||
public String opName() { | ||
return "bitcast"; | ||
} | ||
|
||
@Override | ||
public String tensorflowName() { | ||
return "Bitcast"; | ||
} | ||
} |
31 changes: 31 additions & 0 deletions
31
...j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package org.nd4j.linalg.api.ops.custom; | ||
|
||
import org.nd4j.autodiff.samediff.SDVariable; | ||
import org.nd4j.autodiff.samediff.SameDiff; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||
import org.nd4j.linalg.factory.Nd4j; | ||
|
||
public class CompareAndBitpack extends DynamicCustomOp { | ||
public CompareAndBitpack() {} | ||
|
||
public CompareAndBitpack(INDArray in, double threshold, INDArray out) { | ||
inputArguments.add(in); | ||
inputArguments.add(Nd4j.scalar(threshold)); | ||
outputArguments.add(out); | ||
} | ||
|
||
public CompareAndBitpack(SameDiff sameDiff, SDVariable threshold) { | ||
super("", sameDiff, new SDVariable[]{threshold}); | ||
} | ||
|
||
@Override | ||
public String opName() { | ||
return "compare_and_bitpack"; | ||
} | ||
|
||
@Override | ||
public String tensorflowName() { | ||
return "CompareAndBitpack"; | ||
} | ||
} |
32 changes: 32 additions & 0 deletions
32
...ds/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package org.nd4j.linalg.api.ops.custom; | ||
|
||
import org.apache.commons.math3.analysis.function.Divide; | ||
import org.nd4j.autodiff.samediff.SDVariable; | ||
import org.nd4j.autodiff.samediff.SameDiff; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||
|
||
public class DivideNoNan extends DynamicCustomOp { | ||
public DivideNoNan() { | ||
} | ||
|
||
public DivideNoNan(INDArray in1, INDArray in2, INDArray out) { | ||
inputArguments.add(in1); | ||
inputArguments.add(in2); | ||
outputArguments.add(out); | ||
} | ||
|
||
public DivideNoNan(SameDiff sameDiff, SDVariable in1, SDVariable in2) { | ||
super("", sameDiff, new SDVariable[]{in1, in2}); | ||
} | ||
|
||
@Override | ||
public String opName() { | ||
return "divide_no_nan"; | ||
} | ||
|
||
@Override | ||
public String tensorflowName() { | ||
return "DivNoNan"; | ||
} | ||
} |
32 changes: 32 additions & 0 deletions
32
...j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package org.nd4j.linalg.api.ops.custom; | ||
|
||
import org.nd4j.autodiff.samediff.SDVariable; | ||
import org.nd4j.autodiff.samediff.SameDiff; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||
|
||
public class DrawBoundingBoxes extends DynamicCustomOp { | ||
public DrawBoundingBoxes() {} | ||
|
||
public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors, | ||
INDArray output) { | ||
inputArguments.add(images); | ||
inputArguments.add(boxes); | ||
inputArguments.add(colors); | ||
outputArguments.add(output); | ||
} | ||
|
||
public DrawBoundingBoxes(SameDiff sameDiff, SDVariable boxes, SDVariable colors) { | ||
super("", sameDiff, new SDVariable[]{boxes, colors}); | ||
} | ||
|
||
@Override | ||
public String opName() { | ||
return "draw_bounding_boxes"; | ||
} | ||
|
||
@Override | ||
public String tensorflowName() { | ||
return "DrawBoundingBoxes"; | ||
} | ||
} |
36 changes: 36 additions & 0 deletions
36
...j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package org.nd4j.linalg.api.ops.custom; | ||
|
||
import org.nd4j.autodiff.samediff.SDVariable; | ||
import org.nd4j.autodiff.samediff.SameDiff; | ||
import org.nd4j.base.Preconditions; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||
|
||
public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp { | ||
public FakeQuantWithMinMaxVarsPerChannel() {} | ||
|
||
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, | ||
INDArray output) { | ||
Preconditions.checkArgument(min.isVector() && max.isVector() && | ||
min.length() == max.length(), | ||
"FakeQuantWithMinMaxVarsPerChannel: min and max should be 1D tensors with the same length"); | ||
inputArguments.add(x); | ||
inputArguments.add(min); | ||
inputArguments.add(max); | ||
outputArguments.add(output); | ||
} | ||
|
||
public FakeQuantWithMinMaxVarsPerChannel(SameDiff sameDiff, SDVariable x, SDVariable min, SDVariable max) { | ||
super("", sameDiff, new SDVariable[]{x, min, max}); | ||
} | ||
|
||
@Override | ||
public String opName() { | ||
return "fake_quant_with_min_max_vars_per_channel"; | ||
} | ||
|
||
@Override | ||
public String tensorflowName() { | ||
return "FakeQuantWithMinMaxVarsPerChannel"; | ||
} | ||
} |
Oops, something went wrong.