<a href="https://colab.research.google.com/github/svetakvsundhar/beam/blob/testing_blog_post/examples/notebooks/blogposts/unittests_in_beam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [23]:
# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the "License")

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License

In [24]:
# Install the Apache Beam library

!pip install apache_beam[gcp] --quiet

In [25]:
#The following packages are used to run the example pipelines

import apache_beam as beam
from apache_beam.io import ReadFromText, WriteToText
from apache_beam.options.pipeline_options import PipelineOptions

class CustomClass(beam.DoFn):
  def custom_function(x):
          ...
          # returned_record = requests.get("http://my-api-call.com")
          ...
          # if len(returned_record)!=10:
          # raise ValueError("Length of record does not match expected length")
          return x

  with beam.Pipeline() as p:
    result = (
            p
            | ReadFromText("/content/sample_data/anscombe.json")
            | beam.ParDo(lambda x: CustomClass.custom_function(x))
            | WriteToText("/content/")
    )



**Example Pipeline 1**


In [26]:
# This function is going to return the square the integer at the first index of our record.
def compute_square(element):
  return int(element[1])**2

with beam.Pipeline() as p1:
    result = (
        p1
        | ReadFromText("/content/sample_data/california_housing_test.csv",skip_header_lines=1)
        | beam.Map(compute_square)
        | WriteToText("/content/")
    )



**Example Pipeline 2**

In [27]:
with beam.Pipeline() as p2:
    result = (
        p2
        | ReadFromText("/content/sample_data/anscombe.json")
        | beam.Map(str.strip)
        | WriteToText("/content/sample_data/")
    )



**Unit Tests for Pipelines**

In [28]:
# The following packages are imported for unit testing.
import unittest
import apache_beam as beam
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that, equal_to
try:
  from apitools.base.py.exceptions import HttpError
except ImportError:
  HttpError = None


@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class TestBeam(unittest.TestCase):

# This test corresponds to pipeline p1, and is written to confirm the compute_square function works as intended.
  def test_compute_square(self):
    with TestPipeline() as p:
      output = p | beam.Create(["1234"]) \
                 | beam.Map(compute_square)
    assert_that(output, equal_to([4]))

In [29]:
# This test corresponds to pipeline p2, and is written to confirm the pipeline works as intended.
def test_strip_map(self):
  strings = [' Strawberry   \n', '   Carrot   \n', '   Eggplant   \n']
  with TestPipeline() as p:
    output = p | beam.Create(strings) \
               | beam.Map(str.strip)
  assert_that(output, equal_to(['Strawberry', 'Carrot', 'Eggplant']))

**Mocking Example**

In [30]:
!pip install mock  # Install the 'mock' module



In [31]:
# We import the mock package for mocking functionality.
import mock

@mock.patch.object(CustomClass, 'custom_function')
def test_error_message_wrong_length(self, get_record):
  record = ["field1","field2"]
  CustomClass.custom_function.return_value = record
  with self.assertRaisesRegex(ValueError,
                              "Length of record does not match expected length'"):
      p = beam.Pipeline()
      result = p | beam.ParDo(CustomClass.custom_function())
  result