# Vision Transformer learning DAS

In [None]:
import gc
import time
import h5py
import torch
import numpy as np
import torch.nn as nn

from functools import partial
from matplotlib import pyplot as plt
from numpy.random import default_rng
from scipy.signal import filtfilt, butter
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

## modules in the package
from das_util import try_gpu
from das_denoise_models import dataflow_nomask
from maevit_model_train import train_augmentation
from maevit_model_train import MaskedAutoencoderViT


## Prepare data for training

In [None]:
data_terra = '/fd1/QibinShi_data/akdas/qibin_data/TERRAtill2023_07_29.hdf5'
data_kkfls = '/fd1/QibinShi_data/akdas/qibin_data/KKFLStill2023_07_29.hdf5'
with h5py.File(data_terra, 'r') as f:
    quake1 = f['quake'][:]
with h5py.File(data_kkfls, 'r') as f:
    quake2 = f['quake'][:]
    
sample_rate = 25
delta_space = 10
ch_num = 7500
    
tmp = np.append(quake2[:,:ch_num,:], quake1[:,:ch_num,:], axis=0)
print(tmp.shape)
# %% divide time into 3 windows for smaller image size
data = np.reshape(tmp, (tmp.shape[0],ch_num,3,500)).swapaxes(1,2).reshape(-1,ch_num,500)
print(data.shape)

# %% Pre-filter to suppress strong long-period vibration
b, a = butter(4, (0.5, 12), fs=sample_rate, btype='bandpass')
filt = filtfilt(b, a, data, axis=2)
rawdata = filt / np.std(filt, axis=(1,2), keepdims=True)

In [None]:
# %% visualize data
time_data = rawdata[5]
plt.figure(figsize=(20, 6)); cmap=plt.cm.get_cmap('RdBu'); max_amp = np.median(np.fabs(time_data))
x, y=np.arange(time_data.shape[1]), np.arange(time_data.shape[0])
plt.pcolormesh(x, y, time_data, shading='auto', vmin=-max_amp, vmax=max_amp, cmap=cmap)
plt.xticks(np.arange(0, 1501, 250), np.arange(0, 1501/sample_rate, 250/sample_rate).astype(int))
plt.yticks(np.arange(0, ch_num+1, 1000), (np.arange(0, delta_space*(ch_num+1), 1000*delta_space)/1000).astype(int))
plt.xlabel("Time (s)", fontsize=20); plt.ylabel("X (km)", fontsize=20)
cbr=plt.colorbar(); cbr.set_label('amplitude', fontsize = 20)

In [None]:
# %% Shuffle and split dataset
X_tr, X, Y_tr, Y = train_test_split(rawdata, rawdata, train_size=0.7, random_state=111)
X_va, X_te, Y_va, Y_te = train_test_split(X, Y, train_size=0.5, random_state=121)

training_data = dataflow_nomask(X_tr, Nx_sub=500, stride=250)
validation_data = dataflow_nomask(X_va, Nx_sub=500, stride=250)
test_data = dataflow_nomask(X_te, Nx_sub=500, stride=250)

del data, filt, rawdata, tmp, X, Y, X_tr, Y_tr, X_va, Y_va, X_te, Y_te
gc.collect()

## Initialize the vision transformer

In [None]:
""" Initialize the U-net model """
model = MaskedAutoencoderViT(
    img_size=500, patch_size=25, in_chans=1,
    embed_dim=2048, depth=16, num_heads=16,
    decoder_embed_dim=2048, decoder_depth=16, decoder_num_heads=16,
    mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6))
            
devc = try_gpu(i=1)
model = nn.DataParallel(model, device_ids=[1,2,3])  # comment if gpus<4 
model.to(devc)

In [None]:
# %% Hyper-parameters for training
batch_size = 64
lr = 1e-4
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
train_iter = DataLoader(training_data, batch_size=batch_size, shuffle=False)
validate_iter = DataLoader(validation_data, batch_size=batch_size, shuffle=False)

# %% Training
model, \
avg_train_losses, \
avg_valid_losses = train_augmentation(train_iter,
                                   validate_iter,
                                   model,
                                   optimizer=optimizer,
                                   epochs=250,
                                   patience=20,
                                   device=devc,
                                   minimum_epochs=50)