In [None]:
import enum
import random
from types import ModuleType
from typing import Any, Protocol, Tuple, Union, runtime_checkable

import cupy as cp
import numpy as np
from numpy.typing import ArrayLike
import pandas as pd
import torch
import torch.utils.dlpack


In [None]:
NUM_SAMPLES = 100

In [None]:
# set seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
cp.random.seed(seed)
torch.manual_seed(seed)

In [None]:
# typical data types for preds and target
preds_list = [random.random() for _ in range(NUM_SAMPLES)]
target_list = [random.randint(0, 1) for _ in range(NUM_SAMPLES)]

preds_tuple = tuple([random.random() for _ in range(NUM_SAMPLES)])
target_tuple = tuple([random.randint(0, 1) for _ in range(NUM_SAMPLES)])

preds_np = np.random.rand(NUM_SAMPLES)
target_np = np.random.randint(0, 2, NUM_SAMPLES)

preds_pd = pd.Series(np.random.rand(NUM_SAMPLES))
target_pd = pd.Series(np.random.randint(0, 2, NUM_SAMPLES))

preds_cp = cp.random.rand(NUM_SAMPLES)
target_cp = cp.random.randint(0, 2, NUM_SAMPLES)

preds_torch = torch.rand(NUM_SAMPLES, device="cuda")
target_torch = torch.randint(0, 2, (NUM_SAMPLES,), device="cuda")

eval_test_data = [
    # (preds_list, target_list),
    # (preds_tuple, target_tuple),
    (preds_np, target_np),
    # (preds_pd, target_pd),
    (preds_cp, target_cp),
    (preds_torch, target_torch),
]


## Utility functions

In [None]:
# see: https://dmlc.github.io/dlpack/latest/
# and: https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html

# define a type that has __dlpack__ and __dlpack_device__ methods
_PyCapsule = Any


@runtime_checkable
class DLPackTensor(Protocol):
    def __dlpack__(self) -> _PyCapsule:
        ...

    def __dlpack_device__(self) -> Any:
        ...


In [None]:
def convert_to_torch_tensor(
    arg: Union[ArrayLike, DLPackTensor], copy: bool = False
) -> torch.Tensor:
    # should only support:
    # object types: list, tuple, numpy array, DLPack-compatible arrays/tensors
    # data types: bool, signed/unsigned integers, float and complex dtypes
    # signed/unsigned integers, float and complex dtypes
    if isinstance(arg, torch.Tensor):
        return arg

    if isinstance(arg, DLPackTensor):
        torch_tensor = torch.utils.dlpack.from_dlpack(arg)
    else:
        try:
            torch_tensor = torch.tensor(arg)
        except:
            raise

    if copy:
        torch_tensor = torch_tensor.clone()

    return torch_tensor

In [None]:
SUPPORTED_DLDEVICE_GPUS = ["kDLGPU", "kDLROCM"]


class DLDeviceType(enum.IntEnum):
    """DLDeviceType enum from DLPack specification."""

    kDLCPU = (1,)  # CPU device
    kDLGPU = (2,)  # CUDA GPU device
    kDLCUDAHost = (3,)  # Pinned CUDA CPU memory by cudaMallocHost
    kDLOpenCL = (4,)  # OpenCL devices.
    kDLVulkan = (7,)  # Vulkan buffer for next generation graphics.
    kDLMetal = (8,)  # Metal for Apple GPU
    kDLVPI = (9,)  # Verilog simulator buffer
    kDLROCM = (10,)  # ROCm GPUs for AMD GPUs
    kDLROCMHost = (11,)  # Pinned ROCm CPU memory allocated by hipMallocHost
    kDLExtDev = (12,)  # Reserved extension device type


def _get_dl_device_type(arr: Any) -> DLDeviceType:
    """Get the DLDeviceType of the given (dlpack-compatible) array.

    Parameters
    ----------
    arr : Any
        The (dlpack-compatible) array to get the DLDeviceType of.

    Returns
    -------
    DLDeviceType
        The DLDeviceType of the array.

    Raises
    ------
    TypeError
        If the array does not have a `__dlpack_device__` attribute.

    """
    if not hasattr(arr, "__dlpack_device__"):
        raise TypeError(
            "Expected `arr` to have a `__dlpack_device__` attribute. "
            "Got {}.".format(type(arr))
        )

    device = arr.__dlpack_device__()
    return DLDeviceType(device[0])


def get_array_module(*args: Any) -> ModuleType:
    """Choose the array module based on the input arguments.

    If at least one of the arguments is on the GPU (CUDA or ROCM), the `cupy`
    module will be returned.

    Parameters
    ----------
    args : Any
        Objects for which to determine the array module.

    Returns
    -------
    The array module - `cupy`, if the object is on the GPU, otherwise `numpy`.

    """
    for arg in args:
        if all(hasattr(arg, attr) for attr in ["__dlpack__"]):
            device_type = _get_dl_device_type(arg)
            if device_type.name in SUPPORTED_DLDEVICE_GPUS:
                return cp

    return cp.get_array_module(*args)


