<a href="https://colab.research.google.com/github/abarrie2/cs598-dlh-project/blob/main/DL4H_Team_24.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## LICENSING NOTICE

Note that all users who use Vital DB, an open biosignal dataset, must agree to the Data Use Agreement below. If you do not agree, please close this window. The Data Use Agreement is available here:
https://vitaldb.net/dataset/#h.vcpgs1yemdb5

# Introduction

This project aims to reproduce findings from the paper titled "Predicting intraoperative hypotension using deep learning with waveforms of arterial blood pressure, electroencephalogram, and electrocardiogram: Retrospective study" by Jo Y-Y et al. (2022) [1]. This study introduces a deep learning model that predicts intraoperative hypotension (IOH) events before they occur, utilizing a combination of arterial blood pressure (ABP), electroencephalogram (EEG), and electrocardiogram (ECG) signals.


## Background of the Problem

Intraoperative hypotension (IOH) is a common and significant surgical complication defined by a mean arterial pressure drop below 65 mmHg. It is associated with increased risks of myocardial infarction, acute kidney injury, and heightened postoperative mortality. Effective prediction and timely intervention can substantially enhance patient outcomes.

### Evolution of IOH Prediction

Initial attempts to predict IOH primarily used arterial blood pressure (ABP) waveforms. A foundational study by Hatib F et al. (2018) titled "Machine-learning Algorithm to Predict Hypotension Based on High-fidelity Arterial Pressure Waveform Analysis" [2] showed that machine learning could forecast IOH events using ABP with reasonable accuracy. This finding spurred further research into utilizing various physiological signals for IOH prediction.

Subsequent advancements included the development of the Acumen™ hypotension prediction index, which was studied in "AcumenTM hypotension prediction index guidance for prevention and treatment of hypotension in noncardiac surgery: a prospective, single-arm, multicenter trial" by Bao X et al. (2024) [3]. This trial integrated a hypotension prediction index into blood pressure monitoring equipment, demonstrating its effectiveness in reducing the number and duration of IOH events during surgeries. Further study is needed to determine whether this resultant reduction in IOH events transalates into improved postoperative patient outcomes.


### Current Study

Building on these advancements, the paper by Jo Y-Y et al. (2022) proposes a deep learning approach that enhances prediction accuracy by incorporating EEG and ECG signals along with ABP. This multi-modal method, evaluated over prediction windows of 3, 5, 10, and 15 minutes, aims to provide a comprehensive physiological profile that could predict IOH more accurately and earlier. Their results indicate that the combination of ABP and EEG significantly improves performance metrics such as AUROC and AUPRC, outperforming models that use fewer signals or different combinations.

Our project seeks to reproduce and verify Jo Y-Y et al.'s results to assess whether this integrated approach can indeed improve IOH prediction accuracy, thereby potentially enhancing surgical safety and patient outcomes.

# Scope of Reproducibility:

The original paper investigated the following hypotheses:

1.   Hypothesis 1: A model using ABP and ECG will outperform a model using ABP alone in predicting IOH.
2.   Hypothesis 2: A model using ABP and EEG will outperform a model using ABP alone in predicting IOH.
3.   Hypothesis 3: A model using ABP, EEG, and ECG will outperform a model using ABP alone in predicting IOH.

Results were compared using AUROC and AUPRC scores. Based on the results described in the original paper, we expect that Hypothesis 2 will be confirmed, and that Hypotheses 1 and 3 will not be confirmed.

In order to perform the corresponding experiments, we will implement a CNN-based model that can be configured to train and infer using the following four model variations:

1.   ABP data alone
2.   ABP and ECG data
3.   ABP and EEG data
4.   ABP, ECG, and EEG data

We will measure the performance of these configurations using the same AUROC and AUPRC metrics as used in the original paper. To test hypothesis 1 we will compare the AUROC and AUPRC measures between model variation 1 and model variation 2. To test hypothesis 2 we will compare the AUROC and AUPRC measures between model variation 1 and model variation 3. To test hypothesis 3 we will compare the AUROC and AUPRC measures between model variation 1 and model variation 4. For all of the above measures and experiment combinations, we will operate multiple experiments where the time-to-IOH event prediction will use the following prediction windows:

1. 3 minutes before event
2. 5 minutes before event
3. 10 minutes before event
4. 15 minutes before event

In the event that we are compute-bound, we will prioritize the 3-minute prediction window experiments as they are the most relevant to the original paper's findings.

