Skip to content

Commit

Permalink
Allow native scripts to set the value of a script field to primitive …
Browse files Browse the repository at this point in the history
…arrays

Script fields could not be set to int[] and float[] by native
scripts because StreamInput and StreamOutput could not handle
them.

closes elastic#4175
  • Loading branch information
brwe committed Nov 15, 2013
1 parent a0f3b09 commit 46461c1
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 2 deletions.
44 changes: 44 additions & 0 deletions src/main/java/org/elasticsearch/common/io/stream/StreamInput.java
Expand Up @@ -383,8 +383,52 @@ public Object readGenericValue() throws IOException {
return readText();
case 16:
return readShort();
case 17:
return readPrimitiveIntArray();
case 18:
return readPrimitiveLongArray();
case 19:
return readPrimitiveFloatArray();
case 20:
return readPrimitiveDoubleArray();
default:
throw new IOException("Can't read unknown type [" + type + "]");
}
}

private Object readPrimitiveIntArray() throws IOException {
int length = readVInt();
int[] values = new int[length];
for(int i=0; i<length; i++) {
values[i] = readInt();
}
return values;
}

private Object readPrimitiveLongArray() throws IOException {
int length = readVInt();
long[] values = new long[length];
for(int i=0; i<length; i++) {
values[i] = readLong();
}
return values;
}

private Object readPrimitiveFloatArray() throws IOException {
int length = readVInt();
float[] values = new float[length];
for(int i=0; i<length; i++) {
values[i] = readFloat();
}
return values;
}

private Object readPrimitiveDoubleArray() throws IOException {
int length = readVInt();
double[] values = new double[length];
for(int i=0; i<length; i++) {
values[i] = readDouble();
}
return values;
}
}
40 changes: 40 additions & 0 deletions src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java
Expand Up @@ -386,8 +386,48 @@ public void writeGenericValue(@Nullable Object value) throws IOException {
} else if (type == Short.class) {
writeByte((byte) 16);
writeShort((Short) value);
} else if (type == int[].class) {
writeByte((byte) 17);
writePrimitiveIntArray((int[]) value);
} else if (type == long[].class) {
writeByte((byte) 18);
writePrimitiveLongArray((long[]) value);
} else if (type == float[].class) {
writeByte((byte) 19);
writePrimitiveFloatArray((float[]) value);
} else if (type == double[].class) {
writeByte((byte) 20);
writePrimitiveDoubleArray((double[]) value);
} else {
throw new IOException("Can't write type [" + type + "]");
}
}

private void writePrimitiveIntArray(int[] value) throws IOException {
writeVInt(value.length);
for (int i=0; i<value.length; i++) {
writeInt(value[i]);
}
}

private void writePrimitiveLongArray(long[] value) throws IOException {
writeVInt(value.length);
for (int i=0; i<value.length; i++) {
writeLong(value[i]);
}
}

private void writePrimitiveFloatArray(float[] value) throws IOException {
writeVInt(value.length);
for (int i=0; i<value.length; i++) {
writeFloat(value[i]);
}
}

private void writePrimitiveDoubleArray(double[] value) throws IOException {
writeVInt(value.length);
for (int i=0; i<value.length; i++) {
writeDouble(value[i]);
}
}
}
Expand Up @@ -25,10 +25,8 @@
import org.elasticsearch.test.ElasticsearchTestCase;
import org.junit.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assume.assumeTrue;

