Skip to content

Commit

Permalink
Add some helpers: basis.derivative_basis() method, field.T and field.…
Browse files Browse the repository at this point in the history
…H for regular/Hermitian transpose, and dist.IdentityTensor.
  • Loading branch information
kburns committed Mar 27, 2022
1 parent c153f2e commit 8b709a9
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 26 deletions.
17 changes: 14 additions & 3 deletions dedalus/core/basis.py
Expand Up @@ -556,6 +556,11 @@ def multiplication_matrix(self, subproblem, arg_basis, coeffs, ncc_comp, arg_com
matrix = convert @ matrix
return matrix[:N, :N]

def derivative_basis(self, order=1):
a = self.a + order
b = self.b + order
return self.clone_with(a=a, b=b)


def Legendre(*args, **kw):
return Jacobi(*args, a=0, b=0, **kw)
Expand Down Expand Up @@ -631,9 +636,7 @@ class DifferentiateJacobi(operators.Differentiate, operators.SpectralOperator1D)

@staticmethod
def _output_basis(input_basis):
a = input_basis.a + 1
b = input_basis.b + 1
return input_basis.clone_with(a=a, b=b)
return input_basis.derivative_basis(order=1)

@staticmethod
@CachedMethod
Expand Down Expand Up @@ -1944,6 +1947,10 @@ def radius_multiplication_matrix(self, m, spintotal, order, d):
operator = R2**(d//2) @ operator
return operator(self.n_size(m), self.alpha + self.k, abs(m + spintotal)).square.astype(np.float64)

def derivative_basis(self, order=1):
k = self.k + order
return self.clone_with(k=k)


class AnnulusBasis(PolarBasis):

Expand Down Expand Up @@ -3924,6 +3931,10 @@ def matrix_dependence(self, matrix_coupling):
def constant(self):
return (self.Lmax==0, self.Lmax==0, False)

def derivative_basis(self, order=1):
k = self.k + order
return self.clone_with(k=k)

@CachedAttribute
def constant_mode_value(self):
# Adjust for SWSH normalization
Expand Down
8 changes: 8 additions & 0 deletions dedalus/core/distributor.py
Expand Up @@ -216,6 +216,14 @@ def TensorField(self, *args, **kw):
from .field import TensorField
return TensorField(self, *args, **kw)

def IdentityTensor(self, coordsys):
"""Identity tensor field."""
from .field import TensorField
I = TensorField(self, (coordsys, coordsys))
for i in range(coordsys.dim):
I['g'][i, i] = 1
return I

def local_grid(self, basis, scale=None):
# TODO: remove from bases and do it all here?
if basis.dim == 1:
Expand Down
8 changes: 8 additions & 0 deletions dedalus/core/field.py
Expand Up @@ -247,7 +247,15 @@ def expression_matrices(self, subproblem, vars, **kw):
"""Build expression matrices for a specific subproblem and variables."""
raise NotImplementedError()

@property
def T(self):
from .operators import TransposeComponents
return TransposeComponents(self)

@property
def H(self):
from .operators import TransposeComponents
return TransposeComponents(np.conj(self))


class Current(Operand):
Expand Down
24 changes: 8 additions & 16 deletions dedalus/core/operators.py
Expand Up @@ -2813,8 +2813,7 @@ def __init__(self, operand, coordsys, out=None):

@staticmethod
def _output_basis(input_basis):
out = input_basis._new_k(input_basis.k + 1)
return out
return input_basis.derivative_basis(1)

def check_conditions(self):
"""Check that operands are in a proper layout."""
Expand Down Expand Up @@ -2962,8 +2961,7 @@ def __init__(self, operand, coordsys, out=None):

@staticmethod
def _output_basis(input_basis):
out = input_basis._new_k(input_basis.k + 1)
return out
return input_basis.derivative_basis(1)

def check_conditions(self):
"""Check that operands are in a proper layout."""
Expand Down Expand Up @@ -3164,8 +3162,7 @@ def __init__(self, operand, index=0, out=None):

@staticmethod
def _output_basis(input_basis):
out = input_basis._new_k(input_basis.k + 1)
return out
return input_basis.derivative_basis(1)

def check_conditions(self):
"""Check that operands are in a proper layout."""
Expand Down Expand Up @@ -3225,8 +3222,7 @@ def __init__(self, operand, index=0, out=None):

@staticmethod
def _output_basis(input_basis):
out = input_basis._new_k(input_basis.k + 1)
return out
return input_basis.derivative_basis(1)

def check_conditions(self):
"""Check that operands are in a proper layout."""
Expand Down Expand Up @@ -3371,8 +3367,7 @@ def __init__(self, operand, index=0, out=None):

@staticmethod
def _output_basis(input_basis):
out = input_basis._new_k(input_basis.k + 1)
return out
return input_basis.derivative_basis(1)

def check_conditions(self):
"""Check that operands are in a proper layout."""
Expand Down Expand Up @@ -3515,8 +3510,7 @@ def __init__(self, operand, index=0, out=None):

@staticmethod
def _output_basis(input_basis):
out = input_basis._new_k(input_basis.k + 1)
return out
return input_basis.derivative_basis(1)

def check_conditions(self):
"""Check that operands are in a proper layout."""
Expand Down Expand Up @@ -3718,8 +3712,7 @@ def __init__(self, operand, coordsys, out=None):

@staticmethod
def _output_basis(input_basis):
out = input_basis._new_k(input_basis.k + 2)
return out
return input_basis.derivative_basis(2)

def check_conditions(self):
"""Check that operands are in a proper layout."""
Expand Down Expand Up @@ -3827,8 +3820,7 @@ def __init__(self, operand, coordsys, out=None):

@staticmethod
def _output_basis(input_basis):
out = input_basis._new_k(input_basis.k + 2)
return out
return input_basis.derivative_basis(2)

def check_conditions(self):
"""Check that operands are in a proper layout."""
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/dedalus_tutorial_3.ipynb
Expand Up @@ -179,7 +179,7 @@
"c = -1.76\n",
"\n",
"# Tau polynomials\n",
"tau_basis = xbasis.clone_with(a=1.5, b=1.5)\n",
"tau_basis = xbasis.derivative_basis(2)\n",
"p1 = dist.Field(bases=tau_basis)\n",
"p2 = dist.Field(bases=tau_basis)\n",
"p1['c'][-1] = 1\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/dedalus_tutorial_4.ipynb
Expand Up @@ -106,7 +106,7 @@
"c = -1.76\n",
"\n",
"# Tau polynomials\n",
"tau_basis = xbasis.clone_with(a=1.5, b=1.5)\n",
"tau_basis = xbasis.derivative_basis(2)\n",
"p1 = dist.Field(bases=tau_basis)\n",
"p2 = dist.Field(bases=tau_basis)\n",
"p1['c'][-1] = 1\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/pages/tau_method.rst
Expand Up @@ -132,7 +132,7 @@ Here we'll take :math:`P(y)` to be the highest mode in the Chebyshev-U basis, in
# Substitutions
ex, ey = coords.unit_vector_fields(dist)
lift_basis = ybasis.clone_with(a=1/2, b=1/2) # Chebyshev U basis
lift_basis = ybasis.derivative_basis(1) # Chebyshev U basis
lift = lambda A, n: d3.Lift(A, lift_basis, -1) # Shortcut for multiplying by U_{N-1}(y)
grad_u = d3.grad(u) - ey*lift(tau_u1) # Operator representing G
Expand Down
2 changes: 1 addition & 1 deletion examples/evp_1d_waves_on_a_string/waves_on_a_string.py
Expand Up @@ -41,7 +41,7 @@

# Substitutions
dx = lambda A: d3.Differentiate(A, xcoord)
lift_basis = xbasis.clone_with(a=1/2, b=1/2) # First derivative basis
lift_basis = xbasis.derivative_basis(1)
lift = lambda A: d3.Lift(A, lift_basis, -1)
ux = dx(u) + lift(tau_1) # First-order reduction

Expand Down
2 changes: 1 addition & 1 deletion examples/ivp_2d_rayleigh_benard/rayleigh_benard.py
Expand Up @@ -61,7 +61,7 @@
nu = (Rayleigh / Prandtl)**(-1/2)
x, z = dist.local_grids(xbasis, zbasis)
ex, ez = coords.unit_vector_fields(dist)
lift_basis = zbasis.clone_with(a=1/2, b=1/2) # First derivative basis
lift_basis = zbasis.derivative_basis(1)
lift = lambda A: d3.Lift(A, lift_basis, -1)
grad_u = d3.grad(u) + ez*lift(tau_u1) # First-order reduction
grad_b = d3.grad(b) + ez*lift(tau_b1) # First-order reduction
Expand Down
2 changes: 1 addition & 1 deletion examples/ivp_shell_convection/shell_convection.py
Expand Up @@ -65,7 +65,7 @@
er['g'][2] = 1
rvec = dist.VectorField(coords, bases=basis.radial_basis)
rvec['g'][2] = r
lift_basis = basis.clone_with(k=1) # First derivative basis
lift_basis = basis.derivative_basis(1)
lift = lambda A: d3.Lift(A, lift_basis, -1)
grad_u = d3.grad(u) + rvec*lift(tau_u1) # First-order reduction
grad_b = d3.grad(b) + rvec*lift(tau_b1) # First-order reduction
Expand Down
2 changes: 1 addition & 1 deletion examples/lbvp_2d_poisson/poisson.py
Expand Up @@ -47,7 +47,7 @@

# Substitutions
dy = lambda A: d3.Differentiate(A, coords['y'])
lift_basis = ybasis.clone_with(a=3/2, b=3/2) # Natural output basis
lift_basis = ybasis.derivative_basis(2)
lift = lambda A, n: d3.Lift(A, lift_basis, n)

# Problem
Expand Down

0 comments on commit 8b709a9

Please sign in to comment.