Skip to content

Commit

Permalink
pyfunc component running without pickle
Browse files Browse the repository at this point in the history
  • Loading branch information
calgray committed Dec 6, 2021
1 parent 1ee2e0e commit d5fc1e6
Showing 1 changed file with 75 additions and 26 deletions.
101 changes: 75 additions & 26 deletions daliuge-engine/dlg/apps/pyfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,30 @@
#
"""Module implementing the PyFuncApp class"""

import ast
import base64
import collections
import importlib
import inspect
import logging
import pickle

from typing import Callable
import dill

from .. import droputils, utils
from ..drop import BarrierAppDROP
from ..exceptions import InvalidDropException

from dlg import droputils, utils
from dlg.drop import BarrierAppDROP
from dlg.exceptions import InvalidDropException
from dlg.meta import (
dlg_bool_param,
dlg_string_param,
dlg_float_param,
dlg_dict_param,
dlg_component,
dlg_batch_input,
dlg_batch_output,
dlg_streaming_input,
)

logger = logging.getLogger(__name__)

Expand All @@ -57,7 +68,7 @@ def serialize_func(f):
a = inspect.getfullargspec(f)
if a.defaults:
fdefaults = dict(
zip(a.args[-len(a.defaults) :], [serialize_data(d) for d in a.defaults])
zip(a.args[-len(a.defaults):], [serialize_data(d) for d in a.defaults])
)
logger.debug("Defaults for function %r: %r", f, fdefaults)
return fser, fdefaults
Expand Down Expand Up @@ -101,11 +112,14 @@ def import_using_code(code):
# @param[in] param/func_name Function Name//String/readwrite/
# \~English Python fuction name
# @param[in] param/func_code Function Code//String/readwrite/
# \~English Python fuction code, e.g. 'def fuction_name(args): pass'
# \~English Python fuction code, e.g. 'def fuction_name(args): return args'
# @param[in] param/pickle Pickle//bool/readwrite/
# \~English Whether the python arguments are pickled.
# @param[in] param/func_defaults Function Defaults//String/readwrite/
# \~English Dictionary of keyword arg names to default values
# \~English Mapping from argname to default value. Should match only the last part
# of the argnames list
# @param[in] param/func_arg_mapping Function Arguments Mapping//String/readwrite/
# \~English Dictionary of keyword arg names to input drop uid
# \~English Mapping between argument name and input drop uids
# @par EAGLE_END
class PyFuncApp(BarrierAppDROP):
"""
Expand All @@ -125,55 +139,87 @@ class PyFuncApp(BarrierAppDROP):
Both inputs and outputs are serialized using the pickle protocol.
"""

component_meta = dlg_component(
"PyFuncApp",
"Py Func App.",
[dlg_batch_input("binary/*", [])],
[dlg_batch_output("binary/*", [])],
[dlg_streaming_input("binary/*")],
)

func_name = dlg_string_param("func_name", None)

# fcode = dlg_bytes_param("func_code", None) # bytes or base64 string

pickle = dlg_bool_param("pickle", True)

func_arg_mapping = dlg_dict_param("func_arg_mapping", {})

func_defaults = dlg_dict_param("func_defaults", {})

f: Callable
fdefaults: dict

def initialize(self, **kwargs):
BarrierAppDROP.initialize(self, **kwargs)

self.fname = fname = self._getArg(kwargs, "func_name", None)
fcode = self._getArg(kwargs, "func_code", None)
if not fname and not fcode:
self.fcode = self._getArg(kwargs, "func_code", None)
if not self.func_name and not self.fcode:
raise InvalidDropException(
self, "No function specified (either via name or code)"
)

if not fcode:
self.f = import_using_name(self, fname)
# Lookup function or import bytecode as a function
if not self.fcode:
self.f = import_using_name(self, self.func_name)
else:
if not isinstance(fcode, bytes):
fcode = base64.b64decode(fcode.encode("utf8"))
self.f = import_using_code(fcode)
if not isinstance(self.fcode, bytes):
self.fcode = base64.b64decode(self.fcode.encode("utf8"))
self.f = import_using_code(self.fcode)

# Mapping from argname to default value. Should match only the last part
# of the argnames list
fdefaults = self._getArg(kwargs, "func_defaults", {}) or {}
self.fdefaults = {name: deserialize_data(d) for name, d in fdefaults.items()}
logger.debug("Default values for function %s: %r", self.fname, self.fdefaults)
if isinstance(self.func_defaults, str):
self.func_defaults = ast.literal_eval(self.func_defaults)

if self.pickle:
self.fdefaults = {name: deserialize_data(d) for name, d in self.func_defaults.items()}
else:
self.fdefaults = self.func_defaults

print(f"Default values for function {self.func_name}: {self.fdefaults}")

# Mapping between argument name and input drop uids
self.func_arg_mapping = self._getArg(kwargs, "func_arg_mapping", {})
logger.debug("Input mapping: %r", self.func_arg_mapping)
print(f"Input mapping: {self.func_arg_mapping}")

def run(self):

# Inputs are un-pickled and treated as the arguments of the function
# Their order must be preserved, so we use an OrderedDict
all_contents = lambda x: pickle.loads(droputils.allDropContents(x))
if self.pickle:
all_contents = lambda x: pickle.loads(droputils.allDropContents(x))
else:
all_contents = lambda x: ast.literal_eval(droputils.allDropContents(x).decode('utf-8'))

inputs = collections.OrderedDict()
for uid, i in self._inputs.items():
inputs[uid] = all_contents(i)
for uid, drop in self._inputs.items():
inputs[uid] = all_contents(drop)

# Keyword arguments are made up by the default values plus the inputs
# that match one of the keyword argument names
argnames = inspect.getfullargspec(self.f).args

kwargs = {
name: inputs.pop(uid)
for name, uid in self.func_arg_mapping.items()
if name in self.fdefaults or name not in argnames
}
kwargs.merge(self.fdefaults)

# The rest of the inputs are the positional arguments
args = list(inputs.values())

logger.debug("Running %s with args=%r, kwargs=%r", self.fname, args, kwargs)
print(f"Running {self.func_name} with args={args}, kwargs={kwargs}")
result = self.f(*args, **kwargs)

# Depending on how many outputs we have we treat our result
Expand All @@ -183,4 +229,7 @@ def run(self):
if len(outputs) == 1:
result = [result]
for r, o in zip(result, outputs):
o.write(pickle.dumps(r)) # @UndefinedVariable
if self.pickle:
o.write(pickle.dumps(r)) # @UndefinedVariable
else:
o.write(repr(r).encode('utf-8'))

0 comments on commit d5fc1e6

Please sign in to comment.