/**
*
Expand All @@ -48,6 +46,14 @@ public void testSimpleStreams() throws Exception {
out.writeVLong(4);
out.writeFloat(1.1f);
out.writeDouble(2.2);
int[] intArray = {1, 2, 3};
out.writeGenericValue(intArray);
long[] longArray = {1, 2, 3};
out.writeGenericValue(longArray);
float[] floatArray = {1.1f, 2.2f, 3.3f};
out.writeGenericValue(floatArray);
double[] doubleArray = {1.1, 2.2, 3.3};
out.writeGenericValue(doubleArray);
out.writeString("hello");
out.writeString("goodbye");
BytesStreamInput in = new BytesStreamInput(out.bytes().toBytes(), false);
Expand All @@ -60,6 +66,10 @@ public void testSimpleStreams() throws Exception {
assertThat(in.readVLong(), equalTo((long) 4));
assertThat((double) in.readFloat(), closeTo(1.1, 0.0001));
assertThat(in.readDouble(), closeTo(2.2, 0.0001));
assertThat(in.readGenericValue(), equalTo((Object)intArray));
assertThat(in.readGenericValue(), equalTo((Object)longArray));
assertThat(in.readGenericValue(), equalTo((Object)floatArray));
assertThat(in.readGenericValue(), equalTo((Object)doubleArray));
assertThat(in.readString(), equalTo("hello"));
assertThat(in.readString(), equalTo("goodbye"));
}
Expand Down
152 changes: 152 additions & 0 deletions src/test/java/org/elasticsearch/script/ScriptFieldTests.java
@@ -0,0 +1,152 @@
/*
* Licensed to ElasticSearch and Shay Banon under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. ElasticSearch licenses this
* file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.script;

import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.plugins.AbstractPlugin;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ElasticsearchIntegrationTest;
import org.elasticsearch.test.ElasticsearchIntegrationTest.ClusterScope;
import org.elasticsearch.test.ElasticsearchIntegrationTest.Scope;

import java.util.Map;
import java.util.concurrent.ExecutionException;

import static org.elasticsearch.common.settings.ImmutableSettings.settingsBuilder;
import static org.hamcrest.Matchers.equalTo;

@ClusterScope(scope = Scope.SUITE, numNodes = 3)
public class ScriptFieldTests extends ElasticsearchIntegrationTest {

@Override
protected Settings nodeSettings(int nodeOrdinal) {
return settingsBuilder().put("plugin.types", CustomScriptPlugin.class.getName()).put(super.nodeSettings(nodeOrdinal)).build();
}

static int[] intArray = { Integer.MAX_VALUE, Integer.MIN_VALUE, 3 };
static long[] longArray = { Long.MAX_VALUE, Long.MIN_VALUE, 9223372036854775807l };
static float[] floatArray = { Float.MAX_VALUE, Float.MIN_VALUE, 3.3f };
static double[] doubleArray = { Double.MAX_VALUE, Double.MIN_VALUE, 3.3d };

public void testNativeScript() throws InterruptedException, ExecutionException {

indexRandom(true, client().prepareIndex("test", "type1", "1").setSource("text", "doc1"), client()
.prepareIndex("test", "type1", "2").setSource("text", "doc2"),
client().prepareIndex("test", "type1", "3").setSource("text", "doc3"), client().prepareIndex("test", "type1", "4")
.setSource("text", "doc4"), client().prepareIndex("test", "type1", "5").setSource("text", "doc5"), client()
.prepareIndex("test", "type1", "6").setSource("text", "doc6"));

client().admin().indices().prepareFlush("test").execute().actionGet();
SearchResponse sr = client().prepareSearch("test").setQuery(QueryBuilders.matchAllQuery())
.addScriptField("int", "native", "int", null).addScriptField("float", "native", "float", null)
.addScriptField("double", "native", "double", null).addScriptField("long", "native", "long", null).execute().actionGet();
assertThat(sr.getHits().hits().length, equalTo(6));
for (SearchHit hit : sr.getHits().getHits()) {
Object result = hit.getFields().get("int").getValues().get(0);
assertThat(result, equalTo((Object) intArray));
result = hit.getFields().get("long").getValues().get(0);
assertThat(result, equalTo((Object) longArray));
result = hit.getFields().get("float").getValues().get(0);
assertThat(result, equalTo((Object) floatArray));
result = hit.getFields().get("double").getValues().get(0);
assertThat(result, equalTo((Object) doubleArray));
}
}

static class IntArrayScriptFactory implements NativeScriptFactory {
@Override
public ExecutableScript newScript(@Nullable Map<String, Object> params) {
return new IntScript();
}
}

static class IntScript extends AbstractSearchScript {
@Override
public Object run() {
return intArray;
}
}

static class LongArrayScriptFactory implements NativeScriptFactory {
@Override
public ExecutableScript newScript(@Nullable Map<String, Object> params) {
return new LongScript();
}
}

static class LongScript extends AbstractSearchScript {
@Override
public Object run() {
return longArray;
}
}

static class FloatArrayScriptFactory implements NativeScriptFactory {
@Override
public ExecutableScript newScript(@Nullable Map<String, Object> params) {
return new FloatScript();
}
}

static class FloatScript extends AbstractSearchScript {
@Override
public Object run() {
return floatArray;
}
}

static class DoubleArrayScriptFactory implements NativeScriptFactory {
@Override
public ExecutableScript newScript(@Nullable Map<String, Object> params) {
return new DoubleScript();
}
}

static class DoubleScript extends AbstractSearchScript {
@Override
public Object run() {
return doubleArray;
}
}

public static class CustomScriptPlugin extends AbstractPlugin {

@Override
public String name() {
return "custom_script";
}

@Override
public String description() {
return "script ";
}

public void onModule(ScriptModule scriptModule) {
scriptModule.registerScript("int", IntArrayScriptFactory.class);
scriptModule.registerScript("long", LongArrayScriptFactory.class);
scriptModule.registerScript("float", FloatArrayScriptFactory.class);
scriptModule.registerScript("double", DoubleArrayScriptFactory.class);
}

}
}

0 comments on commit 46461c1

Please sign in to comment.