In [1]:
import mne
import matplotlib 
import matplotlib.pyplot as plt
matplotlib.use('Qt5Agg')
from mne.preprocessing import ICA
from mne.preprocessing import read_ica
raw = mne.io.read_raw_fif('rEEG\sub-143\preICA143_raw.fif', preload=True)
ica = read_ica(f'rEEG\sub-143\ica-143-variance_compica.fif')

Opening raw data file rEEG\sub-143\preICA143_raw.fif...
    Range : 0 ... 30554 =      0.000 ...   122.216 secs
Ready.
Reading 0 ... 30554  =      0.000 ...   122.216 secs...
Reading c:\Users\User\Documents\EEG_Project\rEEG\sub-143\ica-143-variance_compica.fif ...
Now restoring ICA solution ...
Ready.


  ica = read_ica(f'rEEG\sub-143\ica-143-variance_compica.fif')


In [2]:
# ---------------------------------------------------------
# CONSERVATIVE MUSCLE ARTIFACT DETECTOR FOR ICA (plug-and-play)
# ---------------------------------------------------------
import numpy as np
import pandas as pd
from scipy import signal, stats


def detect_muscle_artifacts(ica, raw):
    """
    Conservative automatic muscle artifact detection.
    INPUTS:
        ica : your fitted ICA object
        raw : the Raw you used for ICA
    OUTPUTS:
        df  : table ranking muscle-like ICs
        bad : list of IC indices to remove
    """

    sfreq = raw.info['sfreq']
    sources = ica.get_sources(raw).get_data()     # (n_ics, n_times)
    mixing = np.abs(ica.get_components())         # (n_channels, n_ics)

    emg_band = (20, 45)

    rows = []
    for ic in range(sources.shape[0]):
        sig = sources[ic]

        # --- PSD ---
        freqs, pxx = signal.welch(sig, sfreq, nperseg=int(2*sfreq))
        emg_mask = (freqs >= emg_band[0]) & (freqs <= emg_band[1])

        emg_pow = np.trapz(pxx[emg_mask], freqs[emg_mask])
        low_pow = np.trapz(pxx[(freqs>=1)&(freqs<=15)], freqs[(freqs>=1)&(freqs<=15)])
        emg_ratio = emg_pow / (low_pow + 1e-12)

        # --- Spike / burst detection ---
        zsig = (sig - sig.mean()) / (sig.std() + 1e-12)
        burst_count = np.sum(np.abs(zsig) > 3)

        # --- Spatial focality ---
        focality = mixing[:, ic].max() / (np.median(mixing[:, ic]) + 1e-12)

        # --- Conservative scoring ---
        score = (
            0.45 * min(emg_ratio / 3, 1) +      # harder to score high
            0.25 * min(burst_count / 80, 1) +   # requires more bursts
            0.30 * min(focality / 8, 1)         # requires tighter focality
        )

        rows.append({
            'IC': ic,
            'EMG_ratio': emg_ratio,
            'bursts': int(burst_count),
            'focality': focality,
            'score': score
        })

    df = pd.DataFrame(rows).sort_values('score', ascending=False).reset_index(drop=True)

    # Conservative threshold: score > 0.75  
    bad = df[df['score'] > 0.75]['IC'].tolist()

    return df, bad


df, bad_ics = detect_muscle_artifacts(ica, raw)

