In [1]:
%load_ext autoreload
%autoreload 1
%aimport ECGDataset
%aimport Models
%aimport train_test_validat
%aimport self_attention
%aimport ECGplot
%aimport Net

import ECGDataset 
import Models 
import Net
from train_test_validat import *
from self_attention import *
import matplotlib.pyplot as plt
import ecg_plot
import cam
import ECGplot
import ECGHandle
import torch
import torch.utils.data as Data
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import StratifiedKFold
import random
import pandas as pd
from tqdm import tqdm

import time
import math
import os
import gc
from torch.utils.tensorboard import SummaryWriter


def seed_torch(seed=2023):
	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化，使得实验可复现
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
	torch.backends.cudnn.benchmark = False 
	torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.enabled = False

In [2]:
EcgChannles_num = 12
EcgLength_num = 5000
DEVICE = "cpu"
seed_torch(2023)

In [3]:
supplement_diagnose = pd.read_csv('./补充诊断.csv',encoding='utf-8-sig')

# 使用groupby方法按照ID分组，然后使用agg方法将data列拼接在一起
supplement_diagnose = supplement_diagnose.groupby('ID')['住院所有诊断'].agg(lambda x: ','.join(x.astype(str))).reset_index()
# 使用str.contains方法筛选data列中包含‘高血压’的行
filtered_df = supplement_diagnose[supplement_diagnose['住院所有诊断'].str.contains('高血压')]
# 将筛选后的ID列转换为list
id_list = filtered_df['ID'].astype(str).tolist()

In [4]:
data_root = '/workspace/data/Preprocess_HTN/datas_/'
ALL_data = pd.read_csv(data_root+'/All_data_handled_ID_range_age_IDimputate.csv',low_memory=False)
ALL_data = ECGHandle.change_label(ALL_data)
ALL_data = ECGHandle.filter_ID(ALL_data)
# ALL_data = ECGHandle.filter_departmentORlabel(ALL_data,'外科')
ALL_data = ECGHandle.filter_ages(ALL_data,18) 
ALL_data = ECGHandle.filter_QC(ALL_data)
ALL_data = ECGHandle.correct_label(ALL_data,reset_list = id_list)
ALL_data = ECGHandle.correct_age(ALL_data)
ALL_data = ECGHandle.filter_diagnose(ALL_data,'起搏')
ALL_data = ECGHandle.filter_diagnose(ALL_data,'房颤')
ALL_data = ECGHandle.filter_diagnose(ALL_data,'左束支传导阻滞')
ALL_data = ECGHandle.filter_diagnose(ALL_data,'左前分支阻滞')
ALL_data[ALL_data['label'] == 0]['住院号'].unique().__len__()
ALL_data = ALL_data.rename(columns={'住院号':'ID','年龄':'age','性别':'gender','姓名':'name'}) 
ALL_data_buffer = ALL_data.copy()



            orginal   removed diagnose NaN
   nums      200082          199997       
              HTN             NHTN        
   nums       3273           196724       


            orginal      removed ID NaN   
   nums      199997          199995       
              HTN             NHTN        
   nums       3273           196722       


            orginal      filtered ages    
   nums      199995          183068       
              HTN             NHTN        
   nums       3220           179848       


            orginal            QC         
   nums      183068          69819        
              HTN             NHTN        
   nums       1477           68342        


     reset num:      18448
  ERR labels num:     24  
            orginal      correct label    
   nums      69819           69819        
              HTN             NHTN        
   nums      18690           51129        


   ERR ages num:     15189
            orginal       correct age     
   n

In [None]:
seed_torch(2023)# keep the the set the same
ALL_data_buffer = ALL_data.copy()
ALL_data_buffer = ALL_data_buffer.sample(frac=1).reset_index(drop=True) #打乱顺序
####################################################################随机选取test
# test_df,tv_df = Pair_ID(ALL_data_buffer,0.2,Range_max=15,pair_num=1)
# ####################################################################随机选取test
test_df,tv_df = Pair_ID(ALL_data,0.2,Range_max=15,pair_num=1)
test_dataset = ECGHandle.ECG_Dataset(data_root,test_df,preprocess = True)

