In [1]:
import bz2
import pickle
from metaphlan.metaphlan import map2bbh
import numpy as np

with bz2.BZ2File( '/ssd1/wy/workspace/MetaPhlAn/test_data/databases/mpa_vJan21_TOY_CHOCOPhlAnSGB_202103.pkl', 'r' ) as a:
    mpa_pkl = pickle.load( a )

In [2]:
SGB_ANALYSIS = True
tax_units = "kpcofgst"

In [8]:

def plain_read_and_split(ofn):
    return (l.strip().split('\t') for l in ofn)
reads2markers = {}
n_metagenome_reads = None
avg_read_length = 1 #Set to 1 if it is not calculated from read_fastx
for r, c in plain_read_and_split(open("/ssd1/wy/workspace/MetaPhlAn/test_data/01.fasta.bowtie2out.txt")):
    if r.startswith('#') and 'nreads' in r:
        n_metagenome_reads = int(c)
    if r.startswith('#') and 'avg_read_length' in r:
        avg_read_length = float(c)
    else:
        reads2markers[r] = c

In [9]:
avg_read_length

93.71439067912249

In [1]:
class TaxClade:
    min_cu_len = -1
    markers2lens = None
    stat = None
    perc_nonzero = None
    quantile = None
    avoid_disqm = False
    avg_read_length = 1

    def __init__( self, name, tax_id, uncl = False):
        self.children, self.markers2nreads = {}, {}
        self.name, self.father = name, None
        self.uncl, self.subcl_uncl = uncl, False
        self.abundance, self.uncl_abundance = None, 0
        self.nreads, self.uncl_nreads = 0, 0
        self.tax_id = tax_id

    def add_child( self, name, tax_id ):
        new_clade = TaxClade( name, tax_id )
        self.children[name] = new_clade
        new_clade.father = self
        return new_clade


    def get_terminals( self ):
        terms = []
        if not self.children:
            return [self]
        for c in self.children.values():
            terms += c.get_terminals()
        return terms

    def get_full_taxids( self ):
        fullname = ['']
        if self.tax_id:
            fullname = [self.tax_id]
        cl = self.father
        while cl:
            fullname = [cl.tax_id] + fullname
            cl = cl.father
        return "|".join(fullname[1:])

    def get_full_name( self ):
        fullname = [self.name]
        cl = self.father
        while cl:
            fullname = [cl.name] + fullname
            cl = cl.father
        return "|".join(fullname[1:])

    def get_normalized_counts( self ):
        return [(m,float(n)*1000.0/(np.absolute(self.markers2lens[m] - self.avg_read_length) +1) )
                    for m,n in self.markers2nreads.items()]

    def compute_mapped_reads( self ):    
        tax_level = 't__' if SGB_ANALYSIS else 's__'
        if self.nreads != 0 or self.name.startswith(tax_level):
            return self.nreads
        for c in self.children.values():
            self.nreads += c.compute_mapped_reads()
        return self.nreads
        
    def compute_abundance( self ):
        if self.abundance is not None: return self.abundance

        sum_ab = sum([c.compute_abundance() for c in self.children.values()])

        # rat_nreads = sorted([(self.markers2lens[marker], n_reads)
        #                             for marker,n_reads in self.markers2nreads.items()],
        #                                     key = lambda x: x[1])

        rat_nreads, removed = [], []
        for marker, n_reads in sorted(self.markers2nreads.items(),key=lambda x:x[0]):
            misidentified = False

            if not self.avoid_disqm:
                for ext in self.markers2exts[marker]:
                    ext_clade = self.taxa2clades[ext]
                    m2nr = ext_clade.markers2nreads
                    
                    tocladetmp = ext_clade
                    while len(tocladetmp.children) == 1:
                        tocladetmp = list(tocladetmp.children.values())[0]
                        m2nr = tocladetmp.markers2nreads

                    nonzeros = sum([v>0 for v in m2nr.values()])
                    if len(m2nr):
                        if float(nonzeros) / len(m2nr) > self.perc_nonzero:
                            misidentified = True
                            removed.append( (self.markers2lens[marker],n_reads) )
                            break
            if not misidentified:
                rat_nreads.append( (self.markers2lens[marker],n_reads) )

        if not self.avoid_disqm and len(removed):
            n_rat_nreads = float(len(rat_nreads))
            n_removed = float(len(removed))
            n_tot = n_rat_nreads + n_removed
            n_ripr = 10

            if len(self.get_terminals()) < 2:
                n_ripr = 0

            if "k__Viruses" in self.get_full_name():
                n_ripr = 0

            if n_rat_nreads < n_ripr and n_tot > n_rat_nreads:
                rat_nreads += removed[:n_ripr-int(n_rat_nreads)]


        rat_nreads = sorted(rat_nreads, key = lambda x: x[1])

        rat_v,nreads_v = zip(*rat_nreads) if rat_nreads else ([],[])
        rat, nrawreads, loc_ab = float(sum(rat_v)) or -1.0, sum(nreads_v), 0.0
        quant = int(self.quantile*len(rat_nreads))
        ql,qr,qn = (quant,-quant,quant) if quant else (None,None,0)

        if not SGB_ANALYSIS and self.name[0] == 't' and (len(self.father.children) > 1 or "_sp" in self.father.name or "k__Viruses" in self.get_full_name()):
            non_zeros = float(len([n for r,n in rat_nreads if n > 0]))
            nreads = float(len(rat_nreads))
            if nreads == 0.0 or non_zeros / nreads < 0.7:
                self.abundance = 0.0
                return 0.0

        if rat < 0.0:
            pass
        elif self.stat == 'avg_g' or (not qn and self.stat in ['wavg_g','tavg_g']):
            loc_ab = nrawreads / rat if rat >= 0 else 0.0
        elif self.stat == 'avg_l' or (not qn and self.stat in ['wavg_l','tavg_l']):
            loc_ab = np.mean([float(n)/(np.absolute(r - self.avg_read_length) + 1) for r,n in rat_nreads])
        elif self.stat == 'tavg_g':
            wnreads = sorted([(float(n)/(np.absolute(r-self.avg_read_length)+1),(np.absolute(r - self.avg_read_length)+1) ,n) for r,n in rat_nreads], key=lambda x:x[0])
            den,num = zip(*[v[1:] for v in wnreads[ql:qr]])
            loc_ab = float(sum(num))/float(sum(den)) if any(den) else 0.0
        elif self.stat == 'tavg_l':
            loc_ab = np.mean(sorted([float(n)/(np.absolute(r - self.avg_read_length) + 1) for r,n in rat_nreads])[ql:qr])
        elif self.stat == 'wavg_g':
            vmin, vmax = nreads_v[ql], nreads_v[qr]
            wnreads = [vmin]*qn+list(nreads_v[ql:qr])+[vmax]*qn
            loc_ab = float(sum(wnreads)) / rat
        elif self.stat == 'wavg_l':
            wnreads = sorted([float(n)/(np.absolute(r - self.avg_read_length) + 1) for r,n in rat_nreads])
            vmin, vmax = wnreads[ql], wnreads[qr]
            wnreads = [vmin]*qn+list(wnreads[ql:qr])+[vmax]*qn
            loc_ab = np.mean(wnreads)
        elif self.stat == 'med':
            loc_ab = np.median(sorted([float(n)/(np.absolute(r - self.avg_read_length) +1) for r,n in rat_nreads])[ql:qr])

        self.abundance = loc_ab
        if rat < self.min_cu_len and self.children:
            self.abundance = sum_ab
        elif loc_ab < sum_ab:
            self.abundance = sum_ab

        if self.abundance > sum_ab and self.children: # *1.1??
            self.uncl_abundance = self.abundance - sum_ab
        self.subcl_uncl = not self.children and self.name[0] not in tax_units[-2:]

        return self.abundance

    def get_all_abundances( self ):
        ret = [(self.name, self.tax_id, self.abundance)]
        if self.uncl_abundance > 0.0:
            lchild = list(self.children.values())[0].name[:3]
            ret += [(lchild+self.name[3:]+"_unclassified", "", self.uncl_abundance)]
        if self.subcl_uncl and self.name[0] != tax_units[-2]:
            cind = tax_units.index( self.name[0] )
            ret += [(   tax_units[cind+1]+self.name[1:]+"_unclassified","",
                        self.abundance)]
        for c in self.children.values():
            ret += c.get_all_abundances()
        return ret


