Skip to content

Commit

Permalink
Add the exponential function
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjm97 committed Jul 6, 2023
1 parent d9b302e commit 4b926ec
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions teg/math/smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Sqrt(SmoothFunc):
NOTE: Does not check inputs are valid (x must be positive)
"""

def __init__(self, expr: ITeg, name: str = "Sqrt"):
super(Sqrt, self).__init__(expr=expr, name=name)

Expand All @@ -33,7 +34,8 @@ def output_size(input_size):


class Sqr(SmoothFunc):
"""y = x**2 """
"""y = x**2"""

def __init__(self, expr: ITeg, name: str = "Sqr"):
super(Sqr, self).__init__(expr=expr, name=name)

Expand All @@ -51,7 +53,8 @@ def output_size(input_size):


class Sin(SmoothFunc):
"""y = sin(x) """
"""y = sin(x)"""

def __init__(self, expr: ITeg, name: str = "Sin"):
super(Sin, self).__init__(expr=expr, name=name)

Expand All @@ -70,6 +73,7 @@ def output_size(input_size):

class Cos(SmoothFunc):
"""y = cos(x)"""

def __init__(self, expr: ITeg, name: str = "Cos"):
super(Cos, self).__init__(expr=expr, name=name)

Expand All @@ -87,7 +91,8 @@ def output_size(input_size):


class ASin(SmoothFunc):
"""theta = asin(x) """
"""theta = asin(x)"""

def __init__(self, expr: ITeg, name: str = "ASin"):
super(ASin, self).__init__(expr=expr, name=name)

Expand All @@ -105,7 +110,8 @@ def output_size(input_size):


class ATan2(SmoothFunc):
"""theta = atan2(x, y) """
"""theta = atan2(x, y)"""

def __init__(self, expr: ITeg, name: str = "ATan2"):
super(ATan2, self).__init__(expr=expr, name=name)

Expand All @@ -121,3 +127,20 @@ def operation(self, in_value):
def output_size(input_size):
assert input_size == 2
return 1


class Exp(SmoothFunc):
def __init__(self, expr: ITeg, name: str = "Exp"):
super(Exp, self).__init__(expr=expr, name=name)

def fwd_deriv(self, in_deriv_expr: ITeg):
return Exp(self.expr) * in_deriv_expr

def rev_deriv(self, out_deriv_expr: ITeg):
return out_deriv_expr * Exp(self.expr)

def operation(self, in_value):
return np.exp(in_value)

def output_size(input_size):
return input_size

0 comments on commit 4b926ec

Please sign in to comment.