In [None]:
import os
import numpy 

from aavomics import database
import scanpy
import anndata
import pandas

from aavomics import aavomics

from plotly import offline as plotly
from plotly import graph_objects
from plotly.subplots import make_subplots

In [None]:
ALIGNMENT_NAME = "cellranger_5.0.1_gex_mm10_2020_A_AAVomics"
TAXONOMY_ALIGNMENT_NAME = "cellranger_5.0.1_gex_mm10_2020_A"

CELL_TYPE_TAXONOMY_NAME = "CCN202105060"
TAXONOMY_NAMES = [CELL_TYPE_TAXONOMY_NAME] + ["CCN202105041", "CCN202105050", "CCN202105051", "CCN202105070"]
CELL_TYPES_TO_FILTER = ["Debris", "Multiplets", "Unknown", "", "nan", numpy.nan, "NaN"]
TAXONOMY_METADATA_PREFIXES = "p", "c", "X", "Y"

ANNDATA_FILE_NAME = "aavomics_mouse_cortex_2021.h5ad"
METADATA_FILE_NAME = "aavomics_cell_metadata.csv"

In [None]:
adatas = []
cell_set_names = []

for cell_set_index, cell_set in enumerate(database.CELL_SETS):
    
    print("Adding %s" % cell_set.name)
    
    tissue_sample = cell_set.source_tissue
    animal = tissue_sample.animal
    
    adata = anndata.read_h5ad(cell_set.get_anndata_file_path(alignment_name=ALIGNMENT_NAME))
    
    taxonomy_adata = anndata.read_h5ad(cell_set.get_anndata_file_path(alignment_name=TAXONOMY_ALIGNMENT_NAME))
    
    cell_mask = numpy.zeros((taxonomy_adata.shape[0], ), dtype=numpy.bool)

    for taxonomy_name in TAXONOMY_NAMES:

        taxonomy_mask = numpy.ones((taxonomy_adata.shape[0], ), dtype=numpy.bool)

        taxonomy_mask = taxonomy_mask & (~taxonomy_adata.obs[taxonomy_name].isna())    
        for cell_type_to_filter in CELL_TYPES_TO_FILTER:
            taxonomy_mask = taxonomy_mask & (taxonomy_adata.obs[taxonomy_name] != cell_type_to_filter)

        cell_mask = cell_mask | taxonomy_mask
        
    fresh_adata = taxonomy_adata[cell_mask].copy()

    for column in fresh_adata.obs.columns:
        fresh_adata.obs.drop(column, axis=1, inplace=True)
    
    injections = animal.injections

    read_sets = set()

    for sequencing_library in cell_set.sequencing_libraries:
        if sequencing_library.type == "Virus Transcripts":
            read_sets.update(sequencing_library.read_sets)
    
    virus_and_vector_names = set()
    
    if len(read_sets) == 0:
        disambiguate_viruses = False
        print("No amplified reads to disambiguate. Only including overall transduction")
    else:
        disambiguate_viruses = True
    
    virus_vector_names = {}

    for injection in injections:

        for vector in injection.vector_pool.vectors:

            virus = vector.delivery_vehicle

            if virus.name not in virus_vector_names:
                virus_vector_names[virus.name] = set([vector.name])
            else:
                virus_vector_names[virus.name].add(vector.name)

    for virus_name, vector_names in virus_vector_names.items():

        virus_and_vector_names.add(virus_name)

        if len(vector_names) > 1:
            virus_and_vector_names.update(vector_names)
    
    fresh_adata.obs.loc[fresh_adata.obs.index.values, "Cell Type"] = taxonomy_adata.obs.loc[fresh_adata.obs.index.values][CELL_TYPE_TAXONOMY_NAME]
    
    if len(virus_and_vector_names) > 0:
        fresh_adata.obs.loc[fresh_adata.obs.index.values, "AAV"] = adata[fresh_adata.obs.index.values, "AAV"].X.todense()
        
    cell_set_names.append(cell_set.name)
    
    if disambiguate_viruses:
        for column_name in virus_and_vector_names:
            fresh_adata.obs.loc[fresh_adata.obs.index.values, column_name] = adata.obs.loc[fresh_adata.obs.index.values][column_name]
        
    for taxonomy_name in TAXONOMY_NAMES:
        
        if taxonomy_name != CELL_TYPE_TAXONOMY_NAME:
            fresh_adata.obs.loc[fresh_adata.obs.index.values, taxonomy_name] = taxonomy_adata.obs.loc[fresh_adata.obs.index.values][taxonomy_name]
        
        for prefix in TAXONOMY_METADATA_PREFIXES:
            
            column_name = "%s_%s" % (prefix, taxonomy_name)
            
            if column_name in taxonomy_adata.obs.columns:
                fresh_adata.obs.loc[fresh_adata.obs.index.values, column_name] = taxonomy_adata.obs.loc[fresh_adata.obs.index.values][column_name]
    
    fresh_adata.obs.index = ["%s-%i" % (x.split("-")[0], cell_set_index + 1) for x in fresh_adata.obs.index.values]
    
    # Replace all NaN-like with numpy NaN
    fresh_adata.obs.replace("nan", numpy.nan, inplace=True)
    fresh_adata.obs.replace("", numpy.nan, inplace=True)
    fresh_adata.obs.replace("Unknown", numpy.nan, inplace=True)
    display(fresh_adata.obs)
    
    adatas.append(fresh_adata)

In [None]:
combined_adata = anndata.concat(adatas, label="Cell Set", keys=cell_set_names, join="outer")
combined_adata.var["Gene Name"] = adata.var["Gene Name"].loc[combined_adata.var.index.values]

In [None]:
combined_adata.obs.replace("nan", numpy.nan, inplace=True)
combined_adata.obs.replace("", numpy.nan, inplace=True)
combined_adata.obs.replace("Unknown", numpy.nan, inplace=True)

In [None]:
combined_adata.write_h5ad(os.path.join(database.DATA_PATH, ANNDATA_FILE_NAME))

In [None]:
combined_adata.obs.to_csv(os.path.join(database.DATA_PATH, METADATA_FILE_NAME))