The predictive power of ABP, ECG and ABP + ECG models at 3-, 5-, 10- and 15-minute prediction windows:
![Predictive power of ABP, ECG and ABP + ECG models at 3-, 5-, 10- and 15-minute prediction windows](https://journals.plos.org/plosone/article/figure/image?download&size=large&id=10.1371/journal.pone.0272055.g004)

# Methodology

The methodology section is composed of the following subsections: Environment, Data and Model.

- **Environment**: This section describes the setup of the environment, including the installation of necessary libraries and the configuration of the runtime environment.
- **Data**: This section describes the dataset used in the study, including its collection and preprocessing.
    - **Data Collection**: This section describes the process of downloading the dataset from VitalDB and populating the local data cache.
    - **Data Preprocessing**: This section describes the preprocessing steps applied to the dataset, including data selection, data cleaning, and feature extraction.
- **Model**: This section describes the deep learning model used in the study, including its implementation, training, and evaluation.
    - **Model Implementation**: This section describes the implementation of the deep learning model, including the architecture, loss function, and optimization algorithm.
    - **Model Training**: This section describes the training process, including the training loop, hyperparameters, and training strategy.
    - **Model Evaluation**: This section describes the evaluation process, including the metrics used, the evaluation strategy, and the results obtained.

## Environment

### Create environment

The environment setup differs based on whether you are running the code on a local machine or on Google Colab. The following sections provide instructions for setting up the environment in each case.

#### Local machine

Create `conda` environment for the project using the `environment.yml` file:

```bash
conda env create --prefix .envs/dlh-team24 -f environment.yml
```

Activate the environment with:
```bash
conda activate .envs/dlh-team24
```

#### Google Colab

The following code snippet installs the required packages in a Google Colab environment:

In [None]:
#install vitaldb
%pip install vitaldb

All other required packages are already installed in the Google Colab environment.

### Load environment

In [None]:
# Import packages
import os
import random

from timeit import default_timer as timer

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset
import vitaldb
import pickle
import _pickle as cPickle

import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from datetime import datetime

#from google.colab import drive


Set random seeds to generate consistent results:

In [None]:
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
#torch.use_deterministic_algorithms(False)
os.environ['PYTHONASHSEED'] = str(RANDOM_SEED)

##  Data

### Data Description

#### Source

Data for this project is sourced from the open biosignal VitalDB dataset as described in "VitalDB, a high-fidelity multi-parameter vital signs database in surgical patients" by Lee H-C et al. (2022) [4], which contains perioperative vital signs and numerical data from 6,388 cases of non-cardiac (general, thoracic, urological, and gynecological) surgery patients who underwent routine or emergency surgery at Seoul National University Hospital between 2016 and 2017. The dataset includes ABP, ECG, and EEG signals, as well as other physiological data. The dataset is available through an [API](https://vitaldb.net/dataset/?query=api) and [Python library](https://vitaldb.net/dataset/?query=lib), and at PhysioNet: https://physionet.org/content/vitaldb/1.0.0/

#### Statistics

Characteristics of the dataset:
| Characteristic        | Value                       | Details                |
|-----------------------|-----------------------------|------------------------|
| Total number of cases | 6,388                       |                        |
| Sex (male)            | 3,243 (50.8%)               |                        |
| Age (years)           | 59                          | Range: 48-68           |
| Height (cm)           | 162                         | Range: 156-169         |
| Weight (kg)           | 61                          | Range: 53-69           |
| Tram-Rac 4A tracks    | 6,355 (99.5%)               | Sampling rate: 500Hz   |
| BIS Vista tracks      | 5,566 (87.1%)               | Sampling rate: 128Hz   |
| Case duration (min)   | 189                         | Range: 27-1041         |

Labels are only known after processing the data. In the original paper, there were an average of 1.6 IOH events per case and 5.7 non-events per case so we expect approximately 10,221 IOH events and 364,116 non-events in the dataset.

#### Data Processing

Data will be processed as follows:
1. Load the dataset from VitalDB, or from a local cache if previously downloaded.
2. Apply the inclusion and exclusion selection criteria to filter the dataset according to surgery metadata.
3. Generate a minified dataset by discarding all tracks except ABP, ECG, and EEG.
4. Preprocess the data by applying band-pass and z-score normalization to the ECG and EEG signals, and filtering out ABP signals below a Signal Quality Index (SQI) threshold.
5. Generate event and non-event samples by extracting 60-second segments around IOH events and non-events.
6. Split the dataset into training, validation, and test sets with a 6:1:3 ratio, ensuring that samples from a single case are not split across different sets to avoid data leakage.

### Set Up Local Data Caches

VitalDB data is static, so local copies can be stored and reused to avoid expensive downloads and to speed up data processing.

The default directory defined below is in the project `.gitignore` file. If this is modified, the new directory should also be added to the project `.gitignore`.

In [None]:
VITALDB_CACHE = './vitaldb_cache'
VITAL_ALL = f"{VITALDB_CACHE}/vital_all"
VITAL_MINI = f"{VITALDB_CACHE}/vital_mini"
VITAL_METADATA = f"{VITALDB_CACHE}/metadata"
VITAL_MODELS = f"{VITALDB_CACHE}/models"
VITAL_PREPROCESS_SCRATCH = f"{VITALDB_CACHE}/data_scratch"

In [None]:
TRACK_CACHE = None
# when USE_DISK_CACHING is enabled, track and segment data will be flushed to disk
USE_DISK_CACHING = False

# When RESET_CACHE is set to True, it will ensure the TRACK_CACHE is disposed and recreated when we do dataset initialization. Use as a shortcut to wiping cache rather than restarting kernel
RESET_CACHE = False
EXPERIMENT_EVENT_HORIZON = 3

# Maximum number of cases of interest for which to download data.
# Set to a small value for demo purposes, else set to None to disable and download all.
#MAX_CASES = None
MAX_CASES = 20

# Preloading Cases: when true, all matched cases will have the _mini tracks extracted and put into in-mem dict
PRELOADING_CASES = True

In [None]:
# windows variant
if not os.path.exists(VITALDB_CACHE):
  os.mkdir(VITALDB_CACHE)
if not os.path.exists(VITAL_ALL):
  os.mkdir(VITAL_ALL)
if not os.path.exists(VITAL_MINI):
  os.mkdir(VITAL_MINI)
if not os.path.exists(VITAL_METADATA):
  os.mkdir(VITAL_METADATA)
if not os.path.exists(VITAL_MODELS):
  os.mkdir(VITAL_MODELS)
if not os.path.exists(VITAL_PREPROCESS_SCRATCH):
  os.mkdir(VITAL_PREPROCESS_SCRATCH)

print(os.listdir(VITALDB_CACHE))


### Bulk Data Download

**This step is not required, but will significantly speed up downstream processing and avoid a high volume of API requests to the VitalDB web site.**

The cache population code checks if the `.vital` files are locally available, and can be populated by calling the vitaldb API or by manually prepopulating the cache (recommended)

- Manually downloaded the dataset from the following site: https://physionet.org/content/vitaldb/1.0.0/
    - Download the [zip file](https://physionet.org/static/published-projects/vitaldb/vitaldb-a-high-fidelity-multi-parameter-vital-signs-database-in-surgical-patients-1.0.0.zip) in a browser, or
    - Use `wget -r -N -c -np https://physionet.org/files/vitaldb/1.0.0/` to download the files in a terminal
- Move the contents of `vital_files` into the `${VITAL_ALL}` directory.

In [None]:
# Returns the Pandas DataFrame for the specified dataset.
#   One of 'cases', 'labs', or 'trks'
# If the file exists locally, create and return the DataFrame.
# Else, download and cache the csv first, then return the DataFrame.
def vitaldb_dataframe_loader(dataset_name):
    if dataset_name not in ['cases', 'labs', 'trks']:
        raise ValueError(f'Invalid dataset name: {dataset_name}')
    file_path = f'{VITAL_METADATA}/{dataset_name}.csv'
    if os.path.isfile(file_path):
        print(f'{dataset_name}.csv exists locally.')
        df = pd.read_csv(file_path)
        return df
    else:
        print(f'downloading {dataset_name} and storing in the local cache for future reuse.')
        df = pd.read_csv(f'https://api.vitaldb.net/{dataset_name}')
        df.to_csv(file_path, index=False)
        return df

## Exploratory Data Analysis

#### Cases

In [None]:
cases = vitaldb_dataframe_loader('cases')
cases = cases.set_index('caseid')
cases.shape

In [None]:
cases.index.nunique()

In [None]:
cases.head()

In [None]:
cases['sex'].value_counts()

#### Tracks

In [None]:
trks = vitaldb_dataframe_loader('trks')
trks = trks.set_index('caseid')
trks.shape

In [None]:
trks.index.nunique()

In [None]:
trks.groupby('caseid')[['tid']].count().plot();

In [None]:
trks.groupby('caseid')[['tid']].count().hist();

In [None]:
trks.groupby('tname').count().sort_values(by='tid', ascending=False)

## Parameters of Interest

### Hemodynamic Parameters Reference
https://vitaldb.net/dataset/?query=overview#h.f7d712ycdpk2

**SNUADC/ART**

arterial blood pressure waveform

Parameter, Description, Type/Hz, Unit

SNUADC/ART, Arterial pressure wave, W/500, mmHg

In [None]:
trks[trks['tname'].str.contains('SNUADC/ART')].shape

**SNUADC/ECG_II**

electrocardiogram waveform

Parameter, Description, Type/Hz, Unit

SNUADC/ECG_II, ECG lead II wave, W/500, mV

In [None]:
trks[trks['tname'].str.contains('SNUADC/ECG_II')].shape

**BIS/EEG1_WAV**

electroencephalogram waveform

Parameter, Description, Type/Hz, Unit

BIS/EEG1_WAV, EEG wave from channel 1, W/128, uV

In [None]:
trks[trks['tname'].str.contains('BIS/EEG1_WAV')].shape

## Cases of Interest

These are the subset of case ids for which modelling and analysis will be performed based upon inclusion criteria and waveform data availability.

In [None]:
TRACK_NAMES = ['SNUADC/ART', 'SNUADC/ECG_II', 'BIS/EEG1_WAV']
TRACK_SRATES = [500, 500, 128]

In [None]:
# As in the paper, select cases which meet the following criteria:
#
# For patients, the inclusion criteria were as follows:
# (1) adults (age >= 18)
# (2) administered general anaesthesia
# (3) undergone non-cardiac surgery. 
#
# For waveform data, the inclusion criteria were as follows:
# (1) no missing monitoring for ABP, ECG, and EEG waveforms
# (2) no cases containing false events or non-events due to poor signal quality
#     (checked in second stage of data preprocessing)

# Adult
inclusion_1 = cases.loc[cases['age'] >= 18].index
print(f'{len(cases)-len(inclusion_1)} cases excluded, {len(inclusion_1)} remaining due to age criteria')

# General Anesthesia
inclusion_2 = cases.loc[cases['ane_type'] == 'General'].index
print(f'{len(cases)-len(inclusion_2)} cases excluded, {len(inclusion_2)} remaining due to anesthesia criteria')

# Non-cardiac surgery
inclusion_3 = cases.loc[
    ~cases['opname'].str.contains("cardiac", case=False)
    & ~cases['opname'].str.contains("aneurysmal", case=False)
].index
print(f'{len(cases)-len(inclusion_3)} cases excluded, {len(inclusion_3)} remaining due to non-cardiac surgery criteria')

# ABP, ECG, EEG waveforms
inclusion_4 = trks.loc[trks['tname'].isin(TRACK_NAMES)].index.value_counts()
inclusion_4 = inclusion_4[inclusion_4 == len(TRACK_NAMES)].index
print(f'{len(cases)-len(inclusion_4)} cases excluded, {len(inclusion_4)} remaining due to missing waveform data')

cases_of_interest_idx = inclusion_1 \
    .intersection(inclusion_2) \
    .intersection(inclusion_3) \
    .intersection(inclusion_4)

cases_of_interest = cases.loc[cases_of_interest_idx]

print()
print(f'{cases_of_interest_idx.shape[0]} out of {cases.shape[0]} total cases remaining after exclusions applied')

# Trim cases of interest to MAX_CASES
if MAX_CASES:
    cases_of_interest_idx = cases_of_interest_idx[:MAX_CASES]
print(f'{cases_of_interest_idx.shape[0]} cases of interest selected')

In [None]:
cases_of_interest.head(n=5)

## Tracks of Interest

These are the subset of tracks (waveforms) for the cases of interest identified above.

In [None]:
# A single case maps to one or more waveform tracks. Select only the tracks required for analysis.
trks_of_interest = trks.loc[cases_of_interest_idx][trks.loc[cases_of_interest_idx]['tname'].isin(TRACK_NAMES)]
trks_of_interest.shape

In [None]:
trks_of_interest.head(n=5)

In [None]:
trks_of_interest_idx = trks_of_interest.set_index('tid').index
trks_of_interest_idx.shape

### Build Tracks Cache for Local Processing

Tracks data are large and therefore expensive to download every time used.
By default, the `.vital` file format stores all tracks for each case internally. Since only select tracks per case are required, each `.vital` file can be further reduced by discarding the unused tracks.

In [None]:
# Ensure the full vital file dataset is available for cases of interest.
count_downloaded = 0
count_present = 0

#for i, idx in enumerate(cases.index):
for i, idx in enumerate(cases_of_interest_idx):
    full_path = f'{VITAL_ALL}/{idx:04d}.vital'
    if not os.path.isfile(full_path):
        print(f'Missing vital file: {full_path}')
        # Download and save the file.
        vf = vitaldb.VitalFile(idx)
        vf.to_vital(full_path)
        count_downloaded += 1
    else:
        count_present += 1

print()
print(f'Count of cases of interest:           {cases_of_interest_idx.shape[0]}')
print(f'Count of vital files downloaded:      {count_downloaded}')
print(f'Count of vital files already present: {count_present}')

In [None]:
# Convert vital files to "mini" versions including only the subset of tracks defined in TRACK_NAMES above.
# Only perform conversion for the cases of interest.
# NOTE: If this cell is interrupted, it can be restarted and will continue where it left off.
count_minified = 0
count_present = 0

for i, idx in enumerate(cases_of_interest_idx):
    full_path = f'{VITAL_ALL}/{idx:04d}.vital'
    mini_path = f'{VITAL_MINI}/{idx:04d}_mini.vital'
    if not os.path.isfile(mini_path):
        print(f'Creating mini vital file: {idx}')
        vf = vitaldb.VitalFile(full_path, TRACK_NAMES)
        vf.to_vital(mini_path)
        count_minified += 1
    else:
        count_present += 1

print()
print(f'Count of cases of interest:           {cases_of_interest_idx.shape[0]}')
print(f'Count of vital files minified:        {count_minified}')
print(f'Count of vital files already present: {count_present}')

In [None]:
# Exclude cases where ABP j signal quality (jSQI) < 0.8
# TODO: Implement jSQI function
# TODO: Filter cases with jSQI < 0.8

#### Filtering

Preprocessing characteristics are different for each of the three signal categories:
 * ABP: no preprocessing, use as-is
 * ECG: apply a 1-40Hz bandpass filter, then perform Z-score normalization
 * EEG: apply a 0.5-50Hz bandpass filter

`apply_bandpass_filter()` implements the bandpass filter using scipy.signal

`apply_zscore_normalization()` implements the Z-score normalization using numpy

In [None]:
from scipy.signal import butter, lfilter, spectrogram

# define two methods for data preprocessing

def apply_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter(order, [lowcut, highcut], fs=fs, btype='band')
    y = lfilter(b, a, np.nan_to_num(data))
    return y

def apply_zscore_normalization(signal):
    mean = np.nanmean(signal)
    std = np.nanstd(signal)
    return (signal - mean) / std


In [None]:
# Filtering Demonstration

# temp experimental, code to be incorporated into overall preloader process
# for now it's just dumping example plots of the before/after filtered signal data
caseidx = 1
file_path = f"{VITAL_MINI}/{caseidx:04d}_mini.vital"
vf = vitaldb.VitalFile(file_path, TRACK_NAMES)

originalAbp = None
filteredAbp = None
originalEcg = None
filteredEcg = None
originalEeg = None
filteredEeg = None

ABP_TRACK_NAME = "SNUADC/ART"
ECG_TRACK_NAME = "SNUADC/ECG_II"
EEG_TRACK_NAME = "BIS/EEG1_WAV"

for i, (track_name, rate) in enumerate(zip(TRACK_NAMES, TRACK_SRATES)):
    # Get samples for this track
    track_samples = vf.get_track_samples(track_name, 1/rate)
    #track_samples, _ = vf.get_samples(track_name, 1/rate)
    print(f"Track {track_name} @ {rate}Hz shape {len(track_samples)}")

    if track_name == ABP_TRACK_NAME:
        # ABP waveforms are used without further pre-processing
        originalAbp = track_samples
        filteredAbp = track_samples
    elif track_name == ECG_TRACK_NAME:
        originalEcg = track_samples
        # ECG waveforms are band-pass filtered between 1 and 40 Hz, and Z-score normalized
        # first apply bandpass filter
        filteredEcg = apply_bandpass_filter(track_samples, 1, 40, rate)
        # then do z-score normalization
        filteredEcg = apply_zscore_normalization(filteredEcg)
    elif track_name == EEG_TRACK_NAME:
        # EEG waveforms are band-pass filtered between 0.5 and 50 Hz
        originalEeg = track_samples
        filteredEeg = apply_bandpass_filter(track_samples, 0.5, 50, rate, 2)

def plotSignal(data, title):
    plt.figure(figsize=(20, 5))
    plt.plot(data)
    plt.title(title)
    plt.show()

plotSignal(originalAbp, "Original ABP")
plotSignal(originalAbp, "Unfiltered ABP")
plotSignal(originalEcg, "Original ECG")
plotSignal(filteredEcg, "Filtered ECG")
plotSignal(originalEeg, "Original EEG")
plotSignal(filteredEeg, "Filtered EEG")


In [None]:
def load_pickled_data(path):
    if USE_DISK_CACHING and os.path.isfile(path):
        with open(path, 'rb') as disk_cache_file:
            result = np.load(disk_cache_file)
            return result
    else:
        return None

def save_pickled_data(path, data):
    if USE_DISK_CACHING:
        with open(path, 'wb') as disk_cache_file:
            np.save(data, disk_cache_file)


In [None]:
# Preprocess data tracks
ABP_TRACK_NAME = "SNUADC/ART"
ECG_TRACK_NAME = "SNUADC/ECG_II"
EEG_TRACK_NAME = "BIS/EEG1_WAV"
MINI_FILE_FOLDER = VITAL_MINI
CACHE_FILE_FOLDER = VITAL_PREPROCESS_SCRATCH

if RESET_CACHE:
    TRACK_CACHE = None

if TRACK_CACHE is None:
    TRACK_CACHE = {}

def get_track_data(case, print_when_file_loaded = False):
    parsedFile = None
    abp = None
    eeg = None
    ecg = None
    for i, (track_name, rate) in enumerate(zip(TRACK_NAMES, TRACK_SRATES)):
        # use integer case id and track name, delimited by pipe, as cache key
        cache_label = f"{case}|{track_name}"
        if cache_label not in TRACK_CACHE:
            if parsedFile is None:
                file_path = f"{MINI_FILE_FOLDER}/{case:04d}_mini.vital"
                if print_when_file_loaded:
                    print(f"[{datetime.now()}] Loading vital file {file_path}")
                parsedFile = vitaldb.VitalFile(file_path, TRACK_NAMES)
            dataset = np.array(vf.get_track_samples(track_name, 1/rate))
            if track_name == ABP_TRACK_NAME:
                # no filtering for ABP
                abp = dataset
                TRACK_CACHE[cache_label] = abp
            elif track_name == ECG_TRACK_NAME:
                ecg = dataset
                # apply ECG filtering: first bandpass then do z-score normalization
                ecg = apply_bandpass_filter(ecg, 1, 40, rate, 2)
                ecg = apply_zscore_normalization(ecg)
                TRACK_CACHE[cache_label] = ecg
            elif track_name == EEG_TRACK_NAME:
                eeg = dataset
                # apply EEG filtering: bandpass only
                eeg = apply_bandpass_filter(eeg, 0.5, 50, rate, 2)
                TRACK_CACHE[cache_label] = eeg
        else:
            # cache hit, pull from cache
            if track_name == ABP_TRACK_NAME:
                abp = TRACK_CACHE[cache_label]
            elif track_name == ECG_TRACK_NAME:
                ecg = TRACK_CACHE[cache_label]
            elif track_name == EEG_TRACK_NAME:
                eeg = TRACK_CACHE[cache_label]

    return (abp, ecg, eeg)

# ABP waveforms are used without further pre-processing
# ECG waveforms are band-pass filtered between 1 and 40 Hz, and Z-score normalized
# EEG waveforms are band-pass filtered between 0.5 and 50 Hz
if PRELOADING_CASES:
    # determine disk cache file label
    maxlabel = "ALL"
    if MAX_CASES is not None:
        maxlabel = str(MAX_CASES)
    picklefile = f"{CACHE_FILE_FOLDER}/{EXPERIMENT_EVENT_HORIZON}_minutes_MAX{maxlabel}.trackcache"

    loaded = False
    TRACK_CACHE = load_pickled_data(picklefile)
    if TRACK_CACHE is None:
        TRACK_CACHE = {}
    if TRACK_CACHE is not None and (MAX_CASES is None or len(TRACK_CACHE) >= MAX_CASES):
        loaded = True
        print(f"Loaded track cache from {picklefile}, {len(TRACK_CACHE)} records loaded")

    if not loaded:
        print(f"At beginning of process, the track cache has {len(TRACK_CACHE)} entries present.")
        for track in tqdm(cases_of_interest_idx):
            # getting track data will cause a cache-check and fill when missing
            # will also apply appropriate filtering per track
            get_track_data(track, False)
        
        save_pickled_data(picklefile, TRACK_CACHE)
        print(f"Saved track cache to {picklefile}, {len(TRACK_CACHE)} records saved")


The following method is adapted from the preprocessing block of reference [6] (https://github.com/vitaldb/examples/blob/master/hypotension_art.ipynb)

In [None]:
# Generate hypotensive events
# Hypotensive events are defined as a 1-minute interval with sustained ABP of less than 65 mmHg
# Note: Hypotensive events should be at least 20 minutes apart to minimize potential residual effects from previous events
# Generate hypotension non-events
# To sample non-events, 30-minute segments where the ABP was above 75 mmHG were selected, and then
# three one-minute samples of each waveform were obtained from the middle of the segment
# both occur in extract_segments

def extract_segments(cases_of_interest_idx, min_before_event=3, debug=False):
    # Sampling rate for ABP and ECG, Hz. These rates should be the same. Default = 500
    ABP_ECG_SRATE_HZ = 500

    # Sampling rate for EEG. Default = 128
    EEG_SRATE_HZ = 128

    # Length of feature segment, seconds.
    FEATURE_LENGTH_SEC = 60
    # Look ahead to predict hypotension, seconds.
    MIDDLE_LENGTH_SEC  = 60 * min_before_event
    # Length of label segment, seconds.
    LABEL_LENGTH_SEC   = 60

    # Length to move down the ABP track for starting a new analysis segment, seconds.
    NEW_SEGMENT_OFFSET_SEC = 10

    # Final dataset for training and testing the model.
    # inputs with shape of (segments, timepoints)
    samples = []
    invalid_samples = []

    # Process each case and extract segments. For each segment identify presence of an event in the label zone.
    time_start = timer()
    
    count_cases = len(cases_of_interest_idx)

    for case_count, caseid in tqdm(enumerate(cases_of_interest_idx), total=count_cases):
        if debug:
            print(f'Loading case: {caseid:04d}, ({case_count + 1} of {count_cases})')
        
        segment_key = []
        segment_abp = []
        segment_ecg = []
        segment_eeg = []
        segment_label = []
        segment_validity = []
        segment_caseid = caseid

        # read the arterial waveform
        (abp, ecg, eeg) = get_track_data(caseid)


        # EEG - Different sample rate, process alone
        if debug:
            print(f'Length of {TRACK_NAMES[2]}:     {eeg.shape[0]}')

        print_first_segment = True

        last_sample_start_index = len(eeg) - EEG_SRATE_HZ * (FEATURE_LENGTH_SEC + MIDDLE_LENGTH_SEC + LABEL_LENGTH_SEC)
        sample_index_offset = NEW_SEGMENT_OFFSET_SEC * EEG_SRATE_HZ

        for i in range(0, last_sample_start_index, sample_index_offset):
            segx_start = i
            segx_end   = i + EEG_SRATE_HZ * FEATURE_LENGTH_SEC
            segx = eeg[segx_start:segx_end]

            if debug and print_first_segment:
                print(f'  Feature Segment Length:   {segx.shape[0]} pts, {segx.shape[0] / EEG_SRATE_HZ} sec')
                print_first_segment = False

            # handle eeg, only care about extracting data from the same time interval used for abp
            segment_eeg.append(eeg[segx_start:segx_end])

        # ABP and ECG - Shared sample rate, process together    
        if debug:
            print(f'Length of {TRACK_NAMES[0]}:       {abp.shape[0]}')
            print(f'Length of {TRACK_NAMES[1]}:    {ecg.shape[0]}')

        segment_count = 0
        segment_valid = 0
        segment_event = 0
        print_first_segment = True

        last_sample_start_index = len(abp) - ABP_ECG_SRATE_HZ * (FEATURE_LENGTH_SEC + MIDDLE_LENGTH_SEC + LABEL_LENGTH_SEC)
        sample_index_offset = NEW_SEGMENT_OFFSET_SEC * ABP_ECG_SRATE_HZ

        for i in range(0, last_sample_start_index, sample_index_offset):
            segment_count += 1

            segx_start = i
            segx_end   = i + ABP_ECG_SRATE_HZ * FEATURE_LENGTH_SEC
            segx = abp[segx_start:segx_end]

            segy_start = i + ABP_ECG_SRATE_HZ * (FEATURE_LENGTH_SEC + MIDDLE_LENGTH_SEC)
            segy_end   = i + ABP_ECG_SRATE_HZ * (FEATURE_LENGTH_SEC + MIDDLE_LENGTH_SEC + LABEL_LENGTH_SEC)
            segy = abp[segy_start:segy_end]

            if debug and print_first_segment:
                print(f'  Feature Segment Length:   {segx.shape[0]} pts, {segx.shape[0] / ABP_ECG_SRATE_HZ} sec')
                print(f'  Middle Segment Length:    {segy_start - segx_end} pts, {(segy_start - segx_end) / ABP_ECG_SRATE_HZ} sec')
                print(f'  Label Segment Length:     {segy.shape[0]} pts, {segy.shape[0] / ABP_ECG_SRATE_HZ} sec')
                print_first_segment = False

            # check the validity of this segment
            valid = True
            if np.isnan(segx).mean() > 0.1:
                valid = False
            elif np.isnan(segy).mean() > 0.1:
                valid = False
            elif (segx > 200).any():
                valid = False
            elif (segy > 200).any():
                valid = False
            elif (segx < 30).any():
                valid = False
            elif (segy < 30).any():
                valid = False
            elif np.max(segx) - np.min(segx) < 30:
                valid = False
            elif np.max(segy) - np.min(segy) < 30:
                valid = False
            elif (np.abs(np.diff(segx)) > 30).any():  # abrupt change -> noise
                valid = False
            elif (np.abs(np.diff(segy)) > 30).any():  # abrupt change -> noise
                valid = False

            # 2 sec moving avg
            n = 2 * ABP_ECG_SRATE_HZ  
            segy = np.nancumsum(segy, dtype=np.float32)
            segy[n:] = segy[n:] - segy[:-n]
            segy = segy[n - 1:] / n

            # forward filling - do this per case to avoid massive resource utilization at the end.
            segx = pd.DataFrame(segx).ffill(axis=0).bfill(axis=0)[0].values

            # Identify IOH event as < 65mm HG
            evt = np.nanmax(segy) < 65

            segment_abp.append(segx)
            segment_label.append(evt)
            segment_validity.append(valid)

            # handle ecg, only care about extracting the same segment used for abp.
            # data is already time aligned and has same sample rate.
            segment_ecg.append(ecg[segx_start:segx_end])
            segment_key.append(f"{caseid}|{segx_start}")

            

            if valid:
                segment_valid += 1
                if evt:
                    segment_event += 1

        for i in range(0, len(segment_abp)):
            if (segment_validity[i]):
                samples.append((segment_abp[i], segment_ecg[i], segment_eeg[i], segment_label[i], segment_validity[i], segment_caseid, segment_key[i]))
            else:
                invalid_samples.append((segment_abp[i], segment_ecg[i], segment_eeg[i], segment_label[i], segment_validity[i], segment_caseid, segment_key[i]))

        # if debug:
        #     print(f'Total Segments Evaluated:   {segment_count}')
        #     segment_valid_percent = 0 if segment_count == 0 else 100 * segment_valid / segment_count 
        #     print(f'  Segments Valid:           {segment_valid}, {segment_valid_percent:.1f}%')
        #     segment_event_percent = 0 if segment_valid == 0 else 100 * segment_event / segment_valid
        #     print(f'  Segments with Event:      {segment_event}, {segment_event_percent:.1f}%')
        #     time_delta = np.round(timer() - time_start, 3)
        #     print(f'Total Processing Time:      {time_delta:.4f} sec')
        #     print()


    # total processing time
    time_end = timer()
    time_delta = np.round(time_end - time_start, 3)

    # if debug:
    #     print('OVERALL SUMMARY')
    #     print(f'Total Processing Time:      {time_delta:.4f} sec')
    #     print(f'Total Cases Processed:      {caseids.shape[0]}')
    #     print(f'Total Segments Evaluated:   {x_abp.shape[0]}')

    #     segment_valid_count = np.sum(valid_mask)
    #     segment_valid_percent = 0 if x_abp.shape[0] == 0 else 100 * segment_valid_count / x_abp.shape[0] 
    #     print(f'  Segments Valid:           {segment_valid_count}, {segment_valid_percent:.1f}%')
    #     segment_event_count = np.sum(y & valid_mask)
    #     segment_event_percent = 0 if y.shape[0] == 0 else 100 * segment_event_count / y.shape[0]
    #     print(f'  Segments with Event:      {segment_event_count}, {segment_event_percent:.1f}%')


    #     print(f'Valid Samples Generated:    {(100 * np.mean(valid_mask)):.1f}%')
    #     print()
    #     print(f'Valid Mask Shape:           {valid_mask.shape}')
    #     print(f'X_ABP Shape:                {x_abp.shape}')
    #     print(f'X_ECG Shape:                {x_ecg.shape}')
    #     print(f'X_EEG Shape:                {x_eeg.shape}')
    #     print(f'Y Shape:                    {y.shape}')
    #     print(f'CIPS Shape:                 {case_id_per_segment.shape}')
    
    return pd.DataFrame(samples, columns=['segment_abp', 'segment_ecg', 'segment_eeg', 'segment_label', 'segment_valid', 'caseidx', 'segment_key'])


In [None]:
cutoff = MAX_CASES
# x_abp, x_ecg, x_eeg, y, valid_mask, case_id_per_segment = \
#     extract_segments(cases_of_interest_idx[:cutoff], min_before_event=3, debug=True)



if PRELOADING_CASES:
    # determine disk cache file label
    maxlabel = "ALL"
    if MAX_CASES is not None:
        maxlabel = str(MAX_CASES)
    picklefile = f"{CACHE_FILE_FOLDER}/{EXPERIMENT_EVENT_HORIZON}_minutes_MAX{maxlabel}.segmentcache"

    loaded = False
    samples = load_pickled_data(picklefile)
    if samples is not None:
        loaded = True
        print(f"Loaded segment cache from {picklefile}, {len(samples)} samples loaded")

    if not loaded:
        samples = \
            extract_segments(cases_of_interest_idx[:cutoff], min_before_event=EXPERIMENT_EVENT_HORIZON, debug=False)
        
        save_pickled_data(picklefile, samples)
        print(f"Saved segment cache to {picklefile}, {len(TRACK_CACHE)} samples saved")


In [None]:
# Split data into training, validation, and test sets
# Use 6:1:3 ratio and prevent samples from a single case from being split across different sets
# Note: number of samples at each time point is not the same, because the first event can occur before the 3/5/10/15 minute mark

# Set target sizes
train_ratio = 0.6
val_ratio = 0.1
test_ratio = 1 - train_ratio - val_ratio # ensure ratios sum to 1

# Split samples into train and other
samples_train, samples_other = train_test_split(samples, test_size=(1 - train_ratio), random_state=RANDOM_SEED)
# Split other into val and test
samples_val, samples_test = train_test_split(samples_other, test_size=(test_ratio / (1 - train_ratio)), random_state=RANDOM_SEED)

# Check how many samples are in each set
print(f"Train samples: {len(samples_train)}, ({len(samples_train) / len(samples):.2%})")
print(f"Val samples: {len(samples_val)}, ({len(samples_val) / len(samples):.2%})")
print(f"Test samples: {len(samples_test)}, ({len(samples_test) / len(samples):.2%})")

In [None]:
# Create vitalDataset class
class vitalDataset(Dataset):
    def __init__(self, file_dir, samples, track_names, track_srates_hz):
        # samples should be a list of (caseidx, starttime, endtime, label)
        self.file_dir = file_dir
        self.samples = samples
        self.track_names = track_names
        self.track_srates_hz = track_srates_hz
        self.vf_dict = {}

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Get metadata for this event
        segment = self.samples.iloc[idx]

        abp = segment['segment_abp']
        ecg = segment['segment_ecg']
        eeg = segment['segment_eeg']
        label = segment['segment_label']

        # all segment data now materialized prior to dataloader instantiation, simply read out the waveform data + label
        #(fullAbp, fullEcg, fullEeg) = get_track_data(caseidx)
        # for i, (track_name, rate) in enumerate(zip(self.track_names, self.track_srates_hz)):
        #     # Convert to tensor and store in samples
        #     start = int((endtime-starttime)*rate)
        #     end = start + int((endtime-starttime)*rate)

        #     if track_name == ABP_TRACK_NAME:
        #         abp = torch.tensor(np.array(fullAbp[start:end]))
        #         #abp = torch.tensor(np.array(fullAbp[0:end-start]))
        #     elif track_name == ECG_TRACK_NAME:
        #         ecg = torch.tensor(np.array(fullEcg[start:end]))
        #         #ecg = torch.tensor(np.array(fullEcg[0:end-start]))
        #     elif track_name == EEG_TRACK_NAME:
        #         eeg = torch.tensor(np.array(fullEeg[start:end]))
        #         #eeg = torch.tensor(np.array(fullEeg[0:end-start]))

        return abp, ecg, eeg, label

In [None]:


sample = samples.iloc[0]
(fullAbp, fullEcg, fullEeg) = get_track_data(sample['caseidx'])

plt.plot(fullAbp.ravel())
plt.plot(fullEcg.ravel())
plt.plot(fullEeg.ravel())
plt.show()
plt.plot(sample['segment_abp'])
plt.plot(sample['segment_ecg'])
plt.plot(sample['segment_eeg'])
plt.show()

In [None]:
train_dataset = vitalDataset(f'{VITALDB_CACHE}/{VITAL_MINI}/', samples_train, TRACK_NAMES, TRACK_SRATES)
val_dataset = vitalDataset(f'{VITALDB_CACHE}/{VITAL_MINI}/', samples_val, TRACK_NAMES, TRACK_SRATES)
test_dataset = vitalDataset(f'{VITALDB_CACHE}/{VITAL_MINI}/', samples_test, TRACK_NAMES, TRACK_SRATES)

BATCH_SIZE = 32
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

##   Model

The model implementation is based on the CNN architecture described in Jo Y-Y et al. (2022). It is designed to handle 1, 2, or 3 signal categories simultaneously, allowing for flexible model configurations based on different combinations of physiological signals:
 * ABP alone
 * EEG alone
 * ECG alone
 * ABP + EEG
 * ABP + ECG
 * EEG + ECG
 * ABP + EEG + ECG

### Model Architecture

The architecture, as depicted in Figure 2 from the original paper, utilizes a ResNet-based approach tailored for time-series data from different physiological signals. The model architecture is adapted to handle varying input signal frequencies, with specific hyperparameters for each signal type, particularly EEG, due to its distinct characteristics compared to ABP and ECG. A diagram of the model architecture is shown below:

![Architecture of the hypotension risk prediction model using multiple waveforms](https://journals.plos.org/plosone/article/figure/image?download&size=large&id=10.1371/journal.pone.0272055.g002)

Each input signal is processed through a sequence of 12 7-layer residual blocks, followed by a flattening process and a linear transformation to produce a 32-dimensional feature vector per signal type. These vectors are then concatenated (if multiple signals are used) and passed through two additional linear layers to produce a single output vector, representing the IOH index. A threshold is determined experimentally in order to minimize the differene between the sensitivity and specificity and is applied to this index to perform binary classification for predicting IOH events.

The hyperparameters for the residual blocks are specified in Supplemental Table 1 from the original paper and vary for different signal type.

A forward pass through the model passes through 85 layers before concatenation, followed by two more linear layers and finally a sigmoid activation layer to produce the prediction measure.

### Residual Block Definition

Each residual block consists of the following seven layers:
 
 * Batch normalization
 * ReLU
 * Dropout (0.5)
 * 1D convolution
 * Batch normalization
 * ReLU
 * 1D convolution

Skip connections are included to aid in gradient flow during training, with optional 1D convolution in the skip connection to align dimensions.

#### Residual Block Hyperparameters

The hyperparameters are detailed in Supplemental Table 1 of the original paper. A screenshot of these hyperparameters is provided for reference below:

![Supplemental Table 1 from original paper](<https://github.com/abarrie2/cs598-dlh-project/blob/main/img/table_1_hyperparameters.png?raw=true>)

**Note**: Please be aware of a transcription error in the original paper's Supplemental Table 1 for the ECG+ABP configuration in Residual Blocks 11 and 12, where the output size should be 469 * 6 instead of the reported 496 * 6.

### Training Objectives

Our model uses binary cross entropy as the loss function and Adam as the optimizer, consistent with the original study. The learning rate is set at 0.0001, and training is configured to run for up to 100 epochs, with early stopping implemented if no improvement in loss is observed over five consecutive epochs.

In [None]:
# First define the residual block which is reused 12x for each data track for each sample.
# Second define the primary model.
class ResidualBlock(nn.Module):
    def __init__(self, in_features: int, out_features: int, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, size_down: bool = False) -> None:
        super(ResidualBlock, self).__init__()
        
        # calculate the appropriate padding required to ensure expected sequence lengths out of each residual block
        padding = int((((stride-1)*in_features)-stride+kernel_size)/2)

        self.size_down = size_down
        self.bn1 = nn.BatchNorm1d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5)
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=False)
        
        self.residualConv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=False)

        # unclear where in sequence this hsuold take place. Size down expressed in Supplemental table S1
        if self.size_down:
            pool_padding = (1 if (in_features % 2 > 0) else 0)
            self.downsample = nn.MaxPool1d(kernel_size=2, stride=2, padding = pool_padding)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        
        out = self.bn1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv1(out)

        if self.size_down:
            out = self.downsample(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        
        if out.shape != identity.shape:
            # run the residual through a convolution when necessary
            identity = self.residualConv(identity)
            
            outlen = np.prod(out.shape)
            idlen = np.prod(identity.shape)
            # downsample when required
            if idlen > outlen:
                identity = self.downsample(identity)
            # match dimensions
            identity = identity.reshape(out.shape)
       
        # add the residual       
        out += identity

        return  out

class HypotensionCNN(nn.Module):
    def __init__(self, useAbp: bool = True, useEeg: bool = False, useEcg: bool = False) -> None:
        super(HypotensionCNN, self).__init__()

        self.useAbp = useAbp
        self.useEeg = useEeg
        self.useEcg = useEcg

        if useAbp:
            self.abpBlock1 = ResidualBlock(30000, 15000, 1, 2, 15, 1, True)
            self.abpBlock2 = ResidualBlock(15000, 15000, 2, 2, 15, 1, False)
            self.abpBlock3 = ResidualBlock(15000, 7500, 2, 2, 15, 1, True)
            self.abpBlock4 = ResidualBlock(7500, 7500, 2, 2, 15, 1, False)
            self.abpBlock5 = ResidualBlock(7500, 3750, 2, 2, 15, 1, True)
            self.abpBlock6 = ResidualBlock(3750, 3750, 2, 4, 15, 1, False)
            self.abpBlock7 = ResidualBlock(3750, 1875, 4, 4, 7, 1, True)
            self.abpBlock8 = ResidualBlock(1875, 1875, 4, 4, 7, 1, False)
            self.abpBlock9 = ResidualBlock(1875, 938, 4, 4, 7, 1, True)
            self.abpBlock10 = ResidualBlock(938, 938, 4, 4, 7, 1, False)
            self.abpBlock11 = ResidualBlock(938, 469, 4, 6, 7, 1, True)
            self.abpBlock12 = ResidualBlock(469, 469, 6, 6, 7, 1, False)
            self.abpFc = nn.Linear(6*469, 32)
        
        if useEcg:
            self.ecgBlock1 = ResidualBlock(30000, 15000, 1, 2, 15, 1, True)
            self.ecgBlock2 = ResidualBlock(15000, 15000, 2, 2, 15, 1, False)
            self.ecgBlock3 = ResidualBlock(15000, 7500, 2, 2, 15, 1, True)
            self.ecgBlock4 = ResidualBlock(7500, 7500, 2, 2, 15, 1, False)
            self.ecgBlock5 = ResidualBlock(7500, 3750, 2, 2, 15, 1, True)
            self.ecgBlock6 = ResidualBlock(3750, 3750, 2, 4, 15, 1, False)
            self.ecgBlock7 = ResidualBlock(3750, 1875, 4, 4, 7, 1, True)
            self.ecgBlock8 = ResidualBlock(1875, 1875, 4, 4, 7, 1, False)
            self.ecgBlock9 = ResidualBlock(1875, 938, 4, 4, 7, 1, True)
            self.ecgBlock10 = ResidualBlock(938, 938, 4, 4, 7, 1, False)
            self.ecgBlock11 = ResidualBlock(938, 469, 4, 6, 7, 1, True)
            self.ecgBlock12 = ResidualBlock(469, 469, 6, 6, 7, 1, False)
            self.ecgFc = nn.Linear(6 * 469, 32)
        
        if useEeg:
            self.eegBlock1 = ResidualBlock(7680, 3840, 1, 2, 7, 1, True)
            self.eegBlock2 = ResidualBlock(3840, 3840, 2, 2, 7, 1, False)
            self.eegBlock3 = ResidualBlock(3840, 1920, 2, 2, 7, 1, True)
            self.eegBlock4 = ResidualBlock(1920, 1920, 2, 2, 7, 1, False)
            self.eegBlock5 = ResidualBlock(1920, 960, 2, 2, 7, 1, True)
            self.eegBlock6 = ResidualBlock(960, 960, 2, 4, 7, 1, False)
            self.eegBlock7 = ResidualBlock(960, 480, 4, 4, 3, 1, True)
            self.eegBlock8 = ResidualBlock(480, 480, 4, 4, 3, 1, False)
            self.eegBlock9 = ResidualBlock(480, 240, 4, 4, 3, 1, True)
            self.eegBlock10 = ResidualBlock(240, 240, 4, 4, 3, 1, False)
            self.eegBlock11 = ResidualBlock(240, 120, 4, 6, 3, 1, True)
            self.eegBlock12 = ResidualBlock(120, 120, 6, 6, 3, 1, False)
            self.eegFc = nn.Linear(6 * 120, 32)

        concatSize = 0
        if useAbp:
            concatSize += 32
        if useEeg:
            concatSize += 32
        if useEcg:
            concatSize += 32

        self.fullLinear1 = nn.Linear(concatSize, 16)
        self.fullLinear2 = nn.Linear(16, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, abp: torch.Tensor, eeg: torch.Tensor, ecg: torch.Tensor) -> torch.Tensor:

        batchSize = len(abp)

        # conditionally operate ABP, EEG, and ECG networks
        if self.useAbp:
            abp = self.abpBlock1(abp)
            abp = self.abpBlock2(abp)
            abp = self.abpBlock3(abp)
            abp = self.abpBlock4(abp)
            abp = self.abpBlock5(abp)
            abp = self.abpBlock6(abp)
            abp = self.abpBlock7(abp)
            abp = self.abpBlock8(abp)
            abp = self.abpBlock9(abp)
            abp = self.abpBlock10(abp)
            abp = self.abpBlock11(abp)
            abp = self.abpBlock12(abp)
            totalLen = np.prod(abp.shape)
            abp = torch.reshape(abp, (batchSize, int(totalLen / batchSize)))
            abp = self.abpFc(abp)

        if self.useEeg:
            eeg = self.eegBlock1(eeg)
            eeg = self.eegBlock2(eeg)
            eeg = self.eegBlock3(eeg)
            eeg = self.eegBlock4(eeg)
            eeg = self.eegBlock5(eeg)
            eeg = self.eegBlock6(eeg)
            eeg = self.eegBlock7(eeg)
            eeg = self.eegBlock8(eeg)
            eeg = self.eegBlock9(eeg)
            eeg = self.eegBlock10(eeg)
            eeg = self.eegBlock11(eeg)
            eeg = self.eegBlock12(eeg)
            totalLen = np.prod(eeg.shape)
            eeg = torch.reshape(eeg, (batchSize, int(totalLen / batchSize)))
            eeg = self.eegFc(eeg)
        
        if self.useEcg:
            ecg = self.ecgBlock1(ecg)
            ecg = self.ecgBlock2(ecg)
            ecg = self.ecgBlock3(ecg)
            ecg = self.ecgBlock4(ecg)
            ecg = self.ecgBlock5(ecg)
            ecg = self.ecgBlock6(ecg)
            ecg = self.ecgBlock7(ecg)
            ecg = self.ecgBlock8(ecg)
            ecg = self.ecgBlock9(ecg)
            ecg = self.ecgBlock10(ecg)
            ecg = self.ecgBlock11(ecg)
            ecg = self.ecgBlock12(ecg)
            #ecg = torch.flatten(ecg)
            totalLen = np.prod(ecg.shape)
            ecg = torch.reshape(ecg, (batchSize, int(totalLen / batchSize)))
            ecg = self.ecgFc(ecg)
        
        # concatenation
        merged = None
        if self.useAbp and self.useEeg and self.useEcg:
            merged = torch.cat((abp, eeg, ecg), dim=1)
        elif self.useAbp and self.useEeg:
            merged = torch.cat((abp, eeg), dim=1)
        elif self.useAbp and self.useEcg:
            merged = torch.cat((abp, ecg), dim=1)
        elif self.useEeg and self.useEcg:
            merged = torch.cat((eeg, ecg), dim=1)
        elif self.useAbp:
            merged = abp
        elif self.useEeg:
            merged = eeg
        elif self.useEcg:
            merged = ecg

        totalLen = np.prod(merged.shape)
        merged = torch.reshape(merged, (batchSize, int(totalLen / batchSize)))
        out = self.fullLinear1(merged)
        out = self.fullLinear2(out)
        out = self.sigmoid(out)

        out = torch.nan_to_num(out)
        return out

### Training

As discussed earlier, our model uses binary cross entropy as the loss function and Adam as the optimizer, consistent with the original study. The learning rate is set at 0.0001, and training is configured to run for up to 100 epochs, with early stopping implemented if no improvement in loss is observed over five consecutive epochs.

In [None]:
experimentName = "ABP_EEG_ECG_DEFAULT"
useAbp = True
useEeg = True
useEcg = False

model = HypotensionCNN(useAbp, useEeg, useEcg)
loss_func = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if (torch.backends.mps.is_available() and torch.backends.mps.is_built()) else "cpu")
print(f"Using device: {device}")
model = model.to(device)

def train_model_one_iter(model, loss_func, optimizer, dataloader):
  curr_epoch_loss = []
  for abp, ecg, eeg, label in tqdm(dataloader):
    batch = len(abp)
    
    abpSampleCount = int(np.prod(abp.shape)/batch)
    ecgSampleCount = int(np.prod(ecg.shape)/batch)
    eegSampleCount = int(np.prod(eeg.shape)/batch)

    abp = torch.nan_to_num(abp.reshape(batch, 1, abpSampleCount)).type(torch.FloatTensor)
    ecg = torch.nan_to_num(ecg.reshape(batch, 1, ecgSampleCount)).type(torch.FloatTensor)
    eeg = torch.nan_to_num(eeg.reshape(batch, 1, eegSampleCount)).type(torch.FloatTensor)
    label = label.type(torch.float).reshape(batch, 1)
   
    abp = abp.to(device)
    eeg = eeg.to(device)
    ecg = ecg.to(device)
    label = label.to(device)

    optimizer.zero_grad()
    mdl = model(abp, eeg, ecg)
    loss = loss_func(torch.nan_to_num(mdl), label)
    loss.backward()
    optimizer.step()
    curr_epoch_loss.append(loss.cpu().data.numpy())
  return np.mean(curr_epoch_loss)



num_epoch = 100

# model training loop: it is better to print the training/validation losses during the training
model.train(True)
losses = []
best_loss = float('inf')
patience = 5
no_improve_epochs = 0
for i in range(num_epoch):
  train_loss = train_model_one_iter(model, loss_func, optimizer, train_loader)
  losses.append(train_loss)
  print(f"[{datetime.now()}] Completed epoch {i} with train loss {train_loss}")

  # check if loss has improved
  if train_loss < best_loss:
    best_loss = train_loss
    no_improve_epochs = 0
  else:
    no_improve_epochs += 1

  # exit early if no improvement in loss over last 'patience' epochs
  if no_improve_epochs >= patience:
    print("Exiting early due to stable training loss")
    break

model.train(False)
torch.save(model.state_dict(), f"{VITAL_MODELS}/{experimentName}.model")

In [None]:
def eval_model(model, dataloader):
    model.eval()
    model = model.to(device)
    total_loss = 0
    y_pred = []
    y_true = []

    with torch.no_grad():
        for abp, ecg, eeg, label in tqdm(dataloader):
            abp, ecg, eeg, label = abp.to(device), ecg.to(device), eeg.to(device), label.to(device)
            output = model(abp, eeg, ecg)
            loss = loss_func(output, label)
            total_loss += loss.item()

            y_pred.append(output.detach())
            y_true.append(label.detach())

    avg_loss = total_loss / len(dataloader)
    return y_pred, y_true, avg_loss

# validation loop
valid_loss = eval_model(model, val_loader)

# test loop
test_loss = eval_model(model, test_loader)

# Results (Planned results for Draft submission)

When we complete our experiments, we will build comparison tables that compare a set of measures for each experiment performed. The full set of experiments and measures are listed below.

## Experiments

 * ABP only
 * ECG only
 * EEG only
 * ABP + ECG
 * ABP + EEG
 * ECG + EEG
 * ABP + ECG + EEG

Note: each experiment will be repeated with the following time-to-IOH-event durations:
 * 3 minutes
 * 5 minutes
 * 10 minutes
 * 15 minutes

Note: the above list of experiments will be performed if there is sufficient time and gpu capability to complete that before the submission deadline. Should we experience any constraints on this front, we will reduce our experimental coverage to the following 4 core experiments that are necessary to measure the hypotheses included at the head of this report:
 * ABP only @ 3 minutes
 * ABP + ECG @ 3 minutes
 * ABP + EEG @ 3 minutes
 * ABP + ECG + EEG @ 3 minutes

For additional details please review the "Planned Actions" in the Discussion section of this report.

## Measures

 * AUROC
 * AUPRC
 * Sensitivity
 * Specificity
 * Threshold
 * Loss Shrinkage

[ TODO for final report - collect data for all measures listed above. ]

[ TODO for final report - generate ROC and PRC plots for each experiment ]

We are collecting a broad set of measures across each experiment in order to perform a comprehensive comparison of all measures listed across all comparable experiments executed in the original paper. However, our key experimental results will be focused on a subset of these results that address the main experiments defined at the beginning of this notebook.

The key experimental result measures will be as follows:

* For 3 minutes ahead of the predicted IOH event:
  * compare AUROC and AUPRC for ABP only vs ABP+ECG
  * compare AUROC and AUPRC for ABP only vs ABP+EEG
  * compare AUROC and AUPRC for ABP only vs ABP+ECG+EEG


In [None]:
# calculate AUROC, AUPRC, sensitivity, specificity, thresold
def getMeasures(model):
    auroc = None
    auprc = None
    sensitivity = None
    specificity = None
    threshold = None
    loss_shrinkage = None
    
    return (auroc, auprc, sensitivity, specificity, threshold, loss_shrinkage)

abp3 = getMeasures("abp 3 minute")
abpEcg3 = getMeasures("abp+Ecg 3 minute")
abpEeg3 = getMeasures("abp+Eeg 3 minute")
abpEcgEeg3 = getMeasures("abp+Ecg+Eeg 3 minute")


# TODO for final report - generate plots


## Model comparison

The following table is Table 3 from the original paper which presents the measured values for each signal combination across each of the four temporal predictive categories:

![Area under the Receiver-operating Characteristic Curve, Area under the Precision-Recall Curve, Sensitivity, and Specificity of the model in predicting intraoperative hypotension](https://journals.plos.org/plosone/article/figure/image?download&size=large&id=10.1371/journal.pone.0272055.t003)

We have not yet completed the execution of the experiments necessary to determine our reproduced model performance in order determine whether our results are accurately representing those of the original paper. These details are expected to be included in the final report.

# Discussion

### Feasibility of reproduction
Our assessment is that this paper will be reproducible. The outstanding risk is that each experiment can take up to 7 hours to run on hardware within the team (i.e., 7h to run ~70 epochs on a desktop with AMD Ryzen 7 3800X 8-core CPU w/ RTX 2070 SUPER GPU and 32GB RAM). There are a total of 28 experiments (7 different combinations of signal inputs, 4 different time horizons for each combination). Should our team find it not possible to complete the necessary experiments across all of the experiments represented in Table 3 of our selected paper, we will reduce the number of experiments to focus solely on the ones directly related to our hypotheses described in the beginning of this notebook (i.e., reduce the number of combinations of interest to 4: ABP alone, ABP+EEG, ABP+ECG, ABP+ECG+EEG). This will result in a new total of 16 experiments to run.

### Planned ablations
Our proposal included a collection of potential ablations to be investigated:

* Remove ResNet skip connection
* Reduce # of residual blocks from 12 to 6
* Reduce # of residual blocks from 12 to 1
* Eliminate dropout from residual block
* Max pooling configuration
  * smaller size/stride
  * eliminate max pooling

Given the amount of time required to conduct each experiment, our team intends to choose only a small number of ablations from this set. Further, we only intend to perform ablation analysis against the best performing signal combination and time horizon from the reproduction experiments. In order words, we intend to perform ablation analysis against the following training combinations, and only against the models trained with data measured 3 minutes prior to an IOH event:
  * ABP alone
  * ABP + ECG
  * ABP + EEG
  * ABP + ECG + EEG

Time and GPU resource permitting, we will complete a broader range of experiments. For additional details, please see the section below titled "Plans for next phase".

### Nature of reproduced results
Our team intends to address the manner in which the experimental results align with the published results in the paper in the final submission of this report. The amount of time required to complete model training and result analysis during the preparation of the Draft notebook was not sufficient to compelte a large number of experiments.

### What was easy? What was difficult?
The difficult aspect of the preparation of this draft involved the data preprocessing.
 * First, the source data is unlabelled, so our team was responsible for implementing analysis methods for identifying positive (IOH event occurred) and negative (IOH event did not occur) by running a lookahead analysis of our input training set.
 * Second, the volume of raw data is in excess of 90GB. A non-trivial amount of compute was required to minify the input data to only include the data tracks of interest to our experiments (i.e., ABP, ECG, and EEG tracks).
 * Third, our team found it difficult to trace back to the definition of the jSQI signal quality index referenced in the paper. Multiple references through multiple papers needed to be traversed to understand which variant of the quality index 
   * The only available source code related to the signal quality index as referenced by our paper in [5]. Source code was not directly linked from the paper, but the GitHub repository for the corresponding author for reference [5] did result in the identification of MATLAB source code for the signal quality index as described in the referenced paper. That code is available here: https://github.com/cliffordlab/PhysioNet-Cardiovascular-Signal-Toolbox/tree/master/Tools/BP_Tools
   * Our team had insufficient time to port this signal quality index to Python for use in our investigation, or to setup a MATLAB environment in which to assess our source data using the above MATLAB functions, but we expect to complete this as part of our final report.

### Suggestions to paper author
The most notable suggestion would be to correct the hyperparameters published in Supplemental Table 1. Specifically, the output size for residual blocks 11 and 12 for the ECG and ABP data sets was 496x6. This is a typo, and shuold read 469x6. This typo became apparent when operating the size down operation within Residual Block 11 and recognizing the tensor dimensions were misaligned.

Additionally, more explicit references to the signal quality index assessment tools should be added. Our team could not find a reference to the MATLAB source code as described in reference [3], and had to manually discover the GitHub profile for the lab of the corresponding author of reference [3] in order to find MATLAB source that corresponded to the metrics described therein.

### Plans for next phase
Our team plans to accomplish the following goals in service of preparing the Final Report:
 * Implement the jSQI filter to remove any training data with aberrent signal quality per the threshold defined in our original paper.
 * Execute the following experiments:
   * Measure predictive quality of the model trained solely with ABP data at 3 minutes prior to IOH events.
   * Measure predictive quality of the model trained with ABP+ECG data at 3 minutes prior to IOH events.
   * Measure predictive quality of the model trained with ABP+EEG data at 3 minutes prior to IOH events.
   * Measure predictive quality of the model trained with ABP+ECG+EEG data at 3 minutes prior to IOH events.
 * Gather our measures for these experiments and perform a comparison against the published results from our selected paper and determine whether or not we are succesfully reproducing the results outlined in the paper.
 * Ablation analysis:
   * Execute the following ablation experiments:
     * Repeat the four experiments described above while reducing the numnber of residual blocks in the model from 12 to 6.
 * Time- and/or GPU-resource permitting, we will complete the remaining 24 experiments as described in the paper:
   * Measure predictive quality of the model trained solely with ABP data at 5, 10, and 15 minutes prior to IOH events.
   * Measure predictive quality of the model trained with ABP+ECG data at 5, 10, and 15 minutes prior to IOH events.
   * Measure predictive quality of the model trained with ABP+EEG data at 5, 10, and 15 minutes prior to IOH events.
   * Measure predictive quality of the model trained with ABP+ECG+EEG data at 5, 10, and 15 minutes prior to IOH events.
   * Measure predictive quality of the model trained solely with ECG data at 3, 5, 10, and 15 minutes prior to IOH events.
   * Measure predictive quality of the model trained solely with EEG data at 3, 5, 10, and 15 minutes prior to IOH events.
   * Measure predictive quality of the model trained with ECG+EEG data at 3, 5, 10, and 15 minutes prior to IOH events.
   * Additional ablation experiments:
     * For the four core experiments (ABP, ABP+ECG, ABP+EEG, ABP+ECG+EEG each trained on event data occurring 3 minutes prior to IOH events), perform the following ablations:
       * Repeat experiment while eliminating dropout from every residual block
       * Repeat experiment while removing the skip connection from every residual block
       * Repeat the four experiments described above while reducing the numnber of residual blocks in the model from 12 to 1.

# References

1. Jo Y-Y, Jang J-H, Kwon J-m, Lee H-C, Jung C-W, Byun S, et al. “Predicting intraoperative hypotension using deep learning with waveforms of arterial blood pressure, electroencephalogram, and electrocardiogram: Retrospective study.” PLoS ONE, (2022) 17(8): e0272055 https://doi.org/10.1371/journal.pone.0272055
2. Hatib, Feras, Zhongping J, Buddi S, Lee C, Settels J, Sibert K, Rhinehart J, Cannesson M “Machine-learning Algorithm to Predict Hypotension Based on High-fidelity Arterial Pressure Waveform Analysis” Anesthesiology (2018) 129:4 https://doi.org/10.1097/ALN.0000000000002300
3. Bao, X., Kumar, S.S., Shah, N.J. et al. "AcumenTM hypotension prediction index guidance for prevention and treatment of hypotension in noncardiac surgery: a prospective, single-arm, multicenter trial." Perioperative Medicine (2024) 13:13 https://doi.org/10.1186/s13741-024-00369-9
4. Lee, HC., Park, Y., Yoon, S.B. et al. VitalDB, a high-fidelity multi-parameter vital signs database in surgical patients. Sci Data 9, 279 (2022). https://doi.org/10.1038/s41597-022-01411-5
5. Li Q., Mark R.G. & Clifford G.D. "Artificial arterial blood pressure artifact models and an evaluation of a robust blood pressure and heart rate estimator." BioMed Eng OnLine. (2009) 8:13. pmid:19586547 https://doi.org/10.1186/1475-925X-8-13
6. Park H-J, "VitalDB Python Example Notebooks" GitHub Repository https://github.com/vitaldb/examples/blob/master/hypotension_art.ipynb