# Imports

In [6]:
from pprint import pprint as pp
from pycl.pycl import head, bash

# Test package

In [3]:
from NanoCount.NanoCount import NanoCount_main

In [4]:
n = NanoCount_main (alignment_file="./data/cDNA_aligned_reads.bam", verbose=True)

Parse Bam file and filter low quality hits
	Mapped hits:21
	Unmapped hits:2
	Valid best hit:20
	Invalid secondary hit:1
Generate initial read/transcript compatibility index
Start EM abundance estimate
..
Convergence target reached after 2 rounds
Convergence value = 0.0


In [5]:
n = NanoCount_main (alignment_file="/home/aleg/Analyses/Aligned_dRNASeq_example/transcriptome_day5.bam", verbose=True)
display(n.count_df)

Parse Bam file and filter low quality hits
	Mapped hits:118237
	Wrong strand hits:1139
	Unmapped hits:38325
	Valid best hit:72427
	Valid secondary hit:31298
	Invalid secondary hit:12695
	Best hit with low query fraction aligned:1066
Generate initial read/transcript compatibility index
Start EM abundance estimate
........
Convergence target reached after 8 rounds
Convergence value = 0.004801809595549253


Unnamed: 0_level_0,raw,est_count,tpm
transcript_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
ENSDART00000010144|ENSDARG00000002768,3.193560e-02,2.313000e+03,2.313000e+09
ENSDART00000055062|ENSDARG00000037789,2.688224e-02,1.947000e+03,1.947000e+09
ENSDART00000055136|ENSDARG00000037840,1.905367e-02,1.380000e+03,1.380000e+09
ENSDART00000093609|ENSDARG00000063908,1.818383e-02,1.317000e+03,1.317000e+09
ENSDART00000023156|ENSDARG00000020850,1.186022e-02,8.590000e+02,8.590000e+08
ENSDART00000093612|ENSDARG00000063911,8.146133e-03,5.900000e+02,5.900000e+08
ENSDART00000123003|ENSDARG00000067990,7.874959e-03,5.703596e+02,5.703596e+08
ENSDART00000012644|ENSDARG00000017624,7.097575e-03,5.140561e+02,5.140561e+08
ENSDART00000093613|ENSDARG00000063912,6.862082e-03,4.970000e+02,4.970000e+08
ENSDART00000093606|ENSDARG00000063905,6.295995e-03,4.560000e+02,4.560000e+08


# Prototypes

In [182]:
#~~~~~~~~~~~~~~CLASS~~~~~~~~~~~~~~#
class Read (object):
    """
    """

    #~~~~~~~~~~~~~~MAGIC METHODS~~~~~~~~~~~~~~#
    def __init__(self):
        self.hit_list = []

    def add_pysam_hit (self, pysam_aligned_segment, **kwargs):
        self.hit_list.append (Hit (pysam_aligned_segment))
        
    def add_hit (self, hit, **kwargs):
        self.hit_list.append (hit)
    
    @property
    def n_hit (self):
        return len(self.hit_list)
    
    @property
    def primary_hit (self):
        for hit in self.hit_list:
            if not hit.secondary:
                return hit
    
    @property
    def secondary_hit_list (self):
        hit_list = []
        for hit in self.hit_list:
            if hit.secondary:
                hit_list.append (hit)
        return hit_list

    def __repr__(self):
        m = ""
        for r in self.hit_list:
            m +="\t\t{}\n".format(r)
        return (m)

class Hit ():
    """
    Extract relevant fields from the a pysam alignedSegment object
    """
    #~~~~~~~~~~~~~~MAGIC METHODS~~~~~~~~~~~~~~#

    def __init__(self, pysam_aligned_segment):
        """
        """
        self.qname = pysam_aligned_segment.query_name
        self.rname = pysam_aligned_segment.reference_name
        self.qlen = int (pysam_aligned_segment.query_length)
        self.align_len = int (pysam_aligned_segment.query_alignment_length)
        self.align_score = int (pysam_aligned_segment.get_tag("AS"))
        self.secondary = pysam_aligned_segment.is_secondary or pysam_aligned_segment.is_supplementary
        #self.rstart = int (pysam_aligned_segment.reference_start)
        #self.rend = int (pysam_aligned_segment.reference_end)
        self.strand = "-" if pysam_aligned_segment.is_reverse else "+"
        self.rlen = int (pysam_aligned_segment.reference_length)
        #self.qstart = int (pysam_aligned_segment.query_alignment_start)
        #self.qend = int (pysam_aligned_segment.query_alignment_end)
    
    @property
    def fraction_aligned (self):
        if not self.secondary and self.qlen:
            return self.align_len/self.qlen
    
    def __repr__(self):
        return "Query:{} | Reference:{} | Strand:{} | Query len:{} | Alignment len:{} | Align Score:{} | Secondary:{}".format(
            self.qname, self.rname, self.strand, self.qlen, self.align_len, self.align_score, self.secondary)
