In [13]:
%load_ext lab_black

In [14]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import seaborn as sns

# Kolmogorov-Smirnov Test
# The Kolmogorov-Smirnov test compares the data to a normal distribution.
from scipy.stats import kstest
from itertools import combinations
from scipy.stats import wilcoxon

from mne_connectivity import spectral_connectivity_epochs

In [15]:
process_epoch_path = "../processed_epochs.pkl"
with open(process_epoch_path, "rb") as f:
    epochs = pickle.load(f)

for key, value in epochs.items():
    print(key, value.shape)

pre (7, 49, 126001)
during (5, 49, 126001)
post (5, 49, 126001)
combine (17, 49, 126001)


In [16]:
# wpli_connectivity_permutation_analysis.pkl
# wpli_connectivity.pkl -> The difference is statistically significant
# dpli_connectivity.pkl -> -


features_path = "../wpli_connectivity.pkl"
with open(features_path, "rb") as f:
    features = pickle.load(f)

for key, value in features.items():
    print(key, value.shape)

pre (2401, 1)
during (2401, 1)
post (2401, 1)
combine (2401, 1)


In [17]:
def extract_lower_triangular(matrix):
    return matrix[np.tril_indices(matrix.shape[0], k=-1)]


def extract_stage_lower_triangular(connectivity):
    return extract_lower_triangular(connectivity.get_data("dense").squeeze())


# Step 3: Visualize the fit
def plot_fit(data, mean, std_dev, title: str = "pre"):
    # Plot the histogram of the data with Seaborn
    sns.histplot(
        data,
        bins=30,
        kde=False,
        stat="density",
        color="blue",
        edgecolor="black",
        alpha=0.6,
    )

    # Plot the PDF of the fitted normal distribution
    xmin, xmax = plt.xlim()
    x = np.linspace(xmin, xmax, 100)
    p = stats.norm.pdf(x, mean, std_dev)
    plt.plot(x, p, "k", linewidth=2)
    title = f"{title} - Fit results: mean = %.2f,  std dev = %.2f" % (mean, std_dev)
    plt.title(title)
    plt.show()


def ks_test(data):
    """
    Kolmogorov-Smirnov Test
    The Kolmogorov-Smirnov test compares the data to a normal
    distribution.

    Parameters
    ----------
    data : array-like
        The data to test for normality.
    """
    stat, p = kstest(data, "norm")
    print("Statistics=%.3f, p=%.3f" % (stat, p))
    if p > 0.05:
        print("Data is normally distributed (fail to reject H0)")
    else:
        print("Data is not normally distributed (reject H0)")


def perform_paired_ttest(sample1, sample2, alpha=0.05):
    t_statistic, p_value = stats.ttest_rel(sample1, sample2)

    return t_statistic, p_value

To determine if the distributions shown in the histograms are normally distributed, we can visually inspect the histograms and the overlaid density curves, as well as consider the mean and standard deviation provided.

First Image:

Mean = 0.47
Standard Deviation = 0.02
Visual Inspection: The histogram appears to have a skew to the right, indicating it might not be perfectly normally distributed. The density curve does not fit perfectly over the histogram, showing some deviation from normality.
Second Image:

Mean = 0.51
Standard Deviation = 0.02
Visual Inspection: The histogram looks more symmetric compared to the first one, and the density curve fits better over the histogram. This suggests that the distribution might be closer to a normal distribution.
Third Image:

Mean = 0.50
Standard Deviation = 0.02
Visual Inspection: Similar to the second image, this histogram appears quite symmetric, and the density curve fits well over the histogram. This indicates that the distribution is likely to be normally distributed.
Based on visual inspection:

The second and third images seem to be normally distributed.
The first image appears to have some skew and may not be normally distributed.
For a more accurate assessment, statistical tests like the Shapiro-Wilk test or the Kolmogorov-Smirnov test could be performed on the data to determine normality.

In [None]:
# check whether features are distributed normally or not
for key, value in features.items():
    data = extract_stage_lower_triangular(value)
    # filter non-zero values
    print(key, data.shape)
    mean, std = data.mean(), data.std()

    print(f"Estimated mean: {mean}")
    print(f"Estimated standard deviation: {std}")

    # print min and max
    print(f"Min: {data.min()}")
    print(f"Max: {data.max()}")

    ks_test(data)

    # Plot the fitted distribution
    plot_fit(data, mean, std, title=key)

In [None]:
# Calculate t-test between classes

