In [50]:
import sys; sys.path.append("..")
import os
import dagpipe
import inspect


def foo(a,b,c):
    ...
    
    
def default_args(a, b, *args):
    ...
    
def default_kwargs(**kwargs):
    ...
    

_sig = inspect.signature(foo)
_sig.bind(1,2,3) == _sig.bind(a=1,b=2,c=3)

True

In [37]:
def final_boss(a, b, *args, d=5, **kwargs_):
    
    print(a,b,args,d,kwargs_)

_sig = inspect.signature(final_boss)

In [17]:
bound = _sig.bind(1,2,3, d=10, e=12)
final_boss(*bound.args, **bound.kwargs)


1 2 (3,) 10 {'e': 12}


In [20]:
for_update = _sig.bind_partial(dict(a=10, e=23))
for_update.args

{'a': {'a': 10, 'e': 23}}

In [38]:
from typing import Callable


    
class FuncParameters:
    def __init__(self, func: Callable, *args, **kwargs) -> None:
        self.sig = inspect.signature(func)
        self._parameters = self.sig.bind(*args, **kwargs).arguments
        
    def to_dict(self):
        return self._parameters
        
    def update(self, *args, **kwargs):
        new_arguments = self.sig.bind_partial(*args, **kwargs).arguments
        for arg_name, value in new_arguments.items():
            self._parameters[arg_name] = value
            
    @property
    def args(self):
        args = []
        for param_name, param in self._sig.parameters.items():
            if param.kind in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.KEYWORD_ONLY):
                break

            if arg := self._parameters.get(param_name, None):
                if param.kind == inspect.Parameter.VAR_POSITIONAL:
                    args.extend(arg)
                else:
                    args.append(arg)
                    
        return tuple(args)

    @property
    def kwargs(self):
        kwargs = {}
        kwargs_started = False
        for param_name, param in self._sig.parameters.items():
            if not kwargs_started:
                if param.kind in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.KEYWORD_ONLY):
                    kwargs_started = True
                else:
                    if param_name not in self._parameters:
                        kwargs_started = True
                        continue

            if not kwargs_started:
                continue
            
            if arg := self._parameters.get(param_name, None):
                if param.kind == inspect.Parameter.VAR_KEYWORD:
                    kwargs.update(arg)
                else:
                    kwargs[param_name] = arg

        return kwargs

        
        
params = FuncParameters(final_boss, 1, b=3, asd=5)
print(params.to_dict())
params.update(fifarafa=10)
params.to_dict()

{'a': 1, 'b': 3, 'kwargs_': {'asd': 5}}


{'a': 1, 'b': 3, 'kwargs_': {'fifarafa': 10}}

In [84]:
_sig = inspect.signature(final_boss)
for name, param in _sig.parameters.items():
    if param.kind == param.VAR_KEYWORD:
        print("o.o", name)

o.o kwargs_


In [79]:
import inspect

def get_kwargs_parameter_name(func):
    sig = inspect.signature(func)
    for name, param in sig.parameters.items():
        if param.kind == param.VAR_KEYWORD:
            return name
    return None

# Example functions
def example_function(a, b=10, *args, d=20, **kwargs):
    pass

def set_options(**options):
    pass

# Get the name of the keyword arguments parameter
kwargs_param_name_1 = get_kwargs_parameter_name(example_function)
kwargs_param_name_2 = get_kwargs_parameter_name(set_options)
kwargs_param_name_3 = get_kwargs_parameter_name(final_boss)

print(kwargs_param_name_1)  # Output: kwargs
print(kwargs_param_name_2)  # Output: options
print(kwargs_param_name_3)  # Output: options

kwargs
options
kwargs_


In [48]:
param.VAR_KEYWORD

inspect._ParameterKind.VAR_POSITIONAL
inspect.Parameter.VAR_POSITIONAL



<_ParameterKind.VAR_POSITIONAL: 2>

In [111]:
from dagpipe.typing import TaskType


