Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save time snapshots #29

Closed
pavane opened this issue Aug 19, 2021 · 10 comments
Closed

save time snapshots #29

pavane opened this issue Aug 19, 2021 · 10 comments

Comments

@pavane
Copy link

pavane commented Aug 19, 2021

Is there a way to store time snapshots of 3D wavefield at given time intervals for visualization?

@ar4
Copy link
Owner

ar4 commented Aug 19, 2021 via email

@pavane
Copy link
Author

pavane commented Aug 19, 2021

Thank you so much. I would like to help with testing 3D boundary conditions and maybe improving that piece of this code.
Please help me get started with the compiled propagator.

@ar4
Copy link
Owner

ar4 commented Aug 20, 2021

I have written some code to call the compiled propagator directly so that we can access the wavefields at arbitrary time steps:

import torch
import numpy as np
import scipy.signal
import deepwave
import deepwave.base.propagator
from deepwave.scalar import scalar

class SteppingPropagator(deepwave.base.propagator.Propagator):
    """PyTorch Module for scalar wave propagator.

    See deepwave.base.propagator.Propagator for description.
    """

    def __init__(self, model, dx,
                 source_amplitudes, source_locations, receiver_locations, dt,
                 pml_width=None, survey_pad=None, vpmax=None):
        if list(model.keys()) != ["vp"]:
            raise RuntimeError(
                "Model must only contain vp, but contains {}".format(
                    list(model.keys())
                )
            )
        super(SteppingPropagator, self).__init__(
            SteppingPropagatorFunction,
            model,
            dx,
            fd_width=4,  # also in Pml
            pml_width=pml_width,
            survey_pad=survey_pad,
        )
        self.model.extra_info["vpmax"] = vpmax
        if model["vp"].min() <= 0.0:
            raise RuntimeError(
                "vp must be > 0, but min is {}".format(model["vp"].min())
            )
            
        (source_amplitudes,
         source_locations,
         receiver_locations,
         dt,
         model,
         property_names,
         vp) = self.forward(source_amplitudes, source_locations, receiver_locations, dt)
        
        if property_names != ["vp"]:
            raise RuntimeError(
                "Model must only contain vp, but contains {}".format(
                    property_names
                )
            )
        if vp.min() <= 0.0:
            raise RuntimeError(
                "vp must be > 0, but min is {}".format(vp.min())
            )
        device = model.device
        dtype = model.dtype
        num_steps, num_shots, num_sources_per_shot = source_amplitudes.shape
        num_receivers_per_shot = receiver_locations.shape[1]

        if model.extra_info["vpmax"] is None:
            max_vel = vp.max().item()
        else:
            max_vel = model.extra_info["vpmax"]
        timestep = scalar.Timestep(dt, model.dx, max_vel)
        model.add_properties(
            {
                "vp2dt2": vp ** 2 * timestep.inner_dt ** 2,
                "scaling": 2 / vp ** 3,
            }
        )
        source_model_locations = model.get_locations(source_locations)
        receiver_model_locations = model.get_locations(receiver_locations)
        scalar_wrapper = scalar._select_propagator(model.ndim, vp.dtype, vp.is_cuda)
        wavefield_save_strategy = scalar._set_wavefield_save_strategy(
            False, dt, timestep.inner_dt, scalar_wrapper
        )
        fd1, fd2 = scalar._set_finite_diff_coeffs(model.ndim, model.dx, device, dtype)
        wavefield, saved_wavefields = scalar._allocate_wavefields(
            wavefield_save_strategy,
            scalar_wrapper,
            model,
            num_steps,
            num_shots,
        )
        receiver_amplitudes = torch.zeros(
            num_steps,
            num_shots,
            num_receivers_per_shot,
            device=device,
            dtype=dtype,
        )
        inner_dt = torch.tensor([timestep.inner_dt]).to(dtype)
        pml = scalar.Pml(model, num_shots, max_vel)
        source_amplitudes_resampled = scipy.signal.resample(
            source_amplitudes.detach().cpu().numpy(),
            num_steps * timestep.step_ratio,
        )
        source_amplitudes_resampled = (
            torch.tensor(source_amplitudes_resampled)
            .to(dtype)
            .to(source_amplitudes.device)
        )
        source_amplitudes_resampled.requires_grad = (
            source_amplitudes.requires_grad
        )
        
        self.scalar_wrapper = scalar_wrapper
        self.wavefield = wavefield
        self.pml = pml
        self.receiver_amplitudes = receiver_amplitudes
        self.saved_wavefields = saved_wavefields
        self.model = model
        self.fd1 = fd1
        self.fd2 = fd2
        self.source_amplitudes_resampled = source_amplitudes_resampled
        self.source_model_locations = source_model_locations
        self.receiver_model_locations = receiver_model_locations
        self.inner_dt = inner_dt
        self.timestep = timestep
        self.num_shots = num_shots
        self.num_sources_per_shot = num_sources_per_shot
        self.num_receivers_per_shot = num_receivers_per_shot
        self.wavefield_save_strategy = wavefield_save_strategy
        self.dtype = dtype
        
        self.total_num_steps = num_steps
        self.current_step = 0
        
    def step(self, num_steps):
        
        assert self.current_step + num_steps <= self.total_num_steps
        
        source_amplitudes_resampled_steps = \
            self.source_amplitudes_resampled[self.current_step*self.timestep.step_ratio:
                                             (self.current_step+num_steps)*self.timestep.step_ratio]

        # Call compiled C code to do forward modeling
        self.scalar_wrapper.forward(
            self.wavefield.to(self.dtype).contiguous(),
            self.pml.aux.to(self.dtype).contiguous(),
            self.receiver_amplitudes.to(self.dtype).contiguous(),
            self.saved_wavefields.to(self.dtype).contiguous(),
            self.pml.sigma.to(self.dtype).contiguous(),
            self.model.properties["vp2dt2"].to(self.dtype).contiguous(),
            self.fd1.to(self.dtype).contiguous(),
            self.fd2.to(self.dtype).contiguous(),
            source_amplitudes_resampled_steps.to(self.dtype).contiguous(),
            self.source_model_locations.long().contiguous(),
            self.receiver_model_locations.long().contiguous(),
            self.model.shape.contiguous(),
            self.pml.pml_width.long().contiguous(),
            self.inner_dt,
            num_steps,
            self.timestep.step_ratio,
            self.num_shots,
            self.num_sources_per_shot,
            self.num_receivers_per_shot,
            self.wavefield_save_strategy,
        )
        
        self.current_step += num_steps 
        
        if num_steps * self.timestep.step_ratio % 3 != 0:
            # Swap the wavefield arrays so that they are in the correct order
            wf_idxs = [0, 1, 2]
            for stepidx in range(num_steps * self.timestep_step_ratio):
                wf_idxs = [wf_idxs[2], wf_idxs[0], wf_idxs[1]]
            self.wavefield[0], self.wavefield[1], self.wavefield[2] = \
                (self.wavefield[wf_idxs[0]],
                 self.wavefield[wf_idxs[1]],
                 self.wavefield[wf_idxs[2]])
        
        if num_steps * self.timestep.step_ratio % 2 != 0:
            # Swap the aux arrays so that they are in the correct order
            ndim = self.model.ndim
            if ndim == 1:
                aux_size = 1
            elif ndim == 2:
                aux_size = 2
            else:
                aux_size = 4
            assert len(self.pml.aux) == 2 * aux_size
            self.pml.aux[:aux_size], self.pml.aux[aux_size:] = \
                self.pml.aux[aux_size:], self.pml.aux[:aux_size]
        
        return self.wavefield[1]

        
