Skip to content

Commit

Permalink
Implement softmax as splittable aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasZahradnik committed Dec 5, 2022
1 parent 66afa63 commit 2b9e714
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 3 deletions.
2 changes: 2 additions & 0 deletions neuralogic/core/constructs/function/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from neuralogic.core.constructs.function.function import Transformation, Combination, Aggregation, Function
from neuralogic.core.constructs.function.reshape import Reshape
from neuralogic.core.constructs.function.slice import Slice
from neuralogic.core.constructs.function.softmax import Softmax

_special_namings = {"LEAKY_RELU": "LEAKYRELU", "TRANSP": "TRANSPOSE"}

Expand All @@ -25,6 +26,7 @@


Aggregation.CONCAT = Concat("CONCAT")
Aggregation.SOFTMAX = Softmax("SOFTMAX")


__all__ = ["Transformation", "Combination", "Aggregation", "Function"]
1 change: 1 addition & 0 deletions neuralogic/core/constructs/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,4 @@ class Aggregation(Function):
SUM: "Aggregation"
COUNT: "Aggregation"
CONCAT: "Aggregation"
SOFTMAX: "Aggregation"
11 changes: 9 additions & 2 deletions neuralogic/core/constructs/function/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ def __init__(
cols: Union[type(Ellipsis), Tuple[int, int]] = ...,
):
super().__init__(name)

if cols is not Ellipsis:
cols = [int(x) for x in self.cols]

if rows is not Ellipsis:
rows = [int(x) for x in self.rows]

self.rows = rows
self.cols = cols

Expand All @@ -33,8 +40,8 @@ def is_parametrized(self) -> bool:
return True

def get(self):
cols = None if self.cols is Ellipsis else list(int(x) for x in self.cols)
rows = None if self.rows is Ellipsis else list(int(x) for x in self.rows)
cols = None if self.cols is Ellipsis else self.cols
rows = None if self.rows is Ellipsis else self.rows

return jpype.JClass("cz.cvut.fel.ida.algebra.functions.transformation.joint.Slice")(rows, cols)

Expand Down
35 changes: 35 additions & 0 deletions neuralogic/core/constructs/function/softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Sequence

import jpype

from neuralogic.core.constructs.function.function import Aggregation


class Softmax(Aggregation):
__slots__ = ("agg_terms",)

def __init__(
self,
name: str,
*,
agg_terms: Sequence[int] = None,
):
super().__init__(name)
if agg_terms is not None:
agg_terms = tuple(int(i) for i in agg_terms)
self.agg_terms = agg_terms

def __call__(self, entity=None, *, agg_terms: Sequence[int] = None):
softmax = Softmax(self.name, agg_terms=agg_terms)
return Aggregation.__call__(softmax, entity)

def is_parametrized(self) -> bool:
return self.agg_terms is not None

def get(self):
return jpype.JClass("cz.cvut.fel.ida.algebra.functions.combination.Softmax")(self.agg_terms)

def __str__(self):
if self.agg_terms is None:
return "softmax"
return f"softmax(agg_terms={self.agg_terms})"
Binary file modified neuralogic/jar/NeuraLogic.jar
Binary file not shown.
6 changes: 5 additions & 1 deletion neuralogic/nn/functional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Tuple
from typing import Union, Tuple, Sequence

from neuralogic.core.constructs.relation import BaseRelation
from neuralogic.core.constructs.function import Transformation, Combination, Function, Aggregation
Expand Down Expand Up @@ -203,3 +203,7 @@ def count(entity: BaseRelation = None) -> Union[BaseRelation, Function]:

def concat(entity: BaseRelation = None, *, axis: int = -1) -> Union[BaseRelation, Function]:
return Aggregation.CONCAT(entity, axis=axis)


def softmax_agg(entity: BaseRelation = None, *, agg_terms: Sequence[int] = None) -> Union[BaseRelation, Function]:
return Aggregation.SOFTMAX(entity, agg_terms=agg_terms)

0 comments on commit 2b9e714

Please sign in to comment.