alpha = 0.05

# calculate the correlation between the lower triangular matrices
correlations = {}
for stage1, stage2 in combinations(features.keys(), 2):
    print(f"\nComparing {stage1} and {stage2}:")
    t_statistic, p_value = perform_paired_ttest(
        extract_stage_lower_triangular(features[stage1]),
        extract_stage_lower_triangular(features[stage2]),
    )

    print("\nPaired Samples T-Test Results:")
    print(f"T-statistic: {t_statistic}")
    print(f"P-value: {p_value}")

    if p_value < alpha:
        print(f"The difference is statistically significant (p < {alpha})")
    else:
        print(f"The difference is not statistically significant (p >= {alpha})")

In [None]:
# Calculate t-test between classes

alpha = 0.05

# calculate the correlation between the lower triangular matrices
correlations = {}
for stage1, stage2 in combinations(features.keys(), 2):
    print(f"\nComparing {stage1} and {stage2}:")
    t_statistic, p_value = wilcoxon(
        extract_stage_lower_triangular(features[stage1]),
        extract_stage_lower_triangular(features[stage2]),
    )

    print("\nPaired Samples T-Test Results:")
    print(f"T-statistic: {t_statistic}")
    print(f"P-value: {p_value}")

    if p_value < alpha:
        print(f"The difference is statistically significant (p < {alpha})")
    else:
        print(f"The difference is not statistically significant (p >= {alpha})")

In [None]:
# performing permutation test


def calculate_connectivity(epochs, sfreq, fmin, fmax, tmin, method: str = "wpli"):

    return spectral_connectivity_epochs(
        epochs,
        method=method,
        mode="multitaper",
        sfreq=sfreq,
        fmin=fmin,
        fmax=fmax,
        faverage=True,  # Average connectivity scores for each frequency band. If True, the output freqs will be a list with arrays of the frequencies that were averaged.
        tmin=tmin,
        mt_adaptive=False,  # Use adaptive weights for multitaper
        n_jobs=4,
    )


def permutation_test(epochs, n_permutations=1000, alpha=0.05, *args, **kwargs):
    # Combine all data
    combined_data = epochs["combine"].copy()

    perm_stats_pre_post = []
    perm_stats_pre_during = []
    perm_stats_during_post = []

    for _ in range(n_permutations):
        np.random.shuffle(combined_data)
        shuffled_pre = combined_data[:7]
        shuffled_during = combined_data[7:12]
        shuffled_post = combined_data[12:]

        perm_wPLI_pre = extract_stage_lower_triangular(
            calculate_connectivity(shuffled_pre, *args, **kwargs)
        )
        perm_wPLI_during = extract_stage_lower_triangular(
            calculate_connectivity(shuffled_during, *args, **kwargs)
        )
        perm_wPLI_post = extract_stage_lower_triangular(
            calculate_connectivity(shuffled_post, *args, **kwargs)
        )

        perm_stats_pre_post.append(
            perform_paired_ttest(perm_wPLI_post.flatten(), perm_wPLI_pre.flatten())
        )
        perm_stats_pre_during.append(
            perform_paired_ttest(perm_wPLI_during.flatten(), perm_wPLI_pre.flatten())
        )
        perm_stats_during_post.append(
            perform_paired_ttest(perm_wPLI_post.flatten(), perm_wPLI_during.flatten())
        )

    # create a dictionary to store the results
    results = {
        "pre_post": perm_stats_pre_post,
        "pre_during": perm_stats_pre_during,
        "during_post": perm_stats_during_post,
    }

    return results

In [None]:
sfreq = 500
fmin = 8
fmax = 50
tmin = 0
method = "wpli"
n_permutation = 1

permutation_test(
    epochs=epochs,
    n_permutations=n_permutation,
    alpha=0.05,
    sfreq=sfreq,
    fmin=fmin,
    fmax=fmax,
    tmin=tmin,
    method=method,
)

In [None]:
import os
import pickle
import pandas as pd
from glob import glob

def load_pickle_file(filename):
    with open(filename, 'rb') as f:
        return pickle.load(f)
    
    
def process_pickle_files(directory, pattern='permutation_results_*.pkl'):
    all_data = []
    
    # Get all matching pickle files in the directory
    pickle_files = glob(os.path.join(directory, pattern))
    
    for file in pickle_files:
        data = load_pickle_file(file)
        for item in data:
            flattened_item = {}
            for key, value in item.items():
                if isinstance(value, tuple):
                    flattened_item[f'{key}_statistic'] = value[0]
                    flattened_item[f'{key}_pvalue'] = value[1]
                else:
                    flattened_item[key] = value
            flattened_item['source_file'] = os.path.basename(file)
            all_data.append(flattened_item)
    
    return pd.DataFrame(all_data)

