Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
from azure.durable_functions.models.actions.SignalEntityAction import SignalEntityAction
from azure.durable_functions.models.actions.CallEntityAction import CallEntityAction
from azure.durable_functions.models.Task import TaskBase
from azure.durable_functions.models.Task import TaskBase, TimerTask
from azure.durable_functions.models.actions.CallHttpAction import CallHttpAction
from azure.durable_functions.models.DurableHttpRequest import DurableHttpRequest
from azure.durable_functions.models.actions.CallSubOrchestratorWithRetryAction import \
Expand Down Expand Up @@ -100,7 +100,8 @@ def from_json(cls, json_string: str):
def _generate_task(self, action: Action,
retry_options: Optional[RetryOptions] = None,
id_: Optional[Union[int, str]] = None,
parent: Optional[TaskBase] = None) -> Union[AtomicTask, RetryAbleTask]:
parent: Optional[TaskBase] = None,
task_constructor=AtomicTask) -> Union[AtomicTask, RetryAbleTask, TimerTask]:
"""Generate an atomic or retryable Task based on an input.

Parameters
Expand All @@ -124,7 +125,7 @@ def _generate_task(self, action: Action,
action_payload = [action]
else:
action_payload = action
task = AtomicTask(id_, action_payload)
task = task_constructor(id_, action_payload)
task.parent = parent

# if task is retryable, provide the retryable wrapper class
Expand Down Expand Up @@ -517,7 +518,7 @@ def create_timer(self, fire_at: datetime.datetime) -> TaskBase:
A Durable Timer Task that schedules the timer to wake up the activity
"""
action = CreateTimerAction(fire_at)
task = self._generate_task(action)
task = self._generate_task(action, task_constructor=TimerTask)
return task

def wait_for_external_event(self, name: str) -> TaskBase:
Expand Down
90 changes: 82 additions & 8 deletions azure/durable_functions/models/Task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from azure.durable_functions.models.actions.Action import Action
from azure.durable_functions.models.actions.WhenAnyAction import WhenAnyAction
from azure.durable_functions.models.actions.WhenAllAction import WhenAllAction
from azure.durable_functions.models.actions.CreateTimerAction import CreateTimerAction

import enum
from typing import Any, List, Optional, Set, Type, Union
Expand Down Expand Up @@ -56,6 +57,14 @@ def __init__(self, id_: Union[int, str], actions: Union[List[Action], Action]):
self.action_repr: Union[List[Action], Action] = actions
self.is_played = False

@property
def is_completed(self) -> bool:
"""Get indicator of whether the task completed.

Note that completion is not equivalent to success.
"""
return not(self.state is TaskState.RUNNING)

def set_is_played(self, is_played: bool):
"""Set the is_played flag for the Task.

Expand Down Expand Up @@ -159,6 +168,9 @@ def __init__(self, tasks: List[TaskBase], compound_action_constructor=None):
self.completed_tasks: List[TaskBase] = []
self.children = tasks

if len(self.children) == 0:
self.state = TaskState.SUCCEEDED

def handle_completion(self, child: TaskBase):
"""Manage sub-task completion events.

Expand Down Expand Up @@ -205,7 +217,47 @@ def try_set_value(self, child: TaskBase):
class AtomicTask(TaskBase):
"""A Task with no subtasks."""

pass
def _get_action(self) -> Action:
action: Action
if isinstance(self.action_repr, list):
action = self.action_repr[0]
else:
action = self.action_repr
return action


class TimerTask(AtomicTask):
"""A Timer Task."""

def __init__(self, id_: Union[int, str], action: CreateTimerAction):
super().__init__(id_, action)
self.action_repr: Union[List[CreateTimerAction], CreateTimerAction]

@property
def is_cancelled(self) -> bool:
"""Check if the Timer is cancelled.

Returns
-------
bool
Returns whether a timer has been cancelled or not
"""
action: CreateTimerAction = self._get_action()
return action.is_cancelled

def cancel(self):
"""Cancel a timer.

Raises
------
ValueError
Raises an error if the task is already completed and an attempt is made to cancel it
"""
if not self.is_completed:
action: CreateTimerAction = self._get_action()
action.is_cancelled = True
else:
raise ValueError("Cannot cancel a completed task.")


