In [4]:
import os
import sys
sys.path.append('..')
import tqdm
import json
import torch
import pandas as pd
from collections import defaultdict
from sklearn.linear_model import LogisticRegression
from libs.helper import load_eer_thresholds
from libs.dataloader import MultiModalDataLoader
from libs.model import ECG_Inception, PPG_Inception, EDA_LSTM

In [5]:
subject_id = "c1s01"
data_dir = "../data"
out_dir = "../results"
version = "v1"

In [6]:
ecg_data_path = os.path.join(data_dir, subject_id, 'ecg.pkl')
ppg_data_path = os.path.join(data_dir, subject_id, 'ppg.pkl')
eda_data_path = os.path.join(data_dir, subject_id, 'eda.pkl')
temp_data_path = os.path.join(data_dir, subject_id, 'temp.pkl')
metadata_path = os.path.join(data_dir, subject_id, 'metadata.json')

ecg_df = pd.read_pickle(ecg_data_path)
ppg_df = pd.read_pickle(ppg_data_path)
eda_df = pd.read_pickle(eda_data_path)
temp_df = pd.read_pickle(temp_data_path)
with open(metadata_path, 'r') as f:
    metadata = json.load(f)

In [7]:
ecg_model = ECG_Inception()
ppg_model = PPG_Inception()
eda_model = EDA_LSTM()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for data_type in ['ecg', 'ppg', 'eda']:
    out_ver_dir = os.path.join(out_dir, data_type, subject_id, version)
    ckpt_path = os.path.join(out_ver_dir, "best.pth")
    assert os.path.exists(ckpt_path), f"Checkpoint not found at {ckpt_path}"

    if data_type == 'ecg':
        ecg_model.load_state_dict(torch.load(ckpt_path))
        ecg_model.eval()
        ecg_model.to(device)
    elif data_type == 'ppg':
        ppg_model.load_state_dict(torch.load(ckpt_path))
        ppg_model.eval()
        ppg_model.to(device)
    elif data_type == 'eda':
        eda_model.load_state_dict(torch.load(ckpt_path))
        eda_model.eval()
        eda_model.to(device)

In [8]:
loader = MultiModalDataLoader(ecg_df, ppg_df, eda_df, temp_df, metadata[version], verbose=True)

Temperature mean: 34.40, std: 0.93


In [9]:
train_loader = loader.get_loader("train", batch_size=1, shuffle=False)
val_loader = loader.get_loader("val", batch_size=1, shuffle=False)

In [10]:
eer_thresholds = load_eer_thresholds(subject_id, out_dir)
print(eer_thresholds)

{'ecg': [0.53606116771698, 0.3416595757007599, 0.4535430073738098, 0.5926393866539001, 0.3898150622844696], 'ppg': [0.5296138525009155, 0.4670945107936859, 0.48016107082366943, 0.4767318665981293, 0.5148026347160339], 'eda': [0.46475496888160706, 0.5323653817176819, 0.5472270250320435, 0.5018326640129089, 0.5097512602806091]}


## Evaluate the accuracy

### Majority Vote

