In [None]:
import matplotlib.pyplot as plt
import mne
from pathlib import Path
import json
import sys

def calculate_and_plot_erf(epochs, event_id=None, title=None, filename=None):
    """
    Calculate and plot the Event-Related Field (ERF) for specified epochs 
    along with separate Global Field Powers (GFP) for magnetometers and 
    gradiometers.

    Parameters
    ----------
    epochs : mne.Epochs
        The epoched data.
    event_id : int or str, optional
        The id of the event for which to compute the ERF. 
        If None (default), all epochs will be used.
    title : str, optional
        The title to be added on top of the plots. 
        If None (default), no title will be added.
    filename : str or Path, optional
        The path and filename where the plot should be saved. 
        If None (default), the plot won't be saved, but only displayed.
    """
    if event_id is not None:
        # Compute ERF for specific event type
        erf = epochs[event_id].average()
    else:
        # Compute ERF using all epochs
        erf = epochs.average()

    # Separate data for magnetometers and gradiometers
    erf_mag = erf.copy().pick_types(meg='mag')
    erf_grad = erf.copy().pick_types(meg='grad')

    # Calculate GFP for magnetometers and gradiometers
    gfp_mag = erf_mag.data.std(axis=0, ddof=0)
    gfp_grad = erf_grad.data.std(axis=0, ddof=0)
    
    # Plot ERF for magnetometers
    fig1 = erf_mag.plot(spatial_colors=True, titles='ERF - Magnetometers', show=False)
    
    # Add a title if provided
    if title is not None:
        fig1.suptitle(title, fontsize=16, y=1.02)
    
    # Adjust the figure size
    fig1.set_size_inches((10, 6))

    if filename:
        mag_filename = filename.with_name(f"{filename.stem}_mag{filename.suffix}")
        plt.tight_layout()
        #plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust the padding
        #plt.savefig(mag_filename, dpi=300)
        plt.savefig(mag_filename, dpi=300, bbox_inches='tight', pad_inches=0.5)
    plt.show()
    
    # Plot ERF for gradiometers
    fig2 = erf_grad.plot(spatial_colors=True, titles='ERF - Gradiometers', show=False)

    # Add a title if provided
    if title is not None:
        fig2.suptitle(title, fontsize=16, y=1.02)
    
    # Adjust the figure size
    fig2.set_size_inches((10, 6))

    if filename:
        grad_filename = filename.with_name(f"{filename.stem}_grad{filename.suffix}")
        plt.tight_layout()
        #plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust the padding
        #plt.savefig(grad_filename, dpi=300)
        plt.savefig(grad_filename, dpi=300, bbox_inches='tight', pad_inches=0.5)

    plt.show()
    
    return erf


def get_event_name(event_dict, event_id_to_find):
    """
    Get event name corresponding to a provided event id.

    Parameters
    ----------
    event_dict : dict
        Dictionary mapping event names to event ids.
    event_id_to_find : int
        The event id for which to find the corresponding name.

    Returns
    -------
    str
        The name of the event corresponding to event_id_to_find.
        Returns None if the event_id is not found.
    """
    for name, id in event_dict.items():
        if id == event_id_to_find:
            return name
    return None  # or some default value if preferred


In [None]:
# PLOTTING SPECIFIC TIMEPOINTS TOPOGRAPHY - 0.2 and 0.8
import matplotlib.pyplot as plt
import mne
from pathlib import Path

