In [None]:
import os

from aavomics import database
from aavomics import aavomics
import pandas

import numpy
import anndata
import scipy.stats
from scipy import stats
from statsmodels.stats import proportion
from statsmodels.stats import multitest

import plotly.graph_objects as graph_objects
from plotly import offline as plotly

In [None]:
CELL_COUNT_THRESHOLD = 50

CELL_SET_NAMES = ["20200720_BC4_1", "20200720_BC4_2"]

TRANSDUCTION_RATE_FILE_NAME = "aavomics_cell_type_transduction_rates.csv"

In [None]:
CELL_TYPE_HIERARCHY = {
    "Neurons": {
        "Glutamatergic": {
            "L2": {},
            "L2/3": {},
            "L3": {},
            "L4/5": {},
            "L5": {},
            "L5/6": {},
            "L6": {}
        },
        "GABAergic": {
            "Pvalb": {},
            "Sst": {},
            "Vip": {},
            "Sncg": {},
            "Pax6": {},
            "Lamp5": {}
        }
    },
    "Non-Neuronal Cells": {
        "Astrocytes": {
            "Myoc- Astrocytes": {},
            "Myoc+ Astrocytes": {}
        },
        "Vascular Cells": {
            "Endothelial Cells": {},
            "Pericytes": {},
            "Red Blood Cells": {},
            "Vascular SMCs": {},
            "VLMCs": {}
        },
        "Immune Cells": {
            "Microglia": {},
            "Perivascular Macrophages": {},
            "Leukocytes": {}
        },
        "Oligodendrocytes": {
            "OPCs": {},
            "OPCs": {},
            "Committed Oligodendrocytes": {},
            "Mature Oligodendrocytes": {}
        }
    }
}

VIRUS_NAMES = [
    "AAV9",
    "PHP.B",
    "PHP.eB",
    "CAP-B10",
    "PHP.V1",
    "PHP.C1",
    "PHP.C2"
]

In [None]:
def get_leaf_nodes(node, key, current_level, max_level):
    
    nodes = []
    
    if len(node) == 0 or current_level == max_level:
        return [key]
    
    for child_node in node:
        nodes += get_leaf_nodes(node[child_node], child_node, current_level + 1, max_level)
    
    return nodes
    
cell_type_names = get_leaf_nodes(CELL_TYPE_HIERARCHY, None, 0, 99)

In [None]:
transduction_rate_df = pandas.read_csv(os.path.join(database.DATA_PATH, TRANSDUCTION_RATE_FILE_NAME), index_col=0)

In [None]:
virus_cell_set_cell_type_transduction_counts = {x: {} for x in VIRUS_NAMES}
cell_set_cell_type_counts = {}

for cell_set_name in CELL_SET_NAMES:
    
    transduction_rate_cell_set_mask = transduction_rate_df["Cell Set"] == cell_set_name
    
    cell_set_cell_type_counts[cell_set_name] = {}
    
    for virus_name in VIRUS_NAMES:
        virus_cell_set_cell_type_transduction_counts[virus_name][cell_set_name] = {}
    
    for virus_name in VIRUS_NAMES:
        
        virus_cell_set_cell_type_transduction_counts[virus_name][cell_set_name]["All Cells"] = 0
        
        transduction_rate_virus_name_mask = transduction_rate_df["Virus"] == virus_name
        
        filtered_transduction_rate_df = transduction_rate_df[transduction_rate_cell_set_mask & transduction_rate_virus_name_mask]
            
        for cell_type_name in cell_type_names:

            found_parent_cell_types = False

            for top_level_cell_type in CELL_TYPE_HIERARCHY:
                for subtype in CELL_TYPE_HIERARCHY[top_level_cell_type]:
                    for subsubtype in CELL_TYPE_HIERARCHY[top_level_cell_type][subtype]:
                        if subsubtype == cell_type_name:
                            found_parent_cell_types = True
                            break
                    if found_parent_cell_types:
                        break
                if found_parent_cell_types:
                    break

            if top_level_cell_type not in cell_set_cell_type_counts[cell_set_name]:
                cell_set_cell_type_counts[cell_set_name][top_level_cell_type] = 0
                
                for virus_to_set in VIRUS_NAMES:
                    virus_cell_set_cell_type_transduction_counts[virus_to_set][cell_set_name][top_level_cell_type] = 0
                    
            if subtype not in cell_set_cell_type_counts[cell_set_name]:
                cell_set_cell_type_counts[cell_set_name][subtype] = 0
                
                for virus_to_set in VIRUS_NAMES:
                    virus_cell_set_cell_type_transduction_counts[virus_to_set][cell_set_name][subtype] = 0

            virus_rate = filtered_transduction_rate_df["%s Transduction Rate" % cell_type_name][0]
            num_cell_type_cells = filtered_transduction_rate_df["%s Num Cells" % cell_type_name][0]
            
            
            if cell_type_name not in cell_set_cell_type_counts[cell_set_name]:
                cell_set_cell_type_counts[cell_set_name][cell_type_name] = num_cell_type_cells
                cell_set_cell_type_counts[cell_set_name][top_level_cell_type] += num_cell_type_cells
                cell_set_cell_type_counts[cell_set_name][subtype] += num_cell_type_cells

            num_transduced = virus_rate*num_cell_type_cells

            virus_cell_set_cell_type_transduction_counts[virus_name][cell_set_name][cell_type_name] = num_transduced
            virus_cell_set_cell_type_transduction_counts[virus_name][cell_set_name][subtype] += num_transduced
            virus_cell_set_cell_type_transduction_counts[virus_name][cell_set_name][top_level_cell_type] += num_transduced
            virus_cell_set_cell_type_transduction_counts[virus_name][cell_set_name]["All Cells"] += num_transduced

