In [3]:
# ====== Load Module =========
import numpy as np
import os
import torch.utils.data as data
from scipy.signal import find_peaks
import sys

sys.path.append('.')
sys.path.append('..')
sys.path.append('../..')

model_path = 'data_model/model_checkpoint.pth' # MSFSegNet-31+51-32-tanh-01111-bs8-lr1.0e-04-BCE-SGD-S2
sample_list = ['/data/wht/seismic/RCP/yz/gth/gth_Line1650_Cdp1382.npy']

In [None]:
# ====== Load Sample ========


In [12]:
# ====== process data ========
from models.cascade_method.seg_nets import MSFSegNet
from models.cascade_method.get_curve import split_group, interp_curves
from models.cascade_method.gauss_reg import posterior_regression
from models.cascade_method.clu_curves import clustering_main

import torch

def predict_single(gth, model_path, agc_list=[11, 15], m=32, device=0):
    pred_hyper_dict={'win_k': 3, 'bw_data': 5, 'bw_para': 50, 'valid_range': 50, 'min_len': 5, 'clu_eps': 4}
    
    # prepare input
    h, w = gth.shape
    h_new = (h//m+1)*m if h%m > 0 else h
    w_new = (w//m+1)*m if w%m > 0 else w
    gth_pad = np.zeros((h_new, w_new), dtype=np.float32)
    gth_pad[:h, :w] = gth
    gth_peak_map = np.zeros((h_new, w_new), dtype=np.float32)
    peak_map = np.zeros_like(gth)
    for j, trace in enumerate(gth.T):
        peaks, _ = find_peaks(trace, height=0)
        peak_map[peaks, j] = 1
            
    gth_peak_map[:h, :w] = peak_map
    gth_feat = np.concatenate([gth_pad[np.newaxis, ...], gth_peak_map[np.newaxis, ...]], axis=0)
    gth_feat_torch = torch.from_numpy(gth_feat).unsqueeze(0).cuda(device)
    
    # load model
    picker = MSFSegNet(agc_list=[15,31], CBAM_reduction=32, basic_act='tanh', dcn_use=0, cbam_use=1, add_peak=1, add_bp=1, device=device)
    picker.cuda(device)
    model_file = torch.load(model_path, map_location='cuda:%s'%device)
    picker.load_state_dict(model_file['weights'])
    picker.change_agc(agc_list, device)
    
    # * stage 1
    # inference of MSFSegNet
    seg_map, mult_feats = picker(gth_feat_torch)
    seg_map = seg_map.cpu().detach().numpy()[:h, :w]
    mult_feats = mult_feats.squeeze().cpu().detach().numpy()[:, :h, :w]
    
    # post process
    # * stage 2
    curve_auto = split_group(seg_map)
    curve_auto_interp = interp_curves(curve_auto)
    curve_auto_interp_cp = curve_auto_interp.copy()
    for name in curve_auto_interp_cp:
        if len(curve_auto_interp_cp[name]) < pred_hyper_dict['win_k']:
            curve_auto_interp.pop(name)
    # * stage 3
    infer_opt = posterior_regression(win_k=pred_hyper_dict['win_k'], bw_data=pred_hyper_dict['bw_data'], bw_para=pred_hyper_dict['bw_para'])
    field_auto, slope_dict = infer_opt.est_prior(curve_auto_interp, gth.shape, valid_range=pred_hyper_dict['valid_range'])
    curve_concat, labels = clustering_main(curve_auto_interp, slope_dict, eps=pred_hyper_dict['clu_eps'])
    # * stage 4
    curve_dict_smooth = infer_opt.infer_posterior(curve_concat, field_auto)
    curve_dict_smooth_cp = curve_dict_smooth.copy()
    for name in curve_dict_smooth_cp:
        if len(curve_dict_smooth[name]) < pred_hyper_dict['min_len']:
            curve_dict_smooth.pop(name)

    return 
    

In [13]:
# test
predict_single(np.load(sample_list[0]).T, model_path, device=0)