# Notebook to analyze and display cell type data


### Load important modules

In [None]:
# Standard modules
import numpy as np
import os
import pandas as pd 
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import logging
from numba import njit
from sklearn import linear_model
import scipy.cluster.hierarchy as sch
from plotly.subplots import make_subplots


# Move to root directory for easier module handling
os.chdir("../..")
print(os.listdir("."))

#LBAE imports
from modules.maldi_data import MaldiData
from modules.figures import Figures
from modules.atlas import Atlas
from modules.storage import Storage
from modules.scRNAseq import ScRNAseq

# multithreading/multiprocessing
from multiprocessing import Pool
import multiprocessing
from threadpoolctl import threadpool_limits

# set thread limit
threadpool_limits(16)


#### Load LBAE objects

In [None]:
path_data = "data/whole_dataset/"
path_annotations = "data/annotations/"
path_db = "data/app_data/data.db"

# Load shelve database
storage = Storage(path_db)

# Load data
data = MaldiData(path_data, path_annotations)

# Load Atlas and Figures objects. At first launch, many objects will be precomputed and shelved in
# the classes Atlas and Figures.
atlas = Atlas(data, storage, resolution=25, sample = True)
scRNAseq = ScRNAseq()
figures = Figures(data, storage, atlas, scRNAseq, sample = True)

### Load cell type data

In [None]:
table_cells = pd.read_csv('notebooks/cell_type_atlas/data/Data_Sheet_2_A Cell Atlas for the Mouse Brain.csv', usecols=[0,4,5,6,7,8,9] )
table_cells

In [None]:
# Get list of leafs in the hierarchy of structures
l_leafs_ids = []
l_leafs_acronyms = []
l_leafs_names = []

for x in atlas.dic_acronym_children_id:
    if len(atlas.dic_acronym_children_id[x]) == 1:
        l_leafs_ids.extend(list(atlas.dic_acronym_children_id[x]))
        l_leafs_acronyms.append(x)
        l_leafs_names.append(atlas.dic_acronym_name[x])
    if len(atlas.dic_acronym_children_id[x]) == 0:
        raise ValueError("No leaf found for structure: " + x)

dic_names_id = {x:y for x,y in zip(l_leafs_names, l_leafs_ids)}
      

In [None]:
# Keep only the regions that are leafs in the hierarchy to avoid overlap
table_cells = table_cells[table_cells['Regions'].isin(l_leafs_names)] 
table_cells

#### Get a dictionnary of lipid expression for each region

In [None]:

def compute_array_exp_lipids(l_id_regions, brain_1 = False, decrease_resolution_factor = 5, ):
    ll_exp_lipids = []
    l_name_lipids = []
    # Simulate a click on all lipid names
    for name in sorted(
        figures._data.get_annotations_MAIA_transformed_lipids(brain_1=brain_1).name.unique()
    ):
        structures = figures._data.get_annotations_MAIA_transformed_lipids(brain_1=brain_1)[
            figures._data.get_annotations_MAIA_transformed_lipids(brain_1=brain_1)["name"] == name
        ].structure.unique()
        for structure in sorted(structures):
            cations = figures._data.get_annotations_MAIA_transformed_lipids(brain_1=brain_1)[
                (
                    figures._data.get_annotations_MAIA_transformed_lipids(brain_1=brain_1)["name"]
                    == name
                )
                & (
                    figures._data.get_annotations_MAIA_transformed_lipids(brain_1=brain_1)[
                        "structure"
                    ]
                    == structure
                )
            ].cation.unique()
            for cation in sorted(cations):
                l_selected_lipids = []
                l_slices = figures._data.get_slice_list(indices="brain_1" if brain_1 else "brain_2")
                for slice_index in l_slices:

                    # Find lipid location
                    l_lipid_loc = (
                        figures._data.get_annotations()
                        .index[
                            (figures._data.get_annotations()["name"] == name)
                            & (figures._data.get_annotations()["structure"] == structure)
                            & (figures._data.get_annotations()["slice"] == slice_index)
                            & (figures._data.get_annotations()["cation"] == cation)
                        ]
                        .tolist()
                    )

                    # If several lipids correspond to the selection, we have a problem...
                    if len(l_lipid_loc) > 1:
                        logging.warning("More than one lipid corresponds to the selection")
                        l_lipid_loc = [l_lipid_loc[-1]]
                    # If no lipid correspond to the selection, set to -1
                    if len(l_lipid_loc) == 0:
                        l_lipid_loc = [-1]

                    # add lipid index for each slice
                    l_selected_lipids.append(l_lipid_loc[0])

                # Get final lipid name
                lipid_string = name + " " + structure + " " + cation

                # If lipid is present in at least one slice
                if np.sum(l_selected_lipids) > -len(l_slices):

                    # Build the list of mz boundaries for each peak and each index
                    lll_lipid_bounds = [
                        [
                            [
                                (
                                    float(figures._data.get_annotations().iloc[index]["min"]),
                                    float(figures._data.get_annotations().iloc[index]["max"]),
                                )
                            ]
                            if index != -1
                            else None
                            for index in [lipid_1_index, -1, -1]
                        ]
                        for lipid_1_index in l_selected_lipids
                    ]
                    print("getting data for lipid ", lipid_string)
                    l_name_lipids.append(lipid_string)
                    l_expr = []
                    for id_region in l_id_regions:
                        #try:
                        interpolated_array = figures.compute_3D_volume_figure(ll_t_bounds = lll_lipid_bounds, name_lipid_1=lipid_string,decrease_dimensionality_factor=decrease_resolution_factor,return_interpolated_array=True, structure_guided_interpolation=False, set_id_regions = set([id_region]))
                        l_expr.append(np.mean([x for x in interpolated_array.flatten() if x>=0]))
                        #except:
                        #    print('An error has occured')
                        #    l_expr.append(np.nan)
                    ll_exp_lipids.append(l_expr)

    return np.array(ll_exp_lipids).T, l_name_lipids



