### Plots flatmaps with distinct features

In [1]:
import sys

sys.path.insert(1, "/home/vinicius/storage1/projects/GrayData-Analysis")

In [2]:
import argparse
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from tqdm import tqdm

from config import sessions
from GDa.flatmap.flatmap import flatmap
from GDa.util import create_stages_time_grid

In [3]:
# Path in which to save plots
results = "/home/vinicius/storage1/projects/GrayData-Analysis/figures/features_flatmaps"

In [4]:
metric = "coh"

#### Helper function to load data

In [5]:
def return_file_path(_ROOT, _FILE_NAME, s_id):
    path_metric = os.path.join(_ROOT, f"Results/lucy/{s_id}/session01", _FILE_NAME)
    return path_metric


def average_stages(data, stats):
    """
    Loads the data DataArray and average it for each task
    stage if needed (avg=1) otherwise return the data itself
    (avg=0).
    """
    out = []
    # Creates stage mask
    attrs = data.attrs
    mask = create_stages_time_grid(
        attrs["t_cue_on"] - 0.2,
        attrs["t_cue_off"],
        attrs["t_match_on"],
        attrs["fsample"],
        data.times.data,
        data.sizes["trials"],
        early_delay=0.3,
        align_to="cue",
        flatten=True,
    )
    for stage in mask.keys():
        mask[stage] = xr.DataArray(mask[stage], dims=("observations"))

    data = data.stack(observations=("trials", "times"))
    for stage in mask.keys():
        aux = data.isel(observations=mask[stage])
        if stats == "mean":
            out += [aux.mean("observations", skipna=True)]
        elif stats == "95p":
            out += [aux.quantile(0.95, "observations", skipna=True)]
        elif stats == "cv":
            mu = aux.mean("observations", skipna=True)
            sig = aux.std("observations", skipna=True)
            out += [sig / mu]

    out = xr.concat(out, "times")
    out.attrs = attrs
    return out

In [6]:
# Name of the file containing the feature
def get_file_name(feature):
    if feature == "power":
        return "power_tt_1_br_1_at_cue.nc"
    else:
        return f"{metric}_{feature}_thr_1_at_cue.nc"

In [7]:
def load_sessions(_ROOT, _FILE_NAME, stats="mean"):
    """Load data for all sessions"""

    assert stats in ["mean", "95p", "cv"]

    data = []
    for s_id in tqdm(sessions):
        # Get path to file
        path_metric = return_file_path(_ROOT, _FILE_NAME, s_id)
        # Load network feature
        out = xr.load_dataarray(path_metric)
        # Average if needed
        out = average_stages(out, stats)
        # Concqtenate channels
        data += [out.isel(roi=[r]).astype(np.float32) for r in range(len(out["roi"]))]

    # Concatenate channels
    data = xr.concat(data, dim="roi")
    # Get unique rois
    urois, counts = np.unique(data.roi.data, return_counts=True)
    # Get unique rois that has at leats 10 channels
    urois = urois[counts >= 10]
    # Average channels withn the same roi
    data = data.groupby("roi").mean("roi", skipna=True)
    data = data.sel(roi=urois)
    return data

#### Plotting function

