# 4) Analyze and evaluate optimization output

This final notebook uses the `runs.pkl` file created in notebook 2 and it analyzes:

- the distance between different feature sets in the parameter space
- the distance between different feature sets in the feature space
- the distance between different feature sets in the extracellular signals

In [None]:
import pickle
import pandas as pd
import seaborn as sns
import sys
from scipy.spatial import distance

import bluepyopt as bpopt
import bluepyopt.ephys as ephys

import matplotlib.pyplot as plt
import MEAutility as mu
import json
import time
import numpy as np
import LFPy
from pathlib import Path

import model
import evaluator
import plotting 
import utils

%matplotlib notebook

### Load gt params and optimization output

In [None]:
result_folder = Path('results')
config_folder = Path('config')

In [None]:
random_params = pd.read_csv(config_folder / 'params' / 'smart_random.csv', index_col='index')
random_params

In [None]:
param_names = random_params.columns

In [None]:
param_names

In [None]:
data = pickle.load(open(result_folder / 'runs.pkl', 'rb'))
df = pd.DataFrame(data)

### Compute complete set of features for all samples

In [None]:
probe_type = 'planar' # 'linear'
electrode = model.define_electrode(probe_type=probe_type)
probe = electrode.probe
ax = mu.plot_probe(probe)

In [None]:
l5pc_cell = model.create()

param_names = [param.name for param in l5pc_cell.params.values() if not param.frozen]

feature_set = 'all'

print(f'Feature set {feature_set}')
gt_responses = []

if feature_set in ["extra", "all"]:
    fitness_protocols = evaluator.define_protocols(electrode=electrode, protocols_with_lfp=["Step1"])
else:
    fitness_protocols = evaluator.define_protocols() 

if feature_set in ["extra", "all"]:
    sim = ephys.simulators.LFPySimulator(LFPyCellModel=l5pc_cell, cvode_active=True, electrode=electrode)
else:
    sim = ephys.simulators.LFPySimulator(LFPyCellModel=l5pc_cell, cvode_active=True)

In [None]:
fitness_protocols

In [None]:
if (result_folder / 'feats_responses.pkl').is_file():
    feats_responses = pickle.load((result_folder / 'feats_responses.pkl').open('rb'))
    feats = feats_responses['feats']
    responses = feats_responses['responses']
    gt_features_v = feats['gt']
    gt_responses = responses['gt']
    fitted_features_v = feats['fitted']
    fitted_responses = responses['fitted']
    compute_feats_responses = False
else:
    compute_feats_responses = True

In [None]:
# Compute GT features and responses
channels = 'map'

if compute_feats_responses:
    gt_features = []
    gt_responses = []
    for i, (index, params) in enumerate(random_params.iterrows()):
        print(f'{i+1} / {len(random_params)}, {index}')

        feature_folder = f'config/features/{index}'
        response, feature_dict = utils.compute_feature_values(params, l5pc_cell, fitness_protocols, sim, 
                                                              feature_set=feature_set, probe=probe, 
                                                              channels=channels)
        gt_features.append(feature_dict)
        gt_responses.append(response)
    gt_features_v = utils.vectorize_features(gt_features)
else:
    print("Using loaded GT features and responses")

In [None]:
plotting.plot_multiple_responses(gt_responses, max_rows=5)

In [None]:
# Compute fitted features and responses
if compute_feats_responses:
    fitted_features = []
    fitted_responses = []
    for i, (index, fit) in enumerate(df.iterrows()):
        params = pd.Series(data=fit['best_params'], index=param_names)
        print(f'{i+1} / {len(df)}')

        response, feature_dict = utils.compute_feature_values(params, l5pc_cell, fitness_protocols, sim, 
                                                              feature_set=feature_set, probe=probe, 
                                                              channels=channels)
        fitted_features.append(feature_dict)
        fitted_responses.append(response)
    fitted_features_v = utils.vectorize_features(fitted_features)
else:
    print("Using loaded FITTED features and responses")

In [None]:
plotting.plot_multiple_responses(fitted_responses, max_rows=5)

In [None]:
# Save features
save_features = True
if save_features and compute_feats_responses:
    feats = {'gt': gt_features_v, 'fitted': fitted_features_v}
    responses = {'gt': gt_responses, 'fitted': fitted_responses}
    dump_dict = {'feats': feats, 'responses': responses}
    with (result_folder / 'feats_responses.pkl').open('wb') as f:
        pickle.dump(dump_dict, f)