# test_df =all_dataset

In [None]:
Models_path = '/workspace/data/Interpretable_HTN/model/20230320_011459/20230320_061146/BestF1_2.pt'
save_root = Models_path[:-3]+'/'    
layervalue_root = save_root+'/layervalue/'    
NET = [Net.MLBFNet_GUR_o(True,True,True,2,Dropout_rate=0.3), ] # type: ignore    
criterion = torch.nn.CrossEntropyLoss()    
testmodel = NET[0].to(DEVICE)    
testmodel.load_state_dict(torch.load(Models_path))    
test_dataloader = Data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)      
y_true,y_pred,y_out,test_loss,test_acc = eval_model(test_dataloader,criterion,testmodel,DEVICE) # 验证模型

In [None]:
test_dataset.infos['out'] = np.array(y_out)[:,1]

In [None]:
p_double_list = ['748939',
'764590',
'763667',
'579114',
'753320',
'795543',
'788055',
'757960',
'772546',
'786546',
'780224',
'813202',
'814775',
'544375',
'828347',
'840799',
'757558',
'803369',
'822443',
'801013',
'327715',
'812001',
'378927',
'839205',
'824053',
'809810',
'633009',
'826823',
'825537',
'794560',]

p_double_df = test_dataset.infos[[True if i in p_double_list else False for i in test_dataset.infos['ID']]]
p_double_df.to_csv('p_double_df.csv',encoding='utf-8-sig')

In [None]:
def calculate_area(ECG:np.array,start_index:np.array,end_index:np.array): # type: ignore 
    p_area = 0
    for lead in range(12):
        for bet in range(len(start_index)):
            start = int(start_index[bet])
            end = int(end_index[bet])
            p_area = p_area + abs((ECG[lead,start:end] - ECG[lead,start:end].min()).sum())/(end-start)
    p_area = p_area/((len(start_index)))
    return p_area

In [None]:
waves_location_file_root = '/workspace/data/Preprocess_HTN/datas_/Wave/'
lead_index = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
Amplitude_name_list = []
for i in range(12):
    Amplitude_name_list.append('P_Amplitude_'+lead_index[i])
    Amplitude_name_list.append('Q_Amplitude_'+lead_index[i])
    Amplitude_name_list.append('R_Amplitude_'+lead_index[i])
    Amplitude_name_list.append('S_Amplitude_'+lead_index[i])
    Amplitude_name_list.append('T_Amplitude_'+lead_index[i])
Amplitude_features_12leads =    pd.DataFrame(columns=Amplitude_name_list)
Timing_features = pd.DataFrame(columns=['P_wave_duration', 'QRS_duration', 'T_wave_duration','PQ_interval', 'PR_interval', 'QT_interval', 'QTc_interval', 'RR_interval','P_area'])
infos = pd.DataFrame(columns=test_dataset.infos.columns)

