Skip to content

Commit

Permalink
Merge pull request #10802 from boyuanzz/refactor
Browse files Browse the repository at this point in the history
[BEAM-8537] Move wrappers of RestrictionTracker out of iobase
  • Loading branch information
boyuanzz committed Feb 11, 2020
2 parents 6d721c1 + cd6e54b commit bcc3e13
Show file tree
Hide file tree
Showing 7 changed files with 337 additions and 260 deletions.
128 changes: 0 additions & 128 deletions sdks/python/apache_beam/io/iobase.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import logging
import math
import random
import threading
import uuid
from builtins import object
from builtins import range
Expand Down Expand Up @@ -65,9 +64,7 @@
from apache_beam.utils.windowed_value import WindowedValue

if TYPE_CHECKING:
from apache_beam.io import restriction_trackers
from apache_beam.runners.pipeline_context import PipelineContext
from apache_beam.utils.timestamp import Timestamp

__all__ = [
'BoundedSource',
Expand Down Expand Up @@ -1246,131 +1243,6 @@ def try_claim(self, position):
raise NotImplementedError


class ThreadsafeRestrictionTracker(object):
"""A thread-safe wrapper which wraps a `RestritionTracker`.
This wrapper guarantees synchronization of modifying restrictions across
multi-thread.
"""
def __init__(self, restriction_tracker):
# type: (RestrictionTracker) -> None
if not isinstance(restriction_tracker, RestrictionTracker):
raise ValueError(
'Initialize ThreadsafeRestrictionTracker requires'
'RestrictionTracker.')
self._restriction_tracker = restriction_tracker
# Records an absolute timestamp when defer_remainder is called.
self._deferred_timestamp = None
self._lock = threading.RLock()
self._deferred_residual = None
self._deferred_watermark = None

def current_restriction(self):
with self._lock:
return self._restriction_tracker.current_restriction()

def try_claim(self, position):
with self._lock:
return self._restriction_tracker.try_claim(position)

def defer_remainder(self, deferred_time=None):
"""Performs self-checkpoint on current processing restriction with an
expected resuming time.
Self-checkpoint could happen during processing elements. When executing an
DoFn.process(), you may want to stop processing an element and resuming
later if current element has been processed quit a long time or you also
want to have some outputs from other elements. ``defer_remainder()`` can be
called on per element if needed.
Args:
deferred_time: A relative ``timestamp.Duration`` that indicates the ideal
time gap between now and resuming, or an absolute ``timestamp.Timestamp``
for resuming execution time. If the time_delay is None, the deferred work
will be executed as soon as possible.
"""

# Record current time for calculating deferred_time later.
self._deferred_timestamp = timestamp.Timestamp.now()
if (deferred_time and not isinstance(deferred_time, timestamp.Duration) and
not isinstance(deferred_time, timestamp.Timestamp)):
raise ValueError(
'The timestamp of deter_remainder() should be a '
'Duration or a Timestamp, or None.')
self._deferred_watermark = deferred_time
checkpoint = self.try_split(0)
if checkpoint:
_, self._deferred_residual = checkpoint

def check_done(self):
with self._lock:
return self._restriction_tracker.check_done()

def current_progress(self):
with self._lock:
return self._restriction_tracker.current_progress()

def try_split(self, fraction_of_remainder):
with self._lock:
return self._restriction_tracker.try_split(fraction_of_remainder)

def deferred_status(self):
# type: () -> Optional[Tuple[Any, Timestamp]]

"""Returns deferred work which is produced by ``defer_remainder()``.
When there is a self-checkpoint performed, the system needs to fulfill the
DelayedBundleApplication with deferred_work for a ProcessBundleResponse.
The system calls this API to get deferred_residual with watermark together
to help the runner to schedule a future work.
Returns: (deferred_residual, time_delay) if having any residual, else None.
"""
if self._deferred_residual:
# If _deferred_watermark is None, create Duration(0).
if not self._deferred_watermark:
self._deferred_watermark = timestamp.Duration()
# If an absolute timestamp is provided, calculate the delta between
# the absoluted time and the time deferred_status() is called.
elif isinstance(self._deferred_watermark, timestamp.Timestamp):
self._deferred_watermark = (
self._deferred_watermark - timestamp.Timestamp.now())
# If a Duration is provided, the deferred time should be:
# provided duration - the spent time since the defer_remainder() is
# called.
elif isinstance(self._deferred_watermark, timestamp.Duration):
self._deferred_watermark -= (
timestamp.Timestamp.now() - self._deferred_timestamp)
return self._deferred_residual, self._deferred_watermark
return None


class RestrictionTrackerView(object):
"""A DoFn view of thread-safe RestrictionTracker.
The RestrictionTrackerView wraps a ThreadsafeRestrictionTracker and only
exposes APIs that will be called by a ``DoFn.process()``. During execution
time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a
restriction_tracker.
"""
def __init__(self, threadsafe_restriction_tracker):
if not isinstance(threadsafe_restriction_tracker,
ThreadsafeRestrictionTracker):
raise ValueError(
'Initialize RestrictionTrackerView requires '
'ThreadsafeRestrictionTracker.')
self._threadsafe_restriction_tracker = threadsafe_restriction_tracker

def current_restriction(self):
return self._threadsafe_restriction_tracker.current_restriction()

def try_claim(self, position):
return self._threadsafe_restriction_tracker.try_claim(position)

def defer_remainder(self, deferred_time=None):
self._threadsafe_restriction_tracker.defer_remainder(deferred_time)


class RestrictionProgress(object):
"""Used to record the progress of a restriction.
Expand Down
86 changes: 1 addition & 85 deletions sdks/python/apache_beam/io/iobase_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
# limitations under the License.
#

"""Unit tests for the SDFRestrictionProvider module."""
"""Unit tests for classes in iobase.py."""

# pytype: skip-file

from __future__ import absolute_import

import time
import unittest

import mock
Expand All @@ -31,9 +30,6 @@
from apache_beam.io.concat_source_test import RangeSource
from apache_beam.io import iobase
from apache_beam.io.iobase import SourceBundle
from apache_beam.io.restriction_trackers import OffsetRange
from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
from apache_beam.utils import timestamp
from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
Expand Down Expand Up @@ -227,85 +223,5 @@ def test_sdf_wrap_range_source(self):
self._run_sdf_wrapper_pipeline(RangeSource(0, 4), [0, 1, 2, 3])


class ThreadsafeRestrictionTrackerTest(unittest.TestCase):
def test_initialization(self):
with self.assertRaises(ValueError):
iobase.ThreadsafeRestrictionTracker(RangeSource(0, 1))

def test_defer_remainder_with_wrong_time_type(self):
threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
OffsetRestrictionTracker(OffsetRange(0, 10)))
with self.assertRaises(ValueError):
threadsafe_tracker.defer_remainder(10)

def test_self_checkpoint_immediately(self):
restriction_tracker = OffsetRestrictionTracker(OffsetRange(0, 10))
threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
restriction_tracker)
threadsafe_tracker.defer_remainder()
deferred_residual, deferred_time = threadsafe_tracker.deferred_status()
expected_residual = OffsetRange(0, 10)
self.assertEqual(deferred_residual, expected_residual)
self.assertTrue(isinstance(deferred_time, timestamp.Duration))
self.assertEqual(deferred_time, 0)

def test_self_checkpoint_with_relative_time(self):
threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
OffsetRestrictionTracker(OffsetRange(0, 10)))
threadsafe_tracker.defer_remainder(timestamp.Duration(100))
time.sleep(2)
_, deferred_time = threadsafe_tracker.deferred_status()
self.assertTrue(isinstance(deferred_time, timestamp.Duration))
# The expectation = 100 - 2 - some_delta
self.assertTrue(deferred_time <= 98)

def test_self_checkpoint_with_absolute_time(self):
threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
OffsetRestrictionTracker(OffsetRange(0, 10)))
now = timestamp.Timestamp.now()
schedule_time = now + timestamp.Duration(100)
self.assertTrue(isinstance(schedule_time, timestamp.Timestamp))
threadsafe_tracker.defer_remainder(schedule_time)
time.sleep(2)
_, deferred_time = threadsafe_tracker.deferred_status()
self.assertTrue(isinstance(deferred_time, timestamp.Duration))
# The expectation =
# schedule_time - the time when deferred_status is called - some_delta
self.assertTrue(deferred_time <= 98)


class RestrictionTrackerViewTest(unittest.TestCase):
def test_initialization(self):
with self.assertRaises(ValueError):
iobase.RestrictionTrackerView(
OffsetRestrictionTracker(OffsetRange(0, 10)))

def test_api_expose(self):
threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
OffsetRestrictionTracker(OffsetRange(0, 10)))
tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
current_restriction = tracker_view.current_restriction()
self.assertEqual(current_restriction, OffsetRange(0, 10))
self.assertTrue(tracker_view.try_claim(0))
tracker_view.defer_remainder()
deferred_remainder, deferred_watermark = (
threadsafe_tracker.deferred_status())
self.assertEqual(deferred_remainder, OffsetRange(1, 10))
self.assertEqual(deferred_watermark, timestamp.Duration())

def test_non_expose_apis(self):
threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
OffsetRestrictionTracker(OffsetRange(0, 10)))
tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
with self.assertRaises(AttributeError):
tracker_view.check_done()
with self.assertRaises(AttributeError):
tracker_view.current_progress()
with self.assertRaises(AttributeError):
tracker_view.try_split()
with self.assertRaises(AttributeError):
tracker_view.deferred_status()


if __name__ == '__main__':
unittest.main()

0 comments on commit bcc3e13

Please sign in to comment.