<a href="https://colab.research.google.com/github/arushi-lu/deep_learning/blob/main/CNN_PPG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Data preprocessing -> creating 5 folds

From the original dataset (127260) ->

Used 50000 samples for training and validation

and  10000 samples for testing

In [1]:
import gdown

In [2]:
file_id = '1IxN2sX2TX0uK6CFDh8eudb8haz3RlF7X'
download_url = f'https://drive.google.com/uc?id={file_id}'
output_file = 'data.hdf5'

gdown.download(download_url, output_file, quiet=False)

Downloading...
From (original): https://drive.google.com/uc?id=1IxN2sX2TX0uK6CFDh8eudb8haz3RlF7X
From (redirected): https://drive.google.com/uc?id=1IxN2sX2TX0uK6CFDh8eudb8haz3RlF7X&confirm=t&uuid=43f4e7b4-a17f-4309-a91f-53b41753975c
To: /content/data.hdf5
100%|██████████| 2.55G/2.55G [00:25<00:00, 101MB/s]


'data.hdf5'

In [3]:
import h5py

# Load the HDF5 file
file_path = 'data.hdf5'
with h5py.File(file_path, 'r') as f:
    # List all groups and datasets within the file
    print("Keys in 'data.hdf5':")
    print(list(f.keys()))

    # Assuming there's a dataset named 'data' containing all samples
    dataset = f['data']

    # Check the shape of the dataset to understand its size
    print("Shape of 'data' dataset:", dataset.shape)


Keys in 'data.hdf5':
['data']
Shape of 'data' dataset: (127260, 2, 1250)


In [4]:
!mkdir data

Data handling: Main function to create 5 folds for cross-validation

In [5]:
import h5py
import numpy as np
import os
from tqdm import tqdm
import pickle

def fold_data():
    length = 1250  # length of the signals

    # Starting points of validation data for 5 folds
    validation_data_start = {
        0: 40000,
        1: 0,
        2: 10000,
        3: 20000,
        4: 30000,
    }

    # Load the episode data once
    fl = h5py.File('data.hdf5', 'r')
    data = fl['data'][:50000]

    for fold_id in tqdm(range(5), desc='Folding Data'):  # Iterate for 5 folds
        X_train = []  # Initialize train data
        Y_train = []

        X_val = []  # Initialize validation data
        Y_val = []

        max_ppg = -10000  # Initialize metadata for min-max of abp, ppg signals
        min_ppg = 10000
        max_abp = -10000
        min_abp = 10000

        val_start = validation_data_start[fold_id]  # Validation data start
        val_end = val_start + 10000  # Validation data end

        # Process training data before validation samples
        for i in tqdm(range(0, val_start), desc='Training Data Part 1'):
            sample = data[i]
            X_train.append(sample[1][:length].reshape(length, 1))  # ppg signal
            Y_train.append(sample[0][:length].reshape(length, 1))  # abp signal
            max_ppg = max(np.max(sample[1]), max_ppg)
            min_ppg = min(np.min(sample[1]), min_ppg)
            max_abp = max(np.max(sample[0]), max_abp)
            min_abp = min(np.min(sample[0]), min_abp)

        # Process training data after validation samples
        for i in tqdm(range(val_end, 50000), desc='Training Data Part 2'):
            sample = data[i]
            X_train.append(sample[1][:length].reshape(length, 1))  # ppg signal
            Y_train.append(sample[0][:length].reshape(length, 1))  # abp signal
            max_ppg = max(np.max(sample[1]), max_ppg)
            min_ppg = min(np.min(sample[1]), min_ppg)
            max_abp = max(np.max(sample[0]), max_abp)
            min_abp = min(np.min(sample[0]), min_abp)

        # Process validation data
        for i in tqdm(range(val_start, val_end), desc='Validation Data'):
            sample = data[i]
            X_val.append(sample[1][:length].reshape(length, 1))  # ppg signal
            Y_val.append(sample[0][:length].reshape(length, 1))  # abp signal
            max_ppg = max(np.max(sample[1]), max_ppg)
            min_ppg = min(np.min(sample[1]), min_ppg)
            max_abp = max(np.max(sample[0]), max_abp)
            min_abp = min(np.min(sample[0]), min_abp)

        # Convert lists to numpy arrays for efficiency
        X_train = np.array(X_train)
        Y_train = np.array(Y_train)
        X_val = np.array(X_val)
        Y_val = np.array(Y_val)

        # Normalize training and validation data
        X_train = (X_train - min_ppg) / (max_ppg - min_ppg)
        Y_train = (Y_train - min_abp) / (max_abp - min_abp)
        X_val = (X_val - min_ppg) / (max_ppg - min_ppg)
        Y_val = (Y_val - min_abp) / (max_abp - min_abp)

        # Save training and validation data splits
        os.makedirs('data', exist_ok=True)
        with open(f'data/train{fold_id}.p', 'wb') as f:
            pickle.dump({'X_train': X_train, 'Y_train': Y_train}, f)
        with open(f'data/val{fold_id}.p', 'wb') as f:
            pickle.dump({'X_val': X_val, 'Y_val': Y_val}, f)

        # Save metadata
        with open(f'data/meta{fold_id}.p', 'wb') as f:
            pickle.dump({'max_ppg': max_ppg, 'min_ppg': min_ppg,
                         'max_abp': max_abp, 'min_abp': min_abp}, f)

    # Process test data
    fl = h5py.File('data.hdf5', 'r')
    test_data = fl['data'][50000:60000]
    X_test = []
    Y_test = []

    for sample in tqdm(test_data, desc='Test Data'):
        X_test.append(sample[1][:length].reshape(length, 1))  # ppg signal
        Y_test.append(sample[0][:length].reshape(length, 1))  # abp signal

    fl.close()  # Close the HDF5 file

    # Convert test data to numpy arrays
    X_test = np.array(X_test)
    Y_test = np.array(Y_test)

    # Normalize test data
    X_test = (X_test - min_ppg) / (max_ppg - min_ppg)
    Y_test = (Y_test - min_abp) / (max_abp - min_abp)

    # Save test data split
    with open('data/test.p', 'wb') as f:
        pickle.dump({'X_test': X_test, 'Y_test': Y_test}, f)

def main():
    fold_data()  # Split the data for 5-fold cross-validation

if __name__ == '__main__':
    main()