class WhenAllTask(CompoundTask):
Expand Down Expand Up @@ -238,7 +290,7 @@ def try_set_value(self, child: TaskBase):
# A WhenAll Task only completes when it has no pending tasks
# i.e _when all_ of its children have completed
if len(self.pending_tasks) == 0:
results = list(map(lambda x: x.result, self.completed_tasks))
results = list(map(lambda x: x.result, self.children))
self.set_value(is_error=False, value=results)
else: # child.state is TaskState.FAILED:
# a single error is sufficient to fail this task
Expand Down Expand Up @@ -287,14 +339,28 @@ class RetryAbleTask(WhenAllTask):
"""

def __init__(self, child: TaskBase, retry_options: RetryOptions, context):
self.id_ = str(child.id) + "_retryable_proxy"
tasks = [child]
super().__init__(tasks, context._replay_schema)

self.retry_options = retry_options
self.num_attempts = 1
self.context = context
self.actions = child.action_repr
self.is_waiting_on_timer = False

@property
def id_(self):
"""Obtain the task's ID.

Since this is an internal-only abstraction, the task ID is represented
by the ID of its inner/wrapped task _plus_ a suffix: "_retryable_proxy"

Returns
-------
[type]
[description]
"""
return str(list(map(lambda x: x.id, self.children))) + "_retryable_proxy"

def try_set_value(self, child: TaskBase):
"""Transition a Retryable Task to a terminal state and set its value.
Expand All @@ -304,6 +370,14 @@ def try_set_value(self, child: TaskBase):
child : TaskBase
A sub-task that just completed
"""
if self.is_waiting_on_timer:
# timer fired, re-scheduling original task
self.is_waiting_on_timer = False
rescheduled_task = self.context._generate_task(
action=NoOpAction("rescheduled task"), parent=self)
self.pending_tasks.add(rescheduled_task)
self.context._add_to_open_tasks(rescheduled_task)
return
if child.state is TaskState.SUCCEEDED:
if len(self.pending_tasks) == 0:
# if all pending tasks have completed,
Expand All @@ -318,11 +392,11 @@ def try_set_value(self, child: TaskBase):
else:
# still have some retries left.
# increase size of pending tasks by adding a timer task
# and then re-scheduling the current task after that
timer_task = self.context._generate_task(action=NoOpAction(), parent=self)
# when it completes, we'll retry the original task
timer_task = self.context._generate_task(
action=NoOpAction("-WithRetry timer"), parent=self)
self.pending_tasks.add(timer_task)
self.context._add_to_open_tasks(timer_task)
rescheduled_task = self.context._generate_task(action=NoOpAction(), parent=self)
self.pending_tasks.add(rescheduled_task)
self.context._add_to_open_tasks(rescheduled_task)
self.is_waiting_on_timer = True

self.num_attempts += 1
8 changes: 6 additions & 2 deletions azure/durable_functions/models/TaskOrchestrationExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,12 @@ def parse_history_event(directive_result):
# retrieve result
new_value = parse_history_event(event)
if task._api_name == "CallEntityAction":
new_value = ResponseMessage.from_dict(new_value)
new_value = json.loads(new_value.result)
event_payload = ResponseMessage.from_dict(new_value)
new_value = json.loads(event_payload.result)

if event_payload.is_exception:
new_value = Exception(new_value)
is_success = False
else:
# generate exception
new_value = Exception(f"{event.Reason} \n {event.Details}")
Expand Down
17 changes: 16 additions & 1 deletion azure/durable_functions/models/actions/NoOpAction.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
from azure.durable_functions.models.actions.Action import Action
from typing import Any, Dict
from typing import Any, Dict, Optional


class NoOpAction(Action):
"""A no-op action, for anonymous tasks only."""

def __init__(self, metadata: Optional[str] = None):
"""Create a NoOpAction object.
This is an internal-only action class used to represent cases when intermediate
tasks are used to implement some API. For example, in -WithRetry APIs, intermediate
timers are created. We create this NoOp action to track those the backing actions
of those tasks, which is necessary because we mimic the DF-internal replay algorithm.
Parameters
----------
metadata : Optional[str]
Used for internal debugging: metadata about the action being represented.
"""
self.metadata = metadata

def action_type(self) -> int:
"""Get the type of action this class represents."""
raise Exception("Attempted to get action type of an anonymous Action")
Expand Down
6 changes: 4 additions & 2 deletions azure/durable_functions/models/entities/ResponseMessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class ResponseMessage:
Specifies the response of an entity, as processed by the durable-extension.
"""

