In [None]:
import os, sys
sys.path.insert(1, './codes/codes')  # insert at 1, 0 is the script path (or '' in REPL)

print("----------------")
!python --version
!nvidia-smi
print("----------------")
print("System Version: ", sys.version)

## ======================================================== ##
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from plot_functions import add_colorbar, imagesc

print("PyTorch Version: ", torch.__version__)
print("----------------")
print("torch.cuda.is_available: ",torch.cuda.is_available())
print("----------------")
print(torch.__version__, torch.version.cuda, torch.cuda.get_device_name(0))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# deterministic behavior
torch.manual_seed(3)
torch.cuda.manual_seed_all(3)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(3)
# random.seed(3)
os.environ['PYTHONHASHSEED'] = str(3)

In [None]:
# deterministic behavior
def seed_everything(seed):
    import os
    import torch
    import random
    import numpy as np
    random.seed(seed)            # Python random module
    np.random.seed(seed)         # Numpy Module
    torch.manual_seed(seed)      # Current CPU
    torch.cuda.manual_seed(seed) # Current GPU
    torch.cuda.manual_seed_all(seed) # All GPU
    torch.backends.cudnn.benchmark = False    # Close Optimization
    torch.backends.cudnn.deterministic = True # Close Optimization
    os.environ['PYTHONHASHSEED'] = str(seed)
    

# 设置一个种子
seed = 37
# 调用函数以设置种子
seed_everything(seed)


In [None]:
file = np.load("./SEGoverthrust.npz")
vel3d = file['vmodel'].transpose(1, 0, 2)/1000  #z, x, y


In [None]:
vel3d.shape

In [None]:
import pandas as pd
from scipy.ndimage import gaussian_filter
vmodel = np.array(pd.read_csv("./vel_marmousi_376x1151.csv")) 
vmodel.shape

In [None]:
import pandas as pd
from scipy.ndimage import gaussian_filter

file = np.load("./SEGoverthrust.npz")
vel3d = file['vmodel'].transpose(1, 0, 2)
vmodel = vel3d[:, 100, :]
v_init = gaussian_filter(vmodel, sigma=15)

dz = 8
nz, nx = vmodel.shape
print("Original Model Shape: {}, Grid Interval: {}m".format(vmodel.shape, dz))

