Skip to content

Commit

Permalink
Add jinja preprocessing to YamlTemplate
Browse files Browse the repository at this point in the history
Signed-off-by: Jeffrey Kinard <jeff@thekinards.com>
  • Loading branch information
Polber committed May 7, 2024
1 parent 45a5b7f commit f6e9709
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ WORKDIR $WORKDIR
RUN if ! [ -f requirements.txt ] ; then echo "$BEAM_PACKAGE" > requirements.txt ; fi

# Install dependencies to launch the pipeline and download to reduce startup time
# Remove Jinja2 dependency once YAML templatization support is added to Beam
RUN python -m venv /venv \
&& /venv/bin/pip install --no-cache-dir --upgrade pip setuptools \
&& /venv/bin/pip install --no-cache-dir -U -r $REQUIREMENTS_FILE \
&& /venv/bin/pip install --no-cache-dir -U Jinja2 \
&& /venv/bin/pip download --no-cache-dir --dest /tmp/dataflow-requirements-cache -r $REQUIREMENTS_FILE \
&& rm -rf /usr/local/lib/python$PY_VERSION/site-packages \
&& mv /venv/lib/python$PY_VERSION/site-packages /usr/local/lib/python$PY_VERSION/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,13 @@ public interface YAMLTemplate {
description = "Input YAML pipeline spec file in Cloud Storage.",
helpText = "A file in Cloud Storage containing a yaml description of the pipeline to run.")
String getYamlPipelineFile();

@TemplateParameter.Text(
order = 3,
name = "jinja_variables",
optional = true,
description = "Input jinja preprocessing variables.",
helpText =
"A json dict of variables used when invoking the jinja preprocessor on the provided yaml pipeline.")
String getJinjaVariables();
}
147 changes: 139 additions & 8 deletions python/src/main/python/yaml-template/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,153 @@
# License for the specific language governing permissions and limitations under
# the License.
#
#
# import argparse
# import json
# import logging
# import os
# from argparse import Namespace
#
# import jinja2
#
# from apache_beam.io.filesystems import FileSystems
# from apache_beam.yaml import cache_provider_artifacts
# from apache_beam.yaml import main
#
#
# def _configure_parser(argv):
# parser = argparse.ArgumentParser()
# parser.add_argument(
# '--jinja_variables',
# default=None,
# type=json.loads,
# help='A json dict of variables used when invoking the jinja preprocessor '
# 'on the provided yaml pipeline.')
# return parser.parse_known_args(argv)
#
#
# class _BeamFileIOLoader(jinja2.BaseLoader):
# def get_source(self, environment, path):
# source = FileSystems.open(path).read().decode()
# return source, path, lambda: True
#
#
# def run(argv=None):
# known_args, pipeline_args = _configure_parser(argv)
# known_args = Namespace(
# **vars(main._configure_parser(argv)[0]), **vars(known_args))
# pipeline_yaml = (
# jinja2.Environment(
# undefined=jinja2.StrictUndefined, loader=_BeamFileIOLoader())
# .from_string(main._pipeline_spec_from_args(known_args))
# .render(**known_args.jinja_variables or {}))
# pipeline_yaml = os.linesep.join(
# [s for s in pipeline_yaml.splitlines() if s.strip()])
#
# pipeline_args = [
# '--sdk_location=container', f'--yaml_pipeline={pipeline_yaml}'
# ] + [
# f'--{name}={value}' for name,
# value in vars(known_args).items()
# if name != "yaml_pipeline" and name != "yaml_pipeline_file"
# ]
# cache_provider_artifacts.cache_provider_artifacts()
# main.run(argv=pipeline_args)
#
#
# if __name__ == '__main__':
# logging.getLogger().setLevel(logging.INFO)
# run()


import argparse
import logging
import json

import jinja2
import yaml

import apache_beam as beam
from apache_beam.io.filesystems import FileSystems
from apache_beam.typehints.schemas import LogicalType
from apache_beam.typehints.schemas import MillisInstant
from apache_beam.yaml import cache_provider_artifacts
from apache_beam.yaml import main
from apache_beam.yaml import yaml_transform

# Workaround for https://github.com/apache/beam/issues/28151.
LogicalType.register_logical_type(MillisInstant)

def run(argv=None):

def _configure_parser(argv):
parser = argparse.ArgumentParser()
_, pipeline_args = parser.parse_known_args(argv)
pipeline_args += ['--sdk_location=container']
cache_provider_artifacts.cache_provider_artifacts()
main.run(argv=pipeline_args)
parser.add_argument(
'--yaml_pipeline',
'--pipeline_spec',
help='A yaml description of the pipeline to run.')
parser.add_argument(
'--yaml_pipeline_file',
'--pipeline_spec_file',
help='A file containing a yaml description of the pipeline to run.')
parser.add_argument(
'--json_schema_validation',
default='generic',
help='none: do no pipeline validation against the schema; '
'generic: validate the pipeline shape, but not individual transforms; '
'per_transform: also validate the config of known transforms')
parser.add_argument(
'--jinja_variables',
default=None,
type=json.loads,
help='A json dict of variables used when invoking the jinja preprocessor '
'on the provided yaml pipeline.')
return parser.parse_known_args(argv)


