Skip to content

Commit

Permalink
Merge pull request #416 from PrefectHQ/update-ids
Browse files Browse the repository at this point in the history
Add Task IDs and remove related cruft
  • Loading branch information
jlowin committed Dec 19, 2018
2 parents 05ff882 + dd6e1dd commit 1ca8c99
Show file tree
Hide file tree
Showing 14 changed files with 280 additions and 265 deletions.
48 changes: 25 additions & 23 deletions src/prefect/core/flow.py
Expand Up @@ -143,10 +143,10 @@ def __init__(
) -> None:
self._cache = {} # type: dict

self._id = str(uuid.uuid4())
self.logger = logging.get_logger("Flow")
# set random id
self.id = str(uuid.uuid4())

self.task_info = dict() # type: Dict[Task, dict]
self.logger = logging.get_logger("Flow")

self.name = name or type(self).__name__
self.schedule = schedule
Expand Down Expand Up @@ -207,7 +207,10 @@ def copy(self) -> "Flow":
Create and returns a copy of the current Flow.
"""
new = copy.copy(self)
# create a new cache
new._cache = dict()
# create new id
new.id = str(uuid.uuid4())
new.tasks = self.tasks.copy()
new.edges = self.edges.copy()
new.set_reference_tasks(self._reference_tasks)
Expand Down Expand Up @@ -273,7 +276,6 @@ def replace(self, old: Task, new: Task, validate: bool = True) -> None:

# update tasks
self.tasks.remove(old)
self.task_info.pop(old)
self.add_task(new)

self._cache.clear()
Expand Down Expand Up @@ -310,13 +312,25 @@ def replace(self, old: Task, new: Task, validate: bool = True) -> None:
def id(self) -> str:
return self._id

@id.setter
def id(self, value: str) -> None:
"""
Args:
- value (str): a UUID-formatted string
"""
try:
uuid.UUID(value)
except Exception:
raise ValueError("Badly formatted UUID string: {}".format(value))
self._id = value

@property # type: ignore
@cache
def task_ids(self) -> Dict[str, Task]:
"""
Returns a dictionary of {task_id: Task} pairs.
"""
return {self.task_info[task]["id"]: task for task in self.tasks}
return {task.id: task for task in self.tasks}

# Context Manager ----------------------------------------------------------

Expand Down Expand Up @@ -452,11 +466,6 @@ def add_task(self, task: Task) -> Task:

if task not in self.tasks:
self.tasks.add(task)
self.task_info[task] = {
"id": str(uuid.uuid4()),
"type": to_qualified_name(type(task)),
"mapped": False,
}
self._cache.clear()

return task
Expand Down Expand Up @@ -526,9 +535,6 @@ def add_edge(
}
inspect.signature(downstream_task.run).bind_partial(**edge_keys)

if mapped:
self.task_info[downstream_task]["mapped"] = True

self._cache.clear()

# check for cycles
Expand Down Expand Up @@ -701,9 +707,6 @@ def validate(self) -> None:
if any(t not in self.tasks for t in self.reference_tasks()):
raise ValueError("Some reference tasks are not contained in this flow.")

if self.tasks.difference(self.task_info):
raise ValueError("Some tasks are not in the task_info dict.")

def sorted_tasks(self, root_tasks: Iterable[Task] = None) -> Tuple[Task, ...]:
"""
Get the tasks in this flow in a sorted manner. This allows us to find if any
Expand Down Expand Up @@ -1095,16 +1098,15 @@ def generate_local_task_ids(
# Generate an ID for each task by hashing:
# - its flow's name
#
# This "fingerprints" each task in terms of its own characteristics and the parent flow.
# Note that the fingerprint does not include the flow version, meaning task IDs can
# remain stable across versions of the same flow.
# This "fingerprints" each task in terms of its own characteristics
#
# -----------------------------------------------------------

ids = {
t: _hash(json.dumps((t.serialize(), self.name), sort_keys=True))
for t in tasks
}
ids = {}
for t in tasks:
serialized = t.serialize()
del serialized["id"] # remove the ID since it is unique but random
ids[t] = _hash(json.dumps(serialized, sort_keys=True))

if _debug_steps:
debug_steps[1] = ids.copy()
Expand Down
20 changes: 20 additions & 0 deletions src/prefect/core/task.py
Expand Up @@ -3,6 +3,7 @@
import collections
import copy
import inspect
import uuid
import warnings
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Set, Tuple
Expand Down Expand Up @@ -140,6 +141,7 @@ def __init__(
self.name = name or type(self).__name__
self.slug = slug

self.id = str(uuid.uuid4())
self.logger = logging.get_logger("Task")

# avoid silently iterating over a string
Expand Down Expand Up @@ -197,6 +199,22 @@ def __repr__(self) -> str:
def __hash__(self) -> int:
return id(self)

@property
def id(self) -> str:
return self._id

@id.setter
def id(self, value: str) -> None:
"""
Args:
- value (str): a UUID-formatted string
"""
try:
uuid.UUID(value)
except Exception:
raise ValueError("Badly formatted UUID string: {}".format(value))
self._id = value

# Run --------------------------------------------------------------------

def run(self) -> None:
Expand Down Expand Up @@ -240,6 +258,8 @@ def copy(self) -> "Task":
)

new = copy.copy(self)
# assign new id
new.id = str(uuid.uuid4())

new.tags = copy.deepcopy(self.tags)
tags = set(prefect.context.get("_tags", set()))
Expand Down
16 changes: 10 additions & 6 deletions src/prefect/engine/flow_runner.py
Expand Up @@ -338,6 +338,10 @@ def get_flow_run_state(
task_contexts = task_contexts or {}
throttle = throttle or {}

# this set keeps track of any tasks that are mapped over - meaning they are the
# downstream task of at least one mapped edge
mapped_tasks = set()

# -- process each task in order

with executor.start():
Expand All @@ -357,6 +361,8 @@ def get_flow_run_state(
# -- process each edge to the task
for edge in self.flow.edges_to(task):
upstream_states[edge] = task_states[edge.upstream_task]
if edge.mapped:
mapped_tasks.add(task)

# if a task is provided as a start_task and its state is also
# provided, we assume that means it requires cached_inputs
Expand All @@ -377,11 +383,11 @@ def get_flow_run_state(
queues.get(tag) for tag in sorted(task.tags) if queues.get(tag)
]

if not self.flow.task_info[task]["mapped"]:
if not task in mapped_tasks:
upstream_mapped = {
e: executor.wait(f) # type: ignore
for e, f in upstream_states.items()
if self.flow.task_info[e.upstream_task]["mapped"]
if e.upstream_task in mapped_tasks
}
upstream_states.update(upstream_mapped)

Expand All @@ -392,12 +398,10 @@ def get_flow_run_state(
inputs=task_inputs,
ignore_trigger=(task in start_tasks),
context=dict(
prefect.context,
task_id=self.flow.task_info[task]["id"],
**task_contexts.get(task, {})
prefect.context, task_id=task.id, **task_contexts.get(task, {})
),
queues=task_queues,
mapped=self.flow.task_info[task]["mapped"],
mapped=task in mapped_tasks,
executor=executor,
)

Expand Down
18 changes: 2 additions & 16 deletions src/prefect/serialization/flow.py
Expand Up @@ -51,17 +51,6 @@ class Meta:
)
environment = fields.Nested(EnvironmentSchema, allow_none=True)

@pre_dump
def put_task_ids_in_context(self, flow: "prefect.core.Flow") -> "prefect.core.Flow":
"""
Adds task ids to context so they may be used by nested TaskSchemas and EdgeSchemas.
If the serialized object is not a Flow (like a dict), this step is skipped.
"""
if isinstance(flow, prefect.core.Flow):
self.context["task_ids"] = {t: i["id"] for t, i in flow.task_info.items()}
return flow

@post_load
def create_object(self, data):
"""
Expand All @@ -79,9 +68,6 @@ def create_object(self, data):
"""
data["validate"] = False
flow = super().create_object(data)
flow._id = data.get("id", None)

for t in flow.tasks:
flow.task_info[t].update({"id": t._id, "type": t._type})

if "id" in data:
flow.id = data["id"]
return flow
94 changes: 27 additions & 67 deletions src/prefect/serialization/task.py
@@ -1,69 +1,29 @@
import uuid
from collections import OrderedDict

import marshmallow
import prefect
from marshmallow import (
ValidationError,
fields,
pre_dump,
post_dump,
post_load,
pre_dump,
pre_load,
post_dump,
ValidationError,
)

import prefect
from prefect.utilities.serialization import (
UUID,
FunctionReference,
JSONCompatible,
VersionedSchema,
version,
to_qualified_name,
from_qualified_name,
to_qualified_name,
version,
)
from prefect.serialization.schedule import ScheduleSchema
from prefect.utilities.serialization import JSONCompatible


class FunctionReference(fields.Field):
"""
Field that stores a reference to a function as a string and reloads it when
deserialized.
The valid functions must be provided as a dictionary of {qualified_name: function}
"""

def __init__(self, valid_functions, **kwargs):
self.valid_functions = {to_qualified_name(f): f for f in valid_functions}
super().__init__(**kwargs)

def _serialize(self, value, attr, obj, **kwargs):
return to_qualified_name(value)

def _deserialize(self, value, attr, data, **kwargs):
return self.valid_functions.get(value, value)


class TaskMethodsMixin:
def dump_task_id(self, obj):
"""
Helper for serializing task IDs that may have been placed in the context dict
Args:
- obj (Task): the object being serialized
Returns:
- str: the object ID
"""
if isinstance(obj, prefect.core.Task) and "task_ids" in self.context:
return self.context["task_ids"].get(obj, None)

def load_task_id(self, data):
"""
Helper for loading task IDs (required because `id` is a Method field)
Args:
- data (str): the id of the object
Returns:
- str: the object ID
"""
return data

def get_attribute(self, obj, key, default):
"""
By default, Marshmallow attempts to index an object, then get its attributes.
Expand All @@ -83,14 +43,16 @@ def create_object(self, data):
deserialized a matching task. In that case, we reload the task from a shared
cache.
"""
task_id = data.get("id", None)
if task_id not in self.context.setdefault("task_cache", {}) or task_id is None:
task_id = data.get("id", str(uuid.uuid4()))

# if the id is not in the task cache, create a task object and add it
if task_id not in self.context.setdefault("task_id_cache", {}):
task = super().create_object(data)
task._id = task_id
task._type = data.get("type", None)
self.context["task_cache"][task_id] = task
task.id = task_id
self.context["task_id_cache"][task_id] = task

return self.context["task_cache"][task_id]
# return the task object from the cache
return self.context["task_id_cache"][task_id]


@version("0.3.3")
Expand All @@ -99,7 +61,7 @@ class Meta:
object_class = lambda: prefect.core.Task
object_class_exclude = ["id", "type"]

id = fields.Method("dump_task_id", "load_task_id", allow_none=True)
id = UUID()
type = fields.Function(lambda task: to_qualified_name(type(task)), lambda x: x)
name = fields.String(allow_none=True)
slug = fields.String(allow_none=True)
Expand All @@ -118,6 +80,8 @@ class Meta:
prefect.triggers.any_successful,
prefect.triggers.any_failed,
],
# don't reject custom functions, just leave them as strings
reject_invalid=False,
allow_none=True,
)
skip_on_upstream_skip = fields.Boolean(allow_none=True)
Expand All @@ -131,6 +95,8 @@ class Meta:
prefect.engine.cache_validators.partial_inputs_only,
prefect.engine.cache_validators.partial_parameters_only,
],
# don't reject custom functions, just leave them as strings
reject_invalid=False,
allow_none=True,
)

Expand All @@ -141,16 +107,10 @@ class Meta:
object_class = lambda: prefect.core.task.Parameter
object_class_exclude = ["id", "type"]

id = fields.Method("dump_task_id", "load_task_id", allow_none=True)
id = UUID()
type = fields.Function(lambda task: to_qualified_name(type(task)), lambda x: x)
name = fields.String()
name = fields.String(required=True)
default = JSONCompatible(allow_none=True)
required = fields.Boolean(allow_none=True)
description = fields.String(allow_none=True)
tags = fields.List(fields.String())

@pre_dump
def validate_name(self, data):
if self.get_attribute(data, "name", None) is None:
raise ValidationError("name is required.")
return data

0 comments on commit 1ca8c99

Please sign in to comment.