<a href="https://colab.research.google.com/github/ShinyaKatoh/PoViT-UQ/blob/main/pred_for_googlecolab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/ShinyaKatoh/PoViT-UQ
!pip install einops
!pip install torchinfo

In [None]:
import os
import sys
import math
import glob
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from scipy.signal import find_peaks

import seaborn as sns

In [None]:
# Model Definition
device = torch.device('cuda')

sys.path.append('/content/PoViT-UQ/PoViT/')
from model_PoViT import Model
sys.path.pop()

emb_dim = 64
kernel_size = 16
ds_kernel_size= 9
ff_kernel_size= 9
seg_kernel_size= 9
stride = 1
block_num = 7
head_num = 4
dropout_ratio = 0.3

# Loading 100 Hz data and model weights
sf = 0.01 # Sampling Rate
data = np.load('/content/PoViT-UQ/data/test_data_100Hz.npz')
input_data = data['data']
label_polality = data['label1']
label_arrival = data['label2']

length = 256
model = Model(in_length=length, kernel_size=kernel_size, ds_kernel_size=ds_kernel_size, ff_kernel_size=ff_kernel_size, seg_kernel_size=seg_kernel_size, stride=stride, head_num=head_num, emb_dim=emb_dim, num_blocks=block_num, dropout_ratio=dropout_ratio)
model.load_state_dict(torch.load('/content/PoViT-UQ/PoViT/model_100Hz.pth', map_location=device))
model.eval()
model.to(device)

# Loading 250 Hz data and model weights
sf = 0.004 # Sampling Rate
data = np.load('/content/PoViT-UQ/data/test_data_250Hz.npz')
input_data = data['data']
label_polality = data['label1']
label_arrival = data['label2']

length = 512
model = Model(in_length=length, kernel_size=kernel_size, ds_kernel_size=ds_kernel_size, ff_kernel_size=ff_kernel_size, seg_kernel_size=seg_kernel_size, stride=stride, head_num=head_num, emb_dim=emb_dim, num_blocks=block_num, dropout_ratio=dropout_ratio)
model.load_state_dict(torch.load('/content/PoViT-UQ/PoViT/model_250Hz.pth', map_location=device))
model.eval()
model.to(device)

print("Input_data shape:", input_data.shape)

In [None]:
for i in range(input_data.shape[0]):
  # Data Preparation for Monte Carlo Dropout with 100 Iterations
  mcd_input = torch.from_numpy(np.tile(input_data[i:i+1,:,:], (100, 1, 1)).astype('float32'))

  # Prediction
  with torch.no_grad():
    output1, output2 = model(mcd_input.to(device))

    pred_polarity = output1.cpu().numpy() # P-wave first motion Polarity
    pred_arrival = output2.cpu().numpy()  # P-wave Arrival time

  if label_polality[i] == 0:
    true_pola = 'Up'
  elif label_polality[i] == 1:
    true_pola = 'Down'
  elif label_polality[i] == 2:
    true_pola = 'Noise'

  true_arrival = np.argmax(label_arrival[i,0])

  pola_median = np.median(pred_polarity, axis=0)
  max_index = np.argmax(pola_median)

  if max_index == 0:
    pred_pola = 'Up'
  elif max_index == 1:
    pred_pola = 'Down'
  elif max_index == 2:
    pred_pola = 'Noise'

  Q1 = np.percentile(pred_polarity[:,max_index], 25)
  Q3 = np.percentile(pred_polarity[:,max_index], 75)
  IQR = Q3 - Q1 # IQR indicates the amount of uncertainty.

  pred_arrival_index = np.median(np.argmax(pred_arrival[:, 0, :], axis=1))

  lapse_time = np.arange(input_data[i:i+1,:,:].shape[-1])*sf

  # Plot result

  fig, axs = plt.subplots(2,1,figsize=(8,4))
  axs[0].plot(lapse_time, input_data[i,0,:], c='black', lw=0.6)

  for j in range(pred_arrival.shape[0]):
    axs[1].plot(lapse_time, pred_arrival[j,0,:], c='red', lw=0.6, alpha=0.5)

  title = f'True Pola = {true_pola}  Pred. Pola = {pred_pola}  IQR = {IQR:.3f}'
  axs[0].set_title(title)

  axs[1].set_ylim(0,1)

  axs[1].set_xlabel('Time (s)')

  axs[0].set_ylabel('Amp.')
  axs[1].set_ylabel('Probability')

  for k in range(2):
    axs[k].axvline(x=lapse_time[true_arrival], c='blue', label='True Arrival')
    axs[k].axvline(x=lapse_time[int(pred_arrival_index)], c='red', label='Pred Arrival')

  plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.7), frameon=False, ncol=2, fontsize=12)

  plt.show()
  plt.close()

  print('------------------------------------------------------------------')

