<a href="https://colab.research.google.com/github/abelowska/mlNeuro/blob/main/2025/wavelets_examples.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wavelets

In [None]:
!pip install scipy==1.12.0
!pip install mne
!pip install ssqueezepy

Imports

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import mne
from mne.datasets import eegbci
from mne.datasets import sample
from mne.decoding import UnsupervisedSpatialFilter, CSP, Vectorizer
from mne.datasets import sample
from mne import Epochs, pick_types
from mne.channels import make_standard_montage
from mne.datasets import eegbci
from mne.decoding import CSP
from mne.io import concatenate_raws, read_raw_edf

from sklearn.model_selection import StratifiedKFold, cross_val_score, train_test_split, ShuffleSplit
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, classification_report
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA, FastICA
from sklearn.pipeline import make_pipeline
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from mne.time_frequency import tfr_morlet
from ssqueezepy import Wavelet, cwt, icwt, issq_cwt
from ssqueezepy.experimental import scale_to_freq, freq_to_scale

import matplotlib.pyplot as plt

## Load Motor Imagery dataset

Read data and do basic pre-processing

In [None]:
def read_motor_imagery_epochs(
    subject=1,
    runs=[4, 8, 12],
    tmin=-1.0,
    tmax=4.0,
    event_ids = ['left', 'right'],
    ):

  # load data from PhysioNet
  raw_fnames = eegbci.load_data(subject, runs)
  raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
  eegbci.standardize(raw)  # set channel names

  # set channels locations
  montage = make_standard_montage("standard_1005")
  raw.set_montage(montage)

  # rename annotations in raw
  raw.annotations.rename(dict(T1=event_ids[0], T2=event_ids[1]))

  # re-reference signal
  raw.set_eeg_reference(projection=True)

  # apply band-pass filter
  l_freq = 1.0
  h_freq = 30.0
  raw_filtered = raw.filter(
      l_freq,
      h_freq,
      fir_design="firwin",
      skip_by_annotation="edge"
  )
  # apply Notch at 60 Hz
  power_freq = 60
  nyquist_freq = raw_filtered.info['sfreq'] / 2

  raw_filtered = raw_filtered.notch_filter(
      picks=['eeg', 'eog'],
      freqs=np.arange(power_freq, nyquist_freq, power_freq),
      n_jobs=10,
  )

  # pick only eeg channels
  picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads")

  # create epochs
  epochs = Epochs(
      raw_filtered,
      event_id=event_ids,
      tmin=tmin,
      tmax=tmax,
      proj=True,
      picks=picks,
      baseline=None,
      preload=True,
  )

  return epochs

In [None]:
subject = 1
runs = [4, 8, 12]  # motor imagery: left vs right tasks
event_ids = ['left', 'right']

tmin, tmax = -1.0, 2.0

epochs = read_motor_imagery_epochs(
    subject=subject,
    runs=runs,
    tmin=tmin,
    tmax=tmax,
    event_ids = event_ids,
    )

Look into power spectral density

In [None]:
fig = epochs.compute_psd(fmin=2.0, fmax=40.0, tmax=3.0, n_jobs=None).plot()

## Time-frequency decomposition using morelet wavelet - MNE implementation

Define parameters and perform decomposition

In [None]:
freqs = np.logspace(*np.log10([1, 30]), num=30)
print(freqs)
n_cycles = freqs / 1  # different number of cycle per frequency

power, itc = epochs.compute_tfr(
    method="morlet",
    freqs=freqs,
    n_cycles=n_cycles,
    average=True,
    return_itc=True,
    decim=3,
)

Inspect format of the results

In [None]:
power

Plot results

In [None]:
# plot topomap of time-frequency
power.plot_topo(baseline=(-1,0), mode="logratio", title="Average power")

# plot time-frequancy solution for given channel
ch_name = 'C3'
ch_index = epochs.info.ch_names.index(ch_name)
fig = power.plot(picks=[ch_index], baseline=(-2,0), mode="logratio", title=ch_name, yscale='log')

# joint plot
fig = power.plot_joint(
    baseline=(-1,0), mode="logratio", tmin=tmin, tmax=2, timefreqs=[(0, 12), (1.25, 12)], picks=ch_name, yscale='log'
)

## Time-frequency decomposition using morelet wavelet - ssqueezepy implementation

This is more manual implementation, but it allows you to choose many parameters directly, and thus to control your wavelet to the greater extend.

In [None]:
participant_epochs_reconstructed = []
participant_epochs_original = []
participant_epochs_Wx = []
participant_epochs_normalized_Wx = []

# define data - for now, only one channel
channel = 'C3'
epochs_data = epochs.copy().pick(channel)
signal_freq = epochs.info['sfreq']

# define wavelet and its parameters
wavelet = Wavelet(('morlet', {'mu': 6}))
scales = np.geomspace(5,500,200)

N = epochs_data.get_data()[0].flatten().shape[0]
freq = scale_to_freq(scales, wavelet, N=N, fs=signal_freq)

for epoch in epochs_data:

    Wx, scales = cwt(epoch.flatten(), wavelet, fs=signal_freq, scales=scales, padtype='wrap', l1_norm=True, nv=None)
    time_domain_signal = icwt(Wx, wavelet, scales=scales, nv=None, padtype='wrap', l1_norm=True, x_mean=np.mean(epoch.flatten()))

    participant_epochs_reconstructed.append(time_domain_signal)
    participant_epochs_original.append(epoch.flatten())
    participant_epochs_Wx.append(Wx)

    # Compute and normalize the power spectrum from the CWT coefficients
    power_spectrum = np.abs(Wx)**2
    normalized_power_spectrum = power_spectrum / np.sum(power_spectrum)

    participant_epochs_normalized_Wx.append(normalized_power_spectrum)

And look into the results (aggregated over epochs):

In [None]:
# agregate results
evoked_original = np.mean(participant_epochs_original, axis=0)
evoked_reconstructed = np.mean(participant_epochs_reconstructed, axis=0)
mean_Wx = np.mean(participant_epochs_Wx, axis=0)
mean_normalized_Wx = np.mean(participant_epochs_normalized_Wx, axis=0)

print(f"Original evoked shape: {evoked_original.shape}\nReconstructed evoked shape: {evoked_reconstructed.shape}\nWx shape: {mean_Wx.shape}")

In [None]:
# Plot CWT result

# baseline
mean_Wx_power = np.abs(mean_Wx)
baseline_stop = int(1*signal_freq)
baseline_mean = np.mean(mean_Wx_power[:, :baseline_stop], axis=1, keepdims=True)
mean_Wx_baselined = mean_Wx_power - baseline_mean

# Plot
fig, ax = plt.subplots()
im = ax.imshow(mean_Wx_baselined, aspect='auto', extent=[t[0], t[-1], freq[-1], freq[0]], origin='lower')
ax.set_yscale('log')

# Add colorbar and labels
plt.colorbar(im, ax=ax)
ax.set_xlabel("Time (s)")
ax.set_ylabel("Frequency (Hz)")

plt.show()

And look into reconstruction quality:

In [None]:
t = np.linspace(tmin, tmax, N)
plt.plot(t, evoked_original, label='original signal')
plt.plot(t, evoked_reconstructed, label='reconstructed signal')

plt.legend()