In [None]:
import os

from aavomics import database
from aavomics import aavomics
import anndata
import numpy
import pandas

import plotly.graph_objects as graph_objects
from plotly import offline as plotly

In [None]:
ANNDATA_FILE_NAME = "aavomics_mouse_cortex_2021.h5ad"
CELL_SET_NAME = "20190712_TC5"
VIRUS_NAMES = ["PHP.eB", "PHP.V1"]
CELL_SET_NAMES = ["20190111_BC1", "20190321_BC2"]

In [None]:
adata = anndata.read_h5ad(os.path.join(database.DATA_PATH, ANNDATA_FILE_NAME))

In [None]:
CELL_TYPE_HIERARCHY = {
    "Astrocytes": {
        "Myoc- Astrocytes": {},
        "Myoc+ Astrocytes": {}
    },
    "Vascular Cells": {
        "Endothelial Cells": {},
        "Pericytes": {},
        "Red Blood Cells": {},
        "Vascular SMCs": {},
        "VLMCs": {}
    },
    "Immune Cells": {
        "Perivascular Macrophages": {},
        "Microglia": {},
        "Leukocytes": {}
    },
    "Oligodendrocytes": {
        "OPCs": {},
        "Committed Oligodendrocytes": {},
        "Mature Oligodendrocytes": {}
    },
    "Neurons": {
        "L2": {},
        "L2/3": {},
        "L3": {},
        "L4/5": {},
        "L5": {},
        "L5/6": {},
        "L6": {},
        "Lamp5": {},
        "Pax6": {},
        "Sncg": {},
        "Vip": {},
        "Sst": {},
        "Pvalb": {}
    }
}

cell_types = []

for cell_type, cell_subtypes in CELL_TYPE_HIERARCHY.items():
    cell_types.extend(cell_subtypes.keys())

In [None]:
cell_set_vector_cell_type_transcripts_per_cell = {}

for cell_set_name in CELL_SET_NAMES:
    
    cell_set_vector_cell_type_transcripts_per_cell[cell_set_name] = {}
    
    cell_set_adata = adata[adata.obs["Cell Set"] == cell_set_name].copy()
    
    for virus_name in VIRUS_NAMES:
        
        cell_set_vector_cell_type_transcripts_per_cell[cell_set_name][virus_name] = {}
        
        for barcode in ["BC%i" % i for i in range(1, 5)]:
            
            column_name = "%s mNeonGreen NLS %s" % (virus_name, barcode)
            
            if column_name not in cell_set_adata.obs.columns:
                continue
            
            if cell_set_adata.obs[column_name].values.sum() == 0:
                continue
                
            cell_set_vector_cell_type_transcripts_per_cell[cell_set_name][virus_name][barcode] = {}
            
            for cell_type in CELL_TYPE_HIERARCHY:
                
                cell_type_mask = cell_set_adata.obs["Cell Type"].isin(list(CELL_TYPE_HIERARCHY[cell_type].keys()))
                
                cell_type_adata = cell_set_adata[cell_type_mask]
                
                virus_transcript_counts = cell_type_adata.obs[column_name].values.sum()
                
                cell_set_vector_cell_type_transcripts_per_cell[cell_set_name][virus_name][barcode][cell_type] = virus_transcript_counts/cell_type_mask.sum()

In [None]:
cell_set_vector_cell_type_transcripts_per_cell

In [None]:
cell_types_list = ["Astrocytes", "Vascular Cells", "Oligodendrocytes", "Immune Cells", "Neurons"]

traces = []

y_max = 0

for virus_index, virus in enumerate(VIRUS_NAMES):
    
    x_values = []
    y_values = []

    for cell_type_index, cell_type in enumerate(cell_types_list):
        
        cell_type_values = []

        for cell_set_index, cell_set_name in enumerate(CELL_SET_NAMES):

            x_value = cell_type_index * len(CELL_SET_NAMES) * len(VIRUS_NAMES) + cell_set_index * len(VIRUS_NAMES) + virus_index
            
            for vector in cell_set_vector_cell_type_transcripts_per_cell[cell_set_name][virus]:
                rate = cell_set_vector_cell_type_transcripts_per_cell[cell_set_name][virus][vector][cell_type]
                
                y_values.append(rate)
                x_values.append(x_value)
    
    trace = graph_objects.Scatter(
        x=x_values,
        y=y_values,
        name=virus,
        mode="markers"
    )
    
    y_max = max(y_max, max(y_values))
    
    traces.append(trace)
layout = {}

layout["height"] = 600
layout["width"] = 800
layout["plot_bgcolor"] = "rgba(255, 255, 255, 0)"
layout["paper_bgcolor"] = "rgba(255, 255, 255, 0)"
layout["yaxis"] = {
    "title": {
        "text": "Viral Transcripts/Cell",
    },
    "tickvals": list(range(int(numpy.ceil(y_max)+1))),
    "gridcolor": "rgba(0, 0, 0, 0.25)"
}
layout["xaxis"] = {
    "tickvals": [],
    "ticktext": []
}
layout["title"] = {
    "text":"Replicate barcodes have variable transcript abundance"
}

y_max = y_max * 1.1

figure = graph_objects.Figure(data=traces, layout=layout)

for cell_type_index, cell_type_name in enumerate(cell_types_list):
    
    cell_type_start_location = cell_type_index * len(CELL_SET_NAMES) * len(VIRUS_NAMES)
    cell_type_end_location = (cell_type_index + 1) * len(CELL_SET_NAMES) * len(VIRUS_NAMES) - 1
    
    middle_location = (cell_type_end_location - cell_type_start_location) / 2 + cell_type_start_location
    
    figure.add_annotation(
        x=middle_location,
        y=0,
        text=cell_type_name,
        showarrow=False,
        yanchor="top",
        yshift=-50,
        textangle=90
        
    )
    
    if cell_type_index != 0:
        previous_cell_type_end_location = (cell_type_index) * len(CELL_SET_NAMES) * len(VIRUS_NAMES) - 1
        line_location = (cell_type_start_location - previous_cell_type_end_location) / 2 + previous_cell_type_end_location
        
        figure.add_shape(
            type="line",
            x0=line_location, y0=0, x1=line_location, y1=y_max,
            line=dict(
                width=2,
                dash="dash",
            )
        )
        
    for cell_set_index, cell_set_name in enumerate(CELL_SET_NAMES):
    
        cell_set_start_location = cell_type_start_location + cell_set_index * len(VIRUS_NAMES)
        cell_set_end_location = cell_type_start_location + (cell_set_index + 1) * len(VIRUS_NAMES) - 1

        middle_location = (cell_set_end_location - cell_set_start_location) / 2 + cell_set_start_location
    
        figure.add_annotation(
            x=middle_location,
            y=0,
            text=cell_set_name.split("_")[-1].replace("BC", "S"),
            showarrow=False,
            yanchor="top",
            yshift=-15,
            textangle=0

        )


plotly.iplot(figure)

figure.write_image(os.path.join("out", "barcode_transcript_rate_variability.svg"))