def calculate_and_plot_erf_topo(epochs, event_id=None, title=None, filename=None, times=[0.2, 0.8]):
    """
    Calculate and plot the ERF along with topoplots at specified time points.

    Parameters
    ----------
    epochs : mne.Epochs
        The epoched data.
    event_id : int or str, optional
        The id of the event for which to compute the ERF. 
        If None (default), all epochs will be used.
    title : str, optional
        The title to be added on top of the plots. 
        If None (default), no title will be added.
    filename : str or Path, optional
        The path and filename where the plot should be saved. 
        If None (default), the plot won't be saved, but only displayed.
    times : list of float, optional
        The time points (in seconds) for which topoplots will be generated.
    """
    # Compute ERF
    if event_id is not None:
        erf = epochs[event_id].average()
    else:
        erf = epochs.average()

    # Separate data for magnetometers and gradiometers
    erf_mag = erf.copy().pick_types(meg='mag')
    erf_grad = erf.copy().pick_types(meg='grad')

    # Plot ERF for magnetometers
    fig1, ax1 = plt.subplots()
    erf_mag.plot(axes=ax1, spatial_colors=True, titles='ERF - Magnetometers', show=False)
    if title is not None:
        fig1.suptitle(title + ' - Magnetometers', fontsize=16, y=1.02)
    if filename:
        mag_filename = filename.with_name(f"{filename.stem}_mag{filename.suffix}")
        #plt.savefig(mag_filename, dpi=300)
    plt.show()
    
    # Plot ERF for gradiometers
    fig2, ax2 = plt.subplots()
    erf_grad.plot(axes=ax2, spatial_colors=True, titles='ERF - Gradiometers', show=False)
    if title is not None:
        fig2.suptitle(title + ' - Gradiometers', fontsize=16, y=1.02)
    if filename:
        grad_filename = filename.with_name(f"{filename.stem}_grad{filename.suffix}")
        #plt.savefig(grad_filename, dpi=300)
    plt.show()
    
        # Generate topoplots at specified time points for magnetometers
    for time_point in times:
        fig_mag_topo, ax = plt.subplots(1, 2, figsize=(10, 4))
        erf_mag.plot_topomap(times=time_point, average=0.05, axes=ax, show=False)
        ax[0].set_title(f"Topomap (mag) - {time_point}s", fontsize=16)
        if filename:
            topo_mag_filename = filename.with_name(f"{filename.stem}_mag_topo_{int(time_point*1000)}ms{filename.suffix}")
            plt.savefig(topo_mag_filename, dpi=300)
        plt.show()
    
    # Generate topoplots at specified time points for gradiometers
    for time_point in times:
        fig_grad_topo, ax = plt.subplots(1, 2, figsize=(10, 4))
        erf_grad.plot_topomap(times=time_point, average=0.05, axes=ax, show=False)
        ax[0].set_title(f"Topomap (grad) - {time_point}s", fontsize=16)
        if filename:
            topo_grad_filename = filename.with_name(f"{filename.stem}_grad_topo_{int(time_point*1000)}ms{filename.suffix}")
            plt.savefig(topo_grad_filename, dpi=300)
        plt.show()

    return erf


In [None]:
# PLOTTING LOCAL MAXIMA FOUND PROGRAMATICALLY
import matplotlib.pyplot as plt
import mne
from scipy.signal import find_peaks
from pathlib import Path
import numpy as np

def calculate_and_plot_erf_topo_peaks(epochs, event_id=None, title=None, filename=None, peak_picking=False): # If peak_picking=True, the function will automatically find peaks in the global field power (GFP) of the magnetometers and plot topomaps at these times
    """
    ...
    peak_picking : bool, optional
        If True, times for topomaps will be determined by automatically picking peaks in the GFP. 
        If False, topomaps will be shown for predefined times.
    """
    # Compute ERF
    erf = epochs[event_id].average() if event_id is not None else epochs.average()
    erf_mag = erf.copy().pick_types(meg='mag')
    erf_grad = erf.copy().pick_types(meg='grad')

    # Find peaks in the Global Field Power (if peak_picking is True)
    times = []
    if peak_picking:
        gfp_mag = np.sqrt((erf_mag.data ** 2).mean(axis=0))
        peaks_mag, _ = find_peaks(gfp_mag, distance=50)  # example parameters; adjust as needed
        times = erf_mag.times[peaks_mag]  # converting sample indices to times in seconds
    
    # Otherwise, define your own times to plot topomaps
    else:
        times = [0.1, 0.2]  # example times; adjust as needed
    
    # Plot ERFs
    for data_type, erf_type in zip(['mag', 'grad'], [erf_mag, erf_grad]):
        fig, ax = plt.subplots()
        erf_type.plot(axes=ax, spatial_colors=True, show=False)
        full_title = f'ERF - {data_type.capitalize()}'
        if title:
            full_title = f'{title} - {full_title}'
        fig.suptitle(full_title, fontsize=16, y=1.02)
        plt.tight_layout()
        plt.subplots_adjust(top=0.85)  # Adjusting space for the title
        #if filename:
        #    plt.savefig(Path(filename).with_suffix(f'_{data_type}.png'), dpi=300)
        plt.show()

        # Topomaps
        for time_point in times:
            fig, ax = plt.subplots(1, 2, figsize=(10, 4))
            erf_type.plot_topomap(times=time_point, average=0.05, axes=ax, show=False)
            ax[0].set_title(f'Topomap ({data_type}) - {time_point*1000:.0f} ms', fontsize=16)
            
            # Adjusting layout for better spacing
            plt.tight_layout()
            plt.subplots_adjust(top=0.85)  # Adjust space for the title at the top
            
            if filename:
                plt.savefig(Path(filename).with_name(f'{filename.stem}_{data_type}_topo_{time_point*1000:.0f}ms.png'), dpi=300)
            plt.show()



In [None]:
# GFP: PLOTTING LOCAL MAXIMA FOUND PROGRAMATICALLY

import matplotlib.pyplot as plt
import mne
from scipy.signal import find_peaks
import numpy as np

