<a href="https://colab.research.google.com/github/ampnb/EEG-Mind-Wandering/blob/main/MW_PSD_FFT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import mne
import numpy as np
import matplotlib.pyplot as plt

def compute_psd(raw, fmin=1, fmax=40, n_fft=2048, n_overlap=0):
    psds, freqs = mne.time_frequency.psd_multitaper(raw, fmin=fmin, fmax=fmax, n_fft=n_fft, n_overlap=n_overlap)
    return psds, freqs

def compute_fft(raw, fmin=1, fmax=40):
    data, times = raw[:]
    fft = np.fft.rfft(data)
    freqs = np.fft.rfftfreq(data.shape[-1], d=1 / raw.info['sfreq'])
    fft = fft[:, (freqs >= fmin) & (freqs <= fmax)]
    freqs = freqs[(freqs >= fmin) & (freqs <= fmax)]
    return fft, freqs

def plot_psd(psd, freqs, channel_names):
    plt.figure()
    for i, channel in enumerate(channel_names):
        plt.semilogy(freqs, psd[i], label=channel)
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('PSD (dB)')
    plt.legend()
    plt.show()

def plot_fft(fft, freqs, channel_names):
    plt.figure()
    for i, channel in enumerate(channel_names):
        plt.plot(freqs, np.abs(fft[i]), label=channel)
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Amplitude')
    plt.legend()
    plt.show()

if __name__ == "__main__":
    preprocessed_data_path = './Participant1_meditation_trial1_preprocessed.fif'
    raw = mne.io.read_raw_fif(preprocessed_data_path, preload=True)

    # Compute PSD
    psds, freqs = compute_psd(raw)
    channel_names = raw.info['ch_names']
    plot_psd(psds, freqs, channel_names)

    # Compute FFT
    fft, freqs = compute_fft(raw)
    plot_fft(fft, freqs, channel_names)