In [6]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pywt
import os
from scipy.signal import spectrogram
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, RobustScaler
from torch.utils.data import DataLoader, Dataset
from torchvision.models import resnet50
from tqdm.notebook import tqdm
import io
import zipfile
from google.oauth2 import service_account
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseDownload
import h5py


%matplotlib inline

In [2]:
def download_dataset(repo_location):
    # Path to the service account key file
    SERVICE_ACCOUNT_FILE = os.path.join(repo_location, 'ppg-ml-427113-ecc7e656b7f5.json')

    # Define the required scopes
    SCOPES = ['https://www.googleapis.com/auth/drive.readonly']

    # Authenticate using the service account
    credentials = service_account.Credentials.from_service_account_file(
        SERVICE_ACCOUNT_FILE, scopes=SCOPES)

    # Build the Drive API client
    service = build('drive', 'v3', credentials=credentials)

    # File ID of the .zip file to download
    file_id = '1uMejT3pEJVFKM20bzpsbS35NJSuG85DL'

    # Request to download the .zip file
    request = service.files().get_media(fileId=file_id)
    zip_file_path = 'MIMIC CSVs.zip'

    with io.FileIO(zip_file_path, 'wb') as fh:
        downloader = MediaIoBaseDownload(fh, request)

        done = False
        while not done:
            status, done = downloader.next_chunk()
            print(f"Download {int(status.progress() * 100)}%.")

    # Extract the downloaded .zip file
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall(os.path.join(repo_location, "CSVs"))  # Specify the directory to extract to

    #Delete the zip file
    os.remove(zip_file_path)
    print("File downloaded and extracted successfully.")

In [5]:
def load_h5_file(file_path):
    def recursively_load_data(h5_obj):
        if isinstance(h5_obj, h5py.Dataset):
            data = h5_obj[()]
            if isinstance(data, bytes):  # Decode byte strings
                return data.decode()
            elif isinstance(data, np.ndarray) and data.dtype.type is np.bytes_:
                return data.astype(str)  # Decode byte strings in numpy arrays
            return data
        elif isinstance(h5_obj, h5py.Group):
            data = {}
            for key, item in h5_obj.items():
                data[key] = recursively_load_data(item)
            return data
        else:
            raise TypeError(f"Unsupported type: {type(h5_obj)}")

    with h5py.File(file_path, 'r') as f:
        return recursively_load_data(f)

In [None]:
def safe_to_datetime(date_str):
    try:
        return pd.to_datetime(date_str, format='%Y-%m-%d-%H-%M-%S')
    except Exception as e:
        print(f"Error parsing date '{date_str}': {e}")
        return pd.NaT

def process_file(path, file, subject_id, interval, labs_filtered):
    data = load_h5_file(os.path.join(path, file))

    if ('Segment_Time') in data['Subj_Wins'].keys():

        # Convert the data file into a dataframe 
        df_data = pd.DataFrame(data['Subj_Wins']['PPG_Raw'].squeeze())

        # Insert the Segment Time variable to the dataframe
        df_data.insert(0, 'Segment_Time', data['Subj_Wins']['Segment_Time'])

        # Insert lab_flag variable to the dataframe. All values are equal to nan for now
        df_data.insert(1, 'lab_flag', np.nan)


        # Insert SubjectID, Age and gender variables to the dataframe
        df_data.insert(1, 'SubjectID', subject_id)
        df_data.insert(2, 'Age', data['Subj_Wins']['Age'])
        df_data.insert(3, 'Gender', data['Subj_Wins']['Gender'])

        #Normalize the age variable by 130 to get a value between 0 and 1
        df_data['Age'] = df_data['Age']/130

        # Convert numpy.str_ to native Python strings and convert to datetime
        df_data['Segment_Time'] = df_data['Segment_Time'].apply(lambda x: x[0] if isinstance(x, np.ndarray) else x).astype(str)
        df_data['Segment_Time'] = df_data['Segment_Time'].apply(safe_to_datetime)
                
        mask = (labs_filtered.SUBJECT_ID == subject_id)
        labs_filtered = labs_filtered[mask].reset_index(drop=True)
        
        matching_labs_indices = set()
        matching_data_indices = set()

        start_time = df_data.Segment_Time - interval
        end_time = df_data.Segment_Time + interval

        for i, row in df_data.iterrows():
            matching_labs_indices.update(
                labs_filtered[(labs_filtered.CHARTTIME >= start_time[i]) & (labs_filtered.CHARTTIME <= end_time[i])].index
            )

        lab_start_time = labs_filtered.CHARTTIME - interval
        lab_end_time = labs_filtered.CHARTTIME + interval

        for i, row in labs_filtered.iterrows():
            matching_data_indices.update(
                df_data[(df_data.Segment_Time >= lab_start_time[i]) & (df_data.Segment_Time <= lab_end_time[i])].index
            )

        # Convert sets to lists for indexing
        matching_labs_indices = list(matching_labs_indices)
        matching_data_indices = list(matching_data_indices)

        df_data = df_data.loc[matching_data_indices].reset_index(drop=True)
        labs_filtered = labs_filtered.loc[matching_labs_indices].reset_index(drop=True)

        start_time = df_data.Segment_Time - interval
        end_time = df_data.Segment_Time + interval

        # Vectorize the process of creating lab flags
        for i in range(len(df_data)):
            labs_subset = labs_filtered[
                (labs_filtered.CHARTTIME >= start_time[i]) &
                (labs_filtered.CHARTTIME <= end_time[i])
            ]
            if labs_subset.shape[0] > 0:
                flag_sum = labs_subset.FLAG.sum()
                df_data.loc[i, 'lab_flag'] = 1 if flag_sum > 0 else 0 if flag_sum == 0 else np.nan

            # Check if we have reached 100 samples for either flag, and filter out the other flag once it happens
            if df_data[df_data.lab_flag == 0].shape[0] == 100:
                labs_subset = labs_subset[labs_subset.FLAG == 1]
            if df_data[df_data.lab_flag == 1].shape[0] == 100:
                labs_subset = labs_subset[labs_subset.FLAG == 0]
        
        return df_data.dropna()

