In [None]:
import os
import pandas as pd
from ete3 import Tree
from collections import Counter
import numpy as np
import itertools

In [1]:
def is_HR(sampleNames,sample_type_dict):
    """given a list of sample names uses sample type dictionary to determine how many sample types are present
    designates ASVs as host restricted = 1 sample type or mixed = multiple sample types.
    Captive sample types are not considered so some clades will have a 0 sample type length. they can fall within
    host-restricted clades or mixed clades or neither"""
    sampleTypes = [sample_type_dict[name] for name in sampleNames]
    sampleTypes = [x.replace('non_industrialized_','').replace('industrialized_','') for x in sampleTypes]
    neutral_sampleTypes = ['captive_gorilla','captive_bonobo','captive_chimp','captive_orangutan']
    captiveNames = [name for name in sampleNames if 'captive' in sample_type_dict[name]]
    HR_sampleTypes = list(set(sampleTypes) - set(neutral_sampleTypes))
    HR_sampleNum = len([x for x in sampleTypes if x not in neutral_sampleTypes])
    CP_sampleTypes = list(set(sampleTypes) & set(neutral_sampleTypes))
    CP_sampleNum = len([x for x in sampleTypes if x in neutral_sampleTypes])
    CP_pres = True if len(CP_sampleTypes) > 0 else False
    if len(HR_sampleTypes) == 0:
        HR_cat,HR_type='Unique_CP','Unique_CP'
    if len(HR_sampleTypes) == 1: #identifies host-restricted clades
        HR_cat,HR_type='HR','HR_'+HR_sampleTypes[0]  
    if len(HR_sampleTypes) > 1: 
        HR_cat = 'MX'
        if len(HR_sampleTypes) == 2:
            if 'human' in HR_sampleTypes:
                HR_type = 'MX_human_single_wild_ape'
            else:
                HR_type = 'MX_2_wild_apes'
        if len(HR_sampleTypes) == 3:
            if 'human' in HR_sampleTypes:
                HR_type = 'MX_human_2_wild_apes'
            else:
                HR_type = 'MX_3_wild_apes'
        if len(HR_sampleTypes) == 4:
            HR_type = 'MX_4_hominids'
        
    return(HR_sampleTypes,HR_sampleNum,HR_cat,HR_type,CP_pres,CP_sampleNum,CP_sampleTypes,captiveNames)

def asv_hr_table(asv_table_file,metadata_file,tax_table_file):
    asv_table = pd.read_csv(asv_table_file,sep='\t',index_col=0)
    sampleNames = asv_table.apply(lambda row: list(row.index[row>0]),axis=1)
    asv_df = sampleNames.reset_index()
    asv_df.columns = ['ASV','sampleNames']
    asv_df['sampleNum'] = asv_df['sampleNames'].apply(lambda names: len(names))
    
    #add host restriction info
    metadata = pd.read_csv(metadata_file,sep='\t',index_col=None)
    sample_type_dict = dict(zip(metadata['X.SampleID'], metadata['Description'])) 
    hr = asv_df['sampleNames'].apply(lambda x: pd.Series(is_HR(x,sample_type_dict),
                                                         index=['HR_sampleTypes','HR_sampleNum','HR_cat','HR_type',
                                                                'CP_pres','CP_sampleNum','CP_sampleTypes','captiveNames']))
    asv_hr_df = asv_df.merge(hr,left_index=True, right_index=True)
    
    #add taxonomic info
    tax_table = pd.read_csv(tax_table_file,sep='\t',index_col=None)
    tax_table = tax_table[['ASV','Phylum','Order','Family','Genus']]
    tax_table['Family'] = tax_table['Family'].fillna('unclassified') 
    tax_table['Family'] = tax_table['Family'].apply(lambda x: 'unclassified' if 'unclassified' in x else x)
    tax_table['Genus'] = tax_table['Genus'].fillna('unclassified') 
    tax_table['Genus'] = tax_table['Genus'].apply(lambda x: 'unclassified' if 'unclassified' in x else x)
    asv_full = asv_hr_df.merge(tax_table,on='ASV',how='left')
    
    return(asv_full)

In [2]:
def get_consensus_taxonomy(listASVs,tax_fam_dict,tax_gen_dict):
    """Because multiple ASVs in a single clade, we determine the taxonomy of the clade by consensus.
    If all ASVs in the clade belong to a given bacterial taxonomic family/genus the clade is assigned that taxonomy
    If the clade contains ASVs assigned to multiple families/genera the taxonomy is labeled as unclassified."""
    fam = list(set([tax_fam_dict[ASV] for ASV in listASVs])-set(['unclassified']))
    gen = list(set([tax_gen_dict[ASV] for ASV in listASVs])-set(['unclassified']))
    fam = fam[0].split('__')[1] if len(fam) == 1 else 'unclassified'
    if fam == 'unclassified':
        gen='unclassified'
    else:
        gen = gen[0].split('__')[1] if len(gen) == 1 else 'unclassified'  
    return(fam+'_'+gen)   

