Skip to content

Commit

Permalink
Add 1D averaging operators and clean up average interface
Browse files Browse the repository at this point in the history
  • Loading branch information
kburns committed Feb 15, 2022
1 parent 199aa52 commit 42e8e36
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 42 deletions.
72 changes: 72 additions & 0 deletions dedalus/core/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,29 @@ def _full_matrix(input_basis, output_basis):
return integ_vector[None, :] * input_basis.COV.stretch


class AverageJacobi(operators.Average, operators.SpectralOperator1D):
"""Jacobi polynomial averaging."""

input_coord_type = Coordinate
input_basis_type = Jacobi
subaxis_dependence = [True]
subaxis_coupling = [True]

@staticmethod
def _output_basis(input_basis):
return None

@staticmethod
def _full_matrix(input_basis, output_basis):
# Build native integration vector
N = input_basis.size
a, b = input_basis.a, input_basis.b
integ_vector = jacobi.integration_vector(N, a, b)
ave_vector = integ_vector / 2
# Rescale and return with shape (1, N)
return ave_vector[None, :]


class LiftJacobi(operators.Lift, operators.Copy):
"""Jacobi polynomial lift."""

Expand Down Expand Up @@ -949,6 +972,30 @@ def _group_matrix(group, input_basis, output_basis):
raise ValueError("This should never happen.")


class AverageComplexFourier(operators.Average, operators.SpectralOperator1D):
"""ComplexFourier averaging."""

input_coord_type = Coordinate
input_basis_type = ComplexFourier
subaxis_dependence = [True]
subaxis_coupling = [False]

@staticmethod
def _output_basis(input_basis):
return None

@staticmethod
def _group_matrix(group, input_basis, output_basis):
# Rescale group (native wavenumber) to get physical wavenumber
k = group / input_basis.COV.stretch
# integ exp(1j*k*x) / L = δ(k, 0)
if k == 0:
return np.array([[1]])
else:
# Constructor should only loop over group 0.
raise ValueError("This should never happen.")


