In [None]:
import math
import matplotlib.pyplot as plt
import heartpy as hp
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from random import shuffle
from scipy.interpolate import CubicSpline
from numba import jit

import numpy as np
import pandas as pd

# pd.set_option('display.height', 1000)
pd.set_option("display.max_rows", 500)
pd.set_option("display.max_columns", 500)
pd.set_option("display.width", 1000)
pd.set_option("max_colwidth", 200)
from IPython.core.display import HTML, display

display(HTML("<style>.container { width:100% !important; }</style>"))

def df_stats(df):
    from tabulate import tabulate

    print("\n***** Shape: ", df.shape, " *****\n")

    columns_list = df.columns.values.tolist()
    isnull_list = df.isnull().sum().values.tolist()
    isunique_list = df.nunique().values.tolist()
    dtypes_list = df.dtypes.tolist()

    list_stat_val = list(zip(columns_list, isnull_list, isunique_list, dtypes_list))
    df_stat_val = pd.DataFrame(
        list_stat_val, columns=["Name", "Null", "Unique", "Dtypes"]
    )
    print(tabulate(df_stat_val, headers="keys", tablefmt="psql"))
    return df.head()


### Load predictions and remove columns that are not needed

In [None]:
import pandas as pd
from tqdm import tqdm
def process_data(data, extend_to_remove_labels=True):
    data = data[~data['npy_path'].str.contains('Error')]
    data = data[data['npy_path'] != '/media/data1/anolin/temp_new_dataset/ecg_npy/0161727_03-25-2020_10-06-16.npy']
    to_remove_labels = ['ST depression (posterior - V7-V8-V9)', 'Tall >2.5 mm', 'J wave', 'Auricular bigeminy', 
                        'Ventricular bigeminy', 'Sinus Pause', 'Dextrocardia', 
                        'Hyperacute T wave (lateral, V5-V6)', 'Hyperacute T wave (septal, V1-V2)', 
                        'Hyperacute T wave (anterior, V3-V4)', 'Bifid', 'RaVL + SV3 > 28 mm (H) or 20 mm (F)', 
                        'Large >0.08 s', 'Biphasic', 'ST depression et T inversion in V5 or V6']

    if (extend_to_remove_labels):
        # Extend to_remove_labels to include both suffixes
        extended_to_remove_labels = [label + suffix for label in to_remove_labels for suffix in ['_CARDIOLOGIST', '_MUSE']]
    else:
        extended_to_remove_labels = to_remove_labels

    # Drop the extended list of columns from the DataFrame
    data = data.drop(columns=extended_to_remove_labels, errors='ignore')
    return data

df_train = pd.read_parquet('/media/data1/muse_ge/train_trial_v1.1.parquet')
df_val = pd.read_parquet('/media/data1/muse_ge/val_trial_v1.1.parquet')
df_test = pd.read_parquet('/media/data1/muse_ge/test_trial_v1.1.parquet')

# Process data using tqdm loop
data_frames = [df_train, df_val, df_test]
for i in tqdm(range(len(data_frames))):
    data_frames[i] = process_data(data_frames[i])

df_train, df_val, df_test = data_frames

In [None]:
display(df_train.diagnosis.head(n=5))

In [None]:
df_ecg = pd.read_parquet('/media/data1/muse_ge/ECG_ad202207_1453937_cat_labels_MUSE_vs_CARDIOLOGIST_v1.2.parquet')
import numpy as np
conditions = [
    df_ecg['npy_path'].isin(df_train['npy_path']),
    df_ecg['npy_path'].isin(df_val['npy_path']),
    df_ecg['npy_path'].isin(df_test['npy_path'])
]

choices = ['train', 'val', 'test']

df_ecg['dataset'] = np.select(conditions, choices, default='unknown')
df_ecg = process_data(df_ecg)


In [None]:
y_label_names = ['Sinusal','Regular','Monomorph','QS complex in V1-V2-V3','R complex in V5-V6','T wave inversion (inferior - II, III, aVF)','Left bundle branch block','RaVL > 11 mm','SV1 + RV5 or RV6 > 35 mm','T wave inversion (lateral -I, aVL, V5-V6)','T wave inversion (anterior - V3-V4)','Left axis deviation','Left ventricular hypertrophy','Bradycardia','Q wave (inferior - II, III, aVF)','Afib','Irregularly irregular','Atrial tachycardia (>= 100 BPM)','Nonspecific intraventricular conduction delay','Premature ventricular complex','Polymorph','T wave inversion (septal- V1-V2)','Right bundle branch block','Ventricular paced','ST elevation (anterior - V3-V4)','ST elevation (septal - V1-V2)','1st degree AV block','Premature atrial complex','Atrial flutter',"rSR' in V1-V2",'qRS in V5-V6-I, aVL','Left anterior fascicular block','Right axis deviation','2nd degree AV block - mobitz 1','ST depression (inferior - II, III, aVF)','Acute pericarditis','ST elevation (inferior - II, III, aVF)','Low voltage','Regularly irregular','Bifid','Junctional rhythm','Left atrial enlargement','ST elevation (lateral - I, aVL, V5-V6)','Atrial paced','Right ventricular hypertrophy','Delta wave','Wolff-Parkinson-White (Pre-excitation syndrome)','Prolonged QT','ST depression (anterior - V3-V4)','QRS complex negative in III','RaVL + SV3 > 28 mm (H) or 20 mm (F)','Q wave (lateral- I, aVL, V5-V6)','Hyperacute T wave (lateral, V5-V6)','Hyperacute T wave (septal, V1-V2)','Supraventricular tachycardia','ST downslopping','ST depression (lateral - I, avL, V5-V6)','2nd degree AV block - mobitz 2','U wave','ST depression et T inversion in V5 or V6','Large >0.08 s','R/S ratio in V1-V2 >1','RV1 + SV6\xa0> 11 mm','Left posterior fascicular block','Right atrial enlargement','ST depression (septal- V1-V2)','Q wave (septal- V1-V2)','Q wave (anterior - V3-V4)','Hyperacute T wave (anterior, V3-V4)','ST upslopping','Right superior axis','Auricular bigeminy','Ventricular tachycardia','ST elevation (posterior - V7-V8-V9)','Ectopic atrial rhythm (< 100 BPM)','Lead misplacement','Biphasic','Ventricular bigeminy','J wave','Tall >2.5 mm','Third Degree AV Block','Sinus Pause','Acute MI','Early repolarization','Q wave (posterior - V7-V9)','Bi-atrial enlargement','LV pacing','Dextrocardia','Brugada','Ventricular Rhythm','ST depression (posterior - V7-V8-V9)','no_qrs']
labels_to_remove = ['ST depression (posterior - V7-V8-V9)','Tall >2.5 mm', 'J wave', 'Auricular bigeminy', 'Ventricular bigeminy', 'Sinus Pause', 'Dextrocardia', 'Hyperacute T wave (lateral, V5-V6)', 'Hyperacute T wave (septal, V1-V2)', 'Hyperacute T wave (anterior, V3-V4)', 'Bifid', 'RaVL + SV3 > 28 mm (H) or 20 mm (F)', 'Large >0.08 s', 'Biphasic', 'ST depression et T inversion in V5 or V6']