def calculate_and_plot_GFP(epochs, event_id=None, title=None, filename=None, peak_picking=False):
    """
    ...
    peak_picking : bool, optional
        If True, peaks in the GFP will be determined and plotted. 
        If False, the GFP will be plotted without peak information.
    """
    # Compute ERF
    erf = epochs[event_id].average() if event_id is not None else epochs.average()
    erf_mag = erf.copy().pick_types(meg='mag')
    
    # Compute GFP
    gfp_mag = np.sqrt((erf_mag.data ** 2).mean(axis=0))
    
    # Find peaks in the GFP (if peak_picking is True)
    peaks_mag = []
    if peak_picking:
        peaks_mag, _ = find_peaks(gfp_mag, distance=50)  # example parameters; adjust as needed
    
    # Plot GFP
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.plot(erf_mag.times, gfp_mag, label='GFP', color='black')
    
    if peak_picking:
        ax.plot(erf_mag.times[peaks_mag], gfp_mag[peaks_mag], 'o', color='red', label='Peaks')
        
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("GFP")
    full_title = 'Global Field Power (Mag)'
    if title:
        full_title = f'{title} - {full_title}'
    ax.set_title(full_title, fontsize=16)
    ax.legend()
    
    # Adjusting layout for better spacing
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)  # Adjust space for the title at the top
    
    #if filename:
    #    plt.savefig(filename.with_suffix('.png'), dpi=300)
    
    plt.show()


In [None]:
# import matplotlib.pyplot as plt
# import mne
# import numpy as np
# from scipy.signal import find_peaks

# def calculate_and_plot_erf_gfp_with_topomaps(epochs, event_id=None, title=None, filename=None):
#     """
#     Calculate and plot ERF, GFP, and topomaps at GFP peaks.
#     """
#     # Separate epochs for mag and grad
#     epochs_mag = epochs.copy().pick_types(meg='mag')
#     epochs_grad = epochs.copy().pick_types(meg='grad')
    
#     for erf, ch_type in zip([epochs_mag.average(), epochs_grad.average()], ['mag', 'grad']):
#         # Compute GFP
#         gfp = np.sqrt((erf.data ** 2).mean(axis=0))
        
#         # Find peaks in the GFP
#         peaks, _ = find_peaks(gfp, distance=20)  # Adjust distance as per your needs
#         peak_times = erf.times[peaks]
        
#         # Plot ERF and GFP
#         fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
#         erf.plot(axes=axes[0], spatial_colors=True, show=False)
#         axes[1].plot(erf.times, gfp, label='GFP')
#         axes[1].plot(peak_times, gfp[peaks], 'ro')  # Mark peaks with red dots
#         axes[1].set(xlabel='Time (s)', ylabel='GFP', title=f'Global Field Power with Peaks ({ch_type})')
#         plt.tight_layout()
        
#         if title:
#             fig.suptitle(f"{title} - {ch_type}", fontsize=16, y=1.02)
#             plt.tight_layout(rect=[0, 0, 1, 0.97])  # Leave space for the suptitle
        
#         #if filename:
#         #    plt.savefig(f"{filename}_ERF_GFP_{ch_type}.png", dpi=300)

#         # Plot topomaps at peaks
#         for idx, peak_time in enumerate(peak_times):
#             fig, ax_topo = plt.subplots(1, 1, figsize=(3, 3))
#             erf.plot_topomap(times=peak_time, size=3, show=False, axes=ax_topo, colorbar=False)
    
#             ax_topo.set_title(f"Topomap @ {peak_time*1000:.1f} ms ({ch_type})")
    
#             #if filename:
#             #    plt.savefig(f"{filename}_Topomap_{ch_type}_{idx}.png", dpi=300)
    
#         plt.show()
# import matplotlib.pyplot as plt
# import mne
# import numpy as np
# from scipy.signal import find_peaks

# def calculate_and_plot_erf_gfp_with_topomaps(epochs, event_id=None, title=None, filename=None):
#     """
#     Calculate and plot ERF, GFP, and topomaps at GFP peaks.
#     """
#     # Separate epochs for mag and grad
#     epochs_mag = epochs.copy().pick_types(meg='mag')
#     epochs_grad = epochs.copy().pick_types(meg='grad')
    
#     for erf, ch_type in zip([epochs_mag.average(), epochs_grad.average()], ['mag', 'grad']):
#         # Compute GFP
#         gfp = np.sqrt((erf.data ** 2).mean(axis=0))
        
#         # Find peaks in the GFP
#         peaks, _ = find_peaks(gfp, distance=20)  # Adjust distance as per your needs
#         peak_times = erf.times[peaks]
        
#         # Plot ERF and GFP
#         fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
#         erf.plot(axes=axes[0], spatial_colors=True, show=False)
#         axes[1].plot(erf.times, gfp, label='GFP')
#         axes[1].plot(peak_times, gfp[peaks], 'ro')  # Mark peaks with red dots
#         axes[1].set(xlabel='Time (s)', ylabel='GFP', title=f'Global Field Power with Peaks ({ch_type})')
#         plt.tight_layout()
        