Folding Data:   0%|          | 0/5 [00:00<?, ?it/s]
Training Data Part 1:   0%|          | 0/40000 [00:00<?, ?it/s][A
Training Data Part 1:   7%|▋         | 2961/40000 [00:00<00:01, 29607.77it/s][A
Training Data Part 1:  15%|█▍        | 5964/40000 [00:00<00:01, 29855.31it/s][A
Training Data Part 1:  22%|██▏       | 8950/40000 [00:00<00:01, 29447.48it/s][A
Training Data Part 1:  30%|██▉       | 11896/40000 [00:00<00:00, 29397.93it/s][A
Training Data Part 1:  37%|███▋      | 14895/40000 [00:00<00:00, 29607.21it/s][A
Training Data Part 1:  45%|████▍     | 17857/40000 [00:00<00:00, 28117.86it/s][A
Training Data Part 1:  52%|█████▏    | 20796/40000 [00:00<00:00, 28515.35it/s][A
Training Data Part 1:  59%|█████▉    | 23658/40000 [00:00<00:00, 28368.10it/s][A
Training Data Part 1:  66%|██████▋   | 26560/40000 [00:00<00:00, 28565.29it/s][A
Training Data Part 1:  74%|███████▎  | 29422/40000 [00:01<00:00, 27937.68it/s][A
Training Data Part 1:  81%|████████  | 32345/40000 [00:01<00:00,

Methods to check the folds

In [6]:
import pickle
import os

def load_pickle(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data

def check_files():
    data_dir = 'data'
    files = os.listdir(data_dir)

    for file_name in files:
        file_path = os.path.join(data_dir, file_name)
        data = load_pickle(file_path)

        print(f"Checking {file_name}...")
        if 'train' in file_name or 'val' in file_name:
            print(f"X shape: {data['X_train'].shape if 'train' in file_name else data['X_val'].shape}")
            print(f"Y shape: {data['Y_train'].shape if 'train' in file_name else data['Y_val'].shape}")
        elif 'test' in file_name:
            print(f"X shape: {data['X_test'].shape}")
            print(f"Y shape: {data['Y_test'].shape}")
        elif 'meta' in file_name:
            print(f"Metadata: {data}")
        print()

def main():
    check_files()

if __name__ == '__main__':
    main()


Checking test.p...
X shape: (10000, 1250, 1)
Y shape: (10000, 1250, 1)

Checking val4.p...
X shape: (10000, 1250, 1)
Y shape: (10000, 1250, 1)

Checking train4.p...
X shape: (40000, 1250, 1)
Y shape: (40000, 1250, 1)

Checking meta2.p...
Metadata: {'max_ppg': 4.001955034213099, 'min_ppg': 0.0, 'max_abp': 199.9479008990615, 'min_abp': 50.0}

Checking meta1.p...
Metadata: {'max_ppg': 4.001955034213099, 'min_ppg': 0.0, 'max_abp': 199.9479008990615, 'min_abp': 50.0}

Checking train3.p...
X shape: (40000, 1250, 1)
Y shape: (40000, 1250, 1)

Checking val1.p...
X shape: (10000, 1250, 1)
Y shape: (10000, 1250, 1)

Checking meta0.p...
Metadata: {'max_ppg': 4.001955034213099, 'min_ppg': 0.0, 'max_abp': 199.9479008990615, 'min_abp': 50.0}

Checking val0.p...
X shape: (10000, 1250, 1)
Y shape: (10000, 1250, 1)

Checking train2.p...
X shape: (40000, 1250, 1)
Y shape: (40000, 1250, 1)

Checking meta3.p...
Metadata: {'max_ppg': 4.001955034213099, 'min_ppg': 0.0, 'max_abp': 199.9479008990615, 'min_abp

In [37]:
file_id = '1HhcOzLbZgxOS5byOJF43r4veGWsQVb8M'
download_url = f'https://drive.google.com/uc?id={file_id}'
output_file = 'meta.p'

gdown.download(download_url, output_file, quiet=False)

Downloading...
From: https://drive.google.com/uc?id=1HhcOzLbZgxOS5byOJF43r4veGWsQVb8M
To: /content/meta.p
100%|██████████| 54.0/54.0 [00:00<00:00, 188kB/s]


'meta.p'

In [38]:
data = load_pickle('meta.p')
data

{'max_abp': 178.8, 'min_abp': 60.2}

BP-Net architecture

In [7]:
import torch
import torch.nn as nn


class IncBlock(nn.Module):
    def __init__(self, in_channels, out_channels, size = 15, stride = 1, padding = 7):
        super(IncBlock,self).__init__()

        self.conv1x1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias = False)

        self.conv1 = nn.Sequential(nn.Conv1d(in_channels, out_channels//4, kernel_size = size, stride = stride, padding = padding ),
                                   nn.BatchNorm1d(out_channels//4))

        self.conv2 = nn.Sequential(nn.Conv1d(in_channels, out_channels//4, kernel_size = 1, bias = False),
                                   nn.BatchNorm1d(out_channels//4),
                                   nn.LeakyReLU(0.2),
                                   nn.Conv1d(out_channels//4, out_channels//4, kernel_size = size +2 , stride = stride, padding = padding + 1),
                                   nn.BatchNorm1d(out_channels//4))

        self.conv3 = nn.Sequential(nn.Conv1d(in_channels, out_channels//4, kernel_size = 1, bias = False),
                                   nn.BatchNorm1d(out_channels//4),
                                   nn.LeakyReLU(0.2),
                                   nn.Conv1d(out_channels//4, out_channels//4, kernel_size = size + 4 , stride = stride, padding = padding + 2),
                                   nn.BatchNorm1d(out_channels//4))


        self.conv4 = nn.Sequential(nn.Conv1d(in_channels, out_channels//4, kernel_size = 1, bias = False),
                                   nn.BatchNorm1d(out_channels//4),
                                   nn.LeakyReLU(0.2),
                                   nn.Conv1d(out_channels//4, out_channels//4, kernel_size = size + 6 , stride = stride, padding = padding + 3),
                                   nn.BatchNorm1d(out_channels//4))
        self.relu = nn.ReLU()
    def forward(self,x):
        res = self.conv1x1(x)
#         print (res.size())


        c1 = self.conv1(x)
#         print (c1.size())

        c2 = self.conv2(x)
#         print (c2.size())

        c3 = self.conv3(x)
#         print (c3.size())

        c4 = self.conv4(x)
#         print (c4.size())

        concat = torch.cat((c1,c2,c3,c4),dim = 1)

        concat+=res
#         print (concat.shape)
        return self.relu(concat)




class InterAxialBlock(nn.Module):
        #3
  def __init__(self,in_channels = 1, out_channels = 1):

    super(InterAxialBlock, self).__init__()

    self.conv1 = nn.Conv1d(in_channels,8,3)
    self.bn1 = nn.BatchNorm1d(8)

    self.conv2 = nn.Conv1d(8,16,3)
    self.bn2 =nn.BatchNorm1d(16)

    self.conv3 = nn.Conv2d(1,1,(3,3), 2)
    self.bn3 = nn.BatchNorm2d(1)

    self.conv4 = nn.Conv2d(1, 1, (3,15), padding = (0,7))
    self.bn4 = nn.BatchNorm2d(1)

    self.conv5 = nn.Conv1d(1,out_channels,3, padding = 1)
    self.bn5 = nn.BatchNorm1d(out_channels)
    self.relu1 = nn.LeakyReLU(0.2)


    self.mp1 = nn.MaxPool1d(2)
    self.mp2 = nn.MaxPool2d((2,2))




  def forward(self, x):

#     print("in Inter",x.shape)
    x = self.relu1(self.bn1(self.conv1(x)))

    x = self.relu1(self.bn2(self.conv2(x)))
#3d -> 4d
    x = x.view(x.shape[0],1,x.shape[1],x.shape[2])

    x = self.relu1(self.bn3(self.conv3(x)))

    x = self.mp2(x)


    x = self.relu1(self.bn4(self.conv4(x)))


    x = torch.squeeze(x, dim = 1)
    x = self.relu1(self.bn5(self.conv5(x)))


    return x

class Unet(nn.Module):
    def __init__(self, shape):
        super(Unet, self).__init__()
        #1
        in_channels = 1

        self.inter = nn.Sequential(InterAxialBlock())

        self.en1 = nn.Sequential(nn.Conv1d(in_channels, 32, 3, padding = 1),
                                nn.BatchNorm1d(32),
                                nn.LeakyReLU(0.2),
                                nn.Conv1d(32, 32, 5, stride = 2, padding = 2),
                                IncBlock(32,32))

        self.en2 = nn.Sequential(nn.Conv1d(32, 64, 3, padding = 1),
                                nn.BatchNorm1d(64),
                                nn.LeakyReLU(0.2),
                                 nn.Conv1d(64, 64, 5, stride = 2, padding = 2),
                                IncBlock(64,64))


        self.en3 = nn.Sequential(nn.Conv1d(64,128, 3, padding = 1),
                                 nn.BatchNorm1d(128),
                                 nn.LeakyReLU(0.2),
                                 nn.Conv1d(128, 128, 3, stride = 2, padding = 1),
                                IncBlock(128,128))

        self.en4 = nn.Sequential(nn.Conv1d(128,256, 3,padding = 1),
                                 nn.BatchNorm1d(256),
                                 nn.LeakyReLU(0.2),
                                 nn.Conv1d(256, 256, 5, stride = 2, padding = 1),
                                IncBlock(256,256))


        self.en5 = nn.Sequential(nn.Conv1d(256,512, 3, padding = 1),
                                 nn.BatchNorm1d(512),
                                 nn.LeakyReLU(0.2),
                                 IncBlock(512,512))


        self.de1 = nn.Sequential(nn.ConvTranspose1d(512,256,1),
                               nn.BatchNorm1d(256),
                               nn.LeakyReLU(0.2),
                                IncBlock(256,256))

        self.de2 =  nn.Sequential(nn.Conv1d(512,256,3, padding = 1),
                               nn.BatchNorm1d(256),
                               nn.LeakyReLU(0.2),
                                  nn.ConvTranspose1d(256,128,3, stride = 2),
                                IncBlock(128,128))

        self.de3 =  nn.Sequential(nn.Conv1d(256,128,3, stride = 1, padding = 1),
                               nn.BatchNorm1d(128),
                               nn.LeakyReLU(0.2),
                                nn.ConvTranspose1d(128,64,3, stride = 2),
                                IncBlock(64,64))

        self.de4 =  nn.Sequential(nn.Conv1d(128,64,3, stride = 1, padding = 1),
                               nn.BatchNorm1d(64),
                               nn.LeakyReLU(0.2),
                                nn.ConvTranspose1d(64,32,3, stride = 2),
                                IncBlock(32,32))

        self.de5 = nn.Sequential(nn.Conv1d(64,32,3, stride = 1, padding = 1),
                               nn.BatchNorm1d(32),
                               nn.LeakyReLU(0.2),
                                nn.ConvTranspose1d(32,16,3, stride = 2),
                                IncBlock(16,16))

        self.de6 = nn.Sequential(nn.ConvTranspose1d(16,8,2,stride =2),
                                nn.BatchNorm1d(8),
                                nn.LeakyReLU(0.2))

        self.de7 = nn.Sequential(nn.ConvTranspose1d(8,4,2,stride =2),
                                nn.BatchNorm1d(4),
                                nn.LeakyReLU(0.2))

        self.de8 = nn.Sequential(nn.ConvTranspose1d(4,2,1,stride =1),
                                nn.BatchNorm1d(2),
                                nn.LeakyReLU(0.2))

        self.de9 = nn.Sequential(nn.ConvTranspose1d(2,1,1,stride =1),
                                nn.BatchNorm1d(1),
                                nn.LeakyReLU(0.2))


    def forward(self,x):

#         print("Before inter ",x.shape)
        x = self.inter(x)
#         print(" After Inter",x.shape)

        x = nn.ConstantPad1d((1,1),0)(x)
#         print ("After ConstantPad1d",x.shape)
        e1 = self.en1(x)
#         print ("After e1 ",e1.shape)

        e2 = self.en2(e1)
#         print ("After e2 ",e2.shape)

        e3 = self.en3(e2)
#         print ("After e3 ",e3.shape)

        e4 = self.en4(e3)
#         print ("After e4  ",e4.shape)

        e5 = self.en5(e4)
#         print ("After e5 ",e5.shape)
#         print ("-----------------------------------------------------------------------------")
        d1 = self.de1(e5)
#         print ("After d1", d1.shape)

#         print("Before cat d1 e4 {} {}".format(d1.shape,e4.shape))
        cat = torch.cat([d1,e4],1)
#         print("After cat d1 e4 {}".format(cat.shape))

        d2 = self.de2(cat)
#         print ("After d2 ",d2.shape)

#         print ("Before cat d2 e3 {} {}  ".format(d2.shape,e3.shape))
        cat = torch.cat([d2,e3[:,:,:-1]],1)
#         print("After cat d2 e3 {}".format(cat.shape))



        d3 = self.de3(cat)

#         print ("After d3 ",d3.shape)
#         print ("Before cat d3 e2 {} {}  ".format(d3.shape,e2.shape))
#         print("-1 being done on d3")
        cat = torch.cat([d3,e2[:,:,:]],1) #MADE A CHANGE HERE, ADDED -1
#         print("After cat d3 e2 {}".format(cat.shape))

        d4 = self.de4(cat)
#         print ("After d4 ",d4.shape)

#         print ("Before cat d4 e1 {} {}  ".format(d4.shape,e1.shape))
        cat = torch.cat([d4[:,:,:-2],e1],1) #MADE A CHANGE HERE, ([d4[:,:,:-2],e1],1) this is the original one
#         print("After cat d4 e1 {}".format(cat.shape))

        d5 = self.de5(cat)[:,:,:-2]
#         print ("After d5 ", d5.shape)

        d6 = self.de6(d5)[:,:,:-1]

#         print(d6.shape)

        d7 = self.de7(d6)
#         print("d7 ", d7.shape)
        d8 = self.de8(d7)
#         print(d8.shape)
        d9 = self.de9(d8)
#         print(d9.shape)
        return d9

Data loading:

BPdatasetv1 used for SSL

BPdatasetv2 used for training

In [8]:
import pickle
import os
import numpy as np
from torch.utils.data import Dataset

class BPdatasetv1(Dataset):
    def __init__(self, fold_num, train=False, val=False):
        if train:
            dt = pickle.load(open(os.path.join('data', f'train{fold_num}.p'), 'rb'))
            self.input = np.swapaxes(dt['X_train'], 1, 2).astype('float32')
            self.output = np.swapaxes(dt['X_train'], 1, 2).astype('float32')
        elif val:
            dt = pickle.load(open(os.path.join('data', f'val{fold_num}.p'), 'rb'))
            self.input = np.swapaxes(dt['X_val'], 1, 2).astype('float32')
            self.output = np.swapaxes(dt['X_val'], 1, 2).astype('float32')

    def __len__(self):
        return len(self.input)

    def __getitem__(self, idx):
        inp = self.input[idx]
        out = self.output[idx]
        return inp, out

class BPdatasetv2(Dataset):
    def __init__(self, fold_num, train=False, val=False, test=False):
        if train:
            dt = pickle.load(open(os.path.join('data', f'train{fold_num}.p'), 'rb'))
            self.input = np.swapaxes(dt['X_train'], 1, 2).astype('float32')
            self.output = np.swapaxes(dt['Y_train'], 1, 2).astype('float32')
        elif val:
            dt = pickle.load(open(os.path.join('data', f'val{fold_num}.p'), 'rb'))
            self.input = np.swapaxes(dt['X_val'], 1, 2).astype('float32')
            self.output = np.swapaxes(dt['Y_val'], 1, 2).astype('float32')
        elif test:
            dt = pickle.load(open(os.path.join('data', 'test.p'), 'rb'))
            self.input = np.swapaxes(dt['X_test'], 1, 2).astype('float32')
            self.output = np.swapaxes(dt['Y_test'], 1, 2).astype('float32')

    def __len__(self):
        return len(self.input)

    def __getitem__(self, idx):
        inp = self.input[idx]
        out = self.output[idx]
        return inp, out


EarlyStopper class to improve performance and discard unnecessary computations

In [9]:
class EarlyStopper:
    def __init__(self, patience=3, min_delta=10):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [10]:
!mkdir model

Self-supervision (supervision)

In [12]:
import os
import pandas as pd
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader


bs = 256
length = 1250
epochs = 20
folds = 5

model = Unet((bs, 1, length)).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 200], gamma=0.1)
scaler = torch.cuda.amp.GradScaler()
early_stopper = EarlyStopper(patience=7, min_delta=0)

best_loss = 1000

for epoch in range(epochs):
    model.train()
    print('Epoch {}/{}'.format(epoch + 1, epochs))

    running_loss = 0.0
    running_loss_v = 0.0

    for fold in range(1, folds + 1):
        train_loader = DataLoader(BPdatasetv1(fold-1, train=True), batch_size=bs, shuffle=True)
        val_loader = DataLoader(BPdatasetv1(fold-1, val=True), batch_size=bs, shuffle=False)

        for idx, (inputs, output) in tqdm(enumerate(train_loader), total=len(train_loader)):
            inputs = inputs.cuda()
            output = output.cuda()

            optimizer.zero_grad()

            with torch.cuda.amp.autocast():
                pred = model(inputs)
                loss = criterion(pred, output)

            scaler.scale(loss).backward()
            running_loss += loss.item() * inputs.size(0)
            scaler.step(optimizer)
            scaler.update()

        scheduler.step()

        # VALIDATION
        model.eval()
        with torch.no_grad():
            for idx, (inputs_v, labels_v) in tqdm(enumerate(val_loader), total=len(val_loader)):
                inputs_v = inputs_v.cuda()
                labels_v = labels_v.cuda()
                outputs_v = model(inputs_v).cuda()
                loss_v = criterion(outputs_v, labels_v)
                running_loss_v += loss_v.item() * inputs_v.size(0)

    avg_train_loss = running_loss / (len(train_loader.dataset) * folds)
    avg_val_loss = running_loss_v / (len(val_loader.dataset) * folds)

    path = 'model/ssl.pt'

    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_dev_loss': best_loss,
            'exp_dir': 'model'
        }, f=path)
    print('Loss: {:.4f}   Val_loss: {:.4f}'.format(avg_train_loss, avg_val_loss))

    if early_stopper.early_stop(avg_val_loss):
        print("Early stopping")
        break


Epoch 1/20


100%|██████████| 157/157 [00:14<00:00, 11.09it/s]
100%|██████████| 40/40 [00:01<00:00, 22.65it/s]
100%|██████████| 157/157 [00:13<00:00, 11.33it/s]
100%|██████████| 40/40 [00:01<00:00, 22.15it/s]
100%|██████████| 157/157 [00:14<00:00, 11.15it/s]
100%|██████████| 40/40 [00:01<00:00, 22.31it/s]
100%|██████████| 157/157 [00:13<00:00, 11.33it/s]
100%|██████████| 40/40 [00:01<00:00, 22.30it/s]
100%|██████████| 157/157 [00:13<00:00, 11.40it/s]
100%|██████████| 40/40 [00:01<00:00, 22.50it/s]


Loss: 0.1175   Val_loss: 0.0591
Epoch 2/20


100%|██████████| 157/157 [00:14<00:00, 10.78it/s]
100%|██████████| 40/40 [00:01<00:00, 22.29it/s]
100%|██████████| 157/157 [00:13<00:00, 11.32it/s]
100%|██████████| 40/40 [00:01<00:00, 22.08it/s]
100%|██████████| 157/157 [00:13<00:00, 11.26it/s]
100%|██████████| 40/40 [00:01<00:00, 22.57it/s]
100%|██████████| 157/157 [00:13<00:00, 11.51it/s]
100%|██████████| 40/40 [00:01<00:00, 22.56it/s]
100%|██████████| 157/157 [00:13<00:00, 11.45it/s]
100%|██████████| 40/40 [00:01<00:00, 22.50it/s]


Loss: 0.0376   Val_loss: 0.0249
Epoch 3/20


100%|██████████| 157/157 [00:14<00:00, 11.12it/s]
100%|██████████| 40/40 [00:01<00:00, 22.41it/s]
100%|██████████| 157/157 [00:13<00:00, 11.40it/s]
100%|██████████| 40/40 [00:01<00:00, 22.50it/s]
100%|██████████| 157/157 [00:13<00:00, 11.22it/s]
100%|██████████| 40/40 [00:01<00:00, 21.73it/s]
100%|██████████| 157/157 [00:13<00:00, 11.35it/s]
100%|██████████| 40/40 [00:01<00:00, 22.46it/s]
100%|██████████| 157/157 [00:13<00:00, 11.42it/s]
100%|██████████| 40/40 [00:01<00:00, 22.35it/s]


Loss: 0.0189   Val_loss: 0.0149
Epoch 4/20


100%|██████████| 157/157 [00:14<00:00, 10.87it/s]
100%|██████████| 40/40 [00:01<00:00, 22.50it/s]
100%|██████████| 157/157 [00:13<00:00, 11.34it/s]
100%|██████████| 40/40 [00:01<00:00, 21.80it/s]
100%|██████████| 157/157 [00:13<00:00, 11.45it/s]
100%|██████████| 40/40 [00:01<00:00, 22.52it/s]
100%|██████████| 157/157 [00:13<00:00, 11.58it/s]
100%|██████████| 40/40 [00:01<00:00, 22.55it/s]
100%|██████████| 157/157 [00:13<00:00, 11.56it/s]
100%|██████████| 40/40 [00:01<00:00, 22.59it/s]


Loss: 0.0134   Val_loss: 0.0123
Epoch 5/20


100%|██████████| 157/157 [00:14<00:00, 11.14it/s]
100%|██████████| 40/40 [00:02<00:00, 18.84it/s]
100%|██████████| 157/157 [00:13<00:00, 11.51it/s]
100%|██████████| 40/40 [00:01<00:00, 22.63it/s]
100%|██████████| 157/157 [00:13<00:00, 11.65it/s]
100%|██████████| 40/40 [00:01<00:00, 22.58it/s]
100%|██████████| 157/157 [00:13<00:00, 11.29it/s]
100%|██████████| 40/40 [00:01<00:00, 22.25it/s]
100%|██████████| 157/157 [00:13<00:00, 11.56it/s]
100%|██████████| 40/40 [00:01<00:00, 22.33it/s]


Loss: 0.0117   Val_loss: 0.0102
Epoch 6/20


100%|██████████| 157/157 [00:13<00:00, 11.34it/s]
100%|██████████| 40/40 [00:01<00:00, 22.66it/s]
100%|██████████| 157/157 [00:13<00:00, 11.57it/s]
100%|██████████| 40/40 [00:01<00:00, 22.61it/s]
100%|██████████| 157/157 [00:13<00:00, 11.67it/s]
100%|██████████| 40/40 [00:01<00:00, 22.33it/s]
100%|██████████| 157/157 [00:13<00:00, 11.54it/s]
100%|██████████| 40/40 [00:01<00:00, 22.41it/s]
100%|██████████| 157/157 [00:13<00:00, 11.61it/s]
100%|██████████| 40/40 [00:01<00:00, 22.73it/s]


Loss: 0.0112   Val_loss: 0.0089
Epoch 7/20


100%|██████████| 157/157 [00:13<00:00, 11.32it/s]
100%|██████████| 40/40 [00:01<00:00, 22.68it/s]
100%|██████████| 157/157 [00:13<00:00, 11.52it/s]
100%|██████████| 40/40 [00:01<00:00, 22.06it/s]
100%|██████████| 157/157 [00:13<00:00, 11.51it/s]
100%|██████████| 40/40 [00:01<00:00, 22.59it/s]
100%|██████████| 157/157 [00:13<00:00, 11.55it/s]
100%|██████████| 40/40 [00:01<00:00, 22.49it/s]
100%|██████████| 157/157 [00:13<00:00, 11.65it/s]
100%|██████████| 40/40 [00:01<00:00, 22.68it/s]


Loss: 0.0101   Val_loss: 0.0094
Epoch 8/20


100%|██████████| 157/157 [00:13<00:00, 11.33it/s]
100%|██████████| 40/40 [00:01<00:00, 22.19it/s]
100%|██████████| 157/157 [00:13<00:00, 11.52it/s]
100%|██████████| 40/40 [00:01<00:00, 22.45it/s]
100%|██████████| 157/157 [00:13<00:00, 11.54it/s]
100%|██████████| 40/40 [00:01<00:00, 22.66it/s]
100%|██████████| 157/157 [00:13<00:00, 11.61it/s]
100%|██████████| 40/40 [00:01<00:00, 22.67it/s]
100%|██████████| 157/157 [00:13<00:00, 11.63it/s]
100%|██████████| 40/40 [00:01<00:00, 22.33it/s]


Loss: 0.0093   Val_loss: 0.0091
Epoch 9/20


100%|██████████| 157/157 [00:14<00:00, 11.02it/s]
100%|██████████| 40/40 [00:01<00:00, 22.49it/s]
100%|██████████| 157/157 [00:13<00:00, 11.57it/s]
100%|██████████| 40/40 [00:01<00:00, 22.62it/s]
100%|██████████| 157/157 [00:13<00:00, 11.72it/s]
100%|██████████| 40/40 [00:01<00:00, 22.73it/s]
100%|██████████| 157/157 [00:13<00:00, 11.62it/s]
100%|██████████| 40/40 [00:01<00:00, 22.45it/s]
100%|██████████| 157/157 [00:13<00:00, 11.58it/s]
100%|██████████| 40/40 [00:01<00:00, 22.14it/s]


Loss: 0.0083   Val_loss: 0.0085
Epoch 10/20


100%|██████████| 157/157 [00:13<00:00, 11.26it/s]
100%|██████████| 40/40 [00:01<00:00, 22.44it/s]
100%|██████████| 157/157 [00:13<00:00, 11.58it/s]
100%|██████████| 40/40 [00:01<00:00, 22.59it/s]
100%|██████████| 157/157 [00:13<00:00, 11.53it/s]
100%|██████████| 40/40 [00:01<00:00, 22.43it/s]
100%|██████████| 157/157 [00:13<00:00, 11.43it/s]
100%|██████████| 40/40 [00:01<00:00, 22.05it/s]
100%|██████████| 157/157 [00:13<00:00, 11.62it/s]
100%|██████████| 40/40 [00:01<00:00, 22.68it/s]


Loss: 0.0076   Val_loss: 0.0081
Epoch 11/20


100%|██████████| 157/157 [00:13<00:00, 11.24it/s]
100%|██████████| 40/40 [00:01<00:00, 22.67it/s]
100%|██████████| 157/157 [00:13<00:00, 11.65it/s]
100%|██████████| 40/40 [00:01<00:00, 22.51it/s]
100%|██████████| 157/157 [00:13<00:00, 11.50it/s]
100%|██████████| 40/40 [00:01<00:00, 22.19it/s]
100%|██████████| 157/157 [00:13<00:00, 11.64it/s]
100%|██████████| 40/40 [00:01<00:00, 22.60it/s]
100%|██████████| 157/157 [00:13<00:00, 11.60it/s]
100%|██████████| 40/40 [00:01<00:00, 22.55it/s]


Loss: 0.0069   Val_loss: 0.0004
Epoch 12/20


100%|██████████| 157/157 [00:13<00:00, 11.25it/s]
100%|██████████| 40/40 [00:01<00:00, 22.43it/s]
100%|██████████| 157/157 [00:13<00:00, 11.58it/s]
100%|██████████| 40/40 [00:01<00:00, 22.18it/s]
100%|██████████| 157/157 [00:13<00:00, 11.58it/s]
100%|██████████| 40/40 [00:01<00:00, 22.62it/s]
100%|██████████| 157/157 [00:13<00:00, 11.56it/s]
100%|██████████| 40/40 [00:01<00:00, 22.71it/s]
100%|██████████| 157/157 [00:13<00:00, 11.61it/s]
100%|██████████| 40/40 [00:01<00:00, 22.62it/s]


Loss: 0.0067   Val_loss: 0.0163
Epoch 13/20


100%|██████████| 157/157 [00:14<00:00, 11.18it/s]
100%|██████████| 40/40 [00:01<00:00, 21.91it/s]
100%|██████████| 157/157 [00:13<00:00, 11.68it/s]
100%|██████████| 40/40 [00:01<00:00, 22.57it/s]
100%|██████████| 157/157 [00:13<00:00, 11.57it/s]
100%|██████████| 40/40 [00:01<00:00, 22.67it/s]
100%|██████████| 157/157 [00:13<00:00, 11.58it/s]
100%|██████████| 40/40 [00:01<00:00, 22.71it/s]
100%|██████████| 157/157 [00:13<00:00, 11.50it/s]
100%|██████████| 40/40 [00:01<00:00, 21.81it/s]


Loss: 0.0066   Val_loss: 0.0046
Epoch 14/20


100%|██████████| 157/157 [00:13<00:00, 11.32it/s]
100%|██████████| 40/40 [00:01<00:00, 22.65it/s]
100%|██████████| 157/157 [00:13<00:00, 11.59it/s]
100%|██████████| 40/40 [00:01<00:00, 22.62it/s]
100%|██████████| 157/157 [00:13<00:00, 11.58it/s]
100%|██████████| 40/40 [00:01<00:00, 22.64it/s]
100%|██████████| 157/157 [00:13<00:00, 11.57it/s]
100%|██████████| 40/40 [00:01<00:00, 21.87it/s]
100%|██████████| 157/157 [00:13<00:00, 11.58it/s]
100%|██████████| 40/40 [00:01<00:00, 22.71it/s]


Loss: 0.0055   Val_loss: 0.0042
Epoch 15/20


100%|██████████| 157/157 [00:14<00:00, 11.18it/s]
100%|██████████| 40/40 [00:01<00:00, 22.64it/s]
100%|██████████| 157/157 [00:13<00:00, 11.65it/s]
100%|██████████| 40/40 [00:01<00:00, 22.58it/s]
100%|██████████| 157/157 [00:13<00:00, 11.63it/s]
100%|██████████| 40/40 [00:01<00:00, 22.15it/s]
100%|██████████| 157/157 [00:13<00:00, 11.43it/s]
100%|██████████| 40/40 [00:01<00:00, 22.75it/s]
100%|██████████| 157/157 [00:13<00:00, 11.45it/s]
100%|██████████| 40/40 [00:01<00:00, 22.62it/s]


Loss: 0.0049   Val_loss: 0.0035
Epoch 16/20


100%|██████████| 157/157 [00:14<00:00, 11.19it/s]
100%|██████████| 40/40 [00:01<00:00, 22.48it/s]
100%|██████████| 157/157 [00:13<00:00, 11.37it/s]
100%|██████████| 40/40 [00:01<00:00, 22.15it/s]
100%|██████████| 157/157 [00:13<00:00, 11.46it/s]
100%|██████████| 40/40 [00:01<00:00, 22.62it/s]
100%|██████████| 157/157 [00:13<00:00, 11.49it/s]
100%|██████████| 40/40 [00:01<00:00, 22.59it/s]
100%|██████████| 157/157 [00:13<00:00, 11.44it/s]
100%|██████████| 40/40 [00:01<00:00, 22.63it/s]


Loss: 0.0046   Val_loss: 0.0034
Epoch 17/20


100%|██████████| 157/157 [00:13<00:00, 11.22it/s]
100%|██████████| 40/40 [00:01<00:00, 21.84it/s]
100%|██████████| 157/157 [00:13<00:00, 11.41it/s]
100%|██████████| 40/40 [00:01<00:00, 22.50it/s]
100%|██████████| 157/157 [00:13<00:00, 11.47it/s]
100%|██████████| 40/40 [00:01<00:00, 22.60it/s]
100%|██████████| 157/157 [00:13<00:00, 11.49it/s]
100%|██████████| 40/40 [00:01<00:00, 22.64it/s]
100%|██████████| 157/157 [00:13<00:00, 11.58it/s]
100%|██████████| 40/40 [00:01<00:00, 22.09it/s]


Loss: 0.0043   Val_loss: 0.0042
Epoch 18/20


100%|██████████| 157/157 [00:14<00:00, 11.17it/s]
100%|██████████| 40/40 [00:01<00:00, 22.58it/s]
100%|██████████| 157/157 [00:13<00:00, 11.62it/s]
100%|██████████| 40/40 [00:01<00:00, 22.63it/s]
100%|██████████| 157/157 [00:13<00:00, 11.57it/s]
100%|██████████| 40/40 [00:01<00:00, 22.66it/s]
100%|██████████| 157/157 [00:13<00:00, 11.62it/s]
100%|██████████| 40/40 [00:01<00:00, 21.88it/s]
100%|██████████| 157/157 [00:13<00:00, 11.53it/s]
100%|██████████| 40/40 [00:01<00:00, 22.73it/s]

Loss: 0.0043   Val_loss: 0.0081
Early stopping





To check the model after SSL

In [13]:
import torch

# Define the path to your checkpoint file
path = 'model/ssl.pt'

# Load the checkpoint
checkpoint = torch.load(path)

# Extract components from the checkpoint
epoch = checkpoint['epoch']  # Epoch number when checkpoint was saved
model_state_dict = checkpoint['model']  # State dictionary of the model
optimizer_state_dict = checkpoint['optimizer']  # State dictionary of the optimizer
best_dev_loss = checkpoint['best_dev_loss']  # Best validation loss recorded
exp_dir = checkpoint['exp_dir']  # Directory or other metadata related to the experiment

# Print or inspect these components as needed
print(f'Epoch: {epoch}')
print(f'Best Validation Loss: {best_dev_loss}')
print(f'Experiment Directory: {exp_dir}')

# Example of inspecting the model state dictionary keys
print('Model State Dictionary Keys:')
for key in model_state_dict.keys():
    print(key)

# Example of inspecting optimizer state dictionary keys
print('Optimizer State Dictionary Keys:')
for key in optimizer_state_dict.keys():
    print(key)


Epoch: 10
Best Validation Loss: 0.00039534185132943095
Experiment Directory: model
Model State Dictionary Keys:
inter.0.conv1.weight
inter.0.conv1.bias
inter.0.bn1.weight
inter.0.bn1.bias
inter.0.bn1.running_mean
inter.0.bn1.running_var
inter.0.bn1.num_batches_tracked
inter.0.conv2.weight
inter.0.conv2.bias
inter.0.bn2.weight
inter.0.bn2.bias
inter.0.bn2.running_mean
inter.0.bn2.running_var
inter.0.bn2.num_batches_tracked
inter.0.conv3.weight
inter.0.conv3.bias
inter.0.bn3.weight
inter.0.bn3.bias
inter.0.bn3.running_mean
inter.0.bn3.running_var
inter.0.bn3.num_batches_tracked
inter.0.conv4.weight
inter.0.conv4.bias
inter.0.bn4.weight
inter.0.bn4.bias
inter.0.bn4.running_mean
inter.0.bn4.running_var
inter.0.bn4.num_batches_tracked
inter.0.conv5.weight
inter.0.conv5.bias
inter.0.bn5.weight
inter.0.bn5.bias
inter.0.bn5.running_mean
inter.0.bn5.running_var
inter.0.bn5.num_batches_tracked
en1.0.weight
en1.0.bias
en1.1.weight
en1.1.bias
en1.1.running_mean
en1.1.running_var
en1.1.num_batches_

Actual Training Part

In [15]:
import os
import re
import pandas as pd
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader

bs = 256
length = 1250
epochs = 20
folds = 5


# Initialize the model and load pretrained weights if available
model = Unet((bs, 1, length)).cuda()
path = 'model/ssl.pt'  # Path to your pretrained model checkpoint
checkpoint = torch.load(path)
pretrained_dict = {k: v for k, v in checkpoint['model'].items() if re.search('^e|^i', k)}

model_dict = model.state_dict()
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = torch.nn.SmoothL1Loss()
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 200], gamma=0.1)
scaler = torch.cuda.amp.GradScaler()
early_stopper = EarlyStopper(patience=100)  # Adjust patience according to your needs

best_loss = 1000

for epoch in range(epochs):
    model.train()
    print('Epoch {}/{}'.format(epoch + 1, epochs))

    running_loss = 0.0
    running_loss_v = 0.0

    for fold in range(1, folds + 1):
        train_loader = DataLoader(BPdatasetv2(fold-1, train=True), batch_size=bs, shuffle=True)
        val_loader = DataLoader(BPdatasetv2(fold-1, val=True), batch_size=bs, shuffle=False)

        for idx, (inputs, output) in tqdm(enumerate(train_loader), total=len(train_loader)):
            inputs = inputs.cuda()
            output = output.cuda()

            optimizer.zero_grad()

            with torch.cuda.amp.autocast():
                pred = model(inputs)
                loss = criterion(pred, output)

            scaler.scale(loss).backward()
            running_loss += loss.item() * inputs.size(0)
            scaler.step(optimizer)
            scaler.update()

        scheduler.step()

        # VALIDATION
        model.eval()
        with torch.no_grad():
            for idx, (inputs_v, labels_v) in tqdm(enumerate(val_loader), total=len(val_loader)):
                inputs_v = inputs_v.cuda()
                labels_v = labels_v.cuda()
                outputs_v = model(inputs_v).cuda()
                loss_v = criterion(outputs_v, labels_v)
                running_loss_v += loss_v.item() * inputs_v.size(0)

    avg_train_loss = running_loss / (len(train_loader.dataset) * folds)
    avg_val_loss = running_loss_v / (len(val_loader.dataset) * folds)

    path = 'final.pt'

    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_dev_loss': best_loss,
            'exp_dir': 'model'
        }, f=path)
    print('Loss: {:.4f}   Val_loss: {:.4f}'.format(avg_train_loss, avg_val_loss))

    if early_stopper.early_stop(avg_val_loss):
        print("Early stopping")
        break


Epoch 1/20


100%|██████████| 157/157 [00:14<00:00, 11.14it/s]
100%|██████████| 40/40 [00:01<00:00, 22.61it/s]
100%|██████████| 157/157 [00:13<00:00, 11.55it/s]
100%|██████████| 40/40 [00:01<00:00, 22.44it/s]
100%|██████████| 157/157 [00:13<00:00, 11.54it/s]
100%|██████████| 40/40 [00:01<00:00, 22.34it/s]
100%|██████████| 157/157 [00:13<00:00, 11.27it/s]
100%|██████████| 40/40 [00:01<00:00, 22.14it/s]
100%|██████████| 157/157 [00:13<00:00, 11.40it/s]
100%|██████████| 40/40 [00:01<00:00, 22.82it/s]


Loss: 0.0177   Val_loss: 0.0259
Epoch 2/20


100%|██████████| 157/157 [00:14<00:00, 11.20it/s]
100%|██████████| 40/40 [00:01<00:00, 22.73it/s]
100%|██████████| 157/157 [00:13<00:00, 11.48it/s]
100%|██████████| 40/40 [00:01<00:00, 22.54it/s]
100%|██████████| 157/157 [00:13<00:00, 11.48it/s]
100%|██████████| 40/40 [00:01<00:00, 22.03it/s]
100%|██████████| 157/157 [00:13<00:00, 11.54it/s]
100%|██████████| 40/40 [00:01<00:00, 22.66it/s]
100%|██████████| 157/157 [00:13<00:00, 11.54it/s]
100%|██████████| 40/40 [00:01<00:00, 22.66it/s]


Loss: 0.0137   Val_loss: 0.0138
Epoch 3/20


100%|██████████| 157/157 [00:14<00:00, 11.20it/s]
100%|██████████| 40/40 [00:01<00:00, 22.71it/s]
100%|██████████| 157/157 [00:13<00:00, 11.55it/s]
100%|██████████| 40/40 [00:01<00:00, 21.85it/s]
100%|██████████| 157/157 [00:13<00:00, 11.44it/s]
100%|██████████| 40/40 [00:01<00:00, 22.56it/s]
100%|██████████| 157/157 [00:13<00:00, 11.54it/s]
100%|██████████| 40/40 [00:01<00:00, 22.70it/s]
100%|██████████| 157/157 [00:13<00:00, 11.55it/s]
100%|██████████| 40/40 [00:01<00:00, 22.66it/s]


Loss: 0.0112   Val_loss: 0.0092
Epoch 4/20


100%|██████████| 157/157 [00:14<00:00, 11.00it/s]
100%|██████████| 40/40 [00:01<00:00, 22.03it/s]
100%|██████████| 157/157 [00:13<00:00, 11.60it/s]
100%|██████████| 40/40 [00:01<00:00, 22.68it/s]
100%|██████████| 157/157 [00:13<00:00, 11.49it/s]
100%|██████████| 40/40 [00:01<00:00, 22.67it/s]
100%|██████████| 157/157 [00:13<00:00, 11.59it/s]
100%|██████████| 40/40 [00:01<00:00, 22.68it/s]
100%|██████████| 157/157 [00:13<00:00, 11.50it/s]
100%|██████████| 40/40 [00:01<00:00, 22.06it/s]


Loss: 0.0090   Val_loss: 0.0088
Epoch 5/20


100%|██████████| 157/157 [00:13<00:00, 11.23it/s]
100%|██████████| 40/40 [00:01<00:00, 22.69it/s]
100%|██████████| 157/157 [00:13<00:00, 11.58it/s]
100%|██████████| 40/40 [00:01<00:00, 22.66it/s]
100%|██████████| 157/157 [00:13<00:00, 11.57it/s]
100%|██████████| 40/40 [00:01<00:00, 22.70it/s]
100%|██████████| 157/157 [00:13<00:00, 11.47it/s]
100%|██████████| 40/40 [00:01<00:00, 21.86it/s]
100%|██████████| 157/157 [00:13<00:00, 11.59it/s]
100%|██████████| 40/40 [00:01<00:00, 22.59it/s]


Loss: 0.0088   Val_loss: 0.0080
Epoch 6/20


100%|██████████| 157/157 [00:13<00:00, 11.21it/s]
100%|██████████| 40/40 [00:01<00:00, 22.61it/s]
100%|██████████| 157/157 [00:13<00:00, 11.55it/s]
100%|██████████| 40/40 [00:01<00:00, 22.52it/s]
100%|██████████| 157/157 [00:13<00:00, 11.54it/s]
100%|██████████| 40/40 [00:01<00:00, 22.32it/s]
100%|██████████| 157/157 [00:13<00:00, 11.55it/s]
100%|██████████| 40/40 [00:01<00:00, 22.65it/s]
100%|██████████| 157/157 [00:13<00:00, 11.40it/s]
100%|██████████| 40/40 [00:01<00:00, 22.68it/s]


Loss: 0.0076   Val_loss: 0.0071
Epoch 7/20


100%|██████████| 157/157 [00:13<00:00, 11.27it/s]
100%|██████████| 40/40 [00:01<00:00, 22.49it/s]
100%|██████████| 157/157 [00:13<00:00, 11.45it/s]
100%|██████████| 40/40 [00:01<00:00, 22.43it/s]
100%|██████████| 157/157 [00:13<00:00, 11.54it/s]
100%|██████████| 40/40 [00:01<00:00, 22.61it/s]
100%|██████████| 157/157 [00:13<00:00, 11.53it/s]
100%|██████████| 40/40 [00:01<00:00, 22.62it/s]
100%|██████████| 157/157 [00:13<00:00, 11.52it/s]
100%|██████████| 40/40 [00:01<00:00, 22.47it/s]


Loss: 0.0071   Val_loss: 0.0067
Epoch 8/20


100%|██████████| 157/157 [00:14<00:00, 11.05it/s]
100%|██████████| 40/40 [00:01<00:00, 22.61it/s]
100%|██████████| 157/157 [00:13<00:00, 11.57it/s]
100%|██████████| 40/40 [00:01<00:00, 22.73it/s]
100%|██████████| 157/157 [00:13<00:00, 11.34it/s]
100%|██████████| 40/40 [00:01<00:00, 22.61it/s]
100%|██████████| 157/157 [00:13<00:00, 11.55it/s]
100%|██████████| 40/40 [00:01<00:00, 21.96it/s]
100%|██████████| 157/157 [00:13<00:00, 11.40it/s]
100%|██████████| 40/40 [00:01<00:00, 22.59it/s]


Loss: 0.0064   Val_loss: 0.0060
Epoch 9/20


100%|██████████| 157/157 [00:13<00:00, 11.24it/s]
100%|██████████| 40/40 [00:01<00:00, 22.64it/s]
100%|██████████| 157/157 [00:13<00:00, 11.49it/s]
100%|██████████| 40/40 [00:01<00:00, 22.65it/s]
100%|██████████| 157/157 [00:13<00:00, 11.43it/s]
100%|██████████| 40/40 [00:01<00:00, 21.99it/s]
100%|██████████| 157/157 [00:13<00:00, 11.57it/s]
100%|██████████| 40/40 [00:01<00:00, 22.61it/s]
100%|██████████| 157/157 [00:13<00:00, 11.58it/s]
100%|██████████| 40/40 [00:01<00:00, 22.61it/s]


Loss: 0.0061   Val_loss: 0.0059
Epoch 10/20


100%|██████████| 157/157 [00:14<00:00, 11.13it/s]
100%|██████████| 40/40 [00:01<00:00, 22.72it/s]
100%|██████████| 157/157 [00:13<00:00, 11.50it/s]
100%|██████████| 40/40 [00:01<00:00, 21.83it/s]
100%|██████████| 157/157 [00:13<00:00, 11.53it/s]
100%|██████████| 40/40 [00:01<00:00, 22.73it/s]
100%|██████████| 157/157 [00:13<00:00, 11.56it/s]
100%|██████████| 40/40 [00:01<00:00, 22.72it/s]
100%|██████████| 157/157 [00:13<00:00, 11.54it/s]
100%|██████████| 40/40 [00:01<00:00, 22.62it/s]


Loss: 0.0058   Val_loss: 0.0055
Epoch 11/20


100%|██████████| 157/157 [00:14<00:00, 11.01it/s]
100%|██████████| 40/40 [00:01<00:00, 22.16it/s]
100%|██████████| 157/157 [00:13<00:00, 11.55it/s]
100%|██████████| 40/40 [00:01<00:00, 22.68it/s]
100%|██████████| 157/157 [00:13<00:00, 11.51it/s]
100%|██████████| 40/40 [00:01<00:00, 22.68it/s]
100%|██████████| 157/157 [00:13<00:00, 11.41it/s]
100%|██████████| 40/40 [00:01<00:00, 22.72it/s]
100%|██████████| 157/157 [00:13<00:00, 11.47it/s]
100%|██████████| 40/40 [00:01<00:00, 22.21it/s]


Loss: 0.0056   Val_loss: 0.0053
Epoch 12/20


100%|██████████| 157/157 [00:13<00:00, 11.27it/s]
100%|██████████| 40/40 [00:01<00:00, 22.81it/s]
100%|██████████| 157/157 [00:13<00:00, 11.56it/s]
100%|██████████| 40/40 [00:01<00:00, 22.66it/s]
100%|██████████| 157/157 [00:13<00:00, 11.57it/s]
100%|██████████| 40/40 [00:01<00:00, 22.35it/s]
100%|██████████| 157/157 [00:13<00:00, 11.53it/s]
100%|██████████| 40/40 [00:01<00:00, 22.29it/s]
100%|██████████| 157/157 [00:13<00:00, 11.52it/s]
100%|██████████| 40/40 [00:01<00:00, 22.62it/s]


Loss: 0.0056   Val_loss: 0.0071
Epoch 13/20


100%|██████████| 157/157 [00:13<00:00, 11.26it/s]
100%|██████████| 40/40 [00:01<00:00, 22.65it/s]
100%|██████████| 157/157 [00:13<00:00, 11.45it/s]
100%|██████████| 40/40 [00:01<00:00, 22.44it/s]
100%|██████████| 157/157 [00:13<00:00, 11.52it/s]
100%|██████████| 40/40 [00:01<00:00, 22.21it/s]
100%|██████████| 157/157 [00:13<00:00, 11.53it/s]
100%|██████████| 40/40 [00:01<00:00, 22.66it/s]
100%|██████████| 157/157 [00:13<00:00, 11.56it/s]
100%|██████████| 40/40 [00:01<00:00, 22.59it/s]


Loss: 0.0059   Val_loss: 0.0058
Epoch 14/20


100%|██████████| 157/157 [00:14<00:00, 11.20it/s]
100%|██████████| 40/40 [00:01<00:00, 22.38it/s]
100%|██████████| 157/157 [00:13<00:00, 11.48it/s]
100%|██████████| 40/40 [00:01<00:00, 22.33it/s]
100%|██████████| 157/157 [00:13<00:00, 11.58it/s]
100%|██████████| 40/40 [00:01<00:00, 22.53it/s]
100%|██████████| 157/157 [00:13<00:00, 11.59it/s]
100%|██████████| 40/40 [00:01<00:00, 22.64it/s]
100%|██████████| 157/157 [00:13<00:00, 11.52it/s]
100%|██████████| 40/40 [00:01<00:00, 22.52it/s]


Loss: 0.0054   Val_loss: 0.0054
Epoch 15/20


100%|██████████| 157/157 [00:14<00:00, 10.99it/s]
100%|██████████| 40/40 [00:01<00:00, 22.35it/s]
100%|██████████| 157/157 [00:13<00:00, 11.55it/s]
100%|██████████| 40/40 [00:01<00:00, 22.68it/s]
100%|██████████| 157/157 [00:13<00:00, 11.56it/s]
100%|██████████| 40/40 [00:01<00:00, 22.66it/s]
100%|██████████| 157/157 [00:13<00:00, 11.62it/s]
100%|██████████| 40/40 [00:01<00:00, 22.34it/s]
100%|██████████| 157/157 [00:13<00:00, 11.49it/s]
100%|██████████| 40/40 [00:01<00:00, 22.11it/s]


Loss: 0.0052   Val_loss: 0.0050
Epoch 16/20


100%|██████████| 157/157 [00:13<00:00, 11.29it/s]
100%|██████████| 40/40 [00:01<00:00, 22.74it/s]
100%|██████████| 157/157 [00:13<00:00, 11.58it/s]
100%|██████████| 40/40 [00:01<00:00, 22.75it/s]
100%|██████████| 157/157 [00:13<00:00, 11.60it/s]
100%|██████████| 40/40 [00:01<00:00, 22.46it/s]
100%|██████████| 157/157 [00:13<00:00, 11.55it/s]
100%|██████████| 40/40 [00:01<00:00, 22.13it/s]
100%|██████████| 157/157 [00:13<00:00, 11.47it/s]
100%|██████████| 40/40 [00:01<00:00, 22.66it/s]


Loss: 0.0052   Val_loss: 0.0057
Epoch 17/20


100%|██████████| 157/157 [00:13<00:00, 11.30it/s]
100%|██████████| 40/40 [00:01<00:00, 22.73it/s]
100%|██████████| 157/157 [00:13<00:00, 11.61it/s]
100%|██████████| 40/40 [00:01<00:00, 22.48it/s]
100%|██████████| 157/157 [00:13<00:00, 11.50it/s]
100%|██████████| 40/40 [00:01<00:00, 22.17it/s]
100%|██████████| 157/157 [00:13<00:00, 11.50it/s]
100%|██████████| 40/40 [00:01<00:00, 22.62it/s]
100%|██████████| 157/157 [00:13<00:00, 11.60it/s]
100%|██████████| 40/40 [00:01<00:00, 22.66it/s]


Loss: 0.0052   Val_loss: 0.0051
Epoch 18/20


100%|██████████| 157/157 [00:13<00:00, 11.30it/s]
100%|██████████| 40/40 [00:01<00:00, 22.66it/s]
100%|██████████| 157/157 [00:13<00:00, 11.60it/s]
100%|██████████| 40/40 [00:02<00:00, 14.66it/s]
100%|██████████| 157/157 [00:15<00:00,  9.94it/s]
100%|██████████| 40/40 [00:01<00:00, 22.53it/s]
100%|██████████| 157/157 [00:13<00:00, 11.45it/s]
100%|██████████| 40/40 [00:01<00:00, 22.39it/s]
100%|██████████| 157/157 [00:13<00:00, 11.45it/s]
100%|██████████| 40/40 [00:01<00:00, 22.07it/s]


Loss: 0.0053   Val_loss: 0.0050
Epoch 19/20


100%|██████████| 157/157 [00:13<00:00, 11.24it/s]
100%|██████████| 40/40 [00:01<00:00, 22.80it/s]
100%|██████████| 157/157 [00:13<00:00, 11.56it/s]
100%|██████████| 40/40 [00:01<00:00, 22.78it/s]
100%|██████████| 157/157 [00:13<00:00, 11.62it/s]
100%|██████████| 40/40 [00:01<00:00, 22.38it/s]
100%|██████████| 157/157 [00:13<00:00, 11.50it/s]
100%|██████████| 40/40 [00:01<00:00, 22.21it/s]
100%|██████████| 157/157 [00:13<00:00, 11.58it/s]
100%|██████████| 40/40 [00:01<00:00, 22.73it/s]


Loss: 0.0051   Val_loss: 0.0049
Epoch 20/20


100%|██████████| 157/157 [00:14<00:00, 11.21it/s]
100%|██████████| 40/40 [00:01<00:00, 22.63it/s]
100%|██████████| 157/157 [00:13<00:00, 11.48it/s]
100%|██████████| 40/40 [00:01<00:00, 22.32it/s]
100%|██████████| 157/157 [00:13<00:00, 11.47it/s]
100%|██████████| 40/40 [00:01<00:00, 22.16it/s]
100%|██████████| 157/157 [00:13<00:00, 11.63it/s]
100%|██████████| 40/40 [00:01<00:00, 22.68it/s]
100%|██████████| 157/157 [00:13<00:00, 11.46it/s]
100%|██████████| 40/40 [00:01<00:00, 22.68it/s]

Loss: 0.0049   Val_loss: 0.0052





To check final trained model

In [16]:
import torch

# Define the path to your checkpoint file
path = 'final.pt'

# Load the checkpoint
checkpoint = torch.load(path)

# Extract components from the checkpoint
epoch = checkpoint['epoch']  # Epoch number when checkpoint was saved
model_state_dict = checkpoint['model']  # State dictionary of the model
optimizer_state_dict = checkpoint['optimizer']  # State dictionary of the optimizer
best_dev_loss = checkpoint['best_dev_loss']  # Best validation loss recorded
exp_dir = checkpoint['exp_dir']  # Directory or other metadata related to the experiment

# Print or inspect these components as needed
print(f'Epoch: {epoch}')
print(f'Best Validation Loss: {best_dev_loss}')
print(f'Experiment Directory: {exp_dir}')

# Example of inspecting the model state dictionary keys
print('Model State Dictionary Keys:')
for key in model_state_dict.keys():
    print(key)

# Example of inspecting optimizer state dictionary keys
print('Optimizer State Dictionary Keys:')
for key in optimizer_state_dict.keys():
    print(key)


Epoch: 18
Best Validation Loss: 0.00494062293574214
Experiment Directory: model
Model State Dictionary Keys:
inter.0.conv1.weight
inter.0.conv1.bias
inter.0.bn1.weight
inter.0.bn1.bias
inter.0.bn1.running_mean
inter.0.bn1.running_var
inter.0.bn1.num_batches_tracked
inter.0.conv2.weight
inter.0.conv2.bias
inter.0.bn2.weight
inter.0.bn2.bias
inter.0.bn2.running_mean
inter.0.bn2.running_var
inter.0.bn2.num_batches_tracked
inter.0.conv3.weight
inter.0.conv3.bias
inter.0.bn3.weight
inter.0.bn3.bias
inter.0.bn3.running_mean
inter.0.bn3.running_var
inter.0.bn3.num_batches_tracked
inter.0.conv4.weight
inter.0.conv4.bias
inter.0.bn4.weight
inter.0.bn4.bias
inter.0.bn4.running_mean
inter.0.bn4.running_var
inter.0.bn4.num_batches_tracked
inter.0.conv5.weight
inter.0.conv5.bias
inter.0.bn5.weight
inter.0.bn5.bias
inter.0.bn5.running_mean
inter.0.bn5.running_var
inter.0.bn5.num_batches_tracked
en1.0.weight
en1.0.bias
en1.1.weight
en1.1.bias
en1.1.running_mean
en1.1.running_var
en1.1.num_batches_tra

Test trained model on remaining (independent) testing set

In [17]:
import pickle
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader

import re

bs = 256

model = Unet((256,1,1250)).cuda()
path = 'final.pt'
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model'])

pick_path = 'output.p'

test = torch.utils.data.DataLoader(BPdatasetv2(0, train = False, val = False,  test = True), batch_size=bs)

temp1 = []
model.eval()
with torch.no_grad():
    for idx,(inputs,labels) in tqdm(enumerate(test),total=len(test),  disable=True):
        inputs = inputs.cuda()
        labels = labels.cuda()
        outputs_v= model(inputs).cuda()

        temp1.extend(outputs_v)

temp1 = torch.stack(temp1)
with open(pick_path,'wb') as f:
    pickle.dump(temp1.cpu().detach().numpy(), f)

Check the final predictions

In [18]:
import pickle
import numpy as np
import torch

# Define the path to your output file
pick_path = 'output.p'

# Load the pickle file
with open(pick_path, 'rb') as f:
    output_data = pickle.load(f)

# Check the type of the loaded data
print(f'Type of output data: {type(output_data)}')

# If it's a numpy array, inspect its shape and some sample data
if isinstance(output_data, np.ndarray):
    print(f'Shape of output data: {output_data.shape}')
    print('Sample data:')
    print(output_data[:5])  # Print the first 5 entries

# If it's a list of tensors, convert to a numpy array and inspect
elif isinstance(output_data, list) and isinstance(output_data[0], torch.Tensor):
    output_data = torch.stack(output_data).cpu().numpy()
    print(f'Shape of output data: {output_data.shape}')
    print('Sample data:')
    print(output_data[:5])  # Print the first 5 entries

# If it's another data type, provide appropriate inspection
else:
    print(f'Unexpected data type: {type(output_data)}')
    print('Sample data:')
    print(output_data[:5])  # Print the first 5 entries


Type of output data: <class 'numpy.ndarray'>
Shape of output data: (10000, 1, 1250)
Sample data:
[[[0.46805903 0.50473255 0.5322795  ... 0.31806147 0.35078776 0.33633965]]

 [[0.19793008 0.18881768 0.20094647 ... 0.28478456 0.26610154 0.25896198]]

 [[0.16194268 0.15142511 0.14950547 ... 0.37587768 0.35324025 0.34329766]]

 [[0.5091641  0.50857586 0.4942793  ... 0.29647446 0.31255406 0.30774057]]

 [[0.18779568 0.20288466 0.20725879 ... 0.52020985 0.5465441  0.53120536]]]


In [19]:
import pickle
import numpy as np

# Define the path to the output file
pick_path = 'output.p'

# Load the predictions from the pickle file
with open(pick_path, 'rb') as f:
    Y_pred = pickle.load(f)

# Check the type and shape of the loaded predictions
print(f'Type of Y_pred: {type(Y_pred)}')
print(f'Shape of Y_pred: {Y_pred.shape}')


Type of Y_pred: <class 'numpy.ndarray'>
Shape of Y_pred: (10000, 1, 1250)


Evaluations standarts

In [39]:
import pickle
import os
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score


def evaluate_BHS_Standard(filename):
    """
        Evaluates PPG2ABP based on
        BHS Standard Metric
    """

    def BHS_metric(err):
        """
        Computes the BHS Standard metric

        Arguments:
            err {array} -- array of absolute error

        Returns:
            tuple -- tuple of percentage of samples with <=5 mmHg, <=10 mmHg and <=15 mmHg error
        """

        leq5 = 0
        leq10 = 0
        leq15 = 0

        for i in range(len(err)):
            if abs(err[i]) <= 5:
                leq5 += 1
                leq10 += 1
                leq15 += 1
            elif abs(err[i]) <= 10:
                leq10 += 1
                leq15 += 1
            elif abs(err[i]) <= 15:
                leq15 += 1

        return (leq5 * 100.0 / len(err), leq10 * 100.0 / len(err), leq15 * 100.0 / len(err))

    def calcError(Ytrue, Ypred, max_abp, min_abp):
        """
        Calculates the absolute error of sbp, dbp, map etc.

        Arguments:
            Ytrue {array} -- ground truth
            Ypred {array} -- predicted
            max_abp {float} -- max value of abp signal
            min_abp {float} -- min value of abp signal

        Returns:
            tuple -- tuple of abs. errors of sbp, dbp and map calculation
        """

        sbps = []
        dbps = []
        maps = []

        for i in range(len(Ytrue)):
            y_t = Ytrue[i].ravel()
            y_p = Ypred[i].ravel()

            y_t = y_t * (max_abp - min_abp)
            y_p = y_p * (max_abp - min_abp)

            dbps.append(abs(min(y_t) - min(y_p)))
            sbps.append(abs(max(y_t) - max(y_p)))
            maps.append(abs(np.mean(y_t) - np.mean(y_p)))

        return (sbps, dbps, maps)

    dt = pickle.load(open(os.path.join('data', 'test.p'), 'rb'))  # loading test data
    Y_test = dt['Y_test']

    dt = pickle.load(open('meta.p', 'rb'))  # loading meta data
    max_abp = dt['max_abp']
    min_abp = dt['min_abp']

    Y_pred = pickle.load(open(filename, 'rb'))  # loading prediction

    (sbps, dbps, maps) = calcError(Y_test, Y_pred, max_abp, min_abp)  # compute errors

    sbp_percent = BHS_metric(sbps)  # compute BHS metric for sbp
    dbp_percent = BHS_metric(dbps)  # compute BHS metric for dbp
    map_percent = BHS_metric(maps)  # compute BHS metric for map

    print('----------------------------')
    print('|        BHS-Metric        |')
    print('----------------------------')

    print('----------------------------------------')
    print('|     | <= 5mmHg | <=10mmHg | <=15mmHg |')
    print('----------------------------------------')
    print('| DBP |  {} %  |  {} %  |  {} %  |'.format(round(dbp_percent[0], 2), round(dbp_percent[1], 2), round(dbp_percent[2], 2)))
    print('| MAP |  {} %  |  {} %  |  {} %  |'.format(round(map_percent[0], 2), round(map_percent[1], 2), round(map_percent[2], 2)))
    print('| SBP |  {} %  |  {} %  |  {} %  |'.format(round(sbp_percent[0], 2), round(sbp_percent[1], 2), round(sbp_percent[2], 2)))
    print('----------------------------------------')


def evaluate_AAMI_Standard(filename):
    """
        Evaluate PPG2ABP using AAMI Standard metric
    """

    def calcErrorAAMI(Ypred, Ytrue, max_abp, min_abp):
        """
        Calculates error of sbp, dbp, map for AAMI standard computation

        Arguments:
            Ytrue {array} -- ground truth
            Ypred {array} -- predicted
            max_abp {float} -- max value of abp signal
            min_abp {float} -- min value of abp signal

        Returns:
            tuple -- tuple of errors of sbp, dbp and map calculation
        """

        sbps = []
        dbps = []
        maps = []

        for i in range(len(Ytrue)):
            y_t = Ytrue[i].ravel()
            y_p = Ypred[i].ravel()

            y_t = y_t * (max_abp - min_abp)
            y_p = y_p * (max_abp - min_abp)

            dbps.append(min(y_p) - min(y_t))
            sbps.append(max(y_p) - max(y_t))
            maps.append(np.mean(y_p) - np.mean(y_t))

        return (sbps, dbps, maps)

    dt = pickle.load(open(os.path.join('data', 'test.p'), 'rb'))  # loading test data
    Y_test = dt['Y_test']

    dt = pickle.load(open(os.path.join('data', 'meta.p'), 'rb'))  # loading metadata
    max_abp = dt['max_abp']
    min_abp = dt['min_abp']

    Y_pred = pickle.load(open(filename, 'rb'))  # loading prediction

    (sbps, dbps, maps) = calcErrorAAMI(Y_test, Y_pred, max_abp, min_abp)  # compute error

    print('---------------------')
    print('|   AAMI Standard   |')
    print('---------------------')

    print('-----------------------')
    print('|     |  ME   |  STD  |')
    print('-----------------------')
    print('| DBP | {} | {} |'.format(round(np.mean(dbps), 3), round(np.std(dbps), 3)))
    print('| MAP | {} | {} |'.format(round(np.mean(maps), 3), round(np.std(maps), 3)))
    print('| SBP | {} | {} |'.format(round(np.mean(sbps), 3), round(np.std(sbps), 3)))
    print('-----------------------')


def evaluate_metrics(filename):
    def calcError(Ytrue, Ypred, max_abp, min_abp):
        sbp_t = []
        sbp_p = []
        dbp_t = []
        dbp_p = []
        map_t = []
        map_p = []

        for i in range(len(Ytrue)):
            y_t = Ytrue[i].ravel()
            y_p = Ypred[i].ravel()

            y_t = y_t * (max_abp - min_abp)
            y_p = y_p * (max_abp - min_abp)

            sbp_p.append(abs(max(y_p)))
            dbp_p.append(abs(min(y_p)))
            map_p.append(abs(np.mean(y_p)))
            sbp_t.append(abs(max(y_t)))
            dbp_t.append(abs(min(y_t)))
            map_t.append(abs(np.mean(y_t)))

        print("SBP")
        print("Mean Absolute Error : ", round(mean_absolute_error(sbp_t, sbp_p), 3))
        print("Root Mean Squared Error : ", round(mean_squared_error(sbp_t, sbp_p, squared=False), 3))
        print("R2 : ", r2_score(sbp_t, sbp_p))

        print("")

        print("DBP")
        print("Mean Absolute Error : ", round(mean_absolute_error(dbp_t, dbp_p), 3))
        print("Root Mean Squared Error : ", round(mean_squared_error(dbp_t, dbp_p, squared=False), 3))
        print("R2 : ", r2_score(dbp_t, dbp_p))

        print("")

        print("MAP")
        print("Mean Absolute Error : ", mean_absolute_error(map_t, map_p))
        print("Root Mean Squared Error : ", round(mean_squared_error(map_t, map_p, squared=False), 2))
        print("R2 : ", r2_score(map_t, map_p))

        print("------------------------------------------------------------------------")

    dt = pickle.load(open(os.path.join('data', 'test.p'), 'rb'))  # loading test data
    Y_test = dt['Y_test']

    dt = pickle.load(open('meta.p', 'rb'))  # loading meta data
    max_abp = dt['max_abp']
    min_abp = dt['min_abp']

    Y_pred = pickle.load(open(filename, 'rb'))  # loading prediction
    calcError(Y_test, Y_pred, max_abp, min_abp)


evaluate_BHS_Standard('output.p')


----------------------------
|        BHS-Metric        |
----------------------------
----------------------------------------
|     | <= 5mmHg | <=10mmHg | <=15mmHg |
----------------------------------------
| DBP |  49.05 %  |  82.52 %  |  95.32 %  |
| MAP |  48.29 %  |  78.28 %  |  92.55 %  |
| SBP |  28.14 %  |  53.55 %  |  71.57 %  |
----------------------------------------


In [40]:
evaluate_AAMI_Standard('output.p')

---------------------
|   AAMI Standard   |
---------------------
-----------------------
|     |  ME   |  STD  |
-----------------------
| DBP | -1.312 | 7.827 |
| MAP | -0.193 | 8.636 |
| SBP | -0.029 | 14.945 |
-----------------------


In [41]:
evaluate_metrics('output.p')

SBP
Mean Absolute Error :  11.584
Root Mean Squared Error :  14.945
R2 :  0.30646496557721925

DBP
Mean Absolute Error :  6.113
Root Mean Squared Error :  7.935
R2 :  0.21471609588969098

MAP
Mean Absolute Error :  6.554343162021774
Root Mean Squared Error :  8.64
R2 :  0.3982593020463836
------------------------------------------------------------------------