for index in range(test_dataset.__len__()):#test_dataset.__len__()
    try:
        info =test_dataset.infos.iloc[index]
        ECG,_ = test_dataset.__getitem__(index)
        ECGfile_name = info['ECGFilename']
        
        FPT = np.load(waves_location_file_root+'/FPT/'+ECGfile_name+'.npy')
        Timing_feature_sync = np.load(waves_location_file_root+'/Timing_feature_sync/'+ECGfile_name+'.npy')
        Amplitude_feature_12leads = np.load(waves_location_file_root+'/Amplitude_feature_12leads/'+ECGfile_name+'.npy')
        
        Amplitude_feature_12leads = pd.DataFrame((Amplitude_feature_12leads.transpose(1,0,2)[1:-1]).reshape(-1,60),columns=Amplitude_name_list) #lead,bet,5 -> bet,lead,5 ->bet,lead*5 
        FPT = pd.DataFrame(FPT[1:-1],columns=['P_start', 'P_peak', 'P_end','Q_start', 'Q_peak', 'R_peak', 'S_peak', 'S_end','res0','T_start', 'T_peak', 'T_end','res0'])
        Timing_feature_sync = pd.DataFrame(Timing_feature_sync[1:-1],columns=['P_wave_duration', 'QRS_duration', 'T_wave_duration','PQ_interval', 'PR_interval', 'QT_interval', 'QTc_interval', 'RR_interval'])
        
        Timing_feature = Timing_feature_sync 
        # Timing_feature = Timing_feature_sync.div(Timing_feature_sync['RR_interval'],axis=0) #(bet,13)把每个间隔除以每个RR间期
        # Timing_feature['RR_interval'] = Timing_feature_sync['RR_interval']#恢复RR间期
        Timing_feature = Timing_feature[1:-1].mean() #
        Timing_feature['P_area'] = calculate_area(np.array(ECG),np.array(FPT['P_start']),np.array(FPT['P_end']))
        
        
        Amplitude_feature = Amplitude_feature_12leads.mean()
        Amplitude_features_12leads = Amplitude_features_12leads.append(Amplitude_feature,True)
        Timing_features = Timing_features.append(Timing_feature,True )
        
        infos = infos.append(info)
    except:
        ECGfile_name = test_dataset.infos.iloc[index]['ECGFilename']
        print('Err :',ECGfile_name)
        continue

In [None]:
Timing_features['ECGFilename'] = infos['ECGFilename'].values
Amplitude_features_12leads['ECGFilename'] = infos['ECGFilename'].values
select_info = ['name','ID','gender','age','检查时间','临床诊断', '诊断','ECGFilename', 'xmlPath', 'Q', 'label', 'out']
test_df_with_ECG_features = pd.merge(infos[select_info],pd.merge(Amplitude_features_12leads,Timing_features,on='ECGFilename',how='inner'),on='ECGFilename',how='inner')
test_df_with_ECG_features.to_csv('test_2023_infos_and_features.csv')

## 常规操作

In [None]:
ALL_data = ALL_data.rename(columns={'住院号':'ID','年龄':'age','性别':'gender','姓名':'name'}) 
ALL_data_buffer = ALL_data.copy()
seed_torch(2023)
ALL_data_buffer = ALL_data_buffer.sample(frac=1).reset_index(drop=True) #打乱顺序
# all_dataset = ECGHandle.ECG_Dataset(data_root,ALL_data_buffer,preprocess = True)
####################################################################随机选取test
test_df,tv_df = Pair_ID(ALL_data,0.2,Range_max=15,pair_num=1)
test_dataset = ECGHandle.ECG_Dataset(data_root,test_df,preprocess = True)

In [None]:
Models_path = '/workspace/data/Interpretable_HTN/model/20230322_030450/20230322_030450/BestF1_0.pt'
save_root = Models_path[:-3]+'/'    
layervalue_root = save_root+'/layervalue/'    
NET = [Net.MLBFNet_GUR_o(True,True,True,2,Dropout_rate=0.3), ] # type: ignore    
criterion = torch.nn.CrossEntropyLoss()    
testmodel = NET[0].to(DEVICE)    
testmodel.load_state_dict(torch.load(Models_path))    
test_dataloader = Data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)      
y_true,y_pred,y_out,test_loss,test_acc = eval_model(test_dataloader,criterion,testmodel,DEVICE) # 验证模型

In [None]:
np.array(y_out).shape

In [None]:
test_df_save = test_df.copy()
test_df_save['predict HTN possibility'] = np.array(y_out)[:,1]
print(test_df_save.head)

In [None]:
test_df_save[test_df_save['label']==1].__len__()

In [None]:
test_df_save.to_csv('./test.csv',encoding='utf_8_sig')

In [None]:

ALL_data.columns

In [None]:
duplicated_HTN_df = ALL_data[ALL_data.duplicated(subset=['ID'], keep=False) & (ALL_data['label']==1)& ( ~ ALL_data['检查时间'].isnull()) ]