#         if title:
#             fig.suptitle(f"{title} - {ch_type}", fontsize=16, y=1.02)
#             plt.tight_layout(rect=[0, 0, 1, 0.97])  # Leave space for the suptitle
        
#         #if filename:
#         #    plt.savefig(f"{filename}_ERF_GFP_{ch_type}.png", dpi=300)

#         # Plot topomaps at peaks
#         for idx, peak_time in enumerate(peak_times):
#             fig, ax_topo = plt.subplots(1, 1, figsize=(3, 3))
#             erf.plot_topomap(times=peak_time, size=3, show=False, axes=ax_topo, colorbar=False)
    
#             ax_topo.set_title(f"Topomap @ {peak_time*1000:.1f} ms ({ch_type})")
    
#             #if filename:
#             #    plt.savefig(f"{filename}_Topomap_{ch_type}_{idx}.png", dpi=300)
    
#         plt.show()
# import matplotlib.pyplot as plt
# import mne
# import numpy as np
# from scipy.signal import find_peaks

# def calculate_and_plot_erf_gfp_with_topomaps(epochs, event_id=None, title=None, filename=None):
#     """
#     Calculate and plot ERF, GFP, and topomaps at GFP peaks.
#     """
#     # Separate epochs for mag and grad
#     epochs_mag = epochs.copy().pick_types(meg='mag')
#     epochs_grad = epochs.copy().pick_types(meg='grad')
    
#     for erf, ch_type in zip([epochs_mag.average(), epochs_grad.average()], ['mag', 'grad']):
#         # Compute GFP
#         gfp = np.sqrt((erf.data ** 2).mean(axis=0))
        
#         # Find peaks in the GFP
#         peaks, _ = find_peaks(gfp, distance=20)  # Adjust distance as per your needs
#         peak_times = erf.times[peaks]
        
#         # Plot ERF and GFP
#         fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
#         erf.plot(axes=axes[0], spatial_colors=True, show=False)
#         axes[1].plot(erf.times, gfp, label='GFP')
#         axes[1].plot(peak_times, gfp[peaks], 'ro')  # Mark peaks with red dots
#         axes[1].set(xlabel='Time (s)', ylabel='GFP', title=f'Global Field Power with Peaks ({ch_type})')
#         plt.tight_layout()
        
#         if title:
#             fig.suptitle(f"{title} - {ch_type}", fontsize=16, y=1.02)
#             plt.tight_layout(rect=[0, 0, 1, 0.97])  # Leave space for the suptitle
        
#         #if filename:
#         #    plt.savefig(f"{filename}_ERF_GFP_{ch_type}.png", dpi=300)

#         # Plot topomaps at peaks
#         for idx, peak_time in enumerate(peak_times):
#             fig, ax_topo = plt.subplots(1, 1, figsize=(3, 3))
#             erf.plot_topomap(times=peak_time, size=3, show=False, axes=ax_topo, colorbar=False)
    
#             ax_topo.set_title(f"Topomap @ {peak_time*1000:.1f} ms ({ch_type})")
    
#             #if filename:
#             #    plt.savefig(f"{filename}_Topomap_{ch_type}_{idx}.png", dpi=300)
    
#         plt.show()
# import matplotlib.pyplot as plt
# import mne
# import numpy as np
# from scipy.signal import find_peaks

# def calculate_and_plot_erf_gfp_with_topomaps(epochs, event_id=None, title=None, filename=None):
#     """
#     Calculate and plot ERF, GFP, and topomaps at GFP peaks.
#     """
#     # Separate epochs for mag and grad
#     epochs_mag = epochs.copy().pick_types(meg='mag')
#     epochs_grad = epochs.copy().pick_types(meg='grad')
    
#     for erf, ch_type in zip([epochs_mag.average(), epochs_grad.average()], ['mag', 'grad']):
#         # Compute GFP
#         gfp = np.sqrt((erf.data ** 2).mean(axis=0))
        
#         # Find peaks in the GFP
#         peaks, _ = find_peaks(gfp, distance=20)  # Adjust distance as per your needs
#         peak_times = erf.times[peaks]
        
#         # Plot ERF and GFP
#         fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
#         erf.plot(axes=axes[0], spatial_colors=True, show=False)
#         axes[1].plot(erf.times, gfp, label='GFP')
#         axes[1].plot(peak_times, gfp[peaks], 'ro')  # Mark peaks with red dots
#         axes[1].set(xlabel='Time (s)', ylabel='GFP', title=f'Global Field Power with Peaks ({ch_type})')
#         plt.tight_layout()
        
#         if title:
#             fig.suptitle(f"{title} - {ch_type}", fontsize=16, y=1.02)
#             plt.tight_layout(rect=[0, 0, 1, 0.97])  # Leave space for the suptitle
        
#         #if filename:
#         #    plt.savefig(f"{filename}_ERF_GFP_{ch_type}.png", dpi=300)