In [4]:
class TaxTree:
    def __init__( self, mpa, markers_to_ignore = None ): #, min_cu_len ):
        self.root = TaxClade( "root", 0)
        self.all_clades, self.markers2lens, self.markers2clades, self.taxa2clades, self.markers2exts = {}, {}, {}, {}, {}
        TaxClade.markers2lens = self.markers2lens
        TaxClade.markers2exts = self.markers2exts
        TaxClade.taxa2clades = self.taxa2clades
        self.avg_read_length = 1

        for clade, value in mpa['taxonomy'].items():
            clade = clade.strip().split("|")
            if isinstance(value,tuple):
                taxids, lenc = value
                taxids = taxids.strip().split("|")
            if isinstance(value,int):
                lenc = value
                taxids = None

            father = self.root
            for i in range(len(clade)):
                clade_lev = clade[i]
                if SGB_ANALYSIS:
                    clade_taxid = taxids[i] if i < 8 and taxids is not None else None
                else:
                    clade_taxid = taxids[i] if i < 7 and taxids is not None else None
                if not clade_lev in father.children:
                    father.add_child(clade_lev, tax_id=clade_taxid)
                    self.all_clades[clade_lev] = father.children[clade_lev]
                if SGB_ANALYSIS: father = father.children[clade_lev]
                if clade_lev[0] == "t":
                    self.taxa2clades[clade_lev[3:]] = father
                if not SGB_ANALYSIS: father = father.children[clade_lev]
                if clade_lev[0] == "t":
                    father.glen = lenc

        def add_lens( node ):
            if not node.children:
                return node.glen
            lens = []
            for c in node.children.values():
                lens.append( add_lens( c ) )
            node.glen = min(np.mean(lens), np.median(lens))
            return node.glen
        
        add_lens(self.root)

        # for k,p in mpa_pkl['markers'].items():
        for k, p in mpa['markers'].items():
            if k in markers_to_ignore:
                continue
            self.markers2lens[k] = p['len']
            self.markers2clades[k] = p['clade']
            self.add_reads(k, 0)
            self.markers2exts[k] = p['ext']

    def set_min_cu_len( self, min_cu_len ):
        TaxClade.min_cu_len = min_cu_len

    def set_stat( self, stat, quantile, perc_nonzero, avg_read_length, avoid_disqm = False):
        TaxClade.stat = stat
        TaxClade.perc_nonzero = perc_nonzero
        TaxClade.quantile = quantile
        TaxClade.avoid_disqm = avoid_disqm
        TaxClade.avg_read_length = avg_read_length

    def add_reads(  self, marker, n,
                    add_viruses = False,
                    ignore_eukaryotes = False,
                    ignore_bacteria = False, ignore_archaea = False, 
                    ignore_ksgbs = False, ignore_usgbs = False  ):
        clade = self.markers2clades[marker]
        cl = self.all_clades[clade]
        if ignore_bacteria or ignore_archaea or ignore_eukaryotes:
            cn = cl.get_full_name()
            if ignore_archaea and cn.startswith("k__Archaea"):
                return (None, None)
            if ignore_bacteria and cn.startswith("k__Bacteria"):
                return (None, None)
            if ignore_eukaryotes and cn.startswith("k__Eukaryota"):
                return (None, None)
        if not SGB_ANALYSIS and not add_viruses:
            cn = cl.get_full_name()
            if not add_viruses and cn.startswith("k__Vir"):
                return (None, None)
        if SGB_ANALYSIS and (ignore_ksgbs or ignore_usgbs):
            cn = cl.get_full_name()
            if ignore_ksgbs and not '_SGB' in cn.split('|')[-2]:
                return (None, None)
            if ignore_usgbs and '_SGB' in cn.split('|')[-2]:
                return (None, None)
        # while len(cl.children) == 1:
            # cl = list(cl.children.values())[0]
        cl.markers2nreads[marker] = n
        return (cl.get_full_name(), cl.get_full_taxids(), )


    def markers2counts( self ):
        m2c = {}
        for _ ,v in self.all_clades.items():
            for m,c in v.markers2nreads.items():
                m2c[m] = c
        return m2c

    def clade_profiles( self, tax_lev, get_all = False  ):
        cl2pr = {}
        for k,v in self.all_clades.items():
            if tax_lev and not k.startswith(tax_lev):
                continue
            prof = v.get_normalized_counts()
            if not get_all and ( len(prof) < 1 or not sum([p[1] for p in prof]) > 0.0 ):
                continue
            cl2pr[v.get_full_name()] = prof
        return cl2pr

    def relative_abundances( self, tax_lev  ):
        clade2abundance_n = dict([(tax_label, clade) for tax_label, clade in self.all_clades.items()
                    if tax_label.startswith("k__") and not clade.uncl])

        clade2abundance, clade2est_nreads, tot_ab, tot_reads = {}, {}, 0.0, 0

        for tax_label, clade in clade2abundance_n.items():
            tot_ab += clade.compute_abundance()

        for tax_label, clade in clade2abundance_n.items():
            for clade_label, tax_id, abundance in sorted(clade.get_all_abundances(), key=lambda pars:pars[0]):
                if SGB_ANALYSIS or clade_label[:3] != 't__':
                    if not tax_lev:
                        if clade_label not in self.all_clades:
                            to = tax_units.index(clade_label[0])
                            t = tax_units[to-1]
                            clade_label = t + clade_label.split("_unclassified")[0][1:]
                            tax_id = self.all_clades[clade_label].get_full_taxids()
                            clade_label = self.all_clades[clade_label].get_full_name()
                            spl = clade_label.split("|")
                            clade_label = "|".join(spl+[tax_units[to]+spl[-1][1:]+"_unclassified"])
                            glen = self.all_clades[spl[-1]].glen
                        else:
                            glen = self.all_clades[clade_label].glen
                            tax_id = self.all_clades[clade_label].get_full_taxids()
                            tax_level = 't__' if SGB_ANALYSIS else 's__' 
                            if tax_level in clade_label and abundance > 0:
                                self.all_clades[clade_label].nreads = int(np.floor(abundance*glen))

                            clade_label = self.all_clades[clade_label].get_full_name()
                    elif not clade_label.startswith(tax_lev):
                        if clade_label in self.all_clades:
                            glen = self.all_clades[clade_label].glen
                        else:
                            glen = 1.0
                        continue
                    clade2abundance[(clade_label, tax_id)] = abundance
        
        for tax_label, clade in clade2abundance_n.items():
            tot_reads += clade.compute_mapped_reads()

        for clade_label, clade in self.all_clades.items():
            if SGB_ANALYSIS or clade.name[:3] != 't__':
                nreads = clade.nreads
                clade_label = clade.get_full_name()
                tax_id = clade.get_full_taxids()
                clade2est_nreads[(clade_label, tax_id)] = nreads

        ret_d = dict([( tax, float(abundance) / tot_ab if tot_ab else 0.0) for tax, abundance in clade2abundance.items()])

        ret_r = dict([( tax, (abundance, clade2est_nreads[tax] )) for tax, abundance in clade2abundance.items() if tax in clade2est_nreads])

        if tax_lev:
            ret_d[("UNCLASSIFIED", '-1')] = 1.0 - sum(ret_d.values())
        return ret_d, ret_r, tot_reads


