# Notebook to analyze and display scRNAseq data


### Load important modules

In [None]:
# Standard modules
import numpy as np
import os
import pandas as pd 
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import logging
from numba import njit
from sklearn import linear_model


# Move to root directory for easier module handling
os.chdir("../..")
print(os.listdir("."))

#LBAE imports
from modules.maldi_data import MaldiData
from modules.figures import Figures
from modules.atlas import Atlas
from modules.storage import Storage


# multithreading/multiprocessing
from multiprocessing import Pool
import multiprocessing
from threadpoolctl import threadpool_limits

# set thread limit
threadpool_limits(16)


#### Load LBAE objects

In [None]:
path_data = "data/whole_dataset/"
path_annotations = "data/annotations/"
path_db = "data/app_data/data.db"

# Load shelve database
storage = Storage(path_db)

# Load data
data = MaldiData(path_data, path_annotations)

# Load Atlas and Figures objects. At first launch, many objects will be precomputed and shelved in
# the classes Atlas and Figures.
atlas = Atlas(data, storage, resolution=25, sample = True)
figures = Figures(data, storage, atlas, sample = True)

### Load data

In [None]:
table_meta = pd.read_csv('notebooks/scRNAseq/data/meta_table.tsv', sep='\t', usecols=[0, 2,3,4, 7], index_col = 0 )
table_exp_genes = pd.read_csv('notebooks/scRNAseq/data/expr_normalized_table.tsv', sep='\t')


In [None]:
# Replace 0 by nan
table_exp_genes.replace(0, np.nan, inplace=True)
table_exp_genes

In [None]:
# Reorganize columns
table_meta = table_meta[['stereo_AP', 'stereo_DV', 'stereo_ML', 'ABA_acronym']]

In [None]:
# Keep only the spots that are present in both tables
table_meta = table_meta.loc[table_meta.index.isin(table_exp_genes.index)]
table_exp_genes = table_exp_genes.loc[table_exp_genes.index.isin(table_exp_genes.index)]

In [None]:
# Ensure that the table are sorted by increasing spots
table_meta.sort_index(inplace = True)
table_exp_genes.sort_index(inplace = True)

# Get spots and genes names
l_spots = list(table_meta.index)
l_genes = list(table_exp_genes.columns)

In [None]:
counts = table_meta['ABA_acronym'].value_counts()

In [None]:
# Do a linear regression structure-wise to get coordinates of the molecular atlas in the ccfv3
l_x = []
l_y = []
l_z = []
l_xs = []
l_ys = []
l_zs = []
for acronym in counts[counts == 1].index.to_list():
    try:
        id_structure = atlas.bg_atlas.structures[acronym]['id']
        array_coor = np.where(atlas.bg_atlas.annotation == id_structure)
        xs, ys, zs = np.mean(array_coor, axis=1)*25/1000
        std_xs, std_ys, std_zs = np.std(array_coor, axis=1)*25/1000
        x, y, z = table_meta[table_meta['ABA_acronym'] == acronym].iloc[0,0:3].to_numpy()
        if std_xs<0.12:
            l_x.append(x)
            l_xs.append(xs)
        if std_ys<0.12:
            l_y.append(y)
            l_ys.append(ys)
        if std_zs<0.12:
            l_z.append(z)
            l_zs.append(zs)
    except:
        pass
  
X = np.vstack([np.array(l_x), np.ones(len(l_x)), np.zeros(len(l_x)), np.zeros(len(l_x))   ]).T
Y = np.vstack([np.array(l_y), np.zeros(len(l_y)), np.ones(len(l_y)), np.zeros(len(l_y))    ]).T
Z = np.vstack([np.array(l_z), np.zeros(len(l_z)), np.zeros(len(l_z)) , np.ones(len(l_z))      ]).T
M = np.vstack((X,Y,Z))
y = np.array(l_xs + l_ys + l_zs)

a, b, c, d = np.linalg.lstsq(M, y, rcond=None)[0]


In [None]:
plt.plot(np.array(l_x), np.array(l_xs), 'o', label='Original data', markersize=10)
plt.plot(np.array(l_x), a*np.array(l_x) + b, 'r', label='Fitted line')
plt.plot(np.array(l_y), np.array(l_ys), 'o', label='Original data', markersize=10)
plt.plot(np.array(l_y), a*np.array(l_y) + c, 'r', label='Fitted line')
plt.plot(np.array(l_z), np.array(l_zs), 'o', label='Original data', markersize=10)
plt.plot(np.array(l_z), a*np.array(l_z) + d, 'r', label='Fitted line')
plt.legend()
plt.show()

In [None]:
# Convert molecular atlas coordinates to our system of coordinates (i.e. ccfv3)
table_meta["stereo_AP"] = a * table_meta["stereo_AP"] + b
table_meta["stereo_DV"] = a * table_meta["stereo_DV"] + c
table_meta["stereo_ML"] = a * table_meta["stereo_ML"] + d
table_meta

#### Plot the points from the molecular atlas in 3D in our reference brain

In [None]:
# Coordinates molecular atlas
x_mol = table_meta['stereo_AP'].to_numpy()
y_mol = table_meta['stereo_DV'].to_numpy()
z_mol = table_meta['stereo_ML'].to_numpy()
# Get scatter figure for the scRNAseq spots
scatter = go.Scatter3d(
    x=x_mol,
    y=y_mol,
    z=z_mol,
    mode='markers',
    marker=dict(
        size=2,
        opacity=0.8
    ),
)

#fig = go.Figure(data=scatter)
#fig.show()

In [None]:
# Get root figure
root_data = figures._storage.return_shelved_object(
    "figures/3D_page",
    "volume_root",
    force_update=False,
    compute_function=figures.compute_3D_root_volume,
)


root_data['hoverinfo']= 'skip'
scatter['hoverinfo']= "all"

In [None]:
fig = go.Figure(data=[root_data, scatter])


# Hide grey background
fig.update_layout(
    margin=dict(t=0, r=0, b=0, l=0),
    scene=dict(
        xaxis=dict(backgroundcolor="rgba(0,0,0,0)"),
        yaxis=dict(backgroundcolor="rgba(0,0,0,0)"),
        zaxis=dict(backgroundcolor="rgba(0,0,0,0)"),
    ),
)

# Set background color to zero
fig.layout.template = "plotly_dark"
fig.layout.plot_bgcolor = "rgba(0,0,0,0)"
fig.layout.paper_bgcolor = "rgba(0,0,0,0)"


fig.show()

#### Get a dictionnary of lipid expression for each dot

In [None]:
@njit
def find_index_coordinate(x, y, z, array_x, array_y, array_z, array_c):
    array_expression = np.zeros((len(x_mol),), dtype = np.uint8)
    idx_mol = 0
    for x, y, z in zip(x_mol, y_mol, z_mol):
        min_idx = -1
        min_dist = 1000000
        idx = 0
        for xt,yt,zt in zip(array_x, array_y, array_z):
            dist = (x-xt)**2 + (y-yt)**2 + (z-zt)**2
            if dist < min_dist:
                min_dist = dist
                min_idx = idx
            idx +=1
        array_expression[idx_mol] = array_c[min_idx]
        idx_mol += 1
    return array_expression
        
@njit
def fill_array_coordinates(array_coordinates, array_x, array_y, array_z, array_c, scaling):
    for xt,yt,zt, c in zip(array_x, array_y, array_z, array_c):
        x = int(round(xt * scaling))
        y = int(round(yt * scaling))
        z = int(round(zt * scaling))
        array_coordinates[x,y,z] = c
    return array_coordinates
                
    
    
@njit
def return_lipid_expr(x_mol, y_mol, z_mol, array_coordinates, scaling, radius = 0.1):
    radius = int(scaling * radius)
    array_expression = np.zeros((len(x_mol),), dtype = np.uint8)
    idx = 0
    for x, y, z in zip(x_mol, y_mol, z_mol):
        x = int(round(x * scaling))
        y = int(round(y * scaling))
        z = int(round(z * scaling))
        # Check closest in a cube of radius 0.1 *scaling
        range_cube = range(-radius,radius+1, 1)
        min_dist = 100000000
        min_expr = -1
        for delta_x in range_cube:
            coor_x = x + delta_x
            if coor_x>=0 and coor_x < array_coordinates.shape[0]:
                for delta_y in range_cube:
                    coor_y = y + delta_y
                    if coor_y>=0 and coor_y < array_coordinates.shape[1]:
                        for delta_z in range_cube:
                            coor_z = z + delta_z
                            if coor_z >=0 and coor_z < array_coordinates.shape[2]:
                                expr = array_coordinates[coor_x, coor_y, coor_z]
                                if not np.isnan(expr):
                                    dist = np.sqrt(delta_x**2 + delta_y**2 + delta_z**2)
                                    if dist < min_dist:
                                        min_dist = dist
                                        min_expr = int(expr)
        if min_expr>=0:
            array_expression[idx] = min_expr
        else:
            print(idx)
        idx+=1
    return array_expression
                            
                    


In [None]:

def compute_array_exp_lipids(brain_1 = False, decrease_resolution_factor = 5, method =  'full'):
    ll_exp_lipids = []
    l_name_lipids = []

    # Define variables for coordinates array it won't change from lipid to lipid 
    if method == 'NN':
        array_coordinates = None

    # Simulate a click on all lipid names
    for name in sorted(
        figures._data.get_annotations_MAIA_transformed_lipids(brain_1=brain_1).name.unique()
    ):
        structures = figures._data.get_annotations_MAIA_transformed_lipids(brain_1=brain_1)[
            figures._data.get_annotations_MAIA_transformed_lipids(brain_1=brain_1)["name"] == name
        ].structure.unique()
        for structure in sorted(structures):
            cations = figures._data.get_annotations_MAIA_transformed_lipids(brain_1=brain_1)[
                (
                    figures._data.get_annotations_MAIA_transformed_lipids(brain_1=brain_1)["name"]
                    == name
                )
                & (
                    figures._data.get_annotations_MAIA_transformed_lipids(brain_1=brain_1)[
                        "structure"
                    ]
                    == structure
                )
            ].cation.unique()
            for cation in sorted(cations):
                l_selected_lipids = []
                l_slices = figures._data.get_slice_list(indices="brain_1" if brain_1 else "brain_2")
                for slice_index in l_slices:

                    # Find lipid location
                    l_lipid_loc = (
                        figures._data.get_annotations()
                        .index[
                            (figures._data.get_annotations()["name"] == name)
                            & (figures._data.get_annotations()["structure"] == structure)
                            & (figures._data.get_annotations()["slice"] == slice_index)
                            & (figures._data.get_annotations()["cation"] == cation)
                        ]
                        .tolist()
                    )

                    # If several lipids correspond to the selection, we have a problem...
                    if len(l_lipid_loc) > 1:
                        logging.warning("More than one lipid corresponds to the selection")
                        l_lipid_loc = [l_lipid_loc[-1]]
                    # If no lipid correspond to the selection, set to -1
                    if len(l_lipid_loc) == 0:
                        l_lipid_loc = [-1]

                    # add lipid index for each slice
                    l_selected_lipids.append(l_lipid_loc[0])

                # Get final lipid name
                lipid_string = name + " " + structure + " " + cation

                # If lipid is present in at least one slice
                if np.sum(l_selected_lipids) > -len(l_slices):

                    # Build the list of mz boundaries for each peak and each index
                    lll_lipid_bounds = [
                        [
                            [
                                (
                                    float(figures._data.get_annotations().iloc[index]["min"]),
                                    float(figures._data.get_annotations().iloc[index]["max"]),
                                )
                            ]
                            if index != -1
                            else None
                            for index in [lipid_1_index, -1, -1]
                        ]
                        for lipid_1_index in l_selected_lipids
                    ]
                    print("getting data for lipid ", lipid_string)
                    l_name_lipids.append(lipid_string)
                    l_expr = []
                    if method == 'NN':
                        array_x, array_y, array_z, array_c = figures.compute_3D_volume_figure(ll_t_bounds = lll_lipid_bounds, name_lipid_1=lipid_string, return_individual_slice_data=True,)
                        # Switch and rescale coordinates system to match both systems
                        array_x, array_y, array_z = array_y*1000, array_z*1000, array_x*1000

                        # Build an array of coordinates if needed:
                        #scaling = 100
                        #if array_coordinates is None:
                        #    array_coordinates = np.full( (int(np.max(array_x)*scaling+1),int(np.max(array_y)*scaling+1),int(np.max( array_z)*scaling+1)) , fill_value = np.nan, dtype = np.float32)
                        #array_coordinates = fill_array_coordinates(array_coordinates, array_x, array_y, array_z, np.array(array_c, dtype = np.uint8), scaling)
                        #l_expr = return_lipid_expr(np.array(x_mol), np.array(y_mol), np.array(z_mol), array_coordinates, scaling, radius = 0.7)
                        l_expr = find_index_coordinate(x_mol, y_mol, z_mol, array_x, array_y, array_z, np.array(array_c))

                        ll_exp_lipids.append(l_expr)
                    else:     
                        try:
                            interpolated_array = figures.compute_3D_volume_figure(ll_t_bounds = lll_lipid_bounds, name_lipid_1=lipid_string,decrease_dimensionality_factor=decrease_resolution_factor,return_interpolated_array=True, structure_guided_interpolation=False)
                        except Exception as e:
                            print(e)

                        for x, y, z in zip(x_mol, y_mol, z_mol):
                            # Need to convert the spot from the molecular atlas to a coordinate from our data (in the ccfv3)
                            x, y, z = np.round(np.array([x,y,z])*1000/25/decrease_resolution_factor).astype(np.int32)
                            l_expr.append(interpolated_array[x,y,z])
                        ll_exp_lipids.append(l_expr)


    return np.array(ll_exp_lipids).T, l_name_lipids



In [None]:
array_exp_lipids_brain_1, l_name_lipids_brain_1 = compute_array_exp_lipids(brain_1 = True)
array_exp_lipids_brain_2, l_name_lipids_brain_2 = compute_array_exp_lipids(brain_1 = False)

### Save lipid data in numpy arrays

In [None]:
# Save array of lipid expression for brain 1
with open('data/scRNAseq/array_exp_lipids_True.npy', 'wb') as f:
    np.save(f, array_exp_lipids_brain_1)

# Save corresponding names for brain 1
with open('data/scRNAseq/array_name_lipids_True.npy', 'wb') as f:
    np.save(f, np.array(l_name_lipids_brain_1))

# Save array of lipid expression for brain 2
with open('data/scRNAseq/array_exp_lipids_False.npy', 'wb') as f:
    np.save(f, array_exp_lipids_brain_2)

# Save corresponding names for brain 2
with open('data/scRNAseq/array_name_lipids_False.npy', 'wb') as f:
    np.save(f, np.array(l_name_lipids_brain_2))

##### Make a LASSO regression to explain lipid expression in terms of gene expression (and remove as many genes as possible)

In [None]:
def compute_regression_all_lipids(array_exp_lipids, array_exp_genes):
    """Compute the LASSO regression coefficients for all lipids.

    Returns:
        list(list(float)), list(float): List of coefficients (for each lipid) and list of scores
            for the LASSO regression explaining lipid expression in terms of gene expression.
    """
    # Define regression as a function for potential parallelization
    def compute_regression(index_lipid):
        clf = linear_model.ElasticNet(fit_intercept=True, alpha=0.2, positive=False)
        clf.fit(array_exp_genes, array_exp_lipids[:, index_lipid])
        return [
            clf.coef_,
            clf.score(array_exp_genes, array_exp_lipids[:, index_lipid]),
        ]

    # Compute regression for all lipids
    l_lipid_indices = list(range(array_exp_lipids.shape[1]))
    l_res = [x for x in map(compute_regression, l_lipid_indices)]

    # Store the coefficients and the score of the regressions
    ll_coef = []
    l_score = []
    for res in l_res:
        ll_coef.append(res[0])
        l_score.append(res[1])

    # Return result
    return np.array(ll_coef), l_score

In [None]:
# Get array of expression as a not log-transformed numpy array
array_exp_genes = np.exp(table_exp_genes.to_numpy())

# Replace nan with 0
np.nan_to_num(array_exp_genes, copy = False, nan = 0)

# Do the LASSO regression
array_coef_brain_1, l_score_brain_1 = compute_regression_all_lipids(array_exp_lipids_brain_1, array_exp_genes)
array_coef_brain_2, l_score_brain_2 = compute_regression_all_lipids(array_exp_lipids_brain_2, array_exp_genes)

In [None]:
# Remove genes that do not explain the expression of any lipids 
def filter_genes(array_coef, array_exp_genes, threshold = 15):
    l_to_keep = []
    for idx, col in enumerate(array_coef.T):
        if np.sum([1 for x in col if abs(x)>0])>threshold:
            l_to_keep.append(idx)
            
    return array_exp_genes[:, l_to_keep], array_coef[:, l_to_keep], np.array(l_genes)[l_to_keep]


array_exp_genes_brain_1, array_coef_brain_1, array_name_genes_brain_1 = filter_genes(array_coef_brain_1, array_exp_genes, threshold = 15)
array_exp_genes_brain_2, array_coef_brain_2, array_name_genes_brain_2 = filter_genes(array_coef_brain_2, array_exp_genes, threshold = 15)

#### Save the filtered data from the molecular atlas

In [None]:
for brain_1 in [True, False]:
    with open('data/scRNAseq/array_exp_genes_'+str(brain_1)+'.npy', 'wb') as f:
        np.save(f, array_exp_genes_brain_1 if brain_1 else array_exp_genes_brain_2)
    with open('data/scRNAseq/array_name_genes_'+str(brain_1)+'.npy', 'wb') as f:
        np.save(f, array_name_genes_brain_1 if brain_1 else array_name_genes_brain_2)
    with open('data/scRNAseq/array_coef_'+str(brain_1)+'.npy', 'wb') as f:
        np.save(f, array_coef_brain_1 if brain_1 else array_coef_brain_2)
    with open('data/scRNAseq/array_score_'+str(brain_1)+'.npy', 'wb') as f:
        np.save(f, np.array(l_score_brain_1) if brain_1 else np.array(l_score_brain_2))

### Save coordinats scRNAseq spots

In [None]:
with open('data/scRNAseq/array_coordinates.npy', 'wb') as f:
    np.save(f, np.array([x_mol, y_mol, z_mol ]))