In [3]:
import os
import gc
import glob
import h5py
import torch
import scipy
import numpy as np
import torch.nn as nn
from das_util import try_gpu
from matplotlib import pyplot as plt
from numpy.random import default_rng
from scipy.signal import filtfilt, butter
from scipy.signal.windows import tukey
from torch.utils.data import DataLoader
from das_denoise_models import unet, dataflow, datalabel
from das_denoise_training import train_augmentation
from sklearn.model_selection import train_test_split

""" Read a H5 data """
class read_one_h5(nn.Module):
    def __init__(self, h5_directory, Nx_sub=1280, Nt=1280, max_ch=5000, mask_ratio=0.5):
        """
        Args:
            h5_directory (string): Path to the folder containing all the h5 files.
            Nx_sub: number of consecutive rows (DAS channels) to extract as a sub-image
            max_ch: upper bound of channels when extracting the sub-image
            mask_ratio: percentage of channels to mask
            Nt: number of columns (time points) to extract
        """
        self.h5_filepaths = glob.glob(os.path.join(h5_directory, '*.h5'))
        self.Nx_sub = Nx_sub
        self.Nt = Nt
        self.max_ch = max_ch
        self.mask_ratio = mask_ratio

    def __len__(self):
        return len(self.h5_filepaths)

    def __getitem__(self, idx):
        h5_filepaths = self.h5_filepaths
        Nx_sub = self.Nx_sub
        Nt = self.Nt
        max_ch = self.max_ch
        mask_ratio = self.mask_ratio
        
        # locate the file to read
        h5_filepath = self.h5_filepaths[idx]
        # read the h5 file
        with h5py.File(h5_filepath, 'r') as f:
            time_data = f['Acquisition']['Raw[0]']['RawData'][:Nt, 100:(100+max_ch)]
        assert time_data.shape[1] == max_ch, "Not enough #ch to read, reduce max_ch!"
        assert Nx_sub <= max_ch, "max_ch is smaller than the #ch needed for the sample, reduce Nx_sub!"
        
        # slice a random portion as the sample
        st_ch = np.random.randint(low=0, high=max_ch - Nx_sub)
        tmp = np.zeros((Nx_sub, Nt), dtype=np.float32)
        tmp[:Nx_sub,:time_data.shape[0]] = time_data.T[st_ch:(st_ch+Nx_sub), :]
        
        b, a = butter(4, (0.5, 12), fs=sample_rate, btype='bandpass')
        filt = filtfilt(b, a, tmp, axis=-1)
        sample = filt / np.std(filt, axis=(0,1), keepdims=True)
        sample = sample.astype(np.float32)
        
        # create a mask for channels: 0 means to mask, 1 means to keep
        mask = np.ones((Nx_sub, Nt), dtype=np.float32)
        rng = np.random.default_rng()
        trace_masked = rng.choice(Nx_sub, size=int(mask_ratio * Nx_sub), replace=False)
        mask[trace_masked, :] = mask[trace_masked, :] * 0
        
        return (sample, mask), sample * (1 - mask)
    

""" Model """
model = unet(1, 16, 1024, factors=(5, 3, 2, 2), use_att=False)
devc = try_gpu(i=0)
model = nn.DataParallel(model, device_ids=[0,1,2,3])  # comment if gpus<4 
model.to(devc)

# %% Hyper-parameters for training
batch_size = 8
lr = 1e-4
loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
lr_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                         mode='min', 
                                                         factor=0.8, 
                                                         patience=5, 
                                                         threshold=0.001, 
                                                         threshold_mode='rel', 
                                                         cooldown=0, 
                                                         min_lr=1e-6, 
                                                         eps=1e-08, 
                                                         verbose=True)

""" Data """
data_terra = '/fd1/QibinShi_data/akdas/qibin_data/kkfln/homer-kkfln'
data_kkfln = '/fd1/QibinShi_data/akdas/qibin_data/kkfln/homer-kkfln'
### for data on Alaska server
# data_terra = '/mnt/qnap/TERRA_FiberA_25Hz'
# data_kkfln = '/mnt/qnap/KKFL-S_FIberA_25Hz'
sample_rate = 25
dchan = 10

training_data = read_one_h5(data_terra, Nx_sub=1500, Nt=1500, max_ch=5000, mask_ratio=0.5)
validation_data = read_one_h5(data_kkfln, Nx_sub=1500, Nt=1500, max_ch=5000, mask_ratio=0.5)

train_iter = DataLoader(training_data, batch_size=batch_size, shuffle=False)
validate_iter = DataLoader(validation_data, batch_size=batch_size, shuffle=False)

""" Training """
avg_train_losses, \
avg_valid_losses = train_augmentation(train_iter,
                                   validate_iter,
                                   model,
                                   loss_fn,
                                   optimizer,
                                   lr_schedule=lr_schedule,
                                   epochs=350,
                                   patience=20,
                                   device=devc,
                                   minimum_epochs=50)

[  1/350] train_loss: 12.75088 valid_loss: 0.53913 time per epoch: 3.224 s
[  2/350] train_loss: 0.53616 valid_loss: 0.50098 time per epoch: 3.163 s
[  3/350] train_loss: 0.50147 valid_loss: 0.50079 time per epoch: 3.171 s
[  4/350] train_loss: 0.50408 valid_loss: 0.50148 time per epoch: 3.215 s
[  5/350] train_loss: 0.49753 valid_loss: 0.49497 time per epoch: 3.108 s
[  6/350] train_loss: 0.50133 valid_loss: 0.50310 time per epoch: 3.133 s
[  7/350] train_loss: 0.49751 valid_loss: 0.49605 time per epoch: 3.107 s
[  8/350] train_loss: 0.49643 valid_loss: 0.49311 time per epoch: 3.104 s
[  9/350] train_loss: 0.50369 valid_loss: 0.49214 time per epoch: 3.072 s
[ 10/350] train_loss: 0.49018 valid_loss: 0.49909 time per epoch: 3.092 s
[ 11/350] train_loss: 0.49700 valid_loss: 0.49583 time per epoch: 3.073 s
[ 12/350] train_loss: 0.49887 valid_loss: 0.49547 time per epoch: 3.058 s
[ 13/350] train_loss: 0.48195 valid_loss: 0.49087 time per epoch: 3.056 s
[ 14/350] train_loss: 0.48095 valid_l

KeyboardInterrupt: 