Skip to content

Commit

Permalink
Merge pull request #2 from KonduitAI/asto_ops_wrapper
Browse files Browse the repository at this point in the history
[WIP] New ops wrapper
  • Loading branch information
alexanderst committed Oct 16, 2019
2 parents 3662657 + 9e57998 commit 630bb3c
Show file tree
Hide file tree
Showing 14 changed files with 434 additions and 6 deletions.
Expand Up @@ -50,6 +50,7 @@
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.util.*;

import static org.junit.Assert.*;
Expand Down Expand Up @@ -816,6 +817,37 @@ public void weightsNotUpdated_WhenLocked_CBOW() throws Exception {
assertEquals(vec1.getWordVectorMatrix("money"), vec2.getWordVectorMatrix("money"));
}

@Test
public void testWordsNearestSum() throws IOException {
log.info("Load & Vectorize Sentences....");
SentenceIterator iter = new BasicLineIterator(inputFile);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());

log.info("Building model....");
Word2Vec vec = new Word2Vec.Builder()
.minWordFrequency(5)
.iterations(1)
.layerSize(100)
.seed(42)
.windowSize(5)
.iterate(iter)
.tokenizerFactory(t)
.build();

log.info("Fitting Word2Vec model....");
vec.fit();
log.info("Writing word vectors to text file....");
log.info("Closest Words:");
Collection<String> lst = vec.wordsNearestSum("day", 10);
log.info("10 Words closest to 'day': {}", lst);
assertTrue(lst.contains("week"));
assertTrue(lst.contains("night"));
assertTrue(lst.contains("year"));
assertTrue(lst.contains("years"));
assertTrue(lst.contains("time"));
}

private static void printWords(String target, Collection<String> list, Word2Vec vec) {
System.out.println("Words close to [" + target + "]:");
for (String word : list) {
Expand Down
Expand Up @@ -351,7 +351,8 @@ public Collection<String> wordsNearestSum(INDArray words, int top) {
if (lookupTable instanceof InMemoryLookupTable) {
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
INDArray syn0 = l.getSyn0();
INDArray weights = syn0.norm2(0).rdivi(1).muli(words);
INDArray temp = syn0.norm2(0).rdivi(1).reshape(words.shape());
INDArray weights = temp.muli(words);
INDArray distances = syn0.mulRowVector(weights).sum(1);
INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false);
INDArray sort = sorted[0];
Expand Down
Expand Up @@ -47,7 +47,7 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 1, 0) {
NDArray factorT(output->dataType(), block.launchContext()); // = NDArrayFactory::create(factor, block.launchContext());
factorT.p(0, factor);
// this is contrast calculation
*output = (*input - mean) * factorT + mean;
output->assign((*input - mean) * factorT + mean);

return Status::OK();
}
Expand Down
Expand Up @@ -33,6 +33,7 @@
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.NoOp;
import org.nd4j.linalg.api.ops.custom.*;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
Expand Down Expand Up @@ -2649,6 +2650,33 @@ public SDVariable nextIteration(SDVariable x){
return new NextIteration(sameDiff, x).outputVariable();
}

public SDVariable adjustContrast(SDVariable in, SDVariable factor) {
return new AdjustContrast(sameDiff, in, factor).outputVariable();
}

public SDVariable adjustContrastV2(SDVariable in, SDVariable factor) {
return new AdjustContrastV2(sameDiff, in, factor).outputVariable();
}

public SDVariable bitCast(SDVariable in, SDVariable dataType) {
return new BitCast(sameDiff, in, dataType).outputVariable();
}

public SDVariable compareAndBitpack(SDVariable threshold) {
return new CompareAndBitpack(sameDiff, threshold).outputVariable();
}

public SDVariable divideNoNan(SDVariable in1, SDVariable in2) {
return new DivideNoNan(sameDiff, in1, in2).outputVariable();
}

public SDVariable drawBoundingBoxes(SDVariable boxes, SDVariable colors) {
return new DrawBoundingBoxes(sameDiff, boxes, colors).outputVariable();
}

public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max) {
return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable();
}

public String toString() {
return "DifferentialFunctionFactory{methodNames=" + methodNames + "}";
Expand Down
Expand Up @@ -581,8 +581,14 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge.class,
org.nd4j.linalg.api.ops.random.impl.Range.class,
org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class,
org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class

org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class,
org.nd4j.linalg.api.ops.custom.AdjustContrast.class,
org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class,
org.nd4j.linalg.api.ops.custom.BitCast.class,
org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class,
org.nd4j.linalg.api.ops.custom.DivideNoNan.class,
org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class,
org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel.class
);

static {
Expand Down
@@ -0,0 +1,30 @@
package org.nd4j.linalg.api.ops.custom;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;

public class AdjustContrast extends BaseAdjustContrast {

public AdjustContrast() {super();}

public AdjustContrast(INDArray in, double factor, INDArray out) {
super(in, factor, out);
}

public AdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor) {
super(sameDiff,new SDVariable[]{in,factor});
}

@Override
public String opName() {
return "adjust_contrast";
}

@Override
public String tensorflowName() {
return "AdjustContrast";
}
}
@@ -0,0 +1,30 @@
package org.nd4j.linalg.api.ops.custom;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;

public class AdjustContrastV2 extends BaseAdjustContrast {

public AdjustContrastV2() {super();}

public AdjustContrastV2(INDArray in, double factor, INDArray out) {
super(in, factor, out);
}

public AdjustContrastV2(SameDiff sameDiff, SDVariable in, SDVariable factor) {
super( sameDiff,new SDVariable[]{in,factor});
}

@Override
public String opName() {
return "adjust_contrast_v2";
}

@Override
public String tensorflowName() {
return "AdjustContrast";
}
}
@@ -0,0 +1,25 @@
package org.nd4j.linalg.api.ops.custom;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;

public abstract class BaseAdjustContrast extends DynamicCustomOp {
public BaseAdjustContrast() {
}

public BaseAdjustContrast(INDArray in, double factor, INDArray out) {
Preconditions.checkArgument(in.rank() >= 3,
String.format("AdjustContrast: op expects rank of input array to be >= 3, but got %d instead", in.rank()));
inputArguments.add(in);
outputArguments.add(out);

addTArgument(factor);
}

public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) {
super("", sameDiff, vars);
}
}
@@ -0,0 +1,32 @@
package org.nd4j.linalg.api.ops.custom;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;

