In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform
import sys
sys.path.append('..')
import complexitylib as complex

In [2]:
def compare_6_groups(eeg_data, group_labels, conditions, m=10, tau=1):
    """
    eeg_data: np.ndarray of shape (n_subjects, n_channels, n_timepoints)
    group_labels: list of group IDs (e.g., [0, 1, 2, 0, 1, 2, ...])
    conditions: list of 'pre' or 'post' strings (same length as group_labels)
    """
    assert len(group_labels) == len(conditions) == eeg_data.shape[0], "Mismatch in input sizes."

    # Group storage: (group_id, condition) → list of per-subject D2 means
    from collections import defaultdict
    group_data = defaultdict(list)

    for i in range(eeg_data.shape[0]):
        dims = complex.compute_subject_dims(eeg_data[i], m=m, tau=tau)
        mean_dim = np.mean(dims)
        key = (group_labels[i], conditions[i])
        group_data[key].append(mean_dim)

    # Convert to list for plotting
    all_groups = sorted(group_data.keys())
    plot_data = [group_data[key] for key in all_groups]
    group_names = [f"Group {g+1} ({c})" for g, c in all_groups]

    # Plot
    plt.figure(figsize=(10, 5))
    plt.boxplot(plot_data, labels=group_names)
    plt.ylabel("Mean Correlation Dimension")
    plt.title("Correlation Dimension: Pre vs Post per Group")
    plt.xticks(rotation=45)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    return group_data


In [3]:
# Convert patients_clusters.csv to a numpy 2darray

file_path = '../output/patients_clusters.csv'

# csv to numpy 2darray
clusters = np.loadtxt(file_path, delimiter=',', skiprows=1, dtype=str)

In [4]:
patients = []
for i, pat in enumerate(clusters):
    patients.append(complex.Patient(name=pat[0], cluster=int(pat[1]), patient_id=i))


In [5]:
# drop patients DD09, DD45 and DD53 because they miss a second recording
# drop patients DD28 because their second recording is corrupted
patients = [p for p in patients if p.name not in ['DD09', 'DD28', 'DD45', 'DD53']]

In [6]:
aggregate_data = []
group_labels = []
conditions = []


unique_clusters = set(p.cluster for p in patients)
for i in range(len(unique_clusters)):
    # get patients in cluster i
    patients_in_cluster = [p for p in patients if p.cluster == i]
    group_data = []
    for pat in patients_in_cluster:
        # pre-treatment
        signal, channels = pat.get_eeg_before()
        # signal as nested list
        group_data.append(signal.tolist())
        group_labels.append(pat.cluster)
        conditions.append('pre')
    aggregate_data.append(group_data)
    group_data = []
    for pat in patients_in_cluster:
        # post-treatment
        signal, channels = pat.get_eeg_after()
        group_data.append(signal.tolist())
        group_labels.append(pat.cluster)
        conditions.append('post')
    aggregate_data.append(group_data)

In [11]:
print(len(aggregate_data))
print(len(aggregate_data[0]))
print(len(aggregate_data[0][0]))
print(len(group_labels))
print(len(conditions))

6
24
32
92
92


In [None]:
# Ensure all arrays in group_data have the same size along dimension 1
min_length = min(signal.shape[1] for signal in group_data)
group_data = [signal[:, :min_length] for signal in group_data]

# Merge all EEGs
eeg_data_all = np.concatenate(group_data, axis=0)

print(eeg_data_all.shape)

# Run analysis
# results = compare_6_groups(eeg_data_all, group_labels, conditions, m=8, tau=2)

        

(2944, 89600)