In [11]:
results = defaultdict(list)
with torch.no_grad():
    for i, (ecg_data, ppg_data, eda_data, temp_data, hypo_label, (glucose, CGM_idx)) in enumerate(tqdm.tqdm(val_loader)):
        ecg_data = ecg_data.float().to(device).squeeze(0)
        ppg_data = ppg_data.float().to(device).squeeze(0)
        eda_data = eda_data.float().to(device).squeeze(0)
        temp_data = temp_data.float().to(device).squeeze(0) 
        hypo_label = hypo_label.to(device)

        ecg_output = ecg_model(ecg_data)
        ppg_output = ppg_model(ppg_data)
        eda_output = eda_model(eda_data[0].unsqueeze(0), eda_data[1].unsqueeze(0))

        if ecg_data.shape[0] == 1:
            ecg_output = ecg_output.unsqueeze(0)
        if ppg_data.shape[0] == 1:
            ppg_output = ppg_output.unsqueeze(0)
        eda_output = eda_output.unsqueeze(0)

        ecg_output = ecg_output.cpu().numpy()
        ppg_output = ppg_output.cpu().numpy()
        eda_output = eda_output.cpu().numpy()
        hypo_label = hypo_label.item()

        ecg_threshold = eer_thresholds['ecg'][int(version.split('v')[-1])-1]
        ppg_threshold = eer_thresholds['ppg'][int(version.split('v')[-1])-1]
        eda_threshold = eer_thresholds['eda'][int(version.split('v')[-1])-1]

        binarized_ecg_output = (ecg_output > ecg_threshold)
        binarized_ecg_output = (binarized_ecg_output.sum(axis=0) > (binarized_ecg_output.shape[0]//2)).astype(int) # majority voting
        binarized_ppg_output = (ppg_output > ppg_threshold)
        binarized_ppg_output = (binarized_ppg_output.sum(axis=0) > (binarized_ppg_output.shape[0]//2)).astype(int)
        binarized_eda_output = (eda_output > eda_threshold)
        binarized_eda_output = (binarized_eda_output.sum(axis=0) > (binarized_eda_output.shape[0]//2)).astype(int)

        results['ecg'].append(binarized_ecg_output)
        results['ppg'].append(binarized_ppg_output)
        results['eda'].append(binarized_eda_output)
        results['label'].append(hypo_label)


  return F.conv1d(input, weight, bias, self.stride,
100%|██████████| 396/396 [00:19<00:00, 20.81it/s]


In [12]:
# plot the accuracy for ensemble
results = pd.DataFrame(results)

ecg_acc = (results['ecg'] == results['label']).mean()
ppg_acc = (results['ppg'] == results['label']).mean()
eda_acc = (results['eda'] == results['label']).mean()

print(f"ECG Accuracy: {ecg_acc:.2f}")
print(f"PPG Accuracy: {ppg_acc:.2f}")
print(f"EDA Accuracy: {eda_acc:.2f}")

ECG Accuracy: 0.80
PPG Accuracy: 0.56
EDA Accuracy: 0.67


### Logistic Regression

In [13]:
train_cgm_df = defaultdict(list)
with torch.no_grad():
    for ecg_data, ppg_data, eda_data, temp_data, hypo_label, (glucose, CGM_idx) in tqdm.tqdm(train_loader):
        ecg_data = ecg_data.float().to(device).squeeze(0)
        ppg_data = ppg_data.float().to(device).squeeze(0)
        eda_data = eda_data.float().to(device).squeeze(0)
        temp_data = temp_data.float().to(device).squeeze(0)
        hypo_label = hypo_label.to(device)

        ecg_output = ecg_model(ecg_data)
        ppg_output = ppg_model(ppg_data)
        eda_output = eda_model(eda_data[0].unsqueeze(0), eda_data[1].unsqueeze(0))

        if ecg_data.shape[0] == 1:
            ecg_output = ecg_output.unsqueeze(0)
        if ppg_data.shape[0] == 1:
            ppg_output = ppg_output.unsqueeze(0)
        eda_output = eda_output.unsqueeze(0)

        ecg_output = ecg_output.cpu().numpy().mean()
        ppg_output = ppg_output.cpu().numpy().mean()
        eda_output = eda_output.cpu().numpy().mean()
        temp_output = temp_data.cpu().numpy().mean()

        train_cgm_df['ecg'].append(ecg_output)
        train_cgm_df['ppg'].append(ppg_output)
        train_cgm_df['eda'].append(eda_output)
        train_cgm_df['temp'].append(temp_output)
        train_cgm_df['label'].append(hypo_label.item())
        
train_cgm_df = pd.DataFrame(train_cgm_df)

100%|██████████| 1588/1588 [01:09<00:00, 22.90it/s]


In [14]:
# Initialize logistic regression model
MoE_model = LogisticRegression(class_weight='balanced')
# Train the model
MoE_model.fit(train_cgm_df[['ecg', 'ppg', 'eda', 'temp']], train_cgm_df['label'])

In [15]:
val_cgm_df = defaultdict(list)
with torch.no_grad():
    for ecg_data, ppg_data, eda_data, temp_data, hypo_label, (glucose, CGM_idx) in tqdm.tqdm(val_loader):
        ecg_data = ecg_data.float().to(device).squeeze(0)
        ppg_data = ppg_data.float().to(device).squeeze(0)
        eda_data = eda_data.float().to(device).squeeze(0)
        temp_data = temp_data.float().to(device).squeeze(0)
        hypo_label = hypo_label.to(device)

        ecg_output = ecg_model(ecg_data)
        ppg_output = ppg_model(ppg_data)
        eda_output = eda_model(eda_data[0].unsqueeze(0), eda_data[1].unsqueeze(0))

        if ecg_data.shape[0] == 1:
            ecg_output = ecg_output.unsqueeze(0)
        if ppg_data.shape[0] == 1:
            ppg_output = ppg_output.unsqueeze(0)
        eda_output = eda_output.unsqueeze(0)

        ecg_output = ecg_output.cpu().numpy().mean()
        ppg_output = ppg_output.cpu().numpy().mean()
        eda_output = eda_output.cpu().numpy().mean()
        temp_output = temp_data.cpu().numpy().mean()

        val_cgm_df['ecg'].append(ecg_output)
        val_cgm_df['ppg'].append(ppg_output)
        val_cgm_df['eda'].append(eda_output)
        val_cgm_df['temp'].append(temp_output)
        val_cgm_df['label'].append(hypo_label.item())
        
val_cgm_df = pd.DataFrame(val_cgm_df)

100%|██████████| 396/396 [00:16<00:00, 23.78it/s]


In [16]:
y_pred = MoE_model.predict(val_cgm_df[['ecg', 'ppg', 'eda', 'temp']])
y_true = val_cgm_df['label'].values

In [17]:
MoE_acc = (y_pred == y_true).mean()
print(f"MoE Accuracy: {MoE_acc:.2f}")

MoE Accuracy: 0.78