#         return "Query:{}-{}:{} ({} pb) / Reference:{}-{}:{}({}) ({} pb) / Alignment len:{} / Align Score:{} | Secondary:{}".format(
#             self.qname, self.qstart, self.qend, self.qlen,
#             self.rname, self.rstart, self.rend, self.strand, self.rlen,
#             self.align_len, self.align_score, self.secondary)

In [187]:
#~~~~~~~~~~~~~~IMPORTS~~~~~~~~~~~~~~#
# Standard library imports
from sys import stderr
from collections import Counter, defaultdict, namedtuple
import argparse

# Third party imports
import pysam
import pandas as pd

# Local imports

#~~~~~~~~~~~~~~MAIN FUNCTION~~~~~~~~~~~~~~#
class NanoCount_main ():

    #~~~~~~~~~~~~~~MAGIC METHODS~~~~~~~~~~~~~~#
    def __init__ (self,
        alignment_file,
        min_read_length = 50,
        min_fraction_aligned = 0.5,
        equivalent_threshold = 0.85,
        scoring_value = "alignment_score",
        convergence_target = 0.005,
        verbose = False):
        """
        """
        # Save args in self variables
        self.alignment_file = alignment_file
        self.min_read_length = min_read_length
        self.min_fraction_aligned = min_fraction_aligned
        self.equivalent_threshold = equivalent_threshold
        self.scoring_value = scoring_value
        self.convergence_target = convergence_target
        self.verbose = verbose
        
        # Collect all hits grouped by read name
        if self.verbose:
            stderr.write ("Parse Bam file and filter low quality hits\n")
        self.read_dict = self._parse_bam ()
        
        # Generate compatibility dict grouped by reads
        if self.verbose:
            stderr.write ("Generate initial read/transcript compatibility index\n")
        self.compatibility_dict = self._get_compatibility ()
        
        # Run EM to calculate abundance and update read-transcript compatibility
        if self.verbose:
            stderr.write ("Start EM abundance estimate\n")
        em_round = 0
        #all_abundance = []
        
        while True:
            em_round += 1
            if self.verbose:
                stderr.write (".")
                stderr.flush()
            
            # Calculate abundance from compatibility assignments
            self.abundance_dict = self._calculate_abundance (em_round)
            
            # Save abundance values to track evolution along rounds
            #s = pd.Series (self.abundance_dict, name=em_round)
            #s.sort_index (inplace=True)
            #all_abundance.append (s)
            
            # Trigger stop if convergence reached
            if self.convergence <= self.convergence_target:
                if self.verbose: 
                    stderr.write ("\nConvergence target reached after {} rounds\n".format(em_round))
                    stderr.write ("Convergence value = {}\n".format(self.convergence))
                break
            
            # Failsafe if convergence not reached
            if em_round == 100:
                if self.verbose: 
                    stderr.write ("\nCannot reach convergence after 100 rounds\n")
                    stderr.write ("Convergence value = {}\n".format(self.convergence))
                break
            
            # Update compatibility assignments
            self.compatibility_dict = self._update_compatibility ()
        
        # Collect all abundance values
        #self.all_abundance_df = pd.concat(all_abundance, axis=1) 
        #self.all_abundance_df.sort_values (by=em_round, inplace=True, ascending=False)
    
    #~~~~~~~~~~~~~~PROPERTY METHODS~~~~~~~~~~~~~~#
    @property
    def count_df (self):
        """ Transform abundance dict to a Dataframe contaning the estimated count and TPM per transcripts """
        df = pd.DataFrame (self.abundance_dict.most_common(), columns=["transcript_name","raw"])
        df.set_index("transcript_name", inplace=True, drop=True)
        df["est_count"] = df["raw"]*len(self.read_dict)
        df["tpm"] = df["est_count"] * 1000000
        return df
    
    #~~~~~~~~~~~~~~PUBLIC METHODS~~~~~~~~~~~~~~#
    def write_count_file (self, count_file):
        """
        """
        self.count_df.to_csv (count_file, sep="\t")
    
    #~~~~~~~~~~~~~~PRIVATE METHODS~~~~~~~~~~~~~~#
    def _parse_bam (self):
        """
        """
        c = Counter()
        
       # Parse bam files 
        read_dict = defaultdict (Read)
        with pysam.AlignmentFile (self.alignment_file) as bam:
            for hit in bam:
                if hit.is_unmapped:
                    c["Unmapped hits"] +=1
                elif hit.is_reverse:
                    c["Wrong strand hits"] +=1
                else:
                    c["Mapped hits"] +=1
                    read_dict [hit.query_name].add_pysam_hit (hit)
        
        # Filter hits
        filtered_read_dict = defaultdict (Read)        
        for query_name, read in read_dict.items ():
            # Check if best hit is valid
            best_hit = read.primary_hit
            # In case the primary hit was removed by filters
            if best_hit:          
                if self.min_read_length and best_hit.qlen < self.min_read_length:
                    c["Best hit too short"] +=1
                elif self.min_fraction_aligned and best_hit.fraction_aligned < self.min_fraction_aligned:
                    c["Low aligned fraction"] +=1
                else:                
                    filtered_read_dict [query_name].add_hit (best_hit)
                    c["Valid best hit"] +=1
                    for hit in read.secondary_hit_list:

                        # Filter out secondary hits based on minimap alignment score
                        if self.scoring_value == "alignment_score" and hit.align_score/best_hit.align_score < self.equivalent_threshold:
                            c["Invalid secondary hit"] += 1

                        # Filter out secondary hits based on minimap alignment length
                        elif self.scoring_value == "alignment_length" and hit.align_len/best_hit.align_len < self.equivalent_threshold:
                            c["Invalid secondary hit"] += 1

                        # Select valid secondary hits
                        else:
                            c["Valid secondary hit"] += 1
                            filtered_read_dict [query_name].add_hit (hit)
        
        # Write filtered reads counters
        if self.verbose:              
            for i, j in c.items():
                stderr.write ("\t{}:{}\n".format(i,j))
        
        return filtered_read_dict
    
    def _get_compatibility (self):
        """
        """
        compatibility_dict = defaultdict(dict)
        for read_name, read in self.read_dict.items ():
            for hit in read.hit_list:
                compatibility_dict[read_name][hit.rname] = score=1.0/read.n_hit
            
        return compatibility_dict
    
    def _calculate_abundance (self, em_round):
        """
        Calculate the abundance of the transcript set based on read-transcript compatibilities
        """
        abundance_dict = Counter()
        total = 0
        convergence = 0

        for read_name, comp in self.compatibility_dict.items ():
            for ref_name, score in comp.items():
                abundance_dict [ref_name] += score
                total += score

        for ref_name in abundance_dict.keys():
            abundance_dict [ref_name] = abundance_dict[ref_name] / total
            
            if em_round > 1:
                convergence += abs (self.abundance_dict [ref_name] - abundance_dict [ref_name])
        
        if em_round == 1:
            self.convergence = 1
        else:
            self.convergence = convergence

        return abundance_dict

    def _update_compatibility (self):
        """
        Update read-transcript compatibility based on transcript abundances
        """
        compatibility_dict = defaultdict (dict)

        for read_name, comp in self.compatibility_dict.items ():
            total=0
            for ref_name in comp.keys ():
                total += self.abundance_dict [ref_name]
        
            for ref_name in comp.keys ():
                compatibility_dict[read_name][ref_name] = self.abundance_dict [ref_name] / total
        
        return compatibility_dict

