forked from spotify/luigi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
task_register.py
238 lines (189 loc) · 7.35 KB
/
task_register.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Define the centralized register of all :class:`~luigi.task.Task` classes.
"""
import abc
from collections import OrderedDict
from luigi import six
import logging
logger = logging.getLogger('luigi-interface')
class TaskClassException(Exception):
pass
class TaskClassNotFoundException(TaskClassException):
pass
class TaskClassAmbigiousException(TaskClassException):
pass
class Register(abc.ABCMeta):
"""
The Metaclass of :py:class:`Task`.
Acts as a global registry of Tasks with the following properties:
1. Cache instances of objects so that eg. ``X(1, 2, 3)`` always returns the
same object.
2. Keep track of all subclasses of :py:class:`Task` and expose them.
"""
__instance_cache = {}
_default_namespace = None
_reg = []
AMBIGUOUS_CLASS = object() # Placeholder denoting an error
"""If this value is returned by :py:meth:`_get_reg` then there is an
ambiguous task name (two :py:class:`Task` have the same name). This denotes
an error."""
def __new__(metacls, classname, bases, classdict):
"""
Custom class creation for namespacing.
Also register all subclasses.
Set the task namespace to whatever the currently declared namespace is.
"""
if "task_namespace" not in classdict:
classdict["task_namespace"] = metacls._default_namespace
cls = super(Register, metacls).__new__(metacls, classname, bases, classdict)
metacls._reg.append(cls)
return cls
def __call__(cls, *args, **kwargs):
"""
Custom class instantiation utilizing instance cache.
If a Task has already been instantiated with the same parameters,
the previous instance is returned to reduce number of object instances.
"""
def instantiate():
return super(Register, cls).__call__(*args, **kwargs)
h = cls.__instance_cache
if h is None: # disabled
return instantiate()
params = cls.get_params()
param_values = cls.get_param_values(params, args, kwargs)
k = (cls, tuple(param_values))
try:
hash(k)
except TypeError:
logger.debug("Not all parameter values are hashable so instance isn't coming from the cache")
return instantiate() # unhashable types in parameters
if k not in h:
h[k] = instantiate()
return h[k]
@classmethod
def clear_instance_cache(cls):
"""
Clear/Reset the instance cache.
"""
cls.__instance_cache = {}
@classmethod
def disable_instance_cache(cls):
"""
Disables the instance cache.
"""
cls.__instance_cache = None
@property
def task_family(cls):
"""
The task family for the given class.
If ``cls.task_namespace is None`` then it's the name of the class.
Otherwise, ``<task_namespace>.`` is prefixed to the class name.
"""
if cls.task_namespace is None:
return cls.__name__
else:
return "%s.%s" % (cls.task_namespace, cls.__name__)
@classmethod
def _get_reg(cls):
"""Return all of the registered classes.
:return: an ``collections.OrderedDict`` of task_family -> class
"""
# We have to do this on-demand in case task names have changed later
# We return this in a topologically sorted list of inheritance: this is useful in some cases (#822)
reg = OrderedDict()
for cls in cls._reg:
name = cls.task_family
if name in reg and reg[name] != cls and \
reg[name] != cls.AMBIGUOUS_CLASS and \
not issubclass(cls, reg[name]):
# Registering two different classes - this means we can't instantiate them by name
# The only exception is if one class is a subclass of the other. In that case, we
# instantiate the most-derived class (this fixes some issues with decorator wrappers).
reg[name] = cls.AMBIGUOUS_CLASS
else:
reg[name] = cls
return reg
@classmethod
def _set_reg(cls, reg):
"""The writing complement of _get_reg
"""
cls._reg = [task_cls for task_cls in reg.values() if task_cls is not cls.AMBIGUOUS_CLASS]
@classmethod
def task_names(cls):
"""
List of task names as strings
"""
return sorted(cls._get_reg().keys())
@classmethod
def tasks_str(cls):
"""
Human-readable register contents dump.
"""
return ','.join(cls.task_names())
@classmethod
def get_task_cls(cls, name):
"""
Returns an unambiguous class or raises an exception.
"""
task_cls = cls._get_reg().get(name)
if not task_cls:
raise TaskClassNotFoundException(cls._missing_task_msg(name))
if task_cls == cls.AMBIGUOUS_CLASS:
raise TaskClassAmbigiousException('Task %r is ambiguous' % name)
return task_cls
@classmethod
def get_all_params(cls):
"""
Compiles and returns all parameters for all :py:class:`Task`.
:return: a generator of tuples (TODO: we should make this more elegant)
"""
for task_name, task_cls in six.iteritems(cls._get_reg()):
if task_cls == cls.AMBIGUOUS_CLASS:
continue
for param_name, param_obj in task_cls.get_params():
yield task_name, (not task_cls.use_cmdline_section), param_name, param_obj
@staticmethod
def _editdistance(a, b):
""" Simple unweighted Levenshtein distance """
r0 = range(0, len(b) + 1)
r1 = [0] * (len(b) + 1)
for i in range(0, len(a)):
r1[0] = i + 1
for j in range(0, len(b)):
c = 0 if a[i] is b[j] else 1
r1[j + 1] = min(r1[j] + 1, r0[j + 1] + 1, r0[j] + c)
r0 = r1[:]
return r1[len(b)]
@classmethod
def _missing_task_msg(cls, task_name):
weighted_tasks = [(Register._editdistance(task_name, task_name_2), task_name_2) for task_name_2 in cls.task_names()]
ordered_tasks = sorted(weighted_tasks, key=lambda pair: pair[0])
candidates = [task for (dist, task) in ordered_tasks if dist <= 5 and dist < len(task)]
if candidates:
return "No task %s. Did you mean:\n%s" % (task_name, '\n'.join(candidates))
else:
return "No task %s. Candidates are: %s" % (task_name, cls.tasks_str())
def load_task(module, task_name, params_str):
"""
Imports task dynamically given a module and a task name.
"""
if module is not None:
__import__(module)
task_cls = Register.get_task_cls(task_name)
return task_cls.from_str_params(params_str)