In [None]:
from analysis.utils.utils import setup_case_manager
from megspikes.scoring.scoring import distance_to_resection_hull
from analysis.utils.plot_paper_images import fig4_one_row
import matplotlib.pylab as plt
import seaborn as sns
import xarray as xr
import pandas as pd
import numpy as np

In [None]:
case = setup_case_manager(1)
save_path = case.root.parent / 'Results'
save_path.mkdir(exist_ok=True)

## Figure 2

## Figure 3

## Figure 4

In [None]:
fig4_table = pd.DataFrame(columns=[
    'case', 'case_name', 'detection_type', 'n_sources_resection', 'n_sources',
    'n_sources_to_n_sources_resection', 'distance_resection'])

n = 0

for subj in range(1, 8):
    case = setup_case_manager(subj)

    resection = np.load(
        case.basic_folders['resection mask'].with_name('resection.npy'))
    resection_stc = np.load(
        case.basic_folders['resection mask'].with_name('resection_stc.npy'))

    clusters_auto = xr.open_dataset(
        case.cluster_dataset.with_name(f'{case.case}_clusters_manually_checked.nc'))

    fwd_mni = clusters_auto.fwd_mni_coordinates.values

    manual_stc = np.load(case.basic_folders['MANUAL'] / 'manual_stc.npy')
    peak_stc = clusters_auto.iz_prediction.loc[:, 'peak'].values
    slope_stc = clusters_auto.iz_prediction.loc[:, 'slope'].values

    if sum(slope_stc > 0) == 0:
        dist_resection_to_slope = np.NAN
    else:
        dist_resection_to_slope = distance_to_resection_hull(
            fwd_mni[resection_stc > 0], fwd_mni[slope_stc > 0])
    dist_resection_to_peak = distance_to_resection_hull(
        fwd_mni[resection_stc > 0], fwd_mni[peak_stc > 0])
    dist_resection_to_manual = distance_to_resection_hull(
        fwd_mni[resection_stc > 0], fwd_mni[manual_stc > 0])

    fig4_table, n = fig4_one_row(n, case, fig4_table, subj, 'manual', fwd_mni,
                                 resection_stc,  manual_stc, dist_resection_to_manual)

    fig4_table, n = fig4_one_row(n, case, fig4_table, subj, 'slope', fwd_mni,
                                 resection_stc,  slope_stc, dist_resection_to_slope)

    fig4_table, n = fig4_one_row(n, case, fig4_table, subj, 'peak', fwd_mni,
                                 resection_stc,  peak_stc, dist_resection_to_peak)



#### Save table with results

In [None]:
fig4_table.to_excel(save_path / "figure_4_table.xlsx", index=False)


#### Plot figure 4

In [None]:
sns.set(style="whitegrid",font_scale=2)

fig4_table.rename(
    columns={'n_sources_to_n_sources_resection': 'N sources\ndetection to\nresection'},
    inplace=True)
f, ax = plt.subplots(figsize=(15, 15))
dy = 'detection_type'
dx = 'distance_resection'
ort="h"
pal = "Set2"
size = 'N sources\ndetection to\nresection'

ax= sns.scatterplot(x=dx, y=dy, data=fig4_table, hue='case', size=size,
                    zorder = 10, palette=pal, sizes=(1,1500), alpha=.6)

plt.setp(ax.get_legend().get_texts(), fontsize='12') # for legend text
plt.setp(ax.get_legend().get_title(), fontsize='32') # for legend title
ax.legend(loc='upper right',ncol=1, bbox_to_anchor=(1.25, 1)) # , title="Title"



ax=sns.boxplot(x=dx, y=dy, data=fig4_table, color="black",
        width=.15, zorder=0, showcaps=True,
        boxprops = {'facecolor':'none', "zorder":10}, showfliers=True,
        whiskerprops = {'linewidth':2, "zorder":10},
        saturation=1, orient=ort)

for case, color in zip(range(1,8), sns.color_palette(pal, n_colors=8, as_cmap=True).colors):
    ax.plot(fig4_table.loc[fig4_table['case']==case, 'distance_resection'], [0, 1, 2],
            lw=3, c=color, alpha=0.4)


plt.tight_layout()
plt.axvline(20, 0, 1, c='r')
plt.savefig(save_path / 'Figure_4_results.png', dpi=300)

## Figure 5