In [None]:
def process_file_optimized(path, file, subject_id, interval, labs_filtered):
    data = load_h5_file(os.path.join(path, file))

    if ('Segment_Time') in data['Subj_Wins'].keys():

        # Convert the data file into a dataframe 
        df_data = pd.DataFrame(data['Subj_Wins']['PPG_Raw'].squeeze())

        # Insert the Segment Time variable to the dataframe
        df_data.insert(0, 'Segment_Time', data['Subj_Wins']['Segment_Time'])

        # Insert lab_flag variable to the dataframe. All values are equal to nan for now
        df_data.insert(1, 'lab_flag', np.nan)

        # Insert SubjectID, Age and Gender variables to the dataframe
        df_data.insert(1, 'SubjectID', subject_id)
        df_data.insert(2, 'Age', data['Subj_Wins']['Age'])
        df_data.insert(3, 'Gender', data['Subj_Wins']['Gender'])

        # Normalize the age variable by 130 to get a value between 0 and 1
        df_data['Age'] = df_data['Age']/130

        # Convert numpy.str_ to native Python strings and convert to datetime
        df_data['Segment_Time'] = df_data['Segment_Time'].apply(lambda x: x[0] if isinstance(x, np.ndarray) else x).astype(str)
        df_data['Segment_Time'] = df_data['Segment_Time'].apply(safe_to_datetime)
   
        # Filter labs_filtered for the specific subject
        labs_filtered = labs_filtered[labs_filtered.SUBJECT_ID == subject_id]

        # Create time windows
        df_data['start_time'] = df_data['Segment_Time'] - interval
        df_data['end_time'] = df_data['Segment_Time'] + interval

        labs_filtered['start_time'] = labs_filtered['CHARTTIME'] - interval
        labs_filtered['end_time'] = labs_filtered['CHARTTIME'] + interval

        # Merge df_data and labs_filtered based on time windows
        merged = pd.merge_asof(df_data.sort_values('Segment_Time'), 
                            labs_filtered.sort_values('CHARTTIME'), 
                            left_on='Segment_Time', right_on='CHARTTIME', 
                            direction='nearest', tolerance=interval)
        
        # Assign lab_flag based on FLAG
        merged['lab_flag'] = np.where(merged['FLAG'].notna(), merged['FLAG'], np.nan)

        # Drop unnecessary columns
        merged.drop(columns=['start_time_y', 'end_time_y', 'CHARTTIME', 'SUBJECT_ID',
                             'FLAG', 'start_time_x', 'end_time_x', 'HADM_ID', 'ROW_ID',
                             'ITEMID', 'VALUE', 'VALUENUM', 'VALUEUOM', 'Panel', 'Subpanel'], inplace=True)

        return merged.dropna().reset_index(drop=True)