# Remove specified labels from y_label_names
y_label_names = [label for label in y_label_names if label not in labels_to_remove]
print(len(y_label_names))
print(y_label_names)

### Compute performance

In [None]:
X_val = np.load('/media/data1/muse_ge/X_test_v1.2.npy').astype(np.float16)
Y_val = np.load('/media/data1/muse_ge/Y_test_v1.2.npy').astype(np.float16)

og_labels =  ['Sinusal','Regular','Monomorph','QS complex in V1-V2-V3','R complex in V5-V6','T wave inversion (inferior - II, III, aVF)','Left bundle branch block','RaVL > 11 mm','SV1 + RV5 or RV6 > 35 mm','T wave inversion (lateral -I, aVL, V5-V6)','T wave inversion (anterior - V3-V4)','Left axis deviation','Left ventricular hypertrophy','Bradycardia','Q wave (inferior - II, III, aVF)','Afib','Irregularly irregular','Atrial tachycardia (>= 100 BPM)','Nonspecific intraventricular conduction delay','Premature ventricular complex','Polymorph','T wave inversion (septal- V1-V2)','Right bundle branch block','Ventricular paced','ST elevation (anterior - V3-V4)','ST elevation (septal - V1-V2)','1st degree AV block','Premature atrial complex','Atrial flutter',"rSR' in V1-V2",'qRS in V5-V6-I, aVL','Left anterior fascicular block','Right axis deviation','2nd degree AV block - mobitz 1','ST depression (inferior - II, III, aVF)','Acute pericarditis','ST elevation (inferior - II, III, aVF)','Low voltage','Regularly irregular','Bifid','Junctional rhythm','Left atrial enlargement','ST elevation (lateral - I, aVL, V5-V6)','Atrial paced','Right ventricular hypertrophy','Delta wave','Wolff-Parkinson-White (Pre-excitation syndrome)','Prolonged QT','ST depression (anterior - V3-V4)','QRS complex negative in III','RaVL + SV3 > 28 mm (H) or 20 mm (F)','Q wave (lateral- I, aVL, V5-V6)','Hyperacute T wave (lateral, V5-V6)','Hyperacute T wave (septal, V1-V2)','Supraventricular tachycardia','ST downslopping','ST depression (lateral - I, avL, V5-V6)','2nd degree AV block - mobitz 2','U wave','ST depression et T inversion in V5 or V6','Large >0.08 s','R/S ratio in V1-V2 >1','RV1 + SV6\xa0> 11 mm','Left posterior fascicular block','Right atrial enlargement','ST depression (septal- V1-V2)','Q wave (septal- V1-V2)','Q wave (anterior - V3-V4)','Hyperacute T wave (anterior, V3-V4)','ST upslopping','Right superior axis','Auricular bigeminy','Ventricular tachycardia','ST elevation (posterior - V7-V8-V9)','Ectopic atrial rhythm (< 100 BPM)','Lead misplacement','Biphasic','Ventricular bigeminy','J wave','Tall >2.5 mm','Third Degree AV Block','Sinus Pause','Acute MI','Early repolarization','Q wave (posterior - V7-V9)','Bi-atrial enlargement','LV pacing','Dextrocardia','Brugada','Ventricular Rhythm','ST depression (posterior - V7-V8-V9)','no_qrs']
to_remove_labels = ['ST depression (posterior - V7-V8-V9)','Tall >2.5 mm', 'J wave', 'Auricular bigeminy', 'Ventricular bigeminy', 'Sinus Pause', 'Dextrocardia', 'Hyperacute T wave (lateral, V5-V6)', 'Hyperacute T wave (septal, V1-V2)', 'Hyperacute T wave (anterior, V3-V4)', 'Bifid', 'RaVL + SV3 > 28 mm (H) or 20 mm (F)', 'Large >0.08 s', 'Biphasic', 'ST depression et T inversion in V5 or V6']

pos_to_drop = list()
new_label_names = list()
for pos, item in enumerate(og_labels):
    if item in to_remove_labels:
        pos_to_drop.append(pos)
    else:
        new_label_names.append(item)
Y_val = np.delete(Y_val, pos_to_drop, axis=1)

In [None]:
# Assuming df_ecg is your DataFrame
# Identify columns ending with _CARDIOLOGIST and _MUSE
cardiologist_cols = [col for col in df_ecg.columns if col.endswith('_CARDIOLOGIST')]
muse_cols = [col for col in df_ecg.columns if col.endswith('_MUSE')]
# Create subsets for each
df_cardiologist = df_ecg[cardiologist_cols]
df_muse = df_ecg[muse_cols]
# Keep only columns between Sinusal_CARDIOLOGIST and no_qrs_CARDIOLOGIST
start_col = "Sinusal_CARDIOLOGIST"
end_col = "no_qrs_CARDIOLOGIST"

df_cardiologist = df_cardiologist.loc[:, start_col:end_col]
start_col = "Sinusal_MUSE"
end_col = "no_qrs_MUSE"

df_muse = df_muse.loc[:, start_col:end_col]

df_cardiologist = df_cardiologist.apply(lambda x: x.apply(lambda y: 1 if y >= 1 else 0))
df_muse = df_muse.apply(lambda x: x.apply(lambda y: 1 if y >= 1 else 0))

In [None]:
df_cardiologist = pd.merge(df_cardiologist, df_ecg[['npy_path', 'dataset', 'validated by MD', 'diagnosis', 'original_diagnosis']], left_index=True, right_index=True)
df_muse = pd.merge(df_muse, df_ecg[['npy_path', 'dataset', 'validated by MD', 'diagnosis', 'original_diagnosis']], left_index=True, right_index=True)

one_hot_encoded_cardiologist_val = df_cardiologist[df_cardiologist['dataset'] == 'val']
one_hot_encoded_muse_val = df_muse[df_muse['dataset'] == 'val']

In [None]:
def sigmoid(x):
  return 1 / (1 + math.exp(-x))

sigmoid_v = np.vectorize(sigmoid)

#### Load Y_PRED : TO DO - LOAD FINAL MODEL OUTPUT

In [None]:
Y_pred = np.load('/media/data1/anolin/results_benchmarkv2/resnet50_notscaled_1997_v1.1/output_1.npy')

def sigmoid(x):
  return 1 / (1 + math.exp(-x))