else:
    print("Responses and features already saved!")

In [None]:
# Double check all GT params produce responses with all BAP features (5)
complete_bap = []
for i, gt in enumerate(list(gt_features_v)):
    bap_feat = [k for k in gt.keys() if 'bAP' in k]
    if len(bap_feat) == 5:
        complete_bap.append(i)
print(len(complete_bap))

## Compute distances in parameters and feature space (no extracellular)

In [None]:
param_distances = []
param_distances_apical = []
param_distances_somatic = []
param_distances_axonal = []

feature_distances = []
feat_soma_dist = []
feat_dend_dist = []
feat_mea_dist = []

# channels = [4,5,6,10,15]
channels=None
for i, (index, fit) in enumerate(df.iterrows()):
    sample_id = int(fit.sample_id)
    gt_params = random_params.iloc[sample_id].sort_index()
    fit_params = pd.Series(fit.best_params, param_names).sort_index()
    
    axonal_idxs = gt_params.index.str.contains('axonal')
    somatic_idxs = gt_params.index.str.contains('somatic')
    apical_idxs = gt_params.index.str.contains('apical')
    
    param_dist = distance.cosine(fit_params.values, gt_params.values)
    param_dist_ax = distance.cosine(fit_params[axonal_idxs].values, gt_params[axonal_idxs].values)
    param_dist_som = distance.cosine(fit_params[somatic_idxs].values, gt_params[somatic_idxs].values)
    param_dist_ap = distance.cosine(fit_params[apical_idxs].values, gt_params[apical_idxs].values)
    
    param_distances.append(param_dist)
    param_distances_axonal.append(param_dist_ax)
    param_distances_somatic.append(param_dist_som)
    param_distances_apical.append(param_dist_ap)
    
    
    selected_keys = []
    for k in gt_features_v[sample_id].keys():
        if 'MEA' not in k:
            selected_keys.append(k)
        else:
            if channels is not None:
                if int(k[-1]) in channels:
                    selected_keys.append(k)
                else:
                    selected_keys.append(k)
    gt_feat, fitted_feat = [], []
    gt_feat_soma, gt_feat_dend, gt_feat_mea = [], [], []
    fitted_feat_soma, fitted_feat_dend, fitted_feat_mea = [], [], []
    for k in selected_keys:
        if k in gt_features_v[sample_id] and k in fitted_features_v[i]:
            gt_feat.append(gt_features_v[sample_id][k])
            fitted_feat.append(fitted_features_v[i][k])
            if 'soma' in k:
                gt_feat_soma.append(gt_features_v[sample_id][k])
                fitted_feat_soma.append(fitted_features_v[i][k])
            if 'dend' in k:
                gt_feat_dend.append(gt_features_v[sample_id][k])
                fitted_feat_dend.append(fitted_features_v[i][k])
            if 'MEA' in k:
                gt_feat_mea.append(gt_features_v[sample_id][k])
                fitted_feat_mea.append(fitted_features_v[i][k])
                
    feature_dist = distance.cosine(fitted_feat, gt_feat)
    feature_dist_soma = distance.cosine(fitted_feat_soma, gt_feat_soma)
    feature_dist_dend = distance.cosine(fitted_feat_dend, gt_feat_dend)
    feature_dist_mea = distance.cosine(fitted_feat_mea, gt_feat_mea)
    
    feature_distances.append(feature_dist)
    feat_soma_dist.append(feature_dist_soma)    
    feat_dend_dist.append(feature_dist_dend)    
    feat_mea_dist.append(feature_dist_mea)    

In [None]:
df['param_dist'] = param_distances
df['param_dist_apical'] = param_distances_apical
df['param_dist_axonal'] = param_distances_axonal
df['param_dist_somatic'] = param_distances_somatic

df['feat_dist'] = feature_distances
df['feat_dist_soma'] = feat_soma_dist
df['feat_dist_dend'] = feat_dend_dist
df['feat_dist_mea'] = feat_mea_dist

In [None]:
fig1 = plt.figure()
ax1 = fig1.add_subplot(111)
sns.barplot(data=df, x='feature_set', y='param_dist', hue='sample_id', ax=ax1, alpha=0.5)
ax1.set_title("All params")