def __init__(self, result: str):
def __init__(self, result: str, is_exception: bool = False):
"""Instantiate a ResponseMessage.

Specifies the response of an entity, as processed by the durable-extension.
Expand All @@ -18,6 +18,7 @@ def __init__(self, result: str):
The result provided by the entity
"""
self.result = result
self.is_exception = is_exception
# TODO: JS has an additional exceptionType field, but does not use it

@classmethod
Expand All @@ -34,5 +35,6 @@ def from_dict(cls, d: Dict[str, Any]) -> 'ResponseMessage':
ResponseMessage:
The ResponseMessage built from the provided dictionary
"""
result = cls(d["result"])
is_error = "exceptionType" in d.keys()
result = cls(d["result"], is_error)
return result
5 changes: 3 additions & 2 deletions tests/orchestrator/orchestrator_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ def assert_entity_state_equals(expected, result):
assert_attribute_equal(expected, result, "signals")

def assert_results_are_equal(expected: Dict[str, Any], result: Dict[str, Any]) -> bool:
assert_attribute_equal(expected, result, "result")
assert_attribute_equal(expected, result, "isError")
for (payload_expected, payload_result) in zip(expected, result):
assert_attribute_equal(payload_expected, payload_result, "result")
assert_attribute_equal(payload_expected, payload_result, "isError")

def assert_attribute_equal(expected, result, attribute):
if attribute in expected:
Expand Down
37 changes: 36 additions & 1 deletion tests/orchestrator/test_create_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ def generator_function(context):
yield context.create_timer(fire_at)
return "Done!"

def generator_function_timer_can_be_cancelled(context):
time_limit1 = context.current_utc_datetime + timedelta(minutes=5)
timer_task1 = context.create_timer(time_limit1)

time_limit2 = context.current_utc_datetime + timedelta(minutes=10)
timer_task2 = context.create_timer(time_limit2)

winner = yield context.task_any([timer_task1, timer_task2])
if winner == timer_task1:
timer_task2.cancel()
return "Done!"
else:
raise Exception("timer task 1 should complete before timer task 2")

def add_timer_action(state: OrchestratorState, fire_at: datetime):
action = CreateTimerAction(fire_at=fire_at)
state._actions.append([action])
Expand Down Expand Up @@ -64,4 +78,25 @@ def test_timers_comparison_with_relaxed_precision():
#assert_valid_schema(result)
# TODO: getting the following error when validating the schema
# "Additional properties are not allowed ('fireAt', 'isCanceled' were unexpected)">
assert_orchestration_state_equals(expected, result)
assert_orchestration_state_equals(expected, result)

def test_timers_can_be_cancelled():

context_builder = ContextBuilder("test_timers_can_be_cancelled")
fire_at1 = context_builder.current_datetime + timedelta(minutes=5)
fire_at2 = context_builder.current_datetime + timedelta(minutes=10)
add_timer_fired_events(context_builder, 0, str(fire_at1))
add_timer_fired_events(context_builder, 1, str(fire_at2))

result = get_orchestration_state_result(
context_builder, generator_function_timer_can_be_cancelled)

expected_state = base_expected_state(output='Done!')
expected_state._actions.append(
[CreateTimerAction(fire_at=fire_at1), CreateTimerAction(fire_at=fire_at2, is_cancelled=True)])

expected_state._is_done = True
expected = expected_state.to_json()

assert_orchestration_state_equals(expected, result)
assert result["actions"][0][1]["isCanceled"]
Loading