In [None]:
import pickle
import mne
import numpy as np
from collections import Counter
import networkx as nx
from collections import defaultdict
from sklearn.manifold import SpectralEmbedding
import matplotlib.pyplot as plt
import matplotlib
import pickle
import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.stats import entropy
from sklearn.decomposition import PCA
import seaborn as sns
import pathlib
from tqdm import tqdm
from mne.preprocessing import ICA
from scipy.signal import butter, filtfilt
from pathlib import Path

In [None]:
from itertools import product

def all_binary_states(n):
    """
    Generate all 2^n binary states in {-1, +1}^n
    Output shape: (2^n, n)
    """
    return np.array(list(product([-1, 1], repeat=n)))


In [None]:
with open("binarized/binary_task.pkl", "rb") as f:
    binary_task = pickle.load(f)

with open("binarized/binary_rest.pkl", "rb") as f:
    binary_rest = pickle.load(f)

with open("ising_models/ising_task.pkl", "rb") as f:
    ising_task = pickle.load(f)

In [None]:
def extract_all_patterns(binary_data, bands=['theta', 'alpha', 'beta', 'gamma', 'broadband'], max_patterns=100):
    """
    Extract binary EEG state patterns for all subjects and bands, ensure {-1, +1}.
    Returns:
        patterns_dict[subject][band] = (n_patterns, n_channels)
    """
    patterns_dict = {}

    for subj_id in binary_data:
        patterns_dict[subj_id] = {}
        for band in bands:
            try:
                X = binary_data[subj_id][band].T  # (timepoints, channels)

                # Sample max_patterns timepoints
                if X.shape[0] > max_patterns:
                    idx = np.random.choice(X.shape[0], size=max_patterns, replace=False)
                    patterns = X[idx]
                else:
                    patterns = X

                # ✅ Enforce binary → bipolar conversion
                patterns = np.where(patterns == 0, -1, 1)

                # Sanity check
                if not np.all(np.isin(patterns, [-1, 1])):
                    raise ValueError("Pattern conversion failed!")

                patterns_dict[subj_id][band] = patterns

            except Exception as e:
                print(f"⚠️ Skipping subject {subj_id}, band {band}: {e}")
    
    return patterns_dict

patterns_rest = extract_all_patterns(binary_rest, bands=['theta', 'alpha', 'beta', 'gamma', 'broadband'], max_patterns=100)
patterns_task = extract_all_patterns(binary_task, bands=['theta', 'alpha', 'beta', 'gamma', 'broadband'], max_patterns=100)


In [None]:
print(patterns_rest['22']['beta'].shape)  # e.g., (100, 19)
print(patterns_task['22']['beta'].shape)

In [None]:

class HopfieldNetwork:
    def __init__(self, n_units):
        self.n_units = n_units
        self.J = np.zeros((n_units, n_units))  # connectivity matrix

    def train_hebbian(self, patterns):
        """
        Train the network using Hebbian learning.
        patterns: (n_patterns, n_units) binary patterns ∈ {-1, +1}
        """
        assert patterns.shape[1] == self.n_units
        assert np.all(np.isin(patterns, [-1, 1])), "Patterns must be in {-1, +1}"

        P = patterns.shape[0]
        self.J = np.dot(patterns.T, patterns) / P
        np.fill_diagonal(self.J, 0)  # Remove self-connections

    def energy(self, state):
        """
        Compute energy of a given state: E = -1/2 * s.T @ J @ s
        """
        assert state.shape[0] == self.n_units
        return -0.5 * np.dot(state, np.dot(self.J, state))

    def update_async(self, state, n_steps=100, tol=1e-5):
        """
        Simulate asynchronous updates until convergence or max steps.
        Returns:
            - final state (1D array of shape [n_units])
            - list of energy values
        """
        state = state.copy()
        energy_trace = [self.energy(state)]

        for step in range(n_steps):
            old_state = state.copy()

            # Random order update
            for i in np.random.permutation(self.n_units):
                h_i = np.dot(self.J[i], state)
                state[i] = 1 if h_i >= 0 else -1

            energy_trace.append(self.energy(state))

            # Convergence check: state doesn't change
            if np.allclose(state, old_state, atol=tol):
                break

        return state, energy_trace

    def run_multiple(self, initial_states, n_steps=100, tol=1e-5):
        """
        Run update_async from multiple initial states.
        Returns:
            - list of final stable states
            - list of number of steps to converge
            - list of energy traces
        """
        stable_states = []
        convergence_steps = []
        energy_trajectories = []

        for state in initial_states:
            final, energy_path = self.update_async(state, n_steps=n_steps, tol=tol)
            stable_states.append(final)
            convergence_steps.append(len(energy_path))
            energy_trajectories.append(energy_path)

        return stable_states, convergence_steps, energy_trajectories


In [None]:
def hopfield_features(patterns_dict, n_init=20):
    """
    Run Hopfield model per subject and band.
    Returns:
        features_dict[subject][band] = {
            'n_attractors', 'avg_steps', 'avg_energy'
        }
    """
    features_dict = {}

    for subj_id in patterns_dict:
        features_dict[subj_id] = {}
        for band in patterns_dict[subj_id]:
            patterns = patterns_dict[subj_id][band]

            if patterns.shape[0] < 2:
                continue

            try:
                model = HopfieldNetwork(n_units=patterns.shape[1])
                model.train_hebbian(patterns)

                # Initial states = random subset of patterns
                idx = np.random.choice(patterns.shape[0], size=min(n_init, len(patterns)), replace=False)
                init_states = patterns[idx]

                stable, steps, energy_traj = model.run_multiple(init_states)

                stable = np.array(stable)
                unique_attractors = np.unique(stable, axis=0)

                features_dict[subj_id][band] = {
                    'n_attractors': len(unique_attractors),
                    'avg_steps': np.mean(steps),
                    'avg_energy': np.mean([model.energy(s) for s in stable])
                }

            except Exception as e:
                print(f"❌ {subj_id} - {band} failed: {e}")
                continue

    return features_dict


In [None]:
def extract_all_patterns(binary_data, bands=['theta', 'alpha', 'beta', 'gamma'], max_patterns=100):
    """
    Extract binary EEG state patterns for all subjects and bands, ensure {-1, +1}.
    Returns:
        patterns_dict[subject][band] = (n_patterns, n_channels)
    """
    patterns_dict = {}

    for subj_id in binary_data:
        patterns_dict[subj_id] = {}
        for band in bands:
            try:
                X = binary_data[subj_id][band].T  # (timepoints, channels)

                # Sample max_patterns timepoints
                if X.shape[0] > max_patterns:
                    idx = np.random.choice(X.shape[0], size=max_patterns, replace=False)
                    patterns = X[idx]
                else:
                    patterns = X

                # ✅ Enforce binary → bipolar conversion
                patterns = np.where(patterns == 0, -1, 1)

                # Sanity check
                if not np.all(np.isin(patterns, [-1, 1])):
                    raise ValueError("Pattern conversion failed!")

                patterns_dict[subj_id][band] = patterns

            except Exception as e:
                print(f"⚠️ Skipping subject {subj_id}, band {band}: {e}")
    
    return patterns_dict


