In [None]:
from aavomics import database
import os
import pandas
import numpy


from plotly import offline as plotly
from plotly import graph_objects

In [None]:
TRANSDUCTION_RATE_FILE_NAME = "aavomics_cell_type_transduction_rates.csv"

In [None]:
CELL_TYPE_HIERARCHY = {
    "Neurons": {
        "Glutamatergic": {
            "L2": {},
            "L2/3": {},
            "L3": {},
            "L4/5": {},
            "L5": {},
            "L5/6": {},
            "L6": {}
        },
        "GABAergic": {
            "Pvalb": {},
            "Sst": {},
            "Vip": {},
            "Sncg": {},
            "Pax6": {},
            "Lamp5": {}
        }
    },
    "Non-Neuronal Cells": {
        "Astrocytes": {
            "Myoc- Astrocytes": {},
            "Myoc+ Astrocytes": {}
        },
        "Vascular Cells": {
            "Endothelial Cells": {},
            "Pericytes": {},
            "Red Blood Cells": {},
            "Vascular SMCs": {},
            "VLMCs": {}
        },
        "Immune Cells": {
            "Microglia": {},
            "Perivascular Macrophages": {},
            "Leukocytes": {}
        },
        "Oligodendrocytes": {
            "OPCs": {},
            "OPCs": {},
            "Committed Oligodendrocytes": {},
            "Mature Oligodendrocytes": {}
        }
    }
}

CELL_SET_NAMES = ["20190711_TC4", "20190712_TC5", "20190713_TC6", "20190713_TC7"]

VIRUS_NAMES = ["PHP.eB", "CAP-B10"]

In [None]:
transduction_rate_df = pandas.read_csv(os.path.join(database.DATA_PATH, TRANSDUCTION_RATE_FILE_NAME), index_col=0)

In [None]:
def append_labels_parent_labels_values(labels, parent_labels, values, cell_type_hierarchy, row_id, parent_name, name):
    
    if len(cell_type_hierarchy) == 0:
        return None
    
    num_total_transduced = 0
    
    for cell_type_name, child_cell_type_names in cell_type_hierarchy.items():
        
        num_child_transduced = append_labels_parent_labels_values(labels, parent_labels, values, child_cell_type_names, row_id, name, cell_type_name)
        
        if num_child_transduced is None:
            
            transduction_rate = transduction_rate_df.loc[row_id, "%s Transduction Rate" % cell_type_name]
            num_cells = transduction_rate_df.loc[row_id, "%s Num Cells" % cell_type_name]
    
            parent_labels.append(name)
            labels.append(cell_type_name)
            
            num_transduced = transduction_rate * num_cells
            values.append(num_transduced)
            num_total_transduced += num_transduced
        else:
            num_total_transduced += 0
        
    parent_labels.append(parent_name)
    labels.append(name)
    values.append(0)
    
    return num_total_transduced

for virus_name in VIRUS_NAMES:
    
    labels = []
    parent_labels = []
    values = []
    
    for cell_set_name in CELL_SET_NAMES:
        
        row_id = "%s-%s" % (cell_set_name, virus_name)
        
        if row_id not in transduction_rate_df.index.values:
            continue

        cell_set_labels = []
        cell_set_parent_labels = []
        cell_set_values = []

        append_labels_parent_labels_values(cell_set_labels, cell_set_parent_labels, cell_set_values, CELL_TYPE_HIERARCHY, row_id, "", "All Cells")
        
        cell_set_values = numpy.array(cell_set_values)/sum(cell_set_values)
        
        values.append(cell_set_values)
        
        if len(labels) > 0 and labels != cell_set_labels:
            raise ValueError("Hierarchies should match perfectly!")
        
        if len(parent_labels) > 0 and parent_labels != cell_set_parent_labels:
            raise ValueError("Hierarchies should match perfectly!")
            
        labels = cell_set_labels
        parent_labels = cell_set_parent_labels
            
    values = numpy.array(values).mean(axis=0)
    
    sunburst = graph_objects.Sunburst(
        labels=labels,
        parents=parent_labels,
        values=values
    )


    layout = graph_objects.Layout(
        title="%s Cell Type Distribution" % virus_name
    )

    figure = graph_objects.Figure([sunburst], layout=layout)
    
    figure.write_image(os.path.join("out", "%s_tropism_sunburst.svg" % virus_name))

    plotly.iplot(figure)