class RealFourier(FourierBase, metaclass=CachedClass):
"""
Fourier real sine/cosine basis.
Expand Down Expand Up @@ -1152,6 +1199,31 @@ def _group_matrix(group, input_basis, output_basis):
raise ValueError("This should never happen.")


class AverageRealFourier(operators.Average, operators.SpectralOperator1D):
"""RealFourier averaging."""

input_coord_type = Coordinate
input_basis_type = RealFourier
subaxis_dependence = [True]
subaxis_coupling = [False]

@staticmethod
def _output_basis(input_basis):
return None

@staticmethod
def _group_matrix(group, input_basis, output_basis):
# Rescale group (native wavenumber) to get physical wavenumber
k = group / input_basis.COV.stretch
# integ cos(k*x) / L = δ(k, 0)
# integ -sin(k*x) / L = 0
if k == 0:
return np.array([[1, 0]])
else:
# Constructor should only loop over group 0.
raise ValueError("This should never happen.")


# class HilbertTransformFourier(operators.HilbertTransform):
# """Fourier series Hilbert transform."""

Expand Down
4 changes: 2 additions & 2 deletions dedalus/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,8 @@ def setup_file(self, file, virtual_file=False):
scale_group.create_dataset(name='timestep', shape=(0,), maxshape=(None,), dtype=np.float64)
scale_group.create_dataset(name='world_time', shape=(0,), maxshape=(None,), dtype=np.float64)
scale_group.create_dataset(name='wall_time', shape=(0,), maxshape=(None,), dtype=np.float64)
scale_group.create_dataset(name='iteration', shape=(0,), maxshape=(None,), dtype=np.int)
scale_group.create_dataset(name='write_number', shape=(0,), maxshape=(None,), dtype=np.int)
scale_group.create_dataset(name='iteration', shape=(0,), maxshape=(None,), dtype=int)
scale_group.create_dataset(name='write_number', shape=(0,), maxshape=(None,), dtype=int)
scale_group.create_dataset(name='constant', data=np.array([0.], dtype=np.float64))
scale_group['constant'].make_scale()

Expand Down
56 changes: 25 additions & 31 deletions dedalus/core/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,27 +1073,16 @@ def _expand_multiply(self, operand, vars):
return prod([self.new_operand(arg) for arg in operand.args])


@parseable('integrate', 'integ')
def integrate(arg, spaces=None):
if spaces is None:
spaces = tuple(b.coords for b in arg.domain.bases)
# Identify domain
#domain = unify_attributes((arg,)+spaces, 'domain', require=False)
# Apply iteratively
for space in spaces:
#space = domain.get_space_object(space)
arg = Integrate(arg, space)
return arg

@alias("integ")
class Integrate(LinearOperator, metaclass=MultiClass):
"""
Definite integration over operand bases.
Integrate over operand bases.
Parameters
----------
operand : number or Operand object
coords : Coordinate or CoordinateSystem object, or list of these
"""

name = "Integrate"
Expand Down Expand Up @@ -1154,17 +1143,38 @@ def new_operand(self, operand, **kw):
@alias("ave")
class Average(LinearOperator, metaclass=MultiClass):
"""
Average along one dimension.
Average over operand bases.
Parameters
----------
operand : number or Operand object
space : Space object
coords : Coordinate or CoordinateSystem object, or list of these
"""

name = "Average"

@classmethod
def _preprocess_args(cls, operand, coord=None):
# Handle numbers
if isinstance(operand, Number):
raise SkipDispatchException(output=operand)
# Average over all operand bases by default
if coord is None:
coord = [basis.coordsystem for basis in operand.domain.bases]
# Recurse over multiple coordinates
if isinstance(coord, (tuple, list)):
if len(coord) > 1:
operand = Average(operand, coord[:-1])
coord = coord[-1]
# Resolve strings to coordinates
if isinstance(coord, str):
coord = operand.domain.get_coord(coord)
# Check coordinate type
if not isinstance(coord, (coords.Coordinate, coords.CoordinateSystem)):
raise ValueError("coords must be Coordinate or str")
return (operand, coord), {}

@classmethod
def _check_args(cls, operand, coords):
# Dispatch by operand basis
Expand All @@ -1175,22 +1185,6 @@ def _check_args(cls, operand, coords):
return True
return False

@classmethod
def _preprocess_args(cls, operand, coord=None):
if isinstance(operand, Number):
raise SkipDispatchException(output=operand)
if coord is None:
coord = operand.dist.single_coordsys
if coord is False:
raise ValueError("coordsys must be specified.")
elif isinstance(coord, (coords.Coordinate, coords.CoordinateSystem)):
pass
elif isinstance(coord, str):
coord = operand.domain.get_coord(coord)
else:
raise ValueError("coord must be Coordinate or str")
return (operand, coord), {}

def __init__(self, operand, coord):
SpectralOperator.__init__(self, operand)
# Require integrand is a scalar
Expand Down
19 changes: 10 additions & 9 deletions docs/notebooks/dedalus_tutorial_2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
"2022-02-11 14:29:22,407 dedalus 0/1 WARNING :: Threading has not been disabled. This may massively degrade Dedalus performance.\n",
"2022-02-11 14:29:22,408 dedalus 0/1 WARNING :: We strongly suggest setting the \"OMP_NUM_THREADS\" environment variable to \"1\".\n",
"2022-02-11 14:29:22,543 numexpr.utils 0/1 INFO :: Note: NumExpr detected 10 cores but \"NUMEXPR_MAX_THREADS\" not set, so enforcing safe limit of 8.\n",
"2022-02-11 14:29:22,543 numexpr.utils 0/1 INFO :: NumExpr defaulting to 8 threads.\n"
"2022-02-15 17:16:52,150 dedalus 0/1 WARNING :: Threading has not been disabled. This may massively degrade Dedalus performance.\n",
"2022-02-15 17:16:52,151 dedalus 0/1 WARNING :: We strongly suggest setting the \"OMP_NUM_THREADS\" environment variable to \"1\".\n",
"2022-02-15 17:16:52,255 numexpr.utils 0/1 INFO :: Note: NumExpr detected 10 cores but \"NUMEXPR_MAX_THREADS\" not set, so enforcing safe limit of 8.\n",
"2022-02-15 17:16:52,256 numexpr.utils 0/1 INFO :: NumExpr defaulting to 8 threads.\n"
]
}
],
Expand Down Expand Up @@ -224,7 +224,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/gl/8q1_pm2s1490lvyfvm_8yby80000gn/T/ipykernel_82069/1812536267.py:4: RuntimeWarning: divide by zero encountered in log10\n",
"/var/folders/gl/8q1_pm2s1490lvyfvm_8yby80000gn/T/ipykernel_62829/1812536267.py:4: RuntimeWarning: divide by zero encountered in log10\n",
" log_mag = lambda xmesh, ymesh, data: (xmesh, ymesh, np.log10(np.abs(data)))\n"
]
},
Expand Down Expand Up @@ -751,18 +751,19 @@
"name": "stdout",
"output_type": "stream",
"text": [
"f integral: [[9.42458659]]\n"
"f integral: [[9.42458659]]\n",
"f average: [[0.74998477]]\n"
]
}
],
"source": [
"# Total integral of the field\n",
"f_int = d3.Integrate(d3.Integrate(f, 'x'), 'y')\n",
"f_int = d3.Integrate(f, ('x', 'y'))\n",
"print('f integral:', f_int.evaluate()['g'])\n",
"\n",
"# Average of the field\n",
"#f_ave = d3.Average(d3.Average(f, 'x'), 'y')\n",
"#print('f average:', f_ave.evaluate()['g'])"
"f_ave = d3.Average(f, ('x', 'y'))\n",
"print('f average:', f_ave.evaluate()['g'])"
]
},
{
Expand Down

0 comments on commit 42e8e36

Please sign in to comment.