# Plotting The Regional Interpolation Map
This notebook demonstrates how to recreate the plot showing regional channel interpolation groups.

In [None]:
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.patches as mpatches
import mne as mne
import config.interpolation_maps

### Loading The RAW EEG Data
First we load the sample eeg recording for eyes closed condition.

In [None]:
# Load EEG data files for eyes-closed conditions
file_path_eyes_closed = r"./data/example_eeg/example_subj_EC_raw.fif.gz"

# Read raw EEG data
raw_eeg_eyes_closed = mne.io.read_raw(file_path_eyes_closed, preload=True, verbose=False)

### Defining The Helper Functions
Next we define a couple of functions to modularize the process.

In [None]:
def get_2d_coordinates(ch_pos, radius=1.1, projection_scale=1):
    """Convert 3D coordinates to 2D using azimuthal equidistant projection."""
    cord_3d = np.array([list(pos) for pos in ch_pos.values()])
    points_norm = cord_3d / np.linalg.norm(cord_3d, axis=1, keepdims=True)
    x_2d = projection_scale * radius * points_norm[:, 0] / (points_norm[:, 2] + radius)
    y_2d = projection_scale * radius * points_norm[:, 1] / (points_norm[:, 2] + radius)
    return np.vstack([x_2d, y_2d]).T

def compute_region_centroids(plot_cord):
    """Compute centroids of regions based on 2D coordinates."""
    return {reg: np.mean(coords, axis=0) for reg, coords in plot_cord.items()}

def plot_connections(ax, plot_cord, region_cord, region_colors):
    """Plot lines from region centroids to channel points."""
    for reg, cords in plot_cord.items():
        for cord_chan in cords:
            ax.plot([region_cord[reg][0], cord_chan[0]], [region_cord[reg][1], cord_chan[1]],
                    color=region_colors[reg], zorder=1, alpha=0.2)
            ax.scatter(cord_chan[0], cord_chan[1], c=region_colors[reg], marker='o', alpha=0.9, zorder=2)

def plot_region_centroids(ax, region_cord, region_colors):
    """Plot region centroids."""
    for key, cord in region_cord.items():
        ax.scatter(cord[0], cord[1], c=region_colors[key], s=600, zorder=3, alpha=0.75)
        ax.text(cord[0], cord[1], key, fontsize=11, fontweight='bold', color='white', zorder=4, ha='center', va='center')

def plot_regional_interpolation_map(region_map, region_colors):
    """Creates a plot of the regional interpolation map."""
    plot_cord = {reg: cord_2d[[ch_names.index(ch) for ch in channels if ch in ch_names]] for reg, channels in region_map.items()}
    region_cord = compute_region_centroids(plot_cord)
    fig, ax = plt.subplots(figsize=(6, 5))
    plot_connections(ax, plot_cord, region_cord, region_colors)
    plot_region_centroids(ax, region_cord, region_colors)
    
    # Create custom legend handles
    legend_labels = ['R'+str(i) for i in range(1, len(region_map)+1)]
    legend_handles = [mpatches.Patch(color=region_colors[label], label=label) for label in legend_labels]
    
    # Create the legend with these custom handles
    plt.legend(handles=legend_handles, bbox_to_anchor=(1.1, 1), title="Regions")
    
    plt.show()


### Interpolation Map With 5, 12, 20 Regions
Here we plot the regional interpolation map for our custom region maps defined in ["interpolation_maps.py"](https://github.com/Arsu-Lab/Different-Algorithms-Uncover-Different-Patterns-BrainAge-Prediction/blob/main/config/interpolation_maps.py).

In [None]:
# Load the montage from the raw eeg data
montage = raw_eeg_eyes_closed.get_montage()

# Load the 3D position coordinates from montage
ch_pos = montage.get_positions()['ch_pos']
ch_names = list(ch_pos.keys())

# Projecting the 3D coordinates to 2D space 
cord_2d = get_2d_coordinates(ch_pos)

# Defining the colores to use for each region
colors = config.interpolation_maps.chan_map_colors

# Loading the 5 region interpolation map
chan_map = config.interpolation_maps.chan_map_R5

# Plotting the regional interpolation map
plot_regional_interpolation_map(chan_map, colors)

In [None]:
# Loading the 12 region interpolation map
chan_map = config.interpolation_maps.chan_map_R12

# Plotting the regional interpolation map
plot_regional_interpolation_map(chan_map, colors)

In [None]:
# Loading the 20 region interpolation map
chan_map = config.interpolation_maps.chan_map_R20

# Plotting the regional interpolation map
plot_regional_interpolation_map(chan_map, colors)