<a href="https://colab.research.google.com/github/HyeonhoonLee/KOHI_advance_2021/blob/main/%5Bopen%5D_04_svpred_preprocessing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 의료인공지능 전문가 양성과정 2021
## VitalDB Tutorial <br> Stroke volume prediction using arterial wave - preprocessing
- Date : Sep. 04, 2021
- Author : **Hyun-Lim Yang, Ph.D.**<br>
Research Assistant Professor @
Seoul National University Hospital <br>
Department of Anesthesiology and Pain Medicine
- E-mail : hly{_at_}snuh{_dot_}org
***

In [1]:
from IPython.display import HTML
style_warn = "<style>div.warn { background-color: #fcf2f2;border-color: #dFb5b4; border-left: 5px solid #dfb5b4; padding: 0.5em;}</style>"
HTML(style_warn)

### Import packages

<div class="warn">**Warning!** : use your directory at `download_directory`</div>


> **Wargning** <br>
> colab 환경을 위한 google drive import 코드가 포함되어 있습니다. <br>
> 로컬 환경에서 실행 시 colab을 위한 import function들을 comment out 한 뒤 실행하세요. 

In [2]:
from google.colab import drive  # for colab
drive.mount('/content/gdrive/')  # for colab

Mounted at /content/gdrive/


In [3]:
import os
cloud_directory = '/content/gdrive/My Drive/KOHI2021/KOHI_advanced_2021/'
os.listdir(cloud_directory) # for colab, check cloud directory mount

['kohi_preprocessor.py',
 'kohi_CNN_model_archi.png',
 'data',
 '__pycache__',
 'new_cnn_weight_balanced.h5',
 'simple_cnn_weight_sv.h5',
 'simple_cnn_weight_sv_transfer.h5',
 'simple_cnn_weight_lowbp.h5']

In [4]:
import os
import sys
# download_directory = os.getcwd() # for local environments
# sys.path.append(download_directory) # for local environments
sys.path.append(cloud_directory) # for colab


In [5]:
!pip install vitaldb

Collecting vitaldb
  Downloading vitaldb-0.0.6-py3-none-any.whl (8.9 kB)
Installing collected packages: vitaldb
Successfully installed vitaldb-0.0.6


In [6]:
import numpy as np
import pandas as pd
import glob
import kohi_preprocessor as pre
import vitaldb
from tqdm import tqdm
import warnings
warnings.filterwarnings(action='ignore')

download_directory = cloud_directory # for colab


### Data loading
샘플 파일을 vitaldb 서버로 부터 직접 다운받아 로드

> **TODO:** `00019.vital`을 100hz로 변환하여 `chart_pd_01`에 DataFrame으로 저장하기

In [7]:
track_names = ["EV1000/SV", "SNUADC/ART"]
### =========== Your code here ====================

vitaldata = vitaldb.load_case(caseid=19, tnames=track_names, interval=1/100)
chart_pd_01 = pd.DataFrame(vitaldata, columns=track_names)

### ===============================================

print(chart_pd_01.columns)

Index(['EV1000/SV', 'SNUADC/ART'], dtype='object')


In [13]:
chart_pd_01.sample(10)

Unnamed: 0,EV1000/SV,SNUADC/ART
679399,,92.8984
2136844,,101.786
805657,,76.1116
1232802,,51.4252
1902787,,31.6761
618273,,64.2621
1782405,,64.2621
1957344,,68.212
279135,,59.3249
2373581,,110.673


In [8]:
col_svs = 'EV1000/SV'
col_art = 'SNUADC/ART'

### Stroke volume 데이터 추출

In [14]:
# stroke volume 데이터 추출
### =========== Your code here ====================

svs_data_pd = chart_pd_01[col_svs][chart_pd_01[col_svs].notnull()]
svs_index = svs_data_pd.index.values

### ===============================================
print(svs_data_pd.head())

53845    61.0
54008    61.0
54208    61.0
54408    61.0
54607    61.0
Name: EV1000/SV, dtype: float64


### arterial wave 데이터 추출 및 nan 지우기

