In [1]:
from pathlib import Path
import json
import re
import numpy as np
import mne
import logging

# Set up logging - must be at the top
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


In [2]:

# 1. Input/Output directories
# ----------------------------------------------------------------------
BIDS_ROOT = Path("/Volumes/cmvm/scs/groups/HELIOS-BD/Part B/helios_bids_vep")        
DERIV_DIR  = BIDS_ROOT / "derivatives" / "mne-bids-pipeline-vep"
DECODE_ROOT = Path("/Users/farjam/OneDrive - University of Edinburgh/wellcome/Amir/Decoding_VEP")
DECODE_DIR = DECODE_ROOT / "Decoding_VEP" / "Luminance_NPZ"
DECODE_DIR.mkdir(parents=True, exist_ok=True)

logger.info(f"BIDS root: {BIDS_ROOT}")
logger.info(f"Derivatives directory: {DERIV_DIR}")
logger.info(f"Decoding root: {DECODE_ROOT}")
logger.info(f"Decoding directory: {DECODE_DIR}")


2025-08-30 02:30:24,579 - INFO - BIDS root: /Volumes/cmvm/scs/groups/HELIOS-BD/Part B/helios_bids_vep
2025-08-30 02:30:24,580 - INFO - Derivatives directory: /Volumes/cmvm/scs/groups/HELIOS-BD/Part B/helios_bids_vep/derivatives/mne-bids-pipeline-vep
2025-08-30 02:30:24,581 - INFO - Decoding root: /Users/farjam/OneDrive - University of Edinburgh/wellcome/Amir/Decoding_VEP
2025-08-30 02:30:24,583 - INFO - Decoding directory: /Users/farjam/OneDrive - University of Edinburgh/wellcome/Amir/Decoding_VEP/Decoding_VEP/Luminance_NPZ


In [4]:
# 2. Automatically list subjects
# ----------------------------------------------------------------------
SUBJECTS = sorted(
    d.name.replace("sub-", "")
    for d in DERIV_DIR.glob("sub-*")
    if d.is_dir() and re.fullmatch(r"sub-\d{4}", d.name)
)

if not SUBJECTS:
    raise RuntimeError(f"No subject folders found in {DERIV_DIR}")

logger.info(f"Found {len(SUBJECTS)} subjects: {SUBJECTS}")


2025-08-30 02:30:58,902 - INFO - Found 97 subjects: ['1001', '1002', '1004', '1005', '1007', '1008', '1009', '1010', '1011', '1014', '1015', '1016', '1017', '1018', '1020', '1021', '1022', '1023', '1024', '1026', '1027', '1028', '1029', '1031', '1033', '1034', '1037', '1038', '1039', '1041', '1042', '1043', '1044', '1046', '1049', '1050', '1052', '1057', '1059', '1061', '1064', '2002', '2006', '2007', '2009', '2012', '2017', '2018', '2019', '2020', '2023', '2025', '2026', '2028', '2029', '2035', '2037', '2041', '2042', '2044', '2049', '3001', '3003', '3004', '3005', '3006', '3007', '3008', '3011', '3012', '3014', '3016', '3017', '3026', '3027', '3030', '3031', '3032', '3034', '3035', '3038', '3039', '3041', '3043', '3046', '3048', '3050', '3052', '3054', '3056', '3057', '3058', '3059', '3060', '3062', '3063', '3065']


In [13]:
# 3. Electrode-set menu 
# ----------------------------------------------------------------------
ELECTRODE_SETS = {
    "1": {
        "name": "64", 
        "channels": list(range(1, 65)),
        "description": "All 64 electrodes"
    },
    "2": {
        "name": "PO", 
        "channels": [29, 27, 64, 25, 26, 30, 63, 62, 28, 24,
                     23, 22, 21, 20, 31, 57, 58, 59, 60, 61],
        "description": "Posterior electrodes only"
    },
    "3": {
        "name": "32", 
        "channels": [1, 34, 3, 36, 7, 5, 38, 40, 42, 9, 11, 46,
                     44, 15, 13, 48, 50, 52, 17, 19, 56, 54, 23, 21, 31, 58, 60, 26, 63, 27, 29, 64],
        "description": "32-electrode grid"
    },
}

