# Notebook to analyze and display scRNAseq data


### Load important modules

In [None]:
# Standard modules
import numpy as np
import os
import pandas as pd 
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from scipy import linalg

# Move to root directory for easier module handling
os.chdir("../..")
print(os.listdir("."))

#LBAE imports
from modules.maldi_data import MaldiData
from modules.figures import Figures
from modules.atlas import Atlas
from modules.launch import Launch
from modules.storage import Storage


# multithreading/multiprocessing
from multiprocessing import Pool
from threadpoolctl import threadpool_limits

# set thread limit
threadpool_limits(16)


#### Load LBAE objects

In [None]:
path_data = "data/whole_dataset/"
path_annotations = "data/annotations/"
path_db = "data/app_data/data.db"

# Load shelve database
storage = Storage(path_db)

# Load data
data = MaldiData(path_data, path_annotations)

# If True, only a small portions of the figures are precomputed (if precomputation has not already
# been done). Used for debugging purposes.
sample = False

# Load Atlas and Figures objects. At first launch, many objects will be precomputed and shelved in
# the classes Atlas and Figures.
atlas = Atlas(data, storage, resolution=25)
figures = Figures(data, storage, atlas)

### Load data

In [None]:
expr_table = pd.read_csv('notebooks/scRNAseq/data/expr_normalized_table.tsv', sep='\t')
meta_table = pd.read_csv('notebooks/scRNAseq/data/meta_table.tsv', sep='\t', usecols=[0, 2,3,4, 7] )

In [None]:
# Reorganize columns
meta_table = meta_table[['Unnamed: 0', 'stereo_AP', 'stereo_DV', 'stereo_ML', 'ABA_acronym']]
meta_table.rename(index={0: "spot_id"}, inplace=True)

In [None]:
counts = meta_table['ABA_acronym'].value_counts()

In [None]:
# Do a linear regression structure-wise to get coordinates of the molecular atlas in the ccfv3
l_x = []
l_y = []
l_z = []
l_xs = []
l_ys = []
l_zs = []
for acronym in counts[counts == 1].index.to_list():
    try:
        id_structure = atlas.bg_atlas.structures[acronym]['id']
        array_coor = np.where(atlas.bg_atlas.annotation == id_structure)
        xs, ys, zs = np.mean(array_coor, axis=1)*25/1000
        std_xs, std_ys, std_zs = np.std(array_coor, axis=1)*25/1000
        x, y, z = meta_table[meta_table['ABA_acronym'] == acronym].iloc[0,1:4].to_numpy()
        if std_xs<0.12:
            l_x.append(x)
            l_xs.append(xs)
        if std_ys<0.12:
            l_y.append(y)
            l_ys.append(ys)
        if std_zs<0.12:
            l_z.append(z)
            l_zs.append(zs)
    except:
        pass
  
X = np.vstack([np.array(l_x), np.ones(len(l_x)), np.zeros(len(l_x)), np.zeros(len(l_x))   ]).T
Y = np.vstack([np.array(l_y), np.zeros(len(l_y)), np.ones(len(l_y)), np.zeros(len(l_y))    ]).T
Z = np.vstack([np.array(l_z), np.zeros(len(l_z)), np.zeros(len(l_z)) , np.ones(len(l_z))      ]).T
M = np.vstack((X,Y,Z))
y = np.array(l_xs + l_ys + l_zs)

a, b, c, d = np.linalg.lstsq(M, y, rcond=None)[0]


In [None]:
plt.plot(np.array(l_x), np.array(l_xs), 'o', label='Original data', markersize=10)
plt.plot(np.array(l_x), a*np.array(l_x) + b, 'r', label='Fitted line')
plt.plot(np.array(l_y), np.array(l_ys), 'o', label='Original data', markersize=10)
plt.plot(np.array(l_y), a*np.array(l_y) + c, 'r', label='Fitted line')
plt.plot(np.array(l_z), np.array(l_zs), 'o', label='Original data', markersize=10)
plt.plot(np.array(l_z), a*np.array(l_z) + d, 'r', label='Fitted line')
plt.legend()
plt.show()

In [None]:
# Convert molecular atlas coordinates to our system of coordinates
meta_table["stereo_AP"] = a * meta_table["stereo_AP"] + b
meta_table["stereo_DV"] = a * meta_table["stereo_DV"] + c
meta_table["stereo_ML"] = a * meta_table["stereo_ML"] + d
meta_table

In [None]:
# Get scatter figure for the scRNAseq spots
scatter = go.Scatter3d(
    x=meta_table['stereo_AP'].to_numpy(),
    y=meta_table['stereo_DV'].to_numpy(),
    z=meta_table['stereo_ML'].to_numpy(),
    mode='markers',
    marker=dict(
        size=1,
        opacity=0.8
    )
)

#fig = go.Figure(data=scatter)
#fig.show()

In [None]:
# Get root figure
root_data = figures._storage.return_shelved_object(
    "figures/3D_page",
    "volume_root",
    force_update=False,
    compute_function=figures.compute_3D_root_volume,
)

In [None]:
fig = go.Figure(data=[root_data, scatter])


# Hide grey background
fig.update_layout(
    margin=dict(t=0, r=0, b=0, l=0),
    scene=dict(
        xaxis=dict(backgroundcolor="rgba(0,0,0,0)"),
        yaxis=dict(backgroundcolor="rgba(0,0,0,0)"),
        zaxis=dict(backgroundcolor="rgba(0,0,0,0)"),
    ),
)

# Set background color to zero
fig.layout.template = "plotly_dark"
fig.layout.plot_bgcolor = "rgba(0,0,0,0)"
fig.layout.paper_bgcolor = "rgba(0,0,0,0)"


fig.show()

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
np.random.seed(1)


fig = make_subplots(rows=1, cols=2)
fig.add_trace(go.Scatter(x=[1, 2, 3], y=[4, 5, 6]),row=1, col=1)
fig.add_trace(go.Scatter(x=[20, 30, 40], y=[50, 60, 70]),row=1, col=2)
f = go.FigureWidget(fig)
p1 = f.data[0]
p2 = f.data[1]
f.layout.hovermode = 'closest'

# create our callback function
def update(trace, points, selector):
    #c = list(scatter.marker.color)
    #s = list(scatter.marker.size)
    y = list(p2.y)
    for i in points.point_inds:
        #c[i] = '#bae2be'
        #s[i] = 20
        y[i] = y[i] + 1
        with f.batch_update():
            #scatter.marker.color = c
            #scatter.marker.size = s
            p2.y = y

p1.on_click(update)

f
# x = np.random.rand(100)
# y = np.random.rand(100)

# f = go.FigureWidget([go.Scatter(x=x, y=y, mode='markers'), go.Scatter(x=x, y=y, mode='lines')])

# scatter = f.data[0]
# line = f.data[1]
# colors = ['#a3a7e4'] * 100
# scatter.marker.color = colors
# scatter.marker.size = [10] * 100
# f.layout.hovermode = 'closest'



# # create our callback function
# def update_point(trace, points, selector):
#     c = list(scatter.marker.color)
#     s = list(scatter.marker.size)
#     y = list(line.y)
#     for i in points.point_inds:
#         c[i] = '#bae2be'
#         s[i] = 20
#         y[i] = y[i] + 1
#         with f.batch_update():
#             scatter.marker.color = c
#             scatter.marker.size = s
#             line.y = y

            


# scatter.on_click(update_point)

# f