# predict

This script designed to run model (Seismo-Performer, GPD) prediction on selected waveform(s). 
Supported data sources:
- .h5

In [1]:
import h5py as h5
import numpy as np
from matplotlib import pyplot as plt

# Modifying sys.path to be able to load project packages
import sys
sys.path.append('../')

# Modifying sys.path to be able to load seismo-performer modules
import sys
sys.path.append('../../seismo-performer/')
sys.path.append('../../seismo-performer/utils')

import seismo_load
import gpd_loader

## Utils

In [52]:
class DataSource():
    
    def __len__(self):
        return 0
    
    def get(self, i):
        return None
    
    def label(self, i):
        return None

    
class H5Source(DataSource):
    
    def __init__(self, data_source):
        self._path = data_source['path']
        self._data_key_stack = data_source['data_key_stack']
        self._label_key_stack = data_source['label_key_stack']
        
    def _data(self, file):
        X = file
        for key in self._data_key_stack:
            X = X[key]
        return X
    
    def _labels(self, file):
        Y = file
        for key in self._label_key_stack:
            Y = Y[key]
        return Y
        
    def get(self, i):
        if (type(i) is tuple or type(i) is list) and len(i) == 2:
            with h5.File(self._path, 'r') as dataset:
                X = self._data(dataset)
                return X[i[0]:i[1]]
        if type(i) is int:
            with h5.File(self._path, 'r') as dataset:
                X = self._data(dataset)
                # Reshape data from (n_samples, n_channels) to (1, n_samples, n_channels)
                X = X[i]
                X = X.reshape((1, *X.shape))
                return X
        raise AttributeError(f'Unsuported type of index {type(i)}!')
        
    def label(self, i):
        if self._label_key_stack is None:
            return None
        if (type(i) is tuple or type(i) is list) and len(i) == 2:
            with h5.File(self._path, 'r') as dataset:
                Y = self._labels(dataset)
                return Y[i[0]:i[1]]
        if type(i) is int:
            with h5.File(self._path, 'r') as dataset:
                Y = self._labels(dataset)
                return Y[i]
        raise AttributeError(f'Unsuported type of index {type(i)}!')
        
    def __len__(self):
        with h5.File(self._path, 'r') as dataset:
            X = self._data(dataset)
            return X.shape[0]

def parse_data_source(data_source):
    source_type = data_source['type']
    if source_type == 'h5':
        return H5Source(data_source)

Avaliable **model names**:
- `sp` *- Seismo-Performer*
- `sp-cnn` *- Seismo-Performer CNN*
- `gpd` *- ConvNet (GPD)*

### Setup model

In [53]:
model_name = 'sp'
weights = None

default_weights = {
    'sp': '../../seismo-performer/WEIGHTS/w_model_performer_with_spec.hd5',
    'sp-cnn': '../../seismo-performer/WEIGHTS/weights_model_cnn_spec.hd5',
    'gpd': '../../seismo-performer/WEIGHTS/w_gpd_scsn_2000_2017.h5',
}

### Setup data source

In [54]:
data_source = {
    'type': 'h5',
    'path': 'C:/data/false.h5',
    'data_key_stack': ['false-positives'],
    'label_key_stack': None,
}

## Load model

In [55]:
if weights is None:
    weights = default_weights[model_name]

if model_name == 'sp':
    model = seismo_load.load_performer(weights)
elif model_name == 'sp-cnn':
    model = seismo_load.load_cnn(weights)
elif model_name == 'gpd':
    model = gpd_loader.load_model(weights)

## Get data

In [56]:
data = parse_data_source(data_source)

In [59]:
result = model.predict(data.get((0, 16))[:, 200:600])

In [60]:
result

array([[9.9956292e-01, 3.8234369e-04, 5.4685039e-05],
       [4.9051305e-04, 9.9853802e-01, 9.7142108e-04],
       [8.5164076e-01, 2.5060115e-02, 1.2329903e-01],
       [9.9909413e-01, 6.1123417e-04, 2.9468132e-04],
       [9.9452150e-01, 3.0807280e-03, 2.3978071e-03],
       [9.9798954e-01, 1.4979603e-03, 5.1249983e-04],
       [9.9812001e-01, 6.1287504e-04, 1.2670311e-03],
       [2.7983024e-04, 9.8806828e-01, 1.1651807e-02],
       [9.9741179e-01, 2.0234962e-03, 5.6481815e-04],
       [9.9506086e-01, 1.1359086e-03, 3.8031884e-03],
       [2.2804782e-04, 9.9715412e-01, 2.6179322e-03],
       [9.9820387e-01, 8.9471455e-04, 9.0144604e-04],
       [9.6929181e-01, 1.6144028e-02, 1.4564155e-02],
       [9.6975058e-01, 9.9243606e-03, 2.0325080e-02],
       [9.8759973e-01, 4.6991869e-03, 7.7009718e-03]], dtype=float32)

In [62]:
for x in result:
    max_class = max(range(len(x)), key=x.__getitem__)
    print(max_class)

0
1
0
0
0
0
0
1
0
0
1
0
0
0
0
