diff --git a/csp/math.py b/csp/math.py index 2e9beac2..e215cd62 100644 --- a/csp/math.py +++ b/csp/math.py @@ -1,5 +1,6 @@ import math import numpy as np +import sys import typing from functools import lru_cache @@ -363,7 +364,10 @@ def comp(x: ts["T"]): log2 = define_unary_op("log2", lambda x: math.log2(x)) log10 = define_unary_op("log10", lambda x: math.log10(x)) exp = define_unary_op("exp", lambda x: math.exp(x)) -exp2 = define_unary_op("exp2", lambda x: math.exp2(x)) +if sys.version_info < (3, 11): + exp2 = define_unary_op("exp2", lambda x: 2**x) +else: + exp2 = define_unary_op("exp2", lambda x: math.exp2) sqrt = define_unary_op("sqrt", lambda x: math.sqrt(x)) erf = define_unary_op("erf", lambda x: math.erf(x)) sin = define_unary_op("sin", lambda x: math.sin(x)) diff --git a/csp/tests/test_math.py b/csp/tests/test_math.py index 6177ceb1..a3b7a84c 100644 --- a/csp/tests/test_math.py +++ b/csp/tests/test_math.py @@ -122,7 +122,7 @@ def test_math_unary_ops(self): csp.log2: lambda x: math.log2(x), csp.log10: lambda x: math.log10(x), csp.exp: lambda x: math.exp(x), - csp.exp2: lambda x: math.exp2(x), + csp.exp2: lambda x: 2**x, csp.sin: lambda x: math.sin(x), csp.cos: lambda x: math.cos(x), csp.tan: lambda x: math.tan(x),