Skip to content

Commit

Permalink
Add time_taper option to apply Hann window
Browse files Browse the repository at this point in the history
Deepwave samples the integration over timesteps to calculate
gradients with respect to model parameters at a frequency that
relies on signals being bandlimited. When signals do not start
from zero amplitude, high frequencies are introduced which can
break this assumption. To ensure that this does not occur, this
commit adds a new time_taper option that applies a Hann window
to source and receiver amplitudes. It is off by default, and is
probably mainly useful during testing of the propagators.
  • Loading branch information
ar4 committed Sep 24, 2023
1 parent 51fa84b commit 79a8349
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 24 deletions.
37 changes: 30 additions & 7 deletions src/deepwave/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def setup_propagator(
nt: Optional[int] = None,
model_gradient_sampling_interval: int = 1,
freq_taper_frac: float = 0.0,
time_pad_frac: float = 0.0
time_pad_frac: float = 0.0,
time_taper: bool = False
) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor],
List[Tensor], List[Tensor], float, float, float, int, int, int, int,
int, List[int]]:
Expand Down Expand Up @@ -114,7 +115,8 @@ def setup_propagator(
sa = upsample(sa,
step_ratio,
freq_taper_frac=freq_taper_frac,
time_pad_frac=time_pad_frac)
time_pad_frac=time_pad_frac,
time_taper=time_taper)
if sa.device == torch.device('cpu'):
sa = torch.movedim(sa, -1, 1)
else:
Expand Down Expand Up @@ -195,6 +197,7 @@ def downsample_and_movedim(receiver_amplitudes: Tensor,
step_ratio: int,
freq_taper_frac: float = 0.0,
time_pad_frac: float = 0.0,
time_taper: bool = False,
shift: float = 0.0) -> Tensor:
if receiver_amplitudes.numel() > 0:
if receiver_amplitudes.device == torch.device('cpu'):
Expand All @@ -205,6 +208,7 @@ def downsample_and_movedim(receiver_amplitudes: Tensor,
step_ratio,
freq_taper_frac=freq_taper_frac,
time_pad_frac=time_pad_frac,
time_taper=time_taper,
shift=shift)
return receiver_amplitudes