In [None]:
USE_L2FC = False
MAX_DOT_SIZE = 15
MIN_DOT_SIZE = 0
USE_ALL_CELLS = True
USE_RELATIVE_RATE = False
COMBINE_Z_SCORES = True

p_values = []
p_value_texts = []
p_val_lengths = []
color_values = []

for TOP_LEVEL_CELL_TYPE in CELL_TYPE_HIERARCHY:

    x_values = []
    y_values = []
    z_scores = []
    size_values = []

    cell_type_names_to_plot = []

    parent_cell_type_map = {}

    for top_level_cell_type in CELL_TYPE_HIERARCHY:

        parent_cell_type_map[top_level_cell_type] = "All Cells"

        if top_level_cell_type != TOP_LEVEL_CELL_TYPE:
            continue

        cell_type_names_to_plot.append(top_level_cell_type)

        for subtype in CELL_TYPE_HIERARCHY[top_level_cell_type]:

            parent_cell_type_map[subtype] = top_level_cell_type
            cell_type_names_to_plot.append(subtype)

            for subsubtype in CELL_TYPE_HIERARCHY[top_level_cell_type][subtype]:
                cell_type_names_to_plot.append(subsubtype)
                parent_cell_type_map[subsubtype] = subtype

    y_axis_labels = []

    for cell_type_name in cell_type_names_to_plot[::-1]:

        num_cells = sum([cell_set_cell_type_counts[x][cell_type_name] for x in CELL_SET_NAMES])

        if num_cells < CELL_COUNT_THRESHOLD:
            continue

        y_value = "%s (%i cells)" % (cell_type_name, num_cells)
        y_axis_labels.append(y_value)

        if USE_ALL_CELLS:
            parent_cell_type = "All Cells"
        elif USE_RELATIVE_RATE:
            parent_cell_type = parent_cell_type_map[cell_type_name]
        else:
            parent_cell_type = TOP_LEVEL_CELL_TYPE

        for virus in VIRUS_NAMES:

            x_values.append(virus)
            y_values.append(y_value)

            cell_set_color_values = []
            cell_set_size_values = []
            cell_set_z_scores = []

            total_virus_count = 0
            total_virus_parent_count = 0

            total_other_virus_count = 0
            total_other_virus_parent_count = 0

            for cell_set_name in CELL_SET_NAMES:

                other_virus_count = 0
                other_virus_parent_count = 0

                for other_virus in VIRUS_NAMES:

                    if virus == other_virus:
                        continue

                    other_virus_count += virus_cell_set_cell_type_transduction_counts[other_virus][cell_set_name][cell_type_name]
                    other_virus_parent_count += virus_cell_set_cell_type_transduction_counts[other_virus][cell_set_name][parent_cell_type]

                virus_count = virus_cell_set_cell_type_transduction_counts[virus][cell_set_name][cell_type_name]
                virus_parent_count = virus_cell_set_cell_type_transduction_counts[virus][cell_set_name][parent_cell_type]

                if virus_parent_count == 0:
                    virus_rate = 0
                else:
                    virus_rate = virus_count/virus_parent_count

                if other_virus_parent_count == 0:
                    other_virus_rate = 0
                else:
                    other_virus_rate = other_virus_count/other_virus_parent_count

                color_value = virus_rate - other_virus_rate

                cell_set_color_values.append(color_value)

                size_value = virus_count/num_cells

                cell_set_size_values.append(size_value)

                total_virus_count += virus_count
                total_other_virus_count += other_virus_count
                total_virus_parent_count += virus_parent_count
                total_other_virus_parent_count += other_virus_parent_count

                virus_count = numpy.round(virus_count)
                other_virus_count = numpy.round(other_virus_count)
                virus_parent_count = numpy.round(virus_parent_count)
                other_virus_parent_count = numpy.round(other_virus_parent_count)

                if virus_count > virus_parent_count:
                    print(virus, virus_count, cell_type_name)

                if virus_count+other_virus_count == 0:
                    z = 0
                else:
                    if virus_parent_count == 0 or other_virus_parent_count == 0:
                        z = 0
                    else:
                        z, _ = proportion.proportions_ztest([virus_count, other_virus_count], [virus_parent_count, other_virus_parent_count])

                cell_set_z_scores.append(z)

            size_values.append(numpy.mean(cell_set_size_values))

            if COMBINE_Z_SCORES:
                z_score = sum(cell_set_z_scores)/numpy.sqrt(len(cell_set_z_scores))
            else:
                if virus_count < 10 or other_virus_count < 10:
                    z_score = 0
                else:
                    z_score, _ = proportion.proportions_ztest([total_virus_count, total_other_virus_count], [total_virus_parent_count, total_other_virus_parent_count])
            z_scores.append(z_score)
            color_values.append(numpy.mean(cell_set_color_values)*100)
            p_value_texts.append("%s %s" % (cell_type_name, virus))

    sizes = (MAX_DOT_SIZE-MIN_DOT_SIZE) * numpy.power(numpy.array(size_values)/max(size_values), 1/2) + MIN_DOT_SIZE

    cell_type_p_values = stats.norm.sf(numpy.abs(z_scores))*2
    p_val_lengths.append(cell_type_p_values.shape[0])
    
    p_values.extend(cell_type_p_values)
    