In [None]:
duplicated_HTN_df.__len__()

In [None]:
duplicated_HTN_df['检查时间'] = pd.to_datetime(duplicated_HTN_df['检查时间'])

In [None]:
duplicated_HTN_df['date_diff'] = duplicated_HTN_df.groupby('ID')['检查时间'].apply(lambda x:abs( x.diff()).dt.total_seconds())

In [None]:
duplicated_HTN_ID_list_buffer = list(duplicated_HTN_df[duplicated_HTN_df['date_diff']>(31536000/2)]['ID'])
print(duplicated_HTN_ID_list_buffer.__len__())

In [None]:
duplicated_HTN_df= ALL_data[ALL_data['ID'].isin(duplicated_HTN_ID_list_buffer)].copy()
duplicated_HTN_dataset = ECGHandle.ECG_Dataset(data_root,duplicated_HTN_df,preprocess = True)

In [None]:
Models_path = '/workspace/data/Interpretable_HTN/model/20230322_030450/20230322_030450/BestF1_0.pt'
save_root = Models_path[:-3]+'/'    
layervalue_root = save_root+'/layervalue/'    
NET = [Net.MLBFNet_GUR_o(True,True,True,2,Dropout_rate=0.3), ] # type: ignore    
criterion = torch.nn.CrossEntropyLoss()    
testmodel = NET[0].to(DEVICE)    
testmodel.load_state_dict(torch.load(Models_path))    
duplicated_HTN_dataloader = Data.DataLoader(dataset=duplicated_HTN_dataset, batch_size=1, shuffle=False)      
y_true,y_pred,y_out,test_loss,test_acc = eval_model(duplicated_HTN_dataloader,criterion,testmodel,DEVICE) # 验证模型

In [None]:
duplicated_HTN_df['predict HTN possibility'] = np.array(y_out)[:,1]

In [None]:
duplicated_HTN_df[duplicated_HTN_df['predict HTN possibility']<0.5]['ID'].unique().tolist()

In [None]:
duplicated_HTN_df.to_csv('./duplicated_HTN_间隔半年以上.csv',encoding='utf_8_sig')

In [None]:
jpg_path = './jpg/duplicated_HTN/'
for index in range(duplicated_HTN_dataset.__len__()):#test_dataset.__len__()
    info =duplicated_HTN_dataset.infos.iloc[index]
    file_name = info['ID']+'_'+info['ECGFilename']
    ID = info['ID']
    date = info['检查时间']
    age = info['age']
    label = info['label']
    ECG,labels = duplicated_HTN_dataset.__getitem__(index)
    ECG = ECG*5000 #恢复
    ecg_plot.plot(ECG*4.88/1000, sample_rate = 500, title = 'ID:'+str(ID)+' '+ 'label: '+ str(label) +' '+'Date: '+str(date)+' '+'age: '+str(age)  ,row_height= 10,show_grid=True,show_separate_line=True)
    ecg_plot.save_as_jpg(file_name,jpg_path)

## 测试集上分类错误的NHTN

In [None]:
ERR_NHTN_df = test_df_save[ (test_df_save['label']==0) & (test_df_save['predict HTN possibility']>=0.5) ]
print(ERR_NHTN_df)

In [None]:
ERR_NHTN_df.to_csv('./ERR_NHTN.csv',encoding='utf_8_sig')

## std & mean check

In [None]:
FOLDS = 5
seed_torch(2023)
tv_df = tv_df.sample(frac=1).reset_index(drop=True) #打乱顺序

