diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/FloatToStringTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/FloatToStringTransform.java new file mode 100644 index 00000000..c642b382 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/FloatToStringTransform.java @@ -0,0 +1,75 @@ +package com.airbnb.aerosolve.core.transforms; + +import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.util.Util; +import com.typesafe.config.Config; +import lombok.extern.slf4j.Slf4j; + +import java.util.Collection; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** + * Similar to MoveFloatToStringAndFloat, however, just move defined float value into String Feature + * not using bucket. This is used when there are certain number of incorrect data, + * i.e. x = 0 doesn't mean it is worse than x = 0.00001, it just somewhere in the pipeline + * make null = 0, so before we fixed the pipeline, convert it to string feature. + */ +@Slf4j +public class FloatToStringTransform implements Transform { + private String fieldName; + private Collection keys; + private Set values; + private String stringOutputName; + + @Override + public void configure(Config config, String key) { + fieldName = config.getString(key + ".field1"); + if (config.hasPath(key + ".keys")) { + keys = config.getStringList(key + ".keys"); + } + values = new HashSet<>(config.getDoubleList(key + ".values")); + stringOutputName = config.getString(key + ".string_output"); + } + + @Override + public void doTransform(FeatureVector featureVector) { + Map> floatFeatures = featureVector.floatFeatures; + + if (floatFeatures == null || floatFeatures.isEmpty()) { + return; + } + + Map input = floatFeatures.get(fieldName); + + if (input == null || input.isEmpty()) { + return; + } + + Util.optionallyCreateStringFeatures(featureVector); + Map> stringFeatures = featureVector.getStringFeatures(); + Set stringOutput = Util.getOrCreateStringFeature(stringOutputName, stringFeatures); + Collection localKeys = (keys == null)? input.keySet() : keys; + log.debug("k {} {}", localKeys, input); + for (String key : localKeys) { + moveFloatToStringAndFloat( + input, key, values, stringOutput); + } + } + + private void moveFloatToStringAndFloat( + Map input, + String key, Set values, + Set stringOutput) { + if (input.containsKey(key)) { + Double inputFloatValue = input.get(key); + + if (values.contains(inputFloatValue)) { + String movedFloat = key + "=" + inputFloatValue; + stringOutput.add(movedFloat); + input.remove(key); + } + } + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringAndFloatTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringAndFloatTransform.java index 23687d98..f7ad6ac6 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringAndFloatTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringAndFloatTransform.java @@ -58,12 +58,9 @@ public void doTransform(FeatureVector featureVector) { Set stringOutput = Util.getOrCreateStringFeature(stringOutputName, stringFeatures); Map floatOutput = Util.getOrCreateFloatFeature(floatOutputName, floatFeatures); + Collection localKeys = (keys == null)? input.keySet() : keys; - if (keys == null) { - keys = input.keySet(); - } - - for (String key : keys) { + for (String key : localKeys) { moveFloatToStringAndFloat( input, key, bucket, minBucket, maxBucket, stringOutput, floatOutput); } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/FloatToStringTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/FloatToStringTransformTest.java new file mode 100644 index 00000000..adf46752 --- /dev/null +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/FloatToStringTransformTest.java @@ -0,0 +1,92 @@ +package com.airbnb.aerosolve.core.transforms; + +import com.airbnb.aerosolve.core.FeatureVector; +import com.typesafe.config.Config; +import com.typesafe.config.ConfigFactory; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +public class FloatToStringTransformTest { + public String makeConfig() { + return "test_float_to_string_and_float {\n" + + " transform : float_to_string\n" + + " field1 : floatFeature1\n" + + " keys : [a, b, g]\n" + + " values : [0.0, 10.0]\n" + + " string_output : stringOutput\n" + + "}"; + } + + public FeatureVector makeFeatureVector() { + Map> floatFeatures = new HashMap<>(); + + Map floatFeature1 = new HashMap<>(); + + floatFeature1.put("a", 0.0); + floatFeature1.put("b", 10.0); + floatFeature1.put("c", 21.3); + floatFeature1.put("d", 10.1); + floatFeature1.put("e", 11.01); + floatFeature1.put("f", -1.01); + floatFeature1.put("g", 0d); + + floatFeatures.put("floatFeature1", floatFeature1); + + FeatureVector featureVector = new FeatureVector(); + featureVector.setFloatFeatures(floatFeatures); + return featureVector; + } + + @Test + public void testEmptyFeatureVector() { + Config config = ConfigFactory.parseString(makeConfig()); + Transform transform = TransformFactory.createTransform( + config, "test_float_to_string_and_float"); + FeatureVector featureVector = new FeatureVector(); + transform.doTransform(featureVector); + + assertTrue(featureVector.getStringFeatures() == null); + assertTrue(featureVector.getFloatFeatures() == null); + } + + @Test + public void testTransform() { + Config config = ConfigFactory.parseString(makeConfig()); + Transform transform = TransformFactory.createTransform( + config, "test_float_to_string_and_float"); + FeatureVector featureVector = makeFeatureVector(); + + transform.doTransform(featureVector); + + Map> stringFeatures = featureVector.getStringFeatures(); + Map> floatFeatures = featureVector.getFloatFeatures(); + + assertNotNull(stringFeatures); + assertEquals(1, stringFeatures.size()); + + assertNotNull(floatFeatures); + assertEquals(1, floatFeatures.size()); + + Set stringOutput = stringFeatures.get("stringOutput"); + Map floatOutput = floatFeatures.get("floatFeature1"); + + assertEquals(3, stringOutput.size()); + assertEquals(4, floatOutput.size()); + + assertTrue(stringOutput.contains("a=0.0")); + assertTrue(stringOutput.contains("b=10.0")); + assertTrue(stringOutput.contains("g=0.0")); + + assertEquals(21.3, floatOutput.get("c"), 0.0); + assertEquals(10.1, floatOutput.get("d"), 0.0); + assertEquals(11.01, floatOutput.get("e"), 0.0); + assertEquals(-1.01, floatOutput.get("f"), 0.0); + } +}