print(df.head(50).to_string(index=False))
print("ICs to remove:", bad_ics)


 IC  EMG_ratio  bursts  focality    score
 10   1.788592     131 14.084479 0.818289
 29   1.677521     141  9.693997 0.801628
 44   1.608052     119  7.584708 0.775634
 31   1.495752     119  7.902678 0.770713
 21   1.444441     111  8.972401 0.766666
  8   1.452882     150  7.469909 0.748054
 35   1.499138     109  7.166493 0.743614
 28   1.401498     127  7.484927 0.740909
 43   1.606766     106  6.105364 0.719966
  6   1.113571     134  7.971367 0.715962
 33   1.504659     113  6.240788 0.709728
 12   1.337699     129  6.693377 0.701656
 38   1.557265     114  5.516687 0.690465
 40   0.884924     174 12.040598 0.682739
 20   1.402493     121  5.534609 0.667922
 18   1.658332     131  4.441897 0.665321
  9   1.443501     135  5.137597 0.659185
 30   1.388828     120  5.310965 0.657485
 45   1.354592     124  5.296988 0.651826
 15   1.496998     122  4.707099 0.651066
 32   1.337182     133  5.269645 0.648189
 14   1.382767     126  4.896791 0.641045
 17   0.413917     222 10.922404 0

  emg_pow = np.trapz(pxx[emg_mask], freqs[emg_mask])
  low_pow = np.trapz(pxx[(freqs>=1)&(freqs<=15)], freqs[(freqs>=1)&(freqs<=15)])


In [14]:
# ---------------------------------------------------------
# EASY EOG/BLINK ARTIFACT DETECTOR FOR ICA (plug-and-play)
# ---------------------------------------------------------
import numpy as np
import pandas as pd
from scipy import signal

def detect_eog_artifacts(ica, raw):
    """
    Simple automatic EOG/blink artifact detection.
    INPUTS:
        ica : your fitted ICA object
        raw : the Raw you used for ICA
    OUTPUTS:
        df  : table ranking EOG-like ICs
        bad : list of IC indices to remove
    """
    
    sfreq = raw.info['sfreq']
    sources = ica.get_sources(raw).get_data()     # shape: (n_ics, n_times)
    mixing = np.abs(ica.get_components())         # shape: (n_channels, n_ics)
    ch_names = [ch['ch_name'] for ch in raw.info['chs']]
    
    # frontal electrodes typically showing EOG
    frontal_chs = ['Fp1', 'Fp2', 'AF7', 'AF8', 'AF3', 'AF4', 'Fz']
    frontal_idx = [ch_names.index(ch) for ch in frontal_chs if ch in ch_names]
    
    eog_band = (0.5, 6)  # blink / slow eye movement
    
    rows = []
    for ic in range(sources.shape[0]):
        sig = sources[ic]
        
        # --- PSD in low-frequency range ---
        freqs, pxx = signal.welch(sig, sfreq, nperseg=int(2*sfreq))
        eog_mask = (freqs >= eog_band[0]) & (freqs <= eog_band[1])
        eog_pow = np.trapz(pxx[eog_mask], freqs[eog_mask])
        total_pow = np.trapz(pxx, freqs)
        eog_ratio = eog_pow / (total_pow + 1e-12)
        
        # --- Spike / burst detection ---
        zsig = (sig - sig.mean()) / (sig.std() + 1e-12)
        burst_count = np.sum(np.abs(zsig) > 3)
        
        # --- Frontal dominance ---
        frontal_weight = mixing[frontal_idx, ic].mean()
        overall_weight = mixing[:, ic].mean()
        frontal_ratio = frontal_weight / (overall_weight + 1e-12)
        
        # --- Score ---
        # Weighting: frontal_ratio is primary, then low-freq power, then bursts
        score = (
            0.8 * min(frontal_ratio / 2, 1) +
            0.3 * min(eog_ratio / 0.5, 1) +
            0.1 * min(burst_count / 50, 1)
        )
        
        rows.append({
            'IC': ic,
            'frontal_ratio': frontal_ratio,
            'eog_ratio': eog_ratio,
            'bursts': int(burst_count),
            'score': score
        })
        
    df = pd.DataFrame(rows).sort_values('score', ascending=False).reset_index(drop=True)
    
    # Threshold: score > 0.5 → probably EOG
    bad = df[df['score'] > 1.1]['IC'].tolist()
    
    return df, bad

# ------------------------------
# Example usage / printing
# ------------------------------
df, bad_ics = detect_eog_artifacts(ica, raw)