# should be deleted

def sample_patterns_from_ising(ising_dict, bands=['theta', 'alpha', 'beta', 'gamma'], n_samples=100):
    """
    Sample patterns from Ising model parameters using Gibbs sampling.
    Returns:
        patterns_dict[subject][band] = (n_samples, n_channels)
    """
    def gibbs_sample(h, J, n_samples):
        n = len(h)
        state = np.random.choice([-1, 1], size=n)
        samples = []

        for _ in range(n_samples * 10):  # Burn-in and oversampling
            for i in range(n):
                field = h[i] + np.dot(J[i], state) - J[i, i] * state[i]
                prob = 1 / (1 + np.exp(-2 * field * state[i]))
                state[i] = 1 if np.random.rand() < prob else -1
            if len(samples) < n_samples:
                samples.append(state.copy())

        return np.array(samples)

    patterns_dict = {}

    for subj_id in ising_dict:
        patterns_dict[subj_id] = {}
        for band in bands:
            try:
                h = ising_dict[subj_id][band]['h']
                J = ising_dict[subj_id][band]['J']
                patterns = gibbs_sample(h, J, n_samples=n_samples)
                patterns_dict[subj_id][band] = patterns
            except Exception as e:
                print(f"⚠️ Skipping subject {subj_id}, band {band}: {e}")
    
    return patterns_dict

patterns_rest = sample_patterns_from_ising(ising_task)

In [None]:
patterns_task = extract_all_patterns(binary_task) # patterns_rest = extract_all_patterns(binary_rest))
patterns_rest = extract_all_patterns(binary_rest)

In [None]:
np.unique(patterns_rest['12']['alpha'])  # should be array([-1, 1])
np.unique(patterns_task['12']['alpha']) 


In [None]:
hopfield_feats_task = hopfield_features(patterns_task)
hopfield_feats_rest = hopfield_features(patterns_rest)

In [None]:
from collections import defaultdict

def groupwise_hopfield_features(hopfield_feats, good_ids, bad_ids, bands=['theta', 'alpha', 'beta', 'gamma', 'broadband']):
    """
    Organize Hopfield features for group comparison.
    Returns:
        features[group][band][metric] = list of values across subjects
    """
    features = {
        'good': defaultdict(lambda: defaultdict(list)),
        'bad': defaultdict(lambda: defaultdict(list))
    }

    for group, ids in [('good', good_ids), ('bad', bad_ids)]:
        for subj_id in ids:
            subj_str = str(subj_id)  # subject IDs are strings in hopfield_feats
            if subj_str not in hopfield_feats:
                continue
            for band in bands:
                if band not in hopfield_feats[subj_str]:
                    continue

                feat_dict = hopfield_feats[subj_str][band]
                for metric, value in feat_dict.items():
                    # Include only scalar numeric values (ignore arrays, etc.)
                    if isinstance(value, (int, float)):
                        features[group][band][metric].append(value)

    return features


In [None]:
# Good/Bad group definitions using numeric keys as strings
bad_counters = ['0', '4', '6', '9', '10', '14', '19', '21', '22', '30']
good_counters = ['1', '2', '3', '5', '7', '8', '11', '12', '13', '15', '16', '17',
                 '18', '20', '23', '24', '25', '26', '27', '28', '29', '31', '32',
                 '33', '34', '35']

In [None]:
features_by_group = groupwise_hopfield_features(
    hopfield_feats_rest,
    good_ids=good_counters,
    bad_ids=bad_counters
)


In [None]:
features_by_group['good']['alpha']['n_attractors']


In [None]:
from scipy.stats import ttest_ind

def compare_group_feature(features_by_group, band='alpha', metric='n_attractors'):
    good = features_by_group['good'][band][metric]
    bad = features_by_group['bad'][band][metric]

    if len(good) < 2 or len(bad) < 2:
        print("Not enough data to compare.")
        return

    t, p = ttest_ind(good, bad, equal_var=False)
    print(f"📊 {metric.upper()} | Band: {band.upper()}")
    print(f"   ➤ Good mean: {np.mean(good):.2f}")
    print(f"   ➤ Bad  mean: {np.mean(bad):.2f}")
    print(f"   ➤ t = {t:.2f}, p = {p:.4f}")


In [None]:
compare_group_feature(features_by_group, band='alpha', metric='n_attractors')
compare_group_feature(features_by_group, band='beta', metric='avg_steps')
compare_group_feature(features_by_group, band='theta', metric='avg_energy')


In [None]:
hopfield_metrics = [
    'n_attractors',
    'avg_steps',
    'avg_energy',
]

for band in ['theta', 'alpha', 'beta', 'gamma', 'broadband']:
    for metric in hopfield_metrics:
        try:
            compare_group_feature(features_by_group, band, metric)
        except Exception as e:
            print(f"❌ Failed comparison for {metric} in {band}: {e}")


### more features

In [None]:
from collections import defaultdict
from scipy.stats import entropy
from itertools import combinations

def hopfield_features(patterns_all_subjects, n_init=100):
    """
    Computes Hopfield features across all subjects and bands.
    Input:
        patterns_all_subjects: {subject_id: {band: binary array [time, channels]}}
    Returns:
        Dictionary of features per subject per band
    """
    features = defaultdict(dict)

    for subj_id, band_dict in patterns_all_subjects.items():
        for band, patterns in band_dict.items():
            try:
                # Convert {0,1} → {-1,+1} if needed
                if np.any(patterns == 0):
                    patterns = np.where(patterns == 0, -1, 1)

                n_patterns, n_units = patterns.shape
                hopfield = HopfieldNetwork(n_units)
                hopfield.train_hebbian(patterns)

                # Run dynamics from n_init random states
                initial_states = np.random.choice([-1, 1], size=(n_init, n_units))
                stable_states, steps_list, energy_traces = hopfield.run_multiple(initial_states)

                # Convert states to tuples for hashing
                attractors = {}
                energies = []
                for s in stable_states:
                    key = tuple(s)
                    attractors[key] = attractors.get(key, 0) + 1
                    energies.append(hopfield.energy(s))

                basin_sizes = list(attractors.values())
                total_basins = sum(basin_sizes)

                # Energy differences
                energy_gaps = []
                for a1, a2 in combinations(attractors.keys(), 2):
                    e1 = hopfield.energy(np.array(a1))
                    e2 = hopfield.energy(np.array(a2))
                    energy_gaps.append(abs(e1 - e2))

                features[subj_id][band] = {
                    'n_attractors': len(attractors),
                    'avg_steps': np.mean(steps_list) if steps_list else np.nan,
                    'avg_energy': np.mean(energies) if energies else np.nan,
                    'mean_basin_size': np.mean(basin_sizes) if len(basin_sizes) > 0 else np.nan,
                    'basin_entropy': entropy(np.array(basin_sizes) / total_basins) if total_basins > 0 else np.nan,
                    'min_energy_gap': np.min(energy_gaps) if len(energy_gaps) > 0 else np.nan,
                    'avg_energy_gap': np.mean(energy_gaps) if len(energy_gaps) > 0 else np.nan,
                    'load_ratio': patterns.shape[0] / patterns.shape[1]  # ✅ Always computable
                }

            except Exception as e:
                print(f"❌ {subj_id} - {band} failed: {e}")

    return features