#         # Plot topomaps at peaks
#         for idx, peak_time in enumerate(peak_times):
#             fig, ax_topo = plt.subplots(1, 1, figsize=(3, 3))
#             erf.plot_topomap(times=peak_time, size=3, show=False, axes=ax_topo, colorbar=False)
    
#             ax_topo.set_title(f"Topomap @ {peak_time*1000:.1f} ms ({ch_type})")
    
#             #if filename:
#             #    plt.savefig(f"{filename}_Topomap_{ch_type}_{idx}.png", dpi=300)
    
#         plt.show()
# import matplotlib.pyplot as plt
# import mne
# import numpy as np
# from scipy.signal import find_peaks

# def calculate_and_plot_erf_gfp_with_topomaps(epochs, event_id=None, title=None, filename=None):
#     """
#     Calculate and plot ERF, GFP, and topomaps at GFP peaks.
#     """
#     # Separate epochs for mag and grad
#     epochs_mag = epochs.copy().pick_types(meg='mag')
#     epochs_grad = epochs.copy().pick_types(meg='grad')
    
#     for erf, ch_type in zip([epochs_mag.average(), epochs_grad.average()], ['mag', 'grad']):
#         # Compute GFP
#         gfp = np.sqrt((erf.data ** 2).mean(axis=0))
        
#         # Find peaks in the GFP
#         peaks, _ = find_peaks(gfp, distance=20)  # Adjust distance as per your needs
#         peak_times = erf.times[peaks]
        
#         # Plot ERF and GFP
#         fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
#         erf.plot(axes=axes[0], spatial_colors=True, show=False)
#         axes[1].plot(erf.times, gfp, label='GFP')
#         axes[1].plot(peak_times, gfp[peaks], 'ro')  # Mark peaks with red dots
#         axes[1].set(xlabel='Time (s)', ylabel='GFP', title=f'Global Field Power with Peaks ({ch_type})')
#         plt.tight_layout()
        
#         if title:
#             fig.suptitle(f"{title} - {ch_type}", fontsize=16, y=1.02)
#             plt.tight_layout(rect=[0, 0, 1, 0.97])  # Leave space for the suptitle
        
#         #if filename:
#         #    plt.savefig(f"{filename}_ERF_GFP_{ch_type}.png", dpi=300)

#         # Plot topomaps at peaks
#         for idx, peak_time in enumerate(peak_times):
#             fig, ax_topo = plt.subplots(1, 1, figsize=(3, 3))
#             erf.plot_topomap(times=peak_time, size=3, show=False, axes=ax_topo, colorbar=False)
    
#             ax_topo.set_title(f"Topomap @ {peak_time*1000:.1f} ms ({ch_type})")
    
#             #if filename:
#             #    plt.savefig(f"{filename}_Topomap_{ch_type}_{idx}.png", dpi=300)
    
#         plt.show()
# import matplotlib.pyplot as plt
# import mne
# import numpy as np
# from scipy.signal import find_peaks

# def calculate_and_plot_erf_gfp_with_topomaps(epochs, event_id=None, title=None, filename=None):
#     """
#     Calculate and plot ERF, GFP, and topomaps at GFP peaks.
#     """
#     # Separate epochs for mag and grad
#     epochs_mag = epochs.copy().pick_types(meg='mag')
#     epochs_grad = epochs.copy().pick_types(meg='grad')
    
#     for erf, ch_type in zip([epochs_mag.average(), epochs_grad.average()], ['mag', 'grad']):
#         # Compute GFP
#         gfp = np.sqrt((erf.data ** 2).mean(axis=0))
        
#         # Find peaks in the GFP
#         peaks, _ = find_peaks(gfp, distance=20)  # Adjust distance as per your needs
#         peak_times = erf.times[peaks]
        
#         # Plot ERF and GFP
#         fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
#         erf.plot(axes=axes[0], spatial_colors=True, show=False)
#         axes[1].plot(erf.times, gfp, label='GFP')
#         axes[1].plot(peak_times, gfp[peaks], 'ro')  # Mark peaks with red dots
#         axes[1].set(xlabel='Time (s)', ylabel='GFP', title=f'Global Field Power with Peaks ({ch_type})')
#         plt.tight_layout()
        
#         if title:
#             fig.suptitle(f"{title} - {ch_type}", fontsize=16, y=1.02)
#             plt.tight_layout(rect=[0, 0, 1, 0.97])  # Leave space for the suptitle
        
#         #if filename:
#         #    plt.savefig(f"{filename}_ERF_GFP_{ch_type}.png", dpi=300)

#         # Plot topomaps at peaks
#         for idx, peak_time in enumerate(peak_times):
#             fig, ax_topo = plt.subplots(1, 1, figsize=(3, 3))
#             erf.plot_topomap(times=peak_time, size=3, show=False, axes=ax_topo, colorbar=False)
    
#             ax_topo.set_title(f"Topomap @ {peak_time*1000:.1f} ms ({ch_type})")
    