sigmoid_v = np.vectorize(sigmoid)
Y_pred = sigmoid_v(Y_pred)
#display(Y_pred)
sigmoid_v_bin =  np.where(Y_pred > 0.5, 1, 0)
#sigmoid_v_bin

In [None]:
# Assuming sigmoid_v_bin, one_hot_encoded_cardiologist_val, and one_hot_encoded_muse_val have aligned indices

# Create a dataframe for sigmoid_v_bin with labels
df_sigmoid_v_bin = pd.DataFrame(sigmoid_v_bin, columns=new_label_names)

# Reset index if necessary and ensure it aligns across all DataFrames
df_sigmoid_v_bin.reset_index(drop=True, inplace=True)
one_hot_encoded_cardiologist_val.reset_index(drop=True, inplace=True)
one_hot_encoded_muse_val.reset_index(drop=True, inplace=True)
try:
    one_hot_encoded_muse_val = one_hot_encoded_muse_val.drop(['npy_path', 'dataset', 'validated by MD', 'diagnosis', 'original_diagnosis'], axis=1)
except:
    pass

# Merge the dataframes on index
df_merged = pd.concat([df_sigmoid_v_bin, one_hot_encoded_cardiologist_val, one_hot_encoded_muse_val], axis=1).reset_index()
df_merged.reset_index(inplace=True)
condition = (df_merged['validated by MD'] == 1) & (df_merged['diagnosis'] != df_merged['original_diagnosis'])
df_modified_by_cardiologist = df_merged.loc[condition]
df_validated_by_cardiologist = df_merged.loc[df_merged['validated by MD'] == 1]
df_not_validated_by_cardiologist = df_merged.loc[df_merged['validated by MD'] == 0]

In [None]:
#df_modified_by_cardiologist.to_csv('data/core_model_performance/df_modified_by_cardiologist.csv')
#df_validated_by_cardiologist.to_csv('data/core_model_performance/df_validated_by_cardiologist.csv')
#df_not_validated_by_cardiologist.to_csv('data/core_model_performance/df_not_validated_by_cardiologist.csv')

#### SELECT A DATAFRAME FOR SUBSEQUENT ANALYSES

In [None]:


df_SELECTED = df_validated_by_cardiologist
# Remove unnecessary columns
unnecessary_columns = ['Unnamed: 0', 'level_0', 'index']
df_SELECTED.drop(columns=unnecessary_columns, inplace=True, errors='ignore')


### Compute per 'category' metrics and 'overall' metrics'

In [None]:
from utils import plot_micro_macro_statistics

import json

# Define the file path
file_path = "utils/categories.json"

# Load the JSON file
with open(file_path, "r") as file:
    categories = json.load(file)

# Print the loaded categories
print(categories)

In [None]:
# Compute micro and macro AUC/PR for CARDIOLOGIST
micro_ROC_cardiologist, micro_PR_cardiologist, Sensitivity_micro_avg_cardiologist, Specificity_micro_avg_cardiologist = plot_micro_macro_statistics.compute_macro_metrics_all_with_youden(df_SELECTED, '_CARDIOLOGIST')
macro_ROC_cardiologist, macro_PR_cardiologist, Sensitivity_macro_avg_cardiologist, Specificity_macro_avg_cardiologist = plot_micro_macro_statistics.compute_macro_metrics_all_with_youden(df_SELECTED, '_CARDIOLOGIST')

# Compute micro and macro AUC/PR for MUSE
micro_ROC_muse, micro_PR_muse, Sensitivity_micro_avg_muse, Specificity_micro_avg_muse = plot_micro_macro_statistics.compute_micro_metrics_all_with_youden(df_SELECTED, '_MUSE')
macro_ROC_muse, macro_PR_muse, Sensitivity_macro_avg_muse, Specificity_macro_avg_muse = plot_micro_macro_statistics.compute_macro_metrics_all_with_youden(df_SELECTED, '_MUSE')
# Create a dictionary with metrics
metrics = {
    'Sensitivity_micro': [Sensitivity_micro_avg_cardiologist, Sensitivity_micro_avg_muse],
    'Specificity_micro': [Specificity_micro_avg_cardiologist, Specificity_micro_avg_muse],
    'ROC_micro': [micro_ROC_cardiologist, micro_ROC_muse],
    'PR_micro': [micro_PR_cardiologist, micro_PR_muse],
    'Sensitivity_macro': [Sensitivity_macro_avg_cardiologist, Sensitivity_macro_avg_muse],
    'Specificity_macro': [Specificity_macro_avg_cardiologist, Specificity_macro_avg_muse],
    'ROC_macro': [macro_ROC_cardiologist, macro_ROC_muse],
    'PR_macro': [macro_PR_cardiologist, macro_PR_muse]
}

# Convert the dictionary into a DataFrame
df_simple = pd.DataFrame(metrics, index=['Cardiologist', 'MUSE'])

# Display the DataFrame
df_simple = df_simple.applymap(lambda x: round(x * 100, 1))


In [None]:
df_simple.to_csv('data/core_model_performance/overall_performance.csv', index=True)

In [None]:

# Compute the metrics for the provided categories
metrics_results = plot_micro_macro_statistics.compute_metrics_for_categories(df_SELECTED, categories)
# Converting the nested dictionary to DataFrame
metrics_df = pd.DataFrame.from_dict(metrics_results, orient='index').reset_index()


# Flattening the nested dictionaries and ensuring unique column names
for column in ['CARDIOLOGIST', 'MUSE']:
    flattened = pd.json_normalize(metrics_df[column])
    # Adding a prefix to the column names to ensure uniqueness
    flattened.columns = [f"{column}_{subcol}" for subcol in flattened.columns]
    # Dropping the original column to avoid name conflict
    metrics_df = metrics_df.drop(columns=[column])
    # Joining the flattened DataFrame
    metrics_df = metrics_df.join(flattened)
    

# Display the DataFrame
for col in metrics_df.columns:
    if metrics_df[col].dtype.kind in 'bifc':  # checks if the column is numerical
        metrics_df[col] = metrics_df[col].apply(lambda x: round(x * 100, 1))

In [None]:
display(metrics_df)
metrics_df.to_csv('data/core_model_performance/categorical_performance.csv', index=True)

### Compute for each label

In [None]:
# Usage
# Assuming df_merged is the merged DataFrame containing sigmoid_v_bin, one_hot_encoded_cardiologist_val, and one_hot_encoded_muse_val
metrics_results = plot_micro_macro_statistics.compute_individual_metrics(df_SELECTED, new_label_names)

# Create an empty DataFrame
df_metrics = pd.DataFrame()

