<a href="https://colab.research.google.com/github/abarrie2/cs598-dlh-project/blob/main/DL4H_Team_24.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## LICENSING NOTICE

Note that all users who use Vital DB, an open biosignal dataset, must agree to the Data Use Agreement below. If you do not agree, please close this window. The Data Use Agreement is available here:
https://vitaldb.net/dataset/#h.vcpgs1yemdb5

# This is the development version of the project code

For the Project Draft submission see the DL4H_Team_24_Project_Draft.ipynb notebook in the project repository.

## Project repository

The project repository can be found at: https://github.com/abarrie2/cs598-dlh-project

## Project video

The project video can be found at: 

# Introduction

This project aims to reproduce findings from the paper titled "Predicting intraoperative hypotension using deep learning with waveforms of arterial blood pressure, electroencephalogram, and electrocardiogram: Retrospective study" by Jo Y-Y et al. (2022) [1]. This study introduces a deep learning model that predicts intraoperative hypotension (IOH) events before they occur, utilizing a combination of arterial blood pressure (ABP), electroencephalogram (EEG), and electrocardiogram (ECG) signals.


## Background of the Problem

Intraoperative hypotension (IOH) is a common and significant surgical complication defined by a mean arterial pressure drop below 65 mmHg. It is associated with increased risks of myocardial infarction, acute kidney injury, and heightened postoperative mortality. Effective prediction and timely intervention can substantially enhance patient outcomes.

### Evolution of IOH Prediction

Initial attempts to predict IOH primarily used arterial blood pressure (ABP) waveforms. A foundational study by Hatib F et al. (2018) titled "Machine-learning Algorithm to Predict Hypotension Based on High-fidelity Arterial Pressure Waveform Analysis" [2] showed that machine learning could forecast IOH events using ABP with reasonable accuracy. This finding spurred further research into utilizing various physiological signals for IOH prediction.

Subsequent advancements included the development of the Acumen™ hypotension prediction index, which was studied in "AcumenTM hypotension prediction index guidance for prevention and treatment of hypotension in noncardiac surgery: a prospective, single-arm, multicenter trial" by Bao X et al. (2024) [3]. This trial integrated a hypotension prediction index into blood pressure monitoring equipment, demonstrating its effectiveness in reducing the number and duration of IOH events during surgeries. Further study is needed to determine whether this resultant reduction in IOH events transalates into improved postoperative patient outcomes.


### Current Study

Building on these advancements, the paper by Jo Y-Y et al. (2022) proposes a deep learning approach that enhances prediction accuracy by incorporating EEG and ECG signals along with ABP. This multi-modal method, evaluated over prediction windows of 3, 5, 10, and 15 minutes, aims to provide a comprehensive physiological profile that could predict IOH more accurately and earlier. Their results indicate that the combination of ABP and EEG significantly improves performance metrics such as AUROC and AUPRC, outperforming models that use fewer signals or different combinations.

Our project seeks to reproduce and verify Jo Y-Y et al.'s results to assess whether this integrated approach can indeed improve IOH prediction accuracy, thereby potentially enhancing surgical safety and patient outcomes.

## Scope of Reproducibility:

The original paper investigated the following hypotheses:

1.   Hypothesis 1: A model using ABP and ECG will outperform a model using ABP alone in predicting IOH.
2.   Hypothesis 2: A model using ABP and EEG will outperform a model using ABP alone in predicting IOH.
3.   Hypothesis 3: A model using ABP, EEG, and ECG will outperform a model using ABP alone in predicting IOH.

Results were compared using AUROC and AUPRC scores. Based on the results described in the original paper, we expect that Hypothesis 2 will be confirmed, and that Hypotheses 1 and 3 will not be confirmed.

In order to perform the corresponding experiments, we will implement a CNN-based model that can be configured to train and infer using the following four model variations:

1.   ABP data alone
2.   ABP and ECG data
3.   ABP and EEG data
4.   ABP, ECG, and EEG data

We will measure the performance of these configurations using the same AUROC and AUPRC metrics as used in the original paper. To test hypothesis 1 we will compare the AUROC and AUPRC measures between model variation 1 and model variation 2. To test hypothesis 2 we will compare the AUROC and AUPRC measures between model variation 1 and model variation 3. To test hypothesis 3 we will compare the AUROC and AUPRC measures between model variation 1 and model variation 4. For all of the above measures and experiment combinations, we will operate multiple experiments where the time-to-IOH event prediction will use the following prediction windows:

1. 3 minutes before event
2. 5 minutes before event
3. 10 minutes before event
4. 15 minutes before event