class SteppingPropagatorFunction(torch.autograd.Function):
    """Forward modeling and backpropagation functions. Not called by users."""

    @staticmethod
    def forward(
        ctx,
        source_amplitudes,
        source_locations,
        receiver_locations,
        dt,
        model,
        property_names,
        vp,
    ):
        return (
            source_amplitudes,
            source_locations,
            receiver_locations,
            dt,
            model,
            property_names,
            vp,
        )

It is a bit hacky - it runs the setup for a regular propagator and then extracts the variables that are passed to the forward method of the propagator. It then uses these to run all of the code in the usual forward propagator up to the point where the compiled propagator gets called, and saves the arguments to this so that they can be used when you actually want to run forward time steps of the propagator. The benefit of doing all of this setup is that the actual stepping part is then quite easy - we just get the right bits of the source wavelet for the desired steps, run the compiled propagator, and then swap some memory around if necessary to make sure it is in the right place.

Here is an example of how to use it:

import matplotlib.pyplot as plt

dx = 5.0 # 5m in each dimension
dt = 0.004 # 4ms
nz = 200
ny = 400
nt = int(5 / dt) # 1s
peak_freq = 4 
peak_source_time = 1/peak_freq

# constant 1500m/s model
model = torch.ones(nz, ny) * 1500

