In [None]:
import anndata

import os
import numpy
from aavomics import database

from plotly import graph_objects
from plotly import offline as plotly

In [None]:
ANNDATA_FILE_NAME = "aavomics_mouse_cortex_2021.h5ad"
CELL_SET_NAME = "20190712_TC5"
VIRUS_NAME = "PHP.eB"

In [None]:
CELL_TYPE_HIERARCHY = {
    "Astrocytes": {
        "Myoc- Astrocytes": {},
        "Myoc+ Astrocytes": {}
    },
    "Vascular Cells": {
        "Endothelial Cells": {},
        "Pericytes": {},
        "Red Blood Cells": {},
        "Vascular SMCs": {},
        "VLMCs": {}
    },
    "Immune Cells": {
        "Perivascular Macrophages": {},
        "Microglia": {},
        "Leukocytes": {}
    },
    "Oligodendrocytes": {
        "OPCs": {},
        "Committed Oligodendrocytes": {},
        "Mature Oligodendrocytes": {}
    },
    "Neurons": {
        "L2": {},
        "L2/3": {},
        "L3": {},
        "L4/5": {},
        "L5": {},
        "L5/6": {},
        "L6": {},
        "Lamp5": {},
        "Pax6": {},
        "Sncg": {},
        "Vip": {},
        "Sst": {},
        "Pvalb": {}
    }
}

In [None]:
adata = anndata.read_h5ad(os.path.join(database.DATA_PATH, ANNDATA_FILE_NAME))

In [None]:
cell_set_adata = adata[adata.obs["Cell Set"] == CELL_SET_NAME].copy()

In [None]:
resolution=50
min_value=None
max_value=None
x_axis_title="# Viral Transcripts"
data_name=None
file_name=None
title=None
x_axis_log_scale=True
y_axis_log_scale=False

traces = []

for cell_type_name in ["Astrocytes", "Neurons"]:
    
    cell_type_mask = cell_set_adata.obs["Cell Type"].isin(list(CELL_TYPE_HIERARCHY[cell_type_name].keys()))
    
    data = cell_set_adata[cell_type_mask].obs[VIRUS_NAME].values
    
    print(cell_type_name, cell_type_mask.sum(), data.sum()/cell_type_mask.sum())
    
    num_data_points = len(data)

    x = []
    y = []

    values, counts = numpy.unique(data, return_counts=True)

    if min_value is None:
        min_value = values[0]
    else:
        min_value = min(min_value, values[0])
    if max_value is None:
        max_value = values[-1]
    else:
        max_value = max(max_value, values[-1])

    value_range = numpy.linspace(min_value, max_value, resolution+1)
    probability_range = numpy.linspace(0, 1, resolution+1)
    count_range = probability_range*sum(counts)

    latest_value_index = 0
    latest_index = 0
    data_point_count = 0

    if min_value < values[0]:
        values = numpy.insert(values, 0, 0)
        counts = numpy.insert(counts, 0, 0)
    if max_value > values[-1]:
        values = numpy.append(values, max_value)
        counts = numpy.append(counts, counts[-1])

    for value, count in zip(values, counts):

        data_point_count += count

        point_added = False

        # If this value is past our latest pointer, we plot a point
        if latest_value_index <= resolution and value >= value_range[latest_value_index]:

            x.append(value)
            y.append(data_point_count/num_data_points)

            latest_value_index += 1

            point_added = True

        if latest_index <= resolution and data_point_count >= count_range[latest_index]:

            latest_index += 1

            if not point_added:
                x.append(value)
                y.append(data_point_count/num_data_points)

            point_added = True


    scatter_plot = graph_objects.Scatter(
        x=x,
        y=y,
        name=" %s (%.1f transcripts/cell)" % (cell_type_name, data.sum()/cell_type_mask.sum())
    )
    
    traces.append(scatter_plot)

layout = {}

if title is not None:
    layout["title"] = title

layout["xaxis"] =  {
    "gridcolor": "rgba(1, 1, 1, 0.1)",
    "linecolor": "rgba(1, 1, 1, 0.1)"
}

if x_axis_log_scale:
    layout["xaxis"]["type"] = "log"
else:
    layout["xaxis"]["tickvals"] = numpy.linspace(0, 200, 11)
    layout["xaxis"]["range"] = [0, 200]

if x_axis_title is not None:
    layout["xaxis"]["title"] = x_axis_title

probability_range = numpy.linspace(0, 1, 6)
count_range = probability_range*sum(counts)

if data_name is None:
    data_name = ""
else:
    data_name = " " + data_name

layout["height"] = 500
layout["width"] = 700

layout["yaxis"] = {
    "title": "Cumulative Probability",
    "gridcolor": "rgba(1, 1, 1, 0.1)",
    "linecolor": "rgba(1, 1, 1, 0.1)",
    "range": [0, 1]
}

if y_axis_log_scale:
    layout["yaxis"]["type"] = "log"
else:
    layout["yaxis"]["tickvals"] = probability_range
    layout["yaxis"]["ticktext"] = ["%.2f (%.1e%s)" % (p, c, data_name) for p, c in zip(probability_range, count_range)]

layout["scene"] = {"aspectmode":"data"}
layout["plot_bgcolor"] = "rgba(255, 255, 255, 0)"
layout["paper_bgcolor"] = "rgba(255, 255, 255, 0)"

layout = graph_objects.Layout(layout)

figure = graph_objects.Figure(
    data=traces,
    layout=layout)

plotly.iplot(figure)

figure.write_image(os.path.join("out", "PHP-eB_astrocytes_vs_neurons_transcript_count_eCDF.svg"))