From 3e8c6df16b45eb9674d8908f5399fed3250370ec Mon Sep 17 00:00:00 2001 From: pengcheng888 Date: Fri, 14 Nov 2025 15:32:57 +0800 Subject: [PATCH] =?UTF-8?q?issue/596=20-=20=E5=B0=86functional.py=E4=B8=AD?= =?UTF-8?q?=E7=9A=84=E5=87=BD=E6=95=B0=EF=BC=8C=E6=8B=86=E6=88=90functiona?= =?UTF-8?q?l=E6=96=87=E4=BB=B6=E5=A4=B9=E4=B8=AD=E7=9A=84=E5=87=BD?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/infinicore/nn/__init__.py | 6 +- python/infinicore/nn/functional.py | 101 ------------------ python/infinicore/nn/functional/__init__.py | 13 +++ .../nn/functional/causal_softmax.py | 15 +++ .../infinicore/nn/functional/random_sample.py | 38 +++++++ python/infinicore/nn/functional/rms_norm.py | 26 +++++ python/infinicore/nn/functional/silu.py | 23 ++++ python/infinicore/nn/functional/swiglu.py | 15 +++ test/infinicore/ops/random_sample.py | 52 ++++----- 9 files changed, 161 insertions(+), 128 deletions(-) delete mode 100644 python/infinicore/nn/functional.py create mode 100644 python/infinicore/nn/functional/__init__.py create mode 100644 python/infinicore/nn/functional/causal_softmax.py create mode 100644 python/infinicore/nn/functional/random_sample.py create mode 100644 python/infinicore/nn/functional/rms_norm.py create mode 100644 python/infinicore/nn/functional/silu.py create mode 100644 python/infinicore/nn/functional/swiglu.py diff --git a/python/infinicore/nn/__init__.py b/python/infinicore/nn/__init__.py index 9a091e628..e08b88af4 100644 --- a/python/infinicore/nn/__init__.py +++ b/python/infinicore/nn/__init__.py @@ -1,3 +1,3 @@ -from infinicore.nn import ( - functional as functional, -) +from infinicore.nn import functional + +__all__ = ["functional"] diff --git a/python/infinicore/nn/functional.py b/python/infinicore/nn/functional.py deleted file mode 100644 index b6a1f2e67..000000000 --- a/python/infinicore/nn/functional.py +++ /dev/null @@ -1,101 +0,0 @@ -import infinicore -from infinicore.lib import _infinicore -from infinicore.tensor import Tensor - -__all__ = ["causal_softmax", "random_sample", "rms_norm", "silu", "swiglu"] - - -def causal_softmax(input: Tensor, out=None) -> Tensor: - r"""Apply a causal softmax function.""" - - if out is None: - return Tensor(_infinicore.causal_softmax(input._underlying)) - - _infinicore.causal_softmax_(out._underlying, input._underlying) - - return out - - -def rms_norm( - input: Tensor, - normalized_shape: list[int], - weight: Tensor, - eps: float = 1e-5, - *, - out=None, -) -> Tensor: - r"""Apply Root Mean Square Layer Normalization.""" - - assert normalized_shape == weight.shape, ( - "normalized_shape does not match weight.shape." - ) - - if out is None: - return Tensor(_infinicore.rms_norm(input._underlying, weight._underlying, eps)) - - _infinicore.rms_norm_(out._underlying, input._underlying, weight._underlying, eps) - - return out - - -def silu(input: Tensor, inplace: bool = False, *, out=None) -> Tensor: - r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise.""" - - if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None: - return infinicore.ntops.torch.silu(input, inplace=inplace) - - if inplace: - _infinicore.silu_(input._underlying, input._underlying) - return input - - if out is None: - return Tensor(_infinicore.silu(input._underlying)) - - _infinicore.silu_(out._underlying, input._underlying) - - return out - - -def swiglu(input: Tensor, other: Tensor, *, out=None): - r"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise.""" - - if out is None: - return Tensor(_infinicore.swiglu(input._underlying, other._underlying)) - - _infinicore.swiglu_(out._underlying, input._underlying, other._underlying) - - return out - - -def random_sample( - logits: Tensor, - random_val: float, - topp: float, - topk: int, - temperature: float, - *, - out=None, -) -> Tensor: - r"""Sample an index from logits with nucleus/top-k filtering.""" - - if out is None: - return Tensor( - _infinicore.random_sample( - logits._underlying, - random_val, - topp, - topk, - temperature, - ) - ) - - _infinicore.random_sample_( - out._underlying, - logits._underlying, - random_val, - topp, - topk, - temperature, - ) - - return out diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py new file mode 100644 index 000000000..0bfdd4230 --- /dev/null +++ b/python/infinicore/nn/functional/__init__.py @@ -0,0 +1,13 @@ +from .causal_softmax import causal_softmax +from .random_sample import random_sample +from .rms_norm import rms_norm +from .silu import silu +from .swiglu import swiglu + +__all__ = [ + "causal_softmax", + "random_sample", + "rms_norm", + "silu", + "swiglu", +] diff --git a/python/infinicore/nn/functional/causal_softmax.py b/python/infinicore/nn/functional/causal_softmax.py new file mode 100644 index 000000000..7eaa59f06 --- /dev/null +++ b/python/infinicore/nn/functional/causal_softmax.py @@ -0,0 +1,15 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + +__all__ = ["causal_softmax"] + + +def causal_softmax(input: Tensor, out=None) -> Tensor: + r"""Apply a causal softmax function.""" + + if out is None: + return Tensor(_infinicore.causal_softmax(input._underlying)) + + _infinicore.causal_softmax_(out._underlying, input._underlying) + + return out diff --git a/python/infinicore/nn/functional/random_sample.py b/python/infinicore/nn/functional/random_sample.py new file mode 100644 index 000000000..624850b6c --- /dev/null +++ b/python/infinicore/nn/functional/random_sample.py @@ -0,0 +1,38 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + +__all__ = ["random_sample"] + + +def random_sample( + logits: Tensor, + random_val: float, + topp: float, + topk: int, + temperature: float, + *, + out=None, +) -> Tensor: + r"""Sample an index from logits with nucleus/top-k filtering.""" + + if out is None: + return Tensor( + _infinicore.random_sample( + logits._underlying, + random_val, + topp, + topk, + temperature, + ) + ) + + _infinicore.random_sample_( + out._underlying, + logits._underlying, + random_val, + topp, + topk, + temperature, + ) + + return out diff --git a/python/infinicore/nn/functional/rms_norm.py b/python/infinicore/nn/functional/rms_norm.py new file mode 100644 index 000000000..dd6aaf08d --- /dev/null +++ b/python/infinicore/nn/functional/rms_norm.py @@ -0,0 +1,26 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + +__all__ = ["rms_norm"] + + +def rms_norm( + input: Tensor, + normalized_shape: list[int], + weight: Tensor, + eps: float = 1e-5, + *, + out=None, +) -> Tensor: + r"""Apply Root Mean Square Layer Normalization.""" + + assert normalized_shape == weight.shape, ( + "normalized_shape does not match weight.shape." + ) + + if out is None: + return Tensor(_infinicore.rms_norm(input._underlying, weight._underlying, eps)) + + _infinicore.rms_norm_(out._underlying, input._underlying, weight._underlying, eps) + + return out diff --git a/python/infinicore/nn/functional/silu.py b/python/infinicore/nn/functional/silu.py new file mode 100644 index 000000000..f67e4c3fd --- /dev/null +++ b/python/infinicore/nn/functional/silu.py @@ -0,0 +1,23 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + +__all__ = ["silu"] + + +def silu(input: Tensor, inplace: bool = False, *, out=None) -> Tensor: + r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise.""" + + if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None: + return infinicore.ntops.torch.silu(input, inplace=inplace) + + if inplace: + _infinicore.silu_(input._underlying, input._underlying) + return input + + if out is None: + return Tensor(_infinicore.silu(input._underlying)) + + _infinicore.silu_(out._underlying, input._underlying) + + return out diff --git a/python/infinicore/nn/functional/swiglu.py b/python/infinicore/nn/functional/swiglu.py new file mode 100644 index 000000000..58b03ec38 --- /dev/null +++ b/python/infinicore/nn/functional/swiglu.py @@ -0,0 +1,15 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + +__all__ = ["swiglu"] + + +def swiglu(input: Tensor, other: Tensor, *, out=None): + r"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise.""" + + if out is None: + return Tensor(_infinicore.swiglu(input._underlying, other._underlying)) + + _infinicore.swiglu_(out._underlying, input._underlying, other._underlying) + + return out diff --git a/test/infinicore/ops/random_sample.py b/test/infinicore/ops/random_sample.py index 98b8dd729..f5e993f75 100644 --- a/test/infinicore/ops/random_sample.py +++ b/test/infinicore/ops/random_sample.py @@ -109,7 +109,11 @@ def torch_random_sample(data, random_val, topp, topk, voc, temperature): idx = torch.searchsorted(cum_probs, threshold) except Exception: indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0] - idx = indices[0] if indices.numel() > 0 else torch.tensor(len(cum_probs) - 1, device=cum_probs.device) + idx = ( + indices[0] + if indices.numel() > 0 + else torch.tensor(len(cum_probs) - 1, device=cum_probs.device) + ) return sorted_indices[idx] return torch.argmax(data) @@ -191,41 +195,41 @@ def infinicore_operator(self, logits, out=None, **kwargs): def run_test(self, device, test_case, config): """ Override run_test to handle random_sample's special comparison logic. - + For random_sample, if the indices differ but the logits values at those indices are equal, the result is still considered valid. This handles cases where multiple valid indices exist due to floating-point precision. - + This is necessary because random_sample can return different valid indices when multiple positions have the same logits value, especially with low-precision types like bfloat16 due to floating-point rounding. """ # Clear stored logits before test to ensure fresh generation self._current_logits = None - + try: # Try the standard comparison first # This will call prepare_inputs_and_kwargs which will set self._current_logits return super().run_test(device, test_case, config) - except AssertionError: + except AssertionError as original_error: # If standard comparison fails, check if this is a valid case where # indices differ but logits values are equal - + # Only handle if we have stored logits (from prepare_inputs_and_kwargs) if self._current_logits is None: raise - + logits_tensor = self._current_logits - + # Re-run operations with the same logits to get results for comparison # prepare_inputs_and_kwargs will reuse self._current_logits if it exists from framework.utils import ( infinicore_tensor_from_torch, convert_infinicore_to_torch, ) - + inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device) - + # Prepare infinicore inputs infini_inputs = [] for inp in inputs: @@ -235,37 +239,37 @@ def run_test(self, device, test_case, config): infini_inputs.append(infini_tensor) else: infini_inputs.append(inp) - + infini_kwargs = kwargs.copy() - if "out" in infini_kwargs and isinstance(infini_kwargs["out"], torch.Tensor): + if "out" in infini_kwargs and isinstance( + infini_kwargs["out"], torch.Tensor + ): cloned_out = infini_kwargs["out"].clone().detach() infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out) - + # Run both operators torch_result = self.torch_operator(*inputs, **kwargs) infini_result = self.infinicore_operator(*infini_inputs, **infini_kwargs) - + # Extract indices from results comparison_target = test_case.comparison_target if comparison_target == "out": # Compare output tensor from kwargs ref_idx = kwargs["out"].item() torch_result_from_infini = convert_infinicore_to_torch( - infini_kwargs["out"], kwargs["out"] + infini_kwargs["out"] ) ic_idx = torch_result_from_infini.item() else: # Compare return values ref_idx = torch_result.item() - torch_result_from_infini = convert_infinicore_to_torch( - infini_result, torch_result - ) + torch_result_from_infini = convert_infinicore_to_torch(infini_result) ic_idx = torch_result_from_infini.item() - + # Check if indices are equal (standard case) if ic_idx == ref_idx: - return - + return True, "passed" + # Special case: indices differ but logits values are equal # This is valid for random_sample when multiple indices have the same logits value try: @@ -273,13 +277,13 @@ def run_test(self, device, test_case, config): logits_ic = logits_tensor[ic_idx].item() if logits_ic == logits_ref: # Valid: different indices but same logits value - return + return True, "passed" except (IndexError, RuntimeError): # If we can't access the logits, fall through to raise the original error pass - + # If we get here, the results are truly different - raise + raise original_error def main():