public class BitCast extends DynamicCustomOp {
public BitCast() {}

public BitCast(INDArray in, int dataType, INDArray out) {
inputArguments.add(in);
outputArguments.add(out);
iArguments.add(Long.valueOf(dataType));
}

public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) {
super("", sameDiff, new SDVariable[]{in, dataType});
}

@Override
public String opName() {
return "bitcast";
}

@Override
public String tensorflowName() {
return "Bitcast";
}
}
@@ -0,0 +1,31 @@
package org.nd4j.linalg.api.ops.custom;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;

public class CompareAndBitpack extends DynamicCustomOp {
public CompareAndBitpack() {}

public CompareAndBitpack(INDArray in, double threshold, INDArray out) {
inputArguments.add(in);
inputArguments.add(Nd4j.scalar(threshold));
outputArguments.add(out);
}

public CompareAndBitpack(SameDiff sameDiff, SDVariable threshold) {
super("", sameDiff, new SDVariable[]{threshold});
}

@Override
public String opName() {
return "compare_and_bitpack";
}

@Override
public String tensorflowName() {
return "CompareAndBitpack";
}
}
@@ -0,0 +1,32 @@
package org.nd4j.linalg.api.ops.custom;

import org.apache.commons.math3.analysis.function.Divide;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;

public class DivideNoNan extends DynamicCustomOp {
public DivideNoNan() {
}

public DivideNoNan(INDArray in1, INDArray in2, INDArray out) {
inputArguments.add(in1);
inputArguments.add(in2);
outputArguments.add(out);
}

public DivideNoNan(SameDiff sameDiff, SDVariable in1, SDVariable in2) {
super("", sameDiff, new SDVariable[]{in1, in2});
}

@Override
public String opName() {
return "divide_no_nan";
}

@Override
public String tensorflowName() {
return "DivNoNan";
}
}
@@ -0,0 +1,32 @@
package org.nd4j.linalg.api.ops.custom;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;

public class DrawBoundingBoxes extends DynamicCustomOp {
public DrawBoundingBoxes() {}

public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors,
INDArray output) {
inputArguments.add(images);
inputArguments.add(boxes);
inputArguments.add(colors);
outputArguments.add(output);
}

public DrawBoundingBoxes(SameDiff sameDiff, SDVariable boxes, SDVariable colors) {
super("", sameDiff, new SDVariable[]{boxes, colors});
}

@Override
public String opName() {
return "draw_bounding_boxes";
}

@Override
public String tensorflowName() {
return "DrawBoundingBoxes";
}
}
@@ -0,0 +1,36 @@
package org.nd4j.linalg.api.ops.custom;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;

public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp {
public FakeQuantWithMinMaxVarsPerChannel() {}

public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max,
INDArray output) {
Preconditions.checkArgument(min.isVector() && max.isVector() &&
min.length() == max.length(),
"FakeQuantWithMinMaxVarsPerChannel: min and max should be 1D tensors with the same length");
inputArguments.add(x);
inputArguments.add(min);
inputArguments.add(max);
outputArguments.add(output);
}

public FakeQuantWithMinMaxVarsPerChannel(SameDiff sameDiff, SDVariable x, SDVariable min, SDVariable max) {
super("", sameDiff, new SDVariable[]{x, min, max});
}

@Override
public String opName() {
return "fake_quant_with_min_max_vars_per_channel";
}

@Override
public String tensorflowName() {
return "FakeQuantWithMinMaxVarsPerChannel";
}
}

0 comments on commit 630bb3c

Please sign in to comment.