Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/infinicore/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from infinicore.nn import (
functional as functional,
)
from infinicore.nn import functional

__all__ = ["functional"]
101 changes: 0 additions & 101 deletions python/infinicore/nn/functional.py

This file was deleted.

13 changes: 13 additions & 0 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .causal_softmax import causal_softmax
from .random_sample import random_sample
from .rms_norm import rms_norm
from .silu import silu
from .swiglu import swiglu

__all__ = [
"causal_softmax",
"random_sample",
"rms_norm",
"silu",
"swiglu",
]
15 changes: 15 additions & 0 deletions python/infinicore/nn/functional/causal_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor

__all__ = ["causal_softmax"]


def causal_softmax(input: Tensor, out=None) -> Tensor:
r"""Apply a causal softmax function."""

if out is None:
return Tensor(_infinicore.causal_softmax(input._underlying))

_infinicore.causal_softmax_(out._underlying, input._underlying)

return out
38 changes: 38 additions & 0 deletions python/infinicore/nn/functional/random_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor

__all__ = ["random_sample"]


def random_sample(
logits: Tensor,
random_val: float,
topp: float,
topk: int,
temperature: float,
*,
out=None,
) -> Tensor:
r"""Sample an index from logits with nucleus/top-k filtering."""

if out is None:
return Tensor(
_infinicore.random_sample(
logits._underlying,
random_val,
topp,
topk,
temperature,
)
)

_infinicore.random_sample_(
out._underlying,
logits._underlying,
random_val,
topp,
topk,
temperature,
)

return out
26 changes: 26 additions & 0 deletions python/infinicore/nn/functional/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor

__all__ = ["rms_norm"]


def rms_norm(
input: Tensor,
normalized_shape: list[int],
weight: Tensor,
eps: float = 1e-5,
*,
out=None,
) -> Tensor:
r"""Apply Root Mean Square Layer Normalization."""

assert normalized_shape == weight.shape, (
"normalized_shape does not match weight.shape."
)

if out is None:
return Tensor(_infinicore.rms_norm(input._underlying, weight._underlying, eps))

_infinicore.rms_norm_(out._underlying, input._underlying, weight._underlying, eps)

return out
23 changes: 23 additions & 0 deletions python/infinicore/nn/functional/silu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor

__all__ = ["silu"]


def silu(input: Tensor, inplace: bool = False, *, out=None) -> Tensor:
r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise."""

if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None:
return infinicore.ntops.torch.silu(input, inplace=inplace)

if inplace:
_infinicore.silu_(input._underlying, input._underlying)
return input

if out is None:
return Tensor(_infinicore.silu(input._underlying))

_infinicore.silu_(out._underlying, input._underlying)

return out
15 changes: 15 additions & 0 deletions python/infinicore/nn/functional/swiglu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor

__all__ = ["swiglu"]


def swiglu(input: Tensor, other: Tensor, *, out=None):
r"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise."""

if out is None:
return Tensor(_infinicore.swiglu(input._underlying, other._underlying))

_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)

return out
52 changes: 28 additions & 24 deletions test/infinicore/ops/random_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ def torch_random_sample(data, random_val, topp, topk, voc, temperature):
idx = torch.searchsorted(cum_probs, threshold)
except Exception:
indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0]
idx = indices[0] if indices.numel() > 0 else torch.tensor(len(cum_probs) - 1, device=cum_probs.device)
idx = (
indices[0]
if indices.numel() > 0
else torch.tensor(len(cum_probs) - 1, device=cum_probs.device)
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面挺过的变动,是ruff插件的自动格式

return sorted_indices[idx]

return torch.argmax(data)
Expand Down Expand Up @@ -191,41 +195,41 @@ def infinicore_operator(self, logits, out=None, **kwargs):
def run_test(self, device, test_case, config):
"""
Override run_test to handle random_sample's special comparison logic.

For random_sample, if the indices differ but the logits values at those
indices are equal, the result is still considered valid. This handles
cases where multiple valid indices exist due to floating-point precision.

This is necessary because random_sample can return different valid indices
when multiple positions have the same logits value, especially with
low-precision types like bfloat16 due to floating-point rounding.
"""
# Clear stored logits before test to ensure fresh generation
self._current_logits = None

try:
# Try the standard comparison first
# This will call prepare_inputs_and_kwargs which will set self._current_logits
return super().run_test(device, test_case, config)
except AssertionError:
except AssertionError as original_error:
# If standard comparison fails, check if this is a valid case where
# indices differ but logits values are equal

# Only handle if we have stored logits (from prepare_inputs_and_kwargs)
if self._current_logits is None:
raise

logits_tensor = self._current_logits

# Re-run operations with the same logits to get results for comparison
# prepare_inputs_and_kwargs will reuse self._current_logits if it exists
from framework.utils import (
infinicore_tensor_from_torch,
convert_infinicore_to_torch,
)

inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)

# Prepare infinicore inputs
infini_inputs = []
for inp in inputs:
Expand All @@ -235,51 +239,51 @@ def run_test(self, device, test_case, config):
infini_inputs.append(infini_tensor)
else:
infini_inputs.append(inp)

infini_kwargs = kwargs.copy()
if "out" in infini_kwargs and isinstance(infini_kwargs["out"], torch.Tensor):
if "out" in infini_kwargs and isinstance(
infini_kwargs["out"], torch.Tensor
):
cloned_out = infini_kwargs["out"].clone().detach()
infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out)

# Run both operators
torch_result = self.torch_operator(*inputs, **kwargs)
infini_result = self.infinicore_operator(*infini_inputs, **infini_kwargs)

# Extract indices from results
comparison_target = test_case.comparison_target
if comparison_target == "out":
# Compare output tensor from kwargs
ref_idx = kwargs["out"].item()
torch_result_from_infini = convert_infinicore_to_torch(
infini_kwargs["out"], kwargs["out"]
infini_kwargs["out"]
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

utils中的convert_infinicore_to_torch函数的参数有删减,对应修改了一下

ic_idx = torch_result_from_infini.item()
else:
# Compare return values
ref_idx = torch_result.item()
torch_result_from_infini = convert_infinicore_to_torch(
infini_result, torch_result
)
torch_result_from_infini = convert_infinicore_to_torch(infini_result)
ic_idx = torch_result_from_infini.item()

# Check if indices are equal (standard case)
if ic_idx == ref_idx:
return
return True, "passed"

# Special case: indices differ but logits values are equal
# This is valid for random_sample when multiple indices have the same logits value
try:
logits_ref = logits_tensor[ref_idx].item()
logits_ic = logits_tensor[ic_idx].item()
if logits_ic == logits_ref:
# Valid: different indices but same logits value
return
return True, "passed"
except (IndexError, RuntimeError):
# If we can't access the logits, fall through to raise the original error
pass

# If we get here, the results are truly different
raise
raise original_error


def main():
Expand Down