Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

package org.apache.beam.runners.core.construction;

import static com.google.common.base.Preconditions.checkArgument;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.protobuf.Any;
Expand All @@ -32,7 +35,7 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.coders.StandardCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.Components;
Expand All @@ -54,11 +57,12 @@ public class Coders {
public static final String CUSTOM_CODER_URN = "urn:beam:coders:javasdk:0.1";

// The URNs for coders which are shared across languages
private static final BiMap<Class<? extends Coder>, String> KNOWN_CODER_URNS =
ImmutableBiMap.<Class<? extends Coder>, String>builder()
@VisibleForTesting
static final BiMap<Class<? extends StandardCoder>, String> KNOWN_CODER_URNS =
ImmutableBiMap.<Class<? extends StandardCoder>, String>builder()
.put(ByteArrayCoder.class, "urn:beam:coders:bytes:0.1")
.put(KvCoder.class, "urn:beam:coders:kv:0.1")
.put(VarIntCoder.class, "urn:beam:coders:varint:0.1")
.put(VarLongCoder.class, "urn:beam:coders:varint:0.1")
.put(IntervalWindowCoder.class, "urn:beam:coders:interval_window:0.1")
.put(IterableCoder.class, "urn:beam:coders:stream:0.1")
.put(GlobalWindow.Coder.class, "urn:beam:coders:global_window:0.1")
Expand All @@ -75,11 +79,17 @@ public static RunnerApi.Coder toProto(

private static RunnerApi.Coder toKnownCoder(Coder<?> coder, SdkComponents components)
throws IOException {
checkArgument(
coder instanceof StandardCoder,
"A Known %s must implement %s, but %s of class %s does not",
Coder.class.getSimpleName(),
StandardCoder.class.getSimpleName(),
coder,
coder.getClass().getName());
StandardCoder<?> stdCoder = (StandardCoder<?>) coder;
List<String> componentIds = new ArrayList<>();
if (coder.getCoderArguments() != null) {
for (Coder<?> componentCoder : coder.getCoderArguments()) {
componentIds.add(components.registerCoder(componentCoder));
}
for (Coder<?> componentCoder : stdCoder.getComponents()) {
componentIds.add(components.registerCoder(componentCoder));
}
return RunnerApi.Coder.newBuilder()
.addAllComponentCoderIds(componentIds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,82 +18,146 @@

package org.apache.beam.runners.core.construction;

import static com.google.common.base.Preconditions.checkState;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
import static org.junit.Assert.assertThat;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.HashSet;
import java.util.Set;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.CustomCoder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.coders.StandardCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.Components;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder;
import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
import org.hamcrest.Matchers;
import org.junit.Test;
import org.junit.experimental.runners.Enclosed;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;

/**
* Tests for {@link Coders}.
*/
@RunWith(Parameterized.class)
/** Tests for {@link Coders}. */
@RunWith(Enclosed.class)
public class CodersTest {
@Parameters(name = "{index}: {0}")
public static Iterable<Coder<?>> data() {
return ImmutableList.<Coder<?>>of(
StringUtf8Coder.of(),
IterableCoder.of(VarLongCoder.of()),
KvCoder.of(StringUtf8Coder.of(), ListCoder.of(VarLongCoder.of())),
SerializableCoder.of(Record.class),
new RecordCoder(),
KvCoder.of(new RecordCoder(), AvroCoder.of(Record.class)));
private static final Set<StandardCoder<?>> KNOWN_CODERS =
ImmutableSet.<StandardCoder<?>>builder()
.add(ByteArrayCoder.of())
.add(KvCoder.of(VarLongCoder.of(), VarLongCoder.of()))
.add(VarLongCoder.of())
.add(IntervalWindowCoder.of())
.add(IterableCoder.of(ByteArrayCoder.of()))
.add(GlobalWindow.Coder.INSTANCE)
.add(
FullWindowedValueCoder.of(
IterableCoder.of(VarLongCoder.of()), IntervalWindowCoder.of()))
.build();

/**
* Tests that all known coders are present in the parameters that will be used by
* {@link ToFromProtoTest}.
*/
@RunWith(JUnit4.class)
public static class ValidateKnownCodersPresentTest {
@Test
public void validateKnownCoders() {
// Validates that every known coder in the Coders class is represented in a "Known Coder"
// tests, which demonstrates that they are serialized via components and specified URNs rather
// than java serialized
Set<Class<? extends StandardCoder>> knownCoderClasses = Coders.KNOWN_CODER_URNS.keySet();
Set<Class<? extends StandardCoder>> knownCoderTests = new HashSet<>();
for (StandardCoder<?> coder : KNOWN_CODERS) {
knownCoderTests.add(coder.getClass());
}
Set<Class<? extends StandardCoder>> missingKnownCoders = new HashSet<>(knownCoderClasses);
missingKnownCoders.removeAll(knownCoderTests);
checkState(
missingKnownCoders.isEmpty(),
"Missing validation of known coder %s in %s",
missingKnownCoders,
CodersTest.class.getSimpleName());
}
}

@Parameter(0)
public Coder<?> coder;

@Test
public void toAndFromProto() throws Exception {
SdkComponents componentsBuilder = SdkComponents.create();
RunnerApi.Coder coderProto = Coders.toProto(coder, componentsBuilder);
/**
* Tests round-trip coder encodings for both known and unknown {@link Coder coders}.
*/
@RunWith(Parameterized.class)
public static class ToFromProtoTest {
@Parameters(name = "{index}: {0}")
public static Iterable<Coder<?>> data() {
return ImmutableList.<Coder<?>>builder()
.addAll(KNOWN_CODERS)
.add(
StringUtf8Coder.of(),
SerializableCoder.of(Record.class),
new RecordCoder(),
KvCoder.of(new RecordCoder(), AvroCoder.of(Record.class)))
.build();
}

Components encodedComponents = componentsBuilder.toComponents();
Coder<?> decodedCoder = Coders.fromProto(coderProto, encodedComponents);
assertThat(decodedCoder, Matchers.<Coder<?>>equalTo(coder));
}
@Parameter(0)
public Coder<?> coder;

static class Record implements Serializable {
}
@Test
public void toAndFromProto() throws Exception {
SdkComponents componentsBuilder = SdkComponents.create();
RunnerApi.Coder coderProto = Coders.toProto(coder, componentsBuilder);

private static class RecordCoder extends CustomCoder<Record> {
@Override
public void encode(Record value, OutputStream outStream, Context context)
throws CoderException, IOException {}
Components encodedComponents = componentsBuilder.toComponents();
Coder<?> decodedCoder = Coders.fromProto(coderProto, encodedComponents);
assertThat(decodedCoder, Matchers.<Coder<?>>equalTo(coder));

@Override
public Record decode(InputStream inStream, Context context) throws CoderException, IOException {
return new Record();
if (KNOWN_CODERS.contains(coder)) {
for (RunnerApi.Coder encodedCoder : encodedComponents.getCodersMap().values()) {
assertThat(
encodedCoder.getSpec().getSpec().getUrn(), not(equalTo(Coders.CUSTOM_CODER_URN)));
}
}
}

@Override
public boolean equals(Object other) {
return other != null && getClass().equals(other.getClass());
}
static class Record implements Serializable {}

private static class RecordCoder extends CustomCoder<Record> {
@Override
public void encode(Record value, OutputStream outStream, Context context)
throws CoderException, IOException {}

@Override
public Record decode(InputStream inStream, Context context)
throws CoderException, IOException {
return new Record();
}

@Override
public boolean equals(Object other) {
return other != null && getClass().equals(other.getClass());
}

@Override
public int hashCode() {
return getClass().hashCode();
@Override
public int hashCode() {
return getClass().hashCode();
}
}
}
}