results = multitest.multipletests(p_values, method="fdr_bh", alpha=0.05)

significance_mask_by_type = {
    "Neurons": results[0][0:p_val_lengths[0]],
    "Non-Neuronal Cells": results[0][p_val_lengths[0]:]
}

adjusted_p_values_by_type = {
    "Neurons": results[1][0:p_val_lengths[0]],
    "Non-Neuronal Cells": results[1][p_val_lengths[0]:]
}

In [None]:
p_values[p_value_texts.index("Astrocytes PHP.eB")]

In [None]:
cell_type_p_values

In [None]:
index = 0

for TOP_LEVEL_CELL_TYPE in CELL_TYPE_HIERARCHY:
    
    cell_type_names_to_plot = []

    parent_cell_type_map = {}

    for top_level_cell_type in CELL_TYPE_HIERARCHY:

        parent_cell_type_map[top_level_cell_type] = "All Cells"

        if top_level_cell_type != TOP_LEVEL_CELL_TYPE:
            continue

        cell_type_names_to_plot.append(top_level_cell_type)

        for subtype in CELL_TYPE_HIERARCHY[top_level_cell_type]:

            parent_cell_type_map[subtype] = top_level_cell_type
            cell_type_names_to_plot.append(subtype)

            for subsubtype in CELL_TYPE_HIERARCHY[top_level_cell_type][subtype]:
                cell_type_names_to_plot.append(subsubtype)
                parent_cell_type_map[subsubtype] = subtype

    for cell_type_name in cell_type_names_to_plot[::-1]:

        num_cells = sum([cell_set_cell_type_counts[x][cell_type_name] for x in CELL_SET_NAMES])

        if num_cells < CELL_COUNT_THRESHOLD:
            continue
        
        for virus_name in VIRUS_NAMES:
            p_value = results[1][index]
            
            print(virus_name, cell_type_name, p_value, color_values[index])
            
            index += 1

In [None]:
all_texts = []
all_values = []

