In [None]:
import os, sys
sys.path.insert(1, './codes')  # insert at 1, 0 is the script path (or '' in REPL)

print("----------------")
!python --version
!nvidia-smi
print("----------------")
print("System Version: ", sys.version)

## ======================================================== ##
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from plot_functions import add_colorbar, imagesc

print("PyTorch Version: ", torch.__version__)
print("----------------")
print("torch.cuda.is_available: ",torch.cuda.is_available())
print("----------------")
print(torch.__version__, torch.version.cuda, torch.cuda.get_device_name(0))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# deterministic behavior
torch.manual_seed(3)
torch.cuda.manual_seed_all(3)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(3)
# random.seed(3)
os.environ['PYTHONHASHSEED'] = str(3)

In [None]:
# deterministic behavior
def seed_everything(seed):
    import os
    import torch
    import random
    import numpy as np
    random.seed(seed)            # Python random module
    np.random.seed(seed)         # Numpy Module
    torch.manual_seed(seed)      # Current CPU
    torch.cuda.manual_seed(seed) # Current GPU
    torch.cuda.manual_seed_all(seed) # All GPU
    torch.backends.cudnn.benchmark = False    # Close Optimization
    torch.backends.cudnn.deterministic = True # Close Optimization
    os.environ['PYTHONHASHSEED'] = str(seed)
    

# 设置一个种子
seed = 37
# 调用函数以设置种子
seed_everything(seed)


In [None]:
import pandas as pd
from scipy.ndimage import gaussian_filter

# Load the velocity model 
vmodel = np.array(pd.read_csv("./vel_marmousi_376x1151.csv")) 
v_init = np.array(pd.read_csv("./vel_marmousi_376x1151.csv"))
v_init = gaussian_filter(v_init, sigma=15)

dz = 8
nz, nx = vmodel.shape
print("Original Model Shape: {}, Grid Interval: {}m".format(vmodel.shape, dz))