# Prototype test

In [188]:
n = NanoCount_main (alignment_file="./data/cDNA_aligned_reads.bam", min_read_length=200, min_fraction_aligned=0.5, equivalent_threshold=0.80, verbose=True)
n.write_count_file ("./data/count_file.tsv")
head  ("./data/count_file.tsv")

Parse Bam file and filter low quality hits
	Mapped hits:21
	Unmapped hits:2
	Valid best hit:20
	Valid secondary hit:1
Generate initial read/transcript compatibility index
Start EM abundance estimate
..

transcript_name	raw	est_count	tpm
YHR174W	0.15	3.0	3000000.0
YGR192C	0.1	2.0	2000000.0
YGR240C	0.05	1.0	1000000.0
YCR030C	0.05	1.0	1000000.0
YLR441C	0.05	1.0	1000000.0
YDR500C	0.05	1.0	1000000.0
YDR224C	0.05	1.0	1000000.0
YIL117C	0.05	1.0	1000000.0
YDL145C	0.05	1.0	1000000.0



Convergence target reached after 2 rounds
Convergence value = 0.0


In [189]:
n = NanoCount_main (alignment_file="/home/aleg/Analyses/Nanopore_yeast/minimap/cDNA_aligned_reads.bam", min_read_length=300, min_fraction_aligned=0.5, equivalent_threshold=0.80, verbose=True)
display(n.count_df)