In [None]:
fig2 = plt.figure(figsize=(9, 5))
ax21 = fig2.add_subplot(131)
ax22 = fig2.add_subplot(132)
ax23 = fig2.add_subplot(133)
sns.barplot(data=df, x='feature_set', y='param_dist_somatic', hue='sample_id', ax=ax21, alpha=0.5, ci=None)
sns.barplot(data=df, x='feature_set', y='param_dist_axonal', hue='sample_id', ax=ax22, alpha=0.5, ci=None)
sns.barplot(data=df, x='feature_set', y='param_dist_apical', hue='sample_id', ax=ax23, alpha=0.5, ci=None)

ax21.set_title("Somatic params")
ax22.set_title("Axonal params")
ax23.set_title("Apical params")
fig2.tight_layout()

In [None]:
fig3 = plt.figure()
ax3 = fig3.add_subplot(111)
sns.boxenplot(data=df, x='feature_set', y='feat_dist', hue='sample_id', ax=ax3)
ax3.set_title('All features')

In [None]:
fig4 = plt.figure(figsize=(10,5))
ax41 = fig4.add_subplot(121)
ax42 = fig4.add_subplot(122)
sns.boxenplot(data=df, x='feature_set', y='feat_dist_soma', hue='sample_id', ax=ax41)#, alpha=0.5, ci=None)
sns.boxenplot(data=df, x='feature_set', y='feat_dist_dend', hue='sample_id', ax=ax42)#, alpha=0.5, ci=None)

ax41.set_title("Somatic features")
ax42.set_title("Dend features")
ax42.set_yticks([])
ax42.set_yticklabels([])
ax42.set_ylabel('')

## Compute distances in extracellular feature space

In [None]:
sample_ids_array = []
feature_set_array = []
fitted_ids_array = []
distances_array = []
for gt_id in range(len(gt_responses)):
    print(f"Test model {gt_id + 1}")
    df_fit = df[df.sample_id == str(gt_id)]
    fitted = np.array(fitted_responses)[np.array(df_fit.index)]
    
    color_list = []
    feature_sets = []
    for i in range(len(df_fit)):
        color_list.append(colors[df_fit.iloc[i].feature_set])
        feature_sets.append(df_fit.iloc[i].feature_set)

    soma_idxs = np.where(np.array(feature_sets) == 'soma')
    bap_idxs = np.where(np.array(feature_sets) == 'multiple')
    extra_idxs = np.where(np.array(feature_sets) == 'extra')

    fitted_soma = fitted[soma_idxs]
    fitted_multiple = fitted[bap_idxs]
    fitted_extra = fitted[extra_idxs]
    
    eap_gt = utils.calculate_eap(responses=gt_responses[gt_id], protocols=fitness_protocols, protocol_name="Step1")
    
    for i, fitted_s in enumerate(fitted_soma):
        eap = utils.calculate_eap(responses=fitted_s, protocols=fitness_protocols, protocol_name="Step1")
        dist = distance.cosine(eap_gt.ravel(), eap.ravel())
        sample_ids_array.append(gt_id)
        fitted_ids_array.append(i)
        feature_set_array.append("soma")
        distances_array.append(dist)
    for i, fitted_m in enumerate(fitted_multiple):
        eap = utils.calculate_eap(responses=fitted_m, protocols=fitness_protocols, protocol_name="Step1")
        dist = distance.cosine(eap_gt.ravel(), eap.ravel())
        sample_ids_array.append(gt_id)
        fitted_ids_array.append(i)
        feature_set_array.append("multiple")
        distances_array.append(dist)
    for i, fitted_e in enumerate(fitted_extra):
        eap = utils.calculate_eap(responses=fitted_e, protocols=fitness_protocols, protocol_name="Step1")
        dist = distance.cosine(eap_gt.ravel(), eap.ravel())
        sample_ids_array.append(gt_id)
        fitted_ids_array.append(i)
        feature_set_array.append("extra")
        distances_array.append(dist)

df_eaps = pd.DataFrame({"sample_id": sample_ids_array, "fitted_id": fitted_ids_array, 
                        "feature_set": feature_set_array, "distance": distances_array})

In [None]:
fig4 = plt.figure()
ax4 = fig4.add_subplot(111)
sns.boxenplot(data=df_eaps, x='feature_set', y='distance', hue='sample_id', ax=ax4)
ax3.set_title('Extracellular distance')

# Plot responses

In [None]:
colors = {'soma': 'C0', 'multiple': 'C1', 'extra': 'C2'}