# Building the DataFrame
data = []
for label, suffixes in metrics_results.items():
    row = {'Label': label}
    for suffix, metrics in suffixes.items():
        for metric_name, value in metrics.items():
            col_name = metric_name + suffix
            row[col_name] = value
    data.append(row)

# Convert the list of dictionaries to a DataFrame
df_metrics = pd.DataFrame(data)

# Reorder columns to put ROC columns next to each other and PR columns next to each other
df_metrics = df_metrics[['Label', 'ROC_CARDIOLOGIST', 'Sensitivity_CARDIOLOGIST', 'Specificity_CARDIOLOGIST', 'PR_CARDIOLOGIST', 'ROC_MUSE',  'Sensitivity_MUSE', 'Specificity_MUSE', 'PR_MUSE']]

df_metrics['PR_diff'] = df_metrics['PR_CARDIOLOGIST'] - df_metrics['PR_MUSE']
df_metrics['ROC_diff'] = df_metrics['ROC_CARDIOLOGIST'] - df_metrics['ROC_MUSE']
df_metrics['Sensitivity_diff'] = df_metrics['Sensitivity_CARDIOLOGIST'] - df_metrics['Sensitivity_MUSE']
df_metrics['Specificity_diff'] = df_metrics['Specificity_CARDIOLOGIST'] - df_metrics['Specificity_MUSE']


In [None]:

# Revised function to group the df_metrics DataFrame by the provided categories
def group_metrics_by_categories(df, categories):
    """
    Groups the df_metrics DataFrame by specified categories based on the 'Label' column.

    Args:
    df (pd.DataFrame): The DataFrame to be grouped (df_metrics).
    categories (dict): A dictionary where keys are category names and values are lists of labels belonging to each category.

    Returns:
    dict: A dictionary of DataFrames, where each key is a category and each value is a DataFrame of rows belonging to that category.
    """

    # Creating a label to category mapping
    label_to_category = {label: category for category, labels in categories.items() for label in labels}

    # Mapping each row to its category based on the 'Label' column
    df['Category'] = df['Label'].map(label_to_category)

    # Grouping the DataFrame by the 'Category' column
    grouped_data_dict = {category: group for category, group in df.groupby('Category')}

    return grouped_data_dict

# Applying the revised function to group df_metrics by the provided categories
grouped_metrics_data = group_metrics_by_categories(df_metrics, categories)



# Sorting this DataFrame by 'Category'
sorted_category_label_df = df_metrics.sort_values(by=['Category','ROC_CARDIOLOGIST'], ascending=False)

# Display the DataFrame
for col in sorted_category_label_df.columns:
    if sorted_category_label_df[col].dtype.kind in 'bifc':  # checks if the column is numerical
        sorted_category_label_df[col] = sorted_category_label_df[col].apply(lambda x: round(x * 100, 1))

sorted_category_label_df.to_csv('data/core_model_performance/individual_performance.csv', index=True)

In [None]:
### EchoNext

In [None]:
df = pd.read_csv('data/EchoNext/echo_results.per_ecg.csv')
display(df.head(n=5))

# this is to get the label frequency from the train set

In [None]:
Y_val_ = np.load('/media/data1/muse_ge/Y_train_v1.1.npy').astype(np.int64)

pos_to_drop = list()
new_label_names = list()
for pos, item in enumerate(og_labels):
    if item in to_remove_labels:
        pos_to_drop.append(pos)
    else:
        new_label_names.append(item)

#print(Y_train.shape)
Y_val_ = np.delete(Y_val_, pos_to_drop, axis=1)


label_counts = np.sum(Y_val_, axis=0)
label_counts/Y_val_.shape[0]

In [None]:
#add the prevalence to the df
df_out['Prevalence'] = label_counts/Y_val_.shape[0]

In [None]:
#create a label for ease of use in matplotlib
df_out_ = df_out[['ROC','Prevalence']]
df_out_.index = [f'{i} ({"{:.3f}".format(j)})' for i,j in zip(df_out_.index,df_out_.ROC)]

#### Robert final approach (seems finicky with figsize)

In [None]:
import seaborn as sns
from matplotlib.pyplot import figure
sns.set_style("whitegrid")

figure(figsize=(12, 2), dpi=80)
df_out_.sort_values('ROC').plot( kind= 'bar' , secondary_y= 'Prevalence' )
plt.xticks(rotation=90, fontsize=8)
plt.title('Distribution of ROC avg 3 seed')
#plt.savefig('/volume/core_model/ROC.jpg', dpi=600, bbox_inches="tight")


### General performance histogram

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
sns.histplot(df_out['ROC'])
plt.title("Score distribution")

# first approach without frenquency

In [None]:
from sklearn.metrics import roc_auc_score, auc, accuracy_score
from sklearn.metrics import precision_recall_curve
import seaborn as sns
import matplotlib.pyplot as plt

og_labels =  ['Sinusal','Regular','Monomorph','QS complex in V1-V2-V3','R complex in V5-V6','T wave inversion (inferior - II, III, aVF)','Left bundle branch block','RaVL > 11 mm','SV1 + RV5 or RV6 > 35 mm','T wave inversion (lateral -I, aVL, V5-V6)','T wave inversion (anterior - V3-V4)','Left axis deviation','Left ventricular hypertrophy','Bradycardia','Q wave (inferior - II, III, aVF)','Afib','Irregularly irregular','Atrial tachycardia (>= 100 BPM)','Nonspecific intraventricular conduction delay','Premature ventricular complex','Polymorph','T wave inversion (septal- V1-V2)','Right bundle branch block','Ventricular paced','ST elevation (anterior - V3-V4)','ST elevation (septal - V1-V2)','1st degree AV block','Premature atrial complex','Atrial flutter',"rSR' in V1-V2",'qRS in V5-V6-I, aVL','Left anterior fascicular block','Right axis deviation','2nd degree AV block - mobitz 1','ST depression (inferior - II, III, aVF)','Acute pericarditis','ST elevation (inferior - II, III, aVF)','Low voltage','Regularly irregular','Bifid','Junctional rhythm','Left atrial enlargement','ST elevation (lateral - I, aVL, V5-V6)','Atrial paced','Right ventricular hypertrophy','Delta wave','Wolff-Parkinson-White (Pre-excitation syndrome)','Prolonged QT','ST depression (anterior - V3-V4)','QRS complex negative in III','RaVL + SV3 > 28 mm (H) or 20 mm (F)','Q wave (lateral- I, aVL, V5-V6)','Hyperacute T wave (lateral, V5-V6)','Hyperacute T wave (septal, V1-V2)','Supraventricular tachycardia','ST downslopping','ST depression (lateral - I, avL, V5-V6)','2nd degree AV block - mobitz 2','U wave','ST depression et T inversion in V5 or V6','Large >0.08 s','R/S ratio in V1-V2 >1','RV1 + SV6\xa0> 11 mm','Left posterior fascicular block','Right atrial enlargement','ST depression (septal- V1-V2)','Q wave (septal- V1-V2)','Q wave (anterior - V3-V4)','Hyperacute T wave (anterior, V3-V4)','ST upslopping','Right superior axis','Auricular bigeminy','Ventricular tachycardia','ST elevation (posterior - V7-V8-V9)','Ectopic atrial rhythm (< 100 BPM)','Lead misplacement','Biphasic','Ventricular bigeminy','J wave','Tall >2.5 mm','Third Degree AV Block','Sinus Pause','Acute MI','Early repolarization','Q wave (posterior - V7-V9)','Bi-atrial enlargement','LV pacing','Dextrocardia','Brugada','Ventricular Rhythm','ST depression (posterior - V7-V8-V9)','no_qrs']
to_remove_labels = ['ST depression (posterior - V7-V8-V9)','Tall >2.5 mm', 'J wave', 'Auricular bigeminy', 'Ventricular bigeminy', 'Sinus Pause', 'Dextrocardia', 'Hyperacute T wave (lateral, V5-V6)', 'Hyperacute T wave (septal, V1-V2)', 'Hyperacute T wave (anterior, V3-V4)', 'Bifid', 'RaVL + SV3 > 28 mm (H) or 20 mm (F)', 'Large >0.08 s', 'Biphasic', 'ST depression et T inversion in V5 or V6']