print("Electrode options:")
print("  1- All 64 electrodes")
print("  2- Posterior only")
print("  3- 32-electrode grid")

while True:
    sel = input("Choose electrode set [1/2/3]: ").strip()
    if sel in ELECTRODE_SETS:
        ELEC_CFG = ELECTRODE_SETS[sel]
        break
    print(f"Invalid selection: {sel}. Please choose 1, 2, or 3.")

logger.info(f"Selected electrode set: {ELEC_CFG['name']} - {ELEC_CFG['description']}")
logger.info(f"Channels: {len(ELEC_CFG['channels'])}")

Electrode options:
  1- All 64 electrodes
  2- Posterior only
  3- 32-electrode grid


2025-08-20 18:04:27,851 - INFO - Selected electrode set: 64 - All 64 electrodes
2025-08-20 18:04:27,851 - INFO - Channels: 64


In [14]:




# 4. Stimulus-dimension menu 
# ----------------------------------------------------------------------
ANALYSES = {
    "1": {  # luminance
        "name": "luminance",
        "event_types": [1, 2, 3, 4],
        "labels": ["lum1", "lum2", "lum3", "lum4"],
        "plot_colors": ["#d9d9d9", "#999999", "#666666", "#262626"],
        "description": "Luminance analysis"
    },
    "2": {  # L–M
        "name": "L-M",
        "event_types": [5, 6, 7, 8],
        "labels": ["LM1", "LM2", "LM3", "LM4"],
        "plot_colors": ["#d98c8c", "#b23333", "#800000", "#260000"],
        "description": "L-M (red-green) analysis"
    },
    "3": {  # S-cone
        "name": "S-cone",
        "event_types": [9, 10, 11, 12],
        "labels": ["S1", "S2", "S3", "S4"],
        "plot_colors": ["#9cbfff", "#6666cc", "#333399", "#000026"],
        "description": "S-cone (blue-yellow) analysis"
    },
}

print("\nDecode which stimulus dimension?")
print("  1- Luminance")
print("  2- L minus M")
print("  3- S-cone")

while True:
    analysis_choice = input("Enter your choice [1/2/3]: ").strip()
    if analysis_choice in ANALYSES:
        ANALYSIS_CFG = ANALYSES[analysis_choice]
        break
    print(f"Invalid selection: {analysis_choice}. Please choose 1, 2, or 3.")

logger.info(f"Selected analysis: {ANALYSIS_CFG['name']} - {ANALYSIS_CFG['description']}")
logger.info(f"Event types: {ANALYSIS_CFG['event_types']}")
logger.info(f"Labels: {ANALYSIS_CFG['labels']}")



Decode which stimulus dimension?
  1- Luminance
  2- L minus M
  3- S-cone


2025-08-20 18:04:31,795 - INFO - Selected analysis: S-cone - S-cone (blue-yellow) analysis
2025-08-20 18:04:31,796 - INFO - Event types: [9, 10, 11, 12]
2025-08-20 18:04:31,797 - INFO - Labels: ['S1', 'S2', 'S3', 'S4']


In [17]:

# 5. Save configuration
# ----------------------------------------------------------------------
CONFIG_PATH = DECODE_DIR / "decode_config.json"
config_data = {
    "bids_root": str(BIDS_ROOT),
    "deriv_dir": str(DERIV_DIR),
    "decode_dir": str(DECODE_DIR),
    "subjects": SUBJECTS,
    "electrode_set": ELEC_CFG,
    "analysis": ANALYSIS_CFG,
}

with open(CONFIG_PATH, 'w') as f:
    json.dump(config_data, f, indent=2)
    
logger.info(f"Configuration saved to {CONFIG_PATH}")




2025-08-20 18:04:51,764 - INFO - Configuration saved to new_helios_bd_eeg_vep/derivatives/mne-bids-pipeline-vep/decoding/npz3/decode_config.json


In [18]:

# 6. Get epochs file path for a subject
# ----------------------------------------------------------------------
def epochs_path_for(subject):
    eeg_dir = DERIV_DIR / f"sub-{subject}" / "eeg"
    fname   = f"sub-{subject}_task-vep_proc-clean_epo.fif"
    path    = eeg_dir / fname
    
    if not path.exists():
        raise FileNotFoundError(f"Epochs file not found: {path}")
    
    return path


In [19]:

# 7. Process each subject and save NPZ files
# ----------------------------------------------------------------------
for subject in SUBJECTS:
    logger.info(f"\n{'='*60}")
    logger.info(f"Processing subject: {subject}")
    logger.info(f"Electrode set: {ELEC_CFG['name']}")
    logger.info(f"Analysis: {ANALYSIS_CFG['name']}")
    
    try:
        # Load epochs
        epo_path = epochs_path_for(subject)
        logger.info(f"Loading epochs from: {epo_path}")
        epochs = mne.read_epochs(epo_path, preload=True, verbose=False)
        
        # Log initial info
        logger.info(f"Original epochs info: {len(epochs)} trials, {len(epochs.ch_names)} channels, {len(epochs.times)} time points")
        logger.info(f"Event codes present: {np.unique(epochs.events[:, 2])}")
        
        # Select channels - convert 1-based indices to 0-based
        ch_indices = [idx - 1 for idx in ELEC_CFG["channels"]]
        
        # Verify indices are within range
        if max(ch_indices) >= len(epochs.ch_names):
            logger.error(f"Channel index {max(ch_indices)+1} out of range for subject {subject}")
            raise IndexError("Channel index out of range")
        
        epochs.pick(ch_indices)
        ch_names = epochs.ch_names
        
        logger.info(f"Selected {len(ch_names)} channels: {ch_names[:3]}...{ch_names[-3:]}")
        
        # Filter events
        keep_codes = ANALYSIS_CFG["event_types"]
        mask = np.isin(epochs.events[:, 2], keep_codes)
        epochs = epochs[mask]
        
        if len(epochs) == 0:
            logger.error(f"No epochs found for event codes {keep_codes}")
            raise ValueError("No epochs after filtering")
        
        # Verify we have all event types
        found_codes = np.unique(epochs.events[:, 2])
        missing_codes = set(keep_codes) - set(found_codes)
        
        if missing_codes:
            logger.warning(f"Missing event codes: {missing_codes}")
        
        logger.info(f"After filtering: {len(epochs)} trials")
        logger.info(f"Found event codes: {found_codes}")
        
        # Create labels
        code_to_label = {code: i for i, code in enumerate(keep_codes)}
        y = np.array([code_to_label[code] for code in epochs.events[:, 2]], dtype=np.int32)
        
        # Log class distribution
        unique_labels, counts = np.unique(y, return_counts=True)
        for label, count in zip(unique_labels, counts):
            logger.info(f"Class {label}: {count} trials")
        
        # Prepare data for saving
        X = epochs.get_data()  # Shape: (n_trials, n_channels, n_times)
        times = epochs.times.astype(np.float32)
        
        # Verify data dimensions
        assert X.shape[0] == len(y), "Trial count mismatch between X and y"
        assert X.shape[1] == len(ch_names), "Channel count mismatch"
        assert X.shape[2] == len(times), "Time point mismatch"
        
        # Create output filename
        out_file = DECODE_DIR / f"sub-{subject}_{ANALYSIS_CFG['name']}_{ELEC_CFG['name']}_data.npz"
        
        # Save data with explicit array names and proper data types
        np.savez_compressed(
            out_file,
            X=X.astype(np.float32),       # EEG data (float32 for efficiency)
            y=y,                          # Class labels (int32)
            times=times,                  # Time points (float32)
            ch_names=np.array(ch_names),  # Channel names (1D array)
            event_codes=epochs.events[:, 2],  # Original event codes
            subject=subject,
            electrode_set=ELEC_CFG["name"],
            analysis=ANALYSIS_CFG["name"]
        )
        
        logger.info(f"Saved NPZ file: {out_file}")
        logger.info(f"Data shape: {X.shape} (trials × channels × time)")
        logger.info(f"Time range: {times[0]:.3f}s to {times[-1]:.3f}s")
        
    except Exception as e:
        logger.error(f"Error processing subject {subject}: {str(e)}", exc_info=True)
        continue

logger.info("\nProcessing complete!")
logger.info(f"NPZ files saved to: {DECODE_DIR}")

2025-08-20 18:04:57,306 - INFO - 
2025-08-20 18:04:57,307 - INFO - Processing subject: 800
2025-08-20 18:04:57,308 - INFO - Electrode set: 64
2025-08-20 18:04:57,309 - INFO - Analysis: S-cone
2025-08-20 18:04:57,310 - INFO - Loading epochs from: new_helios_bd_eeg_vep/derivatives/mne-bids-pipeline-vep/sub-800/eeg/sub-800_task-vep_proc-clean_epo.fif
2025-08-20 18:04:57,601 - INFO - Original epochs info: 682 trials, 73 channels, 256 time points
2025-08-20 18:04:57,602 - INFO - Event codes present: [ 1  2  3  4  5  6  7  8  9 10 11 12]
2025-08-20 18:04:57,613 - INFO - Selected 64 channels: ['Fp1', 'AF7', 'AF3']...['PO8', 'PO4', 'O2']
2025-08-20 18:04:57,681 - INFO - After filtering: 227 trials
2025-08-20 18:04:57,681 - INFO - Found event codes: [ 9 10 11 12]
2025-08-20 18:04:57,682 - INFO - Class 0: 57 trials
2025-08-20 18:04:57,682 - INFO - Class 1: 59 trials
2025-08-20 18:04:57,682 - INFO - Class 2: 57 trials
2025-08-20 18:04:57,683 - INFO - Class 3: 54 trials
2025-08-20 18:04:58,149 - I

In [20]:
import numpy as np

# Load one of the generated files
data = np.load("new_helios_bd_eeg_vep/derivatives/mne-bids-pipeline-vep/decoding/npz1/sub-800_luminance_64_data.npz", allow_pickle=True)

# Check keys and shapes
print("Keys:", list(data.keys()))
print("X shape:", data['X'].shape)
print("y shape:", data['y'].shape)
print("Times shape:", data['times'].shape)
print("Channel names:", data['ch_names'])
print("Event codes:", np.unique(data['event_codes']))
print("Subject:", data['subject'])
print("Electrode set:", data['electrode_set'])
print("Analysis:", data['analysis'])

Keys: ['X', 'y', 'times', 'ch_names', 'event_codes', 'subject', 'electrode_set', 'analysis']
X shape: (224, 64, 256)
y shape: (224,)
Times shape: (256,)
Channel names: ['Fp1' 'AF7' 'AF3' 'F1' 'F3' 'F5' 'F7' 'FT7' 'FC5' 'FC3' 'FC1' 'C1' 'C3'
 'C5' 'T7' 'TP7' 'CP5' 'CP3' 'CP1' 'P1' 'P3' 'P5' 'P7' 'P9' 'PO7' 'PO3'
 'O1' 'Iz' 'Oz' 'POz' 'Pz' 'CPz' 'Fpz' 'Fp2' 'AF8' 'AF4' 'AFz' 'Fz' 'F2'
 'F4' 'F6' 'F8' 'FT8' 'FC6' 'FC4' 'FC2' 'FCz' 'Cz' 'C2' 'C4' 'C6' 'T8'
 'TP8' 'CP6' 'CP4' 'CP2' 'P2' 'P4' 'P6' 'P8' 'P10' 'PO8' 'PO4' 'O2']
Event codes: [1 2 3 4]
Subject: 800
Electrode set: 64
Analysis: luminance