def _pipeline_spec_from_args(known_args):
if known_args.yaml_pipeline_file and known_args.yaml_pipeline:
raise ValueError(
"Exactly one of yaml_pipeline or yaml_pipeline_file must be set.")
elif known_args.yaml_pipeline_file:
with FileSystems.open(known_args.yaml_pipeline_file) as fin:
pipeline_yaml = fin.read().decode()
elif known_args.yaml_pipeline:
pipeline_yaml = known_args.yaml_pipeline
else:
raise ValueError(
"Exactly one of yaml_pipeline or yaml_pipeline_file must be set.")

return pipeline_yaml


class _BeamFileIOLoader(jinja2.BaseLoader):
def get_source(self, environment, path):
source = FileSystems.open(path).read().decode()
return source, path, lambda: True


def run(argv=None):
known_args, pipeline_args = _configure_parser(argv)
pipeline_yaml = ( # keep formatting
jinja2.Environment(
undefined=jinja2.StrictUndefined, loader=_BeamFileIOLoader())
.from_string(_pipeline_spec_from_args(known_args))
.render(**known_args.jinja_variables or {}))
pipeline_spec = yaml.load(pipeline_yaml, Loader=yaml_transform.SafeLineLoader)

with beam.Pipeline( # linebreak for better yapf formatting
options=beam.options.pipeline_options.PipelineOptions(
pipeline_args,
pickle_library='cloudpickle',
**yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get(
'options', {}))),
display_data={'yaml': pipeline_yaml}) as p:
print("Building pipeline...")
yaml_transform.expand_pipeline(
p, pipeline_spec, validate_schema=known_args.json_schema_validation)
print("Running pipeline...")


if __name__ == '__main__':
import logging
logging.getLogger().setLevel(logging.INFO)
run()
cache_provider_artifacts.cache_provider_artifacts()
run()
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,24 @@ private void testSimpleComposite(
}

private String createSimpleYamlMessage() throws IOException {
String yamlMessage =
Files.readString(Paths.get(Resources.getResource("YamlTemplateIT.yaml").getPath()));
yamlMessage = yamlMessage.replaceAll("INPUT_PATH", getGcsBasePath() + "/input/test.csv");
return yamlMessage.replaceAll("OUTPUT_PATH", getGcsBasePath() + "/output");
return Files.readString(Paths.get(Resources.getResource("YamlTemplateIT.yaml").getPath()));
}

private void runYamlTemplateTest(
Function<PipelineLauncher.LaunchConfig.Builder, PipelineLauncher.LaunchConfig.Builder>
paramsAdder)
throws IOException {
// Arrange
String inputPath = getGcsBasePath() + "/input/test.csv";
String outputPath = getGcsBasePath() + "/output";
PipelineLauncher.LaunchConfig.Builder options =
paramsAdder.apply(PipelineLauncher.LaunchConfig.builder(testName, specPath));
paramsAdder.apply(
PipelineLauncher.LaunchConfig.builder(testName, specPath)
.addParameter(
"jinja_variables",
String.format(
"{\"INPUT_PATH_PARAM\": \"%s\", \"OUTPUT_PATH_PARAM\": \"%s\"}",
inputPath, outputPath)));

// Act
PipelineLauncher.LaunchInfo info = launchTemplate(options);
Expand Down
18 changes: 5 additions & 13 deletions python/src/test/resources/YamlTemplateIT.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pipeline:
transforms:
- type: ReadFromCsv
config:
path: "INPUT_PATH"
path: {{ INPUT_PATH_PARAM }}
- type: MapToFields
name: MapWithErrorHandling
input: ReadFromCsv
Expand Down Expand Up @@ -42,21 +42,13 @@ pipeline:
fields:
sum:
expression: num + inverse
- type: WriteToJsonPython
- type: WriteToJson
name: WriteGoodFiles
input: Sum
config:
path: "OUTPUT_PATH/good"
- type: WriteToJsonPython
path: {{ OUTPUT_PATH_PARAM }}/good
- type: WriteToJson
name: WriteBadFiles
input: TrimErrors
config:
path: "OUTPUT_PATH/bad"

# TODO(polber) - remove with https://github.com/apache/beam/pull/30777
providers:
- type: python
config:
packages: []
transforms:
'WriteToJsonPython': 'apache_beam.io.WriteToJson'
path: {{ OUTPUT_PATH_PARAM }}/bad

0 comments on commit f6e9709

Please sign in to comment.