In [None]:
import logging
import warnings

import mne
import xarray as xr
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pylab as plt

from megspikes.visualization.visualization import (DetectionsViewer,
                                                   ClusterSlopeViewer)

from megspikes.scoring.scoring import distance_to_resection_hull

from utils.utils import setup_case_manager
from utils.plot_paper_images import fig4_one_row

warnings.filterwarnings("ignore", category=DeprecationWarning)
%load_ext autoreload
%autoreload 2

### Kurtosis example

In [None]:
subj = 4
case = setup_case_manager(subj)

In [None]:
dataset = xr.open_dataset(case.dataset)
dataset.ica_component_properties.loc[dict(sensors='mag')]

In [None]:
pp = DetectionsViewer(dataset, case)
pp.time = (700, 720)

In [None]:
app = pp.view_ica()
app.show()

In [None]:
app = pp.view_ica_sources_and_peaks()
app.show()

### More slope points

In [None]:
save_path = case.root.parent / 'Results'

In [None]:
fig4_table = pd.DataFrame(columns=[
    'case', 'case_name', 'detection_type', 'n_sources_resection', 'n_sources',
    'n_sources_to_n_sources_resection', 'distance_resection'])

n = 0

for subj in range(1, 8):
    case = setup_case_manager(subj)

    resection = np.load(
        case.basic_folders['resection mask'].with_name('resection.npy'))
    resection_stc = np.load(
        case.basic_folders['resection mask'].with_name('resection_stc.npy'))

    clusters_auto = xr.open_dataset(
        case.cluster_dataset.with_name(f'{case.case}_clusters_manually_checked.nc'))

    fwd_mni = clusters_auto.fwd_mni_coordinates.values

    manual_stc = np.load(case.basic_folders['MANUAL'] / 'manual_stc.npy')
    
    dist_resection_to_manual = distance_to_resection_hull(
    fwd_mni[resection_stc > 0], fwd_mni[manual_stc > 0])

    fig4_table, n = fig4_one_row(n, case, fig4_table, subj, 'manual', fwd_mni,
                                 resection_stc,  manual_stc, dist_resection_to_manual)
    
    pc = ClusterSlopeViewer(clusters_auto, case)
    
    baseline = clusters_auto.cluster_properties.loc[
        dict(cluster_property="time_baseline")].values
    slope_manual = clusters_auto.cluster_properties.loc[
        dict(cluster_property="time_slope")].values
    peak = clusters_auto.cluster_properties.loc[
        dict(cluster_property="time_peak")].values
    
    # split the slope between baseline and peak in 10 steps for each cluster
    n_steps = 11
    step = (peak - baseline)/n_steps
    
    for i in range(1, n_steps+1):
        # update the slope time
        if i < 6:
            step = (slope_manual - baseline)/6 # (n_steps / 2)
            time_slope = baseline + step * i
            label = f'slope {i}'
        elif i == 6:
            label = 'slope' #  manual 50%
            time_slope = slope_manual
        else:
            step = (peak - slope_manual)/5 #(n_steps / 2)
            time_slope = slope_manual + step * (i - 6) #(i - n_steps / 2)
            label = f'slope {i - 1}' if i < (n_steps) else 'peak'
        
        if subj == 1: print(label, time_slope)
        
        # rerun iz_prediction
        pc.data.clusters_properties.time_slope = time_slope
        pc._rerun_iz_prediction()
        
        slope_stc = pc.data.ds.iz_prediction.loc[:, 'slope'].values
        
        if sum(slope_stc > 0) == 0:
            dist_resection_to_slope = np.NAN
        else:
            dist_resection_to_slope = distance_to_resection_hull(
                fwd_mni[resection_stc > 0], fwd_mni[slope_stc > 0])
        
        
        
        fig4_table, n = fig4_one_row(n, case, fig4_table, subj, label, fwd_mni,
                                     resection_stc,  slope_stc, dist_resection_to_slope)

fig4_table.to_excel(save_path / "figure_4_table.xlsx", index=False)

#### Plot figure 4

In [None]:
fig4_table = pd.read_excel(save_path / "figure_4_table.xlsx")

In [None]:
sns.set(style="whitegrid",font_scale=2)

fig4_table.rename(
    columns={'n_sources_to_n_sources_resection': 'N sources\ndetection to\nresection'},
    inplace=True)
f, ax = plt.subplots(figsize=(15, 15))
dy = 'detection_type'
dx = 'distance_resection'
ort="h"
pal = "Set2"
size = 'N sources\ndetection to\nresection'

ax= sns.scatterplot(x=dx, y=dy, data=fig4_table, hue='case', size=size,
                    zorder = 10, palette=pal, sizes=(1,1500), alpha=.6)

plt.setp(ax.get_legend().get_texts(), fontsize='12') # for legend text
plt.setp(ax.get_legend().get_title(), fontsize='32') # for legend title
ax.legend(loc='upper right',ncol=1, bbox_to_anchor=(1.25, 1)) # , title="Title"



ax=sns.boxplot(x=dx, y=dy, data=fig4_table, color="black",
        width=.15, zorder=0, showcaps=True,
        boxprops = {'facecolor':'none', "zorder":10}, showfliers=True,
        whiskerprops = {'linewidth':2, "zorder":10}, whis=[0.01, 99.99],
        saturation=1, orient=ort)

for case, color in zip(range(1,8), sns.color_palette(pal, n_colors=8, as_cmap=True).colors):
    ax.plot(fig4_table.loc[fig4_table['case']==case, 'distance_resection'], [i for i in range(12)],
            lw=3, c=color, alpha=0.4)

ax.set(ylabel='IZ estimation', xlabel='distance to resection margin [mm]')
ax.set(xlim=(None, 60))
plt.tight_layout()
plt.axvline(20, 0, 1, c='r')
plt.savefig(save_path / 'Figure_4_more_slope_points.png', dpi=300)