Y_val = np.load('/media/data1/anolin/Y_val_v1.1.npy').astype(np.int64)

pos_to_drop = list()
new_label_names = list()
for pos, item in enumerate(og_labels):
    if item in to_remove_labels:
        pos_to_drop.append(pos)
    else:
        new_label_names.append(item)

#print(Y_train.shape)
Y_val = np.delete(Y_val, pos_to_drop, axis=1)

dict_results = dict(zip(new_label_names,[[0,0,0] for _ in range(len(new_label_names))]))
for i in [2023,1997,42]:
    Y_pred = np.load(f'/media/data1/anolin/results_benchmarkv2/resnet50_notscaled_{i}_v1.1/output_1.npy')
    Y_pred = sigmoid_v(Y_pred)
    sigmoid_v_bin =  np.where(Y_pred > 0.5, 1, 0)
    

    for pos, label in enumerate(new_label_names):
        try:
            ROC = roc_auc_score(Y_val[:,pos], sigmoid_v_bin[:,pos], average=None)
        except:
            ROC = 0.5
        precision, recall, thresholds = precision_recall_curve(Y_val[:,pos], sigmoid_v_bin[:,pos])
        PR = auc(recall, precision)

        acc = accuracy_score(Y_val[:,pos], sigmoid_v_bin[:,pos])

        dict_results[label][0] += ROC
        dict_results[label][1] += PR
        dict_results[label][2] += acc

for k,v in dict_results.items():
    v[0] = v[0]/3
    v[1] = v[1]/3
    v[2] = v[2]/3
df_out = pd.DataFrame.from_dict(dict_results).T
df_out.columns = ['ROC','PR','ACC']
df_out

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(rc={'figure.figsize':(15.7,2)})

sns.barplot(y=df_out.sort_values('ACC', ascending=True)['ACC'], x=df_out.sort_values('ACC', ascending=True).index)
plt.xticks(rotation=90, fontsize=8)
plt.title('Distribution of ACC avg 3 seed')

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
sns.histplot(df_out['ROC'])
plt.title("Score distribution")

# test the scaling's impace on the results

In [None]:
Y_val = np.load('/media/data1/anolin/Y_val_v1.1.npy').astype(np.int64)
og_labels =  ['Sinusal','Regular','Monomorph','QS complex in V1-V2-V3','R complex in V5-V6','T wave inversion (inferior - II, III, aVF)','Left bundle branch block','RaVL > 11 mm','SV1 + RV5 or RV6 > 35 mm','T wave inversion (lateral -I, aVL, V5-V6)','T wave inversion (anterior - V3-V4)','Left axis deviation','Left ventricular hypertrophy','Bradycardia','Q wave (inferior - II, III, aVF)','Afib','Irregularly irregular','Atrial tachycardia (>= 100 BPM)','Nonspecific intraventricular conduction delay','Premature ventricular complex','Polymorph','T wave inversion (septal- V1-V2)','Right bundle branch block','Ventricular paced','ST elevation (anterior - V3-V4)','ST elevation (septal - V1-V2)','1st degree AV block','Premature atrial complex','Atrial flutter',"rSR' in V1-V2",'qRS in V5-V6-I, aVL','Left anterior fascicular block','Right axis deviation','2nd degree AV block - mobitz 1','ST depression (inferior - II, III, aVF)','Acute pericarditis','ST elevation (inferior - II, III, aVF)','Low voltage','Regularly irregular','Bifid','Junctional rhythm','Left atrial enlargement','ST elevation (lateral - I, aVL, V5-V6)','Atrial paced','Right ventricular hypertrophy','Delta wave','Wolff-Parkinson-White (Pre-excitation syndrome)','Prolonged QT','ST depression (anterior - V3-V4)','QRS complex negative in III','RaVL + SV3 > 28 mm (H) or 20 mm (F)','Q wave (lateral- I, aVL, V5-V6)','Hyperacute T wave (lateral, V5-V6)','Hyperacute T wave (septal, V1-V2)','Supraventricular tachycardia','ST downslopping','ST depression (lateral - I, avL, V5-V6)','2nd degree AV block - mobitz 2','U wave','ST depression et T inversion in V5 or V6','Large >0.08 s','R/S ratio in V1-V2 >1','RV1 + SV6\xa0> 11 mm','Left posterior fascicular block','Right atrial enlargement','ST depression (septal- V1-V2)','Q wave (septal- V1-V2)','Q wave (anterior - V3-V4)','Hyperacute T wave (anterior, V3-V4)','ST upslopping','Right superior axis','Auricular bigeminy','Ventricular tachycardia','ST elevation (posterior - V7-V8-V9)','Ectopic atrial rhythm (< 100 BPM)','Lead misplacement','Biphasic','Ventricular bigeminy','J wave','Tall >2.5 mm','Third Degree AV Block','Sinus Pause','Acute MI','Early repolarization','Q wave (posterior - V7-V9)','Bi-atrial enlargement','LV pacing','Dextrocardia','Brugada','Ventricular Rhythm','ST depression (posterior - V7-V8-V9)','no_qrs']
to_remove_labels = ['ST depression (posterior - V7-V8-V9)','Tall >2.5 mm', 'J wave', 'Auricular bigeminy', 'Ventricular bigeminy', 'Sinus Pause', 'Dextrocardia', 'Hyperacute T wave (lateral, V5-V6)', 'Hyperacute T wave (septal, V1-V2)', 'Hyperacute T wave (anterior, V3-V4)', 'Bifid', 'RaVL + SV3 > 28 mm (H) or 20 mm (F)', 'Large >0.08 s', 'Biphasic', 'ST depression et T inversion in V5 or V6']