for TOP_LEVEL_CELL_TYPE in CELL_TYPE_HIERARCHY:

    USE_L2FC = False
    MAX_DOT_SIZE = 30
    MIN_DOT_SIZE = 0
    USE_ALL_CELLS = True
    USE_RELATIVE_RATE = False
    COMBINE_Z_SCORES = True

    x_values = []
    y_values = []
    color_values = []
    z_scores = []
    size_values = []

    cell_type_names_to_plot = []

    parent_cell_type_map = {}

    for top_level_cell_type in CELL_TYPE_HIERARCHY:

        parent_cell_type_map[top_level_cell_type] = "All Cells"

        if top_level_cell_type != TOP_LEVEL_CELL_TYPE:
            continue

        cell_type_names_to_plot.append(top_level_cell_type)

        for subtype in CELL_TYPE_HIERARCHY[top_level_cell_type]:

            parent_cell_type_map[subtype] = top_level_cell_type
            cell_type_names_to_plot.append(subtype)

            for subsubtype in CELL_TYPE_HIERARCHY[top_level_cell_type][subtype]:
                cell_type_names_to_plot.append(subsubtype)
                parent_cell_type_map[subsubtype] = subtype

    y_axis_labels = []
    cell_type_cell_virus_z_scores = {}

    for cell_type_name in cell_type_names_to_plot[::-1]:

        num_cells = sum([cell_set_cell_type_counts[x][cell_type_name] for x in CELL_SET_NAMES])

        if num_cells < CELL_COUNT_THRESHOLD:
            continue

        cell_type_cell_virus_z_scores[cell_type_name] = {}

        y_value = "%s (%i cells)" % (cell_type_name, num_cells)
        y_axis_labels.append(y_value)

        if USE_ALL_CELLS:
            parent_cell_type = "All Cells"
        elif USE_RELATIVE_RATE:
            parent_cell_type = parent_cell_type_map[cell_type_name]
        else:
            parent_cell_type = TOP_LEVEL_CELL_TYPE

        for virus_name in VIRUS_NAMES:

            x_values.append(virus_name)
            y_values.append(y_value)

            cell_set_color_values = []
            cell_set_size_values = []
            cell_set_z_scores = []

            total_virus_count = 0
            total_virus_parent_count = 0

            total_other_virus_count = 0
            total_other_virus_parent_count = 0

            for cell_set_name in CELL_SET_NAMES:

                num_cells = cell_set_cell_type_counts[cell_set_name][cell_type_name]

                other_virus_count = 0
                other_virus_parent_count = 0

                for other_virus in VIRUS_NAMES:

                    if virus == other_virus:
                        continue

                    other_virus_count += virus_cell_set_cell_type_transduction_counts[other_virus][cell_set_name][cell_type_name]
                    other_virus_parent_count += virus_cell_set_cell_type_transduction_counts[other_virus][cell_set_name][parent_cell_type]

                virus_count = virus_cell_set_cell_type_transduction_counts[virus_name][cell_set_name][cell_type_name]
                virus_parent_count = virus_cell_set_cell_type_transduction_counts[virus_name][cell_set_name][parent_cell_type]

                if virus_parent_count == 0:
                    virus_rate = 0
                else:
                    virus_rate = virus_count/virus_parent_count

                if other_virus_parent_count == 0:
                    other_virus_rate = 0
                else:
                    other_virus_rate = other_virus_count/other_virus_parent_count

                color_value = virus_rate - other_virus_rate

                cell_set_color_values.append(color_value)

                size_value = virus_count/num_cells

                cell_set_size_values.append(size_value)

                total_virus_count += virus_count
                total_other_virus_count += other_virus_count
                total_virus_parent_count += virus_parent_count
                total_other_virus_parent_count += other_virus_parent_count

                virus_count = numpy.round(virus_count)
                other_virus_count = numpy.round(other_virus_count)
                virus_parent_count = numpy.round(virus_parent_count)
                other_virus_parent_count = numpy.round(other_virus_parent_count)

                if virus_count+other_virus_count == 0:
                    z = 0
                else:
                    if virus_parent_count == 0 or other_virus_parent_count == 0:
                        z = 0
                    else:
                        z, _ = proportion.proportions_ztest([virus_count, other_virus_count], [virus_parent_count, other_virus_parent_count])

                cell_set_z_scores.append(z)

            size_values.append(numpy.mean(cell_set_size_values))

            if COMBINE_Z_SCORES:
                z_score = sum(cell_set_z_scores)/numpy.sqrt(len(cell_set_z_scores))
            else:
                if virus_count+other_virus_count == 0 or total_virus_parent_count == 0 or total_other_virus_parent_count == 0:
                    z_score = 0
                else:
                    z_score, _ = proportion.proportions_ztest([total_virus_count, total_other_virus_count], [total_virus_parent_count, total_other_virus_parent_count])
            z_scores.append(z_score)
            color_values.append(numpy.mean(cell_set_color_values)*100)
            all_values.append(color_values[-1])
            all_texts.append("%s %s" % (cell_type_name, virus_name))
            cell_type_cell_virus_z_scores[cell_type_name][virus_name] = z_score

    sizes = (MAX_DOT_SIZE-MIN_DOT_SIZE) * numpy.power(numpy.array(size_values)/max(size_values), 1/2) + MIN_DOT_SIZE

    significance_mask = significance_mask_by_type[TOP_LEVEL_CELL_TYPE]

    if USE_L2FC:
        title = "Num Cells L2FC"
    else:
        title = "Delta Fraction Transduced (%)"

    line_colors = []
    line_widths = []

    for i in range(len(significance_mask)):
        if not significance_mask[i]:
            line_widths.append(0)
            line_colors.append("rgba(0, 0, 0, 0)")
        else:
            if color_values[i] > 0:
                line_colors.append("rgba(0, 0, 255, 1)")
            else:
                line_colors.append("rgba(255, 0, 0, 1)")
            line_widths.append(2)

    trace = graph_objects.Scatter(
        x=x_values,
        y=y_values,
        marker = {
            "color": [x if significance_mask[i] else 0 for i, x in enumerate(color_values)],
            "line": {
                "color": line_colors,
                "width": line_widths
            },
            "size": sizes,
            "colorscale": "rdbu",
            "colorbar": {
                "title": {
                    "text": title,
                    "side": "right"
                }
            },
            "cmid": 0
        },
        mode="markers",
        showlegend=False
    )

    significance_overlay = graph_objects.Scatter(
        x=numpy.array(x_values)[~significance_mask],
        y=numpy.array(y_values)[~significance_mask],
        marker={
            "color": "rgba(0, 0, 0, 0.25)",
            "opacity": 1,
            "size": numpy.array(sizes)[~significance_mask]
        },
        mode="markers",
        showlegend=False
    )

    percent_traces = []
    max_percent = numpy.max(numpy.array(size_values) * 100)
    INCREMENT = 0.25
    percent = INCREMENT

    x_axis_values = x_values.copy()
    y_axis_values = y_axis_labels.copy()
    x_axis_labels = VIRUS_NAMES.copy()

    skip_rows = 1

    for i in range(skip_rows + 1):
        x_axis_labels.insert(0, "")
        x_axis_values.insert(0, i)

    cell_type_index = len(y_axis_labels) - 1

    percents = [0.1, 0.5, 1, 1.5, 2, 3, 4, 5, 10, 15, 20, 30]

    for percent in percents:

        size = (MAX_DOT_SIZE-MIN_DOT_SIZE) * numpy.power((percent/100)/max(size_values), 1/2) + MIN_DOT_SIZE

        if cell_type_index <= 0:
            y_axis_labels.insert(0, "")
            cell_type_label = str(cell_type_index)
            y_axis_values.insert(0, cell_type_label)
        else:
            cell_type_label = y_axis_labels[cell_type_index]

        scatter = graph_objects.Scatter(
            x=[skip_rows],
            y=[cell_type_label],
            marker={
                "size": size,
                "color": "rgba(0, 0, 0, 0.25)"
            },
            line={
                "color": "rgba(255, 255, 255, 0)"
            },
            name=percent)

        percent_traces.append(scatter)

        cell_type_index -= 1

    layout = graph_objects.Layout(
        width=750,
        height=900,
        xaxis={
            "side": "top",
        },
        yaxis={
        },
        paper_bgcolor="rgba(255, 255, 255, 0)",
        plot_bgcolor="rgba(255, 255, 255, 0)",
        legend={
            "orientation": "h",
            "itemsizing": "trace"
        }
    )

    figure = graph_objects.Figure(data=[trace, significance_overlay] + percent_traces, layout=layout)

    plotly.iplot(figure)

    if USE_L2FC:
        plot_suffix = "L2FC"
    else:
        plot_suffix = "z_score"

    if not os.path.exists("out"):
        os.makedirs("out")

    figure.write_image(os.path.join("out", "7_variant_pool_tropism_%s_%s.svg" % (TOP_LEVEL_CELL_TYPE.lower().replace(" ", "_").replace("-", "_"), plot_suffix)))

In [None]:
cell_type_virus = "Neurons PHP.C2"

# p_values[p_value_texts.index()]
print(p_values[p_value_texts.index(cell_type_virus)])
print(all_values[all_texts.index(cell_type_virus)])