################# Plot true & initial velocity model #################
fig = plt.figure(figsize=(12, 2))
gs = fig.add_gridspec(1, 2)
ax = fig.add_subplot(gs[0, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]")
im = ax.imshow(vmodel/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1.3, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

ax = fig.add_subplot(gs[0, 1])
ax.set( xlabel="Distance x[km]")
im = ax.imshow(v_init/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1.3, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')
################# Plot true & initial velocity model #################

dz = 25
vp_tensor_init = torch.from_numpy(vmodel).type(dtype=torch.float32).to(device)
vi_tensor_init = torch.from_numpy(v_init).type(dtype=torch.float32).to(device)
print(vp_tensor_init.shape)
vp_tensor = vp_tensor_init[None, 0:128, 0:768:3]
vi_tensor = vi_tensor_init[None, 0:128, 0:768:3]

vi_tensor = torch.from_numpy(gaussian_filter(vi_tensor.cpu().numpy(), sigma=13)).type(dtype=torch.float32).to(device)

nv, nz, nx = vp_tensor.shape
print("Resampled Model Shape: {}, Grid Interval: {}m".format((nz, nx), dz))

# Setting locations of sources and receivers
xs = torch.arange(15, nx-10, 13, dtype=torch.long).repeat([nv, 1])      # x-coordinate for sources
ns = xs.shape[1]                                                        # number of shots 
xr = torch.arange(0, nx, 1, dtype=torch.long).repeat([nv, ns, 1])       # x-coordinate for receivers
zs = torch.full((nv, ns), 0, dtype=torch.long)                          # depth of sources    震源深度
zr = torch.full((nv, ns, nx), 0, dtype=torch.long)                      # depth of receivers 地表
print("Number of shots: {}, with interval: {}m, in depth: {}m".format(ns, (xs[0, 1]-xs[0, 0])*dz, zs[0, 0]*dz))
print(xs)

################# Plot true & initial velocity model #################
fig = plt.figure(figsize=(12, 2))
gs = fig.add_gridspec(1, 2)
ax = fig.add_subplot(gs[0, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]")
im = ax.imshow(vp_tensor[0].cpu().numpy()/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=0.8, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

ax = fig.add_subplot(gs[0, 1])
ax.set( xlabel="Distance x[km]")
im = ax.imshow(vi_tensor[0].cpu().numpy()/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=0.8, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')
# ################# Plot true & initial velocity model #################

In [None]:
vp_tensor.max()

In [None]:
vp_tensor.min()

In [None]:
from rnn_fd import rnn2D
from generator_old import wGenerator #单频雷克子波

freeSurface = True                                                      # free surface option for forward modeling
npad = 15                                                               # velocity padding in grid points
freq = 6                                                              # dominant frequency of wavelet in Hz
dt = 0.003                                                            # time samling interval, fixed for all shots gathers
nt = 1024                                                              # number of samples in time

t = dt * torch.arange(0, nt, dtype=torch.float32)                       # create time vector
wavelet = wGenerator(t, freq).ricker().to(device)                       # generate wavelet
nx_pad = nx + 2 * npad
nz_pad = nz + npad if freeSurface else nz + 2 * npad
f = np.arange(0, nt/2+1) / (nt*dt)

fig = plt.figure(figsize=(10, 2.5))
gs = fig.add_gridspec(1, 2)
ax = fig.add_subplot(gs[0, 0])
ax.set(xlabel="$Time$", ylabel="$Amp$", title="$Ricker$", xlim=[0, 1])
ax.plot(t, wavelet.cpu().numpy(), color='red', linestyle='-', linewidth=1.5)
ax.grid(True, which='both', linestyle='--', color='grey', linewidth=.8, alpha=1.0)
ax.minorticks_on()

ax = fig.add_subplot(gs[0, 1])
ax.set(xlabel="$Frequency [Hz]$", ylabel="$Amp$", title="$Amp Spectrum$", xlim=[0, 40])
ax.plot(f, np.abs(torch.fft.rfft(wavelet).cpu().numpy()), color='red', linestyle='-', linewidth=1.5)
ax.grid(True, which='both', linestyle='--', color='grey', linewidth=.8, alpha=1.0)
ax.minorticks_on()

################## Check the stability condition #################
print(vp_tensor.max()*dt/dz/np.sqrt(1/2),"< 1") # should <1
print(vp_tensor.min()/10/freq/dz,"> 1") # should >1

forward_rnn = rnn2D(nz, nx, zs, xs, zr, xr, dz, dt, 
                    npad=npad, order=2, vmax=vp_tensor.max(), 
                    log_para=1e-6,
                    freeSurface=True, 
                    dtype=torch.float32, 
                    device=device).to(device)

# forward modeling
# _, _, shots, _ = forward_rnn(vmodel=vp_tensor.to(device), segment_wavelet=wavelet)
# _, _, shots_init, _ = forward_rnn(vmodel=vi_tensor.to(device), segment_wavelet=wavelet)

In [None]:
torch.save(shots, "./shots_overthrust_scale1024_big4Hz.npz")
shots = torch.load("./shots_overthrust_scale1024_big4Hz.npz", weights_only=True)

In [None]:
fig=plt.figure(figsize=(ns*1.5, 8))
imagesc(fig,
        shots.cpu().numpy().reshape(-1, ns, nt, nx),
        vmin=-shots.max()/50,
        vmax=shots.max()/50,
        extent=[0, nx*dz/1000, t.numpy().max(), 0],
        aspect=4,
        nRows_nCols=(1, 6),
        cmap='RdBu_r', #seismic
        ylabel="Time[s]",
        xlabel="Position[km]",
        clabel="",
        xticks=np.arange(0., int(nx*dz/1000), 2),
        yticks=np.arange(0., t.numpy().max(), .5),
        fontsize=15,
        cbar_width="7%",
        cbar_height="100%",
        cbar_loc='lower left')
fig.tight_layout(pad=-0.85)

In [None]:
#VAE
import torch
from torch import nn

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        
        self.conv = nn.Sequential(nn.Conv2d(in_channels, 
                                             out_channels, 
                                             kernel_size=k_size, 
                                             stride=stride, 
                                             padding=padding),
                                   nn.BatchNorm2d(out_channels),
                                   nn.LeakyReLU()
                                  )
        
    def forward(self, x):
        out = self.conv(x)
        return out
    
    
class TransConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, k_size=3, stride=2, padding=1):
        super(TransConvBlock, self).__init__()
        
        self.transconv = nn.Sequential(nn.ConvTranspose2d(in_channels, 
                                                          out_channels, 
                                                          kernel_size=k_size, 
                                                          stride=stride, 
                                                          padding=padding),
                                       nn.BatchNorm2d(out_channels),
                                       nn.LeakyReLU()
                                       )
        
    def forward(self, x):
        out = self.transconv(x)
        return out
    
    
class TransConvBlock_last(nn.Module):
    def __init__(self, in_channels, out_channels, k_size=3, stride=2, padding=1):
        super(TransConvBlock_last, self).__init__()
        
        self.transconv = nn.Sequential(nn.ConvTranspose2d(in_channels, 
                                                          out_channels, 
                                                          kernel_size=k_size, 
                                                          stride=stride, 
                                                          padding=padding),
                                       nn.BatchNorm2d(out_channels)
                                       )
        
    def forward(self, x):
        out = self.transconv(x)
        return out
        

class MyModel(nn.Module):
    def __init__(self, latent_dim=128):
        super(MyModel, self).__init__()    
        
        #encoder--------------------------------------------                                   
        self.conv_block_1 = ConvBlock(18, 32, k_size=(4,1), stride=(2,1), padding=(1,0))     
        self.conv_block_2 = ConvBlock(32, 64, k_size=(4,1), stride=(2,1), padding=(1,0)) 
        self.conv_block_3 = ConvBlock(64, 128, k_size=4, stride=2, padding=1)          
        self.conv_block_4 = ConvBlock(128, 256, k_size=4, stride=2, padding=1)         
        self.conv_block_5 = ConvBlock(256, 512, k_size=4, stride=2, padding=1)      
        self.conv_block_6 = ConvBlock(512, 512, k_size=4, stride=2, padding=1)   
        
        #decoder--------------------------------------------     
        self.trans_conv_block_1 = TransConvBlock(512, 256, k_size=(1,4), stride=(1,2), padding=(0,1))                
        self.trans_conv_block_2 = TransConvBlock(256, 128, k_size=4, stride=2, padding=1)                
        self.trans_conv_block_3 = TransConvBlock(128, 64, k_size=4, stride=2, padding=1)                
        self.trans_conv_block_4 = TransConvBlock(64, 32, k_size=4, stride=2, padding=1)                      
        self.trans_conv_block_5 = TransConvBlock_last(32, 1, k_size=3, stride=1, padding=1)                

        self.fc_mu = nn.Linear(16*16*512, latent_dim)
        self.fc_var = nn.Linear(16*16*512, latent_dim)
        self.decoder_input = nn.Linear(latent_dim, 16*16*512)

    def encode(self, x):
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        x = self.conv_block_3(x)
        x = self.conv_block_4(x)
        x = self.conv_block_5(x)    
        result = self.conv_block_6(x)    
        result = torch.flatten(result, start_dim=1)
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)
        
        return mu, log_var
        
        
    def decode(self, x):
        x = self.decoder_input(x)
        x = x.view(1, 512, 16, 16)
        x = self.trans_conv_block_1(x)
        x = self.trans_conv_block_2(x)
        x = self.trans_conv_block_3(x)
        x = self.trans_conv_block_4(x)
        x = self.trans_conv_block_5(x)
        
        return x
    
    
    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        v_pred = self.decode(z)  
        
        return v_pred, mu, log_var
    

In [None]:
from torchinfo import summary
model = MyModel()
inputs = (1, 18, 1024, 256)
summary(model, inputs)

In [None]:
vp_tensor.mean()

In [None]:
vp_tensor.std()

In [None]:
vp_tensor_norm = (vp_tensor - 3800)/800

In [None]:
fig = plt.figure(figsize=(20, 7))
gs = fig.add_gridspec(2, 2)
ax = fig.add_subplot(gs[0, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vp_tensor")
im = ax.imshow((vp_tensor[0].cpu().detach().numpy())/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = (vp_tensor[0]/1000).min() # 最小值
vmax = ((vp_tensor[0])/1000).max()   # 最大值
im.set_clim(vmin, vmax)


ax = fig.add_subplot(gs[0, 1])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vp_tensor_norm")
im = ax.imshow((vp_tensor_norm[0].cpu().detach().numpy()), extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = (vp_tensor_norm[0]).min() # 最小值
vmax = (vp_tensor_norm[0]).max() # 最大值
im.set_clim(vmin, vmax)

In [None]:
import torchvision
# import torchvision.transforms as transforms
# torchvision.models.vgg16().features

def train(model_Jin, data, wavelet, vi_tensor, optimizer, device='cpu'):
    
    data_resample = data[:, :, :, :]   #输入的是resample后的
    data_norm = (data_resample-data.mean())/(data.std()).to(device)
   
    train_loss = 0
    data_norm = data_norm.to(device)
    optimizer.zero_grad()

    v_recon_norm, mu, logvar = model_Jin(data_norm) 
    v_recon = (v_recon_norm)*(1000) + 2900  
    v_recon = v_recon.to(device).type(dtype=torch.float32)
    
    loss_fn2 = nn.MSELoss()
    Loss2 = loss_fn2(v_recon[0], vi_tensor)
    loss = Loss2
    
    train_loss += loss.detach().cpu().item()
    loss.backward()
    optimizer.step()
    
    return [train_loss], v_recon[0], data

def save_state(epoch, model, optimizer, train_loss_history):
    import os
    state = {'epoch': epoch + 1,
             'train_loss_history': train_loss_history,
             'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict()
             }
    torch.save(state, "./checkpoints-VAE-overthrust-G-scale/checkpoint-VAE-pre-overthrust-1-{}.pth".format(epoch+1))
    if os.path.exists("./checkpoints-VAE-overthrust-G-scale/checkpoint-VAE-pre-overthrust-1-{}.pth".format(epoch)):
        os.remove("./checkpoints-VAE-overthrust-G-scale/checkpoint-VAE-pre-overthrust-1-{}.pth".format(epoch))

def load_state(model, optimizer, resume_file):
    checkpoint = torch.load(resume_file)
    resume_epoch = checkpoint['epoch']
    train_loss_history = checkpoint['train_loss_history']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return resume_epoch, model, optimizer, train_loss_history   


In [None]:
from torch.optim.lr_scheduler import StepLR
# 训练模型
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
data = torch.load("./shots_overthrust_scale1024_big4Hz.npz").to(device).type(dtype=torch.float32)  #[1,12,1024,256]

model = MyModel().to(device)
optimizer = torch.optim.Adam([{'params': model.parameters()}], 
                               lr=0.0001, weight_decay=0.0001)

max_epoch = 1000
train_loss_history = []
print_interval = 5
for epoch in range(0, max_epoch):
   
    train_loss, v_pred, s_pred  = train(model, data, wavelet, vi_tensor, optimizer, device)
    train_loss_history.append(train_loss)
    save_state(epoch, model, optimizer, train_loss_history)
   
    if (epoch + 1) % print_interval == 0:  #取余数是否等于0
        print("Epoch: {}, Training Loss: {:.8f}".format(epoch, train_loss[0]))

train_loss_history = torch.tensor(train_loss_history)


In [None]:
fig = plt.figure(figsize=(20, 7))
gs = fig.add_gridspec(2, 2)

ax = fig.add_subplot(gs[0, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred_full")
im = ax.imshow((v_pred[0].cpu().detach().numpy())/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')
# 设置colorbar的取值范围
vmin = (vi_tensor[0]/1000).min() # 最小值
vmax = ((vi_tensor[0])/1000).max()   # 最大值
im.set_clim(vmin, vmax)


ax = fig.add_subplot(gs[1, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vi")
im = ax.imshow(((vi_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')


ax = fig.add_subplot(gs[1, 1])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="loss")
im = ax.imshow(((v_pred[0]-vi_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')

# 设置colorbar的取值范围
vmin = -((v_pred[0]-vi_tensor[0])/1000).max()  # 最小值
vmax = ((v_pred[0]-vi_tensor[0])/1000).max()   # 最大值
im.set_clim(vmin, vmax)

In [None]:
def train(model, data, wavelet, vp_tensor, optimizer, device='cpu'):
    torch.autograd.set_detect_anomaly(True)
    model.train()
    
    lamda1 = lamda2 = lamda3 = lamda4 = 1
    
    data_resample = data[:, :, :, :] 
    data_norm = (data_resample-data.mean())/(data.std()).to(device)
    
    train_loss = 0
    Loss_val = 0
    optimizer.zero_grad()

    v_recon_norm, mu, logvar= model(data_norm) 
    v_recon = (v_recon_norm)*(800) + 3800  
    
    v_recon[v_recon < 2300] = 2300
    v_recon[v_recon > 5500] = 5500
    v_recon[torch.isnan(v_recon)] = 0
    v_recon[v_recon == float('inf')] = 0

    _, _, shots, _ = forward_rnn(vmodel = v_recon[0].to(device), segment_wavelet=wavelet)
    shots[torch.isnan(shots)] = 0
    shots[shots == float('inf')] = 0
    shots_resample = shots[:, :, :, :]

    data_rgb = data_resample[0][:,None,].repeat(1, 3, 1, 1)    
    shots_rgb = shots_resample[0][:,None,].repeat(1, 3, 1, 1)  
    data_rgb_norm = (data_rgb-data_rgb.mean())/(data_rgb.std()) 
    shots_rgb_norm = (shots_rgb-data_rgb.mean())/(data_rgb.std())
    data_vgg = model_vgg(data_rgb_norm)
    shots_vgg = model_vgg(shots_rgb_norm)
    
    loss_fn1 = nn.L1Loss()
    loss_fn2 = nn.MSELoss()
   
    Loss1 = loss_fn1(data_resample, shots_resample)
    Loss2 = loss_fn2(data_resample, shots_resample)
    L_recon_pixel = lamda1*Loss1 + lamda2*Loss2
    
    Loss3 = loss_fn1(data_vgg, shots_vgg)
    Loss4 = loss_fn2(data_vgg, shots_vgg)
    L_recon_perceptual = lamda3*Loss3 + lamda4*Loss4
    
    KLD = -0.5 * torch.sum(1 + torch.log((torch.exp(0.5 * logvar)).pow(2)) - mu.pow(2) - (torch.exp(0.5 * logvar)).pow(2))  #先验 N(0,1)

    loss = L_recon_pixel + L_recon_perceptual + KLD
    
    train_loss += loss.detach().cpu().item()
    loss.backward()
    
    optimizer.step()
    
    loss_fn2 = nn.MSELoss()
    Loss_val = loss_fn2(v_recon[0], vp_tensor)
    Loss_val += Loss_val.detach().cpu().item()
    return [train_loss, L_recon_pixel, L_recon_perceptual, KLD, Loss_val], v_recon[0], shots, mu, logvar 


def test(model, data, device='cpu'):
    model.eval()
    data_resample = data[:, :, :, :]  
    data_norm = (data_resample-data.mean())/(data.std()).to(device)
    
    v_recon_norm, mu, logvar= model(data_norm)
    v_recon = (v_recon_norm)*(800) + 3800  
   
    return v_recon[0], mu, logvar 
    
def save_state(epoch, model, optimizer, train_loss_history):
    import os
    state = {'epoch': epoch + 1,
             'train_loss_history': train_loss_history,
             'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict()
             }
    torch.save(state, "./checkpoints-VAE-overthrust-G-scale/checkpoint-VAE-overthrust-1-{}.pth".format(epoch+1))
    if os.path.exists("./checkpoints-VAE-overthrust-G-scale/checkpoint-VAE-overthrust-1-{}.pth".format(epoch)):
        os.remove("./checkpoints-VAE-overthrust-G-scale/checkpoint-VAE-overthrust-1-{}.pth".format(epoch))
        
def load_state(model, optimizer, resume_file):
    checkpoint = torch.load(resume_file)
    resume_epoch = checkpoint['epoch']
    train_loss_history = checkpoint['train_loss_history']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return resume_epoch, model, optimizer, train_loss_history   


In [None]:
from torch.optim.lr_scheduler import StepLR
import torchvision

# 训练模型
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
data = torch.load("./shots_overthrust_scale1024_big4Hz.npz").to(device).type(dtype=torch.float32)

model = MyModel().to(device)
optimizer = torch.optim.Adam([{'params': model.parameters()}], 
                               lr=0.0001, weight_decay=0.0001)
_, model, _, _ = load_state(model, optimizer, "./checkpoints-VAE-overthrust-G-scale/checkpoint-VAE-pre-overthrust-1-1000.pth") 
model = model.to(device)
optimizer = torch.optim.AdamW([{'params': model.parameters()}], 
                               lr=0.00001, betas=(0.9, 0.999), weight_decay=0.0001)

model_vgg = torchvision.models.vgg16(pretrained=True).features[:13].to(device) #第13层输出--conv5的结果。

max_epoch = 3000
train_loss_history = [[], [], [], [], []] 
print_interval = 1
save_interval = 100
for epoch in range(0, max_epoch):
   
    train_loss, v_pred, s_pred, mu, logvar = train(model, data, wavelet, vp_tensor, optimizer, device)
    for i in range(5):
        train_loss_history[i].append(train_loss[i]) 
      
    if epoch == 0 or (epoch < 100 and (epoch + 1) % 10 == 0) or (epoch >= 100 and (epoch + 1) % 100 == 0):
        save_state(epoch, model, optimizer, train_loss_history)
    
    if (epoch + 1) % print_interval == 0:
        print("Epoch: {}, Training Loss: {:.8f}, recon_pixel Loss: {:.8f}, recon_perceptual Loss: {:.8f}, KLD Loss: {:.8f}, recon_vmodel Loss: {:.8f}".format(epoch, train_loss[0], train_loss[1], train_loss[2], train_loss[3], train_loss[4]))

train_loss_history = torch.tensor(train_loss_history)

In [None]:
train_loss_history = torch.tensor(train_loss_history)
train_loss_history.shape

In [None]:
n=3000
plt.figure(figsize=(15, 5))
plt.subplot(1, 1, 1)
# plt.plot(np.arange(0,n), train_loss_history[0,0:n].numpy(), label="all")
plt.plot(np.arange(0,n), train_loss_history[1,0:n].numpy(), label="recon_pixel")
plt.plot(np.arange(0,n), train_loss_history[2,0:n].numpy(), label="recon_percep")
plt.plot(np.arange(0,n), train_loss_history[3,0:n].numpy(), label="KLD")
plt.legend()


In [None]:
n=300
plt.figure(figsize=(15, 5))

plt.subplot(1, 1, 1)
plt.plot(np.arange(0,n), np.log(train_loss_history[3,0:n].numpy())/1000, label="KLD")
plt.legend()


In [None]:
n = 3000
plt.figure(figsize=(15, 5))
plt.subplot(1, 1, 1)
plt.plot(np.arange(0,n), train_loss_history[4,0:n].numpy(), label="recon_vp")

plt.legend()


In [None]:
fig = plt.figure(figsize=(20, 15))
gs = fig.add_gridspec(3, 2)
ax = fig.add_subplot(gs[0, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred")
im = ax.imshow((v_pred[0].cpu().detach().numpy())/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')
# 设置colorbar的取值范围
vmin = ((vp_tensor[0])/1000).min()  # 最小值
vmax = ((vp_tensor[0])/1000).max()  # 最大值
im.set_clim(vmin, vmax)

ax = fig.add_subplot(gs[1, 1])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vp")
im = ax.imshow(((vp_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')
# 设置colorbar的取值范围
vmin = ((vp_tensor[0])/1000).min()  # 最小值
vmax = ((vp_tensor[0])/1000).max()  # 最大值
im.set_clim(vmin, vmax)

ax = fig.add_subplot(gs[1, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vi")
im = ax.imshow(((vi_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')
vmin = ((vp_tensor[0])/1000).min()  # 最小值
vmax = ((vp_tensor[0])/1000).max()  # 最大值
im.set_clim(vmin, vmax)

ax = fig.add_subplot(gs[2, 1])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred-vp")
im = ax.imshow(((v_pred[0]-vp_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')
# 设置colorbar的取值范围
vmin = -((v_pred[0]-vp_tensor[0])/1000).max()  # 最小值
vmax = ((v_pred[0]-vp_tensor[0])/1000).max()   # 最大值
im.set_clim(vmin, vmax)


ax = fig.add_subplot(gs[2, 0])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred-vi")
im = ax.imshow(((v_pred[0]-vi_tensor[0]).cpu().detach().numpy() )/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%", ctitle='km/s')
# 设置colorbar的取值范围
vmin = -((v_pred[0]-vi_tensor[0])/1000).max()   # 最小值
vmax = ((v_pred[0]-vi_tensor[0])/1000).max()   # 最大值
im.set_clim(vmin, vmax)



In [None]:
from torch.optim.lr_scheduler import StepLR
import torchvision

# 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data = torch.load("./shots_overthrust_scale1024.npz").to(device).type(dtype=torch.float32)  #[1,12,1024,256]

model = MyModel().to(device)
optimizer = torch.optim.AdamW([{'params': model.parameters()}], 
                               lr=0.00001, betas=(0.9, 0.999), weight_decay=0.0001)

_, model, _, train_loss_history = load_state(model, optimizer, "./checkpoints-VAE-overthrust-G-scale/checkpoint-VAE-overthrust-1-3000.pth") 
model_vgg = torchvision.models.vgg16(pretrained=True).features[:13].to(device) #第13层输出--conv5的结果。
model = model.to(device)


In [None]:
def test(model, data, device='cpu'):
    model.eval()
    data_resample = data[:, :, :, :]   #输入的是resample后的
    data_norm = (data_resample-data.mean())/(data.std()).to(device)
    with torch.no_grad():
        v_recon_norm, mu, logvar= model(data_norm)  #输出范围[-1,1]
        v_recon = (v_recon_norm)*(800) + 3800  
    return v_recon[0], mu, logvar 


In [None]:
samples = 500
v_samp = torch.zeros((samples, 128, 256))
for s in range(samples):
    v_pred, _, _ = test(model, data, device)
    v_samp[s] = v_pred


In [None]:
fig = plt.figure(figsize=(21, 9.5))
gs = fig.add_gridspec(2, 2)
fontsize=16
ax = fig.add_subplot(gs[0, 0])
# ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred_mean")
im = ax.imshow((torch.mean(v_samp, axis=0).cpu().detach().numpy())/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%")
cbar.ax.set_title('km/s', size=fontsize)
cbar.ax.tick_params(labelsize=fontsize)
# 设置colorbar的取值范围
vmin = ((vp_tensor[0])/1000).min()  # 最小值
vmax = ((vp_tensor[0])/1000).max()  # 最大值
im.set_clim(vmin, vmax)
ax.set_xlabel("Distance x[km]", fontsize=fontsize)
ax.set_title('v_pred_mean', fontsize=fontsize)
ax.set_ylabel("Depth z[km]",fontsize=fontsize)
ax.tick_params(axis='both', which='major', labelsize=fontsize)


ax = fig.add_subplot(gs[0, 1])
# ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="v_pred_std")
im = ax.imshow((torch.std(v_samp, axis=0).cpu().detach().numpy())/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%")
cbar.ax.set_title('km/s', size=fontsize)
cbar.ax.tick_params(labelsize=fontsize)
# 设置colorbar的取值范围
vmin = ((torch.std(v_samp, axis=0))/1000).min()  # 最小值
vmax = ((torch.std(v_samp, axis=0))/1000).max()  # 最大值
# vmax = 0.2  # 最大值
im.set_clim(vmin, vmax)
ax.set_xlabel("Distance x[km]", fontsize=fontsize)
ax.set_title('v_pred_std', fontsize=fontsize)
ax.set_ylabel("Depth z[km]",fontsize=fontsize)
ax.tick_params(axis='both', which='major', labelsize=fontsize)


ax = fig.add_subplot(gs[1, 0])
# ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vp_tensor")
im = ax.imshow((vp_tensor[0].cpu().detach().numpy())/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%")
cbar.ax.set_title('km/s', size=fontsize)
cbar.ax.tick_params(labelsize=fontsize)
# 设置colorbar的取值范围
vmin = ((vp_tensor[0])/1000).min()  # 最小值
vmax = ((vp_tensor[0])/1000).max()  # 最大值
im.set_clim(vmin, vmax)
ax.set_xlabel("Distance x[km]", fontsize=fontsize)
ax.set_title('vp', fontsize=fontsize)
ax.set_ylabel("Depth z[km]",fontsize=fontsize)
ax.tick_params(axis='both', which='major', labelsize=fontsize)


ax = fig.add_subplot(gs[1, 1])
ax.set(ylabel="Depth z[km]", xlabel="Distance x[km]", title="vp-v_pred_mean")
im = ax.imshow(np.abs((vp_tensor[0].cpu()-torch.mean(v_samp, axis=0).cpu()).detach().numpy())/1000, extent=[0, nx*dz/1000, nz*dz/1000, 0], aspect=1, cmap='RdBu_r')
cbar = add_colorbar(ax, im, ax.transAxes, width="3%")
cbar.ax.set_title('km/s', size=fontsize)
cbar.ax.tick_params(labelsize=fontsize)
# cbar.set_label('km/s', size=fontsize)
# 设置colorbar的取值范围
vmin = 0
vmax=((vp_tensor[0].cpu()-torch.mean(v_samp, axis=0)).cpu()).max()/1000
im.set_clim(vmin, vmax)
ax.set_xlabel("Distance x[km]", fontsize=fontsize)
ax.set_title('vp-v_pred_mean', fontsize=fontsize)
ax.set_ylabel("Depth z[km]",fontsize=fontsize)
ax.tick_params(axis='both', which='major', labelsize=fontsize)


In [None]:
torch.save(v_samp, './Figures/v_samp/v_samp_add/v_samp_VAE_over_scale4Hz.pth')
v_samp= torch.load('./Figures/v_samp/v_samp_add/v_samp_VAE_over_scale4Hz.pth').cpu()