def search_clades(tree, samples_cutoff, BS_support,ASV_sampleName_dict):
    """Finds nodes with at least 50% BS support containing ASVs only found in a single wild ape species
    ie wild_gorilla, wild_chimp or wild_bonobo, cpat"""
    clades_prelim = []
    counter = 1
    for n in tree.traverse():
        if n.support > float(BS_support): #makes sure Bootstrap support is over threshold
            ASVs = [leaf.name for leaf in n.iter_leaves() if 'ASV' in leaf.name]
            listoflists = [ASV_sampleName_dict[ASV] for ASV in ASVs]
            sampleNames = list(set(list(itertools.chain.from_iterable(listoflists))))
            cladeName='clade_'+str(counter)
            counter+=1
            clade = [cladeName,ASVs,sampleNames]
            if len(sampleNames)>samples_cutoff:
                clades_prelim.append(clade)
    clades_prelim = pd.DataFrame(clades_prelim, columns = 
                          ['cladeName','ASVs','sampleNames'])
    clades_prelim['sampleNum'] = clades_prelim['sampleNames'].apply(lambda x: len(x))
    clades_prelim['ASVsNum'] = clades_prelim['ASVs'].apply(lambda x: len(x))
    return(clades_prelim)

def eliminate_redundant_clades(clades_df,offlimits_ASVs):
    """sorts clades and returns the largest non overlapping clade """
    df = clades_df.sort_values('ASVsNum',ascending=False) #start with the largest clades first
    NRclades = []
    for index, row in df.iterrows():
        if len(set(row['ASVs']) & set(offlimits_ASVs)) == 0: 
            offlimits_ASVs = offlimits_ASVs + row['ASVs']
            NRclades.append(row['cladeName']) 
    res = df[df['cladeName'].isin(NRclades)]   
    return(res)  

def expand_clades(clades_df):
    ASV = clades_df.apply(lambda x: pd.Series(x['ASVs']),axis=1).stack().reset_index(level=1, drop=True)
    ASV.name = 'ASVs'
    clades_ASVs_df = clades_df.drop('ASVs', axis=1).join(ASV)
    return(clades_ASVs_df)
                                 
def host_restricted_clades(asv_table_file,metadata_file,tax_table_file,tree_file):
    asv_table = pd.read_csv(asv_table_file,sep='\t',index_col=0)
    sampleNames = asv_table.apply(lambda row: list(row.index[row>0]),axis=1)
    ASV_sampleName_dict = dict(zip(sampleNames.index,sampleNames))

    #sample to sample type category
    metadata = pd.read_csv(metadata_file,sep='\t',index_col=None)
    sample_type_dict = dict(zip(metadata['X.SampleID'], metadata['Description']))

    #taxonomic info, family and genus
    tax_table = pd.read_csv(tax_table_file,sep='\t',index_col=None)
    tax_table['Family'] = tax_table['Family'].apply(lambda x: 'unclassified' if 'unclassified' in x else x)
    tax_table['Genus'] = tax_table['Genus'].apply(lambda x: 'unclassified' if 'unclassified' in x else x)
    tax_fam_dict = dict(zip(tax_table['ASV'], tax_table['Family']))
    tax_gen_dict = dict(zip(tax_table['ASV'], tax_table['Genus']))
    
    #search tree for clades
    fulltree = Tree(tree_file, format=0)
    clades_prelim = search_clades(fulltree, 5, .5, ASV_sampleName_dict)
    hr =clades_prelim['sampleNames'].apply(lambda x:  pd.Series(is_HR(x,sample_type_dict),
                index=['HR_sampleTypes','HR_sampleNum','HR_cat','HR_type',
                'CP_pres','CP_sampleNum','CP_sampleTypes','captiveNames']))
    clades_prelim_hr = clades_prelim.merge(hr,right_index=True,left_index=True)

    #identify HR clades
    HR_clades = clades_prelim_hr[clades_prelim_hr['HR_cat']=='HR']
    HR_clades = HR_clades[HR_clades['HR_sampleNum']>=5]
    HR_clades = eliminate_redundant_clades(HR_clades,[])
    HR_clades_ASVs = expand_clades(HR_clades)

    #identify MX clades that don't contain any HR clades
    MX_clades = clades_prelim_hr[clades_prelim_hr['HR_cat']=='MX']
    MX_clades = MX_clades[MX_clades['sampleNum']>=5]
    MX_clades = eliminate_redundant_clades(MX_clades,offlimits_ASVs=list(HR_clades_ASVs['ASVs']))
    MX_clades_ASVs = expand_clades(MX_clades)
    
    #identify CP clades that don't contain any HR or MX clades
    CP_clades = clades_prelim_hr[clades_prelim_hr['HR_cat']=='Unique_CP']
    CP_clades = CP_clades[CP_clades['sampleNum']>=5]
    HR_MX_ASVs = list(HR_clades_ASVs['ASVs']) + list(MX_clades_ASVs['ASVs'])
    CP_clades = eliminate_redundant_clades(CP_clades,offlimits_ASVs=HR_MX_ASVs)
    CP_clades_ASVs = expand_clades(CP_clades)
    
    #merge dataframes
    clades = pd.concat([HR_clades,MX_clades,CP_clades])
    clades.reset_index(drop=True, inplace=True)
    clades_ASVs = pd.concat([HR_clades_ASVs,MX_clades_ASVs,CP_clades_ASVs])
    clades_ASVs.reset_index(drop=True, inplace=True)
    
    
    clades['cladeTax']=clades['ASVs'].apply(lambda x: 
        get_consensus_taxonomy(x,tax_fam_dict,tax_gen_dict))
    
    sample_type_counts =  metadata['Description'].value_counts()
    print(sample_type_counts)
    description_df = clades['sampleNames'].apply(lambda l: pd.Series(
    [sample_type_dict[name] for name in l]).value_counts())
    description_df = description_df.fillna(0)  
    sample_type_percent = description_df/sample_type_counts
    clades=clades.merge(sample_type_percent,left_index=True,right_index=True)  
    
    return(clades,clades_ASVs)