In [8]:
def plot_feat_flatmap(data, f, suptitle=None, fig_name=None):
    """Plot the flatmaps of data for a given frequency"""
    # Define sub-cortical areas names
    sca = np.array(["thal", "putamen", "claustrum", "caudate"])

    # Get area names
    areas = data.roi.values
    areas = [a.lower() for a in areas]
    index = np.where(np.isin(areas, sca))
    _areas_nosca = np.delete(areas, index)

    times = data.times.values
    n_times = len(times)

    # Name of each stage to use in plot titles
    stage = [r"$P$", r"$S$", r"$D_1$", r"$D_2$", r"$D_M$"]
    # Name of each metric
    metrics = ["mean", "95p", "cv"]
    n_metrics = len(metrics)

    # Annotations
    ann = ["A", "B", "C"]

    # Create canvas in which the flatmaps will be drawn
    fig = plt.figure(figsize=(8, 4), dpi=600)
    gs1 = fig.add_gridspec(
        nrows=n_metrics, ncols=n_times, left=0.05, right=0.87, bottom=0.05, top=0.92
    )
    gs2 = fig.add_gridspec(
        nrows=n_metrics,
        ncols=1,
        left=0.90,
        right=0.92,
        bottom=0.08,
        top=0.90,
        height_ratios=[0.2] * 3,
    )

    # Will store the axes of the figure
    ax, ax_cbar = [], []
    count = 0
    for i_f, feature in enumerate(metrics):
        # Plot flatmap for different freuquencies and times
        ax_cbar += [plt.subplot(gs2[i_f])]  # Colorbar axis
        # Limits
        vmin = data.isel(freqs=f, stats=i_f).min()
        vmax = data.isel(freqs=f, stats=i_f).max()
        for t in range(n_times):
            ax += [plt.subplot(gs1[count])]  # Flatmap axis
            # Get values to plot in the flatmap
            values = data.isel(times=t, freqs=f, stats=i_f).values
            # Delete values for subcortical areas
            values = np.delete(values, index)
            # Instantiate flatmap
            fmap = flatmap(values, _areas_nosca)
            # Only plot colorbar for last column
            if t == 3:
                fmap.plot(
                    ax[count],
                    ax_colorbar=ax_cbar[i_f],
                    cbar_title=f"{feature}",
                    vmin=vmin,
                    vmax=vmax,
                    alpha=0.6,
                    colormap="hot_r",
                )
            else:
                fmap.plot(
                    ax[count],
                    ax_colorbar=None,
                    cbar_title=f"{feature}",
                    vmin=vmin,
                    vmax=vmax,
                    alpha=0.6,
                    colormap="hot_r",
                )
            count += 1
            if t == 0:
                plt.annotate(ann[i_f], (0, 0.9))
            if isinstance(suptitle, str):
                plt.suptitle(suptitle + f", band {f+1}", fontsize=10)
    plt.savefig(os.path.join(results, fig_name))
    plt.close()

In [9]:
_ROOT = os.path.expanduser("~/funcog/gda/")

### Power

In [10]:
_FILE_NAME = get_file_name("power")

In [18]:
if metric == "coh":
    power_mean = load_sessions(_ROOT, _FILE_NAME, stats="mean")
    power_cv = load_sessions(_ROOT, _FILE_NAME, stats="cv")
    power_95p = load_sessions(_ROOT, _FILE_NAME, stats="95p")
    data = xr.concat([power_mean, power_95p, power_cv], "stats")
    for f in range(10):
        fig_name = f"power_{metric}_f_{f}.png"
        plot_feat_flatmap(data, f, suptitle="Power", fig_name=fig_name)

### Degree

In [19]:
_FILE_NAME = get_file_name("degree")

In [20]:
degree_mean = load_sessions(_ROOT, _FILE_NAME, stats="mean")

100%|███████████████████████████████████████████| 62/62 [03:09<00:00,  3.06s/it]


In [21]:
degree_cv = load_sessions(_ROOT, _FILE_NAME, stats="cv")

100%|███████████████████████████████████████████| 62/62 [01:25<00:00,  1.37s/it]


In [22]:
degree_95p = load_sessions(_ROOT, _FILE_NAME, stats="95p")

100%|███████████████████████████████████████████| 62/62 [01:30<00:00,  1.45s/it]


In [23]:
data = xr.concat([degree_mean, degree_cv, degree_95p], "stats")

In [24]:
for f in range(10):
    fig_name = f"degree_{metric}_f_{f}.png"
    plot_feat_flatmap(data, f, suptitle="Degree", fig_name=fig_name)

### Coreness

In [25]:
_FILE_NAME = get_file_name("coreness")

In [26]:
coreness_mean = load_sessions(_ROOT, _FILE_NAME, stats="mean")

100%|███████████████████████████████████████████| 62/62 [03:13<00:00,  3.12s/it]


In [27]:
coreness_cv = load_sessions(_ROOT, _FILE_NAME, stats="cv")

100%|███████████████████████████████████████████| 62/62 [01:51<00:00,  1.80s/it]


In [28]:
coreness_95p = load_sessions(_ROOT, _FILE_NAME, stats="95p")

100%|███████████████████████████████████████████| 62/62 [01:23<00:00,  1.35s/it]


In [29]:
data = xr.concat([coreness_mean, coreness_95p, coreness_cv], "stats")

In [30]:
for f in range(10):
    fig_name = f"coreness_{metric}_f_{f}.png"
    plot_feat_flatmap(data, f, suptitle="Coreness", fig_name=fig_name)

### Efficiency

In [31]:
_FILE_NAME = get_file_name("efficiency")

In [32]:
eff_mean = load_sessions(_ROOT, _FILE_NAME, stats="mean")

100%|███████████████████████████████████████████| 62/62 [03:10<00:00,  3.08s/it]


