In [None]:
import os 

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
import fwi
# import fwiLBFGS as fwi
import torch
import time 
import matplotlib.pylab  as plt
from util import * 
import deepwave
from scipy import signal
from skimage.transform import resize, rescale 
from  scipy.ndimage import gaussian_filter


# =================================================================== #

In [None]:
# ========================== Main  ============================== #

device = torch.device('cuda:0')

# mtrue = mtrue.T
# ============================ setting parameters =============================#

# Define the model and achuisition parameters
par = {     'nx':1685,   'dx':0.02, 'ox':0,
            'nz':201,   'dz':0.02, 'oz':0,
#             'ns':400,   'ds':0.0825,   'osou':0,  'sz':0.06,
            'ns':360,   'ds':0.055,   'osou':7,  'sz':0.06,
            # 'ns':200,   'ds':0.165,   'osou':0,  'sz':0.06,
            # 'ns':1,   'ds':1,   'osou':16.85,  'sz':0.06,
            'nr':842,'dr':0.04,  'orec':0,    'rz':0.06,
            'nt':4000,  'dt':0.002,  'ot':0,
#             'nt':1250,  'dt':0.004,  'ot':0,
#             'nt':625,  'dt':0.008,  'ot':0,
            'freq':10,
            'FWI_itr': 500,
            'num_dims':2
      }


par['mxoffset']= 6 
# par['mxoffset']= 'full'
# par['nr'] = int((2 * par['mxoffset'])//(par['dr'])) + 1  
# par['orec'] = par['osou'] - par['mxoffset']



    
fs = 1/par['dt'] # sampling frequency

# par ['batch_size'] =4
par ['batch_size'] =1
par ['num_batches'] = par['ns']//par ['batch_size'] 
 
# Don't change the below two lines 
num_sources_per_shot=1

# Mapping the par dictionary to variables 
for k in par:
    locals()[k] = par[k]

In [None]:
# ============================ I/O =============================#

# True model
path = './'
velocity_path = './velocity/'
data_path = './data/'
# i/o files
vel_true =velocity_path+'bp_full_fixed.npy' # true model 
data_file = data_path +f'data{ns}.npy'
wavel_file = data_path+'wavel.npy'
# Output parameter 

fwi_pass= 1
minF = 3
maxF = 7
TV_FLAG = True 
# TV_ALPHA = 1
TV_ALPHA = 0.1
smth1 = 1e-7
smth2 = 1e-7
opt='Adam'
# fwi_pass= 1
# minF = 3
# maxF = 7
# TV_FLAG = True 
# TV_ALPHA = 0.01
# smth1 = 2
# smth2 = 5

# inv_file=f"BPfull_1stinv_TV{TV_ALPHA}_offs{par['mxoffset']}_DomFreq{par['freq']}_MinFreq{minF}_MaxFre{maxF}_fwi{fwi_pass}_smth{smth1}-{smth2}"



inv_file=f"BPfull_1stinv_TV{TV_ALPHA}_offs{par['mxoffset']}_DomFreq{par['freq']}_MinFreq{minF}_MaxFre{maxF}_fwi{fwi_pass}_smth{smth1}-{smth2}_nt{par['nt']}-ns{par['ns']}-opt{opt}-1Dtemp"
output_file = velocity_path+inv_file
    
    
mtrue = np.load(vel_true)
# mtrue = mtrue.T




In [None]:
# ============================ Forward modelling =============================#
# convert to tensor

mtrue = torch.tensor(mtrue,dtype=torch.float32)

# Convert to 1D 
mtrue = np.repeat(mtrue[:,150].reshape(nz,-1),axis=1,repeats=nx) 
# mtrue[mtrue>=4.] = 3

# initiate the fwi class
inversion = fwi.fwi(par,2)

# xr_corr = inversion.r_cor[:,:,1]
# xr_corr [xr_corr < 0 ] = 0
# xr_corr [xr_corr > (nx-1)*dx ] = (nx-1)*dx  # last point in the model 
# inversion.r_cor[:,:,1] =  xr_corr




In [None]:
wavel = inversion.Ricker(freq)  
data = torch.zeros((nt,ns,nr),dtype=torch.float32)
data = inversion.forward_modelling(mtrue,wavel.repeat(1,ns,num_sources_per_shot),device).cpu()



torch.cuda.empty_cache()


In [None]:
# import pickle
# np.save(data_file,data)
# np.save(wavel_file,wavel)
# with open(f'{data_file}.pickle', 'wb') as f:
#     pickle.dump(par, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:

## Load the saved data 
# data = np.load(data_file)
# wavel = np.load(wavel_file)


In [None]:
shot = 0
plt.figure(figsize=(5,10))
vmin, vmax = np.percentile(data[:,shot], [2,98])
plt.imshow(data[:,shot,:],cmap='seismic',vmin=vmin,vmax=-vmin,
           extent=[inversion.r_cor[shot,0,1].numpy(),inversion.r_cor[shot,nr-1,1].numpy(),par['nt']*par['dt'],par['ot']])