#             #if filename:
#             #    plt.savefig(f"{filename}_Topomap_{ch_type}_{idx}.png", dpi=300)
    
#         plt.show()
# import matplotlib.pyplot as plt
# import mne
# import numpy as np
# from scipy.signal import find_peaks

# def calculate_and_plot_erf_gfp_with_topomaps(epochs, event_id=None, title=None, filename=None):
#     """
#     Calculate and plot ERF, GFP, and topomaps at GFP peaks.
#     """
#     # Separate epochs for mag and grad
#     epochs_mag = epochs.copy().pick_types(meg='mag')
#     epochs_grad = epochs.copy().pick_types(meg='grad')
    
#     for erf, ch_type in zip([epochs_mag.average(), epochs_grad.average()], ['mag', 'grad']):
#         # Compute GFP
#         gfp = np.sqrt((erf.data ** 2).mean(axis=0))
        
#         # Find peaks in the GFP
#         peaks, _ = find_peaks(gfp, distance=20)  # Adjust distance as per your needs
#         peak_times = erf.times[peaks]
        
#         # Plot ERF and GFP
#         fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
#         erf.plot(axes=axes[0], spatial_colors=True, show=False)
#         axes[1].plot(erf.times, gfp, label='GFP')
#         axes[1].plot(peak_times, gfp[peaks], 'ro')  # Mark peaks with red dots
#         axes[1].set(xlabel='Time (s)', ylabel='GFP', title=f'Global Field Power with Peaks ({ch_type})')
#         plt.tight_layout()
        
#         if title:
#             fig.suptitle(f"{title} - {ch_type}", fontsize=16, y=1.02)
#             plt.tight_layout(rect=[0, 0, 1, 0.97])  # Leave space for the suptitle
        
#         #if filename:
#         #    plt.savefig(f"{filename}_ERF_GFP_{ch_type}.png", dpi=300)

#         # Plot topomaps at peaks
#         for idx, peak_time in enumerate(peak_times):
#             fig, ax_topo = plt.subplots(1, 1, figsize=(3, 3))
#             erf.plot_topomap(times=peak_time, size=3, show=False, axes=ax_topo, colorbar=False)
    
#             ax_topo.set_title(f"Topomap @ {peak_time*1000:.1f} ms ({ch_type})")
    
#             #if filename:
#             #    plt.savefig(f"{filename}_Topomap_{ch_type}_{idx}.png", dpi=300)
    
#         plt.show()
# import matplotlib.pyplot as plt
# import mne
# import numpy as np
# from scipy.signal import find_peaks

# def calculate_and_plot_erf_gfp_with_topomaps(epochs, event_id=None, title=None, filename=None):
#     """
#     Calculate and plot ERF, GFP, and topomaps at GFP peaks.
#     """
#     # Separate epochs for mag and grad
#     epochs_mag = epochs.copy().pick_types(meg='mag')
#     epochs_grad = epochs.copy().pick_types(meg='grad')
    
#     for erf, ch_type in zip([epochs_mag.average(), epochs_grad.average()], ['mag', 'grad']):
#         # Compute GFP
#         gfp = np.sqrt((erf.data ** 2).mean(axis=0))
        
#         # Find peaks in the GFP
#         peaks, _ = find_peaks(gfp, distance=20)  # Adjust distance as per your needs
#         peak_times = erf.times[peaks]
        
#         # Plot ERF and GFP
#         fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
#         erf.plot(axes=axes[0], spatial_colors=True, show=False)
#         axes[1].plot(erf.times, gfp, label='GFP')
#         axes[1].plot(peak_times, gfp[peaks], 'ro')  # Mark peaks with red dots
#         axes[1].set(xlabel='Time (s)', ylabel='GFP', title=f'Global Field Power with Peaks ({ch_type})')
#         plt.tight_layout()
        
#         if title:
#             fig.suptitle(f"{title} - {ch_type}", fontsize=16, y=1.02)
#             plt.tight_layout(rect=[0, 0, 1, 0.97])  # Leave space for the suptitle
        
#         #if filename:
#         #    plt.savefig(f"{filename}_ERF_GFP_{ch_type}.png", dpi=300)

#         # Plot topomaps at peaks
#         for idx, peak_time in enumerate(peak_times):
#             fig, ax_topo = plt.subplots(1, 1, figsize=(3, 3))
#             erf.plot_topomap(times=peak_time, size=3, show=False, axes=ax_topo, colorbar=False)
    
#             ax_topo.set_title(f"Topomap @ {peak_time*1000:.1f} ms ({ch_type})")
    
#             #if filename:
#             #    plt.savefig(f"{filename}_Topomap_{ch_type}_{idx}.png", dpi=300)
    
#         plt.show()
import matplotlib.pyplot as plt
import mne
from scipy.signal import find_peaks
import numpy as np

