In [None]:
import os
import pandas

from aavomics import database
import anndata
import pysam

from skbio.alignment import StripedSmithWaterman as SSW

In [None]:
ALIGNMENT_NAME = "cellranger_5.0.1_gex_mm10_2020_A_AAVomics"

In [None]:
stats_dict = {}

for cell_set in database.CELL_SETS:
    
    tissue_sample = cell_set.source_tissue
    dissociation_run = tissue_sample.dissociation_run
    animal = tissue_sample.animal
    
    if tissue_sample.region != "Cortex":
        raise ValueError("Should only be cortex samples!")
    
    if animal is None:
        continue

    if dissociation_run.protocol_version == 2:
        cell_barcode_UMI_length = 26
    else:
        cell_barcode_UMI_length = 28

    if animal.injections is None:
        print("%s has no injections, skipping." % cell_set.name)
        continue
        
    injections = animal.injections

    # Get all the vectors associated with the same template
    template_vectors = {}

    for injection in injections:
        for vector in injection.vector_pool.vectors:
            if vector.cargo not in template_vectors:
                template_vectors[vector.cargo] = [vector]
            else:
                template_vectors[vector.cargo].append(vector)
    
    if len(template_vectors) == 0:
        print("%s has no delivered genetic cargo, skipping." % cell_set.name)
        continue
    
    viruses = set([vector.delivery_vehicle for template, vectors in template_vectors.items() for vector in vectors])
        
    adata_file_path = cell_set.get_anndata_file_path(alignment_name=ALIGNMENT_NAME)
    
    if not os.path.exists(adata_file_path):
        print("%s is missing anndata file expected at %s" % (cell_set.name, adata_file_path))
        continue
    
    print("Processing %s now" % cell_set.name)
        
    adata = anndata.read_h5ad(adata_file_path)
    
    alignment_directory = os.path.join(database.DATA_PATH, "cell_sets", cell_set.name, "virus", "alignments")

    read_sets = set()

    for sequencing_library in cell_set.sequencing_libraries:
        if sequencing_library.type == "Virus Transcripts":
            read_sets.update(sequencing_library.read_sets)

    template_cell_barcode_alignments = {}
    template_num_reads = {}

    for template in template_vectors:

        vectors = template_vectors[template]
        
        has_barcodes = False

        if len(vectors) > 1:
            has_barcodes = True
            vector_aligners = []

            for vector in vectors:

                if not vector.barcode:
                    continue

                vector_aligner = SSW(vector.barcode, suppress_sequences=False, score_only=True)

                vector_aligners.append([vector, vector_aligner])

                for i in range(1, len(vector.barcode) + 1):

                    if template.sequence[int(vector.barcode_location)-i] == vector.barcode[-i]:
                        permuted_barcode = vector.barcode[-i:] + vector.barcode[0:-i]
                        vector_aligner = SSW(permuted_barcode, suppress_sequences=False, score_only=True)
                        vector_aligners.append([vector, vector_aligner])
                    else:
                        break

                for i in range(len(vector.barcode)):

                    if template.sequence[int(vector.barcode_location)+i] == vector.barcode[i]:
                        permuted_barcode = vector.barcode[i+1:] + vector.barcode[-(i+1):]
                        vector_aligner = SSW(permuted_barcode, suppress_sequences=False, score_only=True)
                        vector_aligners.append([vector, vector_aligner])
                    else:
                        break
            barcode_length = len(vector.barcode)
        else:
            barcode_length = 0

        template_cell_barcode_alignments[template] = {}
        template_num_reads[template] = 0
        
        insertion_best_vectors = {}
        insertion_counts = {}

        for read_set in read_sets:

            alignment_file_name = "%s_%s.csv" % (read_set.name, template.name)
            alignment_file_path = os.path.join(alignment_directory, alignment_file_name)

            i = 0
            
            print("Parsing alignment file %s" % read_set.name)
            
            with open(alignment_file_path, "r") as alignment_file:

                line = alignment_file.readline()

                while line:
                
                    i += 1
                    
                    if i % 100000 == 0:
                        print(i)

                    line_values = line.split(",")
                    cell_barcode_UMI = line_values[0]
                    score = int(line_values[2])
                    max_score = int(line_values[3])
                    insertion = line_values[-2]
                    
                    if has_barcodes:
                        max_score -= 3 + 4*len(insertion)
                        
                    count = int(line_values[-1])

                    line = alignment_file.readline()
                    
                    if score/max_score < 0.9:
                        continue

                    template_num_reads[template] += count

                    if len(vectors) > 1:

                        if len(insertion) == 0:
                            continue
                            
                        if insertion not in insertion_best_vectors:
                            
                            insertion_counts[insertion] = 0
                            
                            vector_scores = [
                                (vector, vector_aligner(insertion)["optimal_alignment_score"])
                                for vector, vector_aligner in vector_aligners
                            ]

                            sorted_scores = sorted(vector_scores, key=lambda x: x[1], reverse=True)

                            if sorted_scores[0][1] == sorted_scores[1][1]:
                                insertion_best_vectors[insertion] = (None, None)
                            elif sorted_scores[0][1]/(barcode_length*2) < 0.9:
                                insertion_best_vectors[insertion] = (None, None)
                            else:
                                insertion_best_vectors[insertion] = (sorted_scores[0][0], sorted_scores[0][1])
                                
                        best_vector = insertion_best_vectors[insertion][0]
                        best_score = insertion_best_vectors[insertion][1]
                        insertion_counts[insertion] += count
                        
                        if best_vector is None:
                            continue
                        
                        score += 3 + 2 * best_score
                    else:
                        best_vector = vectors[0]

                    if cell_barcode_UMI not in template_cell_barcode_alignments[template]:
                        template_cell_barcode_alignments[template][cell_barcode_UMI] = [(score, max_score, count, best_vector)]
                    else:
                        template_cell_barcode_alignments[template][cell_barcode_UMI].append((score, max_score, count, best_vector))
    
    total_vector_counts = {vector: 0 for vector in template_vectors[template]}
    total_template_counts = {template: 0 for template in template_vectors}
    vector_counts_array = []

    for cell_barcode_UMI, counts in template_cell_barcode_alignments[template].items():

        cell_barcode_UMI_vector_counts = {}

        for score_tuple in template_cell_barcode_alignments[template][cell_barcode_UMI]:

            vector = score_tuple[3]

            total_vector_counts[vector] += score_tuple[2]
            
            for template in template_vectors:
                if vector in template_vectors[template]:
                    break

            if vector not in cell_barcode_UMI_vector_counts:
                cell_barcode_UMI_vector_counts[vector] = score_tuple[2]
                total_template_counts
            else:
                cell_barcode_UMI_vector_counts[vector] += score_tuple[2]

        for vector, count in cell_barcode_UMI_vector_counts.items():
            vector_counts_array.append(count)

    print("Done parsing alignment files")
            
    BAM_file_path = os.path.join(
            database.DATA_PATH, "cell_sets", cell_set.name, "transcriptome", "transcripts", ALIGNMENT_NAME, "possorted_genome_bam.bam")

    samfile = pysam.AlignmentFile(BAM_file_path, "rb")

    template_cell_barcode_UMI_counts = {}
    template_cell_barcode_UMI_vectors = {}

    for template in template_vectors:

        template_cell_barcode_UMI_counts[template] = {}
        template_cell_barcode_UMI_vectors[template] = {}
        
    vector_total_counts  = {}

    for vector in vectors:

        vector_total_counts[vector] = 0

        for cell_barcode_UMI in template_cell_barcode_alignments[template]:

            for score_tuple in template_cell_barcode_alignments[template][cell_barcode_UMI]:

                if vector != score_tuple[3]:
                    continue

                vector_total_counts[vector] += score_tuple[2]

    num_reads = 0
    num_reads_passing_filter = 0
    num_reads_matching_template = 0
    num_reads_with_cell_barcode_UMI_in_amplified = 0
    num_reads_unique_to_template = 0
    num_reads_unique_vector = 0

    for read in samfile.fetch("AAV"):

        num_reads += 1

        if read.mapping_quality != 255:
            continue

        if not read.has_tag("CB") or not read.has_tag("UB"):
            continue

        filter_bits = read.get_tag("xf")

        if not (filter_bits & 1):
            continue

        if (filter_bits & 2):
            continue

        if (filter_bits & 32):
            continue

        if not read.has_tag("TX"):
            continue

        num_reads_passing_filter += 1

        matching_templates = set([x.split(",")[0] for x in read.get_tag("TX").split(";")])

        matches_template = False

        for template in template_vectors:
            if template.name in matching_templates:
                matches_template = True
                break

        if not matches_template:
            continue

        num_reads_matching_template += 1

        cell_barcode_UMI = read.get_tag("CB")[0:16] + read.get_tag("UB")

        cell_barcode_UMI_exists_in_amplified = False

        for template in template_vectors:
            if cell_barcode_UMI in template_cell_barcode_alignments[template]:
                cell_barcode_UMI_exists_in_amplified = True
                break

        if not cell_barcode_UMI_exists_in_amplified:
            continue

        num_reads_with_cell_barcode_UMI_in_amplified += 1

        template_counts = {}

        for template in template_vectors:

            if cell_barcode_UMI in template_cell_barcode_alignments[template]:

                cell_barcode_UMI_reads = sum(x[2] for x in template_cell_barcode_alignments[template][cell_barcode_UMI])
                count = cell_barcode_UMI_reads/template_num_reads[template]

                template_counts[template] = count

        if len(template_counts) > 1:

            sorted_template_counts = sorted(template_counts.items(), key=lambda x: x[1], reverse=True)

            top_template_count = sorted_template_counts[0][1]
            second_template_count = sorted_template_counts[1][1]

            if top_template_count == second_template_count:
                template = None
            else:
                template = sorted_template_counts[0][0]
        else:
            template = sorted(template_counts.items(), key=lambda x: x[1], reverse=True)[0][0]

        if template is None:
            continue

        num_reads_unique_to_template += 1

        if cell_barcode_UMI not in template_cell_barcode_UMI_counts[template]:

            if len(template_cell_barcode_alignments[template][cell_barcode_UMI]) > 1:

                vector_counts = {}

                for score_tuple in template_cell_barcode_alignments[template][cell_barcode_UMI]:

                    vector = score_tuple[3]

                    if vector not in vector_counts:
                        vector_counts[vector] = score_tuple[2]
                    else:
                        vector_counts[vector] += score_tuple[2]
                        
                if len(vector_counts) > 1:
                    sorted_vector_counts = sorted(vector_counts.items(), key=lambda x: x[1], reverse=True)

                    highest_count = sorted_vector_counts[0][1]
                    highest_vector = sorted_vector_counts[0][0]

                    second_highest_count = sorted_vector_counts[1][1]
                    second_highest_vector = sorted_vector_counts[1][0]
                    
                    if highest_count == second_highest_count:
                        continue
                    else:
                        highest_vector = sorted_vector_counts[0][0]
                else:
                    highest_vector = vector

                highest_count = vector_counts[vector]
            else:
                highest_vector = template_cell_barcode_alignments[template][cell_barcode_UMI][0][3]
                highest_count = template_cell_barcode_alignments[template][cell_barcode_UMI][0][2]

            template_cell_barcode_UMI_vectors[template][cell_barcode_UMI] = highest_vector
            template_cell_barcode_UMI_counts[template][cell_barcode_UMI] = 1
            num_reads_unique_vector += 1
        else:
            template_cell_barcode_UMI_counts[template][cell_barcode_UMI] += 1
            num_reads_unique_vector += 1
            
    stats_dict[cell_set.name] = {
        "num_reads": num_reads,
        "num_reads_passing_filter": num_reads_passing_filter,
        "num_reads_matching_template": num_reads_matching_template,
        "num_reads_with_cell_barcode_UMI_in_amplified": num_reads_with_cell_barcode_UMI_in_amplified,
        "num_reads_unique_to_template": num_reads_unique_to_template,
        "num_reads_unique_vector": num_reads_unique_vector
    }
                
    for template, vectors in template_vectors.items():
        
        for vector in vectors:
            
            cell_barcode_counts = {x: 0 for x in adata.obs.index.values}
            
            for cell_barcode_UMI, virus_count in template_cell_barcode_UMI_counts[template].items():

                if vector != template_cell_barcode_UMI_vectors[template][cell_barcode_UMI]:
                    continue

                cell_barcode = cell_barcode_UMI[0:16] + "-1"

                if cell_barcode not in cell_barcode_counts:
                    continue

                cell_barcode_counts[cell_barcode] += 1

            vector_counts_df = pandas.DataFrame.from_dict(cell_barcode_counts, orient="index", columns=[template.name])
            
            if vector.name in adata.obs.columns:
                adata.obs.drop(vector.name, axis=1, inplace=True)
                
            adata.obs[vector.name] = vector_counts_df

    for virus in viruses:

        cell_barcode_counts = {x: 0 for x in adata.obs.index.values}

        for template, vectors in template_vectors.items():

            for vector in vectors:

                if vector.delivery_vehicle != virus:
                    continue

                for cell_barcode_UMI, virus_count in template_cell_barcode_UMI_counts[template].items():

                    if vector != template_cell_barcode_UMI_vectors[template][cell_barcode_UMI]:
                        continue

                    cell_barcode = cell_barcode_UMI[0:16] + "-1"

                    if cell_barcode not in cell_barcode_counts:
                        continue

                    cell_barcode_counts[cell_barcode] += 1

        virus_counts_df = pandas.DataFrame.from_dict(cell_barcode_counts, orient="index", columns=[template.name])
            
        if virus.name in adata.obs.columns:
            adata.obs.drop(virus.name, axis=1, inplace=True)
        
        adata.obs[virus.name] = virus_counts_df
    
    adata.write_h5ad(adata_file_path)
    
    display(adata.obs.sum())