In [None]:
load_from_save = True
l_name_regions = list(table_cells['Regions'])
l_name_cells = table_cells.columns[1:]
if not load_from_save:    
    array_exp_lipids_brain_2, l_name_lipids_brain_2 = compute_array_exp_lipids(l_id_regions = [dic_names_id[x] for x in table_cells['Regions']], brain_1 = False)

    # Save array of lipid expression for brain 2
    with open('notebooks/cell_type_atlas/data/array_exp_lipids_False.npy', 'wb') as f:
        np.save(f, array_exp_lipids_brain_2)

    # Save corresponding names for brain 2
    with open('notebooks/cell_type_atlas/data/array_name_lipids_False.npy', 'wb') as f:
        np.save(f, np.array(l_name_lipids_brain_2))

else:
    array_exp_lipids_brain_2 = np.load('notebooks/cell_type_atlas/data/array_exp_lipids_False.npy')
    array_name_lipids_False = np.load('notebooks/cell_type_atlas/data/array_name_lipids_False.npy')
    l_name_lipids_brain_2 = array_name_lipids_False.tolist()

### Get cell type data as an array

In [None]:
array_exp_cells = table_cells.iloc[:,1:].to_numpy()
array_exp_cells

In [None]:
array_exp_lipids_brain_2 = np.nan_to_num(array_exp_lipids_brain_2)

##### Make an elastic net regression to explain lipid expression in terms of cell type

In [None]:
from sklearn.ensemble import RandomForestRegressor

In [None]:
def compute_regression_all_lipids(array_exp_lipids, array_exp_cells):
    """Compute the elastic net regression coefficients for all lipids.

    Returns:
        list(list(float)), list(float): List of coefficients (for each lipid) and list of scores
            for the elastic net regression explaining lipid expression in terms of gene expression.
    """
    # Define regression as a function for potential parallelization
    def compute_regression(index_lipid):
        #clf = linear_model.ElasticNet(fit_intercept=True, alpha=0., positive=False)
        #clf = linear_model.LinearRegression(fit_intercept=True)
        #clf = linear_model.Lars(fit_intercept=True)
        clf = RandomForestRegressor()
        clf.fit(array_exp_cells, array_exp_lipids[:, index_lipid])
        return [
            #clf.coef_,
            None,
            clf.score(array_exp_cells, array_exp_lipids[:, index_lipid]),
        ]

    # Compute regression for all lipids
    l_lipid_indices = list(range(array_exp_lipids.shape[1]))
    l_res = [x for x in map(compute_regression, l_lipid_indices)]

    # Store the coefficients and the score of the regressions
    ll_coef = []
    l_score = []
    for res in l_res:
        ll_coef.append(res[0])
        l_score.append(res[1])

    # Return result
    return np.array(ll_coef), l_score

In [None]:
# Do the LASSO regression
array_coef_brain_2, l_score_brain_2 = compute_regression_all_lipids(array_exp_lipids_brain_2, array_exp_cells)
array_coef_brain_2_reversed, l_score_brain_2_reversed = compute_regression_all_lipids(array_exp_cells, array_exp_lipids_brain_2)

#### Save the filtered data from the molecular atlas

In [None]:

with open('notebooks/cell_type_atlas/data/array_coef.npy', 'wb') as f:
    np.save(f, array_coef_brain_2)
with open('notebooks/cell_type_atlas/data/array_score.npy', 'wb') as f:
    np.save(f, np.array(l_score_brain_2))

with open('notebooks/cell_type_atlas/data/array_coef_reversed.npy', 'wb') as f:
    np.save(f, array_coef_brain_2_reversed)
with open('notebooks/cell_type_atlas/data/array_score_reversed.npy', 'wb') as f:
    np.save(f, np.array(l_score_brain_2_reversed))

In [None]:
print(l_score_brain_2_reversed)

In [None]:
print(l_score_brain_2)

#### Plot regression