In [None]:
from collections import defaultdict
from scipy.stats import entropy
from itertools import combinations
from tqdm import tqdm  # ✅ Import tqdm

def hopfield_features(patterns_all_subjects, n_init=100):
    """
    Computes Hopfield features across all subjects and bands.
    Input:
        patterns_all_subjects: {subject_id: {band: binary array [time, channels]}}
    Returns:
        Dictionary of features per subject per band
    """
    features = defaultdict(dict)

    for subj_id, band_dict in tqdm(patterns_all_subjects.items(), desc="🧠 Subjects"):
        for band, patterns in tqdm(band_dict.items(), desc=f"  🎧 Bands ({subj_id})", leave=False):
            try:
                # Convert {0,1} → {-1,+1} if needed
                if np.any(patterns == 0):
                    patterns = np.where(patterns == 0, -1, 1)

                n_patterns, n_units = patterns.shape
                hopfield = HopfieldNetwork(n_units)
                hopfield.train_hebbian(patterns)

                # Run dynamics from n_init random states
                initial_states = np.random.choice([-1, 1], size=(n_init, n_units))
                stable_states, steps_list, energy_traces = hopfield.run_multiple(initial_states)

                # Convert states to tuples for hashing
                attractors = {}
                energies = []
                for s in stable_states:
                    key = tuple(s)
                    attractors[key] = attractors.get(key, 0) + 1
                    energies.append(hopfield.energy(s))

                basin_sizes = list(attractors.values())
                total_basins = sum(basin_sizes)

                # Energy differences
                energy_gaps = []
                for a1, a2 in combinations(attractors.keys(), 2):
                    e1 = hopfield.energy(np.array(a1))
                    e2 = hopfield.energy(np.array(a2))
                    energy_gaps.append(abs(e1 - e2))

                features[subj_id][band] = {
                    'n_attractors': len(attractors),
                    'avg_steps': np.mean(steps_list) if steps_list else np.nan,
                    'avg_energy': np.mean(energies) if energies else np.nan,
                    'mean_basin_size': np.mean(basin_sizes) if len(basin_sizes) > 0 else np.nan,
                    'basin_entropy': entropy(np.array(basin_sizes) / total_basins) if total_basins > 0 else np.nan,
                    'min_energy_gap': np.min(energy_gaps) if energy_gaps else np.nan,
                    'avg_energy_gap': np.mean(energy_gaps) if energy_gaps else np.nan,
                    'load_ratio': patterns.shape[0] / patterns.shape[1]
                }

            except Exception as e:
                print(f"❌ {subj_id} - {band} failed: {e}")

    return features


In [None]:
patterns_rest = extract_all_patterns(binary_rest)
hopfield_feats_rest = hopfield_features(patterns_rest)
patterns_task = extract_all_patterns(binary_task)
hopfield_feats_task = hopfield_features(patterns_task)


In [None]:
from collections import defaultdict

def groupwise_hopfield_features(hopfield_feats, good_ids, bad_ids, bands=['theta', 'alpha', 'beta', 'gamma', 'broadband']):
    features = {
        'good': defaultdict(lambda: defaultdict(list)),
        'bad': defaultdict(lambda: defaultdict(list))
    }

    for group, ids in [('good', good_ids), ('bad', bad_ids)]:
        for subj_id in ids:
            sid = str(subj_id)
            if sid not in hopfield_feats:
                print(f"⚠️ Missing subject: {sid}")
                continue

            for band in bands:
                if band not in hopfield_feats[sid]:
                    print(f"⚠️ Missing band: {band} for subject {sid}")
                    continue

                for metric, value in hopfield_feats[sid][band].items():
                    # ✅ Sanity check: skip None, NaN, or invalid types
                    if value is None:
                        print(f"⚠️ Skipping {metric} for subject {sid} (None)")
                        continue
                    if isinstance(value, (float, int, np.float64, np.int64)):
                        if np.isnan(value):
                            print(f"⚠️ Skipping {metric} for subject {sid} (NaN)")
                            continue
                        features[group][band][metric].append(float(value))
                    else:
                        print(f"⚠️ Invalid type for {metric} in subject {sid}: {type(value)}")
                        continue

    return features


In [None]:
features_by_group_rest = groupwise_hopfield_features(hopfield_feats_rest, good_counters, bad_counters)
features_by_group_task = groupwise_hopfield_features(hopfield_feats_task, good_counters, bad_counters)

In [None]:
with open("binarized/binary_task.pkl", "rb") as f:
    hopfield_binaries_task = pickle.load(f)
with open("binarized/binary_rest.pkl", "rb") as f:
    hopfield_binaries_rest = pickle.load(f)
features_by_group_rest = hopfield_features(hopfield_binaries_rest, n_init=100)
features_by_group_task = hopfield_features(hopfield_binaries_task, n_init=100)

In [None]:
all_hopfield_metrics = [
    'n_attractors', 'avg_steps', 'avg_energy',
    'mean_basin_size', 'basin_entropy',
    'min_energy_gap', 'avg_energy_gap', 'load_ratio'
]

for band in ['theta', 'alpha', 'beta', 'gamma']:
    for metric in all_hopfield_metrics:
        compare_group_feature(features_by_group_task, band, metric)

for band in ['theta', 'alpha', 'beta', 'gamma']:
    for metric in all_hopfield_metrics:
        compare_group_feature(features_by_group_rest, band, metric)


In [None]:
import numpy as np
import pickle
from collections import defaultdict
from scipy.stats import entropy
from itertools import combinations
from tqdm import tqdm
import random

# === Hopfield network implementation ===
class HopfieldNetwork:
    def __init__(self, n_units):
        self.n_units = n_units
        self.weights = np.zeros((n_units, n_units))

    def train_hebbian(self, patterns):
        self.weights = np.dot(patterns.T, patterns)
        np.fill_diagonal(self.weights, 0)

    def energy(self, state):
        return -0.5 * np.dot(state, np.dot(self.weights, state))

    def run(self, state, max_iter=100):
        prev = state.copy()
        for _ in range(max_iter):
            for i in range(self.n_units):
                s = np.dot(self.weights[i], prev)
                prev[i] = 1 if s >= 0 else -1
        return prev

    def run_multiple(self, initial_states):
        stable_states = []
        steps_list = []
        energy_traces = []

        for init in initial_states:
            state = init.copy()
            steps = 0
            trace = []

            for _ in range(100):
                trace.append(self.energy(state))
                new_state = self.run(state)
                steps += 1
                if np.array_equal(new_state, state):
                    break
                state = new_state

            stable_states.append(state)
            steps_list.append(steps)
            energy_traces.append(trace)

        return stable_states, steps_list, energy_traces

# === Feature extraction from Hopfield dynamics ===
def hopfield_features(patterns_all_subjects, n_init=100, max_pairs=5000):
    features = defaultdict(dict)

    for subj_id, band_dict in tqdm(patterns_all_subjects.items(), desc="🧠 Subjects"):
        for band, patterns in tqdm(band_dict.items(), desc=f"  🎧 Bands ({subj_id})", leave=False):
            try:
                if np.any(patterns == 0):
                    patterns = np.where(patterns == 0, -1, 1)

                n_patterns, n_units = patterns.shape
                hopfield = HopfieldNetwork(n_units)
                hopfield.train_hebbian(patterns)

                initial_states = np.random.choice([-1, 1], size=(n_init, n_units))
                stable_states, steps_list, energy_traces = hopfield.run_multiple(initial_states)

                attractors = {}
                energies = []

                for s in stable_states:
                    key = tuple(s)
                    attractors[key] = attractors.get(key, 0) + 1
                    energies.append(hopfield.energy(s))

                basin_sizes = list(attractors.values())
                total_basins = sum(basin_sizes)

                # Sample energy gaps
                attractor_keys = list(attractors.keys())
                sampled_pairs = list(combinations(attractor_keys, 2))
                if len(sampled_pairs) > max_pairs:
                    sampled_pairs = random.sample(sampled_pairs, max_pairs)

                energy_gaps = [
                    abs(hopfield.energy(np.array(a1)) - hopfield.energy(np.array(a2)))
                    for a1, a2 in sampled_pairs
                ]

                features[subj_id][band] = {
                    'n_attractors': len(attractors),
                    'avg_steps': np.mean(steps_list) if steps_list else np.nan,
                    'avg_energy': np.mean(energies) if energies else np.nan,
                    'mean_basin_size': np.mean(basin_sizes) if basin_sizes else np.nan,
                    'basin_entropy': entropy(np.array(basin_sizes) / total_basins) if total_basins > 0 else np.nan,
                    'min_energy_gap': np.min(energy_gaps) if energy_gaps else np.nan,
                    'avg_energy_gap': np.mean(energy_gaps) if energy_gaps else np.nan,
                    'load_ratio': n_patterns / n_units
                }

            except Exception as e:
                print(f"❌ {subj_id} - {band} failed: {e}")

    return features

# === Groupwise summarization ===
def groupwise_hopfield_features(hopfield_feats, good_ids, bad_ids, bands=['theta', 'alpha', 'beta', 'gamma', 'broadband']):
    features = {
        'good': defaultdict(lambda: defaultdict(list)),
        'bad': defaultdict(lambda: defaultdict(list))
    }

    for group, ids in [('good', good_ids), ('bad', bad_ids)]:
        for subj_id in ids:
            sid = str(subj_id)
            if sid not in hopfield_feats:
                print(f"⚠️ Missing subject: {sid}")
                continue

            for band in bands:
                if band not in hopfield_feats[sid]:
                    print(f"⚠️ Missing band: {band} for subject {sid}")
                    continue

                for metric, value in hopfield_feats[sid][band].items():
                    if value is None or (isinstance(value, float) and np.isnan(value)):
                        continue
                    features[group][band][metric].append(float(value))

    return features

# === Feature comparison ===
def compare_group_feature(grouped_feats, band, metric):
    from scipy.stats import ttest_ind
    good = grouped_feats['good'][band][metric]
    bad = grouped_feats['bad'][band][metric]
    if not good or not bad:
        print(f"⚠️ Insufficient data for {band} - {metric}")
        return
    tstat, pval = ttest_ind(good, bad, equal_var=False)
    print(f"📊 {metric.upper()} ({band}):")
    print(f"   Good mean: {np.mean(good):.3f} | Bad mean: {np.mean(bad):.3f} | p = {pval:.4f}\n")

# === Main execution ===
if __name__ == "__main__":
    # Load binarized data
    with open("binarized/binary_rest.pkl", "rb") as f:
        binary_rest = pickle.load(f)
    with open("binarized/binary_task.pkl", "rb") as f:
        binary_task = pickle.load(f)

    # Extract patterns
    patterns_rest = binary_rest  # Already band-structured
    patterns_task = binary_task

    # Extract features
    hopfield_feats_rest = hopfield_features(patterns_rest, n_init=100)
    hopfield_feats_task = hopfield_features(patterns_task, n_init=100)

    # Define group labels
    #good_counters = [...]  # Fill with list of good subject IDs
    #bad_counters = [...]   # Fill with list of bad subject IDs

    # Groupwise aggregation
    features_by_group_rest = groupwise_hopfield_features(hopfield_feats_rest, good_counters, bad_counters)
    features_by_group_task = groupwise_hopfield_features(hopfield_feats_task, good_counters, bad_counters)

    # Metrics to evaluate
    all_hopfield_metrics = [
        'n_attractors', 'avg_steps', 'avg_energy',
        'mean_basin_size', 'basin_entropy',
        'min_energy_gap', 'avg_energy_gap', 'load_ratio'
    ]

    print("\n==== TASK COMPARISON ====")
    for band in ['theta', 'alpha', 'beta', 'gamma']:
        for metric in all_hopfield_metrics:
            compare_group_feature(features_by_group_task, band, metric)

    print("\n==== REST COMPARISON ====")
    for band in ['theta', 'alpha', 'beta', 'gamma']:
        for metric in all_hopfield_metrics:
            compare_group_feature(features_by_group_rest, band, metric)


In [None]:

# Save the Hopfield features for ML
with open("all_features/hopf_features_task.pkl", "wb") as f:
    pickle.dump(hopfield_feats_task, f)

print("✅ Saved Hopfield features to all_features/hopf_features_task.pkl")


# Save the Hopfield features for ML
with open("all_features/hopf_features_rest.pkl", "wb") as f:
    pickle.dump(hopfield_feats_rest, f)

print("✅ Saved Hopfield features to all_features/hopf_features_rest.pkl")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# === Manually entered from your table ===
data = {
    "Alpha - Attractor Count": [2.77, 2.20, 0.1197],
    "Beta - Attractor Count": [3.38, 2.70, 0.1772],
    "Beta - Basin Entropy": [1.02, 0.84, 0.1148],
    "Gamma - Average Steps to Convergence": [4.24, 4.04, 0.0845],
    "Gamma - Mean Basin Size": [34.21, 25.42, 0.1222],
}

labels = list(data.keys())
good_means = [v[0] for v in data.values()]
bad_means = [v[1] for v in data.values()]
pvals = [v[2] for v in data.values()]

x = np.arange(len(labels))
width = 0.45

fig, ax = plt.subplots(figsize=(12, 6))
bars1 = ax.bar(x - width/2, good_means, width, label='Good', color='green')
bars2 = ax.bar(x + width/2, bad_means, width, label='Bad', color='red')

# Add p-values above bars
for i, (g, b, p) in enumerate(zip(good_means, bad_means, pvals)):
    y = max(g, b) + 0.6
    ax.text(i, y, f"p = {p:.3f}", ha='center', fontsize=9, fontweight='bold')

ax.set_ylabel('Feature Value')
#ax.set_title('Hopfield Network Features (Good vs Bad)')
ax.set_xticks(x)
ax.set_xticklabels(labels, rotation=45, ha='right')
ax.legend()
plt.grid(axis='y', linestyle='--', alpha=0.4)
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# === Data with clearer formatting ===
data = {
    r"$\alpha$ - Attractor Count": [2.77, 2.20, 0.1197],
    r"$\beta$ - Attractor Count": [3.38, 2.70, 0.1772],
    r"$\beta$ - Basin Entropy": [1.02, 0.84, 0.1148],
    r"$\gamma$ - Avg Steps to Convergence": [4.24, 4.04, 0.0845],
    r"$\gamma$ - Mean Basin Size": [34.21, 25.42, 0.1222],
}

labels = list(data.keys())
good_means = [v[0] for v in data.values()]
bad_means = [v[1] for v in data.values()]
pvals = [v[2] for v in data.values()]

x = np.arange(len(labels))
width = 0.4

fig, ax = plt.subplots(figsize=(12, 6))
bars1 = ax.bar(x - width/2, good_means, width, label='Good', color='green')
bars2 = ax.bar(x + width/2, bad_means, width, label='Bad', color='crimson')

# Annotate p-values above the taller bars
for i, (g, b, p) in enumerate(zip(good_means, bad_means, pvals)):
    y = max(g, b) + 1
    ax.text(i, y, f"$p$ = {p:.3f}", ha='center', fontsize=9, fontweight='bold')

# Format axis
ax.set_ylabel('Feature Value', fontsize=12)
ax.set_xticks(x)
ax.set_xticklabels(labels, rotation=20, ha='right', fontsize=10)

# Shift x-tick labels slightly to the right
for tick in ax.get_xticklabels():
    tick.set_x(tick.get_position()[0] - 3.15)

# Optional: adjust spacing between ticks and axis
ax.tick_params(axis='x', pad=5)

ax.legend(fontsize=10, loc='best')
ax.set_title("Comparison of Good vs Bad Performers Across EEG-Derived Features")
ax.grid(axis='y', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

def plot_group_feature_distribution(features_by_group, metric, save_fig=False):
    """
    Plot boxplots of a given Hopfield metric across bands and groups.
    """
    records = []
    for group in ['good', 'bad']:
        for band in ['theta', 'alpha', 'beta', 'gamma']:
            values = features_by_group[group][band].get(metric, [])
            for v in values:
                records.append({'Group': group, 'Band': band, 'Value': v})

    df = pd.DataFrame(records)

    plt.figure(figsize=(10, 5))
    sns.boxplot(x='Band', y='Value', hue='Group', data=df, palette='Set2')
    plt.title(f"Group Comparison for {metric.replace('_', ' ').title()}")
    plt.grid(True)
    plt.tight_layout()

    if save_fig:
        plt.savefig(f"plots/{metric}.png", dpi=300)
    plt.show()


In [None]:
for metric in all_hopfield_metrics:
    plot_group_feature_distribution(features_by_group, metric)


In [None]:
import os

def plot_group_feature_distribution(features_by_group, metric, save_fig=False, out_dir="plots"):
    """
    Plot boxplots comparing a given Hopfield feature across groups and EEG bands.
    """
    records = []
    for group in ['good', 'bad']:
        for band in ['theta', 'alpha', 'beta', 'gamma']:
            values = features_by_group[group][band].get(metric, [])
            for v in values:
                records.append({'Group': group, 'Band': band, 'Value': v})

    df = pd.DataFrame(records)

    if df.empty:
        print(f"⚠️ No data found for metric: {metric}")
        return

    plt.figure(figsize=(10, 5))
    sns.boxplot(x='Band', y='Value', hue='Group', data=df, palette='Set2')
    plt.title(f"📊 {metric.replace('_', ' ').title()} across EEG Bands")
    plt.grid(True)
    plt.tight_layout()

    if save_fig:
        os.makedirs(out_dir, exist_ok=True)
        plt.savefig(f"{out_dir}/{metric}.png", dpi=300)
        print(f"✅ Saved: {out_dir}/{metric}.png")
    plt.show()


In [None]:
all_hopfield_metrics = [
    'n_attractors', 'avg_steps', 'avg_energy',
    'mean_basin_size', 'basin_entropy',
    'min_energy_gap', 'avg_energy_gap', 'load_ratio'
]

for metric in all_hopfield_metrics:
    plot_group_feature_distribution(features_by_group, metric)


## Classification using Hopfield Features

In [None]:
import numpy as np

def prepare_classification_data(features_by_group, bands=['theta', 'alpha', 'beta', 'gamma']):
    """
    Create feature matrix X and label vector y.
    Each row = one subject × band
    """
    rows = []
    for group_label, group in enumerate(['good', 'bad']):  # good → 0, bad → 1
        for subj_idx in range(len(features_by_group[group]['theta']['n_attractors'])):
            for band in bands:
                try:
                    row = {
                        'label': group_label,
                        'band': band,
                    }
                    for metric in all_hopfield_metrics:
                        values = features_by_group[group][band].get(metric, [])
                        if subj_idx < len(values):
                            row[metric] = values[subj_idx]
                        else:
                            row[metric] = np.nan
                    rows.append(row)
                except Exception as e:
                    print(f"⚠️ Error processing subject {subj_idx} in {group}-{band}: {e}")
    return pd.DataFrame(rows)

df = prepare_classification_data(features_by_group)
df = df.dropna()  # remove any rows with missing values


In [None]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler

# Encode band as one-hot
df_encoded = pd.get_dummies(df, columns=['band'])

X = df_encoded.drop('label', axis=1).values
y = df_encoded['label'].values

# Standardize features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3, random_state=42)

# Train classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# Evaluate
y_pred = clf.predict(X_test)
print("📊 Classification Report:")
print(classification_report(y_test, y_pred))
print("📉 Confusion Matrix:")
print(confusion_matrix(y_test, y_pred))


In [None]:
import matplotlib.pyplot as plt

feat_names = df_encoded.drop('label', axis=1).columns
importances = clf.feature_importances_

plt.figure(figsize=(12, 6))
plt.barh(feat_names, importances)
plt.title("🎯 Feature Importance")
plt.xlabel("Importance Score")
plt.tight_layout()
plt.show()


In [None]:
from imblearn.over_sampling import SMOTE

sm = SMOTE(random_state=42)
X_res, y_res = sm.fit_resample(X_train, y_train)

clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_res, y_res)

y_pred = clf.predict(X_test)
print(classification_report(y_test, y_pred))
print(confusion_matrix(y_test, y_pred))


In [None]:
clf = RandomForestClassifier(n_estimators=100, class_weight='balanced', random_state=42)
clf.fit(X_train, y_train)


In [None]:
y_pred = clf.predict(X_test)

In [None]:
from sklearn.metrics import classification_report, confusion_matrix

print("📊 Classification Report:")
print(classification_report(y_test, y_pred))

print("\n📉 Confusion Matrix:")
print(confusion_matrix(y_test, y_pred))

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay

ConfusionMatrixDisplay.from_predictions(y_test, y_pred, display_labels=['Good', 'Bad'])
plt.title("Confusion Matrix")
plt.show()


In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import pandas as pd
import matplotlib.pyplot as plt

# ----------------------
# STEP 1: Prepare Data
# ----------------------

# Assuming you already have your features and labels:
# X = features (as DataFrame or array)
# y = labels (0 for good, 1 for bad)

# Split into train/test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, stratify=y, random_state=42)

# ----------------------
# STEP 2: Train Classifier
# ----------------------

clf = RandomForestClassifier(n_estimators=100, class_weight='balanced', random_state=42)
clf.fit(X_train, y_train)

# ----------------------
# STEP 3: Evaluate
# ----------------------

y_pred = clf.predict(X_test)

print("\n📊 Classification Report:")
print(classification_report(y_test, y_pred))

print("\n📉 Confusion Matrix:")
print(confusion_matrix(y_test, y_pred))

# ----------------------
# STEP 4: Feature Importance
# ----------------------

# Feature names
feature_names = X.columns if isinstance(X, pd.DataFrame) else [f'feature_{i}' for i in range(X.shape[1])]
importances = clf.feature_importances_

# Create Series and plot
feat_imp = pd.Series(importances, index=feature_names).sort_values(ascending=False)

print("\n🔎 Top 10 Important Features:")
print(feat_imp.head(10))

# Plot top 10
feat_imp.head(10).plot(kind='barh', title='Top 10 Feature Importances', figsize=(8, 5))
plt.xlabel("Importance")
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()


In [None]:
def flatten_hopfield_features(hopfield_feats, subject_ids=None):
    """
    Flatten nested Hopfield feature dictionary into a tabular format.
    Returns:
        - X: DataFrame of shape (n_subjects, n_features)
        - y: Labels (0=bad, 1=good)
    """
    all_rows = []
    all_subjects = list(hopfield_feats.keys()) if subject_ids is None else subject_ids
    bands = ['theta', 'alpha', 'beta', 'gamma']
    features = [
        'n_attractors', 'avg_steps', 'avg_energy',
        'mean_basin_size', 'basin_entropy',
        'min_energy_gap', 'avg_energy_gap', 'load_ratio'
    ]

    for sid in all_subjects:
        row = {}
        if sid not in hopfield_feats:
            continue
        for band in bands:
            if band not in hopfield_feats[sid]:
                continue
            for feat in features:
                value = hopfield_feats[sid][band].get(feat, np.nan)
                row[f"{band}_{feat}"] = value
        row["subject_id"] = sid
        row["label"] = 1 if sid in good_counters else 0  # good = 1, bad = 0
        all_rows.append(row)

    df = pd.DataFrame(all_rows).set_index("subject_id")
    df = df.dropna(axis=0)  # Drop subjects with missing features
    return df.drop(columns="label"), df["label"]


In [None]:
def build_X_y_from_features(features_by_group, metrics, bands):
    """
    Convert nested feature dictionary into feature matrix X and labels y.
    Output:
        - X: shape (n_samples, n_features)
        - y: binary labels: 0=good, 1=bad
    """
    X = []
    y = []
    feature_names = []

    for group_label, group in enumerate(['good', 'bad']):
        for subj_idx in range(len(features_by_group[group][bands[0]][metrics[0]])):
            feature_vector = []
            for band in bands:
                for metric in metrics:
                    try:
                        value = features_by_group[group][band][metric][subj_idx]
                    except IndexError:
                        value = np.nan
                    feature_vector.append(value)
            X.append(feature_vector)
            y.append(group_label)

    feature_names = [f"{band}_{metric}" for band in bands for metric in metrics]
    return np.array(X), np.array(y), feature_names


In [None]:
# Set your desired bands and metrics
band_list = ['theta', 'alpha', 'beta', 'gamma']
metric_list = [
    'n_attractors', 'avg_steps', 'avg_energy',
    'mean_basin_size', 'basin_entropy',
    'min_energy_gap', 'avg_energy_gap', 'load_ratio'
]

# Build X and y
X, y, feature_names = build_X_y_from_features(features_by_group, metric_list, band_list)

# Now build the DataFrame
X_df = pd.DataFrame(X, columns=feature_names)


In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pandas as pd

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X_df, y, stratify=y, test_size=0.25, random_state=42)

# Train model
clf = RandomForestClassifier(n_estimators=100, class_weight='balanced', random_state=42)
clf.fit(X_train, y_train)

# Feature importance
importances = clf.feature_importances_
feat_imp = pd.Series(importances, index=X_df.columns).sort_values(ascending=False)

# Plot
feat_imp.head(10).plot(kind='barh', title='Top 10 Feature Importances', figsize=(8, 5))
plt.xlabel("Importance")
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()


In [None]:
from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter

def plot_attractors_pca(hopfield_feats, patterns_all, subject_id='12', band='alpha', n_init=100):
    """
    Run Hopfield dynamics and plot final attractors in PCA space.
    """
    patterns = patterns_all[subject_id][band]
    n_patterns, n_units = patterns.shape

    # Train the Hopfield model
    model = HopfieldNetwork(n_units)
    model.train_hebbian(patterns)

    # Run dynamics from random initial states
    init_states = np.random.choice([-1, 1], size=(n_init, n_units))
    stable_states, _, _ = model.run_multiple(init_states)

    # Convert to array and count attractors
    stable_states = np.array(stable_states)
    attractor_tuples = [tuple(s) for s in stable_states]
    counts = Counter(attractor_tuples)
    unique_states = np.array(list(counts.keys()))
    sizes = np.array(list(counts.values()))

    # PCA projection to 2D
    pca = PCA(n_components=2)
    proj = pca.fit_transform(unique_states)

    # Plot
    plt.figure(figsize=(6, 5))
    scatter = plt.scatter(proj[:, 0], proj[:, 1], s=sizes*10, c=sizes, cmap='viridis', edgecolor='k')
    plt.colorbar(scatter, label="Basin Size")
    plt.title(f"PCA of Attractors - Subject {subject_id} | {band.upper()}")
    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.grid(True)
    plt.tight_layout()
    plt.show()


In [None]:
from collections import Counter

def plot_basin_histogram(hopfield_feats, patterns_dict, subject_id, band):
    patterns = patterns_dict[subject_id][band]
    model = HopfieldNetwork(n_units=patterns.shape[1])
    model.train_hebbian(patterns)

    stable_states, _, _ = model.run_multiple(patterns)
    attractors = [tuple(s) for s in stable_states]
    basin_sizes = Counter(attractors).values()

    plt.figure(figsize=(6, 4))
    plt.hist(basin_sizes, bins=range(1, max(basin_sizes)+2), edgecolor='black')
    plt.title(f"Basin Sizes | Subject {subject_id}, Band: {band}")
    plt.xlabel("Size of Basin")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.show()


In [None]:
import matplotlib.pyplot as plt

def plot_energy_trajectories(hopfield_feats, patterns_all, subject_id='12', band='alpha'):
    """
    Plot energy trajectories for a given subject and band.
    """
    patterns = patterns_all[subject_id][band]
    n_patterns, n_units = patterns.shape

    # Re-train Hopfield network
    model = HopfieldNetwork(n_units)
    model.train_hebbian(patterns)

    # Run dynamics
    init_states = patterns[:10]  # or random subset
    _, _, energy_traces = model.run_multiple(init_states)

    # Plot energy over time for each run
    plt.figure(figsize=(8, 5))
    for traj in energy_traces:
        plt.plot(traj, alpha=0.6)
    plt.xlabel("Step")
    plt.ylabel("Energy")
    plt.title(f"Energy Trajectories - Subject {subject_id} | {band.upper()}")
    plt.grid(True)
    plt.tight_layout()
    plt.show()


In [None]:
def compute_local_curvature(hopfield_model, n_units):
    """
    Compute discrete curvature for each binary state in {-1, +1}^n.
    Curvature = sum of energy differences between a state and its Hamming neighbors.

    Returns:
        state_energy: list of energies
        curvature: list of curvature values (same order as states)
    """
    all_states = all_binary_states(n_units)  # shape: (2^n, n)
    state_energy = []
    curvature = []

    for i, s in enumerate(all_states):
        e_s = hopfield_model.energy(s)
        state_energy.append(e_s)

        # Hamming neighbors
        neighbors = []
        for j in range(n_units):
            neighbor = s.copy()
            neighbor[j] *= -1
            neighbors.append(neighbor)

        energy_diffs = [hopfield_model.energy(n) - e_s for n in neighbors]
        kappa = sum(energy_diffs)
        curvature.append(kappa)

    return np.array(state_energy), np.array(curvature)


In [None]:
# Load patterns for subject '12' and band 'alpha'
patterns = patterns_rest['12']['alpha']

In [None]:
for subj in patterns_rest:
    for band in patterns_rest[subj]:
        print(subj, band, patterns_rest[subj][band].shape)


In [None]:
patterns = patterns_rest['29']['theta']


In [None]:
# Train model
model = HopfieldNetwork(n_units=patterns.shape[1])
model.train_hebbian(patterns)

# Compute all binary states
states = all_binary_states(patterns.shape[1])  # (524288, 19)

# Compute energy and curvature
state_energy = np.array([model.energy(s) for s in states])

from scipy.spatial import distance_matrix
from scipy.sparse import csgraph

def compute_local_curvature(model, n_units):
    states = all_binary_states(n_units)
    energy = np.array([model.energy(s) for s in states])
    curvature = []
    for i, s in enumerate(states):
        neighbors = [s ^ (1 << j) for j in range(n_units)]  # Hamming neighbors
        e0 = energy[i]
        en = [model.energy(states[j]) for j in neighbors]
        lap = sum(en) - len(neighbors) * e0
        curvature.append(lap)
    return np.array(curvature)

# Optional: downsample for faster processing
subset_idx = np.random.choice(len(states), size=3000, replace=False)
states_sample = states[subset_idx]
energy_sample = state_energy[subset_idx]


In [None]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

pca = PCA(n_components=2)
xy = pca.fit_transform(states_sample)

plt.figure(figsize=(8, 6))
plt.scatter(xy[:, 0], xy[:, 1], c=energy_sample, cmap='viridis', s=10)
plt.colorbar(label="Energy")
plt.title("Hopfield Energy Landscape (PCA view)")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.grid(True)
plt.show()


In [None]:
from sklearn.decomposition import PCA
from tqdm import tqdm

def energy_landscape_features(patterns_dict, max_states=2**19):
    features = {}
    
    for subj_id in tqdm(patterns_dict):
        features[subj_id] = {}
        for band in patterns_dict[subj_id]:
            patterns = patterns_dict[subj_id][band]
            n_units = patterns.shape[1]

            try:
                # Train Hopfield
                model = HopfieldNetwork(n_units)
                model.train_hebbian(patterns)

                # All states
                states = all_binary_states(n_units)
                energies = np.array([model.energy(s) for s in states])

                # PCA reduction
                pca = PCA(n_components=2)
                coords = pca.fit_transform(states)

                # Store summary
                features[subj_id][band] = {
                    'mean_energy': float(np.mean(energies)),
                    'min_energy': float(np.min(energies)),
                    'max_energy': float(np.max(energies)),
                    'energy_std': float(np.std(energies)),
                    'energy_range': float(np.ptp(energies))
                }

            except Exception as e:
                print(f"❌ {subj_id} {band} failed: {e}")
                continue

    return features


In [None]:
landscape_feats_rest = energy_landscape_features(patterns_rest)


In [None]:
from collections import defaultdict
from scipy.stats import ttest_ind

def groupwise_energy_features(feats, good_ids, bad_ids, bands, metrics):
    grouped = {'good': defaultdict(list), 'bad': defaultdict(list)}

    for group, ids in [('good', good_ids), ('bad', bad_ids)]:
        for sid in ids:
            sid = str(sid)
            if sid not in feats:
                continue
            for band in bands:
                if band in feats[sid]:
                    for metric in metrics:
                        val = feats[sid][band].get(metric, None)
                        if val is not None and not np.isnan(val):
                            grouped[group][f"{band}_{metric}"].append(val)
    return grouped

# Define groups
good_ids = ['1','2','3','5','7','8','11','12','13','15','16','17','18','20','23','24','25','26','27','28','29','31','32','33','34','35']
bad_ids  = ['0','4','6','9','10','14','19','21','22','30']

metrics = ['mean_energy', 'min_energy', 'max_energy', 'energy_std', 'energy_range']
bands = ['theta', 'alpha', 'beta', 'gamma']

grouped_energy_feats = groupwise_energy_features(landscape_feats_rest, good_ids, bad_ids, bands, metrics)


In [None]:
def compare_grouped_features(grouped_feats):
    for metric in grouped_feats['good']:
        good = grouped_feats['good'][metric]
        bad  = grouped_feats['bad'][metric]
        if len(good) > 1 and len(bad) > 1:
            t, p = ttest_ind(good, bad, equal_var=False)
            print(f"📊 {metric.upper()}")
            print(f"   ➤ Good mean: {np.mean(good):.2f}")
            print(f"   ➤ Bad  mean: {np.mean(bad):.2f}")
            print(f"   ➤ t = {t:.2f}, p = {p:.4f}\n")
        else:
            print(f"⚠️ Skipping {metric} due to insufficient data.")

compare_grouped_features(grouped_energy_feats)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_grouped_feature_comparison(grouped_feats, figsize=(12, 6)):
    metrics = list(grouped_feats['good'].keys())
    metrics.sort()  # Sort to group same band together
    x = np.arange(len(metrics))  # x-axis positions
    
    # Prepare means and std
    good_means = [np.mean(grouped_feats['good'][m]) for m in metrics]
    bad_means  = [np.mean(grouped_feats['bad'][m]) for m in metrics]
    good_stds  = [np.std(grouped_feats['good'][m]) for m in metrics]
    bad_stds   = [np.std(grouped_feats['bad'][m]) for m in metrics]

    # Plot
    width = 0.35
    fig, ax = plt.subplots(figsize=figsize)
    ax.bar(x - width/2, good_means, width, yerr=good_stds, capsize=5, label='Good', color='skyblue')
    ax.bar(x + width/2, bad_means, width, yerr=bad_stds, capsize=5, label='Bad', color='salmon')
    
    ax.set_ylabel('Feature Value')
    ax.set_title('Energy Landscape Feature Comparison (Good vs Bad)')
    ax.set_xticks(x)
    ax.set_xticklabels(metrics, rotation=45, ha='right')
    ax.legend()
    plt.tight_layout()
    plt.grid(True, axis='y', linestyle='--', alpha=0.6)
    plt.show()


In [None]:
plot_grouped_feature_comparison(grouped_energy_feats)


In [None]:
all_tda_metrics = [
    'tda_entropy',        # Entropy of persistence diagram
    'tda_betti0_max',     # Max number of connected components
    'tda_betti1_max',     # Max number of 1D holes (loops)
    'tda_pers1_mean',     # Avg lifetime of 1D holes
    'tda_pers1_max',      # Max persistence of 1D holes
    'tda_n_long_holes'    # Count of holes with lifetime > threshold
]


In [None]:
from ripser import ripser
from scipy.stats import entropy
from scipy.spatial.distance import squareform, pdist
import numpy as np

# Step 4: Compute a distance matrix (Hamming or energy-weighted)
def compute_distance_matrix(states, energies=None, use_energy_weighted=False):
    """
    Compute a distance matrix for binary states.
    If `use_energy_weighted` is True, weight distances by energy differences.
    """
    hamming_distances = squareform(pdist(states, metric='hamming'))
    if use_energy_weighted and energies is not None:
        energy_weights = np.exp(-np.abs(energies[:, None] - energies[None, :]))
        return hamming_distances * energy_weights
    return hamming_distances

# Step 5: Run persistent homology with ripser
def run_persistent_homology(distance_matrix):
    """
    Run persistent homology on the distance matrix using ripser.
    """
    diagrams = ripser(distance_matrix, distance_matrix=True)['dgms']
    return diagrams

# Step 6: Extract TDA features
def extract_tda_features(diagrams):
    """
    Extract TDA features from persistence diagrams.
    """
    features = {}
    # Persistence entropy
    features['tda_entropy'] = entropy([d[1] - d[0] for d in diagrams[1] if d[1] < np.inf])
    # Betti numbers
    features['tda_betti0_max'] = len(diagrams[0])  # Number of connected components
    features['tda_betti1_max'] = len(diagrams[1])  # Number of 1D holes (loops)
    # Persistence statistics for 1D holes
    lifetimes = [d[1] - d[0] for d in diagrams[1] if d[1] < np.inf]
    features['tda_pers1_mean'] = np.mean(lifetimes) if lifetimes else 0
    features['tda_pers1_max'] = np.max(lifetimes) if lifetimes else 0
    features['tda_n_long_holes'] = sum(1 for l in lifetimes if l > 0.1)  # Threshold for "long" holes
    return features

# Step 7: Full TDA pipeline
def tda_pipeline(patterns_dict, energies_dict=None, use_energy_weighted=False):
    """
    Full TDA pipeline for all subjects and bands.
    If energies_dict is not provided, compute energies using HopfieldNetwork.
    """
    tda_features = {}
    if energies_dict is None:
        energies_dict = {}

    for subj_id, bands in patterns_dict.items():
        tda_features[subj_id] = {}
        energies_dict[subj_id] = energies_dict.get(subj_id, {})
        for band, states in bands.items():
            try:
                # Compute energies if not already done
                if band not in energies_dict[subj_id]:
                    model = HopfieldNetwork(n_units=states.shape[1])
                    model.train_hebbian(states)
                    energies = np.array([model.energy(state) for state in states])
                    energies_dict[subj_id][band] = energies
                else:
                    energies = energies_dict[subj_id][band]
                
                # Compute distance matrix
                distance_matrix = compute_distance_matrix(states, energies, use_energy_weighted)
                
                # Run persistent homology
                diagrams = run_persistent_homology(distance_matrix)
                
                # Extract TDA features
                tda_features[subj_id][band] = extract_tda_features(diagrams)
            
            except Exception as e:
                print(f"❌ Failed for subject {subj_id}, band {band}: {e}")
                continue
    return tda_features, energies_dict

# Example usage
# Assuming `patterns_rest` contains binary states
tda_features_rest, energies_rest = tda_pipeline(patterns_rest, use_energy_weighted=True)

# Compare features at the end
def compare_tda_features(tda_features, good_ids, bad_ids):
    """
    Compare TDA features between good and bad groups.
    """
    grouped = {'good': defaultdict(list), 'bad': defaultdict(list)}
    for group, ids in [('good', good_ids), ('bad', bad_ids)]:
        for subj_id in ids:
            subj_id = str(subj_id)
            if subj_id not in tda_features:
                continue
            for band, features in tda_features[subj_id].items():
                for metric, value in features.items():
                    grouped[group][metric].append(value)
    return grouped

# Define good and bad groups
good_ids = ['1', '2', '3', '5', '7', '8', '11', '12', '13', '15', '16', '17', '18', '20', '23', '24', '25', '26', '27', '28', '29', '31', '32', '33', '34', '35']
bad_ids = ['0', '4', '6', '9', '10', '14', '19', '21', '22', '30']

# Compare TDA features
grouped_tda_features = compare_tda_features(tda_features_rest, good_ids, bad_ids)

# Print comparison results
for metric in all_tda_metrics:
    good_values = grouped_tda_features['good'][metric]
    bad_values = grouped_tda_features['bad'][metric]
    if len(good_values) > 1 and len(bad_values) > 1:
        t_stat, p_value = ttest_ind(good_values, bad_values, equal_var=False)
        print(f"📊 {metric.upper()}")
        print(f"   ➤ Good mean: {np.mean(good_values):.2f}")
        print(f"   ➤ Bad  mean: {np.mean(bad_values):.2f}")
        print(f"   ➤ t = {t_stat:.2f}, p = {p_value:.4f}\n")
    else:
        print(f"⚠️ Not enough data for {metric}.")