In [None]:
from typing import Any, Protocol, Union, runtime_checkable

import numpy as np
from numpy.typing import ArrayLike
import pandas as pd
import torch
import torch.utils.dlpack


In [None]:
# see: https://dmlc.github.io/dlpack/latest/
# and: https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html

# define a type that has __dlpack__ and __dlpack_device__ methods
_PyCapsule = Any


@runtime_checkable
class DLPackTensor(Protocol):
    def __dlpack__(self) -> _PyCapsule:
        ...

    def __dlpack_device__(self) -> Any:
        ...


In [None]:
def convert_to_torch_tensor(
    arg: Union[ArrayLike, DLPackTensor], copy: bool = False
) -> torch.Tensor:
    # should only support:
    # object types: list, tuple, numpy array, DLPack-compatible arrays/tensors
    # data types: bool, signed/unsigned integers, float and complex dtypes
    # signed/unsigned integers, float and complex dtypes
    if isinstance(arg, torch.Tensor):
        return arg

    if isinstance(arg, DLPackTensor):
        torch_tensor = torch.utils.dlpack.from_dlpack(arg)
    else:
        try:
            torch_tensor = torch.tensor(arg)
        except (TypeError, ValueError) as err:
            raise TypeError(
                f"Expected argument {arg} to be a DLPack-compatible tensor or a list,"
                " tuple, or numpy array that can be converted to a torch tensor."
                f" Got {type(arg)} instead."
            ) from err

    if copy:
        torch_tensor = torch_tensor.clone()

    return torch_tensor


In [None]:
# test that the function works as expected for different data types
test_data: list[Any] = [
    1,
    1.0,
    "1",
    False,
    np.array([1]),
    np.array([1.0, 2.0, 3.0]),
    np.array(["1", "2", "3"]),
    np.array([[True, False], [False, True]]),
    np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
    [1.0],
    ["1"],
    [1.0, 2.0, 3.0],
    ["1", "2", "3"],
    [False, True],
    torch.tensor([1]),
    torch.tensor([1, 2, 3]),
    torch.tensor([True, False]),
    torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
    pd.Series([1, 2, 3]),
    pd.Series([1.0, 2.0, 3.0]),
    pd.Series(["1", "2", "3"]),
    pd.Series([True, False]),
]

for data in test_data:
    print(f"Input: {data} ({type(data)})")
    try:
        print(f"Output: {convert_to_torch_tensor(data)}")
    except (TypeError, ValueError) as err:
        print(err)
    print()


In [None]:
from torchmetrics.classification import (
    Accuracy,
    AUROC,
    F1Score,
    FBetaScore,
    Precision,
    PrecisionRecallCurve,
    Recall,
    ROC,
    Specificity,
    StatScores,
)
from torchmetrics import MetricCollection

metrics = MetricCollection(
    [
        Accuracy(task="binary"),
        AUROC(task="binary"),
        Precision(task="binary"),
        Recall(task="binary"),
        F1Score(task="binary"),
        FBetaScore(task="binary", beta=0.5),
        PrecisionRecallCurve(task="binary"),
        ROC(task="binary"),
        Specificity(task="binary"),
        StatScores(task="binary"),
    ]
)


In [None]:
preds = [
    [0.0, 0.0, 1.0, 1.0],
    (0.0, 0.0, 1.0, 1.0),
    np.array([0.0, 0.0, 1.0, 1.0]),
    pd.Series([0.0, 0.0, 1.0, 1.0]),
    torch.tensor([0.0, 0.0, 1.0, 1.0]),
    # torch.tensor([0.0, 1.0, 0.0, 1.0], device="cuda:0"),
]

target = [
    [0, 0, 1, 1],
    (0, 0, 1, 1),
    np.array([0, 0, 1, 1]),
    pd.Series([0, 0, 1, 1]),
    torch.tensor([0, 0, 1, 1]),
    # torch.tensor([0, 1, 0, 1], device="cuda:0"),
]


In [None]:
# compute metrics for different data types
# check that the inputs are not mutated
from copy import deepcopy

for pred, target in zip(preds, target):
    pred_copy = deepcopy(pred)
    target_copy = deepcopy(target)

    preds_tensor = convert_to_torch_tensor(pred)
    target_tensor = convert_to_torch_tensor(target)

    result = metrics(preds_tensor, target_tensor)

    # check that the values of preds and target are not mutated
    assert torch.equal(preds_tensor, convert_to_torch_tensor(pred_copy))
    assert torch.equal(target_tensor, convert_to_torch_tensor(target_copy))