In [15]:
# art 데이터 전체 추출 및 nan value 채우기
### =========== Your code here ====================

art_full_pd = chart_pd_01[col_art]
art_full_pd = art_full_pd.fillna(0)

### ===============================================

print(art_full_pd.head())

0    0.0
1    0.0
2    0.0
3    0.0
4    0.0
Name: SNUADC/ART, dtype: float64


### 데이터셋 만들기
필요한 파라미터들 정의

In [16]:
# 필요한 파라미터들 정의
srate = 100
length = 20
max_limit_svs = 200 # svs max : 200
min_limit_svs = 20  # svs min : 20

In [17]:
svs_index

array([  53845,   54008,   54208, ..., 2700200, 2700400, 2700600])

In [20]:
svs_index[svs_index > 0]

array([  53845,   54008,   54208, ..., 2700200, 2700400, 2700600])

In [18]:
# svs index를 20초 뒤로 미룸
### =========== Your code here ====================

svs_points = svs_index[svs_index > (length*srate)]

### ===============================================

print(svs_points)

[  53845   54008   54208 ... 2700200 2700400 2700600]


### Arterial wave segment 추출

In [21]:
# arterial segment 추출하기
### =========== Your code here ====================
svs_values_list = [svs_data_pd[idx] for idx in svs_points]
art_seg_list = [art_full_pd[idx- (length*srate) : 
                            idx].values 
                for idx in svs_points]

### ===============================================

svs_values_np = np.array(svs_values_list)
art_seg_np = np.array(art_seg_list)

print(svs_values_np)
print(art_seg_np)

[61. 61. 61. ... 63. 63. 63.]
[[ 62.2872    62.2872    60.3123   ...  61.2998    65.2496    66.2371  ]
 [ 54.3876    56.3625    60.3123   ...  53.4001    51.4252    50.4377  ]
 [ 65.2496    64.2621    64.2621   ...  77.0991    71.1743    67.2245  ]
 ...
 [ 30.6886    -4.85986   -0.910027 ...  28.7137   -26.5839     8.96454 ]
 [ 20.814    -21.6466    34.6384   ...  22.7889   -11.7721    -2.88494 ]
 [ 19.8266   -21.6466    29.7011   ...  24.7639   -15.7219     3.0398  ]]


### 조건에 따라 filter들 정의

In [23]:
# 필터들 선언
# svs min-max filter

### =========== Your code here ====================

svs_max_filter = svs_values_np < max_limit_svs
svs_min_filter = svs_values_np > min_limit_svs
svs_filter = svs_max_filter & svs_min_filter

### ===============================================


# abp range filter

### =========== Your code here ====================

art_filter_list = []
for seg in art_seg_np:
    filter_value = (np.array(seg) > 25.0).all() and (np.array(seg) < 250.0).all()
    art_filter_list.append(filter_value)
art_filter = np.array(art_filter_list)

### ===============================================


# mstds 필터

### =========== Your code here ====================

mstds_values_list = []
for seg in tqdm(art_seg_np):
    if  (np.array(seg) < 0.).any():
        mstds_values_list.append(float(0.))
    else:
        mstd_val, _ = pre.process_beat(seg)
        mstds_values_list.append(mstd_val)
mstds_filter = np.array(mstds_values_list) > 0.

### ===============================================


100%|██████████| 13217/13217 [03:07<00:00, 70.34it/s]


전체 필터 하나로 만들기

In [27]:
### =========== Your code here ====================

all_filters = svs_filter & art_filter & mstds_filter

### ===============================================


### 필터 적용하여 데이터 추출

In [28]:
# 필터 적용해서 추출

### =========== Your code here ====================

svs_filtered = svs_values_np[all_filters]
art_filtered = art_seg_np[all_filters]

### ===============================================

print(svs_filtered.shape)
print(art_filtered.shape)

(12385,)
(12385, 2000)


데이터셋 정의

In [29]:
x_data = art_filtered
y_label = svs_filtered

In [30]:
print(x_data.shape)
print(y_label.shape)

(12385, 2000)
(12385,)
