In [34]:
import pickle
from itertools import product

import numpy as np
import pandas as pd
from scipy import stats

from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix

In [29]:
CLASS_DICT = {
    0: "W",
    1: "N1",
    2: "N2",
    3: "N3",
    4: "REM"
}

In [30]:
class_names = CLASS_DICT.values()

results = []
classwise_results = []
for model, split in product(["SESM","EEGNet"],[1,2,3]):
    path = f"/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri1/Results/{model}/split{split}/"
    
    with open(f'{path}/y_true_test.pkl','rb') as f:
            y_true = pickle.load(f)
        
    with open(f'{path}/y_pred_test.pkl','rb') as f:
            y_pred = pickle.load(f)
    
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average="macro")
    recall = recall_score(y_true, y_pred, average="macro")

    row = {
        "split": split,
        "model": model,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
    }
    results.append(row)

    # Create a confusion matrix
    confusion = confusion_matrix(y_true, y_pred)
    class_accuracy = confusion.diagonal()/confusion.sum(axis=1)

    classwise_row = { class_name: class_accuracy[i] for i, class_name in enumerate(class_names) }
    classwise_row["split"] = split
    classwise_row["model"] = model
    classwise_results.append(classwise_row)

results_df = pd.DataFrame(results)
classwise_df = pd.DataFrame(classwise_results)

In [25]:
grouped = results_df.groupby("model")

# Separate the data for the two models
sesm_data = grouped.get_group('SESM')
eegnet_data = grouped.get_group('EEGNet')

# Perform t-tests for each metric
accuracy_ttest = stats.ttest_ind(sesm_data['accuracy'], eegnet_data['accuracy'])
precision_ttest = stats.ttest_ind(sesm_data['precision'], eegnet_data['precision'])
recall_ttest = stats.ttest_ind(sesm_data['recall'], eegnet_data['recall'])

# Create the resulting DataFrame
t_test_results = pd.DataFrame({
    'metric': ['accuracy', 'precision', 'recall'],
    'statistic': [accuracy_ttest.statistic, precision_ttest.statistic, recall_ttest.statistic],
    'p_value': [accuracy_ttest.pvalue, precision_ttest.pvalue, recall_ttest.pvalue]
})


In [26]:
t_test_results

Unnamed: 0,metric,statistic,p_value
0,accuracy,5.277954,0.006179
1,precision,1.020844,0.365045
2,recall,-2.208778,0.091751


In [27]:
stats.false_discovery_control(t_test_results["p_value"])

array([0.01853612, 0.36504495, 0.13762632])

# Class wise accuracies

In [31]:
grouped = classwise_df.groupby("model")

# Separate the data for the two models
sesm_data = grouped.get_group('SESM')
eegnet_data = grouped.get_group('EEGNet')

# Perform t-tests for each metric
classwise_test_results = []
for class_name in class_names:
    t_test = stats.ttest_ind(sesm_data[class_name], eegnet_data[class_name])
    classwise_test_results.append(t_test)

# Create the resulting DataFrame
classwise_t_test_results = pd.DataFrame({
    'metric': class_names,
    'statistic': [test.statistic for test in classwise_test_results],
    'p_value': [test.pvalue for test in classwise_test_results]
})

In [32]:
classwise_t_test_results

Unnamed: 0,metric,statistic,p_value
0,W,0.709489,0.517189
1,N1,-10.160507,0.000528
2,N2,2.616859,0.058994
3,N3,-1.960218,0.121524
4,REM,0.35958,0.737341


In [33]:
stats.false_discovery_control(classwise_t_test_results["p_value"])

array([0.64648622, 0.00264193, 0.14748384, 0.20254033, 0.73734121])

Predicted Class: W
1.4668883
16
0.5166144
TtestResult(statistic=4.0880449087745285, pvalue=0.00013340474855422295, df=59.0)


Predicted Class: N2
1.4572722
109
0.43939048
TtestResult(statistic=10.255296820669766, pvalue=8.447303667041605e-22, df=357.0)


Predicted Class: N3
1.4774722
7
0.43798834
TtestResult(statistic=2.3111056661108242, pvalue=0.032880124854366775, df=18.0)


Predicted Class: REM
1.7485868
66
0.4953538
TtestResult(statistic=10.495021118314252, pvalue=1.4186766905306205e-21, df=249.0)

In [36]:
stats.false_discovery_control([0.00013340474855422295,8.447303667041605e-22,0.032880124854366775,1.4186766905306205e-21])

array([1.77872998e-04, 2.83735338e-21, 3.28801249e-02, 2.83735338e-21])