In [1]:
import os

# helper function
def exists(path):
    val = os.path.exists(path)
    if val:
        print(f'{path} already exits. Using cached. Delete it manually to recieve it again!')
    return val

In [2]:
# Import
import os
import torch
import torch.nn as nn
import numpy as np
from tqdm.notebook import trange, tqdm
import h5py
from torch.utils.data import TensorDataset, random_split, DataLoader
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, f1_score
from utils import pgd_attack
import ecg_plot
from models import ResNet1d
import ast
%matplotlib inline

In [3]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tqdm.write("Use device: {device:}\n".format(device=device))

Use device: cuda



In [9]:
# Load the data
dataset_path = '/local_storage/users/arveri/ptb-xl'

path_to_csv = dataset_path + '/ptbxl_database.csv'

# Get labels
df = pd.read_csv(path_to_csv, index_col='ecg_id')
df_size = df.shape[0]
# Remove all where age == 300
df = df[df['age'] != 300]
df_size_filtered = df.shape[0]
print(f"Filtered out {df_size - df_size_filtered} rows where age == 300")

validation_fold = 9
test_fold = 10

# Load labels
train = df[(df.strat_fold != validation_fold) & (df.strat_fold != test_fold)]
val = df[df.strat_fold == validation_fold]
test = df[df.strat_fold == test_fold]

f_names_train = train['filename_hr']
f_names_val = val['filename_hr']
f_names_test = test['filename_hr']

# Save the file names for the train, validation and test set in RECORDS_train.txt, RECORDS_val.txt and RECORDS_test.txt
with open(dataset_path + '/RECORDS_train_age.txt', 'w') as f:
    for s in f_names_train.values:
        f.write(s + '\n')
        
with open(dataset_path + '/RECORDS_val_age.txt', 'w') as f:
    for s in f_names_val.values:
        f.write(s + '\n')
        
with open(dataset_path + '/RECORDS_test_age.txt', 'w') as f:
    for s in f_names_test.values:
        f.write(s + '\n')

In [10]:
if not exists(dataset_path + '/train_age.h5'):
    !python ecg-preprocessing/generate_h5.py --new_freq 400 --new_len 4096 --remove_baseline --use_all_leads --remove_powerline 60 ptb-xl/RECORDS_train.txt /local_storage/users/arveri/ptb-xl/train_age.h5
    
if not exists(dataset_path + '/val_age.h5'):
    !python ecg-preprocessing/generate_h5.py --new_freq 400 --new_len 4096 --remove_baseline --use_all_leads --remove_powerline 60 ptb-xl/RECORDS_val.txt /local_storage/users/arveri/ptb-xl/val_age.h5
    
if not exists(dataset_path + '/test_age.h5'):
    !python ecg-preprocessing/generate_h5.py --new_freq 400 --new_len 4096 --remove_baseline --use_all_leads --remove_powerline 60 ptb-xl/RECORDS_test.txt /local_storage/users/arveri/ptb-xl/test_age.h5

Namespace(input_file='ptb-xl/RECORDS_train.txt', out_file='ptb-xl/train.h5', root_dir=None, new_freq=400.0, new_len=4096, scale=1, use_all_leads=True, remove_baseline=True, remove_powerline=60.0, fmt='wfdb')
100%|█████████████████████████████████████| 17418/17418 [03:01<00:00, 95.79it/s]
Namespace(input_file='ptb-xl/RECORDS_val.txt', out_file='ptb-xl/val.h5', root_dir=None, new_freq=400.0, new_len=4096, scale=1, use_all_leads=True, remove_baseline=True, remove_powerline=60.0, fmt='wfdb')
100%|██████████████████████████████████████| 2183/2183 [00:21<00:00, 103.25it/s]
Namespace(input_file='ptb-xl/RECORDS_test.txt', out_file='ptb-xl/test.h5', root_dir=None, new_freq=400.0, new_len=4096, scale=1, use_all_leads=True, remove_baseline=True, remove_powerline=60.0, fmt='wfdb')
100%|██████████████████████████████████████| 2198/2198 [00:21<00:00, 104.03it/s]
