-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Elastic wave gradient calculation #75
Comments
Hello again, I am glad to hear that Deepwave continues to be useful to you. The error that you encountered is caused by your parameters containing NaNs after several steps of optimization. The elastic propagator is unfortunately not very stable, so it is quite common for this to occur. I suggest that you make three changes that might help. The first is to apply a little bit of smoothing to your model parameters before they are passed to the propagator. This helps to avoid sharp changes in the model that can cause instability. An example of this can be seen in the Joint Migration-Inversion example. The second is to restrict parameters to be within an expected range, such as by using the The combination of these changes would look something like this (untested): class Model(torch.nn.Module):
def __init__(self, initial, min_val, max_val):
super().__init__()
self.min_val = min_val
self.max_val = max_val
self.model = torch.nn.Parameter(
torch.logit((initial - min_val) /
(max_val - min_val))
)
def forward(self):
return (torch.sigmoid(self.model) *
(self.max_val - self.min_val) +
self.min_val)
# You should ensure that the initial and expected parameter
# values are within these ranges
model_vp = Model(vp, 1400, 5000).to(device)
model_vs = Model(vs, 900, 3000).to(device)
model_rho = Model(rho, 1000, 3000).to(device)
# Using 'strong_wolfe' seems to improve stability, but may not
# work if running on multiple GPUs
optimiser = torch.optim.LBFGS(list(model_vp.parameters()) +
list(model_vs.parameters()) +
list(model_rho.parameters()),
line_search_fn='strong_wolfe')
for epoch in range(n_epochs):
def closure():
optimiser.zero_grad()
vp_smooth = (
torchvision.transforms.functional.gaussian_blur(
model_vp()[None], [5, 5]
).squeeze()
)
vs_smooth = (
torchvision.transforms.functional.gaussian_blur(
model_vs()[None], [5, 5]
).squeeze()
)
rho_smooth = (
torchvision.transforms.functional.gaussian_blur(
model_rho()[None], [5, 5]
).squeeze()
)
out = elastic(
*deepwave.common.vpvsrho_to_lambmubuoyancy(vp_smooth, vs_smooth, rho_smooth),
dx, dt,
source_amplitudes_y=source_amplitudes,
source_locations_y=source_locations,
receiver_locations_y=receiver_locations,
pml_freq=freq,
pml_width=30
)[-2]
loss = 1e17*loss_fn(out, observed_data)
loss.backward()
return loss It may also simply be that using a weight of 1e17 causes the update step size to be too big, which will also lead to instability. A step size that is too large can cause bad updates that may, for example, result in vs becoming larger than vp at some point in your model, in which case the elastic propagator will become unstable and start producing NaNs. I hope this helps. If you continue to have difficulty then perhaps you might have more success with a different optimizer, such as Adam. |
I should mention that if you use the `Model` class, then you may need
to adjust your scaling factor of 1e17, since the model parameters will
now be stored in a different range.
|
Thank you for your answer. Due to being too busy lately, I didn't reply in a timely manner. This issue has been resolved. Once again, thank you for your response. |
That's excellent news. Thank you for letting me know. May I ask if there were any particular changes that resolved the problem (if it is not too complicated to explain)? |
The reason for the previous error was that the initial speed was a linear speed, which led to the problem. When I replaced the initial speed with a smooth speed, the problem was resolved. |
Ah, I see. Thank you for the explanation. I will close this issue now, but please feel free to reopen it or to create another if you have further questions. |
Hello,
Thank you for bothering me again. I recently encountered an elastic wave problem. Here is my code. When I set the coefficient before the loss function to 1e16, it can be automatically updated, but the update speed is not obvious. When I change the coefficient before to 1e17, the gradient update will occur several times, and the following error will occur. I don't know what the reason for this error is, whether it's a problem with my coefficient settings or other reasons.
`
ny = 320
nx = 128
dx = 20
n_shots = 1
nt = 6000
dt = 0.002
epoches = 10
for m in range(epoches):
The text was updated successfully, but these errors were encountered: