In [2]:
import mne
import matplotlib 
import matplotlib.pyplot as plt
matplotlib.use('Qt5Agg')
from mne.preprocessing import ICA
from mne.preprocessing import read_ica
raw = mne.io.read_raw_fif('rEEG\sub-052\preICA052_raw.fif', preload=True)
ica = read_ica(f'rEEG\sub-052\ica-052-variance_compica.fif')

Opening raw data file rEEG\sub-052\preICA052_raw.fif...
    Range : 0 ... 32734 =      0.000 ...   130.936 secs
Ready.
Reading 0 ... 32734  =      0.000 ...   130.936 secs...
Reading c:\Users\User\Documents\EEG_Project\rEEG\sub-052\ica-052-variance_compica.fif ...
Now restoring ICA solution ...
Ready.


  ica = read_ica(f'rEEG\sub-052\ica-052-variance_compica.fif')


In [3]:
# ---------------------------------------------------------
# CONSERVATIVE MUSCLE ARTIFACT DETECTOR FOR ICA (plug-and-play)
# ---------------------------------------------------------
import numpy as np
import pandas as pd
from scipy import signal, stats


def detect_muscle_artifacts(ica, raw):
    """
    Conservative automatic muscle artifact detection.
    INPUTS:
        ica : your fitted ICA object
        raw : the Raw you used for ICA
    OUTPUTS:
        df  : table ranking muscle-like ICs
        bad : list of IC indices to remove
    """

    sfreq = raw.info['sfreq']
    sources = ica.get_sources(raw).get_data()     # (n_ics, n_times)
    mixing = np.abs(ica.get_components())         # (n_channels, n_ics)

    emg_band = (20, 45)

    rows = []
    for ic in range(sources.shape[0]):
        sig = sources[ic]

        # --- PSD ---
        freqs, pxx = signal.welch(sig, sfreq, nperseg=int(2*sfreq))
        emg_mask = (freqs >= emg_band[0]) & (freqs <= emg_band[1])

        emg_pow = np.trapz(pxx[emg_mask], freqs[emg_mask])
        low_pow = np.trapz(pxx[(freqs>=1)&(freqs<=15)], freqs[(freqs>=1)&(freqs<=15)])
        emg_ratio = emg_pow / (low_pow + 1e-12)

        # --- Spike / burst detection ---
        zsig = (sig - sig.mean()) / (sig.std() + 1e-12)
        burst_count = np.sum(np.abs(zsig) > 3)

        # --- Spatial focality ---
        focality = mixing[:, ic].max() / (np.median(mixing[:, ic]) + 1e-12)

        # --- Conservative scoring ---
        score = (
            0.45 * min(emg_ratio / 3, 1) +      # harder to score high
            0.25 * min(burst_count / 80, 1) +   # requires more bursts
            0.30 * min(focality / 8, 1)         # requires tighter focality
        )

        rows.append({
            'IC': ic,
            'EMG_ratio': emg_ratio,
            'bursts': int(burst_count),
            'focality': focality,
            'score': score
        })

    df = pd.DataFrame(rows).sort_values('score', ascending=False).reset_index(drop=True)

    # Conservative threshold: score > 0.75  
    bad = df[df['score'] > 0.75]['IC'].tolist()

    return df, bad


df, bad_ics = detect_muscle_artifacts(ica, raw)

print(df.head(50).to_string(index=False))
print("ICs to remove:", bad_ics)


  emg_pow = np.trapz(pxx[emg_mask], freqs[emg_mask])
  low_pow = np.trapz(pxx[(freqs>=1)&(freqs<=15)], freqs[(freqs>=1)&(freqs<=15)])


 IC  EMG_ratio  bursts  focality    score
 47   1.228212     132  6.476695 0.677108
 38   1.194358     158  6.533891 0.674175
 51   1.416548     136  5.366261 0.663717
 55   1.277158     128  5.660724 0.653851
 30   0.832872     207  7.338166 0.650112
 32   0.626524     180  8.285732 0.643979
 53   1.177350     176  5.455442 0.631182
 12   0.654855     168  7.438729 0.627181
 39   1.193087     154  5.238663 0.625413
 52   1.425105     128  4.087144 0.617034
 48   1.265212     114  4.358562 0.603228
 54   1.300886     146  3.932604 0.592605
 40   1.366787     134  3.666701 0.592519
 49   1.348324     158  3.439356 0.581224
 46   1.317509     129  3.487420 0.578405
 41   0.934131     183  4.645907 0.564341
 36   1.090858     151  3.965215 0.562324
 16   1.163206     113  3.463570 0.554365
 42   0.501629     216  5.897805 0.546412
 37   1.235267     132  2.800608 0.540313
 43   0.774452     241  4.527230 0.535939
 13   1.169398     154  2.735104 0.527976
 19   0.459029     314  5.538425 0

In [4]:
# ---------------------------------------------------------
# EASY EOG/BLINK ARTIFACT DETECTOR FOR ICA (plug-and-play)
# ---------------------------------------------------------
import numpy as np
import pandas as pd
from scipy import signal

def detect_eog_artifacts(ica, raw):
    """
    Simple automatic EOG/blink artifact detection.
    INPUTS:
        ica : your fitted ICA object
        raw : the Raw you used for ICA
    OUTPUTS:
        df  : table ranking EOG-like ICs
        bad : list of IC indices to remove
    """
    
    sfreq = raw.info['sfreq']
    sources = ica.get_sources(raw).get_data()     # shape: (n_ics, n_times)
    mixing = np.abs(ica.get_components())         # shape: (n_channels, n_ics)
    ch_names = [ch['ch_name'] for ch in raw.info['chs']]
    
    # frontal electrodes typically showing EOG
    frontal_chs = ['Fp1', 'Fp2', 'AF7', 'AF8', 'AF3', 'AF4', 'Fz']
    frontal_idx = [ch_names.index(ch) for ch in frontal_chs if ch in ch_names]
    
    eog_band = (0.5, 6)  # blink / slow eye movement
    
    rows = []
    for ic in range(sources.shape[0]):
        sig = sources[ic]
        
        # --- PSD in low-frequency range ---
        freqs, pxx = signal.welch(sig, sfreq, nperseg=int(2*sfreq))
        eog_mask = (freqs >= eog_band[0]) & (freqs <= eog_band[1])
        eog_pow = np.trapz(pxx[eog_mask], freqs[eog_mask])
        total_pow = np.trapz(pxx, freqs)
        eog_ratio = eog_pow / (total_pow + 1e-12)
        
        # --- Spike / burst detection ---
        zsig = (sig - sig.mean()) / (sig.std() + 1e-12)
        burst_count = np.sum(np.abs(zsig) > 3)
        
        # --- Frontal dominance ---
        frontal_weight = mixing[frontal_idx, ic].mean()
        overall_weight = mixing[:, ic].mean()
        frontal_ratio = frontal_weight / (overall_weight + 1e-12)
        
        # --- Score ---
        # Weighting: frontal_ratio is primary, then low-freq power, then bursts
        score = (
            0.8 * min(frontal_ratio / 2, 1) +
            0.3 * min(eog_ratio / 0.5, 1) +
            0.1 * min(burst_count / 50, 1)
        )
        
        rows.append({
            'IC': ic,
            'frontal_ratio': frontal_ratio,
            'eog_ratio': eog_ratio,
            'bursts': int(burst_count),
            'score': score
        })
        
    df = pd.DataFrame(rows).sort_values('score', ascending=False).reset_index(drop=True)
    
    # Threshold: score > 0.5 → probably EOG
    bad = df[df['score'] > 1.1]['IC'].tolist()
    
    return df, bad

# ------------------------------
# Example usage / printing
# ------------------------------
df, bad_ics = detect_eog_artifacts(ica, raw)

print(df.head(50).to_string(index=False))
print("ICs to remove:", bad_ics)


 IC  frontal_ratio  eog_ratio  bursts    score
  0       3.604755   0.834495     765 1.200000
  3       2.000989   0.565166     416 1.200000
 17       2.012916   0.562171     375 1.200000
 15       1.978982   0.557327     277 1.191593
  4       1.778224   0.532327     331 1.111289
 56       1.722584   0.514550     422 1.089034
 13       2.087504   0.235750     154 1.041450
 28       1.567546   0.481122     220 1.015692
  9       1.652135   0.351890     168 0.971988
 19       1.423014   0.491143     314 0.963891
  8       1.379135   0.538312     401 0.951654
 23       1.717190   0.234455     143 0.927549
 27       1.310587   0.492265     218 0.919594
 44       1.433419   0.409376     163 0.918993
 32       1.449632   0.386001     180 0.911454
 25       1.195651   0.546202     400 0.878260
 22       1.198150   0.468313     249 0.860248
  6       1.551984   0.231427     151 0.859650
 11       1.104949   0.536116     329 0.841980
  1       1.314511   0.355301     204 0.838985
 41       1.3

  eog_pow = np.trapz(pxx[eog_mask], freqs[eog_mask])
  total_pow = np.trapz(pxx, freqs)


In [None]:
import numpy as np
from mne.time_frequency import psd_array_welch

# Example: components 1 through 37 inclusive
ic_range = range(1,48)  

sources = ica.get_sources(raw).get_data()  # shape: n_components x n_times

print(f"{'IC':>4} | {'HF/LF Ratio (30-40/1-30 Hz)':>25}")
print("-"*32)

for comp in ic_range:
    comp_ts = sources[comp]  # 1D array of this IC's time series
    psd, freqs = psd_array_welch(comp_ts, sfreq=raw.info['sfreq'], fmin=1, fmax=40, n_fft=2048)
    
    low = psd[(freqs >= 1) & (freqs < 30)].mean()
    high = psd[(freqs >= 30) & (freqs <= 40)].mean()
    ratio = high / low if low > 0 else np.nan
    
    print(f"{comp:>4} | {ratio:>25.4f}")


In [None]:
import numpy as np
from mne.time_frequency import psd_array_welch

# 1) channel HF power (you already ran this)
data = raw.get_data()
sfreq = raw.info['sfreq']
psds_ch, freqs = psd_array_welch(data, sfreq=sfreq, fmin=1, fmax=80, n_fft=2048, n_overlap=0)
hf_idx = (freqs >= 30) & (freqs <= 45)
ch_hf = psds_ch[:, hf_idx].mean(axis=1)         # V^2/Hz per channel
ch_hf_uv = ch_hf * 1e12                         # µV^2/Hz (readable)

# 2) component HF power (show why ICA components look 'hot')
sources = ica.get_sources(raw).get_data()       # n_ics x n_times
psds_ic, freqs_ic = psd_array_welch(sources, sfreq=sfreq, fmin=1, fmax=80, n_fft=2048, n_overlap=0)
ic_hf = psds_ic[:, hf_idx].mean(axis=1)         # V^2/Hz per IC
ic_hf_uv = ic_hf * 1e12

# 3) relative contribution: sum HF power of ICs that you flagged vs total channel HF
total_ch_hf = ch_hf.sum()
total_ic_hf = ic_hf.sum()                       # should be comparable (same energy space)
print("Total channel HF (V^2/Hz):", total_ch_hf)
print("Total IC HF (V^2/Hz):     ", total_ic_hf)
print("Ratio IC/Channel total:", total_ic_hf / (total_ch_hf + 1e-30))

# 4) how much HF power is removed if you drop candidate muscle ICs (example list)
candidates = [20,16,47]   # replace with your auto_reject list
# project data without those ICs
ica.exclude = candidates
recon = ica.apply(raw.copy(), exclude=candidates, start=None, stop=None).get_data()
psds_recon, _ = psd_array_welch(recon, sfreq=sfreq, fmin=1, fmax=80, n_fft=2048, n_overlap=0)
ch_hf_recon = psds_recon[:, hf_idx].mean(axis=1)
print("Median channel HF before (µV^2/Hz):", np.median(ch_hf_uv))
print("Median channel HF after  (µV^2/Hz):", np.median(ch_hf_recon * 1e12))