In [None]:
class MyDataset(Dataset):

    def __init__(self, df):
        self.df = df
        self.y_data = df.lab_flag.values
        self.X_data = df.loc[:, df.columns.intersection([str(i) for i in range(0, 1250)])]
        if ('Age' in self.df.columns) and ('Gender' in self.df.columns):
            self.age_data = df.Age.values
            # Change gender values from 'M' and 'F' to 0 and 1
            self.gender_data = df.Gender.map({'M': 0, 'F': 1}).values

    
    def __len__(self):
        return len(self.X_data)

    def __getitem__(self, idx):
        row = self.X_data.iloc[idx]
        _, _, spec = spectrogram(row.values.astype(np.float32), fs=125, nperseg=256)
        y = self.y_data[idx]
        if ('Age' in self.df.columns) and ('Gender' in self.df.columns):
            age = torch.tensor(self.age_data[idx], dtype=torch.float32)
            gender = torch.tensor(self.gender_data[idx], dtype=torch.float32)
            X = torch.tensor(spec, dtype=torch.float32).unsqueeze(0)  # Add channel dimension for grayscale image
            return X, age, gender, y
        else:
            X = torch.tensor(spec, dtype=torch.float32)
            return X, y  

In [None]:
class MyWaveletDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.y_data = df.lab_flag.values
        self.X_data = df.loc[:, df.columns.intersection([str(i) for i in range(0, 1250)])]
        if ('Age' in self.df.columns) and ('Gender' in self.df.columns):
            self.age_data = df.Age.values
            # Change gender values from 'M' and 'F' to 0 and 1
            self.gender_data = df.Gender.map({'M': 0, 'F': 1}).values

    def __len__(self):
        return len(self.X_data)

    def __getitem__(self, idx):
        row = self.X_data.iloc[idx]
        signal = row.values.astype(np.float32)

        # Perform Discrete Wavelet Transform
        coeffs = pywt.wavedec(signal, 'db4', level=4)
        coeffs_array = np.concatenate(coeffs)

        y = self.y_data[idx]
        if ('Age' in self.df.columns) and ('Gender' in self.df.columns):
            age = torch.tensor(self.age_data[idx], dtype=torch.float32)
            gender = torch.tensor(self.gender_data[idx], dtype=torch.float32)
            X = torch.tensor(coeffs_array, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # Add channel dimension if needed
            return X, age, gender, y
        else:
            X = torch.tensor(coeffs_array, dtype=torch.float32).unsqueeze(0)
            return X, y

In [2]:
def train(model, trainloader, optim, criterion, epoch, device, gender_age_used, scheduler=None):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for data in trainloader:
        if gender_age_used == "GA_used":
            inputs, age, gender, labels = data
            inputs, age, gender, labels = inputs.to(device), age.to(device), gender.to(device), labels.to(device)
            labels = labels.long()  # Convert labels to Long type
            optim.zero_grad()
            outputs = model(inputs, age, gender)
        else:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            labels = labels.long()
            optim.zero_grad()
            outputs = model(inputs.unsqueeze(1))
        # print(inputs.shape)
        # print(outputs.shape, labels)
        loss = criterion(outputs, labels)
        loss.backward()
        optim.step()

        running_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_loss = running_loss / len(trainloader)
    train_acc = correct / total
    if scheduler:
        scheduler.step()
    # print(f"Epoch {epoch} loss: {train_loss}")
    # print(f"Epoch {epoch} accuracy: {train_acc}")
    
    return train_loss, train_acc


def validate(model, testloader, criterion, best_val_acc, device, test, repo_location, gender_age_used):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for data in testloader:
            
            if gender_age_used == "GA_used":
                inputs, age, gender, labels = data
                inputs, age, gender, labels = inputs.to(device), age.to(device), gender.to(device), labels.to(device)
                labels = labels.long()  # Convert labels to Long type
                outputs = model(inputs, age, gender)
            else:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                labels = labels.long()
                outputs = model(inputs.unsqueeze(1))
            loss = criterion(outputs, labels)

            running_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = (running_loss / len(testloader))
    val_acc = (correct / total)      

    # print(f"Validation loss: {val_loss}")
    # print(f"Validation accuracy: {val_acc}")

    if correct / total > best_val_acc:
        best_val_acc = correct / total
        torch.save(model.state_dict(), os.path.join(repo_location, f"models/bestmodel_{test}_{gender_age_used}.pth"))
    return val_loss, val_acc, best_val_acc

In [10]:
def plot_loss_accuracy(train_loss, train_acc, validation_loss, validation_acc, test, gender_age_used=""):
    epochs = len(train_loss)
    fig, (ax1, ax2) = plt.subplots(1, 2)
    ax1.plot(list(range(epochs)), train_loss, label='Training Loss')
    ax1.plot(list(range(epochs)), validation_loss, label='Validation Loss')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.set_title(f'Epoch vs Loss for {test}')
    ax1.legend()

    ax2.plot(list(range(epochs)), train_acc, label='Training Accuracy')
    ax2.plot(list(range(epochs)), validation_acc, label='Validation Accuracy')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Accuracy')
    ax2.set_title(f'Epoch vs Accuracy for {test}')
    ax2.legend()
    fig.set_size_inches(15.5, 5.5)
    fig.savefig(f"plots/{test}{gender_age_used}.png")