In [39]:
# REPORT_MERGED = mpa_pkl.get('merged_taxon',False)
tree = TaxTree( mpa_pkl, {} )
tree.set_min_cu_len( 2000 )
# minimum total nucleotide length for the markers in a clade for estimating the abundance without considering sub-clade abundances
markers2reads, n_metagenome_reads, avg_read_length = map2bbh("/ssd1/wy/workspace/MetaPhlAn/test_data/01.fasta.bowtie2out.txt", 
                                                           5, "bowtie2out", None, None, None, '1992')
tree.set_stat( 'tavg_g',0.2, 0.33, avg_read_length,False)

map_out = []
for marker,reads in sorted(markers2reads.items(), key=lambda pars: pars[0]):
    if marker not in tree.markers2lens:
        continue
    tax_seq, ids_seq = tree.add_reads( marker, len(reads),
                                add_viruses = False,
                                ignore_eukaryotes = False,
                                ignore_bacteria = False,
                                ignore_archaea = False,
                                ignore_ksgbs = False,
                                ignore_usgbs = False
                                )
    if tax_seq:
        map_out +=["\t".join([r,tax_seq, ids_seq]) for r in sorted(reads)]
        
cl2ab, _, tot_nreads = tree.relative_abundances(
                        'a'+"__" if 'a' != 'a' else None )



