In [1]:
cd ..

/home/dmitriishubin/Desktop/physionet-challenge-2020


In [2]:
import pandas as pd
import numpy as np
import seaborn as sns
import json
import os
import gc
from tqdm import tqdm
import matplotlib.pyplot as plt

from kardioml.data.p_t_wave_detector import PTWaveDetection

In [3]:
%matplotlib qt

In [62]:
DATASETS = ['A','B','C','D','E','F']
#DATASETS = ['A','B','D','E'] #current for training
#DATASETS = ['D']
datalist = []

for dataset in DATASETS:
    files = [i[:-4] for i in os.listdir(f'./data/{dataset}/formatted/') if i.find('.npy')!=-1]
    for file in files:
        datalist.append(f'./data/{dataset}/formatted/'+file)



In [63]:
len(datalist)

43101

# Check the length fits json and numpy

In [64]:
for data in tqdm(datalist):
    signal = np.load(data+'.npy')
    meta = json.load(open(data+'.json'))
    
    if meta['shape'][0] != signal.shape[0]:
        break
    

100%|██████████| 43101/43101 [47:57<00:00, 14.98it/s]      


# Check the length distribution, all datasets

In [None]:
length_list = []
exclusions = []
exclusions_labels = []
exclusions_digits = []

for data in tqdm(datalist):
    meta = json.load(open(data+'.json'))
    if meta['labels_training_merged'] is None:
        continue
    if meta['shape'][0] > 38000:
        file_name = data.split('/')     
        exclusions.append(file_name[-1])
        exclusions_labels.append(meta['labels_full'])
        exclusions_digits.append(meta['labels_training_merged'])
    length_list.append(meta['shape'][0])

    
sns.distplot(length_list)

In [None]:
hist = np.histogram(length_list,2000)
plt.plot(hist[1][:-1][:100],hist[0])

In [None]:
np.percentile(length_list,99)

# Check distribution of classes, calculate weights

In [18]:
labels = []

for data in tqdm(datalist):
    meta = json.load(open(data+'.json'))
    if meta['shape'][0] > 38000:
        continue
    elif meta['labels_training_merged'] is None:
        continue
    else:
        labels.append(meta['labels_training_merged'])
    

    
labels = np.array(labels)

100%|██████████| 43027/43027 [03:14<00:00, 221.46it/s]


In [16]:
weights = np.sum(labels,axis=0)
weights = weights/labels.shape[0]
weights[np.where(weights == 0.)] = 1
weights = 1/weights
weights = weights.tolist()
weights

[17.330229671011793,
 9.83409651285664,
 228.84426229508196,
 109.05859375,
 11.344575375863469,
 23.304674457429048,
 17.170356703567037,
 5.42537893509522,
 34.768368617683684,
 153.4010989010989,
 35.206809583858764,
 93.37458193979933,
 22.940838126540672,
 157.73446327683618,
 82.11470588235294,
 228.84426229508196,
 50.854280510018214,
 81.15988372093022,
 1.0,
 35.70204603580562,
 41.11782032400589,
 1.4694983946523503,
 24.994628469113696,
 1.0,
 11.795099281791297,
 93.37458193979933,
 1.0]

# Checking PT detector perfomance

In [38]:
data_range = np.arange(len(datalist))[0:11].tolist()

datalist_sub = [i for index,i in enumerate(datalist) if index in data_range]


for data in tqdm(datalist_sub):
    meta = json.load(open(data+'.json'))
    if meta['shape'][0] > 38000:
        continue
    elif meta['labels_training_merged'] is None:
        continue
    else:
        signal = np.load(data+'.npy')
        
        
        fig = plt.figure(figsize=(20,10))
        for i in range(12):
            channel = signal[:,i]
            plt.plot(channel+i*500)
            plt.plot(meta['p_waves'][i],channel[meta['p_waves'][i]]+i*500,'*')
            plt.plot(meta['t_waves'][i],channel[meta['t_waves'][i]]+i*500,'*')
        plt.show()
        plt.title(str(meta['labels_full'])+' | '+data)
        print('===================')
        print(data)
        print(meta['labels_full'])
        print('===================')
        if data == '../data/A/formatted/A2091':
            break

 27%|██▋       | 3/11 [00:00<00:00, 15.29it/s]

./data/F/formatted/E06964
['ventricular premature beats', 'prolonged qt interval']
./data/F/formatted/E00805
['sinus rhythm']


 82%|████████▏ | 9/11 [00:00<00:00, 19.47it/s]

./data/F/formatted/E09754
['left axis deviation']
./data/F/formatted/E09218
['left axis deviation']
./data/F/formatted/E07023
['1st degree av block', 'sinus bradycardia']
./data/F/formatted/E08959
['premature atrial contraction', 'premature atrial contraction']


100%|██████████| 11/11 [00:00<00:00, 20.48it/s]

./data/F/formatted/E07619
['t wave abnormal', 'bradycardia']
./data/F/formatted/E05312
['prolonged qt interval', 'atrial fibrillation', 't wave abnormal']





# Custom metric

In [57]:
#pred = np.array([1,1,0.2,0.3,0]).reshape(1,-1)
pred =  np.array([1,1,0,1,1]).reshape(1,-1)
label = np.array([1,1,0,1,0]).reshape(1,-1)

pred = (pred - 0.5)*2
label = (label - 0.5)*2

metric = pred*label
# metric = abs(metric)
metric

array([[ 1.,  1.,  1.,  1., -1.]])

In [50]:
np.sum(metric)

55.0