In [5]:
import mne
import numpy as np
import os

data_dir = "data_raw"
save_dir = "data_trials_preprocessed"
os.makedirs(save_dir, exist_ok=True)

fs = 250
tmin, tmax = 0, 4
n_samples = int((tmax - tmin) * fs)
valid_events = [769, 770, 771, 772]

def remove_artifacts(X, y, threshold=100):
    bad_idx = []
    for i in range(X.shape[0]):
        if np.max(np.abs(X[i])) > threshold:
            bad_idx.append(i)
    if bad_idx:
        X = np.delete(X, bad_idx, axis=0)
        y = np.delete(y, bad_idx, axis=0)
    return X, y

def standardize_data(X):
    mu = np.mean(X, axis=(0,2), keepdims=True)
    std = np.std(X, axis=(0,2), keepdims=True) + 1e-8
    return (X - mu) / std

for fname in os.listdir(data_dir):
    if not fname.endswith("T.npz"):
        continue
    subj = fname.replace(".npz", "")
    data = np.load(os.path.join(data_dir, fname))

    s = data["s"].T[:22, :]          # 转为 (22, samples)
    etyp = data["etyp"][:, 0]
    epos = data["epos"][:, 0]

    trials, labels = [], []
    for label, pos in zip(etyp, epos):
        if label in valid_events:
            start = int(pos)
            end = start + n_samples
            if end <= s.shape[1]:
                trials.append(s[:, start:end])
                labels.append(label - 768)
    if len(trials) == 0:
        print(f"⚠️ {subj} has no valid trials")
        continue

    X = np.array(trials)
    y = np.array(labels)
    print(f"{subj}: Extracted {X.shape[0]} trials")

    # MNE band-pass filtering (8–30 Hz)
    ch_names = [f"EEG{i:03d}" for i in range(22)]
    info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types=["eeg"]*22)
    epochs = mne.EpochsArray(X, info)
    epochs_filtered = epochs.copy().filter(8, 30, method="fir", phase="zero-double")
    X_filtered = epochs_filtered.get_data()
    print(f"  → Filtered {X_filtered.shape}")

    # Artifact removal
    X_clean, y_clean = remove_artifacts(X_filtered, y, threshold=100)
    print(f"  → Clean trials: {X_clean.shape[0]}/{X_filtered.shape[0]}")

    # Standardization (same as your SL.ipynb)
    X_norm = standardize_data(X_clean)
    print(f"  → Normalized range: [{X_norm.min():.2f}, {X_norm.max():.2f}]")

    # Save processed data
    np.savez_compressed(
        os.path.join(save_dir, f"{subj}.npz"),
        X=X_norm, y=y_clean
    )
    print(f" Saved {subj}.npz ({X_norm.shape})")

print("\nAll subjects processed successfully!")


A01T: Extracted 288 trials
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-12 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-12 dB cutoff frequency: 33.75 Hz)
- Filter length: 413 samples (1.652 s)

  → Filtered (288, 22, 1000)
  → Clean trials: 288/288
  → Normalized range: [-7.32, 7.28]
 Saved A01T.npz ((288, 22, 1000))
A02T: Extracted 288 trials
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
--