diff --git a/neuralogic/core/constructs/function/__init__.py b/neuralogic/core/constructs/function/__init__.py index 546fa2d1..35e6af01 100644 --- a/neuralogic/core/constructs/function/__init__.py +++ b/neuralogic/core/constructs/function/__init__.py @@ -2,6 +2,7 @@ from neuralogic.core.constructs.function.function import Transformation, Combination, Aggregation, Function from neuralogic.core.constructs.function.reshape import Reshape from neuralogic.core.constructs.function.slice import Slice +from neuralogic.core.constructs.function.softmax import Softmax _special_namings = {"LEAKY_RELU": "LEAKYRELU", "TRANSP": "TRANSPOSE"} @@ -25,6 +26,7 @@ Aggregation.CONCAT = Concat("CONCAT") +Aggregation.SOFTMAX = Softmax("SOFTMAX") __all__ = ["Transformation", "Combination", "Aggregation", "Function"] diff --git a/neuralogic/core/constructs/function/function.py b/neuralogic/core/constructs/function/function.py index 966645d1..b504c06c 100644 --- a/neuralogic/core/constructs/function/function.py +++ b/neuralogic/core/constructs/function/function.py @@ -85,3 +85,4 @@ class Aggregation(Function): SUM: "Aggregation" COUNT: "Aggregation" CONCAT: "Aggregation" + SOFTMAX: "Aggregation" diff --git a/neuralogic/core/constructs/function/slice.py b/neuralogic/core/constructs/function/slice.py index ed6372ed..eb4533d9 100644 --- a/neuralogic/core/constructs/function/slice.py +++ b/neuralogic/core/constructs/function/slice.py @@ -16,6 +16,13 @@ def __init__( cols: Union[type(Ellipsis), Tuple[int, int]] = ..., ): super().__init__(name) + + if cols is not Ellipsis: + cols = [int(x) for x in self.cols] + + if rows is not Ellipsis: + rows = [int(x) for x in self.rows] + self.rows = rows self.cols = cols @@ -33,8 +40,8 @@ def is_parametrized(self) -> bool: return True def get(self): - cols = None if self.cols is Ellipsis else list(int(x) for x in self.cols) - rows = None if self.rows is Ellipsis else list(int(x) for x in self.rows) + cols = None if self.cols is Ellipsis else self.cols + rows = None if self.rows is Ellipsis else self.rows return jpype.JClass("cz.cvut.fel.ida.algebra.functions.transformation.joint.Slice")(rows, cols) diff --git a/neuralogic/core/constructs/function/softmax.py b/neuralogic/core/constructs/function/softmax.py new file mode 100644 index 00000000..f6531e78 --- /dev/null +++ b/neuralogic/core/constructs/function/softmax.py @@ -0,0 +1,35 @@ +from typing import Sequence + +import jpype + +from neuralogic.core.constructs.function.function import Aggregation + + +class Softmax(Aggregation): + __slots__ = ("agg_terms",) + + def __init__( + self, + name: str, + *, + agg_terms: Sequence[int] = None, + ): + super().__init__(name) + if agg_terms is not None: + agg_terms = tuple(int(i) for i in agg_terms) + self.agg_terms = agg_terms + + def __call__(self, entity=None, *, agg_terms: Sequence[int] = None): + softmax = Softmax(self.name, agg_terms=agg_terms) + return Aggregation.__call__(softmax, entity) + + def is_parametrized(self) -> bool: + return self.agg_terms is not None + + def get(self): + return jpype.JClass("cz.cvut.fel.ida.algebra.functions.combination.Softmax")(self.agg_terms) + + def __str__(self): + if self.agg_terms is None: + return "softmax" + return f"softmax(agg_terms={self.agg_terms})" diff --git a/neuralogic/jar/NeuraLogic.jar b/neuralogic/jar/NeuraLogic.jar index a2eb18ad..fa77fa53 100644 Binary files a/neuralogic/jar/NeuraLogic.jar and b/neuralogic/jar/NeuraLogic.jar differ diff --git a/neuralogic/nn/functional.py b/neuralogic/nn/functional.py index 8513091e..e5fd3db2 100644 --- a/neuralogic/nn/functional.py +++ b/neuralogic/nn/functional.py @@ -1,4 +1,4 @@ -from typing import Union, Tuple +from typing import Union, Tuple, Sequence from neuralogic.core.constructs.relation import BaseRelation from neuralogic.core.constructs.function import Transformation, Combination, Function, Aggregation @@ -203,3 +203,7 @@ def count(entity: BaseRelation = None) -> Union[BaseRelation, Function]: def concat(entity: BaseRelation = None, *, axis: int = -1) -> Union[BaseRelation, Function]: return Aggregation.CONCAT(entity, axis=axis) + + +def softmax_agg(entity: BaseRelation = None, *, agg_terms: Sequence[int] = None) -> Union[BaseRelation, Function]: + return Aggregation.SOFTMAX(entity, agg_terms=agg_terms)