In [None]:
#| default_exp utils

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import inspect
from functools import wraps

In [None]:
#| export
def ensure_dtypes(*names):
    def decorator(f):
        @wraps(f)
        def inner(*args, **kwargs):
            fparams = inspect.signature(f).parameters
            fparams_names = list(fparams.keys())
            for name in names:
                position = fparams_names.index(name)
                if position < len(args) or name in kwargs:
                    arg = args[position] if position < len(args) else kwargs[name]
                    expected_dtype = fparams[name].annotation
                    if not isinstance(arg, expected_dtype):
                        raise ValueError(f"'{name}' should have the following type: {expected_dtype}. Got {type(arg)}.")
            return f(*args, **kwargs)
        return inner
    return decorator

In [None]:
import pandas as pd
import polars as pl
from fastcore.test import test_eq, test_fail

from utilsforecast.compat import DataFrame

In [None]:
@ensure_dtypes('a')
def f(a: int, b=None):
    return a

test_eq(f(1, 2), 1)
expected_err_msg = "'a' should have the following type: <class 'int'>. Got <class 'str'>."
test_fail(lambda: f('a', 2), contains=expected_err_msg)
test_fail(lambda: f(a='a', b=1), contains=expected_err_msg)

In [None]:
@ensure_dtypes('df')
def g(df: DataFrame):
    return df

g(pd.DataFrame([1]))
g(df=pd.DataFrame([1]))
g(pl.DataFrame([1]))
expected_err_msg = "'df' should have the following type"
test_fail(lambda: g(1), contains=expected_err_msg)