plt.axis('tight')
plt.xlabel('Position (Km)',weight='heavy')
plt.ylabel('Time (s)',weight='heavy')
plt.title(f'shot at {inversion.s_cor[shot,0,1]}')
plt.gca().axes.get_xaxis().set_label_position('top')
plt.gca().axes.get_xaxis().tick_top()


plt.figure(figsize=(8,2))
plt.plot(np.arange(par['ot'],par['nt'])*par['dt'],wavel[:,0,0],linewidth=2,color='r')


In [None]:


# filter frequencies 
wavel_f = freq_filter(freq=[3,7],wavelet=wavel,btype='bp',fs=fs)
data_f = freq_filter(freq=[3,7],wavelet=data,btype='bp',fs=fs)



In [None]:
# ======================= Plotting Frequency Spectrum ============================= #

# Plot data 
plt.figure(figsize=(5,10))
vmin, vmax = np.percentile(data[:,0], [2,98])
plt.imshow(data[:,int(ns//2),:],cmap='seismic',vmin=vmin,vmax=-vmin,extent=[0,par['dr']*par['nr'],par['nt']*par['dt'],par['ot']])
plt.axis('tight')
plt.xlabel('Offset (Km)',weight='heavy')
plt.ylabel('Time (s)',weight='heavy')
plt.gca().axes.get_xaxis().set_label_position('top')
plt.gca().axes.get_xaxis().tick_top()
plt.savefig('./Fig/shot',bbox_inches='tight')


plt.figure(figsize=(5,10))
vmin, vmax = np.percentile(data[:,0], [2,98])
plt.imshow(data_f[:,int(ns//2),:],cmap='seismic',vmin=vmin,vmax=-vmin,extent=[0,par['dr']*par['nr'],par['nt']*par['dt'],par['ot']])
plt.axis('tight')
plt.xlabel('Offset (Km)',weight='heavy')
plt.ylabel('Time (s)',weight='heavy')
plt.gca().axes.get_xaxis().set_label_position('top')
plt.gca().axes.get_xaxis().tick_top()
plt.savefig('./Fig/shot_filter',bbox_inches='tight')
# Plotting shot spectrum
plt.figure()
YY = []
FF = []
for i in range(data.shape[2]):
    Y = plt.magnitude_spectrum(data_f[:,0,i].numpy().flatten(),Fs=fs,color='r')
    plt.xlim([0,30])
    YY.append(Y[0])
    FF.append(Y[1])
    # plt.fill_between(Y[1],Y[0],color='r')
YYY = np.array(YY)
FFF = np.array(FF)

plt.figure(figsize=(10,3))
plt.plot(FFF[0,:],np.average(YYY,axis=0),color='r')
plt.xlim([0,30])
plt.xlabel('Frequency (Hz)',weight='heavy')
plt.ylabel('Amplitude',weight='heavy')
plt.fill_between(FFF[0,:],np.average(YYY,axis=0),color='r')
plt.grid(which='both')
plt.savefig('./Fig/shot_spectrum',bbox_inches='tight')




# Plotting wavelet and its spectrum 
wav = wavel_f.numpy()[:,-1]
plt.figure(figsize=(10,3))
plt.plot(np.arange(wav.shape[0])*dt,wav,color='k')
plt.xlabel('Time (s)',weight='heavy')
plt.ylabel('Amplitude',weight='heavy')
# plt.gca().axes.get_yaxis().set_visible(False)
plt.savefig('./Fig//wavelet',bbox_inches='tight')

plt.figure()
Y = plt.magnitude_spectrum(wav.flatten(),Fs=fs)
YY = Y[0]
FF = Y[1]

plt.figure(figsize=(10,3))
plt.plot(FF,YY,color='r')
plt.xlim([0,30])
plt.xlabel('Frequency (Hz)',weight='heavy')
plt.ylabel('Amplitude',weight='heavy')
plt.fill_between(FF,YY,color='r')
plt.grid(which='both')
plt.savefig('./Fig/wavelet_frequency',bbox_inches='tight')

## Resample


In [None]:
# wavel_res = wavel_f[::4,]
# data_res = data_f [::4,]
# assert data_res.shape[0] ==wavel_res.shape[0], "shape mismatch in the nt for data and wavelet"

# # par['nt-old'] = par['nt']
# par['nt-old'] = 2500
# par['nt'] = data_res.shape[0]
# par['dt'] = (nt*dt)/par['nt']

# # update the variables
# # Mapping the par dictionary to variables 
# for k in par:
#     locals()[k] = par[k]
    
# inversion.nt = nt 
# inversion.dt = dt 
    
# print(wavel_res.shape, data_res.shape)

In [None]:
# fig, ax = plt.subplots(1,2,figsize=(8,4))

# ax[0].imshow(data_f[:,int(ns//2),:],cmap='seismic',vmin=vmin,vmax=-vmin,extent=[0,par['dr']*par['nr'],par['nt']*par['dt'],par['ot']])
# ax[0].axis('tight')
# ax[0].set_xlabel('Offset (Km)',weight='heavy')
# ax[0].set_ylabel('Time (s)',weight='heavy')
# ax[1].imshow(data_res[:,int(ns//2),:],cmap='seismic',vmin=vmin,vmax=-vmin,extent=[0,par['dr']*par['nr'],par['nt']*par['dt'],par['ot']])
# ax[1].axis('tight')
# ax[1].set_xlabel('Offset (Km)',weight='heavy')
# ax[1].set_ylabel('Time (s)',weight='heavy')


# fig, ax = plt.subplots(1,2,figsize=(13,1))
# ax[0].plot(wavel_f[:,0,0])
# ax[1].plot(wavel_res[:,0,0])




In [None]:
# ================ Plotting =============== # 
# Plot the true models and initial     
print(mtrue.shape)
Plot_model(mtrue,par)

print(data.shape)
plt.figure(figsize=(5,10))
vmin, vmax = np.percentile(data[:,0], [2,98])
plt.imshow(data_f[:,ns//2,:],cmap='seismic',vmin=vmin,vmax=-vmin,extent=[0,par['dr']*par['nr'],par['nt']*par['dt'],par['ot']])
plt.axis('tight')
plt.xlabel('Offset (Km)',weight='heavy')
plt.ylabel('Time (s)',weight='heavy')
plt.gca().axes.get_xaxis().set_label_position('top')
plt.gca().axes.get_xaxis().tick_top()




In [None]:
# ========================= Cretae initial model =================== # 
# mask 
msk = mask(mtrue.numpy(),1.5)

bp_mean = np.nanmean(mtrue,axis=1)
bp_mean = bp_mean.reshape(-1,1)

# minit = gaussian_filter(mtrue,sigma=10)


# minit = np.load(velocity_path+'BPfull_1stinv_TV0.01_offs6_DomFreq10_MinFreq3_MaxFre7_fwi1_smth10-15_nt2500-1Dtemp.npy')
# minit=minit[200]


minit =  np.repeat(bp_mean,nx,axis=1)
minit = minit * msk
minit[minit==0] = 1.5


## constant init 
minit = msk.copy()
minit = minit.astype(np.float32())
for ix in range (nx):
    iz = np.where(minit[:,ix] > 0)[0][0]
    minit[iz:,ix]  = mtrue[iz,ix]
minit [minit == 0] = 1.5


# minit = './velocity/BPfull_1stinv_TV0_offs6_DomFreq10_MinFreq3_MaxFre7_fwi1_smth10-15temp.npy'
# minit = np.load(minit)[-1]

Plot_model(minit,par)

In [None]:

# Convert to torch
minit = torch.tensor(minit,dtype=torch.float32)
data_f = torch.tensor(data_f,dtype=torch.float32)

wavel_f = torch.tensor(wavel_f,dtype=torch.float32)


In [None]:
plt.imshow(data_f[:,20,:],vmin=-1e-7,vmax=1e-7,cmap='gray')
plt.colorbar()
plt.axis('tight')

In [None]:
# ======================= Saving data and wavelet ============================= #
# # %%
# data_save = data.clone().permute(1,2,0).numpy()
# save_3ddata(data_save,par,'./data_mod.rsf')
# save_1drsf(wavel.clone().numpy(),par,'./wavl.rsf')

In [None]:


minv,loss = inversion.run_inversion(minit,data_f,wavel_f.repeat(1,ns,1),
            msk,FWI_itr,device,smth_flag=True,smth=[smth1,smth2],vmin=1.5,vmax=4.5,
            tv_flag=TV_FLAG,alphatv=TV_ALPHA,plot_flag=True,Method="")

In [None]:
plt.figure(figsize=(10,3))
plt.imshow(minv[0,]-minit.numpy(),cmap='seismic'
          ,extent=[par['ox'],par['dx']*par['nx'],par['nz']*par['dz'],par['oz']])
plt.colorbar()
plt.scatter(inversion.r_cor[:,:,1],inversion.r_cor[:,:,0],marker='.',s=0.3,c='r')
plt.scatter(inversion.s_cor[:,:,1],inversion.s_cor[:,:,0],marker='x',s=50,c='b')
plt.axis('tight')



plt.figure()
plt.plot(minv[0,:,800]-minit.numpy()[:,800])


In [None]:
plt.plot(loss)
plt.savefig('loss')

In [None]:
Plot_model(minv[-1,],par,vmin=1.5,vmax=3)
# Plot_model(minv[-1,],par)
plt.plot(minv[-1,:,1000])
plt.plot(mtrue[:,1000])


In [None]:
# save_2drsf(minv[-1,].T,par,output_file)
# np.save(output_file,(minv))

In [None]:
minit2  =minv[-1,].copy()