Skip to content

Commit

Permalink
fix: many dispatches on same function (#144)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jul 31, 2024
1 parent 7ba6dcb commit 02b1edf
Showing 1 changed file with 21 additions and 32 deletions.
53 changes: 21 additions & 32 deletions src/unxt/_quantity/register_dispatches.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,41 @@
# pylint: disable=import-error

from collections.abc import Callable
from typing import Any, TypeVar

import jax
import jax.core
import jax.experimental.array_api as jax_xp
import numpy as np
from jax import Device
from plum import Dispatcher, Function
from plum.parametric import type_unparametrized as type_np

from quaxed._types import DType
from quaxed.array_api._dispatch import dispatcher as dispatcher_
from quaxed.numpy._dispatch import dispatcher as np_dispatcher_
from quaxed.array_api._dispatch import dispatcher as xp_dispatcher
from quaxed.numpy._dispatch import dispatcher as np_dispatcher

from .base import AbstractQuantity
from .core import Quantity

T = TypeVar("T")


def dispatcher(f: T) -> T: # TODO: figure out mypy stub issue.
"""Dispatcher that makes mypy happy."""
return dispatcher_(f)
def chain_dispatchers(*dispatchers: Dispatcher) -> Callable[[Any], Function]:
"""Apply many dispatchers to a function."""

def decorator(method: Any) -> Function:
for dispatcher in dispatchers:
f = dispatcher(method)
return f

def np_dispatcher(f: T) -> T: # TODO: figure out mypy stub issue.
"""Dispatcher that makes mypy happy."""
return np_dispatcher_(f)
return decorator


# -----------------------------------------------


@np_dispatcher
@dispatcher
@chain_dispatchers(np_dispatcher, xp_dispatcher)
def arange(
start: Quantity,
stop: Quantity | None = None,
Expand All @@ -58,8 +60,7 @@ def arange(
# -----------------------------------------------


@np_dispatcher
@dispatcher
@chain_dispatchers(np_dispatcher, xp_dispatcher)
def empty_like(
x: AbstractQuantity, /, *, dtype: Any = None, device: Any = None
) -> AbstractQuantity:
Expand All @@ -70,8 +71,7 @@ def empty_like(
# -----------------------------------------------


@np_dispatcher
@dispatcher
@chain_dispatchers(np_dispatcher, xp_dispatcher)
def full_like(
x: AbstractQuantity,
/,
Expand All @@ -83,7 +83,8 @@ def full_like(
return full_like(x, fill_value, dtype=dtype, device=device)


def full_like( # type: ignore[no-redef]
@chain_dispatchers(np_dispatcher, xp_dispatcher) # type: ignore[no-redef]
def full_like(
x: AbstractQuantity,
fill_value: AbstractQuantity,
/,
Expand All @@ -97,12 +98,8 @@ def full_like( # type: ignore[no-redef]
)


# TODO: fix when https://github.com/beartype/plum/pull/186
np_dispatcher(full_like)
dispatcher(full_like)


def full_like( # type: ignore[no-redef]
@chain_dispatchers(np_dispatcher, xp_dispatcher) # type: ignore[no-redef]
def full_like(
x: AbstractQuantity,
fill_value: bool | int | float | complex,
/,
Expand All @@ -115,16 +112,10 @@ def full_like( # type: ignore[no-redef]
)


# TODO: fix when https://github.com/beartype/plum/pull/186
np_dispatcher(full_like)
dispatcher(full_like)


# -----------------------------------------------


@np_dispatcher
@dispatcher
@chain_dispatchers(np_dispatcher, xp_dispatcher)
def linspace(
start: Quantity,
stop: Quantity,
Expand All @@ -149,17 +140,15 @@ def linspace(
)


@np_dispatcher
@dispatcher
@chain_dispatchers(np_dispatcher, xp_dispatcher)
def ones_like(
x: AbstractQuantity, /, *, dtype: Any = None, device: Any = None
) -> AbstractQuantity:
out = type_np(x)(jax_xp.ones_like(x.value, dtype=dtype), unit=x.unit)
return jax.device_put(out, device=device)


@np_dispatcher
@dispatcher
@chain_dispatchers(np_dispatcher, xp_dispatcher)
def zeros_like(
x: AbstractQuantity, /, *, dtype: Any = None, device: Any = None
) -> AbstractQuantity:
Expand Down

0 comments on commit 02b1edf

Please sign in to comment.