Y_val = np.load('/media/data1/anolin/Y_val_v1.1.npy').astype(np.int64)

pos_to_drop = list()
new_label_names = list()
for pos, item in enumerate(og_labels):
    if item in to_remove_labels:
        pos_to_drop.append(pos)
    else:
        new_label_names.append(item)
Y_val = np.delete(Y_val, pos_to_drop, axis=1)
Y_val

In [None]:
Y_pred = np.load('/media/data1/anolin/results_benchmarkv2/resnet50_robustscaler_leads_1997_v1.1/output_1.npy')

def sigmoid(x):
  return 1 / (1 + math.exp(-x))

sigmoid_v = np.vectorize(sigmoid)
Y_pred = sigmoid_v(Y_pred)
sigmoid_v_bin =  np.where(Y_pred > 0.5, 1, 0)
sigmoid_v_bin

In [None]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.metrics import roc_auc_score
import scikit_posthocs as sp
from statistics import mean

In [None]:
def hamming_score(y_true: np.ndarray, y_pred: np.ndarray):
    numerator = (y_true & y_pred).sum(axis=1)
    denominator = (y_true | y_pred).sum(axis=1)

    return np.divide(numerator, denominator, out=np.ones_like(numerator, dtype=np.float_),
                        where=denominator != 0).mean()

In [None]:
list_condition = list()
list_inter = list()
list_score = list()

metric = 'avg_acc' #['cat_accuracy', 'hamming', 'avg_acc', 'roc_macro', 'roc_micro', 'pr_macro', 'pr_micro']

Y_val = np.load('/media/data1/anolin/Y_val_v1.1.npy').astype(np.int64)
og_labels =  ['Sinusal','Regular','Monomorph','QS complex in V1-V2-V3','R complex in V5-V6','T wave inversion (inferior - II, III, aVF)','Left bundle branch block','RaVL > 11 mm','SV1 + RV5 or RV6 > 35 mm','T wave inversion (lateral -I, aVL, V5-V6)','T wave inversion (anterior - V3-V4)','Left axis deviation','Left ventricular hypertrophy','Bradycardia','Q wave (inferior - II, III, aVF)','Afib','Irregularly irregular','Atrial tachycardia (>= 100 BPM)','Nonspecific intraventricular conduction delay','Premature ventricular complex','Polymorph','T wave inversion (septal- V1-V2)','Right bundle branch block','Ventricular paced','ST elevation (anterior - V3-V4)','ST elevation (septal - V1-V2)','1st degree AV block','Premature atrial complex','Atrial flutter',"rSR' in V1-V2",'qRS in V5-V6-I, aVL','Left anterior fascicular block','Right axis deviation','2nd degree AV block - mobitz 1','ST depression (inferior - II, III, aVF)','Acute pericarditis','ST elevation (inferior - II, III, aVF)','Low voltage','Regularly irregular','Bifid','Junctional rhythm','Left atrial enlargement','ST elevation (lateral - I, aVL, V5-V6)','Atrial paced','Right ventricular hypertrophy','Delta wave','Wolff-Parkinson-White (Pre-excitation syndrome)','Prolonged QT','ST depression (anterior - V3-V4)','QRS complex negative in III','RaVL + SV3 > 28 mm (H) or 20 mm (F)','Q wave (lateral- I, aVL, V5-V6)','Hyperacute T wave (lateral, V5-V6)','Hyperacute T wave (septal, V1-V2)','Supraventricular tachycardia','ST downslopping','ST depression (lateral - I, avL, V5-V6)','2nd degree AV block - mobitz 2','U wave','ST depression et T inversion in V5 or V6','Large >0.08 s','R/S ratio in V1-V2 >1','RV1 + SV6\xa0> 11 mm','Left posterior fascicular block','Right atrial enlargement','ST depression (septal- V1-V2)','Q wave (septal- V1-V2)','Q wave (anterior - V3-V4)','Hyperacute T wave (anterior, V3-V4)','ST upslopping','Right superior axis','Auricular bigeminy','Ventricular tachycardia','ST elevation (posterior - V7-V8-V9)','Ectopic atrial rhythm (< 100 BPM)','Lead misplacement','Biphasic','Ventricular bigeminy','J wave','Tall >2.5 mm','Third Degree AV Block','Sinus Pause','Acute MI','Early repolarization','Q wave (posterior - V7-V9)','Bi-atrial enlargement','LV pacing','Dextrocardia','Brugada','Ventricular Rhythm','ST depression (posterior - V7-V8-V9)','no_qrs']
to_remove_labels = ['ST depression (posterior - V7-V8-V9)','Tall >2.5 mm', 'J wave', 'Auricular bigeminy', 'Ventricular bigeminy', 'Sinus Pause', 'Dextrocardia', 'Hyperacute T wave (lateral, V5-V6)', 'Hyperacute T wave (septal, V1-V2)', 'Hyperacute T wave (anterior, V3-V4)', 'Bifid', 'RaVL + SV3 > 28 mm (H) or 20 mm (F)', 'Large >0.08 s', 'Biphasic', 'ST depression et T inversion in V5 or V6']

Y_val = np.load('/media/data1/anolin/Y_val_v1.1.npy').astype(np.int64)

pos_to_drop = list()
new_label_names = list()
for pos, item in enumerate(og_labels):
    if item in to_remove_labels:
        pos_to_drop.append(pos)
    else:
        new_label_names.append(item)
Y_val = np.delete(Y_val, pos_to_drop, axis=1)

dict_eq = {42:0,1997:1,2023:2}