Expand Down Expand Up @@ -463,7 +467,8 @@ def zero_last_element_of_final_dimension(signal: Tensor) -> Tensor:
def upsample(signal: Tensor,
step_ratio: int,
freq_taper_frac: float = 0.0,
time_pad_frac: float = 0.0) -> Tensor:
time_pad_frac: float = 0.0,
time_taper: bool = False) -> Tensor:
"""Upsamples the final dimension of a Tensor by a factor.
Low-pass upsampling is used to produce an upsampled signal without
Expand All @@ -475,8 +480,8 @@ def upsample(signal: Tensor,
The Tensor that will have its final dimension upsampled.
step_ratio:
The integer factor by which the signal will be upsampled.
The input signal is returned if this is 1 (freq_taper_frac
and time_pad_frac will be ignored).
The input signal is returned if this is 1 (freq_taper_frac,
time_pad_frac, and time_taper will be ignored).
freq_taper_frac:
A float specifying the fraction of the end of the signal
amplitude in the frequency domain to cosine taper. This
Expand All @@ -490,6 +495,12 @@ def upsample(signal: Tensor,
dimension of the input signal. This might be useful to reduce
wraparound artifacts. A value of 0.1 means that zero padding
of 10% of the length of the signal will be used. Default 0.0.
time_taper:
A bool specifying whether to apply a Hann window in time.
This is useful during correctness tests of the propagators
as it ensures that signals taper to zero at their edges in
time, avoiding the possibility of high frequencies being
introduced.
Returns:
The signal after upsampling.
Expand All @@ -515,13 +526,16 @@ def upsample(signal: Tensor,
signal = torch.fft.irfft(signal_f, n=up_nt, norm='ortho')
if time_pad_frac > 0.0:
signal = signal[..., :signal.shape[-1] - n_time_pad * step_ratio]
if time_taper:
signal = signal * torch.hann_window(signal.shape[-1], periodic=False)
return signal


def downsample(signal: Tensor,
step_ratio: int,
freq_taper_frac: float = 0.0,
time_pad_frac: float = 0.0,
time_taper: bool = False,
shift: float = 0.0) -> Tensor:
"""Downsamples the final dimension of a Tensor by a factor.
Expand All @@ -533,8 +547,9 @@ def downsample(signal: Tensor,
The Tensor that will have its final dimension downsampled.
step_ratio:
The integer factor by which the signal will be downsampled.
The input signal is returned if this is 1 (freq_taper_frac
and time_pad_frac will be ignored).
The input signal is returned if this is 1 and shift is 0
(freq_taper_frac, time_pad_frac, and time_taper will be
ignored).
freq_taper_frac:
A float specifying the fraction of the end of the signal
amplitude in the frequency domain to cosine taper. This
Expand All @@ -549,6 +564,12 @@ def downsample(signal: Tensor,
wraparound artifacts. A value of 0.1 means that zero padding
of 10% of the length of the output signal will be used.
Default 0.0.
time_taper:
A bool specifying whether to apply a Hann window in time.
This is useful during correctness tests of the propagators
as it ensures that signals taper to zero at their edges in
time, avoiding the possibility of high frequencies being
introduced.
shift:
Amount (in units of time samples) to shift the data in time
before downsampling.
Expand All @@ -559,6 +580,8 @@ def downsample(signal: Tensor,
"""
if step_ratio == 1 and shift == 0.0:
return signal
if time_taper:
signal = signal * torch.hann_window(signal.shape[-1], periodic=False)
if time_pad_frac > 0.0:
n_time_pad = int(time_pad_frac * (signal.shape[-1] // step_ratio))
signal = torch.nn.functional.pad(signal, (0, n_time_pad * step_ratio))
Expand Down
17 changes: 12 additions & 5 deletions src/deepwave/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ def elastic(
nt: Optional[int] = None,
model_gradient_sampling_interval: int = 1,
freq_taper_frac: float = 0.0,
time_pad_frac: float = 0.0
time_pad_frac: float = 0.0,
time_taper: bool = False
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,
Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Elastic wave propagation (functional interface).
Expand Down Expand Up @@ -400,6 +401,12 @@ def elastic(
useful to reduce wraparound artifacts. A value of 0.1 means that
zero padding of 10% of the number of time samples will be used.
Default 0.0.
time_taper:
A bool specifying whether to apply a Hann window in time to
source and receiver amplitudes (if present). This is useful
during correctness tests of the propagators as it ensures that
signals taper to zero at their edges in time, avoiding the
possibility of high frequencies being introduced.
Returns:
Tuple[Tensor]:
Expand Down Expand Up @@ -478,7 +485,7 @@ def elastic(
accuracy, pml_width, pml_freq, max_vel,
survey_pad,
origin, nt, model_gradient_sampling_interval,
freq_taper_frac, time_pad_frac)
freq_taper_frac, time_pad_frac, time_taper)
lamb, mu, buoyancy = models
source_amplitudes_y, source_amplitudes_x = source_amplitudes
(vy, vx, sigmayy, sigmaxy, sigmaxx, m_vyy, m_vyx, m_vxy, m_vxx, m_sigmayyy,
Expand Down Expand Up @@ -524,13 +531,13 @@ def average_adjacent(receiver_amplitudes: Tensor) -> Tensor:
receiver_amplitudes_x = average_adjacent(receiver_amplitudes_x)
receiver_amplitudes_y = downsample_and_movedim(receiver_amplitudes_y,
step_ratio, freq_taper_frac,
time_pad_frac)
time_pad_frac, time_taper)
receiver_amplitudes_x = downsample_and_movedim(receiver_amplitudes_x,
step_ratio, freq_taper_frac,
time_pad_frac)
time_pad_frac, time_taper)
receiver_amplitudes_p = downsample_and_movedim(receiver_amplitudes_p,
step_ratio, freq_taper_frac,
time_pad_frac)
time_pad_frac, time_taper)

return (vy, vx, sigmayy, sigmaxy, sigmaxx, m_vyy, m_vyx, m_vxy, m_vxx,
m_sigmayyy, m_sigmaxyy, m_sigmaxyx, m_sigmaxxx,
Expand Down
13 changes: 10 additions & 3 deletions src/deepwave/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def scalar(
nt: Optional[int] = None,
model_gradient_sampling_interval: int = 1,
freq_taper_frac: float = 0.0,
time_pad_frac: float = 0.0
time_pad_frac: float = 0.0,
time_taper: bool = False
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Scalar wave propagation (functional interface).
Expand Down Expand Up @@ -339,6 +340,12 @@ def scalar(
useful to reduce wraparound artifacts. A value of 0.1 means that
zero padding of 10% of the number of time samples will be used.
Default 0.0.
time_taper:
A bool specifying whether to apply a Hann window in time to
source and receiver amplitudes (if present). This is useful
during correctness tests of the propagators as it ensures that
signals taper to zero at their edges in time, avoiding the
possibility of high frequencies being introduced.
Returns:
Tuple[Tensor]:
Expand Down Expand Up @@ -373,7 +380,7 @@ def scalar(
accuracy, pml_width, pml_freq, max_vel,
survey_pad,
origin, nt, model_gradient_sampling_interval,
freq_taper_frac, time_pad_frac)
freq_taper_frac, time_pad_frac, time_taper)
v = models[0]
wfc, wfp, psiy, psix, zetay, zetax = wavefields
source_amplitudes = source_amplitudes_l[0]
Expand All @@ -390,7 +397,7 @@ def scalar(

receiver_amplitudes = downsample_and_movedim(receiver_amplitudes,
step_ratio, freq_taper_frac,
time_pad_frac)
time_pad_frac, time_taper)

return wfc, wfp, psiy, psix, zetay, zetax, receiver_amplitudes

Expand Down
9 changes: 5 additions & 4 deletions src/deepwave/scalar_born.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def scalar_born(
nt: Optional[int] = None,
model_gradient_sampling_interval: int = 1,
freq_taper_frac: float = 0.0,
time_pad_frac: float = 0.0
time_pad_frac: float = 0.0,
time_taper: bool = False
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,
Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Scalar Born wave propagation (functional interface).
Expand Down Expand Up @@ -248,7 +249,7 @@ def scalar_born(
accuracy, pml_width, pml_freq, max_vel,
survey_pad,
origin, nt, model_gradient_sampling_interval,
freq_taper_frac, time_pad_frac)
freq_taper_frac, time_pad_frac, time_taper)
v, scatter = models
(wfc, wfp, psiy, psix, zetay, zetax, wfcsc, wfpsc, psiysc, psixsc, zetaysc,
zetaxsc) = wavefields
Expand All @@ -268,10 +269,10 @@ def scalar_born(

receiver_amplitudes = downsample_and_movedim(receiver_amplitudes,
step_ratio, freq_taper_frac,
time_pad_frac)
time_pad_frac, time_taper)
receiver_amplitudessc = downsample_and_movedim(receiver_amplitudessc,
step_ratio, freq_taper_frac,
time_pad_frac)
time_pad_frac, time_taper)

return (wfc, wfp, psiy, psix, zetay, zetax, wfcsc, wfpsc, psiysc, psixsc,
zetaysc, zetaxsc, receiver_amplitudes, receiver_amplitudessc)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,10 @@ def test_gradcheck_2d_cfl():

def test_gradcheck_2d_cfl_gradgrad():
"""Test gradcheck with a timestep greater than the CFL limit."""
run_gradcheck_2d(propagator=scalarprop, dt=0.002, atol=1.5e-6,
run_gradcheck_2d(propagator=scalarprop, dt=0.002, atol=1e-7,
prop_kwargs={'time_taper': True},
gradgrad=True,
source_requires_grad=False,
source_requires_grad=True,
wavefield_0_requires_grad=False,
wavefield_m1_requires_grad=False,
psiy_m1_requires_grad=False,
Expand Down Expand Up @@ -475,9 +476,9 @@ def test_gradcheck_2d_big_gradgrad():
"""Test gradcheck with a big model."""
run_gradcheck_2d(propagator=scalarprop,
nx=(5 + 2 * (3 + 3 * 2), 4 + 2 * (3 + 3 * 2)),
atol=2e-8,
prop_kwargs={'time_taper': True},
gradgrad=True,
source_requires_grad=False,
source_requires_grad=True,
wavefield_0_requires_grad=False,
wavefield_m1_requires_grad=False,
psiy_m1_requires_grad=False,
Expand Down Expand Up @@ -1016,7 +1017,6 @@ def run_gradcheck(c,
rtol=rtol)

if gradgrad:
#if gradgradatol is None:
gradgradatol = math.sqrt(atol)
torch.autograd.gradgradcheck(
propagator,
Expand Down

0 comments on commit 79a8349

Please sign in to comment.