In [None]:
figures_soma_intra = []
figures_soma_extra = []
figures_multi_intra = []
figures_multi_extra = []
figures_extra_intra = []
figures_extra_extra = []
for gt_id in range(len(gt_responses)):
    print(f"Test model {gt_id + 1}")
    df_fit = df[df.sample_id == str(gt_id)]
    fitted = np.array(fitted_responses)[np.array(df_fit.index)]
    
    color_list = []
    feature_sets = []
    for i in range(len(df_fit)):
        color_list.append(colors[df_fit.iloc[i].feature_set])
        feature_sets.append(df_fit.iloc[i].feature_set)

    soma_idxs = np.where(np.array(feature_sets) == 'soma')
    bap_idxs = np.where(np.array(feature_sets) == 'multiple')
    extra_idxs = np.where(np.array(feature_sets) == 'extra')

    fitted_soma = fitted[soma_idxs]
    fitted_multiple = fitted[bap_idxs]
    fitted_extra = fitted[extra_idxs]

    color_list.append('k')
    
    # Plot intracellular responses
    fig_soma = plotting.plot_multiple_responses(responses_list=np.concatenate((fitted_soma, [gt_responses[gt_id]])), 
                                                max_rows=5, colors=[colors["soma"]] * len(fitted_soma) + ['k'], 
                                                return_fig=True)
    fig_soma.suptitle(f"Test model {gt_id + 1} - 'soma' feature set\nIntracellular", fontsize=25, y=0.98)
    fig_soma.subplots_adjust(top=0.8)
    figures_soma_intra.append(fig_soma)
    
    ax_eap_soma = plotting.plot_multiple_eaps(responses_list=fitted_soma, protocols=fitness_protocols,
                                              protocol_name="Step1", probe=probe, colors="C0", norm=True)
    ax_eap_soma = plotting.plot_eap(responses=gt_responses[gt_id], protocols=fitness_protocols,
                                    protocol_name="Step1", probe=probe, color="k", norm=True,
                                    ax=ax_eap_soma) 
    ax_eap_soma.set_title(f"Test model {gt_id + 1} - 'soma' feature set\nExtracellular", fontsize=25)
    figures_soma_extra.append(ax_eap_soma.get_figure())
    
    fig_multiple = plotting.plot_multiple_responses(responses_list=np.concatenate((fitted_multiple, [gt_responses[gt_id]])), 
                                                    max_rows=5, 
                                                    colors=[colors["multiple"]] * len(fitted_multiple) + ['k'], 
                                                    return_fig=True)
    fig_multiple.suptitle(f"Test model {gt_id + 1} - 'multiple' feature set\nIntracellular", fontsize=25, y=0.98)
    fig_multiple.subplots_adjust(top=0.8)
    figures_multi_intra.append(fig_soma)
    
    
    ax_eap_multi = plotting.plot_multiple_eaps(responses_list=fitted_multiple, protocols=fitness_protocols,
                                               protocol_name="Step1", probe=probe, colors="C1", norm=True)
    ax_eap_multi = plotting.plot_eap(responses=gt_responses[gt_id], protocols=fitness_protocols,
                                     protocol_name="Step1", probe=probe, color="k", norm=True,
                                     ax=ax_eap_multi) 
    ax_eap_multi.set_title(f"Test model {gt_id + 1} - 'multiple' feature set\nExtracellular", fontsize=25)
    figures_multi_extra.append(ax_eap_multi.get_figure())
    
    fig_extra = plotting.plot_multiple_responses(responses_list=np.concatenate((fitted_extra, [gt_responses[gt_id]])), 
                                                 max_rows=5, colors=[colors["extra"]] * len(fitted_extra) + ['k'], 
                                                 return_fig=True)
    fig_extra.suptitle(f"Test model {gt_id + 1} - 'extra' feature set\nIntracellular", fontsize=30, y=0.98)
    fig_extra.subplots_adjust(top=0.8)
    figures_extra_intra.append(fig_soma)
    
    # Plot EAPs
    ax_eap_extra = plotting.plot_multiple_eaps(responses_list=fitted_extra, protocols=fitness_protocols,
                                               protocol_name="Step1", probe=probe, colors="C2", norm=True)
    ax_eap_extra = plotting.plot_eap(responses=gt_responses[gt_id], protocols=fitness_protocols,
                                     protocol_name="Step1", probe=probe, color="k", norm=True,
                                     ax=ax_eap_extra) 
    ax_eap_extra.set_title(f"Test model {gt_id + 1} - 'extra' feature set\nExtracellular", fontsize=30)
    figures_extra_extra.append(ax_eap_extra.get_figure())
    
    print("\n\n\n\n\n")
    
    