def calculate_and_plot_erf_gfp_with_topomaps(epochs, ch_type='mag', event_id=None, title=None, filename=None):
    """
    Calculate and plot ERF, GFP and topomaps at GFP peaks.
    ...
    """
    # Specify the channel type and average the epochs
    erf = epochs[event_id].copy().pick_types(meg=ch_type).average() if event_id is not None else epochs.copy().pick_types(meg=ch_type).average()
    
    # Compute GFP
    gfp = np.sqrt((erf.data ** 2).mean(axis=0))
    
    # Find peaks in the GFP
    peaks, _ = find_peaks(gfp, distance=20)  # Adjust distance as per your needs
    peak_times = erf.times[peaks]
    
    # Plot ERF and GFP
    fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
    erf.plot(axes=axes[0], spatial_colors=True, show=False)
    axes[1].plot(erf.times, gfp, label='GFP')
    axes[1].plot(peak_times, gfp[peaks], 'ro')  # Mark peaks with red dots
    axes[1].set(xlabel='Time (s)', ylabel='GFP', title='Global Field Power with Peaks')
    
    # Set xticks and labels for more granularity
    x_ticks = np.arange(np.min(erf.times), np.max(erf.times), 0.1)  # specify interval
    x_labels = [f"{tick*1000:.0f}" for tick in x_ticks]  # convert to ms and format labels
    axes[1].set_xticks(x_ticks)
    axes[1].set_xticklabels(x_labels)

    plt.tight_layout()
    
    if title:
        fig.suptitle(title, fontsize=16, y=1.02)
        plt.tight_layout(rect=[0, 0, 1, 0.97])  # Leave space for the suptitle
    
    #if filename:
    #    plt.savefig(f"{filename}_ERF_GFP.png", dpi=300)
    
    # Annotate the peaks in the GFP plot
    # for peak, peak_time in zip(peaks, peak_times):
    #     axes[1].annotate(f"{peak_time*1000:.0f} ms", (peak_time, 0),
    #                     textcoords="offset points", xytext=(2,5),
    #                     ha='center', fontsize=8, color='r')
    #axes[1].plot(erf.times, gfp, label='GFP')
    
    # Adjust y-axis limits to make annotations visible
    #axes[1].set_ylim(bottom=-0.2, top=gfp.max() + 0.2)  # Adjust as per your requirements

    # Plot topomaps at peaks
    for idx, peak_time in enumerate(peak_times):
        _, ax_topo = plt.subplots(1, 1, figsize=(3, 3))
        erf.plot_topomap(times=peak_time, size=3, 
                         show=False, axes=ax_topo, colorbar=False)
        ax_topo.set_title(f"Topomap @ {peak_time*1000:.1f} ms")
        
    #    if filename:
    #        plt.savefig(f"{filename}_Topomap_{idx}.png", dpi=300)
        
        plt.show()

# Usage example:
# calculate_and_plot_erf_gfp_with_topomaps(epochs, ch_type='mag', event_id='YourEventID', title='YourTitle')




In [None]:
# ---- FUNC 1: GENERATING ERF PLOTS
from pathlib import Path
import mne
import numpy as np
import json
import sys  # Ensure sys is imported

# Assuming notebook is located in: /work/name/notebooks_PHB/MEG_portfolio/sanity_checks/
# Define a static path to where the notebook is located
notebook_path = Path("/work/PernilleHøjlundBrams#8577/notebooks_PHB/MEG_portfolio/sanity_checks")

# Add the utils module path to sys.path to import preprocess_data_sensorspace
sys.path.append(str(notebook_path.parents[0]))  # Adding the parent directory to sys.path
from utils import preprocess_data_sensorspace

# Define other paths
MEG_data_path = Path("/work/834761")
ICA_path = Path("/work/study_group_8/ICA")  # Adjusting according to your folder structure
#plot_path = notebook_path / "plots"
#plot_path = notebook_path / "plots_button_removed"
plot_path = notebook_path / "plots_button_removed_topo"

# Make sure plot_path exists
if not plot_path.exists():
    plot_path.mkdir(parents=True)

# Read session info
with open('/work/PernilleHøjlundBrams#8577/notebooks_PHB/MEG_portfolio/session_info.txt', 'r') as f:
    session_info = json.loads(f.read())

subjects = ["0108", "0109", "0110", "0111", "0112", "0113", "0114", "0115"]
recording_names = [
    '001.self_block1',  '002.other_block1', '003.self_block2',
    '004.other_block2', '005.self_block3',  '006.other_block3'
]