################# Plot true & initial velocity model #################
fig = plt.figure(figsize=(12, 2))
gs = fig.add_gridspec(1, 2)
ax = fig.add_subplot(gs[0, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]")
im = ax.imshow(vmodel/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

ax = fig.add_subplot(gs[0, 1])
ax.set( xlabel="Distance x[km]")
im = ax.imshow(v_init/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')
################# Plot true & initial velocity model #################


dz = 15
vp_tensor_init = torch.from_numpy(vmodel).type(dtype=torch.float32).to(device)
vi_tensor_init = torch.from_numpy(v_init).type(dtype=torch.float32).to(device)
sea = (torch.ones([9,1151])*5500).type(dtype=torch.float32).to(device)
vp_tensor = torch.cat([vp_tensor_init, sea], axis=0)[None, ::3, 60:1084:4]
vi_tensor = torch.cat([vp_tensor_init, sea], axis=0)[None, ::3, 60:1084:4]
vi_tensor = torch.from_numpy(gaussian_filter(vi_tensor.cpu().numpy(), sigma=10)).type(dtype=torch.float32).to(device)

nv, nz, nx = vp_tensor.shape
print("Resampled Model Shape: {}, Grid Interval: {}m".format((nz, nx), dz))

# Setting locations of sources and receivers
xs = torch.arange(20, nx-10, 20, dtype=torch.long).repeat([nv, 1])      # x-coordinate for sources
ns = xs.shape[1]                                                        # number of shots 
xr = torch.arange(0, nx, 1, dtype=torch.long).repeat([nv, ns, 1])       # x-coordinate for receivers
zs = torch.full((nv, ns), 0, dtype=torch.long)                          # depth of sources    震源深度
zr = torch.full((nv, ns, nx), 0, dtype=torch.long)                      # depth of receivers 地表
print("Number of shots: {}, with interval: {}m, in depth: {}m".format(ns, (xs[0, 1]-xs[0, 0])*dz, zs[0, 0]*dz))
print(xs)

################# Plot true & initial velocity model #################
fig = plt.figure(figsize=(12, 2))
gs = fig.add_gridspec(1, 2)
ax = fig.add_subplot(gs[0, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]")
im = ax.imshow(vp_tensor[0].cpu().numpy()/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=0.8, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

ax = fig.add_subplot(gs[0, 1])
ax.set( xlabel="Distance x[km]")
im = ax.imshow(vi_tensor[0].cpu().numpy()/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=0.8, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')
# ################# Plot true & initial velocity model #################

In [None]:
vp_tensor.max()

In [None]:
vp_tensor.min()

In [None]:
from rnn_fd import rnn2D
from generator_old import wGenerator #单频雷克子波

freeSurface = True                                                      # free surface option for forward modeling
npad = 15                                                               # velocity padding in grid points
freq = 8                                                               # dominant frequency of wavelet in Hz
dt = 0.0019                                                            # time samling interval, fixed for all shots gathers
nt = 1024                                                              # number of samples in time
t = dt * torch.arange(0, nt, dtype=torch.float32)                       # create time vector
wavelet = wGenerator(t, freq).ricker().to(device)                       # generate wavelet
nx_pad = nx + 2 * npad
nz_pad = nz + npad if freeSurface else nz + 2 * npad
f = np.arange(0, nt/2+1) / (nt*dt)

fig = plt.figure(figsize=(10, 2.5))
gs = fig.add_gridspec(1, 2)
ax = fig.add_subplot(gs[0, 0])
ax.set(xlabel="$Time$", ylabel="$Amp$", title="$Ricker$", xlim=[0, 1])
ax.plot(t, wavelet.cpu().numpy(), color='red', linestyle='-', linewidth=1.5)
ax.grid(True, which='both', linestyle='--', color='grey', linewidth=.8, alpha=1.0)
ax.minorticks_on()

ax = fig.add_subplot(gs[0, 1])
ax.set(xlabel="$Frequency [Hz]$", ylabel="$Amp$", title="$Amp Spectrum$", xlim=[0, 40])
ax.plot(f, np.abs(torch.fft.rfft(wavelet).cpu().numpy()), color='red', linestyle='-', linewidth=1.5)
ax.grid(True, which='both', linestyle='--', color='grey', linewidth=.8, alpha=1.0)
ax.minorticks_on()

################## Check the stability condition #################
print(vp_tensor.max()*dt/dz/np.sqrt(1/2),"< 1") # should <1
print(vp_tensor.min()/10/freq/dz,"> 1") # should >1

forward_rnn = rnn2D(nz, nx, zs, xs, zr, xr, dz, dt, 
                    npad=npad, order=2, vmax=vp_tensor.max(), 
                    log_para=1e-6,
                    freeSurface=True, 
                    dtype=torch.float32, 
                    device=device).to(device)

# # forward modeling
# _, _, shots, _ = forward_rnn(vmodel=vp_tensor.to(device), segment_wavelet=wavelet)
# _, _, shots_init, _ = forward_rnn(vmodel=vi_tensor.to(device), segment_wavelet=wavelet)

In [None]:
torch.save(shots, "./shots_marmousi_I_2048.npz")
shots = torch.load("./shots_marmousi_I_2048.npz")

In [None]:
shots.shape

In [None]:
fig=plt.figure(figsize=(ns*1.5, 8))
imagesc(fig,
        shots.cpu().numpy().reshape(-1, ns, nt, nx),
        vmin=-shots.max()/30,
        vmax=shots.max()/30,
        extent=[0, nx*dz/1000, t.numpy().max(), 0],
        aspect=6,
        nRows_nCols=(1, ns),
        cmap='RdBu_r', #seismic
        ylabel="Time (s)",
        xlabel="Position (km)",
        clabel="",
        xticks=np.arange(0., int(nx*dz/1000), 2),
        yticks=np.arange(0., t.numpy().max(), .5),
        fontsize=8,
        cbar_width="7%",
        cbar_height="100%",
        cbar_loc='lower left')
fig.tight_layout(pad=-0.85)

In [None]:
import torch
from torch import nn


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        
        self.conv = nn.Sequential(nn.Conv2d(in_channels, 
                                             out_channels, 
                                             kernel_size=k_size, 
                                             stride=1, # 将stride设置为1
                                             padding=padding),
                                   nn.MaxPool2d(kernel_size=stride),  # 添加MaxPool2d层，参数为原来的stride
                                   nn.Dropout(p=0.1, inplace=False), 
                                   nn.BatchNorm2d(out_channels, track_running_stats=False),
                                   nn.LeakyReLU()
                                  )
        
    def forward(self, x):
        out = self.conv(x)
        return out    
    
    
class TransConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1):
        super(TransConvBlock, self).__init__()
        
        self.transconv = nn.Sequential(nn.ConvTranspose2d(in_channels, 
                                                          out_channels, 
                                                          kernel_size=k_size, 
                                                          stride=stride,  
                                                          padding=padding),
                                       nn.Conv2d(out_channels, 
                                             out_channels, 
                                             kernel_size=3, 
                                             stride=1, 
                                             padding=1),
                                       nn.Dropout(p=0.1, inplace=False),
                                       nn.BatchNorm2d(out_channels, track_running_stats=False),
                                       nn.LeakyReLU()
                                       )
        
    def forward(self, x):
        out = self.transconv(x)
        return out
    
    
class TransConvBlock_last(nn.Module):
    def __init__(self, in_channels, out_channels, k_size=3, stride=2, padding=1):
        super(TransConvBlock_last, self).__init__()
        
        self.transconv = nn.Sequential(nn.ConvTranspose2d(in_channels, 
                                                          out_channels, 
                                                          kernel_size=k_size, 
                                                          stride=stride, 
                                                          padding=padding),
                                       nn.BatchNorm2d(out_channels, track_running_stats=False)
                                       )
        
    def forward(self, x):
        out = self.transconv(x)
        return out
        

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()    
        
        #encoder--------------------------------------------                                  
        self.conv_block_1 = ConvBlock(12, 32, k_size=(3,1), stride=(2,1), padding=(1,0))      
        self.conv_block_2 = ConvBlock(32, 64, k_size=(3,1), stride=(2,1), padding=(1,0)) 
        self.conv_block_3 = ConvBlock(64, 128, k_size=3, stride=2, padding=1)          
        self.conv_block_4 = ConvBlock(128, 256, k_size=3, stride=2, padding=1)         
        self.conv_block_5 = ConvBlock(256, 512, k_size=3, stride=2, padding=1)      
        self.conv_block_6 = ConvBlock(512, 512, k_size=3, stride=2, padding=1)   
        
        #decoder-------------------------------------------     
        self.trans_conv_block_1 = TransConvBlock(512, 256, k_size=(1,4), stride=(1,2), padding=(0,1))                
        self.trans_conv_block_2 = TransConvBlock(256, 128, k_size=4, stride=2, padding=1)                
        self.trans_conv_block_3 = TransConvBlock(128, 64, k_size=4, stride=2, padding=1)                
        self.trans_conv_block_4 = TransConvBlock(64, 32, k_size=4, stride=2, padding=1)                      
        self.trans_conv_block_5 = TransConvBlock_last(32, 1, k_size=3, stride=1, padding=1)                


    def encode(self, x):
       
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        x = self.conv_block_3(x)
        x = self.conv_block_4(x)
        x = self.conv_block_5(x)    
        out = self.conv_block_6(x)   
        return out
        
        
        
    def decode(self, x):
        
        x = self.trans_conv_block_1(x)
        x = self.trans_conv_block_2(x)
        x = self.trans_conv_block_3(x)
        x = self.trans_conv_block_4(x)
        x = self.trans_conv_block_5(x)
        return x
    
    
    def forward(self, x):
        
        x = self.encode(x)
        v_pred = self.decode(x)  
        return v_pred
    


In [None]:
from torchinfo import summary
model = MyModel()
inputs = (1,12, 1024, 256)
summary(model, inputs)

In [None]:
vp_tensor_norm = (vp_tensor - 2900)/1000

In [None]:
fig = plt.figure(figsize=(20, 7))
gs = fig.add_gridspec(2, 2)
ax = fig.add_subplot(gs[0, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vp_tensor")
im = ax.imshow((vp_tensor[0].cpu().detach().numpy())/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = (vp_tensor[0]/1000).min() # 最小值
vmax = ((vp_tensor[0])/1000).max()   # 最大值
im.set_clim(vmin, vmax)


ax = fig.add_subplot(gs[0, 1])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vp_tensor_norm")
im = ax.imshow((vp_tensor_norm[0].cpu().detach().numpy()), extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = (vp_tensor_norm[0]).min() # 最小值
vmax = (vp_tensor_norm[0]).max() # 最大值
im.set_clim(vmin, vmax)

In [None]:
import torchvision
# import torchvision.transforms as transforms
# torchvision.models.vgg16().features

def train(model_Jin, data, wavelet, vi_tensor, optimizer, device='cpu'):
    
    data_resample = data[:, :, :, :] 
    data_norm = (data_resample-data.mean())/(data.std()).to(device)
   
    train_loss = 0
    data_norm = data_norm.to(device)
    optimizer.zero_grad()

    v_recon_norm = model_Jin(data_norm) 
    v_recon = (v_recon_norm)*(1000) + 2900  
    v_recon = v_recon.to(device).type(dtype=torch.float32)
 
    loss_fn2 = nn.MSELoss()
    Loss2 = loss_fn2(v_recon[0], vi_tensor)
    loss = Loss2
    
    train_loss += loss.detach().cpu().item()
    loss.backward()
    optimizer.step()
    
    return [train_loss], v_recon[0], data

def save_state(epoch, model, optimizer, train_loss_history):
    import os
    state = {'epoch': epoch + 1,
             'train_loss_history': train_loss_history,
             'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict()
             }
    torch.save(state, "./checkpoints-Drop-marmousi-G-New-p0.1/checkpoint-Drop-pre-marmousi-G-2-{}.pth".format(epoch+1))  #  2  track_running_stats=False
    if os.path.exists("./checkpoints-Drop-marmousi-G-New-p0.1/checkpoint-Drop-pre-marmousi-G-2-{}.pth".format(epoch)):
        os.remove("./checkpoints-Drop-marmousi-G-New-p0.1/checkpoint-Drop-pre-marmousi-G-2-{}.pth".format(epoch))
        
def load_state(model, optimizer, resume_file):
    checkpoint = torch.load(resume_file)
    resume_epoch = checkpoint['epoch']
    train_loss_history = checkpoint['train_loss_history']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return resume_epoch, model, optimizer, train_loss_history   


In [None]:
from torch.optim.lr_scheduler import StepLR
# 预训练模型
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
data = torch.load("./shots_marmousi_I_2048.npz").to(device).type(dtype=torch.float32) 

model = MyModel().to(device)
optimizer = torch.optim.Adam([{'params': model.parameters()}], 
                               lr=0.001, weight_decay=0.0001)

max_epoch = 1000
train_loss_history = []
print_interval = 5
for epoch in range(0, max_epoch):
   
    train_loss, v_pred, s_pred  = train(model, data, wavelet, vi_tensor, optimizer, device)
    train_loss_history.append(train_loss)
    save_state(epoch, model, optimizer, train_loss_history)
   
    if (epoch + 1) % print_interval == 0: 
        print("Epoch: {}, Training Loss: {:.8f}".format(epoch, train_loss[0]))

train_loss_history = torch.tensor(train_loss_history)


In [None]:
fig = plt.figure(figsize=(20, 7))
gs = fig.add_gridspec(2, 2)
ax = fig.add_subplot(gs[0, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred_full")
im = ax.imshow((v_pred[0].cpu().detach().numpy())/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = (vi_tensor[0]/1000).min() # 最小值
vmax = ((vi_tensor[0])/1000).max()   # 最大值
im.set_clim(vmin, vmax)


ax = fig.add_subplot(gs[1, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vi")
im = ax.imshow(((vi_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')


ax = fig.add_subplot(gs[1, 1])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="loss")
im = ax.imshow(((v_pred[0]-vi_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')


# 设置colorbar的取值范围
vmin = -((v_pred[0]-vi_tensor[0])/1000).max()  # 最小值
vmax = ((v_pred[0]-vi_tensor[0])/1000).max()   # 最大值
im.set_clim(vmin, vmax)

In [None]:
def train(model, data, wavelet, vp_tensor, optimizer, device='cpu'):
    lamda1 = lamda2 = lamda3 = lamda4 = 1
    model.train()
    data_resample = data[:, :, :, :]
    data_norm = (data_resample-data.mean())/(data.std()).to(device)
    
    train_loss = 0
    Loss_val = 0
    optimizer.zero_grad()

    v_recon_norm = model(data_norm) 
    v_recon = (v_recon_norm)*(1000) + 2900  
    _, _, shots, _ = forward_rnn(vmodel = v_recon[0].to(device), segment_wavelet=wavelet)
    shots_resample = shots[:, :, :, :]
    
    data_rgb = data_resample[0][:,None,].repeat(1, 3, 1, 1) 
    shots_rgb = shots_resample[0][:,None,].repeat(1, 3, 1, 1)  
    data_rgb_norm = (data_rgb-data_rgb.mean())/(data_rgb.std())
    shots_rgb_norm = (shots_rgb-data_rgb.mean())/(data_rgb.std())
    data_vgg = model_vgg(data_rgb_norm)
    shots_vgg = model_vgg(shots_rgb_norm)
    
    loss_fn1 = nn.L1Loss()
    loss_fn2 = nn.MSELoss()
   
    Loss1 = loss_fn1(data_resample, shots_resample)
    Loss2 = loss_fn2(data_resample, shots_resample)
    L_recon_pixel = lamda1*Loss1 + lamda2*Loss2
    
    Loss3 = loss_fn1(data_vgg, shots_vgg)
    Loss4 = loss_fn2(data_vgg, shots_vgg)
    L_recon_perceptual = lamda3*Loss3 + lamda4*Loss4
    
    loss = L_recon_pixel + L_recon_perceptual
    
    train_loss += loss.detach().cpu().item()
    loss.backward()
    optimizer.step()
    
    Loss_val = loss_fn2(v_recon[0], vp_tensor)
    Loss_val += Loss_val.detach().cpu().item()
    return [train_loss, L_recon_pixel, L_recon_perceptual, Loss_val], v_recon[0], shots, [data.min(), data.max()] 
    
    
def test(model, data, device='cpu'):
    model.train()
    data_resample = data[:, :, :, :] 
    data_norm = (data_resample-data.mean())/(data.std()).to(device)

    v_recon_norm = model(data_norm) 
    v_recon = (v_recon_norm)*(1000) + 2900  
    return v_recon[0]
    
def save_state(epoch, model, optimizer, train_loss_history):
    import os
    state = {'epoch': epoch + 1,
             'train_loss_history': train_loss_history,
             'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict()
             }
    torch.save(state, "./checkpoints-Drop-marmousi-G-New-p0.1/checkpoint-Drop-marmousi-G-2-{}.pth".format(epoch+1))
#     if os.path.exists("./checkpoints-Drop-marmousi-G-New-p0.1/checkpoint-Drop-marmousi-G-2-{}.pth".format(epoch)):
#         os.remove("./checkpoints-Drop-marmousi-G-New-p0.1/checkpoint-Drop-marmousi-G-2-{}.pth"".format(epoch))
        
def load_state(model, optimizer, resume_file):
    checkpoint = torch.load(resume_file)
    resume_epoch = checkpoint['epoch']
    train_loss_history = checkpoint['train_loss_history']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return resume_epoch, model, optimizer, train_loss_history   


In [None]:
from torch.optim.lr_scheduler import StepLR
import torchvision

# 训练模型
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
data = torch.load("./sho ts_marmousi_I_2048.npz").to(device).type(dtype=torch.float32) 

model = MyModel().to(device)
optimizer = torch.optim.Adam([{'params': model.parameters()}], 
                               lr=0.001, weight_decay=0.0001)
_, model, _, _ = load_state(model, optimizer, "./checkpoints-Drop-marmousi-G-New-p0.1/checkpoint-Drop-pre-marmousi-G-2-1000.pth") 
optimizer = torch.optim.AdamW([{'params': model.parameters()}], 
                               lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0001)
model_vgg = torchvision.models.vgg16(pretrained=True).features[:13].to(device) #第13层输出--conv5的结果。

max_epoch = 3000
train_loss_history = [[], [], [], []] 
print_interval = 1
# save_interval = 100
for epoch in range(0, max_epoch):

    train_loss, v_pred, s_pred, norm_list = train(model, data, wavelet, vp_tensor, optimizer, device)

    for i in range(4):  
        train_loss_history[i].append(train_loss[i]) 
   
    if epoch == 0 or (epoch < 100 and (epoch + 1) % 10 == 0) or (epoch >= 100 and (epoch + 1) % 100 == 0):
        save_state(epoch, model, optimizer, train_loss_history)
        
    if (epoch + 1) % print_interval == 0:
        print("Epoch: {}, Training Loss: {:.8f}, recon_pixel Loss: {:.8f}, recon_perceptual Loss: {:.8f}, recon_vmodel Loss: {:.8f}".format(epoch, train_loss[0], train_loss[1], train_loss[2], train_loss[3]))

train_loss_history = torch.tensor(train_loss_history)

In [None]:
def train(model, data, device='cpu'):
    model.train()
    data_resample = data[:, :, :, :] 
    data_norm = (data_resample-data.mean())/(data.std()).to(device)
    
    v_recon_norm = model(data_norm) 
    v_recon = (v_recon_norm)*(1000) + 2900  
   
    return v_recon[0]

def test(model, data, device='cpu'):
    model.eval()
    data_resample = data[:, :, :, :] 
    data_norm = (data_resample-data.mean())/(data.std()).to(device)
    
    v_recon_norm = model(data_norm) 
    v_recon = (v_recon_norm)*(1000) + 2900  
   
    return v_recon[0]

def load_state(model, optimizer, resume_file):
    checkpoint = torch.load(resume_file)
    resume_epoch = checkpoint['epoch']
    train_loss_history = checkpoint['train_loss_history']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return resume_epoch, model, optimizer, train_loss_history   


In [None]:
# 展示反演过程的vpred
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data = torch.load("./shots_marmousi_I_2048.npz").to(device).type(dtype=torch.float32)  
model = MyModel().to(device)
optimizer = torch.optim.AdamW([{'params': model.parameters()}], 
                               lr=0.00001, betas=(0.9, 0.999), weight_decay=0.0001)

max_epoch = 3000
for epoch in range(0, max_epoch):
    if epoch == 0 or (epoch + 1) % 100 == 0:
        _, model, _, _ = load_state(model, optimizer, "./checkpoints-Drop-marmousi-G-New-p0.1/checkpoint-Drop-marmousi-G-2-{}.pth".format(epoch+1)) 
        model = model.to(device)
        v_pred = test(model, data, device)
        
        torch.save(v_pred, './Figures/Dropout-G-inversion-results/p0.1/v_pred_test/v_pred_{}.pth'.format(epoch+1))
            
        cmap = 'RdBu_r'
        fig = plt.figure(figsize=(26, 10))
        gs = fig.add_gridspec(3, 4)
        ax = fig.add_subplot(gs[0, 0])
        ax.set_ylabel('Depth (km)', fontsize=16)
        #ax.set_xlabel('Distance (km)', fontsize=16)
        ax.set_title('(a) {}'.format(epoch+1), loc='left', fontsize=20)
        im = ax.imshow(v_pred.detach().cpu().numpy().squeeze(), extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap=cmap,
                       vmin=v_pred.cpu().squeeze().min(), vmax=v_pred.cpu().squeeze().max())
        ax.set_xticks([])
        cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='$m/s$')

In [None]:
# 汇总展示反演过程的vpred
epochs = [1, 1000, 1500, 2000, 3000]
v_pred = torch.zeros((5, 128, 256))
for s in range(5):
    epoch = epochs[s]
    v_pred[s] = torch.load('./Figures/Dropout-G-inversion-results/p0.1/v_pred_test/v_pred_{}.pth'.format(epoch))
# print(v_pred)  

cmap = 'RdBu_r'
fig = plt.figure(figsize=(30, 3))
gs = fig.add_gridspec(1, 5)
ax = fig.add_subplot(gs[0, 0])
ax.set_ylabel('Depth [km]', fontsize=16)
ax.set_xlabel('Distance [km]', fontsize=16)
ax.set_title('(a)', loc='center', fontsize=20)
im = ax.imshow(v_pred[0].detach().cpu().numpy()/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap=cmap,
               vmin=vp_tensor.min()/1000, vmax=vp_tensor.max()/1000)
# ax.set_xticks([])
ax.tick_params(axis='x', labelsize=12) 
ax.tick_params(axis='y', labelsize=12) 
# cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='$Km/s$')

ax = fig.add_subplot(gs[0, 1])
#ax.set_ylabel('Depth (km)', fontsize=16)
ax.set_xlabel('Distance [km]', fontsize=16)
ax.set_title('(b)', loc='center', fontsize=20)
im = ax.imshow(v_pred[1].detach().cpu().numpy()/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap=cmap,
               vmin=vp_tensor.min()/1000, vmax=vp_tensor.max()/1000)
ax.set_yticks([])
# ax.set_xticks([])
ax.tick_params(axis='x', labelsize=12) 
# cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='$Km/s$')

ax = fig.add_subplot(gs[0, 2])
#ax.set_ylabel('Depth (km)', fontsize=16)
ax.set_xlabel('Distance [km]', fontsize=16)
ax.set_title('(c)', loc='center', fontsize=20)
im = ax.imshow(v_pred[2].detach().cpu().numpy()/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap=cmap,
               vmin=vp_tensor.min()/1000, vmax=vp_tensor.max()/1000)
ax.set_yticks([])
# ax.set_xticks([])
ax.tick_params(axis='x', labelsize=12) 
# cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='$Km/s$')

ax = fig.add_subplot(gs[0, 3])
#ax.set_ylabel('Depth (km)', fontsize=16)
ax.set_xlabel('Distance [km]', fontsize=16)
ax.set_title('(d)', loc='center', fontsize=20)
im = ax.imshow(v_pred[3].detach().cpu().numpy()/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap=cmap,
               vmin=vp_tensor.min()/1000, vmax=vp_tensor.max()/1000)
ax.set_yticks([])
# ax.set_xticks([])
ax.tick_params(axis='x', labelsize=12) 
# cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='$Km/s$')

ax = fig.add_subplot(gs[0, 4])
#ax.set_ylabel('Depth (km)', fontsize=16)
ax.set_xlabel('Distance [km]', fontsize=16)
ax.set_title('(e)', loc='center', fontsize=20)
im = ax.imshow(v_pred[4].detach().cpu().numpy()/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap=cmap,
               vmin=vp_tensor.min()/1000, vmax=vp_tensor.max()/1000)
ax.set_yticks([])
# ax.set_xticks([])
ax.tick_params(axis='x', labelsize=12) 
cbar = add_colorbar(ax, im, ax.transAxes, width="3%")
cbar.ax.tick_params(labelsize=12)  
cbar.set_label('Km/s', size=14) 

plt.subplots_adjust(left=0.125,
                    bottom=0.1, 
                    right=0.9, 
                    top=0.9, 
                    wspace=0.05,  
                    hspace=0.4)

plt.savefig('./Figures/Dropout-0.1-test-G-inversion-results.png', dpi=300, bbox_inches='tight')


In [None]:
# 汇总展示反演过程的vpred
epochs = [1, 1000, 1500, 2000, 3000]
v_pred = torch.zeros((5, 128, 256))
for s in range(5):
    epoch = epochs[s]
    v_pred[s] = torch.load('./Figures/Dropout-G-inversion-results/p0.1/v_pred_train/v_pred_{}.pth'.format(epoch))
# print(v_pred)  

cmap = 'RdBu_r'
fig = plt.figure(figsize=(30, 3))
gs = fig.add_gridspec(1, 5)
ax = fig.add_subplot(gs[0, 0])
ax.set_ylabel('Depth [km]', fontsize=16)
ax.set_xlabel('Distance [km]', fontsize=16)
ax.set_title('(a)', loc='center', fontsize=20)
im = ax.imshow(v_pred[0].detach().cpu().numpy()/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap=cmap,
               vmin=vp_tensor.min()/1000, vmax=vp_tensor.max()/1000)
# ax.set_xticks([])
ax.tick_params(axis='x', labelsize=12) 
ax.tick_params(axis='y', labelsize=12) 
# cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='$Km/s$')

ax = fig.add_subplot(gs[0, 1])
#ax.set_ylabel('Depth (km)', fontsize=16)
ax.set_xlabel('Distance [km]', fontsize=16)
ax.set_title('(b)', loc='center', fontsize=20)
im = ax.imshow(v_pred[1].detach().cpu().numpy()/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap=cmap,
               vmin=vp_tensor.min()/1000, vmax=vp_tensor.max()/1000)
ax.set_yticks([])
# ax.set_xticks([])
ax.tick_params(axis='x', labelsize=12) 
# cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='$Km/s$')

ax = fig.add_subplot(gs[0, 2])
#ax.set_ylabel('Depth (km)', fontsize=16)
ax.set_xlabel('Distance [km]', fontsize=16)
ax.set_title('(c)', loc='center', fontsize=20)
im = ax.imshow(v_pred[2].detach().cpu().numpy()/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap=cmap,
               vmin=vp_tensor.min()/1000, vmax=vp_tensor.max()/1000)
ax.set_yticks([])
# ax.set_xticks([])
ax.tick_params(axis='x', labelsize=12) 
# cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='$Km/s$')

ax = fig.add_subplot(gs[0, 3])
#ax.set_ylabel('Depth (km)', fontsize=16)
ax.set_xlabel('Distance [km]', fontsize=16)
ax.set_title('(d)', loc='center', fontsize=20)
im = ax.imshow(v_pred[3].detach().cpu().numpy()/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap=cmap,
               vmin=vp_tensor.min()/1000, vmax=vp_tensor.max()/1000)
ax.set_yticks([])
# ax.set_xticks([])
ax.tick_params(axis='x', labelsize=12) 
# cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='$Km/s$')

ax = fig.add_subplot(gs[0, 4])
#ax.set_ylabel('Depth (km)', fontsize=16)
ax.set_xlabel('Distance [km]', fontsize=16)
ax.set_title('(e)', loc='center', fontsize=20)
im = ax.imshow(v_pred[4].detach().cpu().numpy()/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap=cmap,
               vmin=vp_tensor.min()/1000, vmax=vp_tensor.max()/1000)
ax.set_yticks([])
# ax.set_xticks([])
ax.tick_params(axis='x', labelsize=12) 
cbar = add_colorbar(ax, im, ax.transAxes, width="3%")
cbar.ax.tick_params(labelsize=12)  
cbar.set_label('Km/s', size=14) 

plt.subplots_adjust(left=0.125,
                    bottom=0.1, 
                    right=0.9, 
                    top=0.9, 
                    wspace=0.05,  
                    hspace=0.4)

plt.savefig('./Figures/Dropout-0.1-train-G-inversion-results.png', dpi=300, bbox_inches='tight')


In [None]:
# 展示反演过程的vpred
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data = torch.load("./shots_marmousi_I_2048.npz").to(device).type(dtype=torch.float32)  
model = MyModel().to(device)
optimizer = torch.optim.AdamW([{'params': model.parameters()}], 
                               lr=0.00001, betas=(0.9, 0.999), weight_decay=0.0001)

_, model, _, train_loss_history = load_state(model, optimizer, "./checkpoints-Drop-marmousi-G-New-p0.1/checkpoint-Drop-marmousi-G-2-{}.pth".format(3000)) 
        
train_loss_history = torch.tensor(train_loss_history)

In [None]:
n=3000
plt.figure(figsize=(15, 5))
plt.subplot(1, 1, 1)
plt.plot(np.arange(0,n), train_loss_history[0, 0:n].numpy(), label="all")
plt.plot(np.arange(0,n), train_loss_history[1, 0:n].numpy()*1000, label="recon_pixel")
plt.plot(np.arange(0,n), train_loss_history[2, 0:n].numpy(), label="recon_percep")
plt.legend()


In [None]:
torch.save(train_loss_history, './Figures/Dropout_0.1_train_loss_history.pth')

In [None]:
#dropout = 0.1 
fig = plt.figure(figsize=(20, 14))
gs = fig.add_gridspec(3, 2)
ax = fig.add_subplot(gs[0, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred")
im = ax.imshow((v_pred[0].cpu().detach().numpy())/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = ((vp_tensor[0])/1000).min()  # 最小值
vmax = ((vp_tensor[0])/1000).max()  # 最大值
im.set_clim(vmin, vmax)
# ax = fig.add_subplot(gs[0, 1])
# ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_init")
# im = ax.imshow(((vi_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
# cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')


ax = fig.add_subplot(gs[1, 1])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vp")
im = ax.imshow(((vp_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = ((vp_tensor[0])/1000).min()  # 最小值
vmax = ((vp_tensor[0])/1000).max()  # 最大值
im.set_clim(vmin, vmax)

ax = fig.add_subplot(gs[1, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vi")
im = ax.imshow(((vi_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

vmin = ((vp_tensor[0])/1000).min()  # 最小值
vmax = ((vp_tensor[0])/1000).max()  # 最大值
im.set_clim(vmin, vmax)

ax = fig.add_subplot(gs[2, 1])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred-vp")
im = ax.imshow(((v_pred[0]-vp_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = -((v_pred[0]-vp_tensor[0])/1000).max()  # 最小值
vmax = ((v_pred[0]-vp_tensor[0])/1000).max()   # 最大值
im.set_clim(vmin, vmax)


ax = fig.add_subplot(gs[2, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred-vi")
im = ax.imshow(((v_pred[0]-vi_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = -((v_pred[0]-vi_tensor[0])/1000).max()   # 最小值
vmax = ((v_pred[0]-vi_tensor[0])/1000).max()   # 最大值
im.set_clim(vmin, vmax)



In [None]:
n=2000
plt.figure(figsize=(15, 5))
plt.subplot(1, 1, 1)
plt.plot(np.arange(0,n), train_loss_history[0, 0:n].numpy(), label="all")
plt.plot(np.arange(0,n), train_loss_history[1, 0:n].numpy()*1000, label="recon_pixel")
plt.plot(np.arange(0,n), train_loss_history[2, 0:n].numpy(), label="recon_percep")
plt.legend()


In [None]:
n=2000
plt.figure(figsize=(15, 5))
plt.subplot(1, 1, 1)
plt.plot(np.arange(0,n), train_loss_history[3, 0:n].numpy(), label="recon_vp")
plt.legend()


In [None]:
plt.figure(figsize=(15, 5))
plt.subplot(1, 1, 1)
plt.plot(np.arange(0,1024), s_pred[0, 6, :, 10].detach().cpu().numpy(), label="s_pred")
plt.plot(np.arange(0,1024), data[0, 6, :, 10].detach().cpu().numpy(), label="data")
plt.plot(np.arange(0,1024), (data[0, 6, :, 10]-s_pred[0, 6, :, 10]).detach().cpu().numpy(), label="loss")

plt.legend()


In [None]:
fig = plt.figure(figsize=(24, 10))
gs = fig.add_gridspec(1, 3)
ax = fig.add_subplot(gs[0, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="s_pred")
im = ax.imshow((s_pred[0, 6].cpu().detach().numpy()), aspect=0.25, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = ((data[0, 6])).min()/40   # 最小值
vmax = ((data[0, 6])).max()/40   # 最大值
im.set_clim(vmin, vmax)
# ax = fig.add_subplot(gs[0, 1])
# ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_init")
# im = ax.imshow(((vi_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
# cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')


ax = fig.add_subplot(gs[0, 1])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="data")
im = ax.imshow(((data[0, 6]).cpu().detach().numpy()), aspect=0.25, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = ((data[0, 6])).min()/40  # 最小值
vmax = ((data[0, 6])).max()/40  # 最大值
im.set_clim(vmin, vmax)


ax = fig.add_subplot(gs[0, 2])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="loss")
im = ax.imshow(((data[0, 6]-s_pred[0, 6]).cpu().detach().numpy()), aspect=0.25, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = -((data[0, 6]-s_pred[0, 6])).max()   # 最小值
vmax = ((data[0, 6]-s_pred[0, 6])).max()   # 最大值
im.set_clim(vmin, vmax)

## 添加噪音

In [None]:
import torch

def add_noise_SNR(data, noise_ratio):
    """
    This function adds white noise to the signal.
    :param data: Input data (seismic recording). It should be a PyTorch tensor of shape (nt, nx).
    :param noise_ratio: SNR value. It must be a non-negative real number.
    :return: Signal with added noise.
    """

    # Generate white noise
    noise = torch.randn(data.size())

    # Compute the power of signal and noise
    original_signal_power = torch.pow(data, 2).mean()
    noise_power = torch.pow(noise, 2).mean()

    # Compute the scaling factor for noise
    noise_scaling = ((original_signal_power / noise_power) / noise_ratio).sqrt()

    # Add noise to the original signal
    noised_data = data + noise_scaling * noise

    return noised_data

# 假设你的地震记录存储在一个名为seismic_record的Tensor中:
seismic_record = torch.load("./shots_marmousi_I_2048.npz").cpu()  #[1,12,1024,256]

# 指定所需的信噪比:
SNR = 10  # 假设你想要的信噪比为10

# 使用函数添加白噪声到地震记录:
seismic_record_noised = add_noise_SNR(seismic_record, SNR)
torch.save(seismic_record_noised, "./shots_marmousi_I_2048_SNR10.npz")


In [None]:
seismic_record_noised.shape

In [None]:
plt.imshow(seismic_record[0,0], vmin=-seismic_record.max()/40, vmax=seismic_record.max()/40,aspect=0.3, cmap='RdBu_r')
plt.colorbar

In [None]:
plt.imshow(seismic_record_noised[0,0], vmin=-seismic_record.max()/40, vmax=seismic_record.max()/40,aspect=0.3, cmap='RdBu_r')
plt.colorbar

In [None]:
import torch

def add_noise_std(data, ratio):
    """
    input: pytorch tensor representing the signal
    """
    # Calculate signal std
    signal_std = torch.std(data)
   
    # Generate an sample of white noise
    noise = torch.randn(data.size())*signal_std*ratio
    
    # Calculate 
    noised_data = data + noise
    
    return noised_data

# 假设你的地震记录存储在一个名为seismic_record的Tensor中:
seismic_record = torch.load("./shots_marmousi_I_2048.npz").cpu()   #[1,12,1024,256]
seismic_vi_record = torch.load("./shots_vi10_marmousi_I_2048.npz").cpu()   #[1,12,1024,256]

# 指定所需的信噪比:
ratio = 1   # 假设你想要的比率

# 使用函数添加白噪声到地震记录:
seismic_record_noised = add_noise_std(seismic_record, ratio)
seismic_vi_record_noised = add_noise_std(seismic_vi_record, ratio)

torch.save(seismic_record_noised, "./shots_marmousi_I_2048_STD1.npz")
torch.save(seismic_vi_record_noised, "./shots_vi10_marmousi_I_2048_STD1.npz")


In [None]:
plt.imshow(seismic_record[0,0], vmin=-seismic_record.max()/40, vmax=seismic_record.max()/40,aspect=0.3, cmap='RdBu_r')
plt.colorbar()

In [None]:
plt.imshow(seismic_record_noised[0,0], vmin=-seismic_record.max()/40, vmax=seismic_record.max()/40,aspect=0.3, cmap='RdBu_r')
plt.colorbar

In [None]:
plt.imshow(seismic_vi_record_noised[0,0], vmin=-seismic_record.max()/40, vmax=seismic_record.max()/40,aspect=0.3, cmap='RdBu_r')
plt.colorbar

## 缺失低频的反演

In [None]:
import torch
from torch import nn


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        
        self.conv = nn.Sequential(nn.Conv2d(in_channels, 
                                             out_channels, 
                                             kernel_size=k_size, 
                                             stride=stride, 
                                             padding=padding),
                                   nn.BatchNorm2d(out_channels),
                                   nn.LeakyReLU(),
                                   nn.Dropout(p=0.1, inplace=False)
                                  )
        
    def forward(self, x):
        out = self.conv(x)
        return out
    
    
class TransConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, k_size=3, stride=2, padding=1):
        super(TransConvBlock, self).__init__()
        
        self.transconv = nn.Sequential(nn.ConvTranspose2d(in_channels, 
                                                          out_channels, 
                                                          kernel_size=k_size, 
                                                          stride=stride, 
                                                          padding=padding),
                                       nn.BatchNorm2d(out_channels),
                                       nn.LeakyReLU(),
                                       nn.Dropout(p=0.1, inplace=False)
                                       )
        
    def forward(self, x):
        out = self.transconv(x)
        return out
    
    
class TransConvBlock_last(nn.Module):
    def __init__(self, in_channels, out_channels, k_size=3, stride=2, padding=1):
        super(TransConvBlock_last, self).__init__()
        
        self.transconv = nn.Sequential(nn.ConvTranspose2d(in_channels, 
                                                          out_channels, 
                                                          kernel_size=k_size, 
                                                          stride=stride, 
                                                          padding=padding),
                                       # nn.Dropout(p=0.2, inplace=False),
                                       nn.BatchNorm2d(out_channels)
                                       )
        
    def forward(self, x):
        out = self.transconv(x)
        return out
        

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()    
        
        #encoder--------------------------------------------                                   #[batch,6,1000,32]
        self.conv_block_1 = ConvBlock(12, 32, k_size=(4,1), stride=(2,1), padding=(1,0))       #[batch,32,500,32]
        
        self.conv_block_2 = ConvBlock(32, 64, k_size=(4,1), stride=(2,1), padding=(1,0)) 
        
        self.conv_block_3 = ConvBlock(64, 128, k_size=4, stride=2, padding=1)          
        
        self.conv_block_4 = ConvBlock(128, 256, k_size=4, stride=2, padding=1)         
        
        self.conv_block_5 = ConvBlock(256, 512, k_size=4, stride=2, padding=1)      
        
        self.conv_block_6 = ConvBlock(512, 512, k_size=4, stride=2, padding=1)   
        
        
        #文章3--------------------------------------------     
        self.trans_conv_block_1 = TransConvBlock(512, 256, k_size=(1,4), stride=(1,2), padding=(0,1))                
        
        self.trans_conv_block_2 = TransConvBlock(256, 128, k_size=4, stride=2, padding=1)                
        
        self.trans_conv_block_3 = TransConvBlock(128, 64, k_size=4, stride=2, padding=1)                
        
        self.trans_conv_block_4 = TransConvBlock(64, 32, k_size=4, stride=2, padding=1)                      
        
        self.trans_conv_block_5 = TransConvBlock_last(32, 1, k_size=3, stride=1, padding=1)                
        
        # self.tanh = nn.Tanh()
        # self.relu = nn.ReLU()

       

    def encode(self, x):
        # print("x1:", x.shape)
        x = self.conv_block_1(x)
        # print("x2:", x.shape)
        x = self.conv_block_2(x)
        # print("x3:", x.shape)
        x = self.conv_block_3(x)
        # print("x4:", x.shape)
        x = self.conv_block_4(x)
        # print("x5:", x.shape)
        x = self.conv_block_5(x)    
        # print("x6:", x.shape)
        out = self.conv_block_6(x)     #[batch,512,1,1]
        # print("x7:", out.shape)
        
        return out
        
        
        
    def decode(self, x):
        # print("x1:", x.shape) 
        x = self.trans_conv_block_1(x)
        # print("x2:", x.shape)
        x = self.trans_conv_block_2(x)
        # print("x3:", x.shape)
        x = self.trans_conv_block_3(x)
        # print("x4:", x.shape)
        x = self.trans_conv_block_4(x)
        # print("x5:", x.shape)
        x = self.trans_conv_block_5(x)
        # print("x6:", x.shape)
        
        # x = self.tanh(x)
        
        return x
    
    
    def forward(self, x):
        
        x = self.encode(x)
        v_pred = self.decode(x)  
        # v_pred = (torch.exp(v_pred)-0.8)/2
        # v_pred_clamped = torch.clamp(v_pred, 0, 1)  #限定输出数值在[0,1]之间
        # v_pred_clamped = torch.sqrt(v_pred_clamped)
        return v_pred
    


In [None]:
import torchvision
# import torchvision.transforms as transforms
# torchvision.models.vgg16().features

def train(model_Jin, data, wavelet, vi_tensor, optimizer, device='cpu'):
    
    data_resample = data[:, :, :, :]   #输入的是resample后的
    data_norm = (data_resample-data.mean())/(data.std()).to(device)
   
    train_loss = 0
    data_norm = data_norm.to(device)
    optimizer.zero_grad()

    v_recon_norm = model_Jin(data_norm) 
    
    v_recon = (v_recon_norm)*(1000) + 2900  
    # v_recon = (v_recon_norm)*(5500-1500) + 1500  
    # v_recon = (v_recon_norm)*(5309-1572) + 1572  
    # print(v_recon.shape)
    v_recon = v_recon.to(device).type(dtype=torch.float32)
    # _, _, shots, _ = forward_rnn(vmodel=v_recon[0].to(device), segment_wavelet=wavelet)

    # loss_fn1 = nn.L1Loss()
    loss_fn2 = nn.MSELoss()
   
    # Loss1 = loss_fn1(v_recon, vi_tensor)
    Loss2 = loss_fn2(v_recon[0], vi_tensor)
    
    loss = Loss2
    
    train_loss += loss.detach().cpu().item()
    loss.backward()
    
    # # # 梯度裁剪
    # max_norm = 0.5 # 设置梯度的最大范数
    # torch.nn.utils.clip_grad_norm_(model_Jin.parameters(), max_norm)

    optimizer.step()
    
    return [train_loss], v_recon[0], data

def save_state(epoch, model, optimizer, train_loss_history):
    import os
    state = {'epoch': epoch + 1,
             'train_loss_history': train_loss_history,
             'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict()
             }
    torch.save(state, "./checkpoints-Drop-marmousi-woLF/checkpoint-Drop-pre-marmousi-without-low-fre-G-1-{}.pth".format(epoch+1))
    if os.path.exists("./checkpoints-Drop-marmousi-woLF/checkpoint-Drop-pre-marmousi-without-low-fre-G-1-{}.pth".format(epoch)):
        os.remove("./checkpoints-Drop-marmousi-woLF/checkpoint-Drop-pre-marmousi-without-low-fre-G-1-{}.pth".format(epoch))
        
def load_state(model, optimizer, resume_file):
    checkpoint = torch.load(resume_file)
    resume_epoch = checkpoint['epoch']
    train_loss_history = checkpoint['train_loss_history']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return resume_epoch, model, optimizer, train_loss_history   


In [None]:
from torch.optim.lr_scheduler import StepLR
# 训练模型
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# data = torch.load("./shots_vi10_marmousi_I_2048.npz").to(device).type(dtype=torch.float32)  #[1,12,1024,256]
data = torch.load("./shots_wo_low6_marmousi_I_1024.npz").to(device).type(dtype=torch.float32)  #[1,12,1024,256]

model = MyModel().to(device)
optimizer = torch.optim.Adam([{'params': model.parameters()}], 
                               lr=0.001, weight_decay=0.0001)

max_epoch = 1000
train_loss_history = []
print_interval = 5
for epoch in range(0, max_epoch):
   
    train_loss, v_pred, s_pred  = train(model, data, wavelet, vi_tensor, optimizer, device)
    train_loss_history.append(train_loss)
    save_state(epoch, model, optimizer, train_loss_history)
   
    if (epoch + 1) % print_interval == 0:  #取余数是否等于0
        print("Epoch: {}, Training Loss: {:.8f}".format(epoch, train_loss[0]))

train_loss_history = torch.tensor(train_loss_history)


#只有第一项loss是准确的，参与训练。

In [None]:
fig = plt.figure(figsize=(20, 7))
gs = fig.add_gridspec(2, 2)
ax = fig.add_subplot(gs[0, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred_full")
im = ax.imshow((v_pred[0].cpu().detach().numpy())/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = (vi_tensor[0]/1000).min() # 最小值
vmax = ((vi_tensor[0])/1000).max()   # 最大值
im.set_clim(vmin, vmax)


ax = fig.add_subplot(gs[1, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vi")
im = ax.imshow(((vi_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')


ax = fig.add_subplot(gs[1, 1])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="loss")
im = ax.imshow(((v_pred[0]-vi_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')


# 设置colorbar的取值范围
vmin = -((v_pred[0]-vi_tensor[0])/1000).max()  # 最小值
vmax = ((v_pred[0]-vi_tensor[0])/1000).max()   # 最大值
im.set_clim(vmin, vmax)

In [None]:
# sigmas = torch.tensor([1.0, 1.0]).type(torch.float32).to(device)
# sigmas.requires_grad_ = True
# False
sigmas = torch.tensor([1.0, 1.0], requires_grad=False).type(torch.float32)



def train(model, data, wavelet, vp_tensor, optimizer, device='cpu'):
    lamda1 = lamda2 = lamda3 = lamda4 = 1
    model.train()
    data_resample = data[:, :, :, :]   #输入的是resample后的
    # data_norm = (data-data.min())/(data.max()-data.min()).to(device)
    data_norm = (data_resample-data.mean())/(data.std()).to(device)
    
    train_loss = 0
    Loss_val = 0
    optimizer.zero_grad()

    v_recon_norm = model(data_norm)  #输出范围[-1,1]
    v_recon = (v_recon_norm)*(1000) + 2900  
    _, _, shots, _ = forward_rnn(vmodel = v_recon[0].to(device), segment_wavelet=wavelet)
    shots_resample = shots[:, :, :, :]
    # print(shots_resample.shape)
    # shots = 2*(shots-data.min())/(data.max()-data.min()) - 1
    # shots = (shots-data.min())/(data.max()-data.min())
   
    
    data_rgb = data_resample[0][:,None,].repeat(1, 3, 1, 1)     #第二个参数表示通道数 [6,3,1000,70]
    shots_rgb = shots_resample[0][:,None,].repeat(1, 3, 1, 1)  
    # data_rgb_norm = (data_rgb-data_rgb.min())/(data_rgb.max()-data_rgb.min())  #网络的输入正则化
    # shots_rgb_norm = (shots_rgb-data_rgb.min())/(data_rgb.max()-data_rgb.min())
    data_rgb_norm = (data_rgb-data_rgb.mean())/(data_rgb.std())  #网络的输入正则化
    shots_rgb_norm = (shots_rgb-data_rgb.mean())/(data_rgb.std())
    # print(shots_rgb_norm.shape)
    data_vgg = model_vgg(data_rgb_norm)
    shots_vgg = model_vgg(shots_rgb_norm)
    # print("data_vgg:", data_vgg.shape)
    # print("data:", data_rgb.shape)
    # print("shots_vgg:", shots_vgg.shape)
    # print("v_recon:", v_recon.shape)
    
    # L_recon_pixel = msle(data, shots)
    # L_recon_perceptual = msle(data_vgg, shots_vgg)
   
    
    loss_fn1 = nn.L1Loss()
    loss_fn2 = nn.MSELoss()
   
    Loss1 = loss_fn1(data_resample, shots_resample)
    Loss2 = loss_fn2(data_resample, shots_resample)
    L_recon_pixel = lamda1*Loss1 + lamda2*Loss2
    
    Loss3 = loss_fn1(data_vgg, shots_vgg)
    Loss4 = loss_fn2(data_vgg, shots_vgg)
    L_recon_perceptual = lamda3*Loss3 + lamda4*Loss4
    
    # 动态调整权重系数
    # sigmas0 = 1 / (1 + torch.exp(-L_recon_pixel))  # 可以根据具体需求调整计算方式
    # sigmas1 = 1 / (1 + torch.exp(-L_recon_perceptual))  # 可以根据具体需求调整计算方式
        
    # loss = L_recon_pixel + L_recon_perceptual
    loss = L_recon_pixel/(2*sigmas[0]*sigmas[0]) + torch.log(sigmas[0]) + L_recon_perceptual/(2*sigmas[1]*sigmas[1]) + torch.log(sigmas[1]) #+ KLD/1000*(2*sigmas[2]) + torch.log(sigmas[2]) 
    # loss = L_recon_pixel/(2*sigmas[0]*sigmas[0]) + torch.log(sigmas[0])
    
    train_loss += loss.detach().cpu().item()
    loss.backward()
    
     # 梯度裁剪
    # max_norm = 0.0001 # 设置梯度的最大范数
    # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
    
    optimizer.step()
    
    # loss_fn2 = nn.MSELoss()
    Loss_val = loss_fn2(v_recon[0], vp_tensor)
    Loss_val += Loss_val.detach().cpu().item()
    return [train_loss, L_recon_pixel, L_recon_perceptual, Loss_val], v_recon[0], shots, [data.min(), data.max()] #, data_true, mu, logvar
    # return [train_loss, loss_recon, loss_recon2], v_recon_f/ull, shots #, data_true, mu, logvar
    # return [train_loss, loss_recon, loss_recon2, KLD], v_recon_full, shots, data_true, mu, logvar

    
    
    
def test(model, data, device='cpu'):
    model.train()
    data_resample = data[:, :, :, :]   #输入的是resample后的
    data_norm = (data_resample-data.mean())/(data.std()).to(device)

    v_recon_norm = model(data_norm)  #输出范围[-1,1]
    v_recon = (v_recon_norm)*(1000) + 2900  
    
    return v_recon[0]
    
  
    
def save_state(epoch, model, optimizer, train_loss_history):
    import os
    state = {'epoch': epoch + 1,
             'train_loss_history': train_loss_history,
             'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict()
             }
    torch.save(state, "./checkpoints-Drop-marmousi-woLF/checkpoint-Drop-marmousi-without-low-fre-G-1-{}.pth".format(epoch+1))
#     if os.path.exists("./checkpoints-Drop-marmousi-woLF/checkpoint-Drop-marmousi-without-low-fre-G-1-{}.pth".format(epoch)):
#         os.remove("./checkpoints-Drop-marmousi-woLF/checkpoint-Drop-marmousi-without-low-fre-G-1-{}.pth".format(epoch))
        
def load_state(model, optimizer, resume_file):
    checkpoint = torch.load(resume_file)
    resume_epoch = checkpoint['epoch']
    train_loss_history = checkpoint['train_loss_history']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return resume_epoch, model, optimizer, train_loss_history   



In [None]:
from torch.optim.lr_scheduler import StepLR
import torchvision

# 训练模型
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
data = torch.load("./shots_wo_low6_marmousi_I_1024.npz").to(device).type(dtype=torch.float32)  #[1,12,1024,256]

model = MyModel().to(device)

optimizer = torch.optim.Adam([{'params': model.parameters()}], 
                               lr=0.001, weight_decay=0.0001)
_, model, _, _ = load_state(model, optimizer, "./checkpoints-Drop-marmousi-woLF/checkpoint-Drop-pre-marmousi-without-low-fre-G-1-1000.pth") 


optimizer = torch.optim.AdamW([{'params': model.parameters()}, {'params':sigmas}], 
                               lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0001)
# 定义学习率调度器，设置每经过10个epoch时，学习率衰减为原来的10%
# scheduler = StepLR(optimizer, step_size=1000, gamma=0.1)

model_vgg = torchvision.models.vgg16(pretrained=True).features[:13].to(device) #第13层输出--conv5的结果。

max_epoch = 2000
# train_loss_history = []
train_loss_history = [[], [], [], []]  # 初始化4个列表，分别表示4列
print_interval = 10
save_interval = 100
for epoch in range(0, max_epoch):
    # 在每个epoch之前更新学习率
    # scheduler.step()
    train_loss, v_pred, s_pred, norm_list = train(model, data, wavelet, vp_tensor, optimizer, device)
    for i in range(4):  # 假设train_loss始终包含4个元素
        train_loss_history[i].append(train_loss[i])  # 将每个元素添加到对应列的列表
        
    # 每100个epochs，保存一次模型状态
    if epoch == 0 or (epoch + 1) % save_interval == 0:    
        save_state(epoch, model, optimizer, train_loss_history)
    
    if (epoch + 1) % print_interval == 0:  #取余数是否等于0
        print("Epoch: {}, Training Loss: {:.8f}, recon_pixel Loss: {:.8f}, recon_perceptual Loss: {:.8f}, recon_vmodel Loss: {:.8f}, sigmas0: {:.8f}, sigmas1: {:.8f}".format(epoch, train_loss[0], train_loss[1], train_loss[2], train_loss[3], sigmas[0], sigmas[1]))

train_loss_history = torch.tensor(train_loss_history)


In [None]:
train_loss_history.shape

In [None]:
n=1000
plt.figure(figsize=(15, 5))
plt.subplot(1, 1, 1)
plt.plot(np.arange(0,n), train_loss_history[0, 0:n].numpy(), label="all")
plt.plot(np.arange(0,n), train_loss_history[1, 0:n].numpy()*100, label="recon_pixel")
plt.plot(np.arange(0,n), train_loss_history[2, 0:n].numpy(), label="recon_percep")
plt.legend()


In [None]:
train_loss_his_data.shape

In [None]:
n=1000
plt.figure(figsize=(15, 5))
plt.subplot(1, 1, 1)
# plt.plot(np.arange(1,5000), train_loss_history[1:5000, 0].numpy(), label="all")
# plt.plot(np.arange(1,1000), train_loss_history[1:1000, 1].numpy(), label="recon_pixel")
# plt.plot(np.arange(1,1000), train_loss_history[1:1000, 2].numpy(), label="recon_percep")
plt.plot(np.arange(0,n), train_loss_history[0, 0:n].numpy(), label="recon_vp")

plt.legend()


In [None]:
#dropout = 0.1   wo-low-fre6

fig = plt.figure(figsize=(20, 14))
gs = fig.add_gridspec(3, 2)
ax = fig.add_subplot(gs[0, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred")
im = ax.imshow((v_pred[0].cpu().detach().numpy())/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = ((vp_tensor[0])/1000).min()  # 最小值
vmax = ((vp_tensor[0])/1000).max()  # 最大值
im.set_clim(vmin, vmax)
# ax = fig.add_subplot(gs[0, 1])
# ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_init")
# im = ax.imshow(((vi_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
# cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')


ax = fig.add_subplot(gs[1, 1])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vp")
im = ax.imshow(((vp_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = ((vp_tensor[0])/1000).min()  # 最小值
vmax = ((vp_tensor[0])/1000).max()  # 最大值
im.set_clim(vmin, vmax)

ax = fig.add_subplot(gs[1, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vi")
im = ax.imshow(((vi_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

vmin = ((vp_tensor[0])/1000).min()  # 最小值
vmax = ((vp_tensor[0])/1000).max()  # 最大值
im.set_clim(vmin, vmax)

ax = fig.add_subplot(gs[2, 1])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred-vp")
im = ax.imshow(((v_pred[0]-vp_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = -((v_pred[0]-vp_tensor[0])/1000).max()  # 最小值
vmax = ((v_pred[0]-vp_tensor[0])/1000).max()   # 最大值
im.set_clim(vmin, vmax)


ax = fig.add_subplot(gs[2, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred-vi")
im = ax.imshow(((v_pred[0]-vi_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = -((v_pred[0]-vi_tensor[0])/1000).max()   # 最小值
vmax = ((v_pred[0]-vi_tensor[0])/1000).max()   # 最大值
im.set_clim(vmin, vmax)



In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
data = torch.load("./shots_wo_low6_marmousi_I_1024.npz").to(device).type(dtype=torch.float32)  #[1,12,1024,256]

model = MyModel().to(device)
optimizer = torch.optim.AdamW([{'params': model.parameters()}, {'params':sigmas}], 
                               lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0001)

# _, model, _, _ = load_state(model, optimizer, "./checkpoint-Drop-marmousi-without-low-fre-1-1000.pth") 
_, model, _, _ = load_state(model, optimizer, "./checkpoints-Drop-marmousi-woLF/checkpoint-Drop-marmousi-without-low-fre-G-1-2000.pth") 

model = model.to(device)


In [None]:
def test(model, data, device='cpu'):
    model.train()
    data_resample = data[:, :, :, :]   #输入的是resample后的
    data_norm = (data_resample-data.mean())/(data.std()).to(device)

    v_recon_norm = model(data_norm)  #输出范围[-1,1]
    v_recon = (v_recon_norm)*(1000) + 2900  
    
    return v_recon[0]


In [None]:
samples = 100
v_samp = torch.zeros((samples, 128, 256))
for s in range(samples):
    v_pred = test(model, data, device)
    v_samp[s] = v_pred

In [None]:
#p=0.1 缺失低频
fig = plt.figure(figsize=(21, 9.5))
gs = fig.add_gridspec(2, 2)
fontsize=16
ax = fig.add_subplot(gs[0, 0])
# ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred_mean")
im = ax.imshow((torch.mean(v_samp, axis=0).cpu().detach().numpy())/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=0.8, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%")
cbar.ax.set_title('km/s', size=fontsize)
cbar.ax.tick_params(labelsize=fontsize)
# 设置colorbar的取值范围
vmin = ((vp_tensor[0])/1000).min()  # 最小值
vmax = ((vp_tensor[0])/1000).max()  # 最大值
im.set_clim(vmin, vmax)
ax.set_xlabel("Distance x[km]", fontsize=fontsize)
ax.set_title('v_pred_mean', fontsize=fontsize)
ax.set_ylabel("Depth z[km]",fontsize=fontsize)
ax.tick_params(axis='both', which='major', labelsize=fontsize)


ax = fig.add_subplot(gs[0, 1])
# ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred_std")
im = ax.imshow((torch.std(v_samp/1000, axis=0).cpu().detach().numpy()), extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=0.8, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%")
cbar.ax.set_title('km/s', size=fontsize)
cbar.ax.tick_params(labelsize=fontsize)
# 设置colorbar的取值范围
vmin = 0  # 最小值
vmax = ((torch.std(v_samp/1000, axis=0))).max()  # 最大值
im.set_clim(vmin, vmax)
ax.set_xlabel("Distance x[km]", fontsize=fontsize)
ax.set_title('v_pred_std', fontsize=fontsize)
ax.set_ylabel("Depth z[km]",fontsize=fontsize)
ax.tick_params(axis='both', which='major', labelsize=fontsize)


ax = fig.add_subplot(gs[1, 0])
# ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vp_tensor")
im = ax.imshow((vp_tensor[0].cpu().detach().numpy())/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=0.8, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%")
cbar.ax.set_title('km/s', size=fontsize)
cbar.ax.tick_params(labelsize=fontsize)
# 设置colorbar的取值范围
vmin = ((vp_tensor[0])/1000).min()  # 最小值
vmax = ((vp_tensor[0])/1000).max()  # 最大值
im.set_clim(vmin, vmax)
ax.set_xlabel("Distance x[km]", fontsize=fontsize)
ax.set_title('vp', fontsize=fontsize)
ax.set_ylabel("Depth z[km]",fontsize=fontsize)
ax.tick_params(axis='both', which='major', labelsize=fontsize)


ax = fig.add_subplot(gs[1, 1])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vp-v_pred_mean")
im = ax.imshow(((vp_tensor[0].cpu()-torch.mean(v_samp, axis=0).cpu()).detach().numpy())/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=0.8, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%")
cbar.ax.set_title('km/s', size=fontsize)
cbar.ax.tick_params(labelsize=fontsize)
# cbar.set_label('km/s', size=fontsize)
# 设置colorbar的取值范围
vmin = -2.8412
vmax = 2.8412
im.set_clim(vmin, vmax)
ax.set_xlabel("Distance x[km]", fontsize=fontsize)
ax.set_title('vp-v_pred_mean', fontsize=fontsize)
ax.set_ylabel("Depth z[km]",fontsize=fontsize)
ax.tick_params(axis='both', which='major', labelsize=fontsize)