for condition in ['notscaled','standardscaler','minmaxscaler','maxabsscaler','robustscaler','quantiletransformeruniform','quantiletransformernormal']:
    for approach in [None,'leads']:
        if approach == 'leads' and condition == 'notscaled':
            continue
        for seed in [42,1997,2023]:
            if approach == None:
                output_matrix = np.load(f'/media/data1/anolin/results_benchmarkv2/resnet50_{condition}_{seed}_v1.1/output_1.npy')
            else:
                if condition != 'notscaled':
                    output_matrix = np.load(f'/media/data1/anolin/results_benchmarkv2/resnet50_{condition}_leads_{seed}_v1.1/output_1.npy')

            def sigmoid(x):
                return 1 / (1 + math.exp(-x))

            sigmoid_v = np.vectorize(sigmoid)
            Y_pred = sigmoid_v(output_matrix)
            sigmoid_v_bin =  np.where(Y_pred > 0.5, 1, 0)
            sigmoid_v_bin


            if approach != None:
                name = f'{condition}_{approach}'

            else:
                name = condition

            if metric == 'cat_accuracy':
                list_condition.append(name)
                list_inter.append(dict_eq[seed])
                list_score.append(accuracy_score(Y_val, sigmoid_v_bin))

            if metric == 'hamming':
                list_condition.append(name)
                list_inter.append(dict_eq[seed])
                list_score.append(accuracy_score(Y_val, sigmoid_v_bin))
            
            if metric == 'avg_acc':
                list_condition.append(name)
                list_inter.append(dict_eq[seed])
                list_score.append(mean([accuracy_score(Y_val[:,i], sigmoid_v_bin[:,i]) for i in range(Y_val.shape[1])]))

            if metric == 'roc_macro':
                list_condition.append(name)
                list_inter.append(dict_eq[seed])
                list_score.append(roc_auc_score(Y_val, sigmoid_v_bin, average='macro'))            
                                  
            if metric == 'roc_micro':
                list_condition.append(name)
                list_inter.append(dict_eq[seed])
                list_score.append(roc_auc_score(Y_val, sigmoid_v_bin, average='micro'))          


            if metric == 'pr_macro':
                list_condition.append(name)
                list_inter.append(dict_eq[seed])
                list_score.append(average_precision_score(Y_val, sigmoid_v_bin, average='macro'))          

            if metric == 'pr_micro':
                list_condition.append(name)
                list_inter.append(dict_eq[seed])
                list_score.append(average_precision_score(Y_val, sigmoid_v_bin, average='micro'))     


df_ = pd.DataFrame(zip(list_condition,list_inter,list_score), columns=['method','fold','score'])
avg_rank = df_.groupby('fold').score.rank(pct=True).groupby(df_.method).mean()
test_results = sp.posthoc_conover_friedman(
    df_,
    melted=True,
    block_col='fold',
    group_col='method',
    y_col='score',
)
#sp.sign_plot(test_results)
#plt.title("Conover test PR Macro")
sp.critical_difference_diagram(avg_rank, test_results)
plt.title("CDD for Acc")

# effect of filter filter

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
import heartpy as hp
import gc


In [None]:
# generate the resutls
from sklearn.metrics import accuracy_score
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.metrics import roc_auc_score
import scikit_posthocs as sp

list_condition = list()
list_inter = list()
list_score = list()

metric = 'pr_micro' #['cat_accuracy', 'hamming', 'avg_acc', 'roc_macro', 'roc_micro', 'pr_macro', 'pr_micro']

Y_val = np.load('/media/data1/anolin/Y_val_v1.1.npy').astype(np.int64)
og_labels =  ['Sinusal','Regular','Monomorph','QS complex in V1-V2-V3','R complex in V5-V6','T wave inversion (inferior - II, III, aVF)','Left bundle branch block','RaVL > 11 mm','SV1 + RV5 or RV6 > 35 mm','T wave inversion (lateral -I, aVL, V5-V6)','T wave inversion (anterior - V3-V4)','Left axis deviation','Left ventricular hypertrophy','Bradycardia','Q wave (inferior - II, III, aVF)','Afib','Irregularly irregular','Atrial tachycardia (>= 100 BPM)','Nonspecific intraventricular conduction delay','Premature ventricular complex','Polymorph','T wave inversion (septal- V1-V2)','Right bundle branch block','Ventricular paced','ST elevation (anterior - V3-V4)','ST elevation (septal - V1-V2)','1st degree AV block','Premature atrial complex','Atrial flutter',"rSR' in V1-V2",'qRS in V5-V6-I, aVL','Left anterior fascicular block','Right axis deviation','2nd degree AV block - mobitz 1','ST depression (inferior - II, III, aVF)','Acute pericarditis','ST elevation (inferior - II, III, aVF)','Low voltage','Regularly irregular','Bifid','Junctional rhythm','Left atrial enlargement','ST elevation (lateral - I, aVL, V5-V6)','Atrial paced','Right ventricular hypertrophy','Delta wave','Wolff-Parkinson-White (Pre-excitation syndrome)','Prolonged QT','ST depression (anterior - V3-V4)','QRS complex negative in III','RaVL + SV3 > 28 mm (H) or 20 mm (F)','Q wave (lateral- I, aVL, V5-V6)','Hyperacute T wave (lateral, V5-V6)','Hyperacute T wave (septal, V1-V2)','Supraventricular tachycardia','ST downslopping','ST depression (lateral - I, avL, V5-V6)','2nd degree AV block - mobitz 2','U wave','ST depression et T inversion in V5 or V6','Large >0.08 s','R/S ratio in V1-V2 >1','RV1 + SV6\xa0> 11 mm','Left posterior fascicular block','Right atrial enlargement','ST depression (septal- V1-V2)','Q wave (septal- V1-V2)','Q wave (anterior - V3-V4)','Hyperacute T wave (anterior, V3-V4)','ST upslopping','Right superior axis','Auricular bigeminy','Ventricular tachycardia','ST elevation (posterior - V7-V8-V9)','Ectopic atrial rhythm (< 100 BPM)','Lead misplacement','Biphasic','Ventricular bigeminy','J wave','Tall >2.5 mm','Third Degree AV Block','Sinus Pause','Acute MI','Early repolarization','Q wave (posterior - V7-V9)','Bi-atrial enlargement','LV pacing','Dextrocardia','Brugada','Ventricular Rhythm','ST depression (posterior - V7-V8-V9)','no_qrs']
to_remove_labels = ['ST depression (posterior - V7-V8-V9)','Tall >2.5 mm', 'J wave', 'Auricular bigeminy', 'Ventricular bigeminy', 'Sinus Pause', 'Dextrocardia', 'Hyperacute T wave (lateral, V5-V6)', 'Hyperacute T wave (septal, V1-V2)', 'Hyperacute T wave (anterior, V3-V4)', 'Bifid', 'RaVL + SV3 > 28 mm (H) or 20 mm (F)', 'Large >0.08 s', 'Biphasic', 'ST depression et T inversion in V5 or V6']

Y_val = np.load('/media/data1/anolin/Y_val_v1.1.npy').astype(np.int64)

pos_to_drop = list()
new_label_names = list()
for pos, item in enumerate(og_labels):
    if item in to_remove_labels:
        pos_to_drop.append(pos)
    else:
        new_label_names.append(item)
Y_val = np.delete(Y_val, pos_to_drop, axis=1)

dict_eq = {42:0,1997:1,2023:2}