In [None]:
for fold in range(FOLDS):
    print("Fold "+str(fold)+" of "+str(FOLDS) + ' :')
    tv_df_buffer = tv_df.copy()
    HTN_tv_df = tv_df[(tv_df['label']==1) ].copy()
    NHTN_tv_df = tv_df[(tv_df['label']==0) ].copy()
    HTN_ID_tv_list = HTN_tv_df['ID'].unique().tolist() #tvset中所有的HTN的ID号
    HTN_tv_size = HTN_tv_df['ID'].unique().__len__()
    HTN_validate_size = int(HTN_tv_size//FOLDS)
    validate_start_index = HTN_validate_size*fold #star index for validate
    validate_df,tarin_df = Pair_ID(tv_df_buffer,0.2,star_index=validate_start_index,Range_max=15,pair_num=1)
    validate_dataset = ECGHandle.ECG_Dataset(data_root,validate_df,preprocess = True)
    
    train_pair_df,_ = Pair_ID(tarin_df,1,star_index=0,Range_max=15,pair_num=1,shuffle=True)
    train_dataset = ECGHandle.ECG_Dataset(data_root,train_pair_df ,preprocess = True)
    for i in range(12):
        print('lead:' ,i)
        print('test:{}',test_dataset.datas[:,i,:].std(),test_dataset.datas[:,i,:].mean())
        print('train:{}',train_dataset.datas[:,i,:].std(),train_dataset.datas[:,i,:].mean())
        print('validat:{}',validate_dataset.datas[:,i,:].std(),validate_dataset.datas[:,i,:].mean())

In [None]:
test_df[test_df['label'] == 0]['ID'].__len__()

In [None]:
tv_df[tv_df['label'] == 0]['ID'].__len__()

## 查看训练集、测试集、验证集的年龄、性别分布是否有差别

In [None]:
ALL_data.columns

In [None]:
ALL_data.hist(column='age', by='gender')

In [None]:
test_df.hist(column='age', by='gender')

In [None]:
seed_torch(2023)
test_df,tv_df = Pair_ID(ALL_data,0.2,Range_max=15,pair_num=1)
FOLDS = 5
seed_torch(2020)
tv_df = tv_df.sample(frac=1).reset_index(drop=True) #打乱顺序
for fold in range(FOLDS):
    print("Fold "+str(fold)+" of "+str(FOLDS) + ' :')
    tv_df_buffer = tv_df.copy()
    HTN_tv_df = tv_df[(tv_df['label']==1) ].copy()
    NHTN_tv_df = tv_df[(tv_df['label']==0) ].copy()
    HTN_ID_tv_list = HTN_tv_df['ID'].unique().tolist() #tvset中所有的HTN的ID号
    HTN_tv_size = HTN_tv_df['ID'].unique().__len__()
    HTN_validate_size = int(HTN_tv_size//FOLDS)
    validate_start_index = HTN_validate_size*fold #star index for validate
    validate_df,tarin_df = Pair_ID(tv_df_buffer,0.2,star_index=validate_start_index,Range_max=15,pair_num=1)
    
    train_pair_df,_ = Pair_ID(tarin_df,1,star_index=0,Range_max=15,pair_num=1,shuffle=True)
    validate_df.hist(column='age', by='gender')
    tarin_df.hist(column='age', by='gender')

#### 按照[18, 30, 40, 50, 60, 70, 110]年龄分组

In [None]:
# 按照指定的区间进行分组
bins = [18, 40, 60, 80, 110]
labels = ['18-40', '41-60', '61-80', '81-110']
ALL_data['agegroup'] = pd.cut(ALL_data['age'], bins=bins, labels=labels)
ALL_data.hist('agegroup',by = 'label',sharex=True)

In [None]:
# 按照指定的区间进行分组
bins = [18, 30, 40, 50, 60, 70, 80,110]
labels = ['18-30', '31-40', '41-50', '51-60', '61-70', '71-80','81-110']
ALL_data['agegroup'] = pd.cut(ALL_data['age'], bins=bins, labels=labels)
ALL_data.hist(column='agegroup',by = 'label',sharex=True)

In [None]:
HTN_df = ALL_data[(ALL_data['label']==1) ].drop_duplicates(subset=['ID'],keep = 'first').copy()

In [None]:
from sklearn.model_selection import train_test_split
# 按照agerange和gender进行分层抽样

NHTN_df = ALL_data[(ALL_data['label']==0) ].drop_duplicates(subset=['ID'],keep = 'first').copy()
HTN_df = ALL_data[(ALL_data['label']==1) ].drop_duplicates(subset=['ID'],keep = 'first').copy()
TV_HTN_df, test_HTN_df = train_test_split(HTN_df, test_size=0.2, stratify=HTN_df[['agegroup', 'gender']])
test_ID_list = pair_HTN(test_HTN_df.drop_duplicates(['ID'],keep='first'),NHTN_df.drop_duplicates(['ID'],keep='first'),
                            Range_max=2,
                            pair_num=1,
                            shuffle=True)['ID'].tolist()#按照年龄和性别对每个ID号去配对 (先去除重复ID)
pair_index = ALL_data[[True if i in test_ID_list else False for i in ALL_data['ID']]].index
test_df = ALL_data.loc[pair_index].copy()
left_index = ALL_data[[False if i in test_ID_list else True for i in ALL_data['ID']]].index #不在test_ID_list的ID 即为tv的
TV_df = ALL_data.loc[left_index].copy()



TV_NHTN_df = TV_df[(TV_df['label']==0) ].drop_duplicates(subset=['ID'],keep = 'first').copy()
TV_HTN_df = TV_df[(TV_df['label']==1) ].drop_duplicates(subset=['ID'],keep = 'first').copy()
fold_len= float((TV_HTN_df.__len__())//5) #每一fold的HTN的长度
TV_HTN_buffer = TV_HTN_df.copy()

validat_HTN_df_subsets = []
for i in range(4):
    TV_HTN_buffer, subset = train_test_split(TV_HTN_buffer, test_size=fold_len/(TV_HTN_buffer.__len__()), stratify=TV_HTN_buffer[['agegroup', 'gender']])
    print(subset.__len__())
    validat_HTN_df_subsets.append(subset)
print(TV_HTN_buffer.__len__())
validat_HTN_df_subsets.append(TV_HTN_buffer)

validat_ID_list_subsets = []
for i in range(validat_HTN_df_subsets.__len__()):
    i_ID_list = pair_HTN(validat_HTN_df_subsets[i].drop_duplicates(['ID'],keep='first'),
                         TV_NHTN_df.drop_duplicates(['ID'],keep='first'),
                            Range_max=2,
                            pair_num=1,
                            shuffle=True)['ID'].tolist()#按照年龄和性别对每个ID号去配对 (先去除重复ID)
    validat_ID_list_subsets.append(i_ID_list)


In [None]:
for fold in range(5):
    print(" "*10+ "Fold "+str(fold)+" of "+str(5) + ' :')

    pair_index = TV_df[[True if i in validat_ID_list_subsets[fold] else False for i in TV_df['ID']]].index
    validate_df = TV_df.loc[pair_index].copy()
    left_index = TV_df[[False if i in validat_ID_list_subsets[fold] else True for i in TV_df['ID']]].index #不在test_ID_list的ID 即为tv的
    train_df = TV_df.loc[left_index].copy()
    train_pair_df,_ = Pair_ID(train_df,1,star_index=0,Range_max=5,pair_num=1,shuffle=True)

    validate_df.hist('age',by = 'label',sharex=True)
    train_pair_df.hist('age',by = 'label',sharex=True)

In [None]:
test_df.hist('age',by = 'label',sharex=True)

## plot

In [None]:
jpg_path = './jpg/'
for index in range(test_dataset.__len__()):#test_dataset.__len__()
    info =test_dataset.infos.iloc[index]
    file_name = info['ECGFilename']
    ID = info['ID']
    date = info['检查时间']
    age = info['age']
    label = info['label']
    ECG,labels = test_dataset.__getitem__(index)
    ECG = ECG*5000 #恢复
    ecg_plot.plot(ECG*4.88/1000, sample_rate = 500, title = 'ID:'+str(ID)+' '+ 'label: '+ str(label) +' '+'Date: '+str(date)+' '+'age: '+str(age)  ,row_height= 10,show_grid=True,show_separate_line=True)
    ecg_plot.save_as_jpg(file_name,jpg_path)