In [None]:
import os
from pathlib import Path
import re

# Avoid numba caching/JIT issues in some environments
os.environ.setdefault("NUMBA_DISABLE_JIT", "1")
# Avoid Matplotlib cache issues if ~/.matplotlib isn't writable
mpl_dir = (Path.cwd() / "New_EEG" / ".mplconfig")
mpl_dir.mkdir(parents=True, exist_ok=True)
os.environ.setdefault("MPLCONFIGDIR", str(mpl_dir.resolve()))

import numpy as np
import mne
import matplotlib.pyplot as plt

# ----------- File setup (NEW data) -----------
DATA_DIR = Path(r"g:\ChristianMusaeus\New_EEG\Processed")
SUBJECT = "001"  # "001".."100"
RATER = "a"      # "a" or "b" (two doctors)
epo_path = DATA_DIR / f"Ros_Sub{SUBJECT}{RATER}_epo.fif"
if not epo_path.exists():
    # Fallback: load the first available epoch file
    candidates = sorted(DATA_DIR.glob("*_epo.fif"))
    if not candidates:
        raise FileNotFoundError(f"No '*_epo.fif' files found in {DATA_DIR}")
    epo_path = candidates[0]

print("Loading:", epo_path)
epochs = mne.read_epochs(str(epo_path), preload=True, verbose=False)
print(epochs)
print("event_id:", epochs.event_id)
inv = {v: k for k, v in epochs.event_id.items()}
counts = {inv[v]: int((epochs.events[:, 2] == v).sum()) for v in sorted(inv)}
print("label counts:", counts)

# For plotting/topomaps, create a copy with cleaned 10-20 channel names + standard montage
def _clean_ch_name(name: str) -> str:
    name = str(name)
    name = name.replace("EEG ", "")
    name = name.replace("-REF", "")
    return name.strip()

epochs_plot = epochs.copy()
epochs_plot.rename_channels({ch: _clean_ch_name(ch) for ch in epochs_plot.ch_names})
montage = mne.channels.make_standard_montage("standard_1020")
epochs_plot.set_montage(montage, match_case=False, on_missing="ignore")

# Quick browse
epochs.plot()
plt.show()


In [None]:
epochs.info

### Power Spectral Density

In [None]:
psd =epochs.compute_psd()

psd_data=psd.get_data()
print(psd_data.shape)
psd.plot()
plt.show()


### Alpha Power

In [None]:
freq = psd.freqs

fmin, fmax = 8, 13

freq_indices = (freq >= fmin) & (freq <= fmax)

# absolute alpha power
psd_alpha = psd_data[:, :, freq_indices].sum(axis=2)

# total power (all frequencies)
psd_total = psd_data.sum(axis=2)

# relative alpha power
psd_relative_alpha = psd_alpha / psd_total

print(f"Total Power Shape: {psd_total}")
print(f"Alpha Power Shape: {psd_alpha}")

In [None]:
print(f"Alpha Power for first 5 epochs:", psd_alpha[:5])

In [None]:
alpha_power_avg = psd_alpha.mean(axis=0) # shape: (n_channels,)
relative_alpha_power_avg = psd_relative_alpha.mean(axis=0) # shape: (n_channels,)

In [None]:
regions = {
    'frontal':['Fz', 'Fp1', 'Fp2', 'F3', 'F4', 'F7', 'F8'],
    'central':['C3', 'C4', 'Cz'],
    'parietal':['P3', 'P4', 'P7', 'P8', 'Pz'],
    'occipital':['O1', 'O2', 'Oz']
}

# Use cleaned channel names for region selection
channels = list(epochs_plot.ch_names)
region_indices = {name: [i for i, ch in enumerate(channels) if ch in chs] for name, chs in regions.items()}
print("Region channel counts:", {k: len(v) for k, v in region_indices.items()})

### Topographical plot of alpha activity on the scalp

In [None]:
fig, ax = plt.subplots()

# Topomap requires channel locations; epochs_plot has a standard montage attached.
mne.viz.plot_topomap(psd_alpha.mean(axis=0), epochs_plot.info, show=False, axes=ax, cmap='coolwarm')

ax.set_title("Alpha Band Power (8â€“13 Hz)")

plt.show()

The blue color represents low values, meaning there is little or no alpha activity in these regions.

White epresents middle values or no significant difference, meaning the alpha activity in these regions is not as high as in the red regions but still above baseline or neutral

Red indicates high values, meaning the alpha activity is high in these regions

In [None]:
eeg_array = epochs.get_data()
print(eeg_array.shape)

### Rejected epochs

In [None]:
# In the NEW pipeline we use FIF, not EEGLAB .set/.mat.
# Equivalent "labels" live in epochs.events + epochs.event_id.
print("File:", epo_path.name)
print("sfreq:", epochs.info['sfreq'])
print("tmin/tmax:", epochs.tmin, epochs.tmax)
print("event_id:", epochs.event_id)
print("first 10 events (sample index, prev, code):")
print(epochs.events[:10])

In [None]:
# Label counts (EO/EC/OTHER) from epochs.events
inv = {v: k for k, v in epochs.event_id.items()}
labels = np.array([inv.get(int(code), 'OTHER') for code in epochs.events[:, 2]], dtype=str)
unique, counts = np.unique(labels, return_counts=True)
print(dict(zip(unique, counts)))


### Channels

In [None]:
# Channel names as stored vs cleaned (for montage/topomaps)
print("Stored channel names:")
for ch in epochs.ch_names:
    print(" ", ch)

print("\nCleaned channel names (epochs_plot):")
for ch in epochs_plot.ch_names:
    print(" ", ch)


In [None]:
# Montage / digitization info (needed for topomaps)
dig = epochs_plot.info.get('dig')
print("Has montage/dig points:", dig is not None and len(dig) > 0)
if dig:
    print("n_dig:", len(dig))


In [None]:
# Show whether each channel has a position after setting the montage
pos = epochs_plot.get_montage().get_positions()['ch_pos'] if epochs_plot.get_montage() is not None else {}
missing = [ch for ch in epochs_plot.ch_names if ch not in pos]
print(f"Channels with known positions: {len(pos)} / {len(epochs_plot.ch_names)}")
if missing:
    print("Missing positions (first 10):", missing[:10])


In [None]:
# Dropped epochs summary (reject_by_annotation=True during creation)
drop_log = getattr(epochs, 'drop_log', None)
if drop_log is None:
    print('No drop_log available')
else:
    dropped = sum(len(r) > 0 for r in drop_log)
    print(f"Dropped epochs: {dropped} / {len(drop_log)}")


In [None]:
# Example: show the first 20 drop reasons (if any)
drop_log = getattr(epochs, 'drop_log', None)
if drop_log is None:
    pass
else:
    for i, reasons in enumerate(drop_log[:20]):
        if reasons:
            print(i, reasons)


In [None]:
# Subject + doctor inferred from filename
m = re.search(r"sub(\d+)([ab])", epo_path.stem)
if m:
    print("subject:", int(m.group(1)), "doctor:", m.group(2))
else:
    print("Could not parse subject/doctor from:", epo_path.name)


In [None]:
# Total number of epochs in this file
len(epochs)