Parse Bam file and filter low quality hits
	Mapped hits:13441
	Wrong strand hits:479
	Unmapped hits:794
	Low aligned fraction:482
	Valid best hit:10691
	Valid secondary hit:1475
	Invalid secondary hit:133
	Best hit too short:268
Generate initial read/transcript compatibility index
Start EM abundance estimate
....
Convergence target reached after 4 rounds
Convergence value = 0.0029325191073464093


Unnamed: 0_level_0,raw,est_count,tpm
transcript_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
YHR174W,2.795644e-01,2988.822528,2.988823e+09
YGR192C,3.470628e-02,371.044802,3.710448e+08
YLR110C,1.786549e-02,191.000000,1.910000e+08
YOL086C,1.440464e-02,154.000000,1.540000e+08
YKL060C,1.122439e-02,120.000000,1.200000e+08
YPR080W,9.353662e-03,100.000000,1.000000e+08
YBR118W,9.353662e-03,100.000000,1.000000e+08
YLR044C,8.979515e-03,96.000000,9.600000e+07
YKL152C,8.044149e-03,86.000000,8.600000e+07
YCR012W,7.857076e-03,84.000000,8.400000e+07


In [190]:
n = NanoCount_main (alignment_file="/home/aleg/Analyses/Aligned_dRNASeq_example/transcriptome_day5.bam", verbose=True)
display(n.count_df)

Parse Bam file and filter low quality hits
	Mapped hits:118237
	Wrong strand hits:1139
	Unmapped hits:38325
	Valid best hit:72427
	Valid secondary hit:34931
	Invalid secondary hit:9062
	Low aligned fraction:1066
Generate initial read/transcript compatibility index
Start EM abundance estimate
.........
Convergence target reached after 9 rounds
Convergence value = 0.004838874671979949


Unnamed: 0_level_0,raw,est_count,tpm
transcript_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
ENSDART00000010144|ENSDARG00000002768,3.193560e-02,2.313000e+03,2.313000e+09
ENSDART00000055062|ENSDARG00000037789,2.688224e-02,1.947000e+03,1.947000e+09
ENSDART00000055136|ENSDARG00000037840,1.905367e-02,1.380000e+03,1.380000e+09
ENSDART00000093609|ENSDARG00000063908,1.818383e-02,1.317000e+03,1.317000e+09
ENSDART00000023156|ENSDARG00000020850,1.186022e-02,8.590000e+02,8.590000e+08
ENSDART00000093612|ENSDARG00000063911,8.146133e-03,5.900000e+02,5.900000e+08
ENSDART00000123003|ENSDARG00000067990,7.449826e-03,5.395685e+02,5.395685e+08
ENSDART00000093613|ENSDARG00000063912,6.862082e-03,4.970000e+02,4.970000e+08
ENSDART00000012644|ENSDARG00000017624,6.803734e-03,4.927741e+02,4.927741e+08
ENSDART00000093606|ENSDARG00000063905,6.295995e-03,4.560000e+02,4.560000e+08