In [25]:
for (taxstr, taxid), relab in cl2ab.items():
   if relab > 0.0:
      print(taxstr,taxid,relab)

k__Bacteria|p__Actinobacteria|c__Actinobacteria 2|201174|1760 1.0
k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Corynebacteriales|f__Corynebacteriaceae 2|201174|1760|85007|1653 0.6086077978332424
k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Micrococcales|f__Micrococcaceae 2|201174|1760|85006|1268 0.3913922021667576
k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Corynebacteriales|f__Corynebacteriaceae|g__Corynebacterium 2|201174|1760|85007|1653|1716 0.6086077978332424
k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Micrococcales|f__Micrococcaceae|g__Rothia 2|201174|1760|85006|1268|32207 0.3913922021667576
k__Bacteria 2 1.0
k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Corynebacteriales 2|201174|1760|85007 0.6086077978332424
k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Micrococcales 2|201174|1760|85006 0.3913922021667576
k__Bacteria|p__Actinobacteria 2|201174 1.0
k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Corynebacteriales|f__Corynebacteriaceae|g__Cor

In [12]:
outpred = [(taxstr, taxid,round(relab*100.0,5)) for (taxstr, taxid), relab in cl2ab.items() if relab > 0.0]

In [14]:
for clade, taxid, relab in sorted(  outpred, reverse=True,
                                        key=lambda x:x[2]+(100.0*(8-(x[0].count("|"))))):
    add_repr = ''
    print( "\t".join( [clade, 
                                taxid, 
                                str(relab*1.0), 
                                add_repr
                            ] ) + "\n" )

