In [27]:
import functools
from typing import Callable

import numpy as np


def convert_ndarray(*input_names, **options):
    dtype = options.get("dtype", float)
    fn_asarray: Callable = options.get("fn_asarray", np.asarray)

    def asarray(x):
        return fn_asarray(x, dtype=dtype)

    def decorator(func):
        code = func.__code__
        argnames = code.co_varnames
        ndefaults = len(func.__defaults__) if func.__defaults__ else 0
        nposargs = code.co_argcount - ndefaults
        posargnames = argnames[:nposargs]
        for name in input_names:
            if name not in argnames:
                raise ValueError("In decorator convert_to_numpy(): Name '{}' "
                                 "doesn't correspond to any positional "
                                 "argument of the decorated function {}()."
                                 "".format(name, func.__name__))

        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            args = list(args)
            for i, name in enumerate(posargnames):
                if name in input_names:
                    args[i] = asarray(args[i])
            for k in kwargs.keys():
                if name in input_names:
                    kwargs[k] = asarray(kwargs[i])
            return func(*args, **kwargs)
        return wrapper
    return decorator


@convert_ndarray('cell', 'a')
def f(cell, a, b=np.array([1, ])):
    return cell.shape, a.shape, b.shape


f([1, ], [1, 2, 3], np.random.rand(3), )
f([1, ], [1, 2, 3])
# f([1, ], [1, 2, 3], [1, 2], )

((1,), (3,), (1,))

In [5]:
def f(*x, y=1):
    for i in x:
        print(i)
    print("y", y)


f(1, 1, 1, 1)


1
1
1
1
y 1


In [32]:
from jax import numpy as jnp
import jax


def f(cell, a, b=np.array([1, ])):
    return cell.shape, a.shape, b.shape


f1 = convert_ndarray('cell', 'a', 'b', fn_asarray=jnp.asarray)(f)

f1([1, ], [1, 2, 3], np.random.rand(3), )


((1,), (3,), (3,))

In [33]:
jax.make_jaxpr(f1)([1, ], [1, 2, 3], np.random.rand(3), )


{ [34m[22m[1mlambda [39m[22m[22m; a[35m:i32[][39m b[35m:i32[][39m c[35m:i32[][39m d[35m:i32[][39m e[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mf[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] a
    _[35m:f32[1][39m = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
    g[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] b
    h[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] c
    i[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] d
    j[35m:f32[1][39m = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] g
    k[35m:f32[1][39m = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] h
    l[35m:f32[1][39m = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] i
    _[35m:f32[3][39m = concatenate[dimension=0] j k l
  [34m[22m[1min [39m[22m[22m(1, 3, 3) }

In [39]:
def func(cell, a, b=jnp.array([1, ])):
    cell = jnp.asarray(cell, dtype=float)
    a = jnp.asarray(a, dtype=float)
    b = jnp.asarray(b, dtype=float)
    return cell.shape, a.shape, b.shape


jax.make_jaxpr(func)([1, ], [1, 2, 3], np.random.rand(3), )


{ [34m[22m[1mlambda [39m[22m[22m; a[35m:i32[][39m b[35m:i32[][39m c[35m:i32[][39m d[35m:i32[][39m e[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mf[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] a
    _[35m:f32[1][39m = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
    g[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] b
    h[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] c
    i[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] d
    j[35m:f32[1][39m = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] g
    k[35m:f32[1][39m = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] h
    l[35m:f32[1][39m = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] i
    _[35m:f32[3][39m = concatenate[dimension=0] j k l
  [34m[22m[1min [39m[22m[22m(1, 3, 3) }

In [6]:
def f(cell):
    return cell.shape


a = np.random.rand(3, 3)
b = np.random.rand(5, 3)
f(a), f(b), f(a.tolist())


AttributeError: 'list' object has no attribute 'shape'