class TaskParams:
    def __init__(self, func: Callable, *init_args, **init_kwargs) -> None:
        self._sig = inspect.signature(func)
        self._parameters = self._sig.bind(*init_args, **init_kwargs).arguments
        self._task_params_names = self._filter_tasks_from_parameters()
        
        self._varargs_name = self._find_param_name(inspect.Parameter.VAR_POSITIONAL)
        self._varkwargs_name = self._find_param_name(inspect.Parameter.VAR_KEYWORD)
        
        self._all_varargs_are_tasks = self._are_all_varargs_tasks()
        self._tasks_in_varkwargs = self._get_tasks_from_varkwargs()
        
    @property
    def evaluated_args(self) -> tuple:
        return tuple(a.evaluated_result if isinstance(a, TaskType) else a for a in self.args)
    
    @property
    def evaluated_kwargs(self) -> dict:
        return {k: v.evaluated_result if isinstance(v, TaskType) else v for k, v in self.kwargs.items()}
        
    @property
    def args(self) -> tuple:
        args = []
        for param_name, param in self._sig.parameters.items():
            if param.kind in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.KEYWORD_ONLY):
                break

            if arg := self._parameters.get(param_name, None):
                if param.kind == inspect.Parameter.VAR_POSITIONAL:
                    args.extend(arg)
                else:
                    args.append(arg)
                    
        return tuple(args)

    @property
    def kwargs(self) -> dict:
        kwargs = {}
        kwargs_started = False
        for param_name, param in self._sig.parameters.items():
            if not kwargs_started:
                if param.kind in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.KEYWORD_ONLY):
                    kwargs_started = True
                else:
                    if param_name not in self._parameters:
                        kwargs_started = True
                        continue

            if not kwargs_started:
                continue
            
            if arg := self._parameters.get(param_name, None):
                if param.kind == inspect.Parameter.VAR_KEYWORD:
                    kwargs.update(arg)
                else:
                    kwargs[param_name] = arg

        return kwargs

        
    def update(self, *args, **kwargs):
        new_arguments = self._sig.bind_partial(*args, **kwargs).arguments
        self.__assert_tasks_are_not_overwritten(new_arguments)
        new_arguments = self.__update_varkwargs_with_varkwargs_tasks(new_arguments)
        for arg_name, value in new_arguments.items():
            self._parameters[arg_name] = value

    def __assert_tasks_are_not_overwritten(self, new_arguments):
        self.__assert_tasks_are_not_overwritten_in_arguments(new_arguments)
        self.__assert_tasks_are_not_overwritten_in_varargs(new_arguments)
        self.__assert_tasks_are_not_overwritten_in_varkwargs(new_arguments)

    def __assert_tasks_are_not_overwritten_in_arguments(self, new_arguments: dict):
        for name in self._task_params_names:
            if name in new_arguments:
                raise TypeError(f"Tried to overwrite {self._parameters[name]}"
                                f" with {new_arguments[name]}"
                                f" in parameter {name}."
                                " Task overwriting is not allowed.")   
    
    def __assert_tasks_are_not_overwritten_in_varkwargs(self, new_arguments: dict):
        if self._varkwargs_name:
            if new_varkwargs := new_arguments.get(self._varkwargs_name, None):
                for name in self._tasks_in_varkwargs:
                    if name in new_varkwargs:
                        raise TypeError(f"Tried to overwrite {self._parameters[name]}"
                                        f" with {new_arguments[name]}"
                                        f" in parameter {name}."
                                        " Task overwriting is not allowed.")        

    def __assert_tasks_are_not_overwritten_in_varargs(self, new_arguments: dict):
        if self._all_varargs_are_tasks:
            if new_arguments.get(self._varargs_name, None):
                raise TypeError("Overwriting varargs, is not allowed"
                                    " when varargs are tasks.")
   
    def __update_varkwargs_with_varkwargs_tasks(self, new_arguments: dict):
        if self._varkwargs_name in new_arguments:
            new_arguments[self._varkwargs_name].update(self._tasks_in_varkwargs)
        return new_arguments
                            
    def _filter_tasks_from_parameters(self):
        return [name for name, param in self._parameters.items() if isinstance(param, TaskType)]
        
    def _are_all_varargs_tasks(self) -> bool:
        if varargs := self._parameters.get(self._varargs_name, None):
            if any(isinstance(arg, TaskType) for arg in varargs):
                self.__assert_all_varargs_are_tasks(varargs)
                return True
        return False
    
    def __assert_all_varargs_are_tasks(self, varargs):
        if not all(isinstance(arg, TaskType) for arg in varargs):
            raise ValueError("Either all or none varargs needs to be a TaskType.")
    
    def _find_param_name(self, kind):
        for name, param in self._sig.parameters.items():
            if param.kind == kind:
                return name
        return None
    
    def _get_tasks_from_varkwargs(self) -> dict:
        if varkwargs := self._parameters.get(self._varkwargs_name, None):
                return {k: v for k, v in varkwargs.items() if isinstance(v, TaskType)}
        return {}
        
    def to_dict(self):
        return self._parameters

## Tests

In [112]:
@dagpipe.task()
def foo(x):
    return x

t = foo(1)
t.run("evaluated")
t.evaluated_result

'evaluated'

In [113]:
(
    TaskParams(final_boss, 1, 2, 3, t=t)._task_params_names,
    TaskParams(final_boss, 1, 2, 3, t=t)._all_varargs_are_tasks,
    TaskParams(final_boss, 1, 2, 3, t=t)._tasks_in_varkwargs,    
)

([], False, {'t': Task<foo>})

In [114]:
try:
    (
        TaskParams(final_boss, 1, 2, 3, t)._task_params_names,
        TaskParams(final_boss, 1, 2, 3, t)._all_varargs_are_tasks,
        TaskParams(final_boss, 1, 2, 3, t)._tasks_in_varkwargs,    
    )
except ValueError as e:
    print("Error occurs")
    print(e)

Error occurs
Either all or none varargs needs to be a TaskType.


In [115]:
(
    TaskParams(final_boss, t, 2, 3)._task_params_names,
    TaskParams(final_boss, t, 2, 3)._all_varargs_are_tasks,
    TaskParams(final_boss, t, 2, 3)._tasks_in_varkwargs,    
)

(['a'], False, {})

In [116]:
(
    TaskParams(final_boss, t, t, t)._task_params_names,
    TaskParams(final_boss, t, t, t)._all_varargs_are_tasks,
    TaskParams(final_boss, t, t, t)._tasks_in_varkwargs,    
)

(['a', 'b'], True, {})

In [120]:
(
    TaskParams(final_boss, 1, t, t, t, t=t, not_task=123).args,
    TaskParams(final_boss, 1, t, t, t, t=t, not_task=123).kwargs,
    TaskParams(final_boss, 1, t, t, t, t=t, not_task=123).evaluated_args,
    TaskParams(final_boss, 1, t, t, t, t=t, not_task=123).evaluated_kwargs,
)

((1, Task<foo>, Task<foo>, Task<foo>),
 {'t': Task<foo>, 'not_task': 123},
 (1, 'evaluated', 'evaluated', 'evaluated'),
 {'t': 'evaluated', 'not_task': 123})