Skip to content

Commit

Permalink
Add jinja preprocessing to YamlTemplate
Browse files Browse the repository at this point in the history
Signed-off-by: Jeffrey Kinard <jeff@thekinards.com>
  • Loading branch information
Polber committed May 2, 2024
1 parent 45a5b7f commit f715a98
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ WORKDIR $WORKDIR
RUN if ! [ -f requirements.txt ] ; then echo "$BEAM_PACKAGE" > requirements.txt ; fi

# Install dependencies to launch the pipeline and download to reduce startup time
# Remove Jinja2 dependency once YAML templatization support is added to Beam
RUN python -m venv /venv \
&& /venv/bin/pip install --no-cache-dir --upgrade pip setuptools \
&& /venv/bin/pip install --no-cache-dir -U -r $REQUIREMENTS_FILE \
&& /venv/bin/pip install --no-cache-dir -U Jinja2 \
&& /venv/bin/pip download --no-cache-dir --dest /tmp/dataflow-requirements-cache -r $REQUIREMENTS_FILE \
&& rm -rf /usr/local/lib/python$PY_VERSION/site-packages \
&& mv /venv/lib/python$PY_VERSION/site-packages /usr/local/lib/python$PY_VERSION/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,13 @@ public interface YAMLTemplate {
description = "Input YAML pipeline spec file in Cloud Storage.",
helpText = "A file in Cloud Storage containing a yaml description of the pipeline to run.")
String getYamlPipelineFile();

@TemplateParameter.Text(
order = 3,
name = "jinja_variables",
optional = true,
description = "Input jinja preprocessing variables.",
helpText =
"A json dict of variables used when invoking the jinja preprocessor on the provided yaml pipeline.")
String getJinjaVariables();
}
42 changes: 38 additions & 4 deletions python/src/main/python/yaml-template/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,52 @@
# License for the specific language governing permissions and limitations under
# the License.
#

#
import argparse
import json
import logging
import os
from argparse import Namespace

import jinja2

from apache_beam.io.filesystems import FileSystems
from apache_beam.yaml import cache_provider_artifacts
from apache_beam.yaml import main


def run(argv=None):
def _configure_parser(argv):
parser = argparse.ArgumentParser()
_, pipeline_args = parser.parse_known_args(argv)
pipeline_args += ['--sdk_location=container']
parser.add_argument(
'--jinja_variables',
default=None,
type=json.loads,
help='A json dict of variables used when invoking the jinja preprocessor '
'on the provided yaml pipeline.')
return parser.parse_known_args(argv)


class _BeamFileIOLoader(jinja2.BaseLoader):
def get_source(self, environment, path):
source = FileSystems.open(path).read().decode()
return source, path, lambda: True


def run(argv=None):
known_args, pipeline_args = _configure_parser(argv)
known_args = Namespace(
**vars(main._configure_parser(argv)[0]), **vars(known_args))
pipeline_yaml = (
jinja2.Environment(
undefined=jinja2.StrictUndefined, loader=_BeamFileIOLoader())
.from_string(main._pipeline_spec_from_args(known_args))
.render(**known_args.jinja_variables or {}))
pipeline_yaml = os.linesep.join(
[s for s in pipeline_yaml.splitlines() if s.strip()])

pipeline_args += [
'--sdk_location=container', f'--yaml_pipeline={pipeline_yaml}'
]
cache_provider_artifacts.cache_provider_artifacts()
main.run(argv=pipeline_args)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,24 @@ private void testSimpleComposite(
}

private String createSimpleYamlMessage() throws IOException {
String yamlMessage =
Files.readString(Paths.get(Resources.getResource("YamlTemplateIT.yaml").getPath()));
yamlMessage = yamlMessage.replaceAll("INPUT_PATH", getGcsBasePath() + "/input/test.csv");
return yamlMessage.replaceAll("OUTPUT_PATH", getGcsBasePath() + "/output");
return Files.readString(Paths.get(Resources.getResource("YamlTemplateIT.yaml").getPath()));
}

private void runYamlTemplateTest(
Function<PipelineLauncher.LaunchConfig.Builder, PipelineLauncher.LaunchConfig.Builder>
paramsAdder)
throws IOException {
// Arrange
String inputPath = getGcsBasePath() + "/input/test.csv";
String outputPath = getGcsBasePath() + "/output";
PipelineLauncher.LaunchConfig.Builder options =
paramsAdder.apply(PipelineLauncher.LaunchConfig.builder(testName, specPath));
paramsAdder.apply(
PipelineLauncher.LaunchConfig.builder(testName, specPath)
.addParameter(
"jinja_variables",
String.format(
"{\"INPUT_PATH_PARAM\": \"%s\", \"OUTPUT_PATH_PARAM\": \"%s\"}",
inputPath, outputPath)));

// Act
PipelineLauncher.LaunchInfo info = launchTemplate(options);
Expand Down
18 changes: 5 additions & 13 deletions python/src/test/resources/YamlTemplateIT.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pipeline:
transforms:
- type: ReadFromCsv
config:
path: "INPUT_PATH"
path: {{ INPUT_PATH_PARAM }}
- type: MapToFields
name: MapWithErrorHandling
input: ReadFromCsv
Expand Down Expand Up @@ -42,21 +42,13 @@ pipeline:
fields:
sum:
expression: num + inverse
- type: WriteToJsonPython
- type: WriteToJson
name: WriteGoodFiles
input: Sum
config:
path: "OUTPUT_PATH/good"
- type: WriteToJsonPython
path: {{ OUTPUT_PATH_PARAM }}/good
- type: WriteToJson
name: WriteBadFiles
input: TrimErrors
config:
path: "OUTPUT_PATH/bad"

# TODO(polber) - remove with https://github.com/apache/beam/pull/30777
providers:
- type: python
config:
packages: []
transforms:
'WriteToJsonPython': 'apache_beam.io.WriteToJson'
path: {{ OUTPUT_PATH_PARAM }}/bad

0 comments on commit f715a98

Please sign in to comment.