def to_xpy(*arrays: Any) -> Tuple[Any]:
    """Convert array/tensor-like objects to either cupy or numpy arrays.

    If multiple objects are passed, they must all be on the same device.

    Parameters
    ----------
    arrays : Any
        Objects to convert to cupy or numpy arrays.

    Returns
    -------
    out : tuple
        Tuple of cupy and/or numpy arrays.

    """
    xp = get_array_module(*arrays)
    is_np = xp == np
    out = []

    for array in arrays:
        if hasattr(array, "__dlpack__"):
            device_type = _get_dl_device_type(array)
            if (not is_np and device_type.name not in SUPPORTED_DLDEVICE_GPUS) or (
                is_np and device_type.name in SUPPORTED_DLDEVICE_GPUS
            ):
                raise ValueError(
                    f"Expected to convert `arrays` to {xp.__name__} arrays. "
                    f"Got array object on {device_type.name} device. Please "
                    "move all arrays to the same device before calling this function."
                )

            xp_arr = xp._from_dlpack(array) if is_np else xp.from_dlpack(array)
        else:
            try:
                xp_arr = xp.asarray(array)
            except:  # noqa: E722
                raise TypeError(
                    "Expected `arrays` to contain be of one of the following types: "
                    "Sequence | numpy.ndarray | cupy.ndarray | torch.Tensor. "
                    "Got {}.".format(type(array))
                )
        out.append(xp_arr)

    return tuple(out)

### Tests for the utility functions

In [None]:
# test data: python scalars, lists, tuples, numpy arrays, pandas series, cupy arrays and torch tensors
# data types: int, float, bool, str
test_data: list[Any] = [
    # python scalars
    1,
    1.0,
    "1",
    False,
    # python lists
    [1, 2, 3],
    [1.0, 2.0, 3.0],
    ["1", "2", "3"],
    [False, True, False],
    # python tuples
    (1, 2, 3),
    (1.0, 2.0, 3.0),
    ("1", "2", "3"),
    (False, True, False),
    # numpy arrays
    np.array([1, 2, 3]),
    np.array([1.0, 2.0, 3.0]),
    np.array(["1", "2", "3"]),
    np.array([[True, False], [False, True]]),
    np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
    # pandas series
    pd.Series([1, 2, 3]),
    pd.Series([1.0, 2.0, 3.0]),
    pd.Series(["1", "2", "3"]),
    pd.Series([True, False, True]),
    # cupy arrays
    cp.array([1, 2, 3]),
    cp.array([1.0, 2.0, 3.0]),
    # cp.array(["1", "2", "3"]), # cupy does not support string arrays
    # cp.array([[True, False], [False, True]]), # dlpack only supports numeric types
    cp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
    # torch tensors - cpu
    torch.tensor([1, 2, 3]),
    torch.tensor([[True, False], [False, True]]),
    torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
    # torch tensors - gpu
    torch.tensor([1, 2, 3], device="cuda"),
    torch.tensor([[True, False], [False, True]], device="cuda"),
    torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], device="cuda"),
]


In [None]:
for data in test_data:
    print(f"Input: {data} ({type(data)})")
    try:
        print(f"Output: {convert_to_torch_tensor(data)}")
    except Exception as err:
        print(err)
    print()


## Option #1: Array API standard

In [None]:
import array_api_compat as apc


def stat_scores_array_api(preds, target):
    xp = apc.get_namespace(preds, target)

    if preds.dtype in (
        xp.float32,
        xp.float64,
    ):  # NOTE: float16 and bfloat16 are not supported
        if not xp.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > 0.5

    preds = preds.reshape(preds.shape[0], -1)
    target = target.reshape(target.shape[0], -1)

    tp = ((target == preds) & (target == 1)).sum().squeeze()
    fn = ((target != preds) & (target == 1)).sum().squeeze()
    fp = ((target != preds) & (target == 0)).sum().squeeze()
    tn = ((target == preds) & (target == 0)).sum().squeeze()

    scores = xp.stack([tp, fp, tn, fn, tp + fn], 0).squeeze()

    return scores

In [None]:
array_api_results = []
for pred, target in eval_test_data:
    result = stat_scores_array_api(pred, target)
    array_api_results.append(result)

print(array_api_results)


## Option #2: DLPack + numpy/cupy

In [None]:
def stat_scores_xp(preds, target):
    preds, target = to_xpy(preds, target)
    xp = get_array_module(preds, target)

    if preds.dtype in (
        xp.float32,
        xp.float64,
    ):  # NOTE: float16 and bfloat16 are not supported
        if not xp.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > 0.5

    preds = preds.reshape(preds.shape[0], -1)
    target = target.reshape(target.shape[0], -1)

    tp = ((target == preds) & (target == 1)).sum().squeeze()
    fn = ((target != preds) & (target == 1)).sum().squeeze()
    fp = ((target != preds) & (target == 0)).sum().squeeze()
    tn = ((target == preds) & (target == 0)).sum().squeeze()

    scores = xp.stack([tp, fp, tn, fn, tp + fn], 0).squeeze()

    return scores


In [None]:
xp_results = []
for pred, target in eval_test_data:
    result = stat_scores_xp(pred, target)
    xp_results.append(result)

print(xp_results)


## Option #3 DLPack + Torchmetrics

In [None]:
# compute metrics for different data types
# check that the inputs are not mutated
from torchmetrics.functional.classification import stat_scores
from copy import deepcopy

tm_results = []
for pred, target in eval_test_data:
    pred_copy = deepcopy(pred)
    target_copy = deepcopy(target)

    preds_tensor = convert_to_torch_tensor(pred)
    target_tensor = convert_to_torch_tensor(target)

    result = stat_scores(preds_tensor, target_tensor, task="binary")
    tm_results.append(to_xpy(result))

    # check that the values of preds and target are not mutated
    assert torch.equal(preds_tensor, convert_to_torch_tensor(pred_copy))
    assert torch.equal(target_tensor, convert_to_torch_tensor(target_copy))

print(tm_results)