# one source and receiver at the same location
x_s = torch.Tensor([[[0, 20 * dx]]])
x_r = x_s.clone()

source_amplitudes = deepwave.wavelets.ricker(peak_freq, nt, dt,
                                             peak_source_time).reshape(-1, 1, 1)

prop = SteppingPropagator({'vp': model}, dx, source_amplitudes, x_s, x_r, dt)
wavefield1 = prop.step(100).detach().numpy().copy()
wavefield2 = prop.step(100).detach().numpy().copy()
wavefield3 = prop.step(100).detach().numpy().copy()

_, ax = plt.subplots(1,3,sharex=True,sharey=True)
ax[0].imshow(wavefield1[0,:,:,0], aspect='auto')
ax[1].imshow(wavefield2[0,:,:,0], aspect='auto')
ax[2].imshow(wavefield3[0,:,:,0], aspect='auto')
plt.show()

The CPU implementation of propagation in 3D is here. If I remember correctly, I used the same PML as PySIT.

@pavane
Copy link
Author

pavane commented Aug 26, 2021

Thank you so much for getting me started. I will update you on my progess

@pavane
Copy link
Author

pavane commented Aug 26, 2021

The code fails with the following error
"TypeError: SteppingPropagatorFunctionBackward.forward: expected Tensor or tuple of Tensor (got float) for return value 3"

@ar4
Copy link
Owner

ar4 commented Aug 27, 2021 via email

@pavane
Copy link
Author

pavane commented Aug 27, 2021 via email

@ar4
Copy link
Owner

ar4 commented Aug 27, 2021 via email

@ar4
Copy link
Owner

ar4 commented Aug 28, 2021

From the message that you got, it sounds like your version of PyTorch is complaining about some of the return values from the forward function in SteppingPropagatorFunction not being Tensors. Perhaps you could try this version of the code instead:

import torch
import numpy as np
import scipy.signal
import deepwave
import deepwave.base.propagator
from deepwave.scalar import scalar
from deepwave.base.propagator import _check_locations_with_model