directory = "/Users/soroush/Documents/Code/freelance-project/vielight/vielight_close_loop/results/permutation_conditions"
df = process_pickle_files(directory)

In [11]:
df.head()

Unnamed: 0,pre_post_statistic,pre_post_pvalue,pre_during_statistic,pre_during_pvalue,during_post_statistic,during_post_pvalue,source_file
0,97.507643,0.0,204.292899,0.0,-102.240614,0.0,permutation_results_37.pkl
1,48.083282,8.174306e-280,46.893947,2.081649e-271,5.892808,4.953931e-09,permutation_results_37.pkl
2,98.663983,0.0,90.099151,0.0,13.116101,8.890592e-37,permutation_results_37.pkl
3,174.117323,0.0,96.456856,0.0,86.588883,0.0,permutation_results_37.pkl
4,119.879075,0.0,207.020505,0.0,-77.797524,0.0,permutation_results_37.pkl


In [10]:
df["source_file"].sort_values().unique()

array(['permutation_results_0.pkl', 'permutation_results_1.pkl',
       'permutation_results_10.pkl', 'permutation_results_11.pkl',
       'permutation_results_12.pkl', 'permutation_results_13.pkl',
       'permutation_results_14.pkl', 'permutation_results_15.pkl',
       'permutation_results_16.pkl', 'permutation_results_17.pkl',
       'permutation_results_18.pkl', 'permutation_results_19.pkl',
       'permutation_results_2.pkl', 'permutation_results_20.pkl',
       'permutation_results_21.pkl', 'permutation_results_22.pkl',
       'permutation_results_23.pkl', 'permutation_results_24.pkl',
       'permutation_results_25.pkl', 'permutation_results_26.pkl',
       'permutation_results_27.pkl', 'permutation_results_28.pkl',
       'permutation_results_29.pkl', 'permutation_results_3.pkl',
       'permutation_results_30.pkl', 'permutation_results_31.pkl',
       'permutation_results_32.pkl', 'permutation_results_33.pkl',
       'permutation_results_34.pkl', 'permutation_results_35.pkl',

In [20]:
with open(
    "/Users/soroush/Documents/Code/freelance-project/vielight/vielight_close_loop/wpli_connectivity.pkl",
    "rb",
) as f:
    wpli_connectivity = pickle.load(f)

triangular_matrice = {
    stage: extract_stage_lower_triangular(wpli_connectivity[stage])
    for stage in ["pre", "during", "post"]
}

from itertools import combinations

# calculate the correlation between the lower triangular matrices
correlations = {}
for stage1, stage2 in combinations(triangular_matrice.keys(), 2):
    correlations[f"{stage1}_{stage2}"] = wilcoxon(
        triangular_matrice[stage1], triangular_matrice[stage2]
    )

correlations

{'pre_during': WilcoxonResult(statistic=18831.0, pvalue=1.3616590055863041e-173),
 'pre_post': WilcoxonResult(statistic=41988.0, pvalue=3.6044230110910713e-150),
 'during_post': WilcoxonResult(statistic=110922.0, pvalue=1.382048093713617e-90)}

In [23]:
correlations["pre_during"]

WilcoxonResult(statistic=18831.0, pvalue=1.3616590055863041e-173)

In [25]:
def calculate_p_value(correlation, statistic):
    return np.mean([1 if x >= correlation else 0 for x in statistic])


condition = "pre_during"
p_value_pre_during = calculate_p_value(
    correlations[f"{condition}"][0], df[f"{condition}_statistic"]
)

condition = "pre_post"
p_value_pre_post = calculate_p_value(
    correlations[f"{condition}"][0], df[f"{condition}_statistic"]
)

condition = "during_post"
p_value_during_post = calculate_p_value(
    correlations[f"{condition}"][0], df[f"{condition}_statistic"]
)

print(f"p-value for pre-during: {p_value_pre_during}")
print(f"p-value for pre-post: {p_value_pre_post}")
print(f"p-value for during-post: {p_value_during_post}")

p-value for pre-during: 0.0
p-value for pre-post: 0.0
p-value for during-post: 0.0