k__Bacteria	2	100.0	

k__Bacteria|p__Actinobacteria	2|201174	100.0	

k__Bacteria|p__Actinobacteria|c__Actinobacteria	2|201174|1760	100.0	

k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Corynebacteriales	2|201174|1760|85007	60.86078	

k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Micrococcales	2|201174|1760|85006	39.13922	

k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Corynebacteriales|f__Corynebacteriaceae	2|201174|1760|85007|1653	60.86078	

k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Micrococcales|f__Micrococcaceae	2|201174|1760|85006|1268	39.13922	

k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Corynebacteriales|f__Corynebacteriaceae|g__Corynebacterium	2|201174|1760|85007|1653|1716	60.86078	

k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Micrococcales|f__Micrococcaceae|g__Rothia	2|201174|1760|85006|1268|32207	39.13922	

k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Corynebacteriales|f__Corynebacteriaceae|g__Corynebacterium|s__Corynebacterium_matr

k__Bacteria|p__Actinobacteria|c__Actinobacteria 2|201174|1760 1.0
k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Corynebacteriales|f__Corynebacteriaceae 2|201174|1760|85007|1653 0.6086077978332424
k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Micrococcales|f__Micrococcaceae 2|201174|1760|85006|1268 0.3913922021667576
k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Corynebacteriales|f__Corynebacteriaceae|g__Corynebacterium 2|201174|1760|85007|1653|1716 0.6086077978332424
k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Micrococcales|f__Micrococcaceae|g__Rothia 2|201174|1760|85006|1268|32207 0.3913922021667576
k__Bacteria 2 1.0
k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Corynebacteriales 2|201174|1760|85007 0.6086077978332424
k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Micrococcales 2|201174|1760|85006 0.3913922021667576
k__Bacteria|p__Actinobacteria 2|201174 1.0
k__Bacteria|p__Actinobacteria|c__Actinobacteria|o__Corynebacteriales|f__Corynebacteriaceae|g__Cor

In [14]:
out_stream =  open("a.txt","w") #sys.stdout
MPA2_OUTPUT = False
CAMI_OUTPUT = False

In [37]:
with open("a.txt","w") as outf:
    outf.write('#{}\n'.format('mpa_vJan21_TOY_CHOCOPhlAnSGB_202103'))
    outf.write('#{}\n'.format(' '.join(sys.argv)))
    outf.write('#{} reads processed\n'.format(n_metagenome_reads))
    outf.write('#' + '\t'.join(('SampleID', 'Metaphlan_Analysis')) + '\n')
    outf.write('#clade_name\tNCBI_tax_id\trelative_abundance\tadditional_species\n')


TypeError: unsupported operand type(s) for *: 'NoneType' and 'int'

TypeError: unsupported operand type(s) for *: 'NoneType' and 'int'

In [None]:
'a'+"__" if 'tax_lev' != 'a' else None 

'a__'