class SteppingPropagator(deepwave.base.propagator.Propagator):
    """PyTorch Module for scalar wave propagator.

    See deepwave.base.propagator.Propagator for description.
    """

    def __init__(self, model, dx,
                 source_amplitudes, source_locations, receiver_locations, dt,
                 pml_width=None, survey_pad=None, vpmax=None):
        if list(model.keys()) != ["vp"]:
            raise RuntimeError(
                "Model must only contain vp, but contains {}".format(
                    list(model.keys())
                )
            )
        super(SteppingPropagator, self).__init__(
            SteppingPropagatorFunction,
            model,
            dx,
            fd_width=4,  # also in Pml
            pml_width=pml_width,
            survey_pad=survey_pad,
        )
        self.model.extra_info["vpmax"] = vpmax
        if model["vp"].min() <= 0.0:
            raise RuntimeError(
                "vp must be > 0, but min is {}".format(model["vp"].min())
            )

        # Check dt
        if not isinstance(dt, float):
            raise RuntimeError('dt must be a float, but has type {}'
                               .format(type(dt)))
        if dt <= 0.0:
            raise RuntimeError('dt must be > 0, but is {}'.format(dt))

        # Check same device as model
        if not (self.model.device == source_amplitudes.device ==
                source_locations.device == receiver_locations.device):
            raise RuntimeError('model, source amplitudes, source_locations, '
                               'and receiver_locations must all have the same '
                               'device, but got {} {} {} {}'
                               .format(self.model.device,
                                       source_amplitudes.device,
                                       source_locations.device,
                                       receiver_locations.device))

        # Check shapes
        if source_amplitudes.dim() != 3:
            raise RuntimeError('source_amplitude must have shape '
                               '[nt, num_shots, num_sources_per_shot]')

        if source_locations.dim() != 3:
            raise RuntimeError('source_locations must have shape '
                               '[num_shots, num_sources_per_shot, num_dims]')

        if receiver_locations.dim() != 3:
            raise RuntimeError('receiver_locations must have shape '
                               '[num_shots, num_receivers_per_shot, num_dims]')

        if not (source_amplitudes.shape[1] == source_locations.shape[0] ==
                receiver_locations.shape[0]):
            raise RuntimeError('Shape mismatch, expected '
                               'source_amplitudes.shape[1] '
                               '== source_locations.shape[0] '
                               '== receiver_locations.shape[0], but got '
                               '{} {} {}'.format(source_amplitudes.shape[1],
                                                 source_locations.shape[0],
                                                 receiver_locations.shape[0]))

        if not (source_amplitudes.shape[2] == source_locations.shape[1]):
            raise RuntimeError('Shape mismatch, expected '
                               'source_amplitudes.shape[2] '
                               '== source_locations.shape[1], but got '
                               '{} {}'.format(source_amplitudes.shape[2],
                                              source_locations.shape[1]))

        if not (self.model.ndim == source_locations.shape[2] ==
                receiver_locations.shape[2]):
            raise RuntimeError('Shape mismatch, expected '
                               'model num dims == source_locations.shape[2] '
                               '== receiver_locations.shape[2], but got '
                               '{} {} {}'.format(self.model.ndim,
                                                 source_locations.shape[2],
                                                 receiver_locations.shape[2]))

        # Check src/rec locations within model
        _check_locations_with_model(self.model, source_locations, 'source')
        _check_locations_with_model(self.model, receiver_locations, 'receiver')

        # Extract a region of the model around the sources/receivers
        model = self.extract(self.model, source_locations, receiver_locations)

        # Apply padding for the spatial finite difference and for the PML
        model = self.pad(model)

        property_names = list(model.properties.keys())
        vp = model.properties["vp"]
        
        if property_names != ["vp"]:
            raise RuntimeError(
                "Model must only contain vp, but contains {}".format(
                    property_names
                )
            )
        if vp.min() <= 0.0:
            raise RuntimeError(
                "vp must be > 0, but min is {}".format(vp.min())
            )
        device = model.device
        dtype = model.dtype
        num_steps, num_shots, num_sources_per_shot = source_amplitudes.shape
        num_receivers_per_shot = receiver_locations.shape[1]

        if model.extra_info["vpmax"] is None:
            max_vel = vp.max().item()
        else:
            max_vel = model.extra_info["vpmax"]
        timestep = scalar.Timestep(dt, model.dx, max_vel)
        model.add_properties(
            {
                "vp2dt2": vp ** 2 * timestep.inner_dt ** 2,
                "scaling": 2 / vp ** 3,
            }
        )
        source_model_locations = model.get_locations(source_locations)
        receiver_model_locations = model.get_locations(receiver_locations)
        scalar_wrapper = scalar._select_propagator(model.ndim, vp.dtype, vp.is_cuda)
        wavefield_save_strategy = scalar._set_wavefield_save_strategy(
            False, dt, timestep.inner_dt, scalar_wrapper
        )
        fd1, fd2 = scalar._set_finite_diff_coeffs(model.ndim, model.dx, device, dtype)
        wavefield, saved_wavefields = scalar._allocate_wavefields(
            wavefield_save_strategy,
            scalar_wrapper,
            model,
            num_steps,
            num_shots,
        )
        receiver_amplitudes = torch.zeros(
            num_steps,
            num_shots,
            num_receivers_per_shot,
            device=device,
            dtype=dtype,
        )
        inner_dt = torch.tensor([timestep.inner_dt]).to(dtype)
        pml = scalar.Pml(model, num_shots, max_vel)
        source_amplitudes_resampled = scipy.signal.resample(
            source_amplitudes.detach().cpu().numpy(),
            num_steps * timestep.step_ratio,
        )
        source_amplitudes_resampled = (
            torch.tensor(source_amplitudes_resampled)
            .to(dtype)
            .to(source_amplitudes.device)
        )
        source_amplitudes_resampled.requires_grad = (
            source_amplitudes.requires_grad
        )
        
        self.dtype = dtype
        self.scalar_wrapper = scalar_wrapper
        self.wavefield = wavefield.to(self.dtype).contiguous()
        self.pml = pml
        self.pml.aux = self.pml.aux.to(self.dtype).contiguous()
        self.pml.sigma = self.pml.sigma.to(self.dtype).contiguous()
        self.pml.pml_width = self.pml.pml_width.long().contiguous()
        self.receiver_amplitudes = receiver_amplitudes.to(self.dtype).contiguous()
        self.saved_wavefields = saved_wavefields.to(self.dtype).contiguous()
        self.model = model
        self.model.properties["vp2dt2"] = self.model.properties["vp2dt2"].to(self.dtype).contiguous()
        self.fd1 = fd1.to(self.dtype).contiguous()
        self.fd2 = fd2.to(self.dtype).contiguous()
        self.source_amplitudes_resampled = source_amplitudes_resampled
        self.source_model_locations = source_model_locations.long().contiguous()
        self.receiver_model_locations = receiver_model_locations.long().contiguous()
        self.inner_dt = inner_dt
        self.timestep = timestep
        self.num_shots = num_shots
        self.num_sources_per_shot = num_sources_per_shot
        self.num_receivers_per_shot = num_receivers_per_shot
        self.wavefield_save_strategy = wavefield_save_strategy
        
        self.total_num_steps = num_steps
        self.current_step = 0
        
    def step(self, num_steps):
        
        assert self.current_step + num_steps <= self.total_num_steps
        
        source_amplitudes_resampled_steps = \
            self.source_amplitudes_resampled[self.current_step*self.timestep.step_ratio:
                                             (self.current_step+num_steps)*self.timestep.step_ratio]

        # Call compiled C code to do forward modeling
        self.scalar_wrapper.forward(
            self.wavefield,
            self.pml.aux,
            self.receiver_amplitudes,
            self.saved_wavefields,
            self.pml.sigma,
            self.model.properties["vp2dt2"],
            self.fd1,
            self.fd2,
            source_amplitudes_resampled_steps.to(self.dtype).contiguous(),
            self.source_model_locations,
            self.receiver_model_locations,
            self.model.shape.contiguous(),
            self.pml.pml_width,
            self.inner_dt,
            num_steps,
            self.timestep.step_ratio,
            self.num_shots,
            self.num_sources_per_shot,
            self.num_receivers_per_shot,
            self.wavefield_save_strategy,
        )
        
        self.current_step += num_steps 

        if num_steps * self.timestep.step_ratio % 3 != 0:
            # Swap the wavefield arrays so that they are in the correct order
            wf_idxs = [0, 1, 2]
            for stepidx in range(num_steps * self.timestep_step_ratio):
                wf_idxs = [wf_idxs[2], wf_idxs[0], wf_idxs[1]]
            self.wavefield[0], self.wavefield[1], self.wavefield[2] = \
                (self.wavefield[wf_idxs[0]],
                 self.wavefield[wf_idxs[1]],
                 self.wavefield[wf_idxs[2]])
        
        if num_steps * self.timestep.step_ratio % 2 != 0:
            # Swap the aux arrays so that they are in the correct order
            ndim = self.model.ndim
            if ndim == 1:
                aux_size = 1
            elif ndim == 2:
                aux_size = 2
            else:
                aux_size = 4
            assert len(self.pml.aux) == 2 * aux_size
            self.pml.aux[:aux_size], self.pml.aux[aux_size:] = \
                self.pml.aux[aux_size:], self.pml.aux[:aux_size]
        
        return self.wavefield[1]

        
class SteppingPropagatorFunction(torch.autograd.Function):
    """Forward modeling and backpropagation functions. Not called by users."""

    @staticmethod
    def forward(
        ctx,
        source_amplitudes,
        source_locations,
        receiver_locations,
        dt,
        model,
        property_names,
        vp,
    ):
        return vp

The example code to run the propagator should be the same.

@pavane
Copy link
Author

pavane commented Aug 28, 2021

This code works. Thank you so much. I will keep you posted.

@ar4 ar4 closed this as completed Jun 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants