In [4]:
import numpy as np
import mne
from mne.time_frequency import tfr_multitaper
from sklearn.svm import SVC
from sklearn.model_selection import KFold
from pyriemann.classification import MDM  # Import the correct classifier
from scipy.optimize import minimize

# Step 1: Load EEG Data
def load_eeg_data(file_path):
    # Example: Using MNE to load .edf or other EEG file formats
    raw = mne.io.read_raw_edf(file_path, preload=True)
    raw.pick_types(eeg=True)
    return raw

# Step 2: Divide Time Windows and Apply Filter Bank
def apply_filter_bands(raw, time_windows, freq_bands):
    epochs = []
    for tw in time_windows:
        for fb in freq_bands:
            # Create time window (using raw data or epochs)
            raw.crop(tmin=tw[0], tmax=tw[1])  # Time crop
            # Filter data by frequency band
            raw.filter(l_freq=fb[0], h_freq=fb[1])
            # Create epochs and append to the list
            epochs.append(raw.get_data())
    return np.array(epochs)

# Step 3: Compute Covariance Matrices
def compute_covariance_matrix(epochs):
    cov_matrices = []
    for epoch in epochs:
        # Compute covariance matrix for each epoch
        cov = np.cov(epoch)
        cov_matrices.append(cov)
    return np.array(cov_matrices)

# Step 4: Backtracking Search Optimization (for selecting optimal parameters)
def backtracking_search_optimization(time_windows, freq_bands):
    def objective_function(params):
        # Objective function to optimize - could be accuracy of classifier, for example
        time_window = time_windows[int(params[0])]
        freq_band = freq_bands[int(params[1])]
        # Here, you should implement your cross-validation or evaluation metric
        return -cross_validation(time_window, freq_band)

    # Optimization function - choose time window and frequency band
    result = minimize(objective_function, [0, 0], bounds=[(0, len(time_windows)-1), (0, len(freq_bands)-1)], method='nelder-mead')
    return result.x

# Step 5: Local Tangent Space Alignment (LTSA)
def ltsa(cov_matrices):
    # LTSA to reduce the dimensionality of the covariance matrix
    # Using PCA (for simplicity in this example)
    from sklearn.decomposition import PCA
    pca = PCA(n_components=2)
    reduced_matrices = [pca.fit_transform(cov) for cov in cov_matrices]
    return np.array(reduced_matrices)

# Step 6: Riemannian Minimum Distance (for classification)
def riemannian_minimum_distance(train_data, train_labels, test_data):
    # Train a classifier using Riemannian geometry (e.g., Minimum Distance to Mean)
    clf = MDM(metric='riemann')  # Use MDM with Riemannian metric
    clf.fit(train_data, train_labels)
    predictions = clf.predict(test_data)
    return predictions

# Step 7: Cross-validation and Model Evaluation
def cross_validation(time_window, freq_band):
    # Define cross-validation process here
    # Assume `epochs` is preloaded data and `labels` is a corresponding list of labels
    kf = KFold(n_splits=5)
    accuracy = 0
    for train_index, test_index in kf.split(epochs):
        X_train, X_test = epochs[train_index], epochs[test_index]
        y_train, y_test = labels[train_index], labels[test_index]
        
        # Compute covariance matrices and apply LTSA
        cov_train = compute_covariance_matrix(X_train)
        cov_test = compute_covariance_matrix(X_test)
        
        # Apply LTSA
        reduced_train = ltsa(cov_train)
        reduced_test = ltsa(cov_test)
        
        # Classification using Riemannian Geometry
        preds = riemannian_minimum_distance(reduced_train, y_train, reduced_test)
        
        # Calculate accuracy
        accuracy += np.sum(preds == y_test) / len(y_test)
    
    return accuracy / 5  # Average over 5 folds

# Main Function: Implement the entire flow
def main(file_path, time_windows, freq_bands):
    # Step 1: Load EEG data
    raw = load_eeg_data(file_path)

    # Step 2: Apply Filter Bank and Time Window Division
    epochs = apply_filter_bands(raw, time_windows, freq_bands)
    
    # Step 3: Compute Covariance Matrices
    cov_matrices = compute_covariance_matrix(epochs)
    
    # Step 4: Backtracking Search for Optimal Time Window and Frequency Band
    optimal_params = backtracking_search_optimization(time_windows, freq_bands)
    optimal_time_window = time_windows[int(optimal_params[0])]
    optimal_freq_band = freq_bands[int(optimal_params[1])]
    print(f"Optimal Time Window: {optimal_time_window}, Optimal Frequency Band: {optimal_freq_band}")
    
    # Step 5: Apply LTSA for Dimensionality Reduction
    reduced_matrices = ltsa(cov_matrices)
    
    # Step 6: Train Model and Perform Classification (Riemannian Geometry)
    predictions = riemannian_minimum_distance(reduced_matrices, labels, reduced_matrices)  # Use train-test split
    
    # Evaluate predictions
    print(f"Classification Predictions: {predictions}")

In [5]:
# Define the parameters (time windows, frequency bands)
time_windows = [(0, 1), (0.5, 1.5), (1, 2), (2, 3), (3, 4)]
freq_bands = [(8, 12), (13, 17), (19, 23), (20, 24), (21, 25)]

# File path to EEG data (example)
file_path = "./subdataset/edffile/sub-45/eeg/sub-45_task-motor-imagery_eeg.edf"

# Run the complete pipeline
# main(file_path, time_windows, freq_bands)

In [6]:
raw = load_eeg_data(file_path=file_path)
print(raw.ch_names)  # 查看剩下的通道名称
print(raw.info)      # 查看更新后的信息


Extracting EDF parameters from /root/autodl-tmp/.autodl/kinlaw/mi_swin/subdataset/edffile/sub-45/eeg/sub-45_task-motor-imagery_eeg.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 159999  =      0.000 ...   319.998 secs...
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
['FP1', 'FP2', 'Fz', 'F3', 'F4', 'F7', 'F8', 'FCz', 'FC3', 'FC4', 'FT7', 'FT8', 'Cz', 'C3', 'C4', 'T3', 'T4', 'CPz', 'CP3', 'CP4', 'TP7', 'TP8', 'Pz', 'P3', 'P4', 'T5', 'T6', 'Oz', 'O1', 'O2', 'HEOL', 'HEOR', '']
<Info | 8 non-empty values
 bads: []
 ch_names: FP1, FP2, Fz, F3, F4, F7, F8, FCz, FC3, FC4, FT7, FT8, Cz, C3, ...
 chs: 33 EEG
 custom_ref_applied: False
 highpass: 0.0 Hz
 lowpass: 250.0 Hz
 meas_date: 2022-12-01 19:00:09 UTC
 nchan: 33
 projs: []
 sfreq: 500.0 Hz
 subject_info: 4 items (dict)
>


CPz
  raw = mne.io.read_raw_edf(file_path, preload=True)