From the original paper, the predictive power of ABP, ECG and ABP + ECG models at 3-, 5-, 10- and 15-minute prediction windows:
![Predictive power of ABP, ECG and ABP + ECG models at 3-, 5-, 10- and 15-minute prediction windows](https://journals.plos.org/plosone/article/figure/image?download&size=large&id=10.1371/journal.pone.0272055.g004)

### Modifications made for demo mode

In order to demonstrate the functioning of the code in a short (ie, <8 minute limit) the following options and modifications were used:

1. `MAX_CASES` was set to 20. The total number of cases to be used in the full training set is 3296, but the smaller numbers allows demonstration of each section of the pipeline.
2. `vitaldb_cache` is prepopulated in Google Colab. The cache file is approx. 800MB and contains the raw and mini-fied copies of the source dataset and is downloaded from Google Drive. This is much faster than using the `vitaldb` API, but is again only a fraction of the data. The full dataset can be downloaded with the API or prepopulated by following the instructions in the "Bulk Data Download" section below.
3. `max_epochs` is set to 6. With the small dataset, training is fast and shows the decreasing training and validation losses. In the full model run, `max_epochs` will be set to 100. In both cases early stopping is enabled and will stop training if the validation losses stop decreasing for five consecutive epochs.
4. Only the "ABP + EEG" combination will be run. In the final report, additional combinations will be run, as discussed later.
5. Only the 3-minute prediction window will be run. In the final report, additional prediction windows (5, 10 and 15 minutes) will be run, as discussed later.
6. No ablations are run in the demo. These will be completed for the final report.

# Methodology

The methodology section is composed of the following subsections: Environment, Data and Model.

- **Environment**: This section describes the setup of the environment, including the installation of necessary libraries and the configuration of the runtime environment.
- **Data**: This section describes the dataset used in the study, including its collection and preprocessing.
    - **Data Collection**: This section describes the process of downloading the dataset from VitalDB and populating the local data cache.
    - **Data Preprocessing**: This section describes the preprocessing steps applied to the dataset, including data selection, data cleaning, and feature extraction.
- **Model**: This section describes the deep learning model used in the study, including its implementation, training, and evaluation.
    - **Model Implementation**: This section describes the implementation of the deep learning model, including the architecture, loss function, and optimization algorithm.
    - **Model Training**: This section describes the training process, including the training loop, hyperparameters, and training strategy.
    - **Model Evaluation**: This section describes the evaluation process, including the metrics used, the evaluation strategy, and the results obtained.

## Environment

### Create environment

The environment setup differs based on whether you are running the code on a local machine or on Google Colab. The following sections provide instructions for setting up the environment in each case.

#### Local machine

Create `conda` environment for the project using the `environment.yml` file:

```bash
conda env create --prefix .envs/dlh-team24 -f environment.yml
```

Activate the environment with:
```bash
conda activate .envs/dlh-team24
```

This environment specifies Python 3.12.2.

#### Google Colab

The following code snippet installs the required packages and downloads the necessary files in a Google Colab environment:

In [None]:
# Google Colab environments have a `/content` directory. Use this as a proxy for running Colab-only code
COLAB_ENV = "google.colab" in str(get_ipython())
if COLAB_ENV:
    #install vitaldb
    %pip install vitaldb

    # Executing in Colab therefore download cached preprocessed data.
    # TODO: Integrate this with the setup local cache data section below.
    # Check for file existence before overwriting.
    import gdown
    gdown.download(id="15b5Nfhgj3McSO2GmkVUKkhSSxQXX14hJ", output="vitaldb_cache.tgz")
    !tar -zxf vitaldb_cache.tgz

    # Download sqi_filter.csv from github repo
    !wget https://raw.githubusercontent.com/abarrie2/cs598-dlh-project/main/sqi_filter.csv

All other required packages are already installed in the Google Colab environment. As of May 5, 2024, Google Colab uses Python 3.10.12.

### Load environment

In [None]:
# Import packages
import os
import random
import sys
import uuid
import copy
from collections import defaultdict
from glob import glob

from timeit import default_timer as timer

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter, spectrogram
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score, precision_recall_curve, auc, confusion_matrix
from sklearn.metrics import RocCurveDisplay, PrecisionRecallDisplay, average_precision_score
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
import torch
from torch.utils.data import Dataset
import vitaldb
import h5py

import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from datetime import datetime

Start a timer to measure notebook runtime:

In [None]:
global_time_start = timer()

Set random seeds to generate consistent results:

In [None]:
RANDOM_SEED = 42

def reset_random_state():
    random.seed(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)
    torch.manual_seed(RANDOM_SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(RANDOM_SEED)
        torch.cuda.manual_seed_all(RANDOM_SEED)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(RANDOM_SEED)
    
reset_random_state()

Set device to GPU or MPS if available

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if (torch.backends.mps.is_available() and torch.backends.mps.is_built()) else "cpu")
print(f"Using device: {device}")

Define class to print to console and simultaneously save to file:

In [None]:
class ForkedStdout:
    def __init__(self, file_path):
        self.file = open(file_path, 'w')
        self.stdout = sys.stdout

    def write(self, message):
        self.stdout.write(message)
        self.file.write(message)

    def flush(self):
        self.stdout.flush()
        self.file.flush()

    def __enter__(self):
        sys.stdout = self

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout = self.stdout
        self.file.close()

##  Data

### Data Description

#### Source

Data for this project is sourced from the open biosignal VitalDB dataset as described in "VitalDB, a high-fidelity multi-parameter vital signs database in surgical patients" by Lee H-C et al. (2022) [4], which contains perioperative vital signs and numerical data from 6,388 cases of non-cardiac (general, thoracic, urological, and gynecological) surgery patients who underwent routine or emergency surgery at Seoul National University Hospital between 2016 and 2017. The dataset includes ABP, ECG, and EEG signals, as well as other physiological data. The dataset is available through an [API](https://vitaldb.net/dataset/?query=api) and [Python library](https://vitaldb.net/dataset/?query=lib), and at PhysioNet: https://physionet.org/content/vitaldb/1.0.0/

#### Statistics

Characteristics of the dataset:
| Characteristic        | Value                       | Details                |
|-----------------------|-----------------------------|------------------------|
| Total number of cases | 6,388                       |                        |
| Sex (male)            | 3,243 (50.8%)               |                        |
| Age (years)           | 59                          | Range: 48-68           |
| Height (cm)           | 162                         | Range: 156-169         |
| Weight (kg)           | 61                          | Range: 53-69           |
| Tram-Rac 4A tracks    | 6,355 (99.5%)               | Sampling rate: 500Hz   |
| BIS Vista tracks      | 5,566 (87.1%)               | Sampling rate: 128Hz   |
| Case duration (min)   | 189                         | Range: 27-1041         |

Labels are only known after processing the data. In the original paper, there were an average of 1.6 IOH events per case and 5.7 non-events per case so we expect approximately 10,221 IOH events and 364,116 non-events in the dataset.

#### Data Processing

Data will be processed as follows:
1. Load the dataset from VitalDB, or from a local cache if previously downloaded.
2. Apply the inclusion and exclusion selection criteria to filter the dataset according to surgery metadata.
3. Generate a minified dataset by discarding all tracks except ABP, ECG, and EEG.
4. Preprocess the data by applying band-pass and z-score normalization to the ECG and EEG signals, and filtering out ABP signals below a Signal Quality Index (SQI) threshold.
5. Generate event and non-event samples by extracting 60-second segments around IOH events and non-events.
6. Split the dataset into training, validation, and test sets with a 6:1:3 ratio, ensuring that samples from a single case are not split across different sets to avoid data leakage.

### Set Up Local Data Caches

VitalDB data is static, so local copies can be stored and reused to avoid expensive downloads and to speed up data processing.

The default directory defined below is in the project `.gitignore` file. If this is modified, the new directory should also be added to the project `.gitignore`.

In [None]:
VITALDB_CACHE = './vitaldb_cache'
VITAL_ALL = f"{VITALDB_CACHE}/vital_all"
VITAL_MINI = f"{VITALDB_CACHE}/vital_mini"
VITAL_METADATA = f"{VITALDB_CACHE}/metadata"
VITAL_MODELS = f"{VITALDB_CACHE}/models"
VITAL_RUNS = f"{VITALDB_CACHE}/runs"
VITAL_PREPROCESS_SCRATCH = f"{VITALDB_CACHE}/data_scratch"
VITAL_EXTRACTED_SEGMENTS = f"{VITALDB_CACHE}/segments"

In [None]:
TRACK_CACHE = None
SEGMENT_CACHE = None

# when USE_MEMORY_CACHING is enabled, track data will be persisted in an in-memory cache. Not useful once we have already pre-extracted all event segments
# DON'T USE: Stores items in memory that are later not used. Causes OOM on segment extraction.
USE_MEMORY_CACHING = False

# When RESET_CACHE is set to True, it will ensure the TRACK_CACHE is disposed and recreated when we do dataset initialization.
# Use as a shortcut to wiping cache rather than restarting kernel
RESET_CACHE = False

PREDICTION_WINDOW = 3
#PREDICTION_WINDOW = 'ALL'

ALL_PREDICTION_WINDOWS = [3, 5, 10, 15]

# Maximum number of cases of interest for which to download data.
# Set to a small value (ex: 20) for demo purposes, else set to None to disable and download and process all.
MAX_CASES = None
#MAX_CASES = 300

# Preloading Cases: when true, all matched cases will have the _mini tracks extracted and put into in-mem dict
PRELOADING_CASES = False
PRELOADING_SEGMENTS = True
# Perform Data Preprocessing: do we want to take the raw vital file and extract segments of interest for training?
PERFORM_DATA_PREPROCESSING = False

In [None]:
if not os.path.exists(VITALDB_CACHE):
  os.mkdir(VITALDB_CACHE)
if not os.path.exists(VITAL_ALL):
  os.mkdir(VITAL_ALL)
if not os.path.exists(VITAL_MINI):
  os.mkdir(VITAL_MINI)
if not os.path.exists(VITAL_METADATA):
  os.mkdir(VITAL_METADATA)
if not os.path.exists(VITAL_MODELS):
  os.mkdir(VITAL_MODELS)
if not os.path.exists(VITAL_RUNS):
  os.mkdir(VITAL_RUNS)
if not os.path.exists(VITAL_PREPROCESS_SCRATCH):
  os.mkdir(VITAL_PREPROCESS_SCRATCH)
if not os.path.exists(VITAL_EXTRACTED_SEGMENTS):
  os.mkdir(VITAL_EXTRACTED_SEGMENTS)

print(os.listdir(VITALDB_CACHE))


### Bulk Data Download

**This step is not required, but will significantly speed up downstream processing and avoid a high volume of API requests to the VitalDB web site.**

**Note:** The dataset is slightly different depending on whether it is downloaded from the API or from Physionet. In almost all cases, the relevant tracks are identical between the two, but this is not always true.

The cache population code checks if the `.vital` files are locally available, and can be populated by calling the vitaldb API or by manually prepopulating the cache (recommended)

- Manually downloaded the dataset from the following site: https://physionet.org/content/vitaldb/1.0.0/
    - Download the [zip file](https://physionet.org/static/published-projects/vitaldb/vitaldb-a-high-fidelity-multi-parameter-vital-signs-database-in-surgical-patients-1.0.0.zip) in a browser, or
    - Use `wget -r -N -c -np https://physionet.org/files/vitaldb/1.0.0/` to download the files in a terminal
- Move the contents of `vital_files` into the `${VITAL_ALL}` directory.

In [None]:
# Returns the Pandas DataFrame for the specified dataset.
#   One of 'cases', 'labs', or 'trks'
# If the file exists locally, create and return the DataFrame.
# Else, download and cache the csv first, then return the DataFrame.
def vitaldb_dataframe_loader(dataset_name):
    if dataset_name not in ['cases', 'labs', 'trks']:
        raise ValueError(f'Invalid dataset name: {dataset_name}')
    file_path = f'{VITAL_METADATA}/{dataset_name}.csv'
    if os.path.isfile(file_path):
        print(f'{dataset_name}.csv exists locally.')
        df = pd.read_csv(file_path)
        return df
    else:
        print(f'downloading {dataset_name} and storing in the local cache for future reuse.')
        df = pd.read_csv(f'https://api.vitaldb.net/{dataset_name}')
        df.to_csv(file_path, index=False)
        return df

## Exploratory Data Analysis

#### Cases

In [None]:
cases = vitaldb_dataframe_loader('cases')
cases = cases.set_index('caseid')
cases.shape

In [None]:
cases.index.nunique()

In [None]:
cases.head()

In [None]:
cases['sex'].value_counts()

#### Tracks

In [None]:
trks = vitaldb_dataframe_loader('trks')
trks = trks.set_index('caseid')
trks.shape

In [None]:
trks.index.nunique()

In [None]:
trks.groupby('caseid')[['tid']].count().plot();

In [None]:
trks.groupby('caseid')[['tid']].count().hist();

In [None]:
trks.groupby('tname').count().sort_values(by='tid', ascending=False)

### Parameters of Interest

### Hemodynamic Parameters Reference
https://vitaldb.net/dataset/?query=overview#h.f7d712ycdpk2

**SNUADC/ART**

arterial blood pressure waveform

Parameter, Description, Type/Hz, Unit

SNUADC/ART, Arterial pressure wave, W/500, mmHg

In [None]:
trks[trks['tname'].str.contains('SNUADC/ART')].shape

**SNUADC/ECG_II**

electrocardiogram waveform

Parameter, Description, Type/Hz, Unit

SNUADC/ECG_II, ECG lead II wave, W/500, mV

In [None]:
trks[trks['tname'].str.contains('SNUADC/ECG_II')].shape

**BIS/EEG1_WAV**

electroencephalogram waveform

Parameter, Description, Type/Hz, Unit

BIS/EEG1_WAV, EEG wave from channel 1, W/128, uV

In [None]:
trks[trks['tname'].str.contains('BIS/EEG1_WAV')].shape

### Cases of Interest

These are the subset of case ids for which modelling and analysis will be performed based upon inclusion criteria and waveform data availability.

In [None]:
# TRACK NAMES is used for metadata analysis via API
TRACK_NAMES = ['SNUADC/ART', 'SNUADC/ECG_II', 'BIS/EEG1_WAV']
TRACK_SRATES = [500, 500, 128]
# EXTRACTION TRACK NAMES adds the EVENT track which is only used when doing actual file i/o
EXTRACTION_TRACK_NAMES = ['SNUADC/ART', 'SNUADC/ECG_II', 'BIS/EEG1_WAV', 'EVENT']
EXTRACTION_TRACK_SRATES = [500, 500, 128, 1]

As in the paper, select cases which meet the following criteria:

For patients, the inclusion criteria were as follows:
1. adults (age >= 18)
2. administered general anaesthesia
3. undergone non-cardiac surgery. 

For waveform data, the inclusion criteria were as follows:
1. no missing monitoring for ABP, ECG, and EEG waveforms
2. no cases containing false events or non-events due to poor signal quality (checked in second stage of data preprocessing)

In [None]:
# Adult
inclusion_1 = cases.loc[cases['age'] >= 18].index
print(f'{len(cases)-len(inclusion_1)} cases excluded, {len(inclusion_1)} remaining due to age criteria')

# General Anesthesia
inclusion_2 = cases.loc[cases['ane_type'] == 'General'].index
print(f'{len(cases)-len(inclusion_2)} cases excluded, {len(inclusion_2)} remaining due to anesthesia criteria')

# Non-cardiac surgery
inclusion_3 = cases.loc[
    ~cases['opname'].str.contains("cardiac", case=False)
    & ~cases['opname'].str.contains("aneurysmal", case=False)
].index
print(f'{len(cases)-len(inclusion_3)} cases excluded, {len(inclusion_3)} remaining due to non-cardiac surgery criteria')

# ABP, ECG, EEG waveforms
inclusion_4 = trks.loc[trks['tname'].isin(TRACK_NAMES)].index.value_counts()
inclusion_4 = inclusion_4[inclusion_4 == len(TRACK_NAMES)].index
print(f'{len(cases)-len(inclusion_4)} cases excluded, {len(inclusion_4)} remaining due to missing waveform data')

# SQI filter
# NOTE: this depends on a sqi_filter.csv generated by external processing
inclusion_5 = pd.read_csv('sqi_filter.csv', header=None, names=['caseid','sqi']).set_index('caseid').index
print(f'{len(cases)-len(inclusion_5)} cases excluded, {len(inclusion_5)} remaining due to SQI threshold not being met')

# Only include cases with known good waveforms.
exclusion_6 = pd.read_csv('malformed_tracks_filter.csv', header=None, names=['caseid']).set_index('caseid').index
inclusion_6 = cases.index.difference(exclusion_6)
print(f'{len(cases)-len(inclusion_6)} cases excluded, {len(inclusion_6)} remaining due to malformed waveforms')

cases_of_interest_idx = inclusion_1 \
    .intersection(inclusion_2) \
    .intersection(inclusion_3) \
    .intersection(inclusion_4) \
    .intersection(inclusion_5) \
    .intersection(inclusion_6)

cases_of_interest = cases.loc[cases_of_interest_idx]

print()
print(f'{cases_of_interest_idx.shape[0]} out of {cases.shape[0]} total cases remaining after exclusions applied')

# Trim cases of interest to MAX_CASES
if MAX_CASES:
    cases_of_interest_idx = cases_of_interest_idx[:MAX_CASES]
print(f'{cases_of_interest_idx.shape[0]} cases of interest selected')

In [None]:
cases_of_interest.head(n=5)

**Note:** In the original paper, the authors used an SQI measure they called jSQI but which appears to be jSQI + wSQI. We were not able to implement the same filter, so the inclusion of `sqi_filter.csv` simulates the inclusion of this filter. By not excluding cases where the SQI is below the threshold set by the authors, our dataset is noisier than that used by the original authors which will impact performance.

### Tracks of Interest

These are the subset of tracks (waveforms) for the cases of interest identified above.

In [None]:
# A single case maps to one or more waveform tracks. Select only the tracks required for analysis.
trks_of_interest = trks.loc[cases_of_interest_idx][trks.loc[cases_of_interest_idx]['tname'].isin(TRACK_NAMES)]
trks_of_interest.shape

In [None]:
trks_of_interest.head(n=5)

In [None]:
trks_of_interest_idx = trks_of_interest.set_index('tid').index
trks_of_interest_idx.shape

### Build Tracks Cache for Local Processing

Tracks data are large and therefore expensive to download every time used.
By default, the `.vital` file format stores all tracks for each case internally. Since only select tracks per case are required, each `.vital` file can be further reduced by discarding the unused tracks.

In [None]:
# Ensure the full vital file dataset is available for cases of interest.
count_downloaded = 0
count_present = 0

#for i, idx in enumerate(cases.index):
for idx in cases_of_interest_idx:
    full_path = f'{VITAL_ALL}/{idx:04d}.vital'
    if not os.path.isfile(full_path):
        print(f'Missing vital file: {full_path}')
        # Download and save the file.
        vf = vitaldb.VitalFile(idx)
        vf.to_vital(full_path)
        count_downloaded += 1
    else:
        count_present += 1

print()
print(f'Count of cases of interest:           {cases_of_interest_idx.shape[0]}')
print(f'Count of vital files downloaded:      {count_downloaded}')
print(f'Count of vital files already present: {count_present}')

### Validate Mini Files

Validate the minified `.vital` files and check that all of the required data tracks are present. The Vital API does not throw an error when you request a track that does not exist.

In [None]:
# Convert vital files to "mini" versions including only the subset of tracks defined in TRACK_NAMES above.
# Only perform conversion for the cases of interest.
# NOTE: If this cell is interrupted, it can be restarted and will continue where it left off.
count_minified = 0
count_present = 0
count_missing_tracks = 0
count_not_fixable = 0

# If set to true, local mini files are checked for all tracks even if the mini file is already present.
FORCE_VALIDATE = False

for idx in cases_of_interest_idx:
    full_path = f'{VITAL_ALL}/{idx:04d}.vital'
    mini_path = f'{VITAL_MINI}/{idx:04d}_mini.vital'

    if FORCE_VALIDATE or not os.path.isfile(mini_path):
        print(f'Creating mini vital file: {idx}')
        vf = vitaldb.VitalFile(full_path, EXTRACTION_TRACK_NAMES)
        
        if len(vf.get_track_names()) != 4:
            print(f'Missing track in vital file: {idx}, {set(EXTRACTION_TRACK_NAMES).difference(set(vf.get_track_names()))}')
            count_missing_tracks += 1
            
            # Attempt to download from VitalDB directly and see if missing tracks are present.
            vf = vitaldb.VitalFile(idx, EXTRACTION_TRACK_NAMES)
            
            if len(vf.get_track_names()) != 4:
                print(f'Unable to fix missing tracks: {idx}')
                count_not_fixable += 1
                continue
                
            if vf.get_track_samples(EXTRACTION_TRACK_NAMES[0], 1/EXTRACTION_TRACK_SRATES[0]).shape[0] == 0:
                print(f'Empty track: {idx}, {EXTRACTION_TRACK_NAMES[0]}')
                count_not_fixable += 1
                continue
                
            if vf.get_track_samples(EXTRACTION_TRACK_NAMES[1], 1/EXTRACTION_TRACK_SRATES[1]).shape[0] == 0:
                print(f'Empty track: {idx}, {EXTRACTION_TRACK_NAMES[1]}')
                count_not_fixable += 1
                continue
                
            if vf.get_track_samples(EXTRACTION_TRACK_NAMES[2], 1/EXTRACTION_TRACK_SRATES[2]).shape[0] == 0:
                print(f'Empty track: {idx}, {EXTRACTION_TRACK_NAMES[2]}')
                count_not_fixable += 1
                continue

            if vf.get_track_samples(EXTRACTION_TRACK_NAMES[3], 1/EXTRACTION_TRACK_SRATES[3]).shape[0] == 0:
                print(f'Empty track: {idx}, {EXTRACTION_TRACK_NAMES[3]}')
                count_not_fixable += 1
                continue

        vf.to_vital(mini_path)
        count_minified += 1
    else:
        count_present += 1

print()
print(f'Count of cases of interest:           {cases_of_interest_idx.shape[0]}')
print(f'Count of vital files minified:        {count_minified}')
print(f'Count of vital files already present: {count_present}')
print(f'Count of vital files missing tracks:  {count_missing_tracks}')
print(f'Count of vital files not fixable:     {count_not_fixable}')

#### Filtering

As in the original paper, preprocessing characteristics are different for each of the three signal categories:
 * ABP: no preprocessing, use as-is
 * ECG: apply a 1-40Hz bandpass filter, then perform Z-score normalization
 * EEG: apply a 0.5-50Hz bandpass filter


`apply_bandpass_filter()` implements the bandpass filter using scipy.signal


In [None]:
def apply_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter(order, [lowcut, highcut], fs=fs, btype='band')
    y = lfilter(b, a, np.nan_to_num(data))
    return y

`apply_zscore_normalization()` implements the Z-score normalization using numpy

In [None]:

def apply_zscore_normalization(signal):
    mean = np.nanmean(signal)
    std = np.nanstd(signal)
    return (signal - mean) / std

### Filtering demonstration

Demonstrate effects of the filters with pre/post filtering waveforms on a sample case:

In [None]:
caseidx = 1
file_path = f"{VITAL_MINI}/{caseidx:04d}_mini.vital"
vf = vitaldb.VitalFile(file_path, TRACK_NAMES)

originalAbp = None
filteredAbp = None
originalEcg = None
filteredEcg = None
originalEeg = None
filteredEeg = None

ABP_TRACK_NAME = "SNUADC/ART"
ECG_TRACK_NAME = "SNUADC/ECG_II"
EEG_TRACK_NAME = "BIS/EEG1_WAV"

for i, (track_name, rate) in enumerate(zip(TRACK_NAMES, TRACK_SRATES)):
    # Get samples for this track
    track_samples = vf.get_track_samples(track_name, 1/rate)
    #track_samples, _ = vf.get_samples(track_name, 1/rate)
    print(f"Track {track_name} @ {rate}Hz shape {len(track_samples)}")

    if track_name == ABP_TRACK_NAME:
        # ABP waveforms are used without further pre-processing
        originalAbp = track_samples
        filteredAbp = track_samples
    elif track_name == ECG_TRACK_NAME:
        originalEcg = track_samples
        # ECG waveforms are band-pass filtered between 1 and 40 Hz, and Z-score normalized
        # first apply bandpass filter
        filteredEcg = apply_bandpass_filter(track_samples, 1, 40, rate)
        # then do z-score normalization
        filteredEcg = apply_zscore_normalization(filteredEcg)
    elif track_name == EEG_TRACK_NAME:
        # EEG waveforms are band-pass filtered between 0.5 and 50 Hz
        originalEeg = track_samples
        filteredEeg = apply_bandpass_filter(track_samples, 0.5, 50, rate, 2)

def plotSignal(data, title):
    plt.figure(figsize=(20, 5))
    plt.plot(data)
    plt.title(title)
    plt.show()

plotSignal(originalAbp, "Original ABP")
plotSignal(originalAbp, "Unfiltered ABP")
plotSignal(originalEcg, "Original ECG")
plotSignal(filteredEcg, "Filtered ECG")
plotSignal(originalEeg, "Original EEG")
plotSignal(filteredEeg, "Filtered EEG")


### Perform data preprocessing

This section performs the actual data preprocessing laid out earlier:

In [None]:
# Preprocess data tracks
ABP_TRACK_NAME = "SNUADC/ART"
ECG_TRACK_NAME = "SNUADC/ECG_II"
EEG_TRACK_NAME = "BIS/EEG1_WAV"
EVENT_TRACK_NAME = "EVENT"
MINI_FILE_FOLDER = VITAL_MINI
CACHE_FILE_FOLDER = VITAL_PREPROCESS_SCRATCH

if RESET_CACHE:
    TRACK_CACHE = None
    SEGMENT_CACHE = None

if TRACK_CACHE is None:
    TRACK_CACHE = {}
    SEGMENT_CACHE = {}

def get_track_data(case, print_when_file_loaded = False):
    parsedFile = None
    abp = None
    eeg = None
    ecg = None
    events = None

    for i, (track_name, rate) in enumerate(zip(EXTRACTION_TRACK_NAMES, EXTRACTION_TRACK_SRATES)):
        # use integer case id and track name, delimited by pipe, as cache key
        cache_label = f"{case}|{track_name}"
        
        if cache_label not in TRACK_CACHE:
            if parsedFile is None:
                file_path = f"{MINI_FILE_FOLDER}/{case:04d}_mini.vital"
                if print_when_file_loaded:
                    print(f"[{datetime.now()}] Loading vital file {file_path}")
                parsedFile = vitaldb.VitalFile(file_path, EXTRACTION_TRACK_NAMES)
            
            dataset = np.array(parsedFile.get_track_samples(track_name, 1/rate))
            
            if track_name == ABP_TRACK_NAME:
                # no filtering for ABP
                abp = dataset
                abp = pd.DataFrame(abp).ffill(axis=0).bfill(axis=0)[0].values
                if USE_MEMORY_CACHING:
                    TRACK_CACHE[cache_label] = abp
            elif track_name == ECG_TRACK_NAME:
                ecg = dataset
                # apply ECG filtering: first bandpass then do z-score normalization
                ecg = pd.DataFrame(ecg).ffill(axis=0).bfill(axis=0)[0].values
                ecg = apply_bandpass_filter(ecg, 1, 40, rate, 2)
                ecg = apply_zscore_normalization(ecg)
                
                if USE_MEMORY_CACHING:
                    TRACK_CACHE[cache_label] = ecg
            elif track_name == EEG_TRACK_NAME:
                eeg = dataset
                eeg = pd.DataFrame(eeg).ffill(axis=0).bfill(axis=0)[0].values
                # apply EEG filtering: bandpass only
                eeg = apply_bandpass_filter(eeg, 0.5, 50, rate, 2)
                if USE_MEMORY_CACHING:
                    TRACK_CACHE[cache_label] = eeg
            elif track_name == EVENT_TRACK_NAME:
                events = dataset
                if USE_MEMORY_CACHING:
                    TRACK_CACHE[cache_label] = events
        else:
            # cache hit, pull from cache
            if track_name == ABP_TRACK_NAME:
                abp = TRACK_CACHE[cache_label]
            elif track_name == ECG_TRACK_NAME:
                ecg = TRACK_CACHE[cache_label]
            elif track_name == EEG_TRACK_NAME:
                eeg = TRACK_CACHE[cache_label]
            elif track_name == EVENT_TRACK_NAME:
                events = TRACK_CACHE[cache_label]

    return (abp, ecg, eeg, events)

# ABP waveforms are used without further pre-processing
# ECG waveforms are band-pass filtered between 1 and 40 Hz, and Z-score normalized
# EEG waveforms are band-pass filtered between 0.5 and 50 Hz
if PRELOADING_CASES:
    # determine disk cache file label
    maxlabel = "ALL"
    if MAX_CASES is not None:
        maxlabel = str(MAX_CASES)
    picklefile = f"{CACHE_FILE_FOLDER}/{PREDICTION_WINDOW}_minutes_MAX{maxlabel}.trackcache"

    for track in tqdm(cases_of_interest_idx):
        # getting track data will cause a cache-check and fill when missing
        # will also apply appropriate filtering per track
        get_track_data(track, False)
    
    print(f"Generated track cache, {len(TRACK_CACHE)} records generated")

Processed data is stored in `.h5` files. Define a loader to read this data and return a tuple with the waveform data:

In [None]:
def get_segment_data(file_path):
    abp = None
    eeg = None
    ecg = None

    if USE_MEMORY_CACHING:
        if file_path in SEGMENT_CACHE:
            (abp, ecg, eeg) = SEGMENT_CACHE[file_path]
            return (abp, ecg, eeg)

    try:
        with h5py.File(file_path, 'r') as f:
            abp = np.array(f['abp'])
            ecg = np.array(f['ecg'])
            eeg = np.array(f['eeg'])
        
        abp = np.array(abp)
        eeg = np.array(eeg)
        ecg = np.array(ecg)

        if len(abp) > 30000:
            abp = abp[:30000]
        elif len(abp) < 30000:
            abp = np.resize(abp, (30000))

        if len(ecg) > 30000:
            ecg = ecg[:30000]
        elif len(ecg) < 30000:
            ecg = np.resize(ecg, (30000))

        if len(eeg) > 7680:
            eeg = eeg[:7680]
        elif len(eeg) < 7680:
            eeg = np.resize(eeg, (7680))

        if USE_MEMORY_CACHING:
            SEGMENT_CACHE[file_path] = (abp, ecg, eeg)
    except:
        abp = None
        ecg = None
        eeg = None

    return (abp, ecg, eeg)


The `.vital` files contain timeseries information before and after the surgery starts, and include a label start where significant events can be indicated. Define a function to read from this track and extract surgery start and end times so that data can be extracted from this period:

In [None]:
def getSurgeryBoundariesInSeconds(event, debug=False):
    eventIndices = np.argwhere(event==event)
    # we are looking for the last index where the string contains 'start
    lastStart = 0
    firstFinish = len(event)-1
    
    # find last start
    for idx in eventIndices:
        if 'started' in event[idx[0]]:
            if debug:
                print(event[idx[0]])
                print(idx[0])
            lastStart = idx[0]
    
    # find first finish
    for idx in eventIndices:
        if 'finish' in event[idx[0]]:
            if debug:
                print(event[idx[0]])
                print(idx[0])

            firstFinish = idx[0]
            break
    
    if debug:
        print(f'lastStart, firstFinish: {lastStart}, {firstFinish}')
    return (lastStart, firstFinish)

Define a function to check if there are extracted segments for this case. If they are not, they will need to be generated:

In [None]:
def areCaseSegmentsCached(caseid):
    seg_folder = f"{VITAL_EXTRACTED_SEGMENTS}/{caseid:04d}"
    return os.path.exists(seg_folder) and len(os.listdir(seg_folder)) > 0

Define a basic signal quality check function for ABP data:

In [None]:
def isAbpSegmentValidNumpy(samples, debug=False):
    valid = True
    if np.isnan(samples).mean() > 0.1:
        valid = False
        if debug:
            print(f">10% NaN")
    elif (samples > 200).any():
        valid = False
        if debug:
            print(f"Presence of BP > 200")
    elif (samples < 30).any():
        valid = False
        if debug:
            print(f"Presence of BP < 30")
    elif np.max(samples) - np.min(samples) < 30:
        if debug:
            print(f"Max - Min test < 30")
        valid = False
    elif (np.abs(np.diff(samples)) > 30).any():  # abrupt change -> noise
        if debug:
            print(f"Abrupt change (noise)")
        valid = False
    
    return valid

Check if the ABP data extracted for a case is valid:

In [None]:
def isAbpSegmentValid(vf, debug=False):
    ABP_ECG_SRATE_HZ = 500
    ABP_TRACK_NAME = "SNUADC/ART"

    samples = np.array(vf.get_track_samples(ABP_TRACK_NAME, 1/ABP_ECG_SRATE_HZ))
    return isAbpSegmentValidNumpy(samples, debug)

Save extracted segments to disk. Use an `.h5` format for efficient packing and playback.

In [None]:
def saveCaseSegments(caseid, positiveSegments, negativeSegments, compresslevel=9, debug=False, forceWrite=False):
    if len(positiveSegments) == 0 and len(negativeSegments) == 0:
        # exit early if no events found
        print(f'{caseid}: exit early, no segments to save')
        return

    # event composition
    # predictiveSegmentStart in seconds, predictiveSegmentEnd in seconds, predWindow (0 for negative), abp, ecg, eeg)
    # 0start, 1end, 2predwindow, 3abp, 4ecg, 5eeg

    seg_folder = f"{VITAL_EXTRACTED_SEGMENTS}/{caseid:04d}"
    if not os.path.exists(seg_folder):
        # if directory needs to be created, then there are no cached segments
        os.mkdir(seg_folder)
    else:
        if not forceWrite:
            # exit early if folder already exists, case already produced
            return

    # prior to writing files out, clear existing files
    for filename in os.listdir(seg_folder):
        file_path = os.path.join(seg_folder, filename)
        if debug:
            print(f'deleting: {file_path}')
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print('Failed to delete %s. Reason: %s' % (file_path, e))
    
    count_pos_saved = 0
    for i in range(0, len(positiveSegments)):
        event = positiveSegments[i]
        startIndex = event[0]
        endIndex = event[1]
        predWindow = event[2]
        abp = event[3]
        #ecg = event[4]
        #eeg = event[5]

        seg_filename = f"{caseid:04d}_{startIndex}_{predWindow:02d}_True.h5"
        seg_fullpath = f"{seg_folder}/{seg_filename}"
        if isAbpSegmentValidNumpy(abp, debug):
            count_pos_saved += 1

            abp = abp.tolist()
            ecg = event[4].tolist()
            eeg = event[5].tolist()
        
            f = h5py.File(seg_fullpath, "w")
            f.create_dataset('abp', data=abp, compression="gzip", compression_opts=compresslevel)
            f.create_dataset('ecg', data=ecg, compression="gzip", compression_opts=compresslevel)
            f.create_dataset('eeg', data=eeg, compression="gzip", compression_opts=compresslevel)
            
            f.flush()
            f.close()
            f = None

            abp = None
            ecg = None
            eeg = None

            # f.create_dataset('label', data=[1], compression="gzip", compression_opts=compresslevel)
            # f.create_dataset('pred_window', data=[event[2]], compression="gzip", compression_opts=compresslevel)
            # f.create_dataset('caseid', data=[caseid], compression="gzip", compression_opts=compresslevel)
        elif debug:
            print(f"{caseid:04d} {predWindow:02d}min {startIndex} starttime = ignored, segment validity issues")

    count_neg_saved = 0
    for i in range(0, len(negativeSegments)):
        event = negativeSegments[i]
        startIndex = event[0]
        endIndex = event[1]
        predWindow = event[2]
        abp = event[3]
        #ecg = event[4]
        #eeg = event[5]

        seg_filename = f"{caseid:04d}_{startIndex}_0_False.h5"
        seg_fullpath = f"{seg_folder}/{seg_filename}"
        if isAbpSegmentValidNumpy(abp, debug):
            count_neg_saved += 1

            abp = abp.tolist()
            ecg = event[4].tolist()
            eeg = event[5].tolist()
            
            f = h5py.File(seg_fullpath, "w")
            f.create_dataset('abp', data=abp, compression="gzip", compression_opts=compresslevel)
            f.create_dataset('ecg', data=ecg, compression="gzip", compression_opts=compresslevel)
            f.create_dataset('eeg', data=eeg, compression="gzip", compression_opts=compresslevel)
            
            f.flush()
            f.close()
            f = None

            abp = None
            ecg = None
            eeg = None

            # f.create_dataset('label', data=[0], compression="gzip", compression_opts=compresslevel)
            # f.create_dataset('pred_window', data=[0], compression="gzip", compression_opts=compresslevel)
            # f.create_dataset('caseid', data=[caseid], compression="gzip", compression_opts=compresslevel)
        elif debug:
            print(f"{caseid:04d} CleanWindow {startIndex} starttime = ignored, segment validity issues")
            
    if count_neg_saved == 0 and count_pos_saved == 0:
        print(f'{caseid}: nothing saved, all segments filtered')

The following method is adapted from the preprocessing block of reference [6] (https://github.com/vitaldb/examples/blob/master/hypotension_art.ipynb)

The approach first finds an interoperative hypotensive event in the ABP waveform. It then backtracks to earlier in the waveform to extract a 60 second segment representing the waveform feature to use as model input. The figure below shows an example of this approach and is reproduced from the VitalDB example notebook referenced above.

![Feature segment extraction](<https://github.com/abarrie2/cs598-dlh-project/blob/main/img/segment_extraction.png?raw=true>)

**Generate hypotensive events**

Hypotensive events are defined as a 1-minute interval with sustained ABP of less than 65 mmHg
**Note:** Hypotensive events should be at least 20 minutes apart to minimize potential residual effects from previous events

**Generate hypotension non-events**

To sample non-events, 30-minute segments where the ABP was above 75 mmHG were selected, and then three one-minute samples of each waveform were obtained from the middle of the segment both occur in extract_segments

In [None]:
def extract_segments(
    cases_of_interest_idx,
    debug=False,
    checkCache=True,
    forceWrite=False,
    returnSegments=False,
    skipInvalidCleanEvents=False,
    skipInvalidIohEvents=False,
):
    # Sampling rate for ABP and ECG, Hz. These rates should be the same. Default = 500
    ABP_ECG_SRATE_HZ = 500

    # Sampling rate for EEG. Default = 128
    EEG_SRATE_HZ = 128

    # Final dataset for training and testing the model.
    positiveSegmentsMap = {}
    negativeSegmentsMap = {}
    iohEventsMap = {}
    cleanEventsMap = {}

    # Process each case and extract segments. For each segment identify presence of an event in the label zone.
    count_cases = len(cases_of_interest_idx)

    #for case_count, caseid in tqdm(enumerate(cases_of_interest_idx), total=count_cases):
    for case_count, caseid in enumerate(cases_of_interest_idx):
        if debug:
            print(f'Loading case: {caseid:04d}, ({case_count + 1} of {count_cases})')

        if checkCache and areCaseSegmentsCached(caseid):
            if debug:
                print(f'Skipping case: {caseid:04d}, already cached')
            # skip records we've already cached
            continue

        # read the arterial waveform
        (abp, ecg, eeg, event) = get_track_data(caseid)
        if debug:
            print(f'Length of {TRACK_NAMES[0]}:       {abp.shape[0]}')
            print(f'Length of {TRACK_NAMES[1]}:    {ecg.shape[0]}')
            print(f'Length of {TRACK_NAMES[2]}:     {eeg.shape[0]}')

        (startInSeconds, endInSeconds) = getSurgeryBoundariesInSeconds(event)
        if debug:
            print(f"Event markers indicate that surgery begins at {startInSeconds}s and ends at {endInSeconds}s.")

        #track_length_seconds = int(len(abp) / ABP_ECG_SRATE_HZ)
        track_length_seconds = endInSeconds
        
        if debug:
            print(f"Processing case {caseid} with length {track_length_seconds}s")

        
        # check if the ABP segment in the surgery window is valid
        if debug:
            isSurgerySegmentValid = \
                isAbpSegmentValidNumpy(abp[startInSeconds * ABP_ECG_SRATE_HZ:endInSeconds * ABP_ECG_SRATE_HZ])
            print(f'{caseid}: surgery segment valid: {isSurgerySegmentValid}')
        
        iohEvents = []
        cleanEvents = []
        i = 0
        started = False
        eofReached = False
        trackStartIndex = None

        # set i pointer (which operates in seconds) to start marker for surgery
        i = startInSeconds

        # FIRST PASS
        # in the first forward pass, we are going to identify the start/end boundaries of all IOH events within the case
        ioh_events_valid = []
        
        while i < track_length_seconds - 60 and i < endInSeconds:
            segmentStart = None
            segmentEnd = None
            segFound = False

            # look forward one minute
            abpSeg = abp[i * ABP_ECG_SRATE_HZ:(i + 60) * ABP_ECG_SRATE_HZ]

            # roll forward until we hit a one minute window where mean ABP >= 65 so we know leads are connected and it's tracking
            if not started:
                if np.nanmean(abpSeg) >= 65:
                    started = True
                    trackStartIndex = i
            # if we're started and mean abp for the window is <65, we are starting a new IOH event
            elif np.nanmean(abpSeg) < 65:
                segmentStart = i
                # now seek forward to find end of event, perpetually checking the lats minute of the IOH event
                for j in range(i + 60, track_length_seconds):
                    # look backward one minute
                    abpSegForward = abp[(j - 60) * ABP_ECG_SRATE_HZ:j * ABP_ECG_SRATE_HZ]
                    if np.nanmean(abpSegForward) >= 65:
                        segmentEnd = j - 1
                        break
                if segmentEnd is None:
                    eofReached = True
                else:
                    # otherwise, end of the IOH segment has been reached, record it
                    iohEvents.append((segmentStart, segmentEnd))
                    segFound = True
                    
                    if skipInvalidIohEvents:
                        isIohSegmentValid = isAbpSegmentValidNumpy(abpSeg)
                        ioh_events_valid.append(isIohSegmentValid)
                        if debug:
                            print(f'{caseid}: ioh segment valid: {isIohSegmentValid}, {segmentStart}, {segmentEnd}, {t_abp.shape}')
                    else:
                        ioh_events_valid.append(True)

            i += 1
            if not started:
                continue
            elif eofReached:
                break
            elif segFound:
                i = segmentEnd + 1

        # SECOND PASS
        # in the second forward pass, we are going to identify the start/end boundaries of all non-overlapping 30 minute "clean" windows
        # reuse the 'start of signal' index from our first pass
        if trackStartIndex is None:
            trackStartIndex = startInSeconds
        i = trackStartIndex
        eofReached = False

        clean_events_valid = []
        
        while i < track_length_seconds - 1800 and i < endInSeconds:
            segmentStart = None
            segmentEnd = None
            segFound = False

            startIndex = i
            endIndex = i + 1800

            # check to see if this 30 minute window overlaps any IOH events, if so ffwd to end of latest overlapping IOH
            overlapFound = False
            latestEnd = None
            for event in iohEvents:
                # case 1: starts during an event
                if startIndex >= event[0] and startIndex < event[1]:
                    latestEnd = event[1]
                    overlapFound = True
                # case 2: ends during an event
                elif endIndex >= event[0] and endIndex < event[1]:
                    latestEnd = event[1]
                    overlapFound = True
                # case 3: event occurs entirely inside of the window
                elif startIndex < event[0] and endIndex > event[1]:
                    latestEnd = event[1]
                    overlapFound = True

            # FFWD if we found an overlap
            if overlapFound:
                i = latestEnd + 1
                continue

            # look forward 30 minutes
            abpSeg = abp[startIndex * ABP_ECG_SRATE_HZ:endIndex * ABP_ECG_SRATE_HZ]

            # if we're started and mean abp for the window is >= 75, we are starting a new clean event
            if np.nanmean(abpSeg) >= 75:
                overlapFound = False
                latestEnd = None
                for event in iohEvents:
                    # case 1: starts during an event
                    if startIndex >= event[0] and startIndex < event[1]:
                        latestEnd = event[1]
                        overlapFound = True
                    # case 2: ends during an event
                    elif endIndex >= event[0] and endIndex < event[1]:
                        latestEnd = event[1]
                        overlapFound = True
                    # case 3: event occurs entirely inside of the window
                    elif startIndex < event[0] and endIndex > event[1]:
                        latestEnd = event[1]
                        overlapFound = True

                if not overlapFound:
                    segFound = True
                    segmentEnd = endIndex
                    cleanEvents.append((startIndex, endIndex))
                    
                    if skipInvalidCleanEvents:
                        isCleanSegmentValid = isAbpSegmentValidNumpy(abpSeg)
                        clean_events_valid.append(isCleanSegmentValid)
                        if debug:
                            print(f'{caseid}: clean segment valid: {isCleanSegmentValid}, {startIndex}, {endIndex}, {abpSeg.shape}')
                    else:
                        clean_events_valid.append(True)

            i += 10
            if segFound:
                i = segmentEnd + 1

        if debug:
            print(f"IOH Events for case {caseid}: {iohEvents}")
            print(f"Clean Events for case {caseid}: {cleanEvents}")

        positiveSegments = []
        negativeSegments = []

        # THIRD PASS
        # in the third pass, we will use the collections of ioh event windows to generate our actual extracted segments based on our prediction window (positive labels)
        for i in range(0, len(iohEvents)):
            # Don't extract segments from invalid IOH event windows.
            if not ioh_events_valid[i]:
                continue

            if debug:
                print(f"Checking event {iohEvents[i]}")
            # we want to review current event boundaries, as well as previous event boundaries if available
            event = iohEvents[i]
            previousEvent = None
            if i > 0:
                previousEvent = iohEvents[i - 1]

            for predWindow in ALL_PREDICTION_WINDOWS:
                if debug:
                    print(f"Checking event {iohEvents[i]} for pred {predWindow}")
                iohEventStart = event[0]
                predictiveSegmentEnd = event[0] - (predWindow*60)
                predictiveSegmentStart = predictiveSegmentEnd - 60

                if (predictiveSegmentStart < 0):
                    # don't rewind before the beginning of the track
                    if debug:
                        print(f"Checking event {iohEvents[i]} for pred {predWindow} - exit, before beginning")
                    continue
                elif (predictiveSegmentStart < trackStartIndex):
                    # don't rewind before the beginning of signal in track
                    if debug:
                        print(f"Checking event {iohEvents[i]} for pred {predWindow} - exit, before track start")
                    continue
                elif previousEvent is not None:
                    # does this event window come before or during the previous event?
                    overlapFound = False
                    # case 1: starts during an event
                    if predictiveSegmentStart >= previousEvent[0] and predictiveSegmentStart < previousEvent[1]:
                        overlapFound = True
                    # case 2: ends during an event
                    elif iohEventStart >= previousEvent[0] and iohEventStart < previousEvent[1]:
                        overlapFound = True
                    # case 3: event occurs entirely inside of the window
                    elif predictiveSegmentStart < previousEvent[0] and iohEventStart > previousEvent[1]:
                        overlapFound = True
                    # do not extract a case if we overlap witha nother IOH
                    if overlapFound:
                        if debug:
                            print(f"Checking event {iohEvents[i]} for pred {predWindow} - exit, overlap with earlier segment")
                        continue

                # track the positive segment
                positiveSegments.append((predictiveSegmentStart, predictiveSegmentEnd, predWindow,
                    abp[predictiveSegmentStart*ABP_ECG_SRATE_HZ:predictiveSegmentEnd*ABP_ECG_SRATE_HZ],
                    ecg[predictiveSegmentStart*ABP_ECG_SRATE_HZ:predictiveSegmentEnd*ABP_ECG_SRATE_HZ],
                    eeg[predictiveSegmentStart*EEG_SRATE_HZ:predictiveSegmentEnd*EEG_SRATE_HZ]))

        # FOURTH PASS
        # in the fourth and final pass, we will use the collections of clean event windows to generate our actual extracted segments based (negative labels)
        for i in range(0, len(cleanEvents)):
            # Don't extract segments from invalid clean event windows.
            if not clean_events_valid[i]:
                continue
            
            # everything will be 30 minutes long at least
            event = cleanEvents[i]
            # choose sample 1 @ 10 minutes
            # choose sample 2 @ 15 minutes
            # choose sample 3 @ 20 minutes
            timeAtTen = event[0] + 600
            timeAtFifteen = event[0] + 900
            timeAtTwenty = event[0] + 1200

            negativeSegments.append((timeAtTen, timeAtTen + 60, 0,
                                   abp[timeAtTen*ABP_ECG_SRATE_HZ:(timeAtTen + 60)*ABP_ECG_SRATE_HZ],
                                   ecg[timeAtTen*ABP_ECG_SRATE_HZ:(timeAtTen + 60)*ABP_ECG_SRATE_HZ],
                                   eeg[timeAtTen*EEG_SRATE_HZ:(timeAtTen + 60)*EEG_SRATE_HZ]))
            negativeSegments.append((timeAtFifteen, timeAtFifteen + 60, 0,
                                   abp[timeAtFifteen*ABP_ECG_SRATE_HZ:(timeAtFifteen + 60)*ABP_ECG_SRATE_HZ],
                                   ecg[timeAtFifteen*ABP_ECG_SRATE_HZ:(timeAtFifteen + 60)*ABP_ECG_SRATE_HZ],
                                   eeg[timeAtFifteen*EEG_SRATE_HZ:(timeAtFifteen + 60)*EEG_SRATE_HZ]))
            negativeSegments.append((timeAtTwenty, timeAtTwenty + 60, 0,
                                   abp[timeAtTwenty*ABP_ECG_SRATE_HZ:(timeAtTwenty + 60)*ABP_ECG_SRATE_HZ],
                                   ecg[timeAtTwenty*ABP_ECG_SRATE_HZ:(timeAtTwenty + 60)*ABP_ECG_SRATE_HZ],
                                   eeg[timeAtTwenty*EEG_SRATE_HZ:(timeAtTwenty + 60)*EEG_SRATE_HZ]))

        if returnSegments:
            positiveSegmentsMap[caseid] = positiveSegments
            negativeSegmentsMap[caseid] = negativeSegments
            iohEventsMap[caseid] = iohEvents
            cleanEventsMap[caseid] = cleanEvents
        
        saveCaseSegments(caseid, positiveSegments, negativeSegments, 9, debug=debug, forceWrite=forceWrite)

        #if debug:
        print(f'{caseid}: positiveSegments: {len(positiveSegments)}, negativeSegments: {len(negativeSegments)}')

    return positiveSegmentsMap, negativeSegmentsMap, iohEventsMap, cleanEventsMap

### Case Extraction - Generage Segments Needed For Training

Ensure that all needed segments are in place for the cases that are being used. If data is already stored on disk this method returns immediately.

In [None]:
MANUAL_EXTRACT=True
SKIP_INVALID_CLEAN_EVENTS=True
SKIP_INVALID_IOH_EVENTS=True

if MANUAL_EXTRACT:
    mycoi = cases_of_interest_idx
    #mycoi = cases_of_interest_idx[:2800]
    #mycoi = [1]

    cnt = 0
    mod = 0
    for ci in mycoi:
        cnt += 1
        if mod % 100 == 0:
            print(f'count processed: {mod}, current case index: {ci}')
        try:
            p, n, i, c = extract_segments([ci], debug=False, checkCache=True, 
                                          forceWrite=True, returnSegments=False, 
                                          skipInvalidCleanEvents=SKIP_INVALID_CLEAN_EVENTS,
                                          skipInvalidIohEvents=SKIP_INVALID_IOH_EVENTS)
            p = None
            n = None
            i = None
            c = None
        except:
            print(f'error on extract segment: {ci}')
        mod += 1
    print(f'extracted: {cnt}')

### Track and Segment Validity Checks

In [None]:
def printAbp(case_id_to_check, plot_invalid_only=False):
        vf_path = f'{VITAL_MINI}/{case_id_to_check:04d}_mini.vital'
        
        if not os.path.isfile(vf_path):
              return
        
        vf = vitaldb.VitalFile(vf_path)
        abp = vf.to_numpy(TRACK_NAMES[0], 1/500)
        
        print(f'Case {case_id_to_check}')
        print(f'ABP Shape: {abp.shape}')

        print(f'nanmin: {np.nanmin(abp)}')
        print(f'nanmean: {np.nanmean(abp)}')
        print(f'nanmax: {np.nanmax(abp)}')
        
        is_valid = isAbpSegmentValidNumpy(abp, debug=True)
        print(f'valid: {is_valid}')

        if plot_invalid_only and is_valid:
            return
            
        plt.figure(figsize=(20, 5))
        plt_color = 'C0' if is_valid else 'red'
        plt.plot(abp, plt_color)
        plt.title(f'ABP - Entire Track - Case {case_id_to_check} - {abp.shape[0] / 500} seconds')
        plt.axhline(y = 65, color = 'maroon', linestyle = '--')
        plt.show()

In [None]:
def printSegments(segmentsMap, case_id_to_check, print_label, normalize=False):
    for (x1, x2, r, abp, ecg, eeg) in segmentsMap[case_id_to_check]:
        print(f'{print_label}: Case {case_id_to_check}')
        print(f'lookback window: {r} min')
        print(f'start time: {x1}')
        print(f'end time: {x2}')
        print(f'length: {x2 - x1} sec')
        
        print(f'ABP Shape: {abp.shape}')
        print(f'ECG Shape: {ecg.shape}')
        print(f'EEG Shape: {eeg.shape}')

        print(f'nanmin: {np.nanmin(abp)}')
        print(f'nanmean: {np.nanmean(abp)}')
        print(f'nanmax: {np.nanmax(abp)}')
        
        is_valid = isAbpSegmentValidNumpy(abp, debug=True)
        print(f'valid: {is_valid}')

        # ABP normalization
        x_abp = np.copy(abp)
        if normalize:
            x_abp -= 65
            x_abp /= 65

        plt.figure(figsize=(20, 5))
        plt_color = 'C0' if is_valid else 'red'
        plt.plot(x_abp, plt_color)
        plt.title('ABP')
        plt.axhline(y = 65, color = 'maroon', linestyle = '--')
        plt.show()

        plt.figure(figsize=(20, 5))
        plt.plot(ecg, 'teal')
        plt.title('ECG')
        plt.show()

        plt.figure(figsize=(20, 5))
        plt.plot(eeg, 'indigo')
        plt.title('EEG')
        plt.show()

        print()

In [None]:
def printEvents(abp_raw, eventsMap, case_id_to_check, print_label, normalize=False):
    for (x1, x2) in eventsMap[case_id_to_check]:
        print(f'{print_label}: Case {case_id_to_check}')
        print(f'start time: {x1}')
        print(f'end time: {x2}')
        print(f'length: {x2 - x1} sec')

        abp = abp_raw[x1*500:x2*500]
        print(f'ABP Shape: {abp.shape}')

        print(f'nanmin: {np.nanmin(abp)}')
        print(f'nanmean: {np.nanmean(abp)}')
        print(f'nanmax: {np.nanmax(abp)}')
        
        is_valid = isAbpSegmentValidNumpy(abp, debug=True)
        print(f'valid: {is_valid}')

        # ABP normalization
        x_abp = np.copy(abp)
        if normalize:
            x_abp -= 65
            x_abp /= 65

        plt.figure(figsize=(20, 5))
        plt_color = 'C0' if is_valid else 'red'
        plt.plot(x_abp, plt_color)
        plt.title('ABP')
        plt.axhline(y = 65, color = 'maroon', linestyle = '--')
        plt.show()

        print()

In [None]:
def moving_average(x, seconds=60):
    w = seconds * 500
    return np.convolve(np.squeeze(x), np.ones(w), 'valid') / w

In [None]:
def printAbpOverlay(
    case_id_to_check,
    positiveSegmentsMap,
    negativeSegmentsMap,
    iohEventsMap,
    cleanEventsMap,
    movingAverage=False
):
    def overlay_segments(plt, segmentsMap, color, linestyle, positive=False):
        for (x1, x2, r, abp, ecg, eeg) in segmentsMap:
            sx1 = x1*500
            sx2 = x2*500
            mycolor = color
            if positive:
                if r == 3:
                    mycolor = 'red'
                elif r == 5:
                    mycolor = 'crimson'
                elif r == 10:
                    mycolor = 'tomato'
                else:
                    mycolor = 'salmon'
            plt.axvline(x = sx1, color = mycolor, linestyle = linestyle)
            plt.axvline(x = sx2, color = mycolor, linestyle = linestyle)
            plt.axvspan(sx1, sx2, facecolor = mycolor, alpha = 0.1)

    def overlay_events(plt, abp, eventsMap, opstart, opend, color, linestyle):
        for (x1, x2) in eventsMap:
            sx1 = x1*500
            sx2 = x2*500
            # only plot valid events
            if isAbpSegmentValidNumpy(abp[sx1:sx2]):
                # that are within the operating start and end times
                if sx1 >= opstart and sx2 <= opend:
                    plt.axvline(x = sx1, color = color, linestyle = linestyle)
                    plt.axvline(x = sx2, color = color, linestyle = linestyle)
                    plt.axvspan(sx1, sx2, facecolor = color, alpha = 0.1)

    vf_path = f'{VITAL_MINI}/{case_id_to_check:04d}_mini.vital'

    if not os.path.isfile(vf_path):
          return

    vf = vitaldb.VitalFile(vf_path)
    abp = vf.to_numpy(TRACK_NAMES[0], 1/500)

    print(f'Case {case_id_to_check}')
    print(f'ABP Shape: {abp.shape}')

    print(f'nanmin: {np.nanmin(abp)}')
    print(f'nanmean: {np.nanmean(abp)}')
    print(f'nanmax: {np.nanmax(abp)}')

    #is_valid = isAbpSegmentValidNumpy(abp, debug=True)
    #print(f'valid: {is_valid}')

    plt.figure(figsize=(24, 8))
    plt_color = 'C0' #if is_valid else 'red'
    plt.plot(abp, plt_color)
    plt.title(f'ABP - Entire Track - Case {case_id_to_check} - {abp.shape[0] / 500} seconds')
    plt.axhline(y = 65, color = 'maroon', linestyle = '--')

    # https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html#linestyles
    
    opstart = cases.loc[case_id_to_check]['opstart'].item() * 500
    plt.axvline(x = opstart, color = 'black', linestyle = '--', linewidth=2)
    plt.text(opstart - 600000, -200, f'Operation Start', fontsize=15)
    
    opend = cases.loc[case_id_to_check]['opend'].item() * 500
    plt.axvline(x = opend, color = 'black', linestyle = '--', linewidth=2)
    plt.text(opend + 50000, -200, r'Operation End', fontsize=15)
    
    overlay_segments(plt, positiveSegmentsMap[case_id_to_check], 'crimson', (0, (1, 1)), positive=True)
    
    overlay_segments(plt, negativeSegmentsMap[case_id_to_check], 'teal', (0, (1, 1)))

    overlay_events(plt, abp, iohEventsMap[case_id_to_check], opstart, opend, 'brown', '-')
    
    overlay_events(plt, abp, cleanEventsMap[case_id_to_check], opstart, opend, 'teal', '-')
    
    abp_mov_avg = None
    if movingAverage:
        abp_mov_avg = moving_average(abp[opstart:(opend + 60*500)])
        myx = np.arange(opstart, opstart + len(abp_mov_avg), 1)
        plt.plot(myx, abp_mov_avg, 'red')

    plt.show()

### Reality Check All Cases

In [None]:
# Global flag to control creating track and segment plots.
# These plots are expensive to create, but very interesting.
# Disable when training in bulk to speed up notebook processing.
PERFORM_TRACK_VALIDITY_CHECKS = False

In [None]:
# Check if all ABPs are well formed. Fast load and scan of the raw track data for ABP.
DISPLAY_REALITY_CHECK_ABP=True
DISPLAY_REALITY_CHECK_ABP_FIRST_ONLY=True

if PERFORM_TRACK_VALIDITY_CHECKS and DISPLAY_REALITY_CHECK_ABP:
    for case_id_to_check in cases_of_interest_idx:
        printAbp(case_id_to_check, plot_invalid_only=False)
        
        if DISPLAY_REALITY_CHECK_ABP_FIRST_ONLY:
            break

### Validate Malformed Vital Files - Missing One Or More Tracks

Cases which were found to be missing one or more data tracks are stored in `malformed_tracks_filter.csv`. These can be analyzed below:

In [None]:
# These are Vital Files removed because of malformed ABP waveforms.
DISPLAY_MALFORMED_ABP=True
DISPLAY_MALFORMED_ABP_FIRST_ONLY=True

if PERFORM_TRACK_VALIDITY_CHECKS and DISPLAY_MALFORMED_ABP:
    malformed_case_ids = pd.read_csv('malformed_tracks_filter.csv', header=None, names=['caseid']).set_index('caseid').index

    for case_id_to_check in malformed_case_ids:
        printAbp(case_id_to_check)
        
        if DISPLAY_MALFORMED_ABP_FIRST_ONLY:
            break

### Validate Cases With No Segments Saved

Cases which were found to not result in any extracted segments can be analyzed below to better understand why:

In [None]:
DISPLAY_NO_SEGMENTS_CASES=True
DISPLAY_NO_SEGMENTS_CASES_FIRST_ONLY=True

if PERFORM_TRACK_VALIDITY_CHECKS and DISPLAY_NO_SEGMENTS_CASES:
    no_segments_case_ids = [3413, 3476, 3533, 3992, 4328, 4648, 4703, 4733, 5130, 5501, 5693, 5908]

    for case_id_to_check in no_segments_case_ids:
        printAbp(case_id_to_check)
        
        if DISPLAY_NO_SEGMENTS_CASES_FIRST_ONLY:
            break

### Select Case For Segment Extraction Validation

Generate segment data for one or more cases. Perform a deep analysis of event and segment quality.

In [None]:
# NOTE: This is always set so that if this section of checks is skipped, the model prediction plots will match.
my_cases_of_interest_idx = [84, 198, 60, 16, 27]

# Note: By default, match extract segments processing block above.
# However, regenerate data real time to allow seeing impacts on segment extraction.
# This is why both checkCache and forceWrite are false by default.
positiveSegmentsMap, negativeSegmentsMap, iohEventsMap, cleanEventsMap = None, None, None, None

if PERFORM_TRACK_VALIDITY_CHECKS:
    positiveSegmentsMap, negativeSegmentsMap, iohEventsMap, cleanEventsMap = \
        extract_segments(my_cases_of_interest_idx, debug=False,
                         checkCache=False, forceWrite=False, returnSegments=True,
                         skipInvalidCleanEvents=SKIP_INVALID_CLEAN_EVENTS,
                         skipInvalidIohEvents=SKIP_INVALID_IOH_EVENTS)

Select a specific case to perform detailed low level analysis.

In [None]:
case_id_to_check = my_cases_of_interest_idx[0]
print(case_id_to_check)
print()

if PERFORM_TRACK_VALIDITY_CHECKS:
    print((
        len(positiveSegmentsMap[case_id_to_check]),
        len(negativeSegmentsMap[case_id_to_check]),
        len(iohEventsMap[case_id_to_check]),
        len(cleanEventsMap[case_id_to_check])
    ))

In [None]:
if PERFORM_TRACK_VALIDITY_CHECKS:
    printAbp(case_id_to_check)

### Positive Events for Case - IOH Events
Used to define the range in front of which positive segments will be extracted. Positive samples happen in front of this region.

In [None]:
tmp_abp = None

if PERFORM_TRACK_VALIDITY_CHECKS:
    tmp_vf_path = f'{VITAL_MINI}/{case_id_to_check:04d}_mini.vital'
    tmp_vf = vitaldb.VitalFile(tmp_vf_path)
    tmp_abp = tmp_vf.to_numpy(TRACK_NAMES[0], 1/500)

In [None]:
if PERFORM_TRACK_VALIDITY_CHECKS:
    printEvents(tmp_abp, iohEventsMap, case_id_to_check, 'IOH Event Segment', normalize=False)

### Negative Events for Case - Non-IOH Events
Used to define the range from in which negative segments will be extracted. Negative samples happen within this region.

In [None]:
if PERFORM_TRACK_VALIDITY_CHECKS:
    printEvents(tmp_abp, cleanEventsMap, case_id_to_check, 'Clean Event Segment', normalize=False)

### Positive Segments for Case - IOH Events Predicted Using These
One minute regions sampled and used for training the model for "positive" events.

In [None]:
if PERFORM_TRACK_VALIDITY_CHECKS:
    printSegments(positiveSegmentsMap, case_id_to_check, 'Positive Segment - IOH Event', normalize=False)

### Negative Segments for Case - Non-IOH Events Predicted Using These
One minute regions sampled and used for training the model for "negative" events.

In [None]:
if PERFORM_TRACK_VALIDITY_CHECKS:
    printSegments(negativeSegmentsMap, case_id_to_check, 'Negative Segment - Non-Event', normalize=False)

### Overlay Plot of All Events and Segments Extracted
For each of the cases in `my_cases_of_interest_idx` overlay the results of event and segment extraction.

In [None]:
DISPLAY_OVERLAY_CHECK_ABP=True
DISPLAY_OVERLAY_CHECK_ABP_FIRST_ONLY=True

if PERFORM_TRACK_VALIDITY_CHECKS and DISPLAY_OVERLAY_CHECK_ABP:
    for case_id_to_check in my_cases_of_interest_idx:
        printAbpOverlay(case_id_to_check, positiveSegmentsMap, 
                        negativeSegmentsMap, iohEventsMap, cleanEventsMap, movingAverage=False)
        
        if DISPLAY_OVERLAY_CHECK_ABP_FIRST_ONLY:
            break

In [None]:
# Memory cleanup
del tmp_abp

## Generate Train/Val/Test Splits

When case segments are stored to disk, the filename is intentionally constructed so that its metadata can be easily reconstructed. The format is as follows: `{case}_{startX}_{predWindow}_{label}.h5`, where `{case}` is the case ID, `{startX}` is the start index of the segment, in seconds, from the start of the `.vital` track, `{predWindow}` is the prediction window, which can be 3, 5, 10 or 15 minutes, and `{label}` is the label indicator of whether the segment is associated with a hypotensive event (`label=1`) or not (`label=0`).

In [None]:
def get_segment_attributes_from_filename(file_path):
    pieces = os.path.basename(file_path).split('_')
    case = int(pieces[0])
    startX = int(pieces[1])
    predWindow = int(pieces[2])
    label = pieces[3].replace('.h5', '')
    return (case, startX, predWindow, label)

In [None]:
count_negative_samples = 0
count_positive_samples = 0

samples = []

seg_folder = f"{VITAL_EXTRACTED_SEGMENTS}"
filenames = [y for x in os.walk(seg_folder) for y in glob(os.path.join(x[0], '*.h5'))]

for filename in filenames:
    (case, start_x, pred_window, label) = get_segment_attributes_from_filename(filename)
    
    # only load cases for cases of interest; this folder could have segments for hundreds of cases
    if case not in cases_of_interest_idx:
        continue

    if pred_window == 0 or pred_window == PREDICTION_WINDOW or PREDICTION_WINDOW == 'ALL':
        #print((case, start_x, pred_window, label))
        if label == 'True':
            count_positive_samples += 1
        else:
            count_negative_samples += 1
        sample = (filename, label)
        samples.append(sample)

print()
print(f"samples loaded:         {len(samples):5} ")
print(f'count negative samples: {count_negative_samples:5}')
print(f'count positive samples: {count_positive_samples:5}')

In [None]:
# Divide by cases
sample_cases = defaultdict(lambda: []) 

for fn, _ in samples:
    (case, start_x, pred_window, label) = get_segment_attributes_from_filename(fn)
    sample_cases[case].append((fn, label))

# understand any missing cases of interest
sample_cases_idx = pd.Index(sample_cases.keys())
missing_case_ids = cases_of_interest_idx.difference(sample_cases_idx)
print(f'cases with no samples: {missing_case_ids.shape[0]}')
print(f'    {missing_case_ids}')

### Split data into training, validation, and test sets

Use 6:1:3 ratio and prevent samples from a single case from being split across different sets

**Note:** number of samples at each time point is not the same, because the first event can occur before the 3/5/10/15 minute mark


In [None]:
# Set target sizes
train_ratio = 0.6
val_ratio = 0.1
test_ratio = 1 - train_ratio - val_ratio # ensure ratios sum to 1

# Split samples into train and other
sample_cases_train, sample_cases_other = train_test_split(list(sample_cases.keys()), test_size=(1 - train_ratio), random_state=RANDOM_SEED)

# Split other into val and test
sample_cases_val, sample_cases_test = train_test_split(sample_cases_other, test_size=(test_ratio / (1 - train_ratio)), random_state=RANDOM_SEED)

# Check how many samples are in each set
print(f'Train/Val/Test Summary by Cases')
print(f"Train cases:  {len(sample_cases_train):5}, ({len(sample_cases_train) / len(sample_cases):.2%})")
print(f"Val cases:    {len(sample_cases_val):5}, ({len(sample_cases_val) / len(sample_cases):.2%})")
print(f"Test cases:   {len(sample_cases_test):5}, ({len(sample_cases_test) / len(sample_cases):.2%})")
print(f"Total cases:  {(len(sample_cases_train) + len(sample_cases_val) + len(sample_cases_test)):5}")

Now that the cases have been split according to the desired ratio, assign all of the segments for each case into the target (train, validation, test) set:

In [None]:
sample_cases_train = set(sample_cases_train)
sample_cases_val = set(sample_cases_val)
sample_cases_test = set(sample_cases_test)

samples_train = []
samples_val = []
samples_test = []

for cid, segs in sample_cases.items():
    if cid in sample_cases_train:
        for seg in segs:
            samples_train.append(seg)
    if cid in sample_cases_val:
        for seg in segs:
            samples_val.append(seg)
    if cid in sample_cases_test:
        for seg in segs:
            samples_test.append(seg)
            
# Check how many samples are in each set
print(f'Train/Val/Test Summary by Events')
print(f"Train events:  {len(samples_train):5}, ({len(samples_train) / len(samples):.2%})")
print(f"Val events:    {len(samples_val):5}, ({len(samples_val) / len(samples):.2%})")
print(f"Test events:   {len(samples_test):5}, ({len(samples_test) / len(samples):.2%})")
print(f"Total events:  {(len(samples_train) + len(samples_val) + len(samples_test)):5}")

### Validate train/val/test Splits

Verify the label distribution in each set:

In [None]:
PRINT_ALL_CASE_SPLIT_DETAILS = False

case_to_sample_distribution = defaultdict(lambda: {'train': [0, 0], 'val': [0, 0], 'test': [0, 0]})

def populate_case_to_sample_distribution(mysamples, idx):
    neg = 0
    pos = 0
    
    for fn, _ in mysamples:
        (case, start_x, pred_window, label) = get_segment_attributes_from_filename(fn)
        slot = 0 if label == 'False' else 1
        case_to_sample_distribution[case][idx][slot] += 1
        if slot == 0:
            neg += 1
        else:
            pos += 1
                
    return (neg, pos)

train_neg, train_pos = populate_case_to_sample_distribution(samples_train, 'train')
val_neg, val_pos     = populate_case_to_sample_distribution(samples_val,   'val')
test_neg, test_pos   = populate_case_to_sample_distribution(samples_test,  'test')

print(f'Total Cases Present: {len(case_to_sample_distribution):5}')
print()

train_tot = train_pos + train_neg
val_tot = val_pos + val_neg
test_tot = test_pos + test_neg
print(f'Train: P: {train_pos:5} ({(train_pos/train_tot):.2}), N: {train_neg:5} ({(train_neg/train_tot):.2})')
print(f'Val:   P: {val_pos:5} ({(val_pos/val_tot):.2}), N: {val_neg:5} ({(val_neg/val_tot):.2})')
print(f'Test:  P: {test_pos:5} ({(test_pos/test_tot):.2}), N: {test_neg:5}  ({(test_neg/test_tot):.2})')
print()

total_pos = train_pos + val_pos + test_pos
total_neg = train_neg + val_neg + test_neg
total = total_pos + total_neg
print(f'P/N Ratio: {(total_pos)}:{(total_neg)}')
print(f'P Percent: {(total_pos/total):.2}')
print(f'N Percent: {(total_neg/total):.2}')
print()

if PRINT_ALL_CASE_SPLIT_DETAILS:
    for ci in sorted(case_to_sample_distribution.keys()):
        print(f'{ci}: {case_to_sample_distribution[ci]}')

Verify that no data has leaked between test sets:

In [None]:
def check_data_leakage(full_data, train_data, val_data, test_data):
    # Convert to sets for easier operations
    full_data_set = set(full_data)
    train_data_set = set(train_data)
    val_data_set = set(val_data)
    test_data_set = set(test_data)

    # Check if train, val, test are subsets of full_data
    if not train_data_set.issubset(full_data_set):
        return "Train data has leakage"
    if not val_data_set.issubset(full_data_set):
        return "Validation data has leakage"
    if not test_data_set.issubset(full_data_set):
        return "Test data has leakage"

    # Check if train, val, test are disjoint
    if train_data_set & val_data_set:
        return "Train and validation data are not disjoint"
    if train_data_set & test_data_set:
        return "Train and test data are not disjoint"
    if val_data_set & test_data_set:
        return "Validation and test data are not disjoint"

    return "No data leakage detected"

print(check_data_leakage(list(sample_cases.keys()), sample_cases_train, sample_cases_val, sample_cases_test))

Create a custom `vitalDataset` class derived from `Dataset` to be used by the data loaders:

In [None]:
# Create vitalDataset class
class vitalDataset(Dataset):
    def __init__(self, samples, normalize_abp=False):
        self.samples = samples
        self.normalize_abp = normalize_abp

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Get metadata for this event
        segment = self.samples[idx]

        file_path = segment[0]
        label = (segment[1] == "True" or segment[1] == "True.vital")

        (abp, ecg, eeg) = get_segment_data(file_path)

        if abp is None or eeg is None or ecg is None:
            return (np.zeros(30000), np.zeros(30000), np.zeros(7680), 0)
        
        if self.normalize_abp:
            abp -= 65
            abp /= 65

        return abp, ecg, eeg, label

NORMALIZE_ABP = False

train_dataset = vitalDataset(samples_train, NORMALIZE_ABP)
val_dataset = vitalDataset(samples_val, NORMALIZE_ABP)
test_dataset = vitalDataset(samples_test, NORMALIZE_ABP)

### Train/val/test Splits Summary Statistics

Analyze the mean value distribution across each dataset in order to study and verify that their characteristics are in line:

In [None]:
def generate_nan_means(mydataset):
    xs = np.zeros(len(mydataset))
    ys = np.zeros(len(mydataset), dtype=int)

    for i, (abp, ecg, eeg, y) in enumerate(iter(mydataset)):
        xs[i] = np.nanmean(abp)
        ys[i] = int(y)

    return pd.DataFrame({'abp_nanmean': xs, 'label': ys})

In [None]:
def generate_nan_means_summaries(tr, va, te, group='all'):
    if group == 'all':
        return pd.DataFrame({
            'train': tr.describe()['abp_nanmean'],
            'validation': va.describe()['abp_nanmean'],
            'test': te.describe()['abp_nanmean']
        })
    
    mytr = tr.reset_index()
    myva = va.reset_index()
    myte = te.reset_index()
    
    label_flag = True if group == 'positive' else False
    
    return pd.DataFrame({
        'train':      mytr[mytr['label'] == label_flag].describe()['abp_nanmean'],
        'validation': myva[myva['label'] == label_flag].describe()['abp_nanmean'],
        'test':       myte[myte['label'] == label_flag].describe()['abp_nanmean']
    })

In [None]:
def plot_nan_means(df, plot_label):
    mydf = df.reset_index()

    maxCases = 'ALL' if MAX_CASES is None else MAX_CASES
    plot_title = f'{plot_label} - ABP nanmean Values, {PREDICTION_WINDOW} Minutes, {maxCases} Cases'
    
    ax = mydf[mydf['label'] == False].plot.scatter(
        x='index', y='abp_nanmean', color='DarkBlue', label='Negative', 
        title=plot_title, figsize=(16,9))

    negative_median = mydf[mydf['label'] == False]['abp_nanmean'].median()
    ax.axhline(y=negative_median, color='DarkBlue', linestyle='--', label='Negative Median')
    
    mydf[mydf['label'] == True].plot.scatter(
        x='index', y='abp_nanmean', color='DarkOrange', label='Positive', ax=ax);
    
    positive_median = mydf[mydf['label'] == True]['abp_nanmean'].median()
    ax.axhline(y=positive_median, color='DarkOrange', linestyle='--', label='Positive Median')
    
    ax.legend(loc='upper right')

In [None]:
def plot_nan_means_hist(df):
    df.plot.hist(column=['abp_nanmean'], by='label', bins=50, figsize=(10, 8));

In [None]:
train_abp_nanmeans = generate_nan_means(train_dataset)
val_abp_nanmeans = generate_nan_means(val_dataset)
test_abp_nanmeans = generate_nan_means(test_dataset)

#### ABP Nanmean Summaries

In [None]:
generate_nan_means_summaries(train_abp_nanmeans, val_abp_nanmeans, test_abp_nanmeans)

In [None]:
generate_nan_means_summaries(train_abp_nanmeans, val_abp_nanmeans, test_abp_nanmeans, group='positive')

In [None]:
generate_nan_means_summaries(train_abp_nanmeans, val_abp_nanmeans, test_abp_nanmeans, group='negative')

#### ABP Nanmean Histograms

In [None]:
plot_nan_means_hist(train_abp_nanmeans)

In [None]:
plot_nan_means_hist(val_abp_nanmeans)

In [None]:
plot_nan_means_hist(test_abp_nanmeans)

#### ABP Nanmean Scatter Plots

In [None]:
plot_nan_means(train_abp_nanmeans, 'Train')

In [None]:
plot_nan_means(val_abp_nanmeans, 'Validation')

In [None]:
plot_nan_means(test_abp_nanmeans, 'Test')

In [None]:
# Memory cleanup
del train_abp_nanmeans
del val_abp_nanmeans
del test_abp_nanmeans

## Classification Studies

Check if data can be easily classified using non-deep learning methods. Create a balanced sample of IOH and non-IOH events and use a simple classifier to see if the data can be easily separated. Datasets which can be easily separated by non-deep learning methods should also be easily classified by deep learning models.

In [None]:
MAX_CLASSIFICATION_SAMPLES = 250
MAX_SAMPLE_SIZE = 1600
classification_sample_size = MAX_SAMPLE_SIZE if len(samples) >= MAX_SAMPLE_SIZE else len(samples)

classification_samples = random.sample(samples, classification_sample_size)

positive_samples = []
negative_samples = []

for sample in classification_samples:
    (sampleAbp, sampleEcg, sampleEeg) = get_segment_data(sample[0])
    
    if sample[1] == "True":
        positive_samples.append([sample[0], True, sampleAbp, sampleEcg, sampleEeg])
    else:
        negative_samples.append([sample[0], False, sampleAbp, sampleEcg, sampleEeg])

positive_samples = pd.DataFrame(positive_samples, columns=["file_path", "segment_label", "segment_abp", "segment_ecg", "segment_eeg"])
negative_samples = pd.DataFrame(negative_samples, columns=["file_path", "segment_label", "segment_abp", "segment_ecg", "segment_eeg"])

total_to_sample_pos = MAX_CLASSIFICATION_SAMPLES if len(positive_samples) >= MAX_CLASSIFICATION_SAMPLES else len(positive_samples)
total_to_sample_neg = MAX_CLASSIFICATION_SAMPLES if len(negative_samples) >= MAX_CLASSIFICATION_SAMPLES else len(negative_samples)

# Select up to 150 random samples where segment_label is True
positive_samples = positive_samples.sample(total_to_sample_pos, random_state=RANDOM_SEED)
# Select up to 150 random samples where segment_label is False
negative_samples = negative_samples.sample(total_to_sample_neg, random_state=RANDOM_SEED)

print(f'positive_samples: {len(positive_samples)}')
print(f'negative_samples: {len(negative_samples)}')

# Combine the positive and negative samples
samples_balanced = pd.concat([positive_samples, negative_samples])

Define function to build data for study. Each waveform field can be enabled or disabled:

In [None]:
def get_x_y(samples, use_abp, use_ecg, use_eeg):
    # Create X and y, using data from `samples_balanced` and the `use_abp`, `use_ecg`, and `use_eeg` variables
    X = []
    y = []
    for i in range(len(samples)):
        row = samples.iloc[i]
        sample = np.array([])
        if use_abp:
            if len(row['segment_abp']) != 30000:
                print(len(row['segment_abp']))
            sample = np.append(sample, row['segment_abp'])
        if use_ecg:
            if len(row['segment_ecg']) != 30000:
                print(len(row['segment_ecg']))
            sample = np.append(sample, row['segment_ecg'])
        if use_eeg:
            if len(row['segment_eeg']) != 7680:
                print(len(row['segment_eeg']))
            sample = np.append(sample, row['segment_eeg'])
        X.append(sample)
        # Convert the label from boolean to 0 or 1
        y.append(int(row['segment_label']))
    return X, y

#### KNN

Define KNN run. This is configurable to enable or disable different data channels so that we can study them individually or together:

In [None]:
N_NEIGHBORS = 20

def run_knn(samples, use_abp, use_ecg, use_eeg):
    # Get samples
    X,y = get_x_y(samples, use_abp, use_ecg, use_eeg)

    # Split samples into train and val
    knn_X_train, knn_X_test, knn_y_train, knn_y_test = train_test_split(X, y, test_size=0.2, random_state=RANDOM_SEED)

    # Normalize the data
    scaler = StandardScaler()
    scaler.fit(knn_X_train)

    knn_X_train = scaler.transform(knn_X_train)
    knn_X_test = scaler.transform(knn_X_test)

    # Initialize the KNN classifier
    knn = KNeighborsClassifier(n_neighbors=N_NEIGHBORS)

    # Train the KNN classifier
    knn.fit(knn_X_train, knn_y_train)

    # Make predictions on the test set
    knn_y_pred = knn.predict(knn_X_test)

    # Evaluate the KNN classifier
    print(f"ABP: {use_abp}, ECG: {use_ecg}, EEG: {use_eeg}")
    print(f"Confusion matrix:\n{confusion_matrix(knn_y_test, knn_y_pred)}")
    print(f"Classification report:\n{classification_report(knn_y_test, knn_y_pred)}")

Study each waveform independently, then ABP+EEG (which had best results in paper), and ABP+ECG+EEG:

In [None]:
run_knn(samples_balanced, use_abp=True, use_ecg=False, use_eeg=False)
run_knn(samples_balanced, use_abp=False, use_ecg=True, use_eeg=False)
run_knn(samples_balanced, use_abp=False, use_ecg=False, use_eeg=True)
run_knn(samples_balanced, use_abp=True, use_ecg=False, use_eeg=True)
run_knn(samples_balanced, use_abp=True, use_ecg=True, use_eeg=True)

Based on the data above, the ABP and ABP+EEG data are somewhat predictive based on the macro average F1-score, the ECG and EEG data are weakly predictive, and ABP+ECG+EEG data somewhat less predictive than either of ABP or ABP+EEG.

Models based on ABP data alone, or ABP+EEG data are expected to train well with good performance. The other signals appear to mostly add noise and are not strongly predictive. This agrees with the results from the paper.

#### t-SNE

Define t-SNE run. This is configurable to enable or disable different data channels so that we can study them individually or together:

In [None]:
def run_tsne(samples, use_abp, use_ecg, use_eeg):
    # Get samples
    X,y = get_x_y(samples, use_abp, use_ecg, use_eeg)
    
    # Convert X and y to numpy arrays
    X = np.array(X)
    y = np.array(y)

    # Run t-SNE on the samples
    tsne = TSNE(n_components=len(np.unique(y)), random_state=RANDOM_SEED)
    X_tsne = tsne.fit_transform(X)
    
    # Create a scatter plot of the t-SNE representation
    plt.figure(figsize=(16, 9))
    plt.title(f"use_abp={use_abp}, use_ecg={use_ecg}, use_eeg={use_eeg}")
    for i, label in enumerate(set(y)):
        plt.scatter(X_tsne[y == label, 0], X_tsne[y == label, 1], label=label)
    plt.legend()
    plt.show()

Study each waveform independently, then ABP+EEG (which had best results in paper), and ABP+ECG+EEG:

In [None]:
run_tsne(samples_balanced, use_abp=True, use_ecg=False, use_eeg=False)
run_tsne(samples_balanced, use_abp=False, use_ecg=True, use_eeg=False)
run_tsne(samples_balanced, use_abp=False, use_ecg=False, use_eeg=True)
run_tsne(samples_balanced, use_abp=True, use_ecg=False, use_eeg=True)
run_tsne(samples_balanced, use_abp=True, use_ecg=True, use_eeg=True)

Based on the plots above, it appears that ABP alone, ABP+EEG and ABP+ECG+EEG are somewhat separable, though with outliers, and should be trainable by our model. The ECG and EEG data are not readily separable from the other data. This agrees with the results from the paper.

In [None]:
# Memory cleanup
del samples_balanced

##   Model

The model implementation is based on the CNN architecture described in Jo Y-Y et al. (2022). It is designed to handle 1, 2, or 3 biosignal waveforms simultaneously, allowing for flexible model configurations based on different combinations of physiological data:
 * ABP alone
 * EEG alone
 * ECG alone
 * ABP + EEG
 * ABP + ECG
 * EEG + ECG
 * ABP + EEG + ECG

### Model Architecture

The architecture, as depicted in Figure 2 from the original paper, utilizes a ResNet-based approach tailored for time-series data from different physiological signals. The model architecture is adapted to handle varying input signal frequencies, with specific hyperparameters for each signal type, particularly EEG, due to its distinct characteristics compared to ABP and ECG. A diagram of the model architecture is shown below:

![Architecture of the hypotension risk prediction model using multiple waveforms](https://journals.plos.org/plosone/article/figure/image?download&size=large&id=10.1371/journal.pone.0272055.g002)

Each input signal is processed through a sequence of 12 7-layer residual blocks, followed by a flattening process and a linear transformation to produce a 32-dimensional feature vector per signal type. These vectors are then concatenated (if multiple signals are used) and passed through two additional linear layers to produce a single output vector, representing the IOH index. A threshold is determined experimentally in order to minimize the differene between the sensitivity and specificity and is applied to this index to perform binary classification for predicting IOH events.

The hyperparameters for the residual blocks are specified in Supplemental Table 1 from the original paper and vary for different signal type.

A forward pass through the model passes through 85 layers before concatenation, followed by two more linear layers and finally a sigmoid activation layer to produce the prediction measure.

### Residual Block Definition

Each residual block consists of the following seven layers:
 
 * Batch normalization
 * ReLU
 * Dropout (0.5)
 * 1D convolution
 * Batch normalization
 * ReLU
 * 1D convolution

Skip connections are included to aid in gradient flow during training, with optional 1D convolution in the skip connection to align dimensions.

#### Residual Block Hyperparameters

The hyperparameters are detailed in Supplemental Table 1 of the original paper. A screenshot of these hyperparameters is provided for reference below:

![Supplemental Table 1 from original paper](<https://github.com/abarrie2/cs598-dlh-project/blob/main/img/table_1_hyperparameters.png?raw=true>)

**Note**: Please be aware of a transcription error in the original paper's Supplemental Table 1 for the ECG+ABP configuration in Residual Blocks 11 and 12, where the output size should be 469 * 6 instead of the reported 496 * 6.

In [None]:
# Define the residual block which is implemented for each biosignal path
class ResidualBlock(nn.Module):
    def __init__(self, in_features: int, out_features: int, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, size_down: bool = False, ignoreSkipConnection: bool = False) -> None:
        super(ResidualBlock, self).__init__()
        
        self.ignoreSkipConnection = ignoreSkipConnection

        # calculate the appropriate padding required to ensure expected sequence lengths out of each residual block
        padding = int((((stride-1)*in_features)-stride+kernel_size)/2)

        self.size_down = size_down
        self.bn1 = nn.BatchNorm1d(in_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=False)
        
        self.residualConv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=False)

        # unclear where in sequence this should take place. Size down expressed in Supplemental table S1
        if self.size_down:
            pool_padding = (1 if (in_features % 2 > 0) else 0)
            self.downsample = nn.MaxPool1d(kernel_size=2, stride=2, padding = pool_padding)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv1(out)

        if self.size_down:
            out = self.downsample(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        if not self.ignoreSkipConnection:
          if out.shape != identity.shape:
              # run the residual through a convolution when necessary
              identity = self.residualConv(identity)

              outlen = np.prod(out.shape)
              idlen = np.prod(identity.shape)
              # downsample when required
              if idlen > outlen:
                  identity = self.downsample(identity)
              # match dimensions
              identity = identity.reshape(out.shape)

          # add the residual       
          out += identity

        return  out

# Define the parameterizable model
class HypotensionCNN(nn.Module):
    def __init__(self, useAbp: bool = True, useEeg: bool = False, useEcg: bool = False, device: str = "cpu", nResiduals: int = 12, ignoreSkipConnection: bool = False, useSigmoid: bool = True) -> None:
        assert useAbp or useEeg or useEcg, "At least one data track must be used"
        assert nResiduals > 0 and nResiduals <= 12, "Number of residual blocks must be between 1 and 12"
        super(HypotensionCNN, self).__init__()

        self.device = device

        self.useAbp = useAbp
        self.useEeg = useEeg
        self.useEcg = useEcg
        self.nResiduals = nResiduals
        self.useSigmoid = useSigmoid

        # Size of the concatenated output from the residual blocks
        concatSize = 0

        if useAbp:
          self.abpBlocks = []
          self.abpMultipliers = [1, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 6, 6]
          self.abpSizes = [30000, 15000, 15000, 7500, 7500, 3750, 3750, 1875, 1875, 938, 938, 469, 469]
          for i in range(self.nResiduals):
            downsample = i % 2 == 0
            self.abpBlocks.append(ResidualBlock(self.abpSizes[i], self.abpSizes[i+1], self.abpMultipliers[i], self.abpMultipliers[i+1], 15 if i < 6 else 7, 1, downsample, ignoreSkipConnection))
          self.abpResiduals = nn.Sequential(*self.abpBlocks)
          self.abpFc = nn.Linear(self.abpMultipliers[self.nResiduals] * self.abpSizes[self.nResiduals], 32)
          concatSize += 32
        
        if useEcg:
          self.ecgBlocks = []
          self.ecgMultipliers = [1, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 6, 6]
          self.ecgSizes = [30000, 15000, 15000, 7500, 7500, 3750, 3750, 1875, 1875, 938, 938, 469, 469]

          for i in range(self.nResiduals):
            downsample = i % 2 == 0
            self.ecgBlocks.append(ResidualBlock(self.ecgSizes[i], self.ecgSizes[i+1], self.ecgMultipliers[i], self.ecgMultipliers[i+1], 15 if i < 6 else 7, 1, downsample, ignoreSkipConnection))
          self.ecgResiduals = nn.Sequential(*self.ecgBlocks)
          self.ecgFc = nn.Linear(self.ecgMultipliers[self.nResiduals] * self.ecgSizes[self.nResiduals], 32)
          concatSize += 32

        if useEeg:
          self.eegBlocks = []
          self.eegMultipliers = [1, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 6, 6]
          self.eegSizes = [7680, 3840, 3840, 1920, 1920, 960, 960, 480, 480, 240, 240, 120, 120]

          for i in range(self.nResiduals):
            downsample = i % 2 == 0
            self.eegBlocks.append(ResidualBlock(self.eegSizes[i], self.eegSizes[i+1], self.eegMultipliers[i], self.eegMultipliers[i+1], 7 if i < 6 else 3, 1, downsample, ignoreSkipConnection))
          self.eegResiduals = nn.Sequential(*self.eegBlocks)
          self.eegFc = nn.Linear(self.eegMultipliers[self.nResiduals] * self.eegSizes[self.nResiduals], 32)
          concatSize += 32

        # The fullLinear1 layer accepts the outputs of the concatenation of the ResidualBlocks from each biosignal path
        self.fullLinear1 = nn.Linear(concatSize, 16)
        self.fullLinear2 = nn.Linear(16, 1)
        self.sigmoid = nn.Sigmoid()


    def forward(self, abp: torch.Tensor, eeg: torch.Tensor, ecg: torch.Tensor) -> torch.Tensor:
        batchSize = len(abp)

        # conditionally operate ABP, EEG, and ECG networks
        tensors = []
        if self.useAbp:
          self.abpResiduals.to(self.device)
          abp = self.abpResiduals(abp)
          totalLen = np.prod(abp.shape)
          abp = torch.reshape(abp, (batchSize, int(totalLen / batchSize)))
          abp = self.abpFc(abp)
          tensors.append(abp)

        if self.useEeg:
          self.eegResiduals.to(self.device)
          eeg = self.eegResiduals(eeg)
          totalLen = np.prod(eeg.shape)
          eeg = torch.reshape(eeg, (batchSize, int(totalLen / batchSize)))
          eeg = self.eegFc(eeg)
          tensors.append(eeg)
        
        if self.useEcg:
          self.ecgResiduals.to(self.device)
          ecg = self.ecgResiduals(ecg)
          totalLen = np.prod(ecg.shape)
          ecg = torch.reshape(ecg, (batchSize, int(totalLen / batchSize)))
          ecg = self.ecgFc(ecg)
          tensors.append(ecg)

        # concatenate the tensors along dimension 1 if there's more than one, otherwise use the single tensor
        merged = torch.cat(tensors, dim=1) if len(tensors) > 1 else tensors[0]

        totalLen = np.prod(merged.shape)
        merged = torch.reshape(merged, (batchSize, int(totalLen / batchSize)))
        out = self.fullLinear1(merged)
        out = self.fullLinear2(out)
        # Skip the final model sigmoid when using BCEWithLogitsLoss loss function
        if self.useSigmoid:
            out = self.sigmoid(out)

        return out

### Training

The training loop is highly parameterizable, and all aspects can be configured. The original paper uses binary cross entropy as the loss function with Adam as the optimizer, a learning rate of 0.0001, and with training configured to run for up to 100 epochs, with early stopping implemented if no improvement in loss is observed over five consecutive epochs. Our models were run with the same parameters, but longer patience values to account for the noisier and smaller dataset that we had access to.

Define a function to train the model for one epoch. Collect the losses so the mean can be reported.

In [None]:
def train_model_one_iter(model, device, loss_func, optimizer, train_loader):
    model.train()
    train_losses = []
    
    for abp, ecg, eeg, label in tqdm(train_loader):
        batch = len(abp)
        abp = abp.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
        ecg = ecg.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
        eeg = eeg.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
        label = label.type(torch.float).reshape(batch, 1).to(device)

        optimizer.zero_grad()
        mdl = model(abp, eeg, ecg)
        loss = loss_func(torch.nan_to_num(mdl), label)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.cpu().data.numpy())
    return np.mean(train_losses)

Evaluate the model using the the provided loss function. This is typically called on the validation dataset at each epoch:

In [None]:
def evaluate_model(model, loss_func, val_loader):
    model.eval()
    val_losses = []
    for abp, ecg, eeg, label in tqdm(val_loader):
        batch = len(abp)

        abp = abp.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
        ecg = ecg.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
        eeg = eeg.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
        label = label.type(torch.float).reshape(batch, 1).to(device)

        mdl = model(abp, eeg, ecg)
        loss = loss_func(torch.nan_to_num(mdl), label)
        val_losses.append(loss.cpu().data.numpy())
    return np.mean(val_losses)

Define a function to plot the training and validation losses from the entire training run and indicate at which epoch the validation loss was minimized. This is typically `patience` epochs before the end of training:

In [None]:
def plot_losses(train_losses, val_losses, best_epoch, experimentName):
    print()
    print(f'Plot Validation and Loss Values from Training')
    print(f'  Epoch with best Validation Loss:  {best_epoch:3}, {val_losses[best_epoch]:.4}')

    # Create x-axis values for epochs
    epochs = range(0, len(train_losses))

    plt.figure(figsize=(16, 9))

    # Plot the training and validation losses
    plt.plot(epochs, train_losses, 'b', label='Training Loss')
    plt.plot(epochs, val_losses, 'r', label='Validation Loss')

    # Add a vertical bar at the best_epoch
    plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best Epoch')

    # Shade everything to the right of the best_epoch a light red
    plt.axvspan(best_epoch, max(epochs), facecolor='r', alpha=0.1)

    # Add labels and title
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title(experimentName)

    # Add legend
    plt.legend(loc='upper right')

    # Save plot to disk
    plt.savefig(os.path.join(VITAL_RUNS, f'{experimentName}_losses.png'))

    # Show the plot
    plt.show()


Define a function to calculate the complete performance metric profile of a model. As in the original paper, the threshold is found as the argmin of the &Delta;(sensitivity, specificity):

In [None]:
def eval_model(model, device, dataloader, loss_func, print_detailed: bool = False):
    model.eval()
    model = model.to(device)
    total_loss = 0
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for abp, ecg, eeg, label in tqdm(dataloader):
            batch = len(abp)
    
            abp = torch.nan_to_num(abp.reshape(batch, 1, -1)).type(torch.FloatTensor).to(device)
            ecg = torch.nan_to_num(ecg.reshape(batch, 1, -1)).type(torch.FloatTensor).to(device)
            eeg = torch.nan_to_num(eeg.reshape(batch, 1, -1)).type(torch.FloatTensor).to(device)
            label = label.type(torch.float).reshape(batch, 1).to(device)
   
            pred = model(abp, eeg, ecg)
            loss = loss_func(pred, label)
            total_loss += loss.item()

            all_predictions.append(pred.detach().cpu().numpy())
            all_labels.append(label.detach().cpu().numpy())

    # Flatten the lists
    all_predictions = np.concatenate(all_predictions).flatten()
    all_labels = np.concatenate(all_labels).flatten()

    # Calculate AUROC and AUPRC
    # y_true, y_pred
    auroc = roc_auc_score(all_labels, all_predictions)
    precision, recall, _ = precision_recall_curve(all_labels, all_predictions)
    auprc = auc(recall, precision)

    # Determine the optimal threshold, which is argmin(abs(sensitivity - specificity)) per the paper
    thresholds = np.linspace(0, 1, 101) # 0 to 1 in 0.01 steps
    min_diff = float('inf')
    optimal_sensitivity = None
    optimal_specificity = None
    optimal_threshold = None

    for threshold in thresholds:
        all_predictions_binary = (all_predictions > threshold).astype(int)

        tn, fp, fn, tp = confusion_matrix(all_labels, all_predictions_binary).ravel()
        sensitivity = tp / (tp + fn)
        specificity = tn / (tn + fp)
        diff = abs(sensitivity - specificity)

        if diff < min_diff:
            min_diff = diff
            optimal_threshold = threshold
            optimal_sensitivity = sensitivity
            optimal_specificity = specificity

    avg_loss = total_loss / len(dataloader)
    
    # accuracy
    predictions_binary = (all_predictions > optimal_threshold).astype(int)
    accuracy = np.mean(predictions_binary == all_labels)

    if print_detailed:
        print(f"Predictions: {all_predictions}")
        print(f"Labels: {all_labels}")
    print(f"Loss: {avg_loss}")
    print(f"AUROC: {auroc}")
    print(f"AUPRC: {auprc}")
    print(f"Sensitivity: {optimal_sensitivity}")
    print(f"Specificity: {optimal_specificity}")
    print(f"Threshold: {optimal_threshold}")
    print(f"Accuracy:  {accuracy}")

    return all_predictions, all_labels, avg_loss, auroc, auprc, \
        optimal_sensitivity, optimal_specificity, optimal_threshold, accuracy

Define a function to calculate and print the AUROC and AURPC values for each epoch of a training run: 

In [None]:
def print_all_evals(model, models, device, val_loader, test_loader, loss_func, print_detailed: bool = False):
    print()
    print(f'Generate AUROC/AUPRC for Each Intermediate Model')
    print()
    val_aurocs = []
    val_auprcs = []
    val_accs   = []

    test_aurocs = []
    test_auprcs = []
    test_accs   = []

    for mod in models:
        model.load_state_dict(torch.load(mod))
        #model.train(False)
        model.eval()
        print(f'Intermediate Model:')
        print(f'  {mod}')
    
        # validation loop
        print("AUROC/AUPRC on Validation Data")
        all_predictions, all_labels, avg_loss, valid_auroc, valid_auprc, \
        optimal_sensitivity, optimal_specificity, optimal_threshold, valid_accuracy = \
            eval_model(model, device, val_loader, loss_func, print_detailed)

        val_aurocs.append(valid_auroc)
        val_auprcs.append(valid_auprc)
        val_accs.append(valid_accuracy)
        print()
    
        # test loop
        print("AUROC/AUPRC on Test Data")
        all_predictions, all_labels, avg_loss, test_auroc, test_auprc, \
        optimal_sensitivity, optimal_specificity, optimal_threshold, test_accuracy = \
            eval_model(model, device, test_loader, loss_func, print_detailed)

        test_aurocs.append(test_auroc)
        test_auprcs.append(test_auprc)
        test_accs.append(test_accuracy)
        print()
    
    return val_aurocs, val_auprcs, val_accs, test_aurocs, test_auprcs, test_accs

Define a function to plot the AUROC, AUPRC and accuracy at each epoch and print the parameters for the best epoch on validation loss, AUROC and accuracy:

In [None]:
def plot_auroc_auprc(val_losses, val_aurocs, val_auprcs, val_accs, 
                                      test_aurocs, test_auprcs, test_accs, all_models, best_epoch, experimentName):
    print()
    print(f'Plot AUROC/AUPRC for Each Intermediate Model')
    
    # Create x-axis values for epochs
    epochs = range(0, len(val_aurocs))

    # Find model with highest AUROC
    np_test_aurocs = np.array(test_aurocs)
    test_auroc_idx = np.argmax(np_test_aurocs)
    test_accs_idx  = np.argmax(test_accs)

    print(f'  Epoch with best Validation Loss:     {best_epoch:3}, {val_losses[best_epoch]:.4}')
    print(f'  Epoch with best model Test AUROC:    {test_auroc_idx:3}, {np_test_aurocs[test_auroc_idx]:.4}')
    print(f'  Epoch with best model Test Accuracy: {test_accs_idx:3}, {test_accs[test_accs_idx]:.4}')
    print()

    plt.figure(figsize=(16, 9))

    # Plots
    plt.plot(epochs, val_aurocs, 'C0', label='AUROC - Validation')
    plt.plot(epochs, test_aurocs, 'C1', label='AUROC - Test')

    plt.plot(epochs, val_auprcs, 'C2', label='AUPRC - Validation')
    plt.plot(epochs, test_auprcs, 'C3', label='AUPRC - Test')
    
    plt.plot(epochs, val_accs, 'C4', label='Accuracy - Validation')
    plt.plot(epochs, test_accs, 'C5', label='Accuracy - Test')

    # Add vertical bars
    plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best Epoch - Validation Loss')
    plt.axvline(x=test_auroc_idx, color='maroon', linestyle='--', label='Best Epoch - Test AUROC')
    plt.axvline(x=test_accs_idx, color='violet', linestyle='--', label='Best Epoch - Test Accuracy')

    # Shade everything to the right of the best_model a light red
    plt.axvspan(test_auroc_idx, max(epochs), facecolor='r', alpha=0.1)

    # Add labels and title
    plt.xlabel('Epochs')
    plt.ylabel('AUROC / AUPRC')
    plt.title('Validation and Test AUROC and AUPRC by Model Iteration Across Training')

    # Add legend
    plt.legend(loc='right')

    # Save plot to disk
    plt.savefig(os.path.join(VITAL_RUNS, f'{experimentName}_all_stats.png'))
    
    # Show the plot
    plt.show()

    return np_test_aurocs, test_auroc_idx

Define a function to make predictions on a given:

In [None]:
# applies the model to a given real case to generate predictions
def predictionsForModel(case_id_to_check, my_model, my_model_state, device, ready_model=None):
    (abp, ecg, eeg, event) = get_track_data(case_id_to_check)
    
    opstart = cases.loc[case_id_to_check]['opstart'].item()
    opend = cases.loc[case_id_to_check]['opend'].item()

    abp = abp[opstart*500:opend*500]
    ecg = ecg[opstart*500:opend*500]
    eeg = eeg[opstart*128:opend*128]
    
    # number of one minute segments in each track
    splits_abp = abp.shape[0] // (60 * 500)
    splits_ecg = ecg.shape[0] // (60 * 500)
    splits_eeg = eeg.shape[0] // (60 * 128)
    
    # predict as long as each track has data in the prediction window
    splits = np.min([splits_abp, splits_ecg, splits_eeg])
    
    preds = []
    
    the_model = None
    
    if ready_model is None:
        my_model.load_state_dict(torch.load(my_model_state))
        my_model.eval()
        my_model = my_model.to(device)
        the_model = my_model
    else:
        ready_model.eval()
        ready_model = ready_model.to(device)
        the_model = ready_model
    
    for i in range(splits):
        t_abp = abp[i*60*500:(i + 1)*60*500]
        t_ecg = ecg[i*60*500:(i + 1)*60*500]
        t_eeg = eeg[i*60*128:(i + 1)*60*128]
    
        if len(t_abp) < 30000:
            t_abp = np.resize(t_abp, (30000))
            
        if len(t_ecg) < 30000:
            t_ecg = np.resize(t_ecg, (30000))
            
        if len(t_eeg) < 7680:
            t_eeg = np.resize(t_eeg, (7680))
            
        t_abp = torch.from_numpy(t_abp)
        t_ecg = torch.from_numpy(t_ecg)
        t_eeg = torch.from_numpy(t_eeg)
        
        t_abp = torch.nan_to_num(t_abp.reshape(1, 1, -1)).type(torch.FloatTensor).to(device)
        t_ecg = torch.nan_to_num(t_ecg.reshape(1, 1, -1)).type(torch.FloatTensor).to(device)
        t_eeg = torch.nan_to_num(t_eeg.reshape(1, 1, -1)).type(torch.FloatTensor).to(device)

        pred = the_model(t_abp, t_eeg, t_ecg)
        preds.append(pred.detach().cpu().numpy())
    
    return np.concatenate(preds).flatten()

Define a function to plot the mean ABP and predictions for a case:

In [None]:
def printModelPrediction(case_id_to_check, preds, experimentName):  
    (abp, ecg, eeg, event) = get_track_data(case_id_to_check)
    
    opstart = cases.loc[case_id_to_check]['opstart'].item()
    opend = cases.loc[case_id_to_check]['opend'].item()
    minutes = (opend - opstart) / 60
    
    plt.figure(figsize=(24, 8))
    plt.margins(0)
    plt.title(f'ABP - Mean Arterial Pressure - Case: {case_id_to_check} - Operating Time: {minutes} minutes')
    plt.axhline(y = 65, color = 'maroon', linestyle = '--')
    
    opstart = opstart * 500
    opend = opend * 500
    
    minute_step = 5
    
    abp_mov_avg = moving_average(abp[opstart:(opend + 60*500)])
    myx = np.arange(opstart, opstart + len(abp_mov_avg), 1)
    plt.plot(myx, abp_mov_avg, 'purple')
    x_ticks = np.arange(opstart, opend, step=minute_step*30000)
    x_labels = [str(i*minute_step) for i in range(len(x_ticks))]
    plt.xticks(x_ticks, labels=x_labels)
    if experimentName is not None:
        plt.savefig(os.path.join(VITAL_RUNS, f'{experimentName}_{case_id_to_check:04d}_surgery_map.png'))
    plt.show()
    
    plt.figure(figsize=(24, 8))
    plt.margins(0)
    plt.title(f'Model Predictions for One Minute Intervals Using {PREDICTION_WINDOW} Minute Prediction Window')
    plt.plot(preds)
    x_ticks = np.arange(0, len(preds), step=minute_step)
    x_labels = [str(i*minute_step) for i in range(len(x_ticks))]
    plt.xticks(x_ticks, labels=x_labels)
    if experimentName is not None:
        plt.savefig(os.path.join(VITAL_RUNS, f'{experimentName}_{case_id_to_check:04d}_surgery_predictions.png'))
    plt.show()
    
    return preds

Define a function to run an experiment, which includes training a model and evaluating it.

In [None]:
def run_experiment(
    experimentNamePrefix: str = None,
    useAbp: bool = True, 
    useEeg: bool = False, 
    useEcg: bool = False, 
    nResiduals: int = 12, 
    skip_connection: bool = False, 
    batch_size: int = 64, 
    learning_rate: float = 1e-4, 
    weight_decay: float = 0.0, 
    pos_weight: float = None,
    max_epochs: int = 100, 
    patience: int = 25, 
    device: str = "cpu"
):
    reset_random_state()

    time_start = timer()

    experimentName = ""

    experimentOptions = [experimentNamePrefix, 'ABP', 'EEG', 'ECG', 'SKIPCONNECTION']
    experimentValues = [experimentNamePrefix is not None, useAbp, useEeg, useEcg, skip_connection]
    experimentFlags = [name for name, value in zip(experimentOptions, experimentValues) if value]
    if experimentFlags:
        experimentName = "_".join(experimentFlags)

    experimentName = f"{experimentName}_{nResiduals}_RESIDUAL_BLOCKS_{batch_size}_BATCH_SIZE_{learning_rate:.0e}_LEARNING_RATE"

    if weight_decay is not None and weight_decay != 0.0:
        experimentName = f"{experimentName}_{weight_decay:.0e}_WEIGHT_DECAY"

    predictionWindow = 'ALL' if PREDICTION_WINDOW == 'ALL' else f'{PREDICTION_WINDOW:03}'
    experimentName = f"{experimentName}_{predictionWindow}_MINS"

    maxCases = '_ALL' if MAX_CASES is None else f'{MAX_CASES:04}'
    experimentName = f"{experimentName}_{maxCases}_MAX_CASES"
    
    # Add unique uuid8 suffix to experiment name
    experimentName = f"{experimentName}_{uuid.uuid4().hex[:8]}"

    # Fork stdout to file and console
    with ForkedStdout(os.path.join(VITAL_RUNS, f'{experimentName}.log')):
        print(f"Experiment Setup")
        print(f'  name:              {experimentName}')
        print(f'  prediction_window: {predictionWindow}')
        print(f'  max_cases:         {maxCases}')
        print(f'  use_abp:           {useAbp}')
        print(f'  use_eeg:           {useEeg}')
        print(f'  use_ecg:           {useEcg}')
        print(f'  n_residuals:       {nResiduals}')
        print(f'  skip_connection:   {skip_connection}')
        print(f'  batch_size:        {batch_size}')
        print(f'  learning_rate:     {learning_rate}')
        print(f'  weight_decay:      {weight_decay}')
        if pos_weight is not None:
            print(f'  pos_weight:        {pos_weight}')
        print(f'  max_epochs:        {max_epochs}')
        print(f'  patience:          {patience}')
        print(f'  device:            {device}')
        print()

        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

        # Disable final sigmoid activation for BCEWithLogitsLoss
        model = HypotensionCNN(useAbp, useEeg, useEcg, device, nResiduals, skip_connection, useSigmoid=(pos_weight is None))
        model = model.to(device)
    
        if pos_weight is not None:
            # Apply weights to positive class
            loss_func = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]).to(device))
        else:
            loss_func = nn.BCELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    
        print(f'Model Architecture')
        print(model)
        print()

        print(f'Training Loop')
        # Training loop
        best_epoch = 0
        train_losses = []
        val_losses = []
        best_loss = float('inf')
        no_improve_epochs = 0
        model_path = os.path.join(VITAL_MODELS, f"{experimentName}.model")

        all_models = []

        for i in range(max_epochs):
            # Train the model and get the training loss
            train_loss = train_model_one_iter(model, device, loss_func, optimizer, train_loader)
            train_losses.append(train_loss)
            # Calculate validate loss
            val_loss = evaluate_model(model, loss_func, val_loader)
            val_losses.append(val_loss)
            print(f"[{datetime.now()}] Completed epoch {i} with training loss {train_loss:.8f}, validation loss {val_loss:.8f}")

            # Save all intermediary models.
            tmp_model_path = os.path.join(VITAL_MODELS, f"{experimentName}_{i:04d}.model")
            torch.save(model.state_dict(), tmp_model_path)
            all_models.append(tmp_model_path)
  
            # Check if validation loss has improved
            if val_loss < best_loss:
                best_epoch = i
                best_loss = val_loss
                no_improve_epochs = 0
                torch.save(model.state_dict(), model_path)
                print(f"Validation loss improved to {val_loss:.8f}. Model saved.")
            else:
                no_improve_epochs += 1
                print(f"No improvement in validation loss. {no_improve_epochs} epochs without improvement.")

            # exit early if no improvement in loss over last 'patience' epochs
            if no_improve_epochs >= patience:
                print("Early stopping due to no improvement in validation loss.")
                break

        # Load best model from disk
        #print()
        #if os.path.exists(model_path):
        #    model.load_state_dict(torch.load(model_path))
        #    print(f"Loaded best model from disk from epoch {best_epoch}.")
        #else:
        #    print("No saved model found for f{experimentName}.")

        #model.train(False)

        # Plot the training and validation losses across all training epochs.
        plot_losses(train_losses, val_losses, best_epoch, experimentName)

        # Generate AUROC/AUPRC for each intermediate model generated across training epochs.
        val_aurocs, val_auprcs, val_accs, test_aurocs, test_auprcs, test_accs = \
            print_all_evals(model, all_models, device, val_loader, test_loader, loss_func, print_detailed=False)

        # Find model with highest AUROC. Plot AUROC/AUPRC across all epochs.
        np_test_aurocs, test_auroc_idx = plot_auroc_auprc(val_losses, val_aurocs, val_auprcs, val_accs, \
                                        test_aurocs, test_auprcs, test_accs, all_models, best_epoch, experimentName)

        ## AUROC / AUPRC - Model with Best Validation Loss
        best_model_val_loss = all_models[best_epoch]
    
        print(f'AUROC/AUPRC Plots - Best Model Based on Validation Loss')
        print(f'  Epoch with best Validation Loss:  {best_epoch:3}, {val_losses[best_epoch]:.4}')
        print(f'  Best Model Based on Validation Loss:')
        print(f'    {best_model_val_loss}')
        print()
        print(f'Generate Stats Based on Test Data')
        model.load_state_dict(torch.load(best_model_val_loss))
        #model.train(False)
        model.eval()
    
        best_model_val_test_predictions, best_model_val_test_labels, test_loss, \
            best_model_val_test_auroc, best_model_val_test_auprc, test_sensitivity, test_specificity, \
            best_model_val_test_threshold, best_model_val_accuracy = \
                eval_model(model, device, test_loader, loss_func, print_detailed=False)

        # y_test, y_pred
        display = RocCurveDisplay.from_predictions(
            best_model_val_test_labels,
            best_model_val_test_predictions,
            plot_chance_level=True
        )
        # Save plot to disk and show
        plt.savefig(os.path.join(VITAL_RUNS, f'{experimentName}_val_auroc.png'))
        plt.show()

        print(f'best_model_val_test_auroc: {best_model_val_test_auroc}')

        # Save best model in its entirety
        torch.save(model, os.path.join(VITAL_MODELS, f'{experimentName}_full.model'))

        best_model_val_test_predictions_binary = \
        (best_model_val_test_predictions > best_model_val_test_threshold).astype(int)

        # y_test, y_pred
        display = PrecisionRecallDisplay.from_predictions(
            best_model_val_test_labels, 
            best_model_val_test_predictions_binary,
            plot_chance_level=True
        )
        # Save plot to disk and show
        plt.savefig(os.path.join(VITAL_RUNS, f'{experimentName}_val_auprc.png'))
        plt.show()

        print(f'best_model_val_test_auprc: {best_model_val_test_auprc}')
        print()

        ## AUROC / AUPRC - Model with Best AUROC
        # Find model with highest AUROC
        best_model_auroc = all_models[test_auroc_idx]

        print(f'AUROC/AUPRC Plots - Best Model Based on Model AUROC')
        print(f'  Epoch with best model Test AUROC: {test_auroc_idx:3}, {np_test_aurocs[test_auroc_idx]:.4}')
        print(f'  Best Model Based on Model AUROC:')
        print(f'    {best_model_auroc}')
        print()
        print(f'Generate Stats Based on Test Data')
        model.load_state_dict(torch.load(best_model_auroc))
        #model.train(False)
        model.eval()
    
        best_model_auroc_test_predictions, best_model_auroc_test_labels, test_loss, \
            best_model_auroc_test_auroc, best_model_auroc_test_auprc, test_sensitivity, test_specificity, \
            best_model_auroc_test_threshold, best_model_auroc_accuracy = \
                eval_model(model, device, test_loader, loss_func, print_detailed=False)

        # y_test, y_pred
        display = RocCurveDisplay.from_predictions(
            best_model_auroc_test_labels,
            best_model_auroc_test_predictions,
            plot_chance_level=True
        )
        # Save plot to disk and show
        plt.savefig(os.path.join(VITAL_RUNS, f'{experimentName}_auroc_auroc.png'))
        plt.show()

        print(f'best_model_auroc_test_auroc: {best_model_auroc_test_auroc}')

        best_model_auroc_test_predictions_binary = \
            (best_model_auroc_test_predictions > best_model_auroc_test_threshold).astype(int)

        # y_test, y_pred
        display = PrecisionRecallDisplay.from_predictions(
            best_model_auroc_test_labels, 
            best_model_auroc_test_predictions_binary,
            plot_chance_level=True
        )
        # Save plot to disk and show
        plt.savefig(os.path.join(VITAL_RUNS, f'{experimentName}_auroc_auprc.png'))
        plt.show()

        print(f"best_model_auroc_test_auprc: {best_model_auroc_test_auprc}")
        print()
        
        time_delta = np.round(timer() - time_start, 3)
        print(f'Total Processing Time: {time_delta:.4f} sec')
        
    return (model, best_model_val_loss, best_model_auroc, experimentName)

## Experiments

In [None]:
# When false, run only the first experiment below and then stop
SWEEP_ALL = True

### Data tracks

Run experiments across the biosignal data track combinations:
- ABP
- ECG
- EEG
- ABP+ECG
- ABP+EEG
- ECG+EEG
- ABP+ECG+EEG

The first experiment acts as a baseline.

In [None]:
ENABLE_EXPERIMENT = True
DISPLAY_MODEL_PREDICTION=True
DISPLAY_MODEL_PREDICTION_FIRST_ONLY=True

MAX_EPOCHS=200
PATIENCE=20

data_tracks = [
    # useAbp, useEeg, useEcg, experiement enable
    [True, False, False, True], # ABP only
    [False, False, True, SWEEP_ALL], # ECG only
    [False, True, False, SWEEP_ALL], # EEG only
    [True, False, True, SWEEP_ALL], # ABP + ECG
    [True, True, False, SWEEP_ALL], # ABP + EEG
    [False, True, True, SWEEP_ALL], # ECG + EEG
    [True, True, True, SWEEP_ALL] # ABP + ECG + EEG
]

if ENABLE_EXPERIMENT:
    for (useAbp, useEeg, useEcg, enable) in data_tracks:
        if enable:
            (model, best_model_val_loss, best_model_auroc, experimentName) = run_experiment(
                experimentNamePrefix=None, 
                useAbp=useAbp, 
                useEeg=useEeg, 
                useEcg=useEcg,
                nResiduals=12, 
                skip_connection=False,
                batch_size=128,
                learning_rate=1e-4,
                weight_decay=1e-1,
                pos_weight=None,
                max_epochs=MAX_EPOCHS,
                patience=PATIENCE,
                device=device
            )

            if DISPLAY_MODEL_PREDICTION:
                for case_id_to_check in my_cases_of_interest_idx:
                    preds = predictionsForModel(case_id_to_check, model, best_model_val_loss, device)
                    printModelPrediction(case_id_to_check, preds, experimentName)

                    if DISPLAY_MODEL_PREDICTION_FIRST_ONLY:
                        break

### Hyperparameter search

#### Batch size

Holding all other parameters fixed, sweep the batch sizes from 16 to 256:

In [None]:
ENABLE_EXPERIMENT = False
DISPLAY_MODEL_PREDICTION=True
DISPLAY_MODEL_PREDICTION_FIRST_ONLY=True

batch_sizes = [
    [16, 32, 64, 128, 256]
]

if ENABLE_EXPERIMENT:
    for batch_size in batch_sizes:
        (model, best_model_val_loss, best_model_auroc, experimentName) = run_experiment(
            experimentNamePrefix=None, 
            useAbp=True, 
            useEeg=False, 
            useEcg=False,
            nResiduals=12, 
            skip_connection=False,
            batch_size=batch_size,
            learning_rate=1e-4,
            weight_decay=0.0,
            pos_weight=None,
            max_epochs=MAX_EPOCHS,
            patience=PATIENCE,
            device=device
        )

        if DISPLAY_MODEL_PREDICTION:
            for case_id_to_check in my_cases_of_interest_idx:
                preds = predictionsForModel(case_id_to_check, model, best_model_val_loss, device)
                printModelPrediction(case_id_to_check, preds, experimentName)

                if DISPLAY_MODEL_PREDICTION_FIRST_ONLY:
                    break

#### Learning Rate

Holding all other parameters fixed, sweep the learning rate from 1e-2 to 1e-4:

In [None]:
ENABLE_EXPERIMENT = False
DISPLAY_MODEL_PREDICTION=True
DISPLAY_MODEL_PREDICTION_FIRST_ONLY=True

learning_rates = [
    1e-4, 1e-3, 1e-2
]

if ENABLE_EXPERIMENT:
    for learning_rate in learning_rates:
        (model, best_model_val_loss, best_model_auroc, experimentName) = run_experiment(
            experimentNamePrefix=None, 
            useAbp=True, 
            useEeg=False, 
            useEcg=False,
            nResiduals=12, 
            skip_connection=False,
            batch_size=128,
            learning_rate=learning_rate,
            weight_decay=0.0,
            pos_weight=None,
            max_epochs=MAX_EPOCHS,
            patience=PATIENCE,
            device=device
        )
    
        if DISPLAY_MODEL_PREDICTION:
            for case_id_to_check in my_cases_of_interest_idx:
                preds = predictionsForModel(case_id_to_check, model, best_model_val_loss, device)
                printModelPrediction(case_id_to_check, preds, experimentName)

                if DISPLAY_MODEL_PREDICTION_FIRST_ONLY:
                    break

#### Weight decay

Holding all other parameters fixed, sweep the weight decay from `1e-3` to `1e0`:

In [None]:
ENABLE_EXPERIMENT = False
DISPLAY_MODEL_PREDICTION=True
DISPLAY_MODEL_PREDICTION_FIRST_ONLY=True

weight_decays = [
    1e-3, 1e-2, 1e-1, 1e0
]

if ENABLE_EXPERIMENT:
    for weight_decay in weight_decays:
        (model, best_model_val_loss, best_model_auroc, experimentName) = run_experiment(
            experimentNamePrefix=None, 
            useAbp=True, 
            useEeg=False, 
            useEcg=False,
            nResiduals=12, 
            skip_connection=False,
            batch_size=128,
            learning_rate=1e-4,
            weight_decay=weight_decay,
            pos_weight=None,
            max_epochs=MAX_EPOCHS,
            patience=PATIENCE,
            device=device
        )
    
        if DISPLAY_MODEL_PREDICTION:
            for case_id_to_check in my_cases_of_interest_idx:
                preds = predictionsForModel(case_id_to_check, model, best_model_val_loss, device)
                printModelPrediction(case_id_to_check, preds, experimentName)

                if DISPLAY_MODEL_PREDICTION_FIRST_ONLY:
                    break

#### Label balance

Holding all other parameters fixed, sweep the `pos_weight` in `BCEWithLogitsLoss` from `2` to `4`:

In [None]:
ENABLE_EXPERIMENT = False
DISPLAY_MODEL_PREDICTION=True
DISPLAY_MODEL_PREDICTION_FIRST_ONLY=True

pos_weights = [
    2.0, 4.0
]

if ENABLE_EXPERIMENT:
    for pos_weight in pos_weights:
        (model, best_model_val_loss, best_model_auroc, experimentName) = run_experiment(
            experimentNamePrefix=None, 
            useAbp=True, 
            useEeg=False, 
            useEcg=False,
            nResiduals=12, 
            skip_connection=False,
            batch_size=128,
            learning_rate=1e-4,
            weight_decay=0.0,
            pos_weight=pos_weight,
            max_epochs=MAX_EPOCHS,
            patience=PATIENCE,
            device=device
        )
    
        if DISPLAY_MODEL_PREDICTION:
            for case_id_to_check in my_cases_of_interest_idx:
                preds = predictionsForModel(case_id_to_check, model, best_model_val_loss, device)
                printModelPrediction(case_id_to_check, preds, experimentName)

                if DISPLAY_MODEL_PREDICTION_FIRST_ONLY:
                    break

### Ablations

Holding all other parameters fixed, perform ablations on the following parameters:
- Number of Residual Blocks (6, 1)
- Skip Connection

In [None]:
ENABLE_EXPERIMENT = False
DISPLAY_MODEL_PREDICTION=True
DISPLAY_MODEL_PREDICTION_FIRST_ONLY=True

ablations = [
    # nResiduals, skip_connection
    [6, False],
    [1, False],
    [12, True]
]

if ENABLE_EXPERIMENT:
    for (nResiduals, skip_connection) in ablations:
        (model, best_model_val_loss, best_model_auroc, experimentName) = run_experiment(
            experimentNamePrefix=None, 
            useAbp=True, 
            useEeg=False, 
            useEcg=False,
            nResiduals=nResiduals, 
            skip_connection=skip_connection,
            batch_size=128,
            learning_rate=1e-4,
            weight_decay=0.0,
            pos_weight=None,
            max_epochs=MAX_EPOCHS,
            patience=PATIENCE,
            device=device
        )
    
        if DISPLAY_MODEL_PREDICTION:
            for case_id_to_check in my_cases_of_interest_idx:
                preds = predictionsForModel(case_id_to_check, model, best_model_val_loss, device)
                printModelPrediction(case_id_to_check, preds, experimentName)

                if DISPLAY_MODEL_PREDICTION_FIRST_ONLY:
                    break

## Evaluation

### Metric description

As in the original paper, model performance will be evaluated on the following metrics:

- **AUROC**: Area Under Receiver Operating Curve. This is a measure of the model's ability to distinguish between positive and negative classes. The curve is generated by plotting the true positive rate (sensitivity) against the false positive rate (1-specificity) at various threshold settings, and the area under this curve is calculated. Higher values are indicative of better model performance.
- **AUPRC**: Area Under Precision Recall Curve. This is a measure of the model's ability to balance precision and recall. The curve is generated by plotting precision against recall at various threshold settings, and the area under this curve is calculated. Higher values are indicative of better model performance.
- **Sensitivity**: The proportion of true positive cases that are correctly identified by the model, as a fraction of all true cases. Higher values are indicative of better model performance.
- **Specificity**: The proportion of true negative cases that are correctly identified by the model, as a fraction of all true negative cases. Higher values are indicative of better model performance.
- **Threshold**: This is not strictly an evaluation metric, but is reported as the threshold value which minimizes the difference between the sensitivity and specificity.

### Model evaluation

Calculate performance metrics on pre-trained models:

In [None]:
ENABLE_VALIDATION = True

validate_models = [
    # prediction window, useAbp, useEeg, useEcg, model path
    # 3-minute models
    [3, os.path.join('pretrained', 'abp_3min_f386500f.model')],
    [3, os.path.join('pretrained', 'ecg_3min_9888ba74.model')],
    [3, os.path.join('pretrained', 'eeg_3min_6e41ecbf.model')],
    [3, os.path.join('pretrained', 'abp_ecg_3min_4c033450.model')],
    [3, os.path.join('pretrained', 'abp_eeg_3min_a25c1edf.model')],
    [3, os.path.join('pretrained', 'eeg_ecg_3min_24df69ca.model')],
    [3, os.path.join('pretrained', 'abp_eeg_ecg_3min_bea05a31.model')],
    # 5-minute models
    [5, os.path.join('pretrained', 'abp_5min_f4919819.model')],
    [5, os.path.join('pretrained', 'ecg_5min_f5345149.model')],
    [5, os.path.join('pretrained', 'eeg_5min_8970a5eb.model')],
    [5, os.path.join('pretrained', 'abp_ecg_5min_6306c305.model')],
    [5, os.path.join('pretrained', 'abp_eeg_5min_482fd843.model')],
    [5, os.path.join('pretrained', 'eeg_ecg_5min_3885bb9f.model')],
    [5, os.path.join('pretrained', 'abp_eeg_ecg_5min_5ab3f8eb.model')],
    # 10-minute models
    [10, os.path.join('pretrained', 'abp_10min_7661baf5.model')],
    [10, os.path.join('pretrained', 'ecg_10min_49dc88bd.model')],
    [10, os.path.join('pretrained', 'eeg_10min_90d4cdb5.model')],
    [10, os.path.join('pretrained', 'abp_ecg_10min_009ed9f2.model')],
    [10, os.path.join('pretrained', 'abp_eeg_10min_ff7c129d.model')],
    [10, os.path.join('pretrained', 'eeg_ecg_10min_e34ef1f5.model')],
    [10, os.path.join('pretrained', 'abp_eeg_ecg_10min_198d1d84.model')],
    # 15-minute models
    [15, os.path.join('pretrained', 'abp_15min_61321b51.model')],
    [15, os.path.join('pretrained', 'ecg_15min_3ac4acf1.model')],
    [15, os.path.join('pretrained', 'eeg_15min_acd313eb.model')],
    [15, os.path.join('pretrained', 'abp_ecg_15min_ad0d8b9b.model')],
    [15, os.path.join('pretrained', 'abp_eeg_15min_4c527f9b.model')],
    [15, os.path.join('pretrained', 'eeg_ecg_15min_2bb1d44d.model')],
    [15, os.path.join('pretrained', 'abp_eeg_ecg_15min_10e6e48b.model')],
]

if ENABLE_VALIDATION:
    #test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64)
    #loss_func = nn.BCELoss()
    for pred_window, model_path in validate_models:
        if pred_window == PREDICTION_WINDOW:
            print()
            print(f"Prediction Window: {pred_window}, Model: {model_path}")
            test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64)
            loss_func = nn.BCELoss()
            model = torch.load(model_path)
            eval_model(model, device, test_loader, loss_func, print_detailed = False)

### Model prediction

Use the model to predict the chance of an IOH event for real cases.

In [None]:
PERFORM_PREDICTIONS = False
PERFORM_PREDICTION_FIRST_ONLY = True

# NOTE: This is always set so that if earlier checks were enabled, the earlier data will be reused.
my_cases_of_interest_idx = [84, 198, 60, 16, 27]

if PERFORM_PREDICTIONS:
    positiveSegmentsMap, negativeSegmentsMap, iohEventsMap, cleanEventsMap = \
        extract_segments(my_cases_of_interest_idx, debug=False,
                         checkCache=False, forceWrite=False, returnSegments=True,
                         skipInvalidCleanEvents=SKIP_INVALID_CLEAN_EVENTS,
                         skipInvalidIohEvents=SKIP_INVALID_IOH_EVENTS)

    for pred_window, model_path in validate_models:
        if pred_window == PREDICTION_WINDOW:
            for case_id_to_check in my_cases_of_interest_idx:
                print()
                print(f'Model Predictions - Case {case_id_to_check} for {pred_window} Minute Prediction Window')
                print(f'Model: {model_path}')
                printAbpOverlay(case_id_to_check, positiveSegmentsMap, 
                            negativeSegmentsMap, iohEventsMap, cleanEventsMap, movingAverage=False)

                ready_model = torch.load(model_path)
                preds = predictionsForModel(case_id_to_check, None, None, device, ready_model=ready_model)

                printModelPrediction(case_id_to_check, preds, None)

                if PERFORM_PREDICTION_FIRST_ONLY:
                    break

# Results


We were able to run all of the same experiments as the authors of the original paper, though we were not able to fully replicate their results. In addition, were were able to run a few ablation studies to quantify the impact of the number of residual blocks and the skip connection path in the model.

Our complete table of experimental results is show below in Table 1:

| Waveform              | AUROC        | AUPRC        | Sensitivity  | Specificity  | Threshold |
| --------------------- | ------------ | ------------ | ------------ | ------------ | --------- |
| **Time to event: 3 min**  |              |              |              |              |           |
| ABP                   | 0.8348665138 | 0.6827477064 | 0.7586206897 | 0.7586916743 | 0.24      |
| ECG                   | 0.5172308631 | 0.2777976524 | 0.6864623244 | 0.3392040256 | 0.34      |
| EEG                   | 0.5765847539 | 0.3108817974 | 0.5504469987 | 0.5562671546 | 0.28      |
| ABP + ECG             | **0.8350885234** | 0.6822426973 | **0.7618135377** | 0.7463403477 | 0.23      |
| ABP + EEG             | 0.8333041215 | **0.6830001845** | 0.7573435504 | 0.7602927722 | 0.26      |
| ECG + EEG             | 0.5634548602 | 0.3061094980 | 0.5932311622 | 0.5166971638 | 0.34      |
| ABP + ECG + EEG       | 0.8327256552 | 0.6812673236 | 0.7541507024 | **0.7653247941** | 0.29      |
| **Time to event: 5 min**  |              |              |              |              |           |
| ABP                   | **0.8001353845** | **0.6089914018** | 0.7172413793 | 0.7268053283 | 0.19      |
| ECG                   | 0.5408307613 | 0.2789858266 | **0.9089655172** | 0.0983874737 | 0.30      |
| EEG                   | 0.5889406162 | 0.3209396685 | 0.5606896552 | 0.5814442627 | 0.26      |
| ABP + ECG             | 0.7980903530 | 0.6008434537 | **0.7234482759** | 0.7120822622 | 0.17      |
| ABP + EEG             | 0.7932250526 | 0.6046273056 | 0.7020689655 | **0.7289086235** | 0.22      |
| ECG + EEG             | 0.5959877026 | 0.3278444283 | 0.6020689655 | 0.5461556438 | 0.30      |
| ABP + ECG + EEG       | 0.7912796254 | 0.6016009768 | 0.7151724138 | 0.7193269455 | 0.25      |
| **Time to event: 10 min** |              |              |              |              |           |
| ABP                   | 0.7417550791 | **0.4515207123** | 0.6550802139 | **0.7046703297** | 0.18      |
| ECG                   | 0.4859654235 | 0.1971933307 | **0.8609625668** | 0.1252289377 | 0.22      |
| EEG                   | 0.5929758558 | 0.2583727183 | 0.5695187166 | 0.5592948718 | 0.24      |
| ABP + ECG             | 0.7434999641 | 0.4485223920 | **0.7058823529** | 0.6572802198 | 0.17      |
| ABP + EEG             | **0.7456936446** | 0.4482909599 | 0.6773618538 | 0.6893315018 | 0.17      |
| ECG + EEG             | 0.5900167031 | 0.2531979287 | 0.5989304813 | 0.5283882784 | 0.23      |
| ABP + ECG + EEG       | 0.7433881478 | 0.4513591475 | 0.6951871658 | 0.6668956044 | 0.16      |
| **Time to event: 15 min** |              |              |              |              |           |
| ABP                   | 0.7350525214 | 0.3629148943 | 0.6534772182 | **0.6929577465** | 0.14      |
| ECG                   | 0.4958383997 | 0.1685094385 | 0.3944844125 | 0.6157276995 | 0.22      |
| EEG                   | 0.5681875626 | 0.1976809641 | 0.5935251799 | 0.5009389671 | 0.21      |
| ABP + ECG             | **0.7377326308** | **0.3649217753** | 0.6642685851 | 0.6786384977 | 0.15      |
| ABP + EEG             | 0.7364418324 | 0.3626843809 | 0.6678657074 | 0.6732394366 | 0.15      |
| ECG + EEG             | 0.5763017755 | 0.2000958414 | 0.5071942446 | 0.6140845070 | 0.18      |
| ABP + ECG + EEG       | 0.7344424460 | 0.3624089403 | **0.6906474820** | 0.6593896714 | 0.15      |

Table 1: Area under the Receiver-operating Characteristic Curve, Area under the Precision-Recall Curve, Sensitivity, and Specificity of the model in predicting intraoperative hypotension

For comparison, the results in the original paper are shown below as Table 2:
![Area under the Receiver-operating Characteristic Curve, Area under the Precision-Recall Curve, Sensitivity, and Specificity of the model in predicting intraoperative hypotension](https://journals.plos.org/plosone/article/figure/image?download&size=large&id=10.1371/journal.pone.0272055.t003)
Table 2: Area under the Receiver-operating Characteristic Curve, Area under the Precision-Recall Curve, Sensitivity, and Specificity of the model in predicting intraoperative hypotension in the original paper

### Hypotheses

The original paper investigated the following hypotheses:

1.   Hypothesis 1: A model using ABP and ECG will outperform a model using ABP alone in predicting IOH.
2.   Hypothesis 2: A model using ABP and EEG will outperform a model using ABP alone in predicting IOH.
3.   Hypothesis 3: A model using ABP, EEG, and ECG will outperform a model using ABP alone in predicting IOH.

As seen in the results in Table 1, our results were very noisy and unable to prove or disprove any of the three hypotheses. The results were not consistent across the prediction windows or metrics, and the performance was not as high as in the original paper. 

### Hyperparameters

We performed a hyperparameter search across batch size, learning rate, weight decay, and label balancing. The results are shown below in Table 3:

| Hyperparameter | Search range          | Optimum value  |
| -------------- | --------------------- | -------------- |
| Batch size     | 16, 32, 64, 128, 256  | 128            |
| Learning rate  | 1e-4, 1e-3, 1e-2      | 1e-4           |
| Weight decay   | 1e-3, 1e-2, 1e-1, 1e0 | 1-e1           |
| Label balance  | 1.0, 2.0, 4.0         | 1.0 (disabled) |

Table 3: Hyperparameter search ranges and optimum values

The experimental data supporting these results can be found in [Supplemental Table 1 - Hyperparameter exploration](https://raw.githubusercontent.com/abarrie2/cs598-dlh-project/main/Supplemental%20Table%201%20Hyperparameter%20exploration.xlsx).

### Ablation Study

We performed an ablation study of the number of residual blocks and presence of skip connections. The original model configuration was found to have better performance than the ablated models, showing that the original model features contribute positively to the model performance and their removal results in a qualitatively worse result.

The experimental data supporting these results can be found in [Supplemental Table 2 - Ablation study](https://raw.githubusercontent.com/abarrie2/cs598-dlh-project/main/Supplemental%20Table%202%20Ablation%20study.xlsx)

### Computational requirements

Training and evaluation was run on three different machines as shown in Table 4 below:

| Machine     | Processor         | System RAM | GPU            | GPU RAM            | Device |
| ----------- | ----------------- | ---------- | -------------- | ------------------ | ------ |
| Macbook Pro | M1                | 32 GB      | Integrated     | Shared with system | mps    |
| Macbook Pro | M3 Pro            | 36 GB      | Integrated     | Shared with system | mps    |
| Desktop PC  | AMD Ryzen 7 3800X | 32 GB      | GTX 2070 Super | 8 GB               | cuda   |

Table 4: Specifications of machines used in training and evaluation

Typical runtime on a Macbook Pro is on the order of 2-3 minutes per epoch, and a typical experiment runs for 20-60 epochs. Including post-training evaluations, a typical experiment takes 90-360 minutes.

We estimate that we spent a total of 350 GPU hours training across all experiments.

# Discussion

## Implications of experimental results

### Reproducibility of original paper

Although we were not able to achieve the same performance level as the original paper, we were able to perform all of the same experiments and test the hypotheses. The results of our experiments were ultimately consistent with the original paper, or consistent from prediction window to prediction window, or metric to metric, and we were unable to prove or disprove any of the three hypotheses. The hyperparameter search and ablation study results were consistent with the original paper.

Using our best models, we were able to generate predictions for cases and then plot the predicted values (probability of an IOH event) and compare it to to mean arterial pressure (MAP) of the raw ABP waveform. The predictions would increase prior to the MAP beginning to drop and decrease as the MAP recovered. For sustained periods of MAP below 65 mmHg the model predictions would become high and remain high for the duration of the IOH event. These results were consistent with similar prediction plots presented in `Figure 3` of the original paper.

We believe that the discrepancy in performance between our results and the original paper is mostly due to the differences in datasets and data preprocessing. The original paper used a dataset with 39,600 cases, of which 14,140 met the inclusion and exclusion critera and were used for training. The authors released a much smaller dataset publically with only 6,388 cases, of which 2,763 met the inclusion and exclusion criteria. This smaller dataset provided less data for training and validation, which likely impacted the model's performance.

The original paper also used a signal quality index to filter out low-quality data, which we were not able to implement. This likely introduced noise into our dataset, which likely impacted the model's performance.

The authors of the original paper also used a different data preprocessing pipeline than we did, and did not precisely document it. This lead us to try various data preprocessing methods to try to match the original paper's results, but we were not able to achieve the same performance levels.

## Factors affecting reproducibility

### Low difficulty

The most straightforward aspects of this project were the data download and the model implementation. Specific areas were we encountered low difficulty:
- The data download was straightforward, as the dataset was available through a convenient API through Python library as well as being published to Physionet for download.
- The model architecture was generally clearly defined in the original paper with an included architecture diagram. The hyperparameters were provided in a supplemental table. These features made it easy to implement the model in PyTorch.

### High difficulty

The most difficult aspects of this project all involved the data preprocessing stage. This is the most impactful part of the data pipeline and it was not fully documented in the original paper. Some areas were we encountered difficulty:
- The source data is unlabelled, so our team was responsible for implementing analysis methods for identifying positive (IOH event occurred) and negative (IOH event did not occur) by running a lookahead analysis of our input training set. The original paper was not precise on how this was done, so we had to make some assumptions.
- The volume of raw data is in excess of 90GB. A non-trivial amount of compute was required to minify the input data to only include the data tracks of interest to our experiments (i.e., ABP, ECG, and EEG tracks).
- We found it difficult to trace back to the definition of the "jSQI" signal quality index referenced in the paper. Multiple references through multiple papers needed to be traversed to understand which variant of the quality index 
   - The only available source code related to the signal quality index as referenced by our paper in [5]. Source code was not directly linked from the paper, but the GitHub repository for the corresponding author for reference [5] did result in the identification of MATLAB source code for the signal quality index as described in the referenced paper. That code is available at https://github.com/cliffordlab/PhysioNet-Cardiovascular-Signal-Toolbox/tree/master/Tools/BP_Tools [7]
   - Our team had insufficient time to port this signal quality index to Python for use in our investigation, or to setup a MATLAB environment in which to assess our source data using the above MATLAB functions. This is a potential source of noise in our dataset, as the signal quality index was used to filter out low-quality data in the original paper, and we were unable to do the same.

### Unknowns

One aspect of our results that we were unable to explain was why our threshold values were lower than expected. In the original paper, the threshold is chosen in order to minimize the difference between sensitivity and specificity, and we applied an algorithm to achieve this goal. However, in the original paper, the thresholds were between 0.30 and 0.62 while our results were between 0.14 and 0.34. We posited that this was due to the label imbalance (4 positive labels : 1 negative label) and performed experiments comparing different `pos_weight` values using the `BCEWithLogitsLoss` loss function with the `BCELoss` loss function fro the original paper. However, this did not yield better, or indeed any usable results.

## Suggestions to original authors

Our main suggestion to the original authors would be to release their code publically and to provide more detailed documentation of the data preprocessing pipeline. This would help future researchers to reproduce the results more accurately. Specifically, the authors should provide more information on how the signal quality index was calculated and used to filter out low-quality data.

We would also suggestion correcting the hyperparameters published in Supplemental Table 1. Specifically, the output size for residual blocks 11 and 12 for the ECG and ABP data sets was 496x6. This is a typo, and should read 469x6. This typo became apparent when operating the size down operation within Residual Block 11 and recognizing the tensor dimensions were misaligned.

Additionally, more explicit references to the signal quality index assessment tools should be added. Our team could not find a reference to the MATLAB source code as described in reference [3], and had to manually discover the GitHub profile for the lab of the corresponding author of reference [3] in order to find MATLAB source that corresponded to the metrics described therein.

## Future work

In future work, we would like to implement the signal quality index and use it to filter out low-quality data. We would also like to experiment with additional data preprocessing techniques and pre-filtered datasets such as PulseDB: a cleaned dataset based on MIMIC-III and VitalDB. Further, we would like to experiment with different model architectures and hyperparameters to see if we can improve the model's performance. Finally, we would like to run the models with different seeds to create a model ensemble in order to smooth some of the noise in our results.

# References

1. Jo Y-Y, Jang J-H, Kwon J-m, Lee H-C, Jung C-W, Byun S, et al. “Predicting intraoperative hypotension using deep learning with waveforms of arterial blood pressure, electroencephalogram, and electrocardiogram: Retrospective study.” PLoS ONE, (2022) 17(8): e0272055 https://doi.org/10.1371/journal.pone.0272055
2. Hatib, Feras, Zhongping J, Buddi S, Lee C, Settels J, Sibert K, Rhinehart J, Cannesson M “Machine-learning Algorithm to Predict Hypotension Based on High-fidelity Arterial Pressure Waveform Analysis” Anesthesiology (2018) 129:4 https://doi.org/10.1097/ALN.0000000000002300
3. Bao, X., Kumar, S.S., Shah, N.J. et al. "AcumenTM hypotension prediction index guidance for prevention and treatment of hypotension in noncardiac surgery: a prospective, single-arm, multicenter trial." Perioperative Medicine (2024) 13:13 https://doi.org/10.1186/s13741-024-00369-9
4. Lee, HC., Park, Y., Yoon, S.B. et al. VitalDB, a high-fidelity multi-parameter vital signs database in surgical patients. Sci Data 9, 279 (2022). https://doi.org/10.1038/s41597-022-01411-5
5. Li Q., Mark R.G. & Clifford G.D. "Artificial arterial blood pressure artifact models and an evaluation of a robust blood pressure and heart rate estimator." BioMed Eng OnLine. (2009) 8:13. pmid:19586547 https://doi.org/10.1186/1475-925X-8-13
6. Park H-J, "VitalDB Python Example Notebooks" GitHub Repository https://github.com/vitaldb/examples/blob/master/hypotension_art.ipynb
7. Vest A, Da Poian G, Li Q, Liu C, Nemati S, Shah A, Clifford GD, "An Open Source Benchmarked Toolbox for Cardiovascular Waveform and Interval Analysis", Physiological measurement 39, no. 10 (2018): 105004. DOI:10.5281/zenodo.1243111; 2018. 

In [None]:
time_delta = np.round(timer() - global_time_start, 3)
print(f'Total Notebook Processing Time: {time_delta:.4f} sec')