In [None]:
# Compute correlation on rows
pairwise_distances = sch.distance.pdist(array_coef_brain_2_reversed)
linkage = sch.linkage(pairwise_distances, method='ward')
cluster_distance_threshold = pairwise_distances.max()/2
idx_to_cluster_array = sch.fcluster(linkage, cluster_distance_threshold, 
                                    criterion='distance')
idx = np.argsort(idx_to_cluster_array)
    
corr_array_clustered = array_coef_brain_2_reversed[idx, :]
l_name_cells = np.array(l_name_cells)[idx]
l_score_brain_2_reversed = np.array(l_score_brain_2_reversed)[idx]

# Compute correlation on columns
pairwise_distances = sch.distance.pdist(corr_array_clustered.T)
linkage = sch.linkage(pairwise_distances, method='ward')
cluster_distance_threshold = pairwise_distances.max()/2
idx_to_cluster_array = sch.fcluster(linkage, cluster_distance_threshold, 
                                    criterion='distance')
idx = np.argsort(idx_to_cluster_array)
    
corr_array_clustered = corr_array_clustered[:, idx]

l_name_lipids_brain_2 = np.array(l_name_lipids_brain_2)[idx]



In [None]:
fig = make_subplots(rows=1, cols=2, shared_yaxes=True,horizontal_spacing=0.)


g1 = go.Heatmap(z = corr_array_clustered, colorscale='RdBu', x=l_name_lipids_brain_2, y = l_name_cells, zmin = -np.max(corr_array_clustered)/5, zmax = np.max(corr_array_clustered)/5, colorbar=dict(title="Coef value", x=1.1 ))
g2 = go.Heatmap(z = np.array([l_score_brain_2_reversed]).T, colorscale='mint', y = l_name_cells, zmin = 0., zmax = 0.6, colorbar=dict(title="R2 score", x=1.05 ))


fig.append_trace(g1, row=1, col=1)
fig.append_trace(g2, row=1, col=2)



# edit axis labels
fig['layout']['xaxis']['title']="Genes"
fig['layout']['yaxis']['title']='Lipids'
#fig['layout']['xaxis2']['title']='R2 score'
fig['layout']['xaxis2']['side']='top'

fig.update_xaxes(showticklabels=False, row=1, col=2)
fig.update_yaxes(showticklabels=False, row=1, col=2)

fig['layout']['yaxis2']['scaleanchor']='x2'
fig['layout']['yaxis']['scaleanchor']='x'

fig['layout']['xaxis']['domain'] = [0, 0.98]
fig['layout']['xaxis2']['domain'] = [0.99, 1]
#fig.update_traces(showscale=False)
fig.update_xaxes(tickangle=45)
#fig.update_layout(coloraxis_colorbar_x=-0.15)

    


fig.update_layout(
    width=int(1520),
    height=int(350),
    font_size=7,
    title_font_size=12,
    title={
        'text' : 'Linear regression factors (Cell types explained with lipids as predictors)',
        'y':0.92,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'}
)
fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')


fig.show()
#fig.write_image("'notebooks/cell_type_atlas/output/heatmap.pdf")

In [None]:
fig = make_subplots(rows=1, cols=2, shared_yaxes=True,horizontal_spacing=0.)

maxi = np.max(corr_array_clustered[:-1,:])*0.9
g1 = go.Heatmap(z = corr_array_clustered[:-1,:], colorscale='RdBu', x=l_name_lipids_brain_2, y = l_name_cells[:-1], zmin = -maxi, zmax = maxi, colorbar=dict(title="Coef value", x=1.1 ))
g2 = go.Heatmap(z = np.array([l_score_brain_2_reversed[:-1]]).T, colorscale='mint', y = l_name_cells[:-1], zmin = 0., zmax = 0.6, colorbar=dict(title="R2 score", x=1.05 ))


fig.append_trace(g1, row=1, col=1)
fig.append_trace(g2, row=1, col=2)



# edit axis labels
fig['layout']['xaxis']['title']="Genes"
fig['layout']['yaxis']['title']='Lipids'
#fig['layout']['xaxis2']['title']='R2 score'
fig['layout']['xaxis2']['side']='top'

fig.update_xaxes(showticklabels=False, row=1, col=2)
fig.update_yaxes(showticklabels=False, row=1, col=2)

fig['layout']['yaxis2']['scaleanchor']='x2'
fig['layout']['yaxis']['scaleanchor']='x'

fig['layout']['xaxis']['domain'] = [0, 0.98]
fig['layout']['xaxis2']['domain'] = [0.99, 1]
#fig.update_traces(showscale=False)
fig.update_xaxes(tickangle=45)
#fig.update_layout(coloraxis_colorbar_x=-0.15)

fig.update_layout(
    width=int(1520),
    height=int(350),
    font_size=7,
    title_font_size=12,
    title={
        'text' : 'Linear regression factors (Cell types explained with lipids as predictors)',
        'y':0.92,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'}
)
fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')


fig.show()

### Try alternative regression to explain lipids with cell types 