In [33]:
eff_cv = load_sessions(_ROOT, _FILE_NAME, stats="cv")

100%|███████████████████████████████████████████| 62/62 [00:45<00:00,  1.35it/s]


In [34]:
eff_95p = load_sessions(_ROOT, _FILE_NAME, stats="95p")

100%|███████████████████████████████████████████| 62/62 [01:33<00:00,  1.51s/it]


In [35]:
data = xr.concat([eff_mean, eff_95p, eff_cv], "stats")

In [36]:
for f in range(10):
    fig_name = f"efficiency_{metric}_f_{f}.png"
    plot_feat_flatmap(data, f, suptitle="Efficiency", fig_name=fig_name)

### Meta-connectivity

In [32]:
_ROOT = os.path.expanduser("~/funcog/gda")
_RESULTS = "Results/lucy/meta_conn"


metric = "plv"


def get_mc_path(session):
    return os.path.join(_ROOT, _RESULTS, f"ts_{metric}_{session}.nc")


def load_mc():
    """Load MC for all sessions"""

    data = []
    for session in tqdm(sessions):
        out = xr.load_dataarray(get_mc_path(session))
        data += [out.isel(roi=[r]).astype(np.float32) for r in range(len(out["roi"]))]

    # Concatenate channels
    data = xr.concat(data, dim="roi")
    # Get unique rois
    urois, counts = np.unique(data.roi.data, return_counts=True)
    # Get unique rois that has at leats 10 channels
    urois = urois[counts >= 10]
    # Average channels withn the same roi
    data = data.groupby("roi").mean("roi", skipna=True)
    data = data.sel(roi=urois)
    return data

In [33]:
MC = load_mc()

100%|███████████████████████████████████████████| 62/62 [00:03<00:00, 15.78it/s]


In [34]:
def plot_mc_flatmap(data, f, suptitle=None, fig_name=None):
    """Plot the flatmaps of MC for a given frequency"""
    # Define sub-cortical areas names
    sca = np.array(["thal", "putamen", "claustrum", "caudate"])

    # Get area names
    areas = data.roi.values
    areas = [a.lower() for a in areas]
    index = np.where(np.isin(areas, sca))
    _areas_nosca = np.delete(areas, index)

    times = data.times.values
    n_times = len(times)
    freqs = data.freqs.values
    # Name of each stage to use in plot titles
    stage = [r"$P$", r"$S$", r"$D_1$", r"$D_2$", r"$D_M$"]
    # Annotations
    ann = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]

    # Create canvas in which the flatmaps will be drawn
    fig = plt.figure(figsize=(8, 4), dpi=600)
    gs1 = fig.add_gridspec(
        nrows=1, ncols=n_times, left=0.05, right=0.87, bottom=0.05, top=0.92
    )
    gs2 = fig.add_gridspec(
        nrows=1,
        ncols=1,
        left=0.90,
        right=0.92,
        bottom=0.35,
        top=0.5,
    )

    # Will store the axes of the figure
    ax, ax_cbar = [], []
    count = 0
    # Plot flatmap for different freuquencies and times
    ax_cbar += [plt.subplot(gs2[0])]  # Colorbar axis
    # Limits
    vmin = data.min()
    vmax = data.max()
    for t in range(n_times):
        ax += [plt.subplot(gs1[count])]  # Flatmap axis
        # Get values to plot in the flatmap
        values = data.isel(times=t, freqs=f).values
        # Delete values for subcortical areas
        values = np.delete(values, index)
        # Instantiate flatmap
        fmap = flatmap(values, _areas_nosca)
        # Only plot colorbar for last column
        if t == 3:
            fmap.plot(
                ax[count],
                ax_colorbar=ax_cbar[0],
                cbar_title="TS",
                vmin=vmin,
                vmax=vmax,
                alpha=0.6,
                colormap="hot_r",
            )
        else:
            fmap.plot(
                ax[count],
                ax_colorbar=None,
                cbar_title="TS",
                vmin=vmin,
                vmax=vmax,
                alpha=0.6,
                colormap="hot_r",
            )
        count += 1
        if t == 0:
            plt.annotate(f"{freqs[f]} Hz", (0, 0.2), fontsize=5)
        plt.title(stage[t], fontsize=10)
    plt.savefig(os.path.join(results, fig_name), bbox_inches="tight")
    plt.close()

In [35]:
for f in range(10):
    plot_mc_flatmap(MC, f, fig_name=f"meta_conn_{metric}_{f}_maps.png")