Skip to content

Commit

Permalink
Merge pull request #17418 from ihji/BEAM-14343
Browse files Browse the repository at this point in the history
[BEAM-14343] Allow expansion service override in ExternalPythonTransform
  • Loading branch information
ihji committed Apr 26, 2022
2 parents a936ef8 + 43f1173 commit 07f30d2
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
import org.apache.beam.repackaged.core.org.apache.commons.lang3.ClassUtils;
import org.apache.beam.runners.core.construction.External;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.schemas.JavaFieldSchema;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
Expand All @@ -46,6 +47,8 @@
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
Expand All @@ -56,6 +59,7 @@ public class ExternalPythonTransform<InputT extends PInput, OutputT extends POut

private static final SchemaRegistry SCHEMA_REGISTRY = SchemaRegistry.createDefault();
private String fullyQualifiedName;
private String expansionService;

// We preseve the order here since Schema's care about order of fields but the order will not
// matter when applying kwargs at the Python side.
Expand All @@ -64,8 +68,9 @@ public class ExternalPythonTransform<InputT extends PInput, OutputT extends POut
private @Nullable Object @NonNull [] argsArray;
private @Nullable Row providedKwargsRow;

private ExternalPythonTransform(String fullyQualifiedName) {
private ExternalPythonTransform(String fullyQualifiedName, String expansionService) {
this.fullyQualifiedName = fullyQualifiedName;
this.expansionService = expansionService;
this.kwargsMap = new TreeMap<>();
argsArray = new Object[] {};
}
Expand All @@ -80,7 +85,21 @@ private ExternalPythonTransform(String fullyQualifiedName) {
*/
public static <InputT extends PInput, OutputT extends POutput>
ExternalPythonTransform<InputT, OutputT> from(String tranformName) {
return new ExternalPythonTransform<InputT, OutputT>(tranformName);
return new ExternalPythonTransform<>(tranformName, "");
}

/**
* Instantiates a cross-language wrapper for a Python transform with a given transform name.
*
* @param tranformName fully qualified transform name.
* @param expansionService address and port number for externally launched expansion service
* @param <InputT> Input {@link PCollection} type
* @param <OutputT> Output {@link PCollection} type
* @return A {@link ExternalPythonTransform} for the given transform name.
*/
public static <InputT extends PInput, OutputT extends POutput>
ExternalPythonTransform<InputT, OutputT> from(String tranformName, String expansionService) {
return new ExternalPythonTransform<>(tranformName, expansionService);
}

/**
Expand Down Expand Up @@ -242,64 +261,95 @@ private Schema generateSchemaFromFieldValues(
return generateSchemaDirectly(fieldValues, fieldNames);
}

@Override
public OutputT expand(InputT input) {
int port;
@VisibleForTesting
ExternalTransforms.ExternalConfigurationPayload generatePayload() {
Row argsRow = buildOrGetArgsRow();
Row kwargsRow = buildOrGetKwargsRow();
Schema.Builder schemaBuilder = Schema.builder();
schemaBuilder.addStringField("constructor");
if (argsRow.getValues().size() > 0) {
schemaBuilder.addRowField("args", argsRow.getSchema());
}
if (kwargsRow.getValues().size() > 0) {
schemaBuilder.addRowField("kwargs", kwargsRow.getSchema());
}
Schema payloadSchema = schemaBuilder.build();
payloadSchema.setUUID(UUID.randomUUID());
Row.Builder payloadRowBuilder = Row.withSchema(payloadSchema);
payloadRowBuilder.addValue(fullyQualifiedName);
if (argsRow.getValues().size() > 0) {
payloadRowBuilder.addValue(argsRow);
}
if (kwargsRow.getValues().size() > 0) {
payloadRowBuilder.addValue(kwargsRow);
}
try {
port = PythonService.findAvailablePort();
PythonService service =
new PythonService(
"apache_beam.runners.portability.expansion_service_main",
"--port",
"" + port,
"--fully_qualified_name_glob",
"*");
Schema payloadSchema =
Schema.of(
Schema.Field.of("constructor", Schema.FieldType.STRING),
Schema.Field.of("args", Schema.FieldType.row(argsRow.getSchema())),
Schema.Field.of("kwargs", Schema.FieldType.row(kwargsRow.getSchema())));
payloadSchema.setUUID(UUID.randomUUID());
Row payloadRow =
Row.withSchema(payloadSchema).addValues(fullyQualifiedName, argsRow, kwargsRow).build();
ExternalTransforms.ExternalConfigurationPayload payload =
ExternalTransforms.ExternalConfigurationPayload.newBuilder()
.setSchema(SchemaTranslation.schemaToProto(payloadSchema, true))
.setPayload(
ByteString.copyFrom(
CoderUtils.encodeToByteArray(RowCoder.of(payloadSchema), payloadRow)))
.build();
try (AutoCloseable p = service.start()) {
PythonService.waitForPort("localhost", port, 15000);
PTransform<PInput, PCollectionTuple> transform =
External.<PInput, Object>of(
"beam:transforms:python:fully_qualified_named",
payload.toByteArray(),
"localhost:" + port)
.withMultiOutputs();
PCollectionTuple outputs;
if (input instanceof PCollection) {
outputs = ((PCollection<?>) input).apply(transform);
} else if (input instanceof PCollectionTuple) {
outputs = ((PCollectionTuple) input).apply(transform);
} else if (input instanceof PBegin) {
outputs = ((PBegin) input).apply(transform);
} else {
throw new RuntimeException("Unhandled input type " + input.getClass());
}
Set<TupleTag<?>> tags = outputs.getAll().keySet();
if (tags.size() == 1) {
return (OutputT) outputs.get(Iterables.getOnlyElement(tags));
} else {
return (OutputT) outputs;
return ExternalTransforms.ExternalConfigurationPayload.newBuilder()
.setSchema(SchemaTranslation.schemaToProto(payloadSchema, true))
.setPayload(
ByteString.copyFrom(
CoderUtils.encodeToByteArray(
RowCoder.of(payloadSchema), payloadRowBuilder.build())))
.build();
} catch (CoderException e) {
throw new RuntimeException(e);
}
}

@Override
public OutputT expand(InputT input) {
try {
ExternalTransforms.ExternalConfigurationPayload payload = generatePayload();
if (!Strings.isNullOrEmpty(expansionService)) {
PythonService.waitForPort(
Iterables.get(Splitter.on(':').split(expansionService), 0),
Integer.parseInt(Iterables.get(Splitter.on(':').split(expansionService), 1)),
15000);
return apply(input, expansionService, payload);
} else {
int port = PythonService.findAvailablePort();
PythonService service =
new PythonService(
"apache_beam.runners.portability.expansion_service_main",
"--port",
"" + port,
"--fully_qualified_name_glob",
"*");
try (AutoCloseable p = service.start()) {
PythonService.waitForPort("localhost", port, 15000);
return apply(input, String.format("localhost:%s", port), payload);
}
}
} catch (RuntimeException exn) {
throw exn;
} catch (Exception exn) {
throw new RuntimeException(exn);
}
}

private OutputT apply(
InputT input,
String expansionService,
ExternalTransforms.ExternalConfigurationPayload payload) {
PTransform<PInput, PCollectionTuple> transform =
External.of(
"beam:transforms:python:fully_qualified_named",
payload.toByteArray(),
expansionService)
.withMultiOutputs();
PCollectionTuple outputs;
if (input instanceof PCollection) {
outputs = ((PCollection<?>) input).apply(transform);
} else if (input instanceof PCollectionTuple) {
outputs = ((PCollectionTuple) input).apply(transform);
} else if (input instanceof PBegin) {
outputs = ((PBegin) input).apply(transform);
} else {
throw new RuntimeException("Unhandled input type " + input.getClass());
}
Set<TupleTag<?>> tags = outputs.getAll().keySet();
if (tags.size() == 1) {
return (OutputT) outputs.get(Iterables.getOnlyElement(tags));
} else {
return (OutputT) outputs;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
package org.apache.beam.sdk.extensions.python;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import java.io.Serializable;
import java.util.Map;
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.MapElements;
Expand Down Expand Up @@ -118,6 +122,38 @@ public void generateArgsWithRow() {
assertEquals(expectedRow, receivedRow);
}

@Test
public void generatePayloadWithoutKwargs() throws Exception {
ExternalPythonTransform<?, ?> transform =
ExternalPythonTransform
.<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
"DummyTransform")
.withArgs("aaa", "bbb", 11, 12L, 15.6, true);
ExternalTransforms.ExternalConfigurationPayload payload = transform.generatePayload();

Schema schema = SchemaTranslation.schemaFromProto(payload.getSchema());
assertTrue(schema.hasField("args"));
assertFalse(schema.hasField("kwargs"));
}

@Test
public void generatePayloadWithoutArgs() {
ExternalPythonTransform<?, ?> transform =
ExternalPythonTransform
.<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
"DummyTransform")
.withKwarg("stringField1", "aaa")
.withKwarg("stringField2", "bbb")
.withKwarg("intField", 11)
.withKwarg("longField", 12L)
.withKwarg("doubleField", 15.6)
.withKwarg("boolField", true);
ExternalTransforms.ExternalConfigurationPayload payload = transform.generatePayload();
Schema schema = SchemaTranslation.schemaFromProto(payload.getSchema());
assertFalse(schema.hasField("args"));
assertTrue(schema.hasField("kwargs"));
}

static class CustomType {
int intField;
String strField;
Expand Down

0 comments on commit 07f30d2

Please sign in to comment.