for subject in subjects:
    subject_info = session_info[subject]
    reject = subject_info["reject"]

    subject_path = MEG_data_path / subject
    subject_meg_path = list(subject_path.glob("*_000000"))[0]

    # Looping through blocks
    for recording_name in recording_names:

        subject_session_info = subject_info[recording_name]
        fif_file_path = list((subject_meg_path / "MEG" / recording_name / "files").glob("*.fif"))[0]
        plot_filename = plot_path / f"{subject}-{recording_name}.png"

        ICA_path_sub = ICA_path / subject / f"{recording_name}-ica.fif"

        # If a self-block, then take only 11/12 (no buttons)
        if 'self' in recording_name:
            event_id = {
                "img/self/positive": 11, 
                "img/self/negative": 12} 
                #"img/button_press": 23,
                #"response/self": 202}

        # If an other-block, then take only 21/22 (no buttons)
        elif 'other' in recording_name: 
            event_id = {
                "img/assigned/positive": 21, 
                "img/assigned/negative": 22} #, 
                #"img/assigned/button_press": 23,
                #"response/assigned": 202}

        epochs = preprocess_data_sensorspace(
            fif_file_path,
            subject_session_info["bad_channels"],
            reject,
            ICA_path_sub,
            subject_session_info["noise_components"],
            event_ids = event_id
        )
        
        print(f"### \n ### EPOCH EVENT_ID after drop are: {epochs.event_id} ### \n ###")

        # Formulate a meaningful title for the plot
        title = f"ERF for Subject: {subject}, Session: {recording_name}"
        
        # Construct filename to save the plot
        filename = plot_path / f"{subject}_{recording_name}_ERF_plot.png"
        
        # Assuming `epochs` is computed somewhere in your code block before this
        calculate_and_plot_erf(epochs, event_id=None, title=title, filename=filename)
    

In [None]:
# ---- FUNC 2: GENERATING TOPO PLOTS
from pathlib import Path
import mne
import numpy as np
import json
import sys  # Ensure sys is imported

# Assuming notebook is located in: /work/name/notebooks_PHB/MEG_portfolio/sanity_checks/
# Define a static path to where the notebook is located
notebook_path = Path("/work/PernilleHøjlundBrams#8577/notebooks_PHB/MEG_portfolio/sanity_checks")

# Add the utils module path to sys.path to import preprocess_data_sensorspace
sys.path.append(str(notebook_path.parents[0]))  # Adding the parent directory to sys.path
from utils import preprocess_data_sensorspace

# Define other paths
MEG_data_path = Path("/work/834761")
ICA_path = Path("/work/study_group_8/ICA")  # Adjusting according to your folder structure
#plot_path = notebook_path / "plots"
#plot_path = notebook_path / "plots_button_removed"
plot_path = notebook_path / "plots_button_removed_topo"

# Make sure plot_path exists
if not plot_path.exists():
    plot_path.mkdir(parents=True)

# Read session info
with open('/work/PernilleHøjlundBrams#8577/notebooks_PHB/MEG_portfolio/session_info.txt', 'r') as f:
    session_info = json.loads(f.read())

subjects = ["0108", "0109", "0110", "0111", "0112", "0113", "0114", "0115"]
recording_names = [
    '001.self_block1',  '002.other_block1', '003.self_block2',
    '004.other_block2', '005.self_block3',  '006.other_block3'
]

for subject in subjects:
    subject_info = session_info[subject]
    reject = subject_info["reject"]

    subject_path = MEG_data_path / subject
    subject_meg_path = list(subject_path.glob("*_000000"))[0]

    # Looping through blocks
    for recording_name in recording_names:

        subject_session_info = subject_info[recording_name]
        fif_file_path = list((subject_meg_path / "MEG" / recording_name / "files").glob("*.fif"))[0]
        plot_filename = plot_path / f"{subject}-{recording_name}.png"

        ICA_path_sub = ICA_path / subject / f"{recording_name}-ica.fif"

        # If a self-block, then take only 11/12 (no buttons)
        if 'self' in recording_name:
            event_id = {
                "img/self/positive": 11, 
                "img/self/negative": 12} 
                #"img/button_press": 23,
                #"response/self": 202}

        # If an other-block, then take only 21/22 (no buttons)
        elif 'other' in recording_name: 
            event_id = {
                "img/assigned/positive": 21, 
                "img/assigned/negative": 22} #, 
                #"img/assigned/button_press": 23,
                #"response/assigned": 202}

        epochs = preprocess_data_sensorspace(
            fif_file_path,
            subject_session_info["bad_channels"],
            reject,
            ICA_path_sub,
            subject_session_info["noise_components"],
            event_ids = event_id
        )
        
        print(f"### \n ### EPOCH EVENT_ID after drop are: {epochs.event_id} ### \n ###")

        # Formulate a meaningful title for the plot
        #title = f"ERF for Subject: {subject}, Session: {recording_name}"
        title = f"GFP for Subject: {subject}, Session: {recording_name}"
        
        # Construct filename to save the plot
        filename = plot_path / f"{subject}_{recording_name}_ERF_plot.png"
        
        # Assuming `epochs` is computed somewhere in your code block before this
        calculate_and_plot_erf_gfp_with_topomaps(epochs, event_id=None, title=title, filename=filename) #peak_picking=True
    