for low_cut in [1, 0.1, 0.01, -1]: 
    for high_cut in [100, 75, 50, -1]:
        if low_cut == high_cut == -1:
            continue

        for seed in [42, 1997, 2023]:

            output_matrix = np.load(f"/media/data1/anolin/results_benchmarkv2/resnet50_filtered_{low_cut}_{high_cut}_{seed}_v1.1/output_1.npy")


            sigmoid_v = np.vectorize(sigmoid)
            Y_pred = sigmoid_v(output_matrix)
            sigmoid_v_bin =  np.where(Y_pred > 0.5, 1, 0)
            sigmoid_v_bin

            if low_cut != -1 and high_cut != -1:
                name_condition = f'BP_{low_cut}_{high_cut}Hz'

            elif low_cut == -1 and high_cut != -1:
                name_condition = f'LP_{high_cut}Hz'

            elif low_cut != -1 and high_cut == -1:
                name_condition = f'HP_{low_cut}Hz'

            else:
                pass


            if metric == 'cat_accuracy':
                list_condition.append(name_condition)
                list_inter.append(dict_eq[seed])
                list_score.append(accuracy_score(Y_val, sigmoid_v_bin))

            if metric == 'hamming':
                list_condition.append(name_condition)
                list_inter.append(dict_eq[seed])
                list_score.append(accuracy_score(Y_val, sigmoid_v_bin))
            
            if metric == 'avg_acc':
                list_condition.append(name_condition)
                list_inter.append(dict_eq[seed])
                list_score.append(mean([accuracy_score(Y_val[:,i], sigmoid_v_bin[:,i]) for i in range(Y_val.shape[1])]))

            if metric == 'roc_macro':
                list_condition.append(name_condition)
                list_inter.append(dict_eq[seed])
                list_score.append(roc_auc_score(Y_val, sigmoid_v_bin, average='macro'))            
                                    
            if metric == 'roc_micro':
                list_condition.append(name_condition)
                list_inter.append(dict_eq[seed])
                list_score.append(roc_auc_score(Y_val, sigmoid_v_bin, average='micro'))          


            if metric == 'pr_macro':
                list_condition.append(name_condition)
                list_inter.append(dict_eq[seed])
                list_score.append(average_precision_score(Y_val, sigmoid_v_bin, average='macro'))          

            if metric == 'pr_micro':
                list_condition.append(name_condition)
                list_inter.append(dict_eq[seed])
                list_score.append(average_precision_score(Y_val, sigmoid_v_bin, average='micro'))     
    
df_ = pd.DataFrame(zip(list_condition,list_inter,list_score), columns=['method','fold','score'])
df_ = pd.concat([df_, df_first[df_first.method == 'notscaled']])
avg_rank = df_.groupby('fold').score.rank(pct=True).groupby(df_.method).mean()
test_results = sp.posthoc_conover_friedman(
    df_,
    melted=True,
    block_col='fold',
    group_col='method',
    y_col='score',
)
#sp.sign_plot(test_results)
#plt.title("Conover test PR Macro")
sp.critical_difference_diagram(avg_rank, test_results)
plt.title("CDD for PR Micro")

In [None]:
avg_rank = df_.groupby('fold').score.rank(pct=True).groupby(df_.method).mean()
test_results = sp.posthoc_conover_friedman(
    df_,
    melted=True,
    block_col='fold',
    group_col='method',
    y_col='score',
)
#sp.sign_plot(test_results)
#plt.title("Conover test PR Macro")
sp.critical_difference_diagram(avg_rank, test_results)
plt.title("CDD for Acc")

In [None]:
#generate the datasets
list_condition = list()
list_inter = list()
list_score = list()



for file in tqdm(['X_val'], desc='levels'):

    X_val = np.load(f'/media/data1/anolin/{file}_v1.1.npy').astype(np.float16)

    for low_cut in [1, 0.1, 0.01, -1]: 
        for high_cut in [100, 75, 50, -1]:

            print(f'low_cut: {low_cut}Hz')
            print(f'high_cut: {high_cut}Hz')

            if low_cut == -1 and high_cut == -1:
                continue

            # Bandpass filter parameters
            lowcut = low_cut  # Define your own lowcut frequency
            highcut = high_cut  # Define your own highcut frequency
            fs = 250  # Sampling frequency
            order = 2  # Filter order
            N = X_val.shape[0] # Replace with your actual N

            # Create bandpass filter coefficients
            if lowcut == -1:
                b, a = butter(order, highcut, btype='low', fs=fs)

            elif highcut == -1:
                b, a = butter(order, lowcut, btype='high', fs=fs)

            else:
                b, a = butter(order, [lowcut,highcut], btype='bandpass', fs=fs)

            def apply_filter(data_slice):
                """Applies the bandpass filter to a slice of the data."""
                filtered_slice = np.empty_like(data_slice).astype(np.float16)
                for i in range(data_slice.shape[0]):
                    for j in range(data_slice.shape[-1]):
                        filtered_slice[i, :, j] = filtfilt(b, a, data_slice[i, :, j])
                return filtered_slice

            # Divide data into chunks for parallel processing
            num_processes = cpu_count()
            chunk_size = N // num_processes
            chunks = [X_val[i:i + chunk_size].astype(np.float16) for i in range(0, N, chunk_size)]

            # Perform parallel processing with progress tracking
            with Pool(num_processes) as pool:
                results = list(tqdm(pool.imap(apply_filter, chunks), total=len(chunks)))

            # Reassemble the results
            filtered_data = np.concatenate(results, axis=0).astype(np.float16)

            np.save(f'/media/data1/anolin/{file}_filtered_{low_cut}_{high_cut}_v1.1.npy', filtered_data.astype(np.float16))
            gc.collect()
                        

In [None]:
import numpy as np
from scipy.signal import butter, filtfilt

# Bandpass filter parameters
lowcut = 0.01  # Define your own lowcut frequency
highcut = 100  # Define your own highcut frequency
fs = 250  # Sampling frequency
order = 2  # Filter order
N = X_val.shape[0] # Replace with your actual N

# Create bandpass filter coefficients
b, a = butter(order, lowcut, btype='highpass', fs=fs)

def apply_filter(data_slice):
    """Applies the bandpass filter to a slice of the data."""
    filtered_slice = np.empty_like(data_slice).astype(np.float16)
    for i in range(data_slice.shape[0]):
        for j in range(data_slice.shape[-1]):
            filtered_slice[i, :, j] = filtfilt(b, a, data_slice[i, :, j])
    return filtered_slice

# Divide data into chunks for parallel processing
num_processes = cpu_count()
chunk_size = N // num_processes
chunks = [X_val[i:i + chunk_size].astype(np.float16) for i in range(0, N, chunk_size)]

# Perform parallel processing with progress tracking
with Pool(num_processes) as pool:
    results = list(tqdm(pool.imap(apply_filter, chunks), total=len(chunks)))

# Reassemble the results
filtered_data = np.concatenate(results, axis=0)