print(df.head(50).to_string(index=False))
print("ICs to remove:", bad_ics)


  eog_pow = np.trapz(pxx[eog_mask], freqs[eog_mask])
  total_pow = np.trapz(pxx, freqs)


 IC  frontal_ratio  eog_ratio  bursts    score
  5       3.119301   0.570313     495 1.200000
 39       1.946864   0.437459     231 1.141221
 49       1.833010   0.546167     355 1.133204
  4       1.490668   0.499631     402 0.996046
 22       1.969570   0.171453     144 0.990699
 25       1.544308   0.391033     471 0.952343
 21       1.751147   0.246381     145 0.948287
 29       1.701055   0.228745     138 0.917669
 16       1.382548   0.382184     192 0.882329
 47       1.757098   0.127982     141 0.879628
 28       1.422328   0.339455     189 0.872604
  0       1.480389   0.289051     206 0.865586
 40       1.593684   0.187400     147 0.849913
 13       1.532116   0.220154     159 0.844939
  1       1.588587   0.176501     265 0.841335
 24       1.322536   0.339187     444 0.832527
 15       1.506132   0.164385     142 0.801084
 18       1.271119   0.319178     200 0.799955
  6       1.473617   0.154242     204 0.781992
 11       1.404095   0.192314     125 0.777027
 43       0.9

In [None]:
import numpy as np
from mne.time_frequency import psd_array_welch

# Example: components 1 through 37 inclusive
ic_range = range(1,48)  

sources = ica.get_sources(raw).get_data()  # shape: n_components x n_times

print(f"{'IC':>4} | {'HF/LF Ratio (30-40/1-30 Hz)':>25}")
print("-"*32)

for comp in ic_range:
    comp_ts = sources[comp]  # 1D array of this IC's time series
    psd, freqs = psd_array_welch(comp_ts, sfreq=raw.info['sfreq'], fmin=1, fmax=40, n_fft=2048)
    
    low = psd[(freqs >= 1) & (freqs < 30)].mean()
    high = psd[(freqs >= 30) & (freqs <= 40)].mean()
    ratio = high / low if low > 0 else np.nan
    
    print(f"{comp:>4} | {ratio:>25.4f}")


In [None]:
import numpy as np
from mne.time_frequency import psd_array_welch

# 1) channel HF power (you already ran this)
data = raw.get_data()
sfreq = raw.info['sfreq']
psds_ch, freqs = psd_array_welch(data, sfreq=sfreq, fmin=1, fmax=80, n_fft=2048, n_overlap=0)
hf_idx = (freqs >= 30) & (freqs <= 45)
ch_hf = psds_ch[:, hf_idx].mean(axis=1)         # V^2/Hz per channel
ch_hf_uv = ch_hf * 1e12                         # µV^2/Hz (readable)

# 2) component HF power (show why ICA components look 'hot')
sources = ica.get_sources(raw).get_data()       # n_ics x n_times
psds_ic, freqs_ic = psd_array_welch(sources, sfreq=sfreq, fmin=1, fmax=80, n_fft=2048, n_overlap=0)
ic_hf = psds_ic[:, hf_idx].mean(axis=1)         # V^2/Hz per IC
ic_hf_uv = ic_hf * 1e12

# 3) relative contribution: sum HF power of ICs that you flagged vs total channel HF
total_ch_hf = ch_hf.sum()
total_ic_hf = ic_hf.sum()                       # should be comparable (same energy space)
print("Total channel HF (V^2/Hz):", total_ch_hf)
print("Total IC HF (V^2/Hz):     ", total_ic_hf)
print("Ratio IC/Channel total:", total_ic_hf / (total_ch_hf + 1e-30))

# 4) how much HF power is removed if you drop candidate muscle ICs (example list)
candidates = [20,16,47]   # replace with your auto_reject list
# project data without those ICs
ica.exclude = candidates
recon = ica.apply(raw.copy(), exclude=candidates, start=None, stop=None).get_data()
psds_recon, _ = psd_array_welch(recon, sfreq=sfreq, fmin=1, fmax=80, n_fft=2048, n_overlap=0)
ch_hf_recon = psds_recon[:, hf_idx].mean(axis=1)
print("Median channel HF before (µV^2/Hz):", np.median(ch_hf_uv))
print("Median channel HF after  (µV^2/Hz):", np.median(ch_hf_recon * 1e12))
