In [1]:
from mgkit.io import gff
from mgkit import kegg
import mgkit
import mgkit.plots
from collections import Counter
from glob import glob
import seaborn as sns
import pandas as pd
from mgkit.utils import dictionary
import itertools
import networkx as nx
from mgkit import graphs
import os, sys
from collections import Counter
from collections import defaultdict
import scipy as sp
import numpy as np
import scipy.stats
import datetime
import timeit
import re
import platform
import getpass
import argparse

mgkit.logger.config_log()

#On Server
#input_dir=""
#output_dir=""
user=getpass.getuser()
if "windows" in platform.platform().lower():
    windows=True
else:
    windows=False
if windows:
    core=os.path.join("C:\Users",user)
    g_drive="Google Drive\Honours"
else:
    core=os.path.join("/home",user)
    g_drive="grive/Honours"
print core

#Made some OS agnostic changes
#gff_dir=os.path.join(*[core,"Documents","Hons","Seaquence","francesco_data","gff_bins-2016-06-14"])
gff_dir=os.path.join(*[core,"Documents","Hons","Seaquence","francesco_data","gff_bins-2016-06-14b"])
#gff_dir=core+"/Documents/Hons/Seaquence/francesco_data/gff_bins-2016-06-14"

tax_file=os.path.join(*[core,g_drive,"metabolic_analysis","ID_TAX_BINS_TEMP.txt"])
#tax_file=core+"/grive/metabolic_analysis/ID_TAX_BINS_TEMP.txt"

output_dir=os.path.join(*[core,g_drive,"metabolic_analysis"])
#output_dir=core+"/grive/metabolic_analysis/"

coral_kegg=os.path.join(*[output_dir,"KO_hits","plut.pathways.txt"])
#coral_kegg=output_dir+"KO_hits/plut.pathways.txt"

cmn_cpds=os.path.join(*[output_dir,"Automated_Network_Analyses","boring_cps.txt"])
#cmn_cpds=output_dir+"Automated_Network_Analyses/boring_cps.txt"


symbiodinium_kegg=os.path.join(*[output_dir,"KO_hits","SymbC15_firstpass_ko_mapping_protID_distinct_KO.txt"])
#symbiodinium_kegg=output_dir+"KO_hits/SymbC15_firstpass_ko_mapping_protID_distinct_KO.txt"

#microbial_kegg=os.path.join(*[output_dir,"KO_hits","Microbial_KO_mapping_protID.txt"])
microbial_kegg=os.path.join(*[output_dir,"KO_hits","TREMBL_SWISSPROT_Microbial_KO_mapping_protID.txt"])
#microbial_kegg=output_dir+"KO_hits/Microbial_KO_mapping_protID.txt"
all_kegg=os.path.join(*[output_dir,"KO_hits","all_kos.txt"])

hmm_dir=os.path.join(*[core,g_drive,"HMM_searches","Symbioses_test","euk_repeat_results"])

database_dir=os.path.join(output_dir,"Databases")

abundance_file=os.path.join(*[output_dir,"Misc_files","id_trimmed_relative_enriched_bin_abundance.tsv"])

raw_coverage_contig=os.path.join(*[output_dir,"Misc_files","bin_contig_sep_coverages.tsv"])

completeness_contamination=os.path.join(*[output_dir,"Misc_files","Completeness_Contamination_data.txt"])

gene_dir=os.path.join(*[core,g_drive,"eukaryote_like_repeats","gene_hits"])#"/home/baker/Documents/MountedDrive/seaquence/data/eukaryote_like_repeats/gene_hits"

plots_dir=os.path.join(output_dir,'Plots')

def load_cmn_cpds(cmn_cpds):
    cpds=set([])
    with open(cmn_cpds,'r') as cpd_list:
        for line in cpd_list:
            cpd=line.strip().split("\t")[0]
            cpds.add(cpd)
    return cpds
            
common_cpds=load_cmn_cpds(cmn_cpds)

# General Data loading

def store_local_kegg_item_keys(kegg_items,database_dir):
    kc=kegg.KeggClientRest()
    all_item_names={}
    illegal_pairs=[("compound","orthology"),("orthology","compound")]
    for kegg_item in kegg_items:
        key_name=os.path.join(database_dir,"{0}_readable_names.tsv".format(kegg_item))
        if not os.path.isfile(key_name):
            item_names=kc.get_ids_names(kegg_item)
            save_readable_key(key_name,item_names,kegg_item)
            all_item_names[kegg_item]=item_names.keys()
        else:
            all_item_names[kegg_item]=load_readable_names(database_dir,[kegg_item],False)[kegg_item].keys()
    for kegg_item_1, kegg_item_2 in itertools.permutations(all_item_names.iterkeys(),2):
        print "considering the pair: {0}, {1}".format(kegg_item_1,kegg_item_2)
        if (kegg_item_1,kegg_item_2) not in illegal_pairs:
            shared_key_name=os.path.join(database_dir,"{0}_linked_{1}_database.tsv").format(kegg_item_1, kegg_item_2)
            if not os.path.isfile(shared_key_name):
                print "The processing of pair: {0},{1} has begun.".format(kegg_item_1,kegg_item_2)
                kc=kegg.KeggClientRest()
                linked_ids=kc.link_ids(kegg_item_2,all_item_names[kegg_item_1])
                save_key_pairings(shared_key_name,linked_ids,(kegg_item_1,kegg_item_2))
        else:
            pass
    return

def load_local_kegg_database_pairings(database_dir,kegg_item_pairs, process_all):
    '''Loads the local databases of kegg_item_1, kegg_item_2 pairings and return a dictionary of
    these pairings in the form kegg_item_1:kegg_items_2 (There can be more than one linked item). This
    loading is based on the earlier use of mgkits kc.link_ids to store all of the pairings needed.
    
    Input:
        database_dir   - The directory with the databases
        kegg_item_pairs- A list of kegg item pairs to load
        process_all    - A boolean decision as whether to load all existing pairs.
        
    Output: A dictionary linking either all existing kegg item pairs or just those specified. It has the form
    dict[item_1,item_2]={kegg_item_1:kegg_2_items}'''
    linking_dictionary={}
    if process_all:
        for file_name in glob(os.path.join(database_dir,'*database.tsv')):
            db_file=os.basename(file_name)
            kegg_1=db_file.split("_linked_")[0]
            kegg_2=db_file.split("_linked_")[1].split("_database")[0]
            linking_dictionary[(kegg_1,kegg_2)]={}             
            with open(file_name) as kegg_links:
                next(kegg_links)#Skip the header
                for line in kegg_links:
                    item_1,item_2=line.strip().split("\t")
                    item_2=item_2.split(";")
                    linking_dictionary[kegg_item_pair][item_1]=item_2
        return linking_dictionary
    
    for kegg_item_pair in kegg_item_pairs:
        file_name=os.path.join(database_dir,"{0}_linked_{1}_database.tsv").format(kegg_item_pair[0], kegg_item_pair[1])
        if os.path.isfile(file_name):
            linking_dictionary[kegg_item_pair]={}
            with open(file_name) as kegg_links:
                next(kegg_links) #skip the header
                for line in kegg_links:
                    item_1,item_2=line.strip().split("\t")
                    item_2=item_2.split(";")
                    linking_dictionary[kegg_item_pair][item_1]=item_2
    return linking_dictionary
                    
def load_readable_names(database_dir,kegg_items,process_all):
    '''Loads in the readable names for a specified kegg item from a list of databases.
    Input:
        database_dir        -  The directory with the databases.
        kegg_items          -  The kegg items to get the readable mapping for.
        process_all         -  Boolean - Should the function retrieve all available databases.
    Output:
        readable_item_dict  -  A dictionary of KEGG_ID: Readable name pairs'''
    readable_item_dict={}
    if process_all:
        for file_name in glob(os.path.join(database_dir,'*_readable_names.tsv')):
            desc_file=os.basename(file_name)
            kegg_item=desv_file.split("_readable_names.tsv")[0]
            readable_item_dict[kegg_item]={}
            with open(file_name) as kegg_descriptions:
                next(kegg_descriptions)
                for line in kegg_descriptions:
                    item_1,item_2=line.strip().split("\t")
                    readable_item_dict[kegg_item][item_1]=item_2
        return readable_item_dict
    
    for kegg_item in kegg_items:
        file_name=os.path.join(database_dir,'{0}_readable_names.tsv'.format(kegg_item))
        readable_item_dict[kegg_item]={}
        with open(file_name) as kegg_descriptions:
            next(kegg_descriptions)
            for line in kegg_descriptions:
                item_1,item_2=line.strip().split("\t")
                readable_item_dict[kegg_item][item_1]=item_2
    return readable_item_dict
            
            
def save_readable_key(key_name,item_names,kegg_item):
    df=pd.DataFrame([
    [col1,col2] for col1,col2 in item_names.iteritems()
                   ])
    df.columns=[kegg_item,"Description"]
    df.to_csv(key_name,sep="\t",index=None)
    return None

def save_key_pairings(shared_key_name,item_links,kegg_item_tuple):
    df=pd.DataFrame([
            [col1,";".join(col2)] for col1, col2 in item_links.iteritems()
        ])
    df.columns=[kegg_item_tuple[0],kegg_item_tuple[1]]
    df.to_csv(shared_key_name,sep="\t",index=None)
    return None

def remove_ko_pth_hits(file_path):
    
    with open(file_path,'r') as KO_PTH_pairs:
        out_dir=os.path.dirname(file_path)
        temp_file=open(os.path.join(out_dir,"temp.tsv"),'w')
        for line in KO_PTH_pairs:
            KO,pathways=line.strip().split("\t")
            pathways=pathways.split(";")
            pathways=[pathway for pathway in pathways if not pathway.startswith("ko")]
            pathways=";".join(pathways)
            new_line="{0}\t{1}\n".format(KO,pathways)
            temp_file.write(new_line)
    temp_file.close()
    
def make_local_rcn_eqn_database(database_dir):
    kc=kegg.KeggClientRest()
    all_reactions=load_readable_names(database_dir,["reaction"],False)["reaction"].keys()
    rcn_eqns=kc.get_reaction_equations(all_reactions,max_len=10)
    file_name=os.path.join(database_dir, "reaction_equation_links.tsv")
    df=rcn_eqn_pd_df(rcn_eqns)
    df.to_csv(file_name,sep="\t",index=False)
        
    return

def rcn_eqn_pd_df(rcn_eqn_dict):
    df=pd.DataFrame([
            [rcn, ";".join(in_cpds),";".join(out_cpds)] for rcn, cpds in rcn_eqn_dict.iteritems() for in_cpds,out_cpds in [cpds.values()]
    if in_cpds!=[] or out_cpds!=[]    
        ])
    #df.replace('','NA')
    df.columns=["Kegg_rcn_ID","side_1_cpds","side_2_cpds"]
    return df

def load_local_rcn_eqn_database(database_dir):
    file_name=os.path.join(database_dir,"reaction_equation_links.tsv")
    rcn_eqn_dict={}
    n_df=pd.read_csv(file_name,sep="\t")
    n_df.fillna('',inplace=True)
    return n_df.set_index("Kegg_rcn_ID").T.to_dict(orient='dict')

def load_local_rcn_eqn_database_set(database_dir):
    rcn_eqn_pairs=load_local_rcn_eqn_database(database_dir)
    return {rcn:{side:set(cpds.split(";")) for side,cpds in pairs.iteritems()} for rcn, pairs in rcn_eqn_pairs.iteritems()}
            
#Load in the coral data
def load_bin_names(tax_file):
    #Load bin_ids and bins_taxonomy from file.
    bin_names={}
    bin_pair=[]
    with open(tax_file,'r') as bin_tax_pair:
        bin_tax_pair.readline()
        for line in bin_tax_pair:
            bin_pair.append(tuple(line.strip().split("\t")))

    bin_names={bin_id:taxonomy for taxonomy, bin_id in bin_pair}
    return bin_names

def make_local_complete_module_info_db(database_dir):
    '''Creates a local database of the module definitions.'''
    all_modules=load_readable_names(database_dir,["module"],False)["module"].keys()
    kc=kegg.KeggClientRest()
    entries={}
    max_len=10
    post_processed_defs={}
    print "There are a total of {0} modules to parse".format(len(all_modules))
    N_modules=len(all_modules)
    for i in xrange(0,N_modules,max_len):
        if N_modules-i<max_len:
            n_entries=N_modules-i
        else:
            n_entries=max_len      
        query="+".join(all_modules[i:i+n_entries])
        kegg_entries=kc.get_entry(query)
        hits=re.findall("\nDEFINITION(.*)\n",kegg_entries)
        print i, i+max_len-1,"n_hits:{0}".format(len(hits))
        #if len(hits)!=max_len:
        #    print m
        for module, definition in itertools.izip(all_modules[i:i+10],hits):
            new_def=definition.strip().replace(" --"," ").replace("-- "," ").replace("  "," ").strip()
            entries[module]=new_def
            if "M" in definition:
                print module,definition
                post_processed_defs[module]=new_def
        #These post_processed modules should be modules defined in terms of other modules.
    for module, definition in post_processed_defs.iteritems():
        new_def=definition
        print "This is the definition being considered.", new_def
        new_defs=re.split("[, +-]",definition)
        for item in new_defs:
            simp_item=item.strip(")").strip("(")
            if simp_item.startswith("M"):
                print "This is the current item",item
                new_def=new_def.replace(simp_item,"("+entries[simp_item]+")")
        print new_def
        entries[module]=new_def
        
    temp_entries=entries
    protein_complexes='(.)?([K][0-9]+[+]){1,}[K][0-9]+(.)?'
    for module, definition in temp_entries.iteritems():
        for match in re.finditer(protein_complexes, definition):
            match_str=match.group()
            if match_str.startswith("(") and match_str.endswith(")"):
                pass
            elif match_str[-1].isdigit() and match_str[0]=="K":
                match_str=match_str[:] #Trim random end characters
                new_match="("+match_str+")"
                entries[module]=entries[module].replace(match_str,new_match)
            elif match_str[0]=="K":
                match_str=match_str[0:-1] #Trim random end characters
                new_match="("+match_str+")"
                entries[module]=entries[module].replace(match_str,new_match)
            elif match_str[-1].isdigit():
                match_str=match_str[1:] #Trim random end characters
                new_match="("+match_str+")"
                entries[module]=entries[module].replace(match_str,new_match)
            else:
                match_str=match_str[1:-1] #Trim random end characters
                new_match="("+match_str+")"
                entries[module]=entries[module].replace(match_str,new_match)
              
        if i%100==0:
            kc=kegg.KeggClientRest()
            print module
    print "{0} modules were parsed".format(len(entries))
    df=pd.DataFrame([
            [module,entry] for module,entry in entries.iteritems()
        ])
    df.columns=["Kegg_id","Kegg_definition"]
    df.to_csv(os.path.join(database_dir,"Module_definitions_pairs_db.tsv"),sep="\t",index=None)
    
    return

def load_local_complete_module_info_db(database_dir):
    '''Loads a local database of kegg definitions'''
    def_dict={}
    with open(os.path.join(database_dir,"Module_definitions_pairs_db.tsv")) as definitions:
        next(definitions)
        for line in definitions:
            module,kegg_def=line.strip().split("\t")
            def_dict[module]=kegg_def
    
    return def_dict

def make_new_trusted_database(database_dir):
    '''
    Definition:
        This function will take an entire kegg module definition file and will create
        a new local database with the expressions written so that they can simple be
        evaluated when loading the files.
    Input: 
        database_dir: str
            A directory containing the database of kegg module definitions.
    Output:
        None
    Calls:
        replacement: Turns kegg definitions in logical nested tuples of sets.
    '''
    old_module_def=load_local_complete_module_info_db(database_dir)
    #print "This is the old module information",old_module_def
    new_pd_df=[""]*len(old_module_def)
    i=0
    for module,definition in old_module_def.iteritems():
        try:
            logical_evaluation=replacement(definition,False)[1]
            new_pd_df[i]=[module,logical_evaluation]
            i+=1
                
        except TypeError:
            print "TypeError 2:",module, definition
        except SyntaxError:
            print "SyntaxError 2:",module, definition
        except NameError:
            print "NameError 2:",module ,definition
    #print new_pd_df      
    new_pd_df=pd.DataFrame(new_pd_df)
    #print "The second checkpoint."
    new_pd_df.columns=["ModuleID","KEGG_log_expr"]
    
    new_pd_df.to_csv(os.path.join(database_dir, "module_kegg_log_expr.tsv"),header=True,sep="\t",index=False)
    return None

def fix_module_orthology_pairs(database_dir):
    '''Replaces the occurences of modules in the module-orthology links to their corresponding KOs'''
    all_pairs=load_local_kegg_database_pairings(database_dir,[["module","orthology"]], False)["module","orthology"]
    new_pairings=[]
    for module,kos in all_pairs.iteritems():
        new_items=[]
        rand_module=False
        for item in kos:
            if item.lower().startswith("m"):
                new_items.extend(all_pairs[item])
                rand_module=True
            else:
                new_items.append(item)
        if rand_module:
            new_pairings.append((module, list(set(new_items))))
    for (module, new_items) in new_pairings:
        all_pairs[module]=new_items
        
    df=pd.DataFrame([
            [col1,";".join(col2)] for col1, col2 in all_pairs.iteritems()
        ])
    df.columns=[kegg_item_tuple[0],kegg_item_tuple[1]]
    fixed_name=os.path.join(database_dir,"")
    df.to_csv(fixed_name,sep="\t",index=None)
    return
        
def load_local_cleaned_definition_db(extra_def_file):
    cleaned_db={}
    with open(os.path.join(extra_def_file)) as paired_exprs:
        next(paired_exprs) #Skip header
        for line in paired_exprs:
            module,expr=line.split("\t")
            kegg_log=eval(expr)
            if not isinstance(kegg_log,tuple):
                kegg_log=tuple([kegg_log])
            cleaned_db[module]=kegg_log
    return cleaned_db

# For remaking graphs

pathways = {
    'carbon': ['map01200'],
    'nitrogen-sulfur-fatty_acid-photosynthesis': ['map00910', 'map00920', 'map01212', 'map00195'],
    'oxidative_phosphorylation': ['map00190'],
    'two-component': ['map02020'],
    'amino-acids':['map01230'],
    #'thiamine-metabolism':'map00730',
    #'riboflavin-metabolism':'map00740',
    #'Vitamin-B6-metabolism':'map00750',
    #'Nicotinate&Nicotinamide-metabolism':'map00760',
    #'PantoThenate and CoA Biosynthesis':'map00770',
    #'Biotin-Metabolism':'map00780',
    #'Lipoic-Acid-Metabolism':'map00785',
    #'Folate-Biosynthesis':'map00790',
    #'OneCarbonPoolByFolate':'map00670',
    #'retinol-metabolism-animals':'map00830',
    #'porphyrin&ChlorophyllMetabolism':'map00860',
    #'Ubiquinone&OtherTerpenoid-QuinoneBiosynthesis':'map00130',
    'vitamins&cofactors':['map00730','map00740','map00750','map00760','map00770','map00780','map00785','map00790','map00670','map00830','map00860','map00130'],
    #"Alanine, aspartate and glutamate metabolism":'map00250',
    #"Cysteine and methionine metabolism":'map00270',
    #"Glycine, serine and threonine metabolism":'map00260',
    #"Valine, leucine and isoleucine degradation":'map00280',
    #"Valine, leucine and isoleucine biosynthesis":'map00290',
    #"Lysine biosynthesis":'map00300',
    #"Lysine degradation":'map00310',
    #"Arginine biosynthesis":'map00220',
    #"Arginine and proline metabolism":'map00330',
    #"Histidine metabolism":'map00340',
    #"Tyrosine metabolism":'map00350',
    #"Phenylalanine metabolism":'map00360',
    #"Tryptophan metabolism":'map00380',
    #"Phenylalanine, tyrosine and tryptophan biosynthesis":'map00400',
    "AminoAcidMetabolism":['map00250','map00270','map00260','map00280','map00290','map00300','map00310','map00220',\
                           'map00330','map00340','map00350','map00360','map00380','map00400']
    #"beta-Alanine metabolism":'map00410',
    #"Taurine and hypotaurine metabolism":'map00430',
    #"Phosphonate and phosphinate metabolism":'map00440',
    #"Selenocompound metabolism":'map00450',
    #"Cyanoamino acid metabolism":'map00460',
    #"D-Glutamine and D-glutamate metabolism":'map00471',
    #"D-Arginine and D-ornithine metabolism":'map00472',
    #"D-Alanine metabolism":'map00473',
    #"Glutathione metabolism":'map00480',
   # "Metabolisms of other amino acids":['map00410','map00430','map00440','map00450','map00460','map00471','map00472','map00473','map00480']
    ,"Glycosaminoglycan degradation & Synthesis":["map00531","map00532","map00534"] ,
    "Bacterial Secretion Systems":["ko03070"],
    "phosphotransferase system (PTS)":["ko02060"],
    "ABC transporters":["ko02010"],
    "N-Glycan biosynthesis": ["map00510"],
    "CationicAntiomicrobialPeptide_CAMP_resistance":["map01503"],
    "Vancomycin_Beta-lactamResistance":["map01502","map01501"]

}

pathways ={key:[item.replace("ko","map") for item in items] for key,items in pathways.iteritems()}

bin_names=load_bin_names(tax_file)

def pathway_to_modules(pathway_dict,database_dir):
    links=load_local_kegg_database_pairings(database_dir,[("pathway","module")], False)["pathway","module"]
    pathways={path:list(set(itertools.chain(*[links[egx] for egx in pathway if egx in links]))) for path, pathway in pathway_dict.iteritems()}
    return pathways

MO_pathways=pathway_to_modules(pathways,database_dir)

####################################
# Abundance parsing
####################################
def load_relative_abundance(file_name):
    abundance_data=pd.DataFrame.from_csv(file_name,sep="\t")
    new_index=np.array(pd.Series(abundance_data.index.values).str.strip("_genomic"))
    abundance_data.set_index(new_index,inplace=True)
    return abundance_data


def rel_abundance_to_dict(df):
    abundance_dict={}
    column_names=df.columns.values
    for genome, rel_abund in df.iterrows():
        abundance_dict[genome]={}
        for i,column_name in enumerate(column_names):
            abundance_dict[genome][column_name]=round(rel_abund[i],9)
    return abundance_dict

def get_abundance(file_name):
    return rel_abundance_to_dict(load_relative_abundance(file_name))

def only_key_abundance(abundance_dict,unique_key):
    reduced_dict={}
    for genome, rel_abund_dict in abundance_dict.iteritems():
        for sample, abundance in rel_abund_dict.iteritems():
            if unique_key in sample:
                reduced_dict[genome]=abundance
    return reduced_dict

def reduced_abundance(file_name,unique_key):
    
    return only_key_abundance(get_abundance(file_name),unique_key)
    

def load_coverages(file_name):
    '''Loads in the coverage file with a separate coverage for each contig. '''
    with open(file_name) as coverage_file:
        header=coverage.readline()
    header=tuple(header.strip().split("\t"))
    coverage_file=np.genfromtxt(fule_name,delimiter="\t",names=True)
    return

def normalised_coverages(read_counts=True):
    return

def load_completeness_contamination(file_name):
    genome_data=pd.DataFrame.from_csv(file_name,sep="\t")
    genome_data_dict={}
    data_names=["completeness","contamination"]
    for genome, row in genome_data.iterrows():
        genome_data_dict[genome]={}
        for item_id,item in itertools.izip(data_names,row):
            genome_data_dict[item_id]=item
    return genome_data

#def 

def estimate_binning_completeness(coverage_file,completeness_file, total_counts_file):
    
    return

#################################
# KEGG Completeness
#################################

def replacement(definition,return_string=False):
    '''
    Description:
        Turns an irregular definition string into a set of KOs in nested tuples to indicate their relationship and to
        prepare them for processing.
    Input:
        definition: string
            Module definition as defined in KEGG    
    output:
        definition: Same as above
        new_nesting: tuple of tuples of sets
            The new logical form of the definition to use in evaluating completeness.
        
    Notes: 
        This function uses eval which a security risk. Caution should be taken in using this function
    '''
    
    logical_chars="[+ ,-]"
    pattern="K[0-9]{5}"
    new_expression=definition
    logical_groups="([K][0-9]+,){1,}[K][0-9]+" #Find any group of KOs (1 or more) separated by commas
    #non_extended_groups='(K[0-9]{5}[^,K0-9\n\]]*){1,}(K[0-9]{5}[^,])' #Get any none comma separated chunk of KOs
    non_extended_groups='([^,]|^)(K[0-9]{5}[^,K0-9\n\]\)\[]*){1,}([\n \+\"\]\[\(\)]|$)'
    end_KOs='(K[0-9]{5})$'
    repeated_set='(set\(\[){2,}(["][K][0-9]+["][^^ )(+.\"-]?).*?(\]\)){1,2}'
    set_in_set='set\(\[[K0-9",]*(set\(\[)"[K][0-9]+"\]\)'
    rear_match="([^0-9][,-]K[0-9]{5})"
    forward_match='(K[0-9]{5})[,-][\(]'
    new_expression=new_expression.replace(" -","-")
    new_expression=new_expression.replace(", ",",")
    ko_set_matches=[]
        
    for match in re.finditer(logical_groups,new_expression):
        #print match.group()
        ko_set_matches.append(match.group())
        #new_expression=new_expression.replace(match.group(),"set(["+match.group()+"])",1)
    set_matches=['']*len(ko_set_matches)
    cleaned_matches=[match.strip(",") for match in ko_set_matches]
    ko_set_matches=set(cleaned_matches)
    
    
    for i,match in enumerate(ko_set_matches):
        set_matches[i]=match
        new_expression=new_expression.replace(match,"set(["+match+"])")
        
    step_0_1=new_expression
    for match in re.finditer(non_extended_groups,new_expression):
        if match:
            new_match=match.group()
            #print type(new_match), new_match
            for sub_match in re.findall(pattern,new_match):
                new_match=new_match.replace(sub_match,"set(["+sub_match+"])")
            new_expression=new_expression.replace(match.group(),new_match)
            
    for match in re.finditer(rear_match,new_expression):
        if match:
            new_match=match.group()
            #print type(new_match), new_match
            for sub_match in re.findall(pattern,new_match):
                new_match=new_match.replace(sub_match,"set(["+sub_match+"])")
            new_expression=new_expression.replace(match.group(),new_match)
            
    for match in re.finditer(forward_match,new_expression):
        if match:
            new_match=match.group()
            #print type(new_match), new_match
            for sub_match in re.findall(pattern,new_match):
                new_match=new_match.replace(sub_match,"set(["+sub_match+"])")
            #print new_match
            new_expression=new_expression.replace(match.group(),new_match)
                
    step_0_2=new_expression
    for match in set(re.findall(pattern,new_expression)):
        new_expression=new_expression.replace(match,"\""+match+"\"")
    step_1=new_expression
    #print new_expression
    
    new_expression=new_expression.replace(",",",\",\",")
    #non_set_comma='.{6}[^\"],[^\"].{6}' #Extends sides to try and ensure uniqueness
    #for match in set(re.findall(non_set_comma,new_expression)):
    #    new_expression=new_expression.replace(match,",\",\",".join(match.split(",")))
        
    new_expression=new_expression.replace(" ",",\" \",")
    #print new_expression
    new_expression=new_expression.replace("-",",\"-\",")
    new_expression=new_expression.replace("+",",\"+\",")
    step_2=new_expression
    #print new_expression
    new_expression=new_expression.replace("\"\"","\"")
    new_expression=new_expression.replace(",,",",")
    new_expression=new_expression.replace(",]","]")
    new_expression=new_expression.replace(",)",")")
    new_expression=new_expression.replace("\"\"","\"")
    step_3=new_expression
    #No Longer needed due to fix in tests.

    for match in re.finditer(repeated_set,new_expression):
        #print "This is the match", match.group()
        new_match=match.group().strip("set([")
        new_match=new_match.strip("])")
        new_match="set(["+new_match+"])"
        #print
        #print "This is the new match", new_match
        new_expression=new_expression.replace(match.group(),new_match)
    
    for match in re.finditer(set_in_set,new_expression):
        new_match=match.group()
        #print new_match.split("set([")
        blank, section_1,section_2=new_match.split("set([")
        section_2=section_2.strip("])")
        #print section_2
        new_match="set(["+section_1+section_2
        new_expression=new_expression.replace(match.group(),new_match)
    step_4=new_expression
    isolated_start="^\"K[0-9]{5}\""
    for match in re.findall(isolated_start,new_expression):
        new_expression=new_expression.replace(match,"("+match+",)")
        
    new_expression="("+new_expression+")"
    
    
        
    
    if return_string:
        return new_expression
    try:
        new_nesting=eval(new_expression)
        return definition,new_nesting 
    

        
    except TypeError:
        print definition
        print "0_1",step_0_1
        print "0_2",step_0_2
        print "Step 1:", step_1
        print "Step 2:", step_2
        print "Step 3:", step_3
        print "Step 4:", step_4
        print ";".join(set_matches)
        print "Type error", new_expression
        raise
        
    except SyntaxError:
        print definition
        print "0_1",step_0_1
        print "0_2",step_0_2
        print "Step 1:", step_1
        print "Step 2:", step_2
        print "Step 3:", step_3
        print "Step 4:", step_4
        print ";".join(set_matches)
        print "Syntax error", new_expression
        raise
        
    except NameError:
        print definition
        print "0_1",step_0_1
        print "0_2",step_0_2
        print "Step 1:", step_1
        print "Step 2:", step_2
        print "Step 3:", step_3
        print "Step 4:", step_4
        print ";".join(set_matches)
        print "NameError", new_expression
        raise
        
    return
        
     

def alt_eval_kegg_bool(kegg_expr,ko_set):
    '''
    Description:
        Evaluates a list of boolean expressions blocks to get a list of T, F results summarising the module completeness.
    Input: 
        kegg_expr: List of sets
            A kegg expression consting of KOs in nested tuples. eg, (KO1 ((KO2,KO3-KO4),KO5).
            The separators represent the kegg boolean separators.

        ko_set: set
            The set of KOs to be evaluated for compelteness in this particular kegg expression.
    Calls:
        eval_kegg_bool: function
            The workhorse of  this function - recursively evaluates each element in kegg_expr to
            decide if it is actually true or false.
    '''
    n_elements=len(kegg_expr)
    results_vec=["na"]*n_elements
    for i in xrange(0,n_elements,2):
        current_element=kegg_expr[i]
        if isinstance(current_element,tuple):
            side_1_result=eval_kegg_bool(current_element,ko_set)
        else:
            side_1_result= len(ko_set & current_element)>0

        full_result=side_1_result
        #print full_result
        results_vec[i]=full_result
    #print results_vec
    for i,element in enumerate(results_vec):
        if not isinstance(element,bool):
            results_vec[i]=kegg_expr[i]
    
    return (n_elements+1)/2,results_vec
    
    
def eval_kegg_bool(kegg_expr,ko_set):
    '''
    Description:
        A recursive implementation of the kegg boolean logic for evaluating based on a set of KOs if a module is complete.
        If given a tuple it will recursively search down for more tuples and evaluate them at the lowest level to move up and
        finally finish evaluating the complete block. 
    Input: 
        kegg_expr: List of sets
            A kegg expression consting of KOs in nested tuples. eg, (KO1 ((KO2,KO3-KO4),KO5).
            The separators represent the kegg boolean separators.
        ko_set: set
            The set of KOs to be evaluated for compelteness in this particular kegg expression.
    Calls:
        eval_kegg_bool: function
            Evalutes logical KEGG blocks.
    '''
    n_elements=len(kegg_expr)
    #vector=np.array(["na"]*((n_element+1)/2)-1)
    for i in xrange(0,n_elements-1,2):
        #print "THe current kegg expression getting evaluated", kegg_expr[i:i+3]
        side_1,log_op,side_2=kegg_expr[i:i+3]
        #print "This is the logical operater being used",log_op
        if log_op==" " or log_op=="+":
            #print "Entering +  recursion"
            if isinstance(side_1,tuple):
                side_1_result=eval_kegg_bool(side_1,ko_set)
            else:
                side_1_result= len(ko_set & side_1)>0
            if isinstance(side_2,tuple):
                side_2_result=eval_kegg_bool(side_2,ko_set)
            else:
                side_2_result=len(ko_set & side_2)>0
            full_result=side_1_result and side_2_result
            
        elif "," in log_op:
            #print "Entering , recursion"
            if isinstance(side_1,tuple):
                side_1_result=eval_kegg_bool(side_1,ko_set)
            else:
                side_1_result= len(ko_set & side_1)>0
            if isinstance(side_2,tuple):
                side_2_result=eval_kegg_bool(side_2,ko_set)
            else:
                side_2_result=len(ko_set & side_2)>0     
            full_result=side_1_result or side_2_result
        elif "-" in log_op:
            #print "Entering - recursion"
            if isinstance(side_1,tuple):
                side_1_result=eval_kegg_bool(side_1,ko_set)
            else:
                side_1_result= len(ko_set & side_1)>0
            if isinstance(side_2,tuple):
                side_2_result=eval_kegg_bool(side_2,ko_set)
            else:
                side_2_result=len(ko_set & side_2)>0
            full_result=side_1_result
            
        else:
            print log_op, "There seems to have been an error:"
        #print "The result for side 1", side_1_result, side_1
        #print "The result for side 2", side_2_result, side_2
    #print "The final results being returned", full_result
    return full_result

def block_level_completeness(results_vector,correct_partial,nested_descr,ko_set):
    '''
    Description:
        Calculates the percent completness of a KEGG module in one of two ways. It either looks at the number of
        logical blocks complete (block level completeness) or also adds a percentage adjustment for how complete the
        incomplete blocks are.
    Input:
        results_vector: List of Booleans
            A list containing the results of evaluating a kegg module KO hits as a boolean expression.
        correct_partial: Boolean
            Indicating whether to try and account for the partial completeness of some logical blocks.
        nested_descr: nested tuple of sets
            A logical grouping of KEGG blocks into tuples with sets of KOs as the lowermost elements.
    Output:
        completeness_perc:  float in [0,1]
            Percent module completeness according to one of two methods.'''
    if isinstance(nested_descr,set):
        return len(nested_desc & ko_set) > 0
    
    keep_indices=True
    log_blocks=make_logical_blocks(results_vector,keep_indices)
    position_mapping=make_position_mapping(log_blocks)
    if keep_indices:
        log_blocks=extract_logical_values(log_blocks)
    else:
        pass
    n_tot=len(log_blocks)
    filled_blocks=[any(block) for block in log_blocks]
    n_filled_blocks=sum(filled_blocks)
    adjustment=["na"]*len(log_blocks)
    
    if not correct_partial:
        completeness_perc=float(n_filled_blocks)/n_tot
        return completeness_perc
    else:
        for i,block in enumerate(log_blocks):
            if not any(block):
                n_max_hits=len(block)
                running_total=0
                for j,item in enumerate(block):
                    if item:
                        running_total+=1
                    else:
                        bool_index=position_mapping[i][j]
                        #print "The boolean index:", bool_index
                        #print position_mapping
                        kegg_bool=nested_descr[bool_index]
                        #print kegg_bool
                        running_total+=module_completeness_proportion(kegg_bool,ko_set,correct_partial)
                adjustment[i]=float(running_total)/n_max_hits
            else:
                adjustment[i]=1    
        n_filled_blocks=sum(adjustment)
        completeness_perc=float(n_filled_blocks)/n_tot
        return completeness_perc
    
def extract_logical_values(logical_blocks):
    return [[item[1] for item in block] for block in logical_blocks]

def make_position_mapping(log_blocks):
    mapping={}
    for i, block in enumerate(log_blocks):
        mapping[i]={j:item[0] for j,item in enumerate(block)}
    #print mapping
    return mapping

def module_completeness_proportion(kegg_bool,ko_set,correct_partial):
    '''Returns the completeness of the current kegg_boolean.
    Input:
        
    Output:
        
    Calls:
        block_level_completeness: Calculate the % of kegg blocks which are complete.'''
    if isinstance(kegg_bool,set):
        return len(kegg_bool & ko_set) > 0
    #print "Kegg bool:",kegg_bool
    #print "ko_set:",ko_set
    n_el,results_vector=alt_eval_kegg_bool(kegg_bool,ko_set)
    #print "This is the result vector:",results_vector
    completeness_perc=block_level_completeness(results_vector,correct_partial,kegg_bool,ko_set)
    
    return completeness_perc
    

def make_logical_blocks(results_vector,keep_indices):
    '''
    Description:
        Turn the uppermost level of results from a KEGG boolean into a series of logical blocks. I.e if I had 
        a vector [T and F and T or F or T] then the blocks formed will be [[T],[F],[T,F,T]]. 
    Input:
        results_vector: List of Bools
            A list containing the results of evaluating a kegg module KO hits as a boolean expression.
    Output:
        log_block: list of lists of bools
            A list composed of the logical blocks needed to decide if a boolean is "complete".'''
    
    operator_set=set([" ","-",",","+"])
    log_blocks=[]
    current_block=[]
    previous_logical=""
    log_operators=[" ",",","+","-"]
    for i,item in enumerate(results_vector):
        if item not in log_operators:
            if not current_block:
                if keep_indices:
                    current_block.append((i,item))
                else:
                    current_block.append(item)
#            elif i==(len(results_vector)-1):
#                log_blocks.append(current_block)
            else:
                if previous_logical==" " or previous_logical=="+":
                    log_blocks.append(current_block)
                    if keep_indices:
                        current_block=[(i,item)]
                    else:
                        current_block=[item]
                elif previous_logical==",":
                    #print item
                    if keep_indices:
                        current_block.append((i,item))
                    else:
                        current_block.append(item)
                    #print current_block
                elif previous_logical=="-":
                    pass
        else:
            previous_logical=item
    log_blocks.append(current_block)
            
    return log_blocks

def test_all_local_modules(database_dir):
    #758 comparisons are to be made.
    completeness_dict={}
    #Load all possible KOs
    MO_KO_pairs=load_local_kegg_database_pairings(database_dir,[("Module","orthology")], False)[("Module","orthology")]
    MO_KO_pairs={MO:set(KOs) for MO, KOs in MO_KO_pairs.iteritems()}
    #Use this as the comparison set.
    log_kegg_exprs=load_local_cleaned_definition_db(database_dir)
    #Screen every single module for compelteness (should all be 1.0)
    for module,expression in log_kegg_exprs.iteritems():
        completeness_dict[module]=module_completeness_proportion(expression,MO_KO_pairs[module],True)
        
    
    failures={Module:completeness for Module,completeness in completeness_dict.iteritems() if completeness<1}
    print failures
    
    return completeness_dict

genome_taxonomy=load_bin_names(tax_file)


C:\Users\Alex


# Kegg Item Enrichment analyses

In [2]:
def bonferroni_correction(p,n):
    '''Returns a new Bonferroni corrected p-value threshold for significance.'''
    return p/n

def flat_bonf_correction(ps,n,p):
    '''Applies a simple bonferonni correccted signifiance threshold to all of the observed data
    to determine signifances.'''
    p_val=p/n
    sig_values={(genome_1,genome_2):{PTH:(p_value,exp_val,obs_val) for PTH, (p_value,exp_val,obs_val,completeness) in pathway_scores.iteritems() if p_value <=p_val} for (genome_1,genome_2),pathway_scores in ps.iteritems()}
    return sig_values

def flat_sidak_correction(ps,n,p):
    '''Applies a simple sidak corrected significance threshold to all of the observed data
    to determine significance.'''
    p_val=1-(1-p)**(float(1)/n)
    sig_values={(genome_1,genome_2):{PTH:(p_value,exp_val,obs_val,completeness) for PTH, (p_value,exp_val,obs_val,completeness) in pathway_scores.iteritems() if p_value <=p_val} for (genome_1,genome_2),pathway_scores in ps.iteritems()}
    return sig_values

def Holm_bonferonni_correction(ps,n,p):
    '''Apply a Holm_bonferonni_corrections to a list of p-values with specific
    familywise error rate and total sample size. This an application of Holm's
    method for familywise error rate control Holm et al. ().'''
    list_data=ps.items()
    expanded_list_data=[(genome_1,genome_2,PTH,p_value,exp_val,obs_val,completeness) for (genome_1,genome_2),path_dict in ps.iteritems() for PTH, (p_value,exp_val,obs_val,completeness) in path_dict.iteritems()]
    p_values=sorted(expanded_list_data,key=lambda x: x[3])
    N=n
    significant_values=[]
    for (genome_1,genome_2,pathway,p_value,exp_val,obs_val,completeness) in p_values:
        if p_value<=p/N:
            significant_values.append((genome_1,genome_2,pathway,p_value,exp_val,obs_val,completeness))
            N=N-1
        else:
            break
    sig_vals=defaultdict(dict)
    #print significant_values
    for (genome_1,genome_2, pathway,p_value,exp_val,obs_val,completeness) in significant_values:
        sig_vals[(genome_1,genome_2)][pathway]=(p_value,exp_val,obs_val,completeness)
    return sig_vals

def Holm_sidak_correction(ps,n,p):
    '''Apply a Holm_sidak_corrections to a list of p-values with specific
    familywise error rate and total sample size. This an application of Holm's
    method for familywise error rate control Holm et al. ().'''
    list_data=ps.items()
    expanded_list_data=[(genome_1,genome_2,PTH,p_value,exp_val,obs_val,completeness) for (genome_1,genome_2),path_dict in ps.iteritems() for PTH, (p_value,exp_val,obs_val,completeness) in path_dict.iteritems()]
    p_values=sorted(expanded_list_data,key=lambda x: x[3])
    N=n
    significant_values=[]
    for (genome_1,genome_2,pathway,p_value,exp_val,obs_val,completeness) in p_values:
        if p_value<=1-(1-p)**(float(1)/N):
            significant_values.append((genome_1,genome_2,pathway,p_value,exp_val,obs_val,completeness))
            N=N-1
        else:
            break
    sig_vals=defaultdict(dict)
    for (genome_1,genome_2, pathway,p_value,exp_val,obs_val,completeness) in significant_values:
        sig_vals[(genome_1,genome_2)][pathway]=(p_value,exp_val,obs_val,completeness)
    return sig_vals        
def sidak_correction(p,n):
    '''Sidak correction as described by Sidak et al. ----'''
    return 1-(1-p)**(float(1)/n)

def False_Discovery_Rate_correction(ps,n,p):
    '''Control the false discovery rate for a list of p-values with specific
    familywise error rate and total sample size as described by _____ et al. (19__)'''
    list_data=ps.items()
    expanded_list_data=[(genome_1,genome_2,PTH,p_value,exp_val,obs_val,completeness) for (genome_1,genome_2),path_dict in ps.iteritems() for PTH, (p_value,exp_val,obs_val,completeness) in path_dict.iteritems()]
    p_values=sorted(expanded_list_data,key=lambda x: x[3])
    i=1
    N=n
    significant_values=[]
    for (genome_1,genome_2,pathway,p_value,exp_val,obs_val,completeness) in p_values:
        if p_value<=float(i)*p/N:
            significant_values.append((genome_1,genome_2,pathway,p_value,exp_val,obs_val,completeness))
            i+=1
        else:
            break
    sig_vals=defaultdict(dict)
    for (genome_1,genome_2, pathway,p_value,exp_val,obs_val,completeness) in significant_values:
        sig_vals[(genome_1,genome_2)][pathway]=(p_value,exp_val,obs_val,completeness)
    return sig_vals

def multiple_test_correction(genome_data,p,n,correction_type="bonferonni"):
    '''Perform multiple test correction to determine the significantly enriched modules.'''
    unrepped_mods={(genome_1,genome_2):{PTH:results for PTH, results in pathway_data.iteritems() if results[1]==0} for (genome_1,genome_2),pathway_data in genome_data.iteritems()}
    repped_mods={(genome_1,genome_2):{PTH:results for PTH, results in pathway_data.iteritems() if results[1]!=0} for (genome_1,genome_2),pathway_data in genome_data.iteritems()}
    sig_values=defaultdict(dict)
    if correction_type=="bonferonni":
        sig_values=flat_bonf_correction(repped_mods,n,p)
    elif correction_type=="Holm-bonferonni":
        sig_values=Holm_bonferonni_correction(repped_mods,n,p)
    elif correction_type=="Holm-Sidak":
        sig_values=Holm_sidak_correction(repped_mods,n,p)
    elif correction_type=="Sidak":
        sig_values=flat_sidak_correction(repped_mods,n,p)
    elif correction_type=="FDR" or correction_type=="False_Discovery_Rate":
        sig_values=False_Discovery_Rate_correction(repped_mods,n,p)
    else:
        print "None of the multiple test corrections were chosen. Please choose \
        one of bonferonni, Holm-bonferonni, Holm-Sidak, Sidak, FDR or False Discovery Rate"
        
    for (genome_1,genome_2), pathway_data in unrepped_mods.iteritems():
        for PTH,results in pathway_data.iteritems():
            sig_values[(genome_1,genome_2)][PTH]=results
        
    return sig_values

def measure_completeness(PTH,target_KOs,KO_PTH_structure):
    '''Ensures that completenss is only calculated for modules.'''
    if not PTH.startswith("M"):
        return 0
    else:
        if KO_PTH_structure==():#empty tuple
            return 0
        else:
            perc_completeness=module_completeness_proportion(KO_PTH_structure,target_KOs,True)
            return perc_completeness

def new_measure_completeness(PTH,target_KOs,KO_PTH_structure,is_module):
    if is_module:
        if KO_PTH_structure==():#empty tuple
            return 0
        else:
            perc_completeness=module_completeness_proportion(KO_PTH_structure,target_KOs,True)
            return perc_completeness
    else:
        return 0
        
    
        
def import_KO_hits(file_name):  
    #Needs to detect counts file and rewrite it.
    KO_hits={}
    with open(file_name) as KO_hits_file:
        header_line=next(KO_hits_file)
        if len(header_line.split("\t"))>2:
            dataframe=True
        else:
            dataframe=False
    if not dataframe:
        with open(file_name) as KO_hits_file:
            for line in KO_hits_file:
                genome,KOs=line.strip().split("\t")
                KOs=KOs.split(";")
                #print genome,KOs
                KO_hits[genome]=KOs
        #print "THE KO HITS", KO_hits
        KO_hits_dup=KO_hits
        for genome, KOs in KO_hits.iteritems():
            KO_hits[genome]=Counter(KOs)
        KO_hits=pd.DataFrame.from_dict(KO_hits,orient='columns')
        KO_hits=KO_hits.fillna(0)
        #print "THe Ko hits", KO_hits
        return KO_hits
    elif dataframe:
        KO_hits=convert_kegg_hits_to_occurrences(file_name)
        #print "THE KO HITS", KO_hits
        return KO_hits
    else:
        print "Error"
        return
    
def convert_kegg_hits_to_occurrences(file_name):
    counts_df=pd.read_csv(file_name,sep="\t",index_col=[0])
    #KO_dict={}
    #for genome in counts_df.columns:
    #    KO_dict[genome]=[]#*n_kos

    #for index,row in counts_df.iterrows():
    #    for genome, count in row.iteritems():
    #        #print index, count
    #        KO_dict[genome].extend((index,)*count)
    #
    counts_df=counts_df.fillna(0)
    return counts_df

def load_groupings(file_name,all_members):
    '''Loads in a tab separated files of the form: group_name\tgenome_1|genome_2|genome_3|genome_4|etc
    Input:
        file_name: str
            Name of groupings file
    Output:
        groupings: dict
            Dictionary of group name and the list of component genomes'''
    groupings={}
    if not isinstance(file_name,type(None)):
        if not os.path.isfile(file_name):
            raise IOError('The specified groupings file does not exist.')

        with open(file_name) as groupings_file:

            next(groupings_file) #skip header

            for line in groupings_file:
                group_name,members=line.strip().split("\t")
                list_of_members=members.split("|")

                groupings[group_name]=list_of_members
    else:
        groupings["all_genomes_grouped"]=list(all_members)
    #print 'The groupings', groupings   
    return groupings

def load_comparisons(file_name):
    '''Loads in a tab separated files of the form: source_group\ttarget_1|target_2|target_3|etc
    
    Input:
        file_name: str
            Name of comparisons file
    Output:
        comparisons: dict
            Dictionary of baseline gorup name and the list of target group/genome_names to look
            for enrichment in'''
    comparisons={}
    if not os.path.isfile(file_name):
        raise IOError('The specified comparisons file does not exist.')
    with open(file_name) as comparisons_file:
        
        next(comparisons_file) #skip header
        
        for line in comparisons_file:
            source,targets=line.strip().split("\t")
            list_of_targets=targets.split("|")
            comparisons[source]=list_of_targets
    #print 'The comparisons', comparisons
    return comparisons

def make_comparisons_dict(KO_hits):
    '''Makes a default comparison of each individual against the grouping of all.
    Input: 
        KO_hits: dict
            dictionary of all KO_hits for each genome
    returns:
        comparison_dict: dict
            Dictionary with one key (all_genomes_grouped:list of genomes)'''
    return {"all_genomes_grouped":KO_hits.keys()}

def make_groupings_dict(KO_hits,groups, all_grouped=False):
    '''Make a list of KOs for each group based on the KOs of group members.
    Input:
    
    Output:
    '''
    group_KOs={}
    if all_grouped:
        KO_hits["all_genomes_grouped"]=KO_hits.sum(axis=1)
        #group_KOs["all_genomes_grouped"]=list(itertools.chain(*KO_hits.itervalues()))
        print "KO_HITS",KO_hits
        return KO_hits
    else:
        for group, members in groups.iteritems():
            KO_hits[group]=KO_hits.loc[:,members].sum(axis=1)
            #group_KOs[group]=list(itertools.chain(*[KO_hits[member] for member in members]))
        #print "KO_HITS",KO_hits
        return KO_hits          
    

def do_enrichment_comparisons(KO_hits, threshold, database_dir, comparisons,extras_dict,excluded_set,definition_file,overlap_dict,make_mo_comp):
    '''Makes all of the enrichment comparisons specified in comparisons.
    Input:
    
    Output:
    '''
    #print "Inside do_enrichment_comparison this is the extras_dict", extras_dict
    results_dict={}
    N_comparisons={}
    
    if make_mo_comp:
        #Load pairings
        KO_KG_ITEM_PAIRS=kegg_pairs_wrapper(["orthology","Module"],excluded_set, extras_dict,database_dir)
        #KO_KG_ITEM_PAIRS=load_local_kegg_database_pairings(database_dir,[("orthology",kegg_item)], False)[("orthology",kegg_item)]
        KO_KG_ITEM_PAIRS=defaultdict(set, KO_KG_ITEM_PAIRS)
        #KG_ITEM_KO_PAIRS=load_local_kegg_database_pairings(database_dir,[(kegg_item,"orthology")], False)[(kegg_item,"orthology")]
        KG_ITEM_KO_PAIRS=kegg_pairs_wrapper(["Module","orthology"],excluded_set, extras_dict,database_dir)    
        KG_ITEM_KO_PAIRS={MO:list(set(KOs)) for MO,KOs in KG_ITEM_KO_PAIRS.iteritems()}
        #Must be MODULES
        KO_PTH_structure=logical_loading_wrapper(database_dir, definition_file, extras_dict)
        #KO_PTH_structure=load_local_cleaned_definition_db(database_dir)
        KO_PTH_structure=defaultdict(tuple,KO_PTH_structure)

        total_KO_hits=KO_hits.sum(axis=0)
        #print total_KO_hits
        new_MO_hits_matrix=construct_MO_matrix(KO_hits, KG_ITEM_KO_PAIRS)
        #print "NEw extreme values", new_MO_hits.ix[(new_MO_hits>total_KO_hits).any(axis=1),:]
        #print "weird value", new_MO_hits_matrix.loc['Msuper_duper',:]
    else:
        KO_KG_ITEM_PAIRS=None
        KG_ITEM_KO_PAIRS=None  
        KO_PTH_structure=None
        total_KO_hits=KO_hits.sum(axis=0)
        #print total_KO_hits
        new_MO_hits_matrix=KO_hits
        #print "NEw extreme values", new_MO_hits.ix[(new_MO_hits>total_KO_hits).any(axis=1),:]
        #print "weird value", new_MO_hits_matrix.loc['Msuper_duper',:]

    for source, targets in comparisons.iteritems():
        for target in targets:
            result,n_comparisons=adjusted_enrichment_test(target,source,threshold, new_MO_hits_matrix,database_dir,KO_KG_ITEM_PAIRS,KG_ITEM_KO_PAIRS,KO_PTH_structure,total_KO_hits,KO_hits,overlap_dict)
            if len(result)>0:
                results_dict[(target,source)]=result
                N_comparisons[(target,source)]=n_comparisons
    
    #Turn results_dict into a dataframe
    #results_dict={(target,source,pathway) for (target,source),results in}
    return results_dict,N_comparisons

def construct_MO_matrix(KO_hits, KG_ITEM_KO_PAIRS):
    new_module_counts={}
    for Module, KOs in KG_ITEM_KO_PAIRS.iteritems():
        #print KOs
        hits_per_module=KO_hits.ix[KOs,:].sum(axis=0)
        #print hits_per_module
        #print "Module", Module, hits_per_module['coral']
        new_module_counts[Module]=hits_per_module
    #print "weird_coral_stuff",sum(KO_hits['coral']), KO_hits['coral']
    new_module_df=pd.DataFrame.from_dict(new_module_counts,orient='index')
    #print new_module_df
    return new_module_df


def load_bin_taxa(bin_file):
    bin_names={}
    
    with open(bin_file) as taxonomy:
        next(taxonomy) #skip header
        for line in taxonomy:
            
            taxa,ID=line.strip().split("\t")
            bin_names[ID]=taxa
            
    bin_taxa=defaultdict(lambda:"No_Associated_Taxonomy" ,bin_names)
    
    return bin_taxa

def load_excluded_items(excluded_file):
    excluded_items=[]
    with open(excluded_file) as ignored:
        for line in ignored:
            excluded_items.append(line.strip())
    return set(excluded_items)

def load_extra_items(extras_file):
    extra_items={}
    with open(extras_file) as extras:
        for line in extras:
            ID, KOs=line.strip().split("\t")
            extra_items[ID]=KOs.split(";")
    return extra_items

def make_extras_logical(extra_def_file,extras_dict,database_dir):
    definition_dict={}
    if extra_def_file==None:
        rename_logical=os.path.join(database_dir,"module_extra_kegg_log_expr.tsv")
        for key,KOs in extras_dict.iteritems():
            n_KOs=len(KOs)
            new_def=[' ']*(2*n_KOs-1)
            for i,KO in enumerate(KOs):
                new_def[2*i]=set([KO])
                #if i==(n_KOs-1):
                    #pass
                #else:
                    #new_def[2*i+1]=' '
            definition_dict[key]=tuple(new_def)
        pd_df=pd.DataFrame([
                [module,definition] for module, definition in definition_dict.iteritems()
            ])
        pd_df.columns=["Module","KEGG_Boolean"]
        pd_df.to_csv(rename_logical,header=True,sep="\t",index=False)
        extra_def_file=rename_logical
        
    else:
        pd_df=[]
        core,file_ending=os.path.splitext(extra_def_file)
        rename_logical=core+"KEGG_bool"+file_ending
        with open(extra_def_file) as kegg_defs:
            for line in kegg_defs:
                module, definition=line.strip().split()
                new_def=replacement(definition)
                pd_df.append([module,new_def])
        pd_df=pd.DataFrame(pd_df)
        pd_df.columns=["Module_Name","Logical_definition"]
        pd_df.to_csv(rename_logical,header=True,sep="\t")
        extra_def_file=rename_logical
        
    return extra_def_file

   
    
def kegg_pairs_wrapper(kegg_items,excluded_items, extra_items,database_dir):
    '''
    ********************************************************************************************************
    This function acts as a wrapper to load in the local pairs databases.
    It also excludes those items mentioned in the excluded items list and adds the user defined extra items.
    ********************************************************************************************************
    Input:
    
    Output:
    
    Calls:
    
    '''
    #print "THese are the excluded items", excluded_items
    kegg_item_1,kegg_item_2=kegg_items
    if kegg_item_1=="orthology":
        if isinstance(extra_items,type(None)):
            KO_KG_ITEM_PAIRS=load_local_kegg_database_pairings(database_dir,[("orthology",kegg_item_2)], False)[("orthology",kegg_item_2)]
            if excluded_items!=None:
                #valid_PAIRS=set(KO_KG_ITEM_PAIRS.keys())-set(excluded_items)
                #print "These are the excluded items", excluded_items
                KO_KG_ITEM_PAIRS={key:set(value)-set(excluded_items) for key,value in KO_KG_ITEM_PAIRS.iteritems()}
            #KO_KG_ITEM_PAIRS=add_extra_items_values(KO_KG_ITEM_PAIRS,extra_items)
        else:
            #print "The extra items",extra_items
            KO_KG_ITEM_PAIRS=load_local_kegg_database_pairings(database_dir,[("orthology",kegg_item_2)], False)[("orthology",kegg_item_2)]
            if excluded_items!=None:
                #valid_PAIRS=set(KO_KG_ITEM_PAIRS.keys())-set(excluded_items)
                #print "These are the excluded items", excluded_items
                KO_KG_ITEM_PAIRS={key:set(value)-set(excluded_items) for key,value in KO_KG_ITEM_PAIRS.iteritems()}
            KO_KG_ITEM_PAIRS=add_extra_items_values(KO_KG_ITEM_PAIRS,extra_items)
        return KO_KG_ITEM_PAIRS
    
    elif kegg_item_2=="orthology":
        if isinstance(extra_items,type(None)):
            KG_ITEM_KO_PAIRS=load_local_kegg_database_pairings(database_dir,[(kegg_item_1,"orthology")], False)[(kegg_item_1,"orthology")]
            if excluded_items!=None:
                valid_PAIRS=set(KG_ITEM_KO_PAIRS.keys())-set(excluded_items)
                KG_ITEM_KO_PAIRS={key:value for key,value in KG_ITEM_KO_PAIRS.iteritems() if key in valid_PAIRS}
        else:
            KG_ITEM_KO_PAIRS=load_local_kegg_database_pairings(database_dir,[(kegg_item_1,"orthology")], False)[(kegg_item_1,"orthology")]
            if excluded_items!=None:
                valid_PAIRS=set(KG_ITEM_KO_PAIRS.keys())-set(excluded_items)
                KG_ITEM_KO_PAIRS={key:value for key,value in KG_ITEM_KO_PAIRS.iteritems() if key in valid_PAIRS}
            for name, KOs in extra_items.iteritems():
                KG_ITEM_KO_PAIRS[name]=KOs
            
        return KG_ITEM_KO_PAIRS
    else:
        print "The input kegg item pair must have at least one occurence of orthology."
        
    return

def add_extra_items_values(dict_1, extra_items_dict):
    ''' Swaps the keys and values of the dictionary. Then, add these to the orthology - kg item pair dict'''
    reversed_dict=defaultdict(list)
    dict_1={key:set(values) for key,values in dict_1.iteritems()}
    dict_1=defaultdict(set,dict_1)
    #print extra_items_dict
    for key,values in extra_items_dict.iteritems():
        for value in values:
            reversed_dict[value].append(key)
    reversed_dict={key:list(set(values)) for key,values in reversed_dict.iteritems()}
    #print "THe reversed_dictionary",reversed_dict
    
    for key,item in reversed_dict.iteritems():
        #print item
        dict_1[key].update(item)

    dict_1={key:list(set(values)) for key,values in dict_1.iteritems()}
    return dict_1

def readable_kegg_wrapper(kegg_item,extra_items,database_dir):
    '''Adds the user specified extras to the existing readable database after it is loaded into python.'''
    
    readable_names=load_readable_names(database_dir,[kegg_item],False)[kegg_item]
    if extra_items==None:
        pass
    else:
        for key in extra_items.iterkeys():
            readable_names[key]=key    
    return readable_names

def standard_enrichment_wf(KO_file, groupings_file,comparisons_file,threshold, database_dir, output_dir, bin_file, mult_adjust_type,extras_dict,excluded_set,extra_defs_file,account_for_overlap,make_mo_comps):
    '''A workflow for calculating the enrichments of modules in target genomes/groups against other groups as based on 
    the comparisons and groupings file.
    Input:
    
    Output:
    
    '''
    
    bin_taxa=load_bin_taxa(bin_file)
    
    #print bin_taxa

    KO_hits=import_KO_hits(KO_file)
    #print groupings_file
    #print comparisons_file
    if isinstance(groupings_file,type(None)):
        print "No groupings file specified."
        groupings=load_groupings(groupings_file,KO_hits.columns)
        groupings_dict=make_groupings_dict(KO_hits,None,all_grouped=True)
    else:
        groupings=load_groupings(groupings_file,KO_hits.columns)
        groupings_dict=make_groupings_dict(KO_hits,groupings,all_grouped=False)
        
    if isinstance(comparisons_file,type(None)):
        print "No comparisons file specified."
        comparisons_dict=make_comparisons_dict(KO_hits)
    else:
        comparisons=load_comparisons(comparisons_file)
        comparisons_dict=comparisons
    #print "The comparisons to be made",comparisons_dict
    #print "The groups made", groupings_dict
    #for source, targets in comparisons.iteritems():
        #print "These are the comparison to be made: [{0}]X[{1}]".format(source,targets)
    #KO_hits_matrix=construct_KO_matrix(KO_hits)
    
    #for group, KOs in groupings_dict.iteritems():
    #    KO_hits[group]=KOs
    if account_for_overlap:
        overlap_file=construct_overlap_file(comparisons_dict, groupings)
    else:
        overlap_file=None
    #print "This is the overlap for each comparison being made", overlap_file
    if not isinstance(extras_dict,type(None)):
        print "Processing the extras files"
        extras_dict=load_extra_items(extras_dict)
        #print "THe extra items", extras_dict
    if not isinstance(excluded_set,type(None)):
        print "Processing the excluded files"
        excluded_set=load_excluded_items(excluded_set)
    
        
    results_dict,N_comparisons=do_enrichment_comparisons(KO_hits,threshold, database_dir, comparisons_dict,extras_dict,excluded_set,extra_defs_file,overlap_file,make_mo_comps)
    
    results_dict=post_hoc_significance_correction(results_dict,N_comparisons,threshold,mult_adjust_type)

    dfs=write_enrichment_data(results_dict, database_dir,output_dir,bin_taxa,mult_adjust_type,extras_dict)
    
    return dfs

def construct_overlap_file(comparison_dict, groupings):
    groupings={key: set(members) for key, members in groupings.iteritems()}
    groupings=defaultdict(set,groupings)
    #print groupings
    overlap_dict={}
    for source, targets in comparison_dict.iteritems():
        for target in targets:
            if target in groupings:
                if source in groupings:
                    overlap_dict[(target,source)]=list(groupings[source] & groupings[target])
                else:
                    overlap_dict[(target,source)]=list(set([source]) & groupings[target])
            else:
                if source in groupings:
                    overlap_dict[(target,source)]=list(groupings[source] & set([target]))
                else:
                    overlap_dict[(target,source)]=list(set([source]) & set([target]))
                    
        
    return overlap_dict
            
            


def write_enrichment_data(enrichment_results_dict,database_dir,output_dir,bin_taxa,mult_adjust,extras_items):
    '''Writes the Results_dict from the all_comparisons function after turning them into a pandas
    dataframe. '''
    pandas_dataframes={}
    out_frame={}
    #readable_orthology=readable_kegg_wrapper("orthology",extra_items,database_dir)
    readable_orthology=load_readable_names(database_dir,["orthology"],False)["orthology"]
    readable_kegg_items=readable_kegg_wrapper("module",extras_items,database_dir)
    #readable_kegg_items=load_readable_names(database_dir,["Module"],False)["Module"]
    
    #pandas_dataframes["Module"]={}
    

    df=pd.DataFrame([
            ["Module",target_genome,source_genome,bin_taxa[target_genome],enriched_item,readable_kegg_items[enriched_item],p_value, expected_count,observed_count,completeness]
            for (target_genome,source_genome), enriched_kegg_items in enrichment_results_dict.iteritems() for enriched_item,(p_value,expected_count,observed_count,completeness) in enriched_kegg_items.iteritems()
        ])
    df.columns=["Kegg_item","Target_genome","BaselineGenome","Target_taxonomy","{0}_ID".format("Module"),"Readable_{0}".format("Module"),"p_value","Expected_Count","Observed_count","Completeness(MO_only)"]
    split_frames={}

    #Split the dataframe into those ovserved in both target and source and only target
    split_frames["Disjoint"]=df[:][(df.Expected_Count==0) & (df.Observed_count>0)]
    split_frames["Overlapped"]=df[:][df.Expected_Count>0]
        
    pandas_dataframes["Module"]=split_frames
    out_frame["Module"]=df
    
    for kegg_item, dfs in pandas_dataframes.iteritems():
        for rel,df in dfs.iteritems():
            df.to_csv(os.path.join(output_dir,"enriched_{0}_{2}_corrected_{1}_comparisons.tsv".format(kegg_item,rel,mult_adjust)),sep="\t",index=False)
    return out_frame


def post_hoc_significance_correction(results_dict,N_comparisons,threshold,mult_adjust_type):
    '''Apply the multiple test correction after analysis.'''
    
    n=sum(N_comparisons.itervalues())
    print "THe number of genome pair comparison",len(N_comparisons)
    print "The number of comparisons made",n
    
    significant_results=multiple_test_correction(results_dict,threshold,n,mult_adjust_type)

    new_results={(target,source):{PTH:results for PTH,results in pathway_dict.iteritems()} for ((target,source),pathway_dict) in significant_results.iteritems()}
    
    return new_results



def adjusted_enrichment_test(target, source,threshold,MO_hits,database_dir,KO_KG_ITEM_PAIRS,KG_ITEM_KO_PAIRS,KO_PTH_structure,total_KO_hits,KO_hits,overlap_dict,kegg_item="Module"):
    '''A simple test for enrichment of some kegg item in the target genome based on the baseline probabilities calculated using
    the source genome. A simple binomial test is used to work out the probably of seeing as many genes as was present
    in the target genome given that it has the same chance of occuring as in the source genome. All module present in the
    target but not at all in the source are also included in the output as (0,0,observed count of kegg item,0).f
    
    Input: 
        target               -  Genome to investigate for enrichment
        source               -  The genome used to calculate the baseline chance of a kegg_item occuring.
        KO_hits              -  Dictionary of genome: KO pairs, requires a collated entry for all microorganisms
        threshold            -  The significance threshold before corrections.
        kegg_item            -  The kegg_item to look for enrichment of. Normally run as [module, pathway]
        database_dir         -  The location of the linking files and file descriptions
        abundance_adjust     -  Account for the relative abundance of organism's in the metagenome.
        abundance_file       -  A file with the relative abundance information of the microbes in each sample
    Output:
        Significant_scores   -  All enriched kegg_items passing the threshold after bonferroni correction.'''
    

    shared_pathways=MO_hits.ix[(MO_hits.loc[:,source]>0) & (MO_hits.loc[:,target]>0),:].index


    all_target_KOs=make_KO_set_from_series(KO_hits[target])
    #print "Number of unique KOs present for the target.", target, len(all_target_KOs)
    target_KOs={PTH:set(items) & all_target_KOs  for PTH, items in KG_ITEM_KO_PAIRS.iteritems()}
    target_KOs=defaultdict(set,target_KOs)
    
    #print "This is the source and target", source, target
    
    N_source_KOs=total_KO_hits[source]
    
    N_target_KOs=total_KO_hits[target]
    
    enrichment_scores=defaultdict(lambda:(0,0,0,0))
    
    #print "These were the overlapping organisms", overlap_dict[(target,source)]

    account_for_overlap=isinstance(overlap_dict,dict)
    if account_for_overlap:
        overlapped_hits=MO_hits[overlap_dict[(target,source)]].sum(axis=1)
        overlap_free_source=MO_hits[source]-overlapped_hits
    else:
        overlap_free_source=MO_hits[source]
    #print overlap_free_source
    #print "weird value", MO_hits[source]['Msuper_duper']
    #print "weird value", MO_hits[target]['Msuper_duper']
    #print "The negative values",overlap_free_source[overlap_free_source<0]

    #print "The total number of KOs for source and target,", N_source_KOs, N_target_KOs
    for PTH,count in MO_hits[target].iteritems():
        if PTH in shared_pathways:
            cont_tab_test=scipy.stats.fisher_exact
            N_source_PTH_hits=overlap_free_source[PTH]
            p_PTH_source=float(N_source_PTH_hits)/N_source_KOs
            data_table=[[count,N_source_PTH_hits],[N_target_KOs-count,N_source_KOs-N_source_PTH_hits]]
            #print PTH, data_table,N_source_KOs, N_source_PTH_hits
            p_val=cont_tab_test(data_table)[1]
        else:
            cont_tab_test=lambda x: 0
            N_source_PTH_hits=0
            p_PTH_source=0
            data_table=[[count,N_source_PTH_hits],[N_target_KOs-count,N_source_KOs-N_source_PTH_hits]]
            #print data_table
            p_val=cont_tab_test(data_table)
            
        enrichment_scores[PTH]=(p_val,p_PTH_source*N_target_KOs,count,measure_completeness(PTH,target_KOs[PTH],KO_PTH_structure[PTH]))

    return enrichment_scores,len(shared_pathways)
    
def make_KO_set_from_series(KO_series):
    
    return set(KO_series[KO_series>0].index)
def enrichm_wf(args):
    
    KO_file=args.KO_file
    groupings_file=args.groupings_file
    comparisons_file=args.comparisons_file
    threshold= args.threshold
    database_dir=args.database_dir
    output_dir=args.output_dir
    bin_file=args.bin_file
    mult_adjust_type=args.mult_test_correction
    excluded_items=args.exclude
    extra_file=args.extra
    extra_defs=args.extra_defs
    dfs=standard_enrichment_wf(KO_file, groupings_file,comparisons_file,threshold, database_dir, output_dir, bin_file, mult_adjust_type,extra_file, excluded_items,extra_defs)
    return dfs

def completem_wf(args):
    #Load in variables
    KO_file=args.KO_file
    #groupings_file=args.groupings_file
    #comparisons_file=args.comparisons_file
    #threshold= args.threshold
    database_dir=args.database_dir
    output_dir=args.output_dir
    bin_file=args.bin_file
    #mult_adjust_type=args.mult_test_correction
    excluded_items=args.exclude
    extra_items=args.extra
    extra_defs=args.extra_defs
    extract_core=args.extract_core
    #process all of the genomes one by one.
    dfs=standard_completeness_wf(KO_file,database_dir,output_dir,bin_file,excluded_items,extra_items,extra_defs,extract_core)
    return dfs

def standard_completeness_wf(KO_file,database_dir,output_dir,bin_file,excluded_set,extras_dict,extra_defs,extract_core):
    bin_taxa=load_bin_taxa(bin_file)
    
    KO_hits=import_KO_hits(KO_file)

    if not isinstance(extras_dict,type(None)):
        print "Processing the extras files"
        extras_dict=load_extra_items(extras_dict)
        #print "THe extra items", extras_dict
    if not isinstance(excluded_set,type(None)):
        print "Processing the excluded files"
        excluded_set=load_excluded_items(excluded_set)
    
    all_modules=get_module_completeness(KO_hits,database_dir,excluded_set,extras_dict,extra_defs)
    
    #print "Number of genomes:", len(all_modules)
    
    #print "The results", all_modules
    
    output_file=os.path.join(output_dir, "Genome_module_completeness_matrix.tsv")
    
    readable_names=readable_kegg_wrapper("module",extras_dict,database_dir)
    
    #print "Readable_names:",readable_names
    
    df=write_module_completeness_data(all_modules,output_file,bin_taxa,extract_core,readable_names)
    
    return df
    
def get_module_completeness(KO_hits,database_dir,excluded_items,extra_items,extra_defs_file):
    #Load pairings
    KO_KG_ITEM_PAIRS=kegg_pairs_wrapper(["orthology","module"],excluded_items, extra_items,database_dir)
    
    KO_KG_ITEM_PAIRS=defaultdict(set, KO_KG_ITEM_PAIRS)
    
    KG_ITEM_KO_PAIRS=kegg_pairs_wrapper(["module","orthology"],excluded_items, extra_items,database_dir)
    
    KO_sets=make_sets_from_df(KO_hits)
    
    new_MO_hits_matrix=construct_MO_matrix(KO_hits, KG_ITEM_KO_PAIRS)
    
    observed_pathways=set(new_MO_hits_matrix.ix[new_MO_hits_matrix.apply(any_hits,axis=1),].index)
    print "The number of observed modules", len(observed_pathways)
    #genome_KO_PTH={genome:set(itertools.chain(*[itertools.chain(*[KO_KG_ITEM_PAIRS[KO] for KO in KOs])])) for genome,KOs in KO_hits.iteritems()}
    is_module=True
    #observed_pathways=set(itertools.chain(*[KO_KG_ITEM_PAIRS[KO] for KO in itertools.chain(*[KOs for ID, KOs in KO_hits.iteritems()])]))

    #Must be MODULES
    KO_PTH_structure=logical_loading_wrapper(database_dir, extra_defs_file, extra_items)            
    KO_PTH_structure=defaultdict(tuple,KO_PTH_structure)
   
    complete_data=defaultdict(lambda: defaultdict(float))
    for genome, KOs in KO_sets.iteritems():

        for module in observed_pathways:
            completeness=new_measure_completeness(module,KOs,KO_PTH_structure[module],is_module)
            #print "module", KO_PTH_structure[module]
            complete_data[genome][module]=completeness
    
    return complete_data

def make_sets_from_df(df):
    ko_set={}
    for genome, KOs in df.iteritems():
        ko_set[genome]=make_KO_set_from_series(KOs)
    
    return ko_set

def write_module_completeness_data(genome_module_data,output_file,bin_taxa,extract_core_MOs,readable_names):
    '''Creates a dataframe from the genome[module]=completeness dictionary.
    Input:
    
    Output:
    
    '''
    df=pd.DataFrame.from_dict(genome_module_data,orient='columns')
    print "The matrix dimensions", df.shape
    #taxonomy
    #if len(bin_taxa)>0:
        #df['Taxonomy']=df.index.map(bin_taxa)
    if extract_core_MOs:
        df=df.ix[df.apply(all_hits,axis=1),]
    df['module_desc']=pd.Series(df.index,index=df.index).map(readable_names)
    #print readable_names
    #print pd.Series(df.index).map(readable_names)
    print df['module_desc']
    cols=df.columns.tolist()
    cols=cols[-1:]+cols[:-1]
    df=df[cols]
    df.to_csv(output_file,index=True, header=True, sep="\t")
    
    return df

def all_hits(row):
    return all(row>0)

def any_hits(row):
    return any(row>0)

def load_definition_file(def_file):
    def_dict={}
    with open(def_file) as definitions:
        for line in definitions:
            module,definition=line.strip().split("\t")
            def_dict[module]=definition
    return def_dict

def make_extra_definition_database(definitions_file,database_dir):
    '''
    Definition:
        This function will take an entire kegg module definition file and will create
        a new local database with the expressions written so that they can simple be
        evaluated when loading the files.
    Input: 
        database_dir: str
            A directory containing the database of kegg module definitions.
    Output:
        None
    Calls:
        replacement: Turns kegg definitions in logical nested tuples of sets.
    '''
    old_module_def=load_definition_file(definitions_file)
    #print "This is the old module information",old_module_def
    new_pd_df=[""]*len(old_module_def)
    i=0
    for module,definition in old_module_def.iteritems():
        try:
            logical_evaluation=replacement(definition,False)[1]
            new_pd_df[i]=[module,logical_evaluation]
            i+=1
                
        except TypeError:
            print "TypeError 2:",module, definition
        except SyntaxError:
            print "SyntaxError 2:",module, definition
        except NameError:
            print "NameError 2:",module ,definition
    #print new_pd_df      
    new_pd_df=pd.DataFrame(new_pd_df)
    #print "The second checkpoint."
    new_pd_df.columns=["ModuleID","KEGG_log_expr"]
    
    new_pd_df.to_csv(os.path.join(database_dir, "module_extra_kegg_log_expr.tsv"),header=True,sep="\t",index=False)
    
    return None

def load_local_cleaned_definition_db(extra_def_file):
    cleaned_db={}
    with open(os.path.join(extra_def_file)) as paired_exprs:
        next(paired_exprs) #Skip header
        for line in paired_exprs:
            module,expr=line.split("\t")
            kegg_log=eval(expr)
            if not isinstance(kegg_log,tuple):
                kegg_log=tuple([kegg_log])
            cleaned_db[module]=kegg_log
    return cleaned_db

def logical_loading_wrapper(database_dir, definitions_file, extras):
    '''
    Loads the kegg booleans used to determine the completeness of a module.
    It can also handle creating a database for the user added definitions and loading them at the same time
    as the complete kegg directory.
    
    Input:
    databse_dir: str
        location of database files
    definitions_file: str or None
        location of user definition definitions
    extras: dict or None
        The extra modules added by the user with no definition structure.
    Output:
        original_kegg_log : dict
            module,kegg_boolean pairs to be used for completeness evaluations
    
    '''
    if definitions_file!=None:
        print "Using predefined kegg booleans"
        make_extra_definition_database(definitions_file,database_dir)
        original_kegg_log=load_local_cleaned_definition_db(os.path.join(database_dir,"module_kegg_log_expr.tsv"))
        new_kegg_log=load_local_cleaned_definition_db(os.path.join(database_dir, "module_extra_kegg_log_expr.tsv"))
        for module, kegg_bool in new_kegg_log.iteritems():
            original_kegg_log[module]=kegg_bool
            #print kegg_bool
    elif definitions_file==None and extras!=None:
        print "Automatically creating kegg boolean from module components"
        make_extras_logical(definitions_file,extras,database_dir)
        original_kegg_log=load_local_cleaned_definition_db(os.path.join(database_dir,"module_kegg_log_expr.tsv"))
        new_kegg_log=load_local_cleaned_definition_db(os.path.join(database_dir, "module_extra_kegg_log_expr.tsv"))
        for module, kegg_bool in new_kegg_log.iteritems():
            original_kegg_log[module]=kegg_bool
            #print kegg_bool
    else:
        print "No extra definitions have been added."
        original_kegg_log=load_local_cleaned_definition_db(os.path.join(database_dir,"module_kegg_log_expr.tsv"))
        
    return original_kegg_log
    
    

In [4]:
KO_file=all_kegg
groupings_file=os.path.join(output_dir,"enriched_hits","groupings.tsv.txt")
#print groupings_file
comparisons_file=os.path.join(output_dir,"enriched_hits","comparisons.tsv.txt")
threshold= 0.05
#database_dir=args.database_dir
enrich_output_dir=os.path.join(output_dir,"enriched_hits")
bin_taxa=tax_file
mult_adjust_type="FDR"
test_new=True
extract_core=False
account_for_overlap=False
make_mo_comparisons=True
if test_new:
    excluded_items=os.path.join(output_dir,"enriched_hits","exclusion_file.txt")
    extra_items=os.path.join(output_dir,"enriched_hits","Extra_module_creation","extras_file.txt")#None#os.path.join(output_dir,"enriched_hits","extras_file.txt")
    extra_defs=os.path.join(output_dir,"enriched_hits","Extra_Module_creation","extra_definitions.txt")#None#os.path.join(output_dir,"enriched_hits","")
else:
    excluded_items=None
    extra_items=None
    extra_defs=None

#print extra_defs
dfs=standard_enrichment_wf(KO_file, groupings_file,comparisons_file,threshold, database_dir, enrich_output_dir, bin_taxa, mult_adjust_type,extra_items,excluded_items,extra_defs,account_for_overlap,make_mo_comparisons)
#print import_KO_hits(KO_file).shape

#comp_dfs=standard_completeness_wf(KO_file,database_dir,enrich_output_dir,bin_taxa,excluded_items,extra_items,extra_defs,extract_core)

Processing the extras files
Processing the excluded files
Using predefined kegg booleans
Traceback (most recent call last):
  File "C:\Anaconda2\lib\site-packages\IPython\core\ultratb.py", line 1118, in get_records


ERROR: Internal Python error in the inspect module.
Below is the traceback from this internal error.


KeyboardInterrupt


In [3]:
KO_file=all_kegg
groupings_file=os.path.join(output_dir,"enriched_hits","groupings.tsv.txt")
#print groupings_file
comparisons_file=os.path.join(output_dir,"enriched_hits","comparisons.tsv.txt")
threshold= 0.05
#database_dir=args.database_dir
enrich_output_dir=os.path.join(output_dir,"enriched_hits")
bin_taxa=tax_file
mult_adjust_type="FDR"
extract_core=False
account_for_overlap=False
make_mo_comparisons=True
excluded_items=None
extra_items=os.path.join(output_dir,"enriched_hits","Extra_module_creation","extras_file.txt")#None#os.path.join(output_dir,"enriched_hits","extras_file.txt")
extra_defs=os.path.join(output_dir,"enriched_hits","Extra_Module_creation","extra_definitions.txt")#None#os.path.join(output_dir,"enriched_hits","")


#print extra_defs
#dfs=standard_enrichment_wf(KO_file, groupings_file,comparisons_file,threshold, database_dir, enrich_output_dir, bin_taxa, mult_adjust_type,extra_items,excluded_items,extra_defs,account_for_overlap,make_mo_comparisons)
#print import_KO_hits(KO_file).shape

comp_dfs=standard_completeness_wf(KO_file,database_dir,enrich_output_dir,bin_taxa,excluded_items,extra_items,extra_defs,extract_core)

Processing the extras files
The number of observed modules 612
Using predefined kegg booleans
The matrix dimensions (612, 54)
2-Oxoglutarate=>GlutamicAcid                                                  2-Oxoglutarate=>GlutamicAcid
Aspartate=>Asparagine                                                                Aspartate=>Asparagine
DMS=>DMSO                                                                                        DMS=>DMSO
DMSO=>DMS                                                                                        DMSO=>DMS
DMSP=>3-(Methylthio)-propanoate(dmdA)                                DMSP=>3-(Methylthio)-propanoate(dmdA)
Dimethyl-benzimidazole=>VitaminB12Coenzyme                      Dimethyl-benzimidazole=>VitaminB12Coenzyme
Glutamate=>Glutamine                                                                  Glutamate=>Glutamine
Glutamate=>Proline_v2                                                                Glutamate=>Proline_v2
Glycine=>Serine   

# Sammy's CPR genomes

In [23]:
person_dir=os.path.join(*[core,g_drive,"Other_Peoples_Data","Sammy","CPR_genomes"])
KO_file=os.path.join(*[person_dir,"CPR_genomes.txt"])
groupings_file=os.path.join(*[person_dir,"groupings.txt"])
#print groupings_file
comparisons_file=os.path.join(*[person_dir,"comparisons.txt"])
threshold= 0.05
#database_dir=args.database_dir
enrich_output_dir=person_dir
#print person_dir
bin_taxa=tax_file
mult_adjust_type="FDR"
test_new=False
extract_core=False
account_for_overlap=False
make_mo_comparisons=True
excluded_items=None
extra_items=None
extra_defs=None
                                
dfs=standard_enrichment_wf(KO_file, groupings_file,comparisons_file,threshold, database_dir, enrich_output_dir, bin_taxa, mult_adjust_type,extra_items,excluded_items,extra_defs,account_for_overlap,make_mo_comparisons)

No extra definitions have been added.
THe number of genome pair comparison 27
The number of comparisons made 1799


# Nitrospira Genomes

In [22]:
#Changed Non-comammox to Non_comammox in the comparisons file 
person_dir=os.path.join(*[core,g_drive,"Other_Peoples_Data","Caitlin"])
KO_file=os.path.join(*[person_dir,"Nitrospira_KO_kegg_matrix.tsv"])
groupings_file=os.path.join(*[person_dir,"nitrospira_groups.tsv"])
#print groupings_file
comparisons_file=os.path.join(*[person_dir,"nitrospira_comparisons.txt"])
threshold= 0.05
#database_dir=args.database_dir
enrich_output_dir=person_dir
#print person_dir
bin_taxa=tax_file
mult_adjust_type="FDR"
test_new=False
extract_core=False
account_for_overlap=False
make_mo_comparisons=True
excluded_items=None
extra_items=None
extra_defs=None
                                
dfs=standard_enrichment_wf(KO_file, groupings_file,comparisons_file,threshold, database_dir, enrich_output_dir, bin_taxa, mult_adjust_type,extra_items,excluded_items,extra_defs,account_for_overlap,make_mo_comparisons)

No extra definitions have been added.
THe number of genome pair comparison 3
The number of comparisons made 754


In [None]:
import argparse
import pandas as pd
import itertools
import os, sys
from collections import Counter
from collections import defaultdict
import scipy as sp
import numpy as np
import scipy.stats
import re
#standard_enrichment_wf(KO_file,\
#        groupings_file,comparisons_file,threshold, database_dir, output_dir, bin_taxa, mult_adjust_type)

def parse_args():
    parser=argparse.ArgumentParser()
    subparsers = parser.add_subparsers(help='Please select one of either: enrichm or completem')
    parser.add_argument("-c","--cpus",help="**UNIMPLEMENTED** - Break each of the pair of comparisons onto a subprocess")
    parser.add_argument("-v","--verbose",help="**UNIMPLEMENTED** - Decided whether to have more descriptive output of current steps.")
    enrichm_parser=subparsers.add_parser('enrichm',description="A simple tool for looking for enrichment of KO hits to kegg modules\
    between groups of genomes.")
    enrichm_parser.add_argument('KO_file',help="A file KO counts with KO as row and source as column or genome\tKO1;KO2;.... pairs")
    enrichm_parser.add_argument('-gf','--groupings_file',help="A file containing the genomes that should be grouped with a specific name")
    enrichm_parser.add_argument('-cf','--comparisons_file',help="A file containing the comparisons between groups to be made")
    enrichm_parser.add_argument('threshold',help="The threshold to control for in false discovery rate of familywise error rate")
    enrichm_parser.add_argument('database_dir',help="The directory containing the local kegg databases.")
    enrichm_parser.add_argument('output_dir',help="The directory to write output information.")
    enrichm_parser.add_argument('-b','--bin_taxa',help="A file containing a taxonomic pairing with some genomes")
    enrichm_parser.add_argument('-m','--mult_test_correction',default='FDR',help='The form of mutiple test correction to use. There are 5\
    options: ')
    enrichm_parser.add_argument('--exclude',help='A list of kegg items to exclude from the analysis.')
    enrichm_parser.add_argument('--extra',help='A file of extra kegg_item: KO pairs defined by the user.')
    enrichm_parser.add_argument('--extra_defs',help='A file of extra kegg_item\tKEGG definition pairs defined by the user.')
    
    completem_parser=subparsers.add_parser('completem',description="A small library of tools for evaluting kegg booleans offline.")
    
    completem_parser.add_argument('KO_file',help="A file KO counts with KO as row and source as column or genome\tKO1;KO2;.... pairs")
    completem_parser.add_argument('database_dir',help="The directory containing the local kegg databases.")
    completem_parser.add_argument('output_dir',help="The directory to write output information.")
    completem_parser.add_argument('-b','--bin_taxa',help="A file containing a taxonomic pairing with some genomes")
    completem_parser.add_argument('--exclude',help='A list of kegg items to exclude from the analysis.')
    completem_parser.add_argument('--extra',help='A file of extra kegg_item: KO pairs defined by the user.')
    completem_parser.add_argument('--extra_defs',help='A file of extra kegg_item\tKEGG definition pairs defined by the user.')
    args=parser.parse_args()
    
    return args
    
def main():
    
    args=parse_args()
    
    if args.subparser_name=='enrichm':
        dfs=enrichm_wf(args)
    elif args.subparser_name=='completem':
        dfs=completem_wd(args)
        
    return dfs

if __name__ == '__main__':
    
    main()
    

# Kegg based hmm extractor

In [8]:
import os, sys, glob,subprocess,shutil
from multiprocessing.dummy import Pool as ThreadPool
from multiprocessing import Pool
import multiprocessing
import os,sys,re
import pandas as pd
import numpy as np
import scipy as sp
import glob as gb
from collections import defaultdict

def kegg_hmms_wf(hmm_dir, output_dir, group_title, opt_item_file, extras_file,database_dir,seq_dir,cpus,tax_file):
    '''This workflow relies on having a local directory with hmms from kegg or a separate directory
    with hmms of interest. It simply creates a directory, a hmm database, runs the hmm search and then reads the results
    to product two counts files. One for gene with a hit per genome/searched object and one 
    with hits per contig in each searched object. It is a tool to improve my speed of using kegg hmms.
    The modules file lets me extract all KO hmms relevant to a module for quicker searching. 
    
    Input:
    
    
    
    
    Output:
    
    
    
    
    '''
    if not isinstance(extras_file,type(None)):
        print "Processing the extras files"
        extras_file=load_extra_items(extras_file)
        
    if not isinstance(opt_item_file,type(None)):
        opt_item_file=load_optional_items(opt_item_file,extras_file, database_dir)

    new_directory=extract_local_hmms(hmm_dir,output_dir,group_title,opt_item_file)
    hmm_database=make_hmm_database(new_directory)
    
    table_dir=parallel_search_hmm_database(new_directory,hmm_database,seq_dir,cpus)
    
    gene_hits_dir=os.path.join(new_directory, "gene_hits")
    os.mkdir(gene_hits_dir)
    N_gene_hits,motif_hits_in_gene=hmm_hits_wf(table_dir, gene_hits_dir,tax_file, 1,"#",None)
    
    return N_gene_hits,motif_hits_in_gene

def load_optional_items(opt_item_file,extra_items,database_dir):
    '''Loads the local file of KOs and modules to extract hmms for.'''
    new_items=[]
    with open(opt_item_file) as optional_items:
        for line in optional_items:
            item=line.strip()
            new_items.append(item)
    module_ko_pairs=kegg_pairs_wrapper(["module",'orthology'],None, extra_items,database_dir)
    for item in new_items:
        if item not in module_ko_pairs:
            module_ko_pairs[item]=item
    old_len=len(new_items)
    new_items=[module_ko_pairs[item] for item in new_items]
    new_len=len(new_items)
    if new_len>old_len:
        print "Some modules were processed."
    else:
        print "No modules were processed."
    new_items=set(new_items)
    
    print "{0} kos will be retrieved.".format(len(new_items))
    
    return new_items

def try_new_filename(file_name):
    '''Checks if a path points to an existing object. If it does then a new file name is map
    with a different number appended to the end.'''
    if os.path.exists(file_name):
        path,ext=os.path.splitext(file_name)
        if re.match('[0-9]+$',path):
            path=path[:-1]+str(int(path[-1])+1)
            file_name=path+ext
            return try_new_filename(file_name)
        else:
            path=path+"0"
            file_name=path+ext
            return try_new_filename(file_name)
    else:
        return file_name

def extract_local_hmms(orig_hmm_dir, output_dir,group_title, opt_items):
    '''Copies the desired hmm's into the new output directory.'''
    new_dir_path=os.path.join(output_dir,group_title)
    os.mkdir(new_dir_path)
    hmm_dir=os.path.join(new_dir_path,"hmms")
    os.mkdir(hmm_dir)
    if not isinstance(opt_items,type(None)):
        for KO in opt_items:
            #print "The KO", KO
            ko_file=os.path.join(orig_hmm_dir,KO+".hmm")
            hmm_name=KO+".hmm"
            new_loc=os.path.join(hmm_dir,hmm_name)
            shutil.copyfile(ko_file,new_loc)
    else:
        for file_name in glob.glob(os.path.join(orig_hmm_dir,'*.hmm')):
            hmm_name=os.path.basename(file_name)
            new_loc=os.path.join(hmm_dir,hmm_name)
            shutil.copyfile(file_name,new_loc)
        
    return new_dir_path

def create_description_file(hmm_dir,database_dir):
    '''Creates a new KO:Readable name key from the hmms found in the directory.'''
    readable_KOs=load_readable_names_wrappers('orthology',database_dir)
    ko_list=[]
    for file_name in glob.glob(hmm_dir,'*.hmm'):
        hmm_name=os.path.basename(file_name)
        hmm,ext=os.path.splitext(hmm_name)
        ko_list.append(hmm)
    ko_list=set(ko_list)
    ko_tup=["\t".join((ko,readable_KOs[ko])) for ko in ko_list]
    ko_file="\n".join(ko_tup)
    with open(os.path.join(hmm_dir,"hmm_descriptions.tsv")) as desc_file:
        desc_file.write(ko_file)    
    return None

def make_hmm_database(working_dir):
    '''Creates new directory for the hmm database, a concatenated database file
    and auxilliary hmmer files needed to process the scan quickly.'''
    hmm_dir=os.path.join(working_dir,"hmms")
    hmm_database_dir=os.path.join(working_dir,"hmm_database")
    os.mkdir(hmm_database_dir)
    #concatenate all of the hmm files.
    key_name=os.path.split(working_dir.strip(os.sep))[-1]
    hmm_database=os.path.join(hmm_database_dir,key_name+"_hmm_database.hmm")
    with open(hmm_database,'wb+') as database:
        for hmm_file in glob.glob(os.path.join(hmm_dir,'*.hmm')):
            if hmm_file==hmm_database:
                continue
            with open(hmm_file, 'rb') as readfile:
                shutil.copyfileobj(readfile, database)
    #set up the local database properly - hmmer auxilliary db files
    #Making it work on windows
    if platform.system()=="Windows":
        #needs to call Cygwin for it work.
        subprocess.call(["hmmpress",hmm_database])
    else:
        subprocess.call(["hmmpress",hmm_database])

    return hmm_database


def parallel_search_hmm_database(output_working_dir,hmm_database,seq_dir,cpus):
    '''Performs a parallel search against the hmm database using hmmer.
    Input:
    
    Output:
    
    '''
    table_dir=os.path.join(output_working_dir,"results_table")
    os.mkdir(table_dir)
    def search_hmm_database(seq_file):
        '''Workhorse which calls hmmer.'''
        file_name=os.path.basename(seq_file)
        file_name,ext=os.path.splitext(file_name)
        domtbl=os.path.join(table_dir,file_name)
        subprocess.call(['hmmscan','--domtblout',domtbl,hmm_database,seq_file])
        return
    #Only at most as many processes as cpus
    N_cpus=min(cpus, multiprocessing.cpu_count())
    
    #Open a pool of workers for running hmmscan
    pool=ThreadPool(N_cpus)
    
    #All relevant files for scanning against the hmm database
    seq_file_iterator=glob.glob(os.path.join(seq_dir,"*.fna"))
    print "Beginning multithreaded use of hmmscan to make domtbls"
    results=pool.map(search_hmm_database,seq_file_iterator)
    print "The domtbls have finished being made."
    
    return table_dir

def process_domtblout_data(table_dir,taxonomy_file):
    output_dir=os.path.join(os.path.split(table_dir)[0],"hmms_results")
    os.mkdir(output_dir)
    #I still need to work out the default comment character and the header line number. THen,
    #I have enough to runs this.
    N_gene_hits,motif_hits_in_gene=hmm_hits_wf(table_dir, output_dir,taxonomy_file, 1,"#",None)
    
    return N_gene_hits,motif_hits_in_gene

def parse_args():
    '''Interface for making this hmm tool command line useable.'''

    parser=argparse.ArgumentParser()
    parser.add_argument("-c","--cpus",help="Cpus to use for hmm search",default=1)
    parser.add_argument("-v","--verbose",help="**UNIMPLEMENTED** - Decided whether to have more descriptive output of current steps.")
    parser.add_argument('hmm_dir',help="A directory of hmms to scan")
    parser.add_argument('seq_dir',help="A directory of sequence files")
    parser.add_argument('output_dir',help="THe directory to make a project dir in")
    parser.add_argument('group_title',help="The name for the project directory.")
    parser.add_argument('database_dir',help="The directory containing the local kegg databases.")
    parser.add_argument('-o','--opt_item_file',help="A optional file indicating which hmms to extract from the hmm_dir")
    parser.add_argument('-b','--bin_taxa',help="A file containing a taxonomic pairing with some genomes")
    parser.add_argument('-x','--extra_items',help="The extra items to consider when reading the optional file.")
    args=parser.parse_args()
    return args

def main():
    args=parse_args()
    
    hmm_processing_wf(args)
    
def hmm_processing_wf(args):
    cpus=args.cpus
    hmm_dir=args.hmm_dir
    seq_dir=args.seq_dir
    output_dir=args.output_dir
    group_title=args.group_title
    database_dir=args.database_dir
    opt_item_file=args.opt_item_file
    tax_file=args.bin_taxa
    extras_file=args.extra_items
    
    kegg_hmms_wf(hmm_dir, output_dir, group_title, opt_item_file, extras_file,database_dir,seq_dir,cpus,tax_file)
    return

def hmmer_domtblout_parser(file_path,header_line,comment_char):
    i=0
    header_one_space_sep=[2,20]
    line_one_space_sep=[2,-1]
    regex_clean=re.compile("\s{2,}")
    with open (file_path) as hmmer_hits:
        for line in hmmer_hits:
            if i==header_line:
                yield list(itertools.chain(*[fields.split(" ", 1) if i in header_one_space_sep else [fields] for i,fields in enumerate(regex_clean.split(line.strip())) ]))
            elif not line.startswith(comment_char):
                #print line[0:3]
                #30 was chosen since there are 23 columns and hence if 1 char + sep per column then 46 chars
                part_proc=regex_clean.split(line.strip())
                for i in line_one_space_sep:
                    part_proc[i]=part_proc[i].split(None,2)
                
                yield list(itertools.chain(*[ [item] if not isinstance(item,list) else item for item in part_proc]))
            else:
                pass
            i+=1
    
def create_hmmer_domtblout_df(file_path,header_line,comment_char):
    hmmer_file=hmmer_domtblout_parser(file_path,header_line,comment_char)
    header=next(hmmer_file)
    df=pd.DataFrame([
            line for line in hmmer_file 
        ])
    df.columns=header
    return df

def all_domtblout_df(file_dir,header_line,comment_char,optional_reg_cut):
    repeat_dfs={}
    i=0
    for file_path in gb.glob(os.path.join(file_dir, "*")):
        file_name=os.path.basename(file_path)
        #print file_name
        if isinstance(optional_reg_cut,type(None)):
            file_id=file_name.replace(".tsv","")
        else:
            file_id=file_name.replace(".tsv","")
            file_id=re.sub(optional_reg_cut,"",file_id)
        repeat_dfs[file_id]=create_hmmer_domtblout_df(file_path,header_line,comment_char)
        i+=1
        if i%10==0:
            print "{0} files have been processed. The last was {1}".format(i,file_id)
    if len(repeat_dfs)==0:
        print "No files were loaded"
    return repeat_dfs

def merge_repeat_dfs(df_dict):
    df_list=[None]*len(df_dict)
    i=0
    for genome_id,df in df_dict.iteritems():
        new_df=df
        new_df['Genome_id']=genome_id
        #new_df.rename(columns={0:'Gene_Name'},inplace=True)
        df_list[i]=new_df
        i+=1
    merged_dict=pd.concat(df_list,axis=0)
    merged_dict.index.names=["Gene_name"]
    cols = merged_dict.columns.tolist()
    cols=cols[-1:]+cols[0:-1]
    merged_dict=merged_dict[cols]
    #merged_dict['Genome_id']=merged_dict['Genome_id'].str.strip("_genomic").str.strip("aa_genes_unfiltered_").str.strip("aa_genes_filtered_")
    #print merged_dict
#    print merged_dict.ix[:,0]
    return merged_dict

def merge_repeat_dfs_wf(df_dict,output_file,taxonomy_file):
    gene_level_hits=merge_repeat_dfs(df_dict)
    gene_level_hits['Taxonomy']=gene_level_hits['Genome_id'].map(taxonomy_file)
    cols = gene_level_hits.columns.tolist()
    cols=[cols[0]]+cols[-1:]+cols[1:-1]
    gene_level_hits=gene_level_hits[cols]
    gene_level_hits.to_csv(output_file,sep="\t",index=True)
    
    return gene_level_hits
    
def construct_n_genes_hits(df_dict,all_columns):
    genome_ids=df_dict.keys()
    column_names=pd.unique(itertools.chain(*df_dict.itervalues()))
    gene_count_df=pd.DataFrame(index=genome_ids,columns=all_columns)
    gene_count_df.index.name="Genome_id"
    
    for file_id, df in df_dict.iteritems():
        for repeat_motif,count in df.iteritems():
            gene_count_df.set_value(file_id,repeat_motif,count)
            
    return gene_count_df

def process_gene_hits(complete_df_dict,output_file,all_motifs,taxonomy_file):
    gene_hits={file_id: make_hmm_hits_to_gene_hits(df) for file_id, df in complete_df_dict.iteritems()}
    gene_df=construct_n_genes_hits(gene_hits,all_motifs)
    gene_df.reset_index(level=0,inplace=True)
    #gene_df['Genome_id']=gene_df['Genome_id'].str.strip("_genomic").str.strip("aa_genes_unfiltered_").str.strip("aa_genes_filtered_")
    #print gene_df.ix[:,0]
    #print taxonomy_file
    gene_df['Taxonomy']=gene_df['Genome_id'].map(taxonomy_file)
    cols = gene_df.columns.tolist()
    cols=[cols[0]]+cols[-1:]+cols[1:-1]
    gene_df=gene_df[cols]
    #Add taxonomy information
    gene_df.to_csv(output_file, sep="\t",index=False)
    
    return gene_df

def hmm_hits_in_gene(df,all_columns):
    genes_hit=pd.unique(df['query name'][False==pd.isnull(df['query name'])])
    repeats_found=pd.unique(df['# target name'])
    hit_df=pd.DataFrame(index=genes_hit,columns=all_columns)
    hit_df=hit_df.fillna(0)
    all_counts=df.groupby(['query name','# target name']).count()['of']
    for ((contig,repeat_motif),count) in all_counts.iteritems():
        hit_df.set_value(contig,repeat_motif,count)
    print hit_df
        
    return hit_df

def make_hmm_hits_to_gene_hits(df):
    #Check for any columns with at least one hit, add all true values
    #to get number of genes with that motif.
    return (df>0).sum(axis=0)

def get_all_target_values(df_dict):
    target_values=[]
    for genome_id, df in df_dict.iteritems():
        target_values.append(df['# target name'])
        
    return pd.unique(pd.concat(target_values,axis=0))

def hmm_hits_wf(input_dir, output_dir,taxonomy_file, header_line,comment_char,optional_reg_cut):
    all_dfs=all_domtblout_df(input_dir,header_line,comment_char,optional_reg_cut)
    #print "These are all of the domtbl dfs", all_dfs
    all_target_ids=get_all_target_values(all_dfs)
    #print "These are all of the target ids,", all_target_ids
    motif_hits_in_gene={file_id: hmm_hits_in_gene(df,all_target_ids) for file_id,df in all_dfs.iteritems()}
    genome_taxonomy=load_bin_names(taxonomy_file)
    genome_taxonomy=defaultdict(lambda: "No_predefined_Taxonomy",genome_taxonomy)
    #print "This is the output_dir", output_dir
    N_gene_hits_file=os.path.join(output_dir,"N_genes_with_hits.tsv")
    N_gene_hits=process_gene_hits(motif_hits_in_gene,N_gene_hits_file,all_target_ids,genome_taxonomy)
    hits_per_gene_file=os.path.join(output_dir,"hmm_hits_per_gene_per_genome.tsv")
    merged_hmm_hits=merge_repeat_dfs_wf(motif_hits_in_gene,hits_per_gene_file,genome_taxonomy)
    
    return N_gene_hits,motif_hits_in_gene
    
def load_bin_names(tax_file):
    #Load bin_ids and bins_taxonomy from file.
    bin_names={}
    bin_pair=[]
    with open(tax_file,'r') as bin_tax_pair:
        bin_tax_pair.readline()
        for line in bin_tax_pair:
            bin_pair.append(tuple(line.strip().split("\t")))

    bin_names={bin_id:taxonomy for taxonomy, bin_id in bin_pair}
    return bin_names

def load_fasta_file():
    
    
def flatten_fasta_file():
    
    return

def extract_regions(contig, fasta, positions):
    
    return



IndentationError: expected an indented block (<ipython-input-8-24cd53c03709>, line 386)

In [12]:
#opt_item_file="C:\Users\Baker\Google Drive\Honours\HMM_searches\HmmToolTest\optional_items.txt"
opt_item_file=None
extras_file=None

hmm_dir=os.path.join(*[core,g_drive,"HMM_searches","Symbioses_test","Hmms"])#"C:\Users\Baker\Google Drive\Honours\HMM_searches\Symbioses_test\Hmms"
output_dir=os.path.join(*[core,g_drive,"HMM_searches"])#"C:\Users\Baker\Google Drive\Honours\HMM_searches"
if not isinstance(opt_item_file,type(None)):
    opt_item_file=load_optional_items(opt_item_file,extras_file, database_dir)
print opt_item_file
project_name="symbioses_test_3"
new_directory=extract_local_hmms(hmm_dir,output_dir,project_name,opt_item_file)
make_hmm_database(new_directory)
table_dir=parallel_search_hmm_database(output_working_dir,hmm_database,seq_dir,cpus)

gene_hits_dir=os.path.join(output_dir,"gene_hits")

hmm_hits_wf(table_dir, gene_hits_dir,tax_file, 1,"#",None)

None


WindowsError: [Error 2] The system cannot find the file specified

# Classification of all tips in decorated tree.


In [2]:
import dendropy

tree_dir=os.path.join(*[output_dir, "enriched_hits","Test_trees"])
tree_file=os.path.join(*[tree_dir,"gtdb.decorated.tree"])

tree=dendropy.Tree.get(path=tree_file,schema="newick")

def taxonomic_ranks():
    return ['d__','p__','c__','o__','f__','g__','s__'] #temporarily removed 'k__'

def taxonomic_rank(tax_id):
    
    return 
    
def is_higher_rank(r1,r2):
    return rank_relation(r1,r2,"higher")

def rank_relation(r1,r2,rel):
    tax_levels=taxonomic_ranks()
    #print r1
    #print r2
    r1_level=None
    r2_level=None
    for i, rank in enumerate(tax_levels):
        search_pat=rank+"[^;]"
        #print rank
        match_1=re.search(search_pat,r1)
        match_2=re.search(search_pat,r2)
        if match_1 and r1_level==None:
            r1_level=i
        if match_2 and r2_level==None:
            r2_level=i
    #print "levels:",r1_level,r2_level
    if rel=="higher":
        if r1_level<r2_level:
            return True
        else:
            return False
        
    elif rel=="lower":
        if r1_level>r2_level:
            return True
        else:
            return False        
    elif rel=="same":
        if r1_level==r2_level:
            return True
        else:
            return False
    else:
        print "Not a valid relation option."
        return
    
def is_lower_rank(r1,r2):
    return rank_relation(r1,r2,"lower")

def is_same_rank(r1,r2):
    return rank_relation(r1,r2,"same")
    
def classify_all_tips(tree):
    tax_string_dict={}
    #cur_taxon=[]
    cur_label=''
    end_core=''
    for node in tree.preorder_node_iter():
        cur_tax=node.taxon
        test_label=node.label
        if isinstance(cur_tax,type(None)) and isinstance(test_label,type(None)):
            pass
        elif not isinstance(cur_tax,type(None)):
            cur_tax=cur_tax.label
            tax_string_dict[cur_tax]=tax_string(cur_tax)
            if cur_label!='':
                #print cur_tax
                #print cur_label
                tax_string_dict[cur_tax].add(cur_label)
        elif not isinstance(test_label,type(None)):
            if is_higher_rank(cur_label,test_label):
                #print "Considering higher rank", test_label
                cur_label=replace_tax_def(cur_label,test_label,'higher')
                #cur_label.strip().strip(";").strip()
            elif is_same_rank(cur_label,test_label):
                #print "Considering same rank", test_label
                cur_label=replace_tax_def(cur_label,test_label,'same')
                #cur_label.strip().strip(";").strip()
            elif is_lower_rank(cur_label,test_label):
                #print "Considering lower rank", test_label
                cur_label=test_label
                #cur_label.strip().strip(";").strip()
            else:
                print "Ah, an error. What can I do?", cur_label, test_label
                
    return tax_string_dict

def replace_tax_def(full_tax_string, ending_core, rel):
    #print "The taxa being considered", full_tax_string
    tax_levels=taxonomic_ranks()
    tax_dict={rank:'' for rank in tax_levels}
    
    if full_tax_string=="":
        obs_level=[]
    else:
        obs_level=full_tax_string.strip().strip(";").split(';')
        #print obs_level
        obs_level=[(det_rank(tax),tax) for tax in obs_level]
    #obs_level=zip(obs_level[0::2],obs_level[1::2])
    tent_level=[el.strip() for el in ending_core.strip().split(';')]
    #print "tentative level:", tent_level
    tent_level=[(det_rank(tax),tax) for tax in tent_level]
    #tent_level=zip(tent_level[0::2],tent_level[1::2])
    if rel=="higher" or rel=="same":
        last_rank=None
        for (rank, name) in obs_level:
            #key=rank.strip(";").strip()+"__"
            tax_dict[rank]=name.strip().strip(rank)
        for (rank, name) in tent_level:
            #key=rank.strip(";").strip()+"__"
            tax_dict[rank]=name.strip().strip(rank)
            last_rank=rank
        for rank in lower_ranks(last_rank):
            tax_dict[rank]=''

    new_tax_string=[]
    for rank in tax_levels:
        new_tax_string.append(rank+tax_dict[rank])
    return ";".join(new_tax_string)

def lower_ranks(rank):
    ranks=taxonomic_ranks()
    if rank not in ranks or isinstance(rank,type(None)):
        return []
    level=ranks.index(rank)
    return ranks[(level+1):]

def det_rank(one_tax):
    ranks=taxonomic_ranks()
    #print "The is the tax which made it here,", one_tax, type(one_tax)
    cur_rank=[rank for rank in ranks if one_tax.startswith(rank)][0]
    return cur_rank

def fill_taxonomic_gaps(tax_file, completed_tax_strings):
    
    return

class tax_string(object):
    def __init__(self,taxon):
        self.ranks=['d__','p__','c__','o__','f__','g__','s__'] #temporarily removed 'k__'
        self.tax_string={rank:'' for rank in self.ranks}
        self.name=taxon
        self.tax=";".join(self.ranks)
        
    def add(self,gg_tax):
        if not isinstance(gg_tax,type(None)) or gg_tax!='':
            levels=[tax_rank.strip() for tax_rank in gg_tax.strip().split(";")]
            for level in levels:
                if level!='':
                    try:
                        cur_rank=self.det_rank(level)
                    except:
                        print "The broken level:", level, gg_tax
                        raise
                    if self.tax_string[cur_rank]!='':
                        print self.name, self.tax
                        print "The current item:", self.tax_string[cur_rank]
                        print "The new item:", gg_tax
                        raise TaxOverlapError('This rank has already been defined.')
                    else:
                        self.tax_string[cur_rank]=level

                        tax_str=[]
                        for rank in self.ranks:
                            tax_str.append(self.tax_string[rank])
                        self.tax=";".join(tax_str)
                else:
                    pass

        else:
            return

    def det_rank(self,one_tax):
        #print "The is the tax which made it here,", one_tax, type(one_tax)
        cur_rank=[rank for rank in self.ranks if one_tax.startswith(rank)][0]
        return cur_rank
    
    def __str__(self):
        
        return self.tax
    
    def ret_pair(self):
        return (self.name,self.tax)
        
class TaxOverlapError(Exception):
    '''Base class error - indicates a taxonomic rank was already defined before trying to add it'''
    pass

def write_genome_taxonomies(tax_string_dict,output_file):
    
    df=make_tax_df(tax_string_dict)
    df.to_csv(output_file,header=True,sep="\t",index=False)
    
    return df

def make_tax_df(tax_string_dict):
    df=pd.DataFrame([
        [name, tax_string.tax] for name,tax_string in tax_string_dict.iteritems()
    ])
    df.columns=["Genome_name","Tax_string"]
    return df
    
def fill_missing_taxonomies(ref_tax_file, tax_df):
    ranks=taxonomic_ranks()
    with open(ref_tax_file,'r') as tax_file:
        tax="\n".join(tax_file.readlines())
    split_df=split_taxonomy(tax_df)
    for rank in ranks:
        observed_ranks=pd.unique(split_df[rank])
        for obs_rank in observed_ranks:
            if obs_rank!=rank:
                search_reg="(?:\t).*?{0}".format(obs_rank)
                match=re.search(search_reg,tax)
                if match:
                    upper_tax=match.group()
                    levels=[el.strip() for el in upper_tax.split(";")]
                    interesting_levels=levels[:-1] #Exclude the level searched for.
                    for tax_match in interesting_levels:
                        match_rank=tax_match.split("__")[0]+"__"
                        split_df.ix[split_df[rank]==obs_rank,match_rank]=tax_match

    return split_df

def split_taxonomy(tax_df):
    col_names=taxonomic_ranks()
    new_df=tax_df.ix[:,1].str.split(";",expand=True)
    #print new_df
    new_df.columns=col_names
    new_df=pd.concat([new_tax_df['Genome_name'],m],axis=1)
    return new_df
    
def taxonomy_reclassify_wf(tree_file,output_core,reference_tax_file):
    ranks=taxonomic_ranks()
    ttree=dendropy.Tree.get(path=tree_file,schema="newick")
    all_taxa=classify_all_tips(tree)
    tax_df=write_genome_taxonomies(all_taxa, output_core+"unfilled_taxonomies.tsv")
    split_df=fill_missing_taxonomies(reference_tax_file, tax_df)
    joined_df=pd.concat([split_df.ix[:,0],split_df[ranks].apply(lambda x: ';'.join(x),axis=1)],axis=1)
    joined_df.columns=['Genome_name','Phylogeny_string']
    
    joined_df.to_csv(output_core+"backfilled_taxonomies.tsv",index=False,header=True, sep="\t")
    
    return joined_df



ImportError: No module named dendropy

In [241]:
taxonomy_reclassify_wf(tree_file,os.path.join(*[core,g_drive,"tools","EnrichM","Taxonomies","new_20160821_pluteadata_"]) ,os.path.join(*[core,g_drive,"tools","EnrichM","Taxonomies","20160821_gtdb_taxonomy_file.tsv"]))

Unnamed: 0,Genome_name,Phylogeny_string
0,U_35948,d__Bacteria;p__Limnochordaeota;c__SHA-98;o__;f...
1,RS_GCF_001476715.1,d__Bacteria;p__Proteobacteria;c__Gammaproteoba...
2,U_35949,d__Bacteria;p__Desulfovibrionaeota;c__Syntroph...
3,RS_GCF_000293245.1,d__Bacteria;p__Proteobacteria;c__Gammaproteoba...
4,U_49765,d__Bacteria;p__Hydrogenedentes;c__;o__;f__;g__...
5,U_49764,d__Bacteria;p__Firmicutes;c__Clostridia;o__Clo...
6,U_49767,d__Bacteria;p__Desulfovibrionaeota;c__Desulfob...
7,U_49761,d__Bacteria;p__MBNT15aeota;c__MBNT15ia;o__MBNT...
8,U_49760,d__Bacteria;p__Bacteroidetes;c__Bacteroidia;o_...
9,U_49763,d__Bacteria;p__Chloroflexi;c__Anaerolineae;o__...


# Graph analysis functions

In [9]:
common_cpds

{'C00001',
 'C00002',
 'C00003',
 'C00004',
 'C00005',
 'C00006',
 'C00007',
 'C00008',
 'C00009',
 'C00010',
 'C00011',
 'C00013',
 'C00014',
 'C00015',
 'C00020',
 'C00022',
 'C00023',
 'C00027',
 'C00034',
 'C00035',
 'C00038',
 'C00044',
 'C00050',
 'C00054',
 'C00055',
 'C00060',
 'C00063',
 'C00068',
 'C00070',
 'C00075',
 'C00076',
 'C00080',
 'C00081',
 'C00104',
 'C00105',
 'C00112',
 'C00144',
 'C01330',
 'C01342',
 'C03028',
 'C14818',
 'C14819',
 'C19610'}

In [40]:
def process_db_entry(db_entry,entry_type):
    if entry_type=="rn:rp":
        db_entry=db_entry.strip()
        values=db_entry.split()
        RP_ID=values[0]
        CPD_PAIRS=values[1]
        #processed_entry="\t".join([RP_ID,CPD_PAIRS])
        return RP_ID,CPD_PAIRS
    elif entry_type=="mo:rn":
        pass
    elif entry_type=="rn:rp:cpd":
        pass
    return processed_entry

def make_local_complete_reaction_rpair_db(database_dir):
    '''Creates a local database of the module definitions.'''
    all_modules=load_readable_names(database_dir,["reaction"],False)["reaction"].keys()
    kc=kegg.KeggClientRest()
    entries=defaultdict(dict)
    max_len=10
    print "There are a total of {0} reactions to parse".format(len(all_modules))
    N_modules=len(all_modules)
    for i in xrange(0,N_modules,max_len):
        if N_modules-i<max_len:
            n_entries=N_modules-i
        else:
            n_entries=max_len      
        query="+".join(all_modules[i:i+n_entries])
        kegg_entries=kc.get_entry(query)

        hits=re.finditer("(?:RPAIR\s)((.|\n)*?)(?:[A-Z/]{3,})",kegg_entries)
        hits=[hit.group(1).strip() for hit in hits] #Remove captured newline character and ensure that it is not a tuple
        print i, i+max_len-1,"n_hits:{0}".format(len(hits))
        #if len(hits)!=max_len:
        #    print m
        processed_hits=[]
        for module, db_entry in itertools.izip(all_modules[i:i+10],hits):
            #print "The database entry", db_entry
            for entry in db_entry.split("\n"):
                #print "An individual RP entry", entry
                rp_id,cpd_links=process_db_entry(entry,"rn:rp")
                processed_hits.append(rp_id)
                #definition.strip().replace(" --"," ").replace("-- "," ").replace("  "," ").strip()
                entries[module][rp_id]=cpd_links
        #These post_processed modules should be modules defined in terms of other modules.
        if i%1000==0:
            kc=kegg.KeggClientRest()
            #print module
            print hits
            print processed_hits
            #print kegg_entries
    temp_entries=entries
    print "{0} reaction rpair info pairs were recovered".format(len(entries))
                
    #df=pd.DataFrame([
    #        [module,entry] for module,entry in entries.iteritems()
    #    ])
    try:
    df=pd.DataFrame([
            [reaction, rp_id, cpds] for reaction, rpairs in rp_pair_data.iteritems() for rp_id, cpds in entries.iteritems()
        ])
        df.columns=["Reaction_id","reaction_definition",'cpd_pairs']
        df.to_csv(os.path.join(database_dir,"Module_rn_rp_full_reaction_pairs.tsv"),sep="\t",index=None)
    except:
        return entries
    return df

def make_local_complete_rpair_cpd_db(database_dir):
    
    
    
    return

def load_local_rn_rp_cpd_db(database_dir):
    infile=os.path.join(database_dir,"Module_rn_rp_cpd_reaction_pairs.tsv")
    old_df=pd.read_csv(infile, sep="\t",index_col=[0,1])
    return old_df

In [47]:
df=pd.DataFrame([
        [reaction, rp_id, cpds] for reaction, rpairs in rp_pair_data.iteritems() for rp_id, cpds in rpairs.iteritems()
    ])

In [50]:
df.columns=["Reaction_id","reaction_definition","compound_pair"]
df.to_csv(os.path.join(database_dir,"Module_rn_rp_full_reaction_pairs.tsv"),sep="\t",index=None)

In [62]:
infile=os.path.join(database_dir,"Module_rn_rp_cpd_reaction_pairs.tsv")
old_df=pd.read_csv(infile, sep="\t",index_col=[0,1])

In [78]:
for i in old_df.loc["R06208"].loc["C00029"].values:
    print i

C00015


In [61]:
red_df=df.iloc[:,[0,2]]
red_df["cpd1"],red_df['cpd2']=zip(*df["compound_pair"].str.split("_").tolist())
#print red_df
del red_df['compound_pair']
new_df=red_df.iloc[:,[0,2,1]]
new_df.columns=red_df.columns
print new_df.ix[new_df.iloc[:,0]=='R06208',]
#print red_df.ix[red_df.iloc[:,0]=='R06208',]
#whole_df=pd.concat([new_df,red_df])
whole_df.to_csv(os.path.join(database_dir,"Module_rn_rp_cpd_reaction_pairs.tsv"),sep="\t",index=None)

  Reaction_id    cpd1    cpd2
0      R06208  C00029  C00015
  Reaction_id    cpd1    cpd2
0      R06208  C00015  C00029


In [41]:
rp_pair_data=make_local_complete_reaction_rpair_db(database_dir)

There are a total of 10235 reactions to parse
0 9 n_hits:2
['RP00012  C00015_C00029 main [RC:RC00005]', 'RP00144  C00010_C00223 main [RC:RC00004]\n            RP12154  C16312_C16351 main [RC:RC00041]']
['RP00012', 'RP00144', 'RP12154']
10 19 n_hits:9
20 29 n_hits:10
30 39 n_hits:9
40 49 n_hits:10
50 59 n_hits:10
60 69 n_hits:9
70 79 n_hits:9
80 89 n_hits:10
90 99 n_hits:10
100 109 n_hits:10
110 119 n_hits:10
120 129 n_hits:10
130 139 n_hits:10
140 149 n_hits:10
150 159 n_hits:10
160 169 n_hits:10
170 179 n_hits:10
180 189 n_hits:9
190 199 n_hits:9
200 209 n_hits:10
210 219 n_hits:10
220 229 n_hits:10
230 239 n_hits:0
240 249 n_hits:9
250 259 n_hits:2
260 269 n_hits:10
270 279 n_hits:10
280 289 n_hits:10
290 299 n_hits:10
300 309 n_hits:9
310 319 n_hits:8
320 329 n_hits:7
330 339 n_hits:9
340 349 n_hits:8
350 359 n_hits:9
360 369 n_hits:3
370 379 n_hits:10
380 389 n_hits:10
390 399 n_hits:10
400 409 n_hits:10
410 419 n_hits:10
420 429 n_hits:9
430 439 n_hits:10
440 449 n_hits:8
450 459 

In [11]:
def generate_links(Genome_ko_hits,output_dir):
    KO_RN_dict = load_local_kegg_database_pairings(os.path.join(output_dir,"Databases"),[("orthology","reaction")], False)[("orthology","reaction")]
    all_linkable_KOs = set(itertools.chain(*Genome_ko_hits.itervalues())) & set(KO_RN_dict.keys())
    cpd_links = {KO:set(KO_RN_dict[KO]) for KO in all_linkable_KOs}
    RN_CPD_dict = load_local_kegg_database_pairings(os.path.join(output_dir,"Databases"),[("reaction","compound")], False)[("reaction","compound")]
    all_linkable_RNs=set(itertools.chain(*cpd_links.itervalues())) & set(RN_CPD_dict.keys())
    RN_list={RN:set(RN_CPD_dict[RN]) for RN in all_linkable_RNs}
    rn_cpd_gen_links={}
    ko_rn_gen_links={}
    for genome, kos in Genome_ko_hits.iteritems():
        cur_kos=set(kos) & all_linkable_KOs
        ko_rn_gen_links[genome]={ko:cpd_links[ko] for ko in cur_kos}
        cur_rns=set(itertools.chain(*ko_rn_gen_links[genome].itervalues())) & all_linkable_RNs
        rn_cpd_gen_links[genome]={rn: RN_list[rn] for rn in cur_rns}
    return rn_cpd_gen_links, ko_rn_gen_links

def store_rcn_graphs(Genome_ko_hits,common_cpds,reaction_dir,output_dir):
    rn_cpd_gen_links,ko_rn_gen_links=generate_links(Genome_ko_hits,output_dir)
    cpd_gen_graph={}
    for genome, rn_dict in rn_cpd_gen_links.iteritems():
        cpd_gen_graph[genome]=reaction_graph(rn_dict,common_cpds)
    for genome, rn_graph in cpd_gen_graph.iteritems():
        save_reaction_graph(rn_graph,os.path.join(reaction_dir,"{0}_graph_single_genome_reactions.xml".format(genome)))
        
    return cpd_gen_graph

def reaction_graph(CPD_rn_links,common_cpds):
    connected_edges={}
    for reac_1, reac_2 in itertools.combinations(CPD_rn_links.iterkeys(),2):
        cpds_1=CPD_rn_links[reac_1]
        cpds_2=CPD_rn_links[reac_2]
        link=(cpds_1 & cpds_2) - common_cpds
        if len(link)>0:
            connected_edges[(reac_1,reac_2)]=link

    rcn_graph=nx.Graph()
    for node,cpds in CPD_rn_links.iteritems():
        rcn_graph.add_node(node,cpds=cpds)
    
    rcn_graph.add_edges_from([edge+({"overlap":overlap,"weight":0.5},) for edge, overlap in connected_edges.iteritems()])
    rcn_graph=remove_self_loops(rcn_graph)
    return rcn_graph

def remove_self_loops(networkx_graph):
    for edge in networkx_graph.edges():
        if edge[0]==edge[1]: #self -loop
            networkx_graph.remove_edge(*edge)
    return networkx_graph

def remove_set_data_edges(reaction_graph,item_name):
    new_graph=reaction_graph.copy()
    new_items={edge:";".join(cpd_overlap) for edge, cpd_overlap in nx.get_edge_attributes(reaction_graph,item_name).iteritems()}
    nx.set_edge_attributes(new_graph,item_name,new_items)
    return new_graph

def remove_set_data_nodes(rcn_graph,item_name):
    new_graph=rcn_graph.copy()
    new_cpds={node:";".join(cpd_list) for node, cpd_list in nx.get_node_attributes(rcn_graph,item_name).iteritems()}
    nx.set_node_attributes(new_graph,item_name,new_cpds)
    return new_graph

def save_reaction_graph(rcn_graph,outfile):
    rcn_graph=remove_set_data_edges(rcn_graph,"overlap")
    rcn_graph=remove_set_data_nodes(rcn_graph,"cpds")
    nx.write_graphml(rcn_graph, outfile)
    return

#nx.write_graphml(rcn_graph, os.path.join(output_dir,'graphs{0}{1}_reaction_graph_single_genome.xml'.format(os.sep,"coral")))
def top_k_hits(list_items,k):
    from collections import Counter
    return sorted(Counter(list_items).iteritems(),key=lambda x: x[1], reverse=True)[:k]

def load_graphs(input_folder,glob_name):
    graph_dict={}
    for fname in glob(os.path.join(input_folder,glob_name)):
        base=os.path.basename(fname)
        genome_id=base.split('_graph_')[0] #Get Genome id
        graph_=nx.read_graphml(fname)
        string_to_set(graph_,"overlap","cpds")
        graph_.name=genome_id
        graph_dict[genome_id]=graph_
    return graph_dict

def string_to_set(rcn_graph,edge_item,node_item):
    new_edge_data={edge:set(cpds.split(";")) for edge, cpds in nx.get_edge_attributes(rcn_graph,edge_item).iteritems()}
    nx.set_edge_attributes(rcn_graph,edge_item,new_edge_data)
    new_node_data={node:set(cpds.split(";")) for node,cpds in nx.get_node_attributes(rcn_graph,node_item).iteritems()}
    nx.set_node_attributes(rcn_graph,node_item,new_node_data)  
    
    return None

def merge_graphs(rcn_graph1,rcn_graph2,commond_cpds):
    '''Creates a merged graph from a pair of graphs. This is done by finding the reactions
    unique to each genome and then scanning all of the reactions in the other genome to create new links.
    Input:
        rcn_graph1    - The first reaction graph.
        rcn_graph2    - The second reaction graph.
        common_cpds   - The set of uninteresting or hyperconnected compounds.
    Output:
        merged_graph  - A graph created from merging rcn_graph1 and rcn_graph2
    '''
    
    graph_1_node_compounds=nx.get_node_attributes(rcn_graph1,"cpds")
    graph_2_node_compounds=nx.get_node_attributes(rcn_graph2,"cpds")
    graph_1_nodes=set(graph_1_node_compounds.keys())
    graph_2_nodes=set(graph_2_node_compounds.keys())
    graph_1_unique_nodes=graph_1_nodes-graph_2_nodes
    graph_2_unique_nodes=graph_2_nodes-graph_1_nodes
    shared_nodes=graph_1_nodes & graph_2_nodes
    
    graph_1_uniq_links=connect_reactions(graph_2_node_compounds,graph_1_unique_nodes,graph_1_node_compounds,common_cpds)
    
    graph_2_uniq_links=connect_reactions(graph_1_node_compounds,graph_2_unique_nodes,graph_2_node_compounds,common_cpds)
    
    merged_graph=rcn_graph1.copy()
    merged_graph.name=rcn_graph1.name+":"+rcn_graph2.name
    merged_graph.add_nodes_from([(node,{"cpds":graph_2_node_compounds[node]})for node in graph_2_unique_nodes],genome=rcn_graph2.name)
    nx.set_node_attributes(merged_graph,"genome",{node:rcn_graph1.name+":"+rcn_graph2.name for node in shared_nodes})
    nx.set_node_attributes(merged_graph,"genome",{node:rcn_graph1.name for node in graph_1_unique_nodes})
    
    merged_graph.add_edges_from([edge+({"overlap":overlap,"weight":0.5},) for edge, overlap in graph_1_uniq_links.iteritems()])
    merged_graph.add_edges_from([edge+({"overlap":overlap,"weight":0.5},) for edge, overlap in graph_2_uniq_links.iteritems()])
    
    return merged_graph

def connect_reactions(node_cpd_pairs_to_join,new_nodes,new_node_cpd_pairs,common_cpds):
    '''Find the connections between reactions from two different sets of reactions or between the pairings
    of new nodes.
    
    Input:
    Output:
    '''
    new_rn_cpd_pairs={rn:cpds for rn,cpds in new_node_cpd_pairs.iteritems() if rn in new_nodes}
    connected_edges={}
    for reac_1, reac_2 in itertools.product(new_rn_cpd_pairs.iterkeys(),node_cpd_pairs_to_join.iterkeys()):
        cpds_1=new_rn_cpd_pairs[reac_1]
        cpds_2=node_cpd_pairs_to_join[reac_2]
        link=(cpds_1 & cpds_2) - common_cpds
        if len(link)>0:
            connected_edges[(reac_1,reac_2)]=link
            
    for reac_1, reac_2 in itertools.combinations(new_rn_cpd_pairs.iterkeys(),2):
        cpds_1=new_rn_cpd_pairs[reac_1]
        cpds_2=new_rn_cpd_pairs[reac_2]
        link=(cpds_1 & cpds_2) - common_cpds
        if len(link)>0:
            connected_edges[(reac_1,reac_2)]=link
            
    return connected_edges
    
def pairwise_merged_graphs(genome_reaction_graphs,to_analyse):
    if not to_analyse:
        pairwise_merged_genomes={}
    else:
        pass
    for genome_1,genome_2 in itertools.combinations(genome_reaction_graphs.iterkeys(),2):
        merged_genome=merge_graphs(genome_reaction_graphs[genome_1],genome_reaction_graphs[genome_2],common_cpds)
        if to_analyse:
            analyse_merged_graph(merged_genome)
        else:
            pairwise_merged_genomes[(genome_1,genome_2)]=merged_genome
    if not to_analyse:
        return pairwise_merged_genomes
    else:
        return None
    
def analyse_merged_graph(merged_graph):
    return None

def extract_starting_ending_reactions(reaction_graph,start_cpd,end_cpd,item_name):
    start_rns=set([])
    end_rns=set([])
    for reaction_data in reaction_graph.nodes_iter(data=True):
        cpds=reaction_data[1][item_name]
        start=cpds & set([start_cpd])
        end=cpds & set([end_cpd])
        if len(start)>0:
            start_rns.add(reaction_data[0])
        if len(end)>0:
            end_rns.add(reaction_data[0])
    return start_rns, end_rns

def extract_all_starting_ending_reacitons(reaction_graph, start_end_pairs,item_name):
    start_rns=defaultdict(set)
    end_rns=defaultdict(set)
    paired_reactions=[]
    for reaction_data in reaction_graph.nodes_iter(data=True):
        cpds=reaction_data[1][item_name]
        for start_cpd, end_cpd in start_end_pairs:
            start=cpds & set([start_cpd])
            end=cpds & set([end_cpd])
            if len(start)>0:
                start_rns[start_cpd].add(reaction_data[0])
            if len(end)>0:
                end_rns[end_cpd].add(reaction_data[0])
                
    for start_cpd,end_cpd in start_end_pairs:
        yield start_rns[start_cpd],end_rns[end_cpd]
        
def all_reaction_pairs(start_end_iter):
    for start_rns, end_rns in start_end_iter:
        for start_rn,end_rn in itertools.product(start_rns,end_rns):
            yield start_rn, end_rn
            
def _unique_compounds(merged_graph,genome_id):
    uniq_rns=set([])
    for node_data in merged_graph.nodes_iter(data=True):
        if node_data[1]['genome']==genome_id or genome_id not in node_data[1]['genome']:
            uniq_rns.add(node_data[0])
    return unique_rns

def _unique_compound_pairs(merged_graph,genome_id_pair):
    gen_1_uniq_rns=set([])

    gen_2_uniq_rns=set([])
    genome_id_1,genome_id_2=genome_id_pair
    for node_data in merged_graph.nodes_iter(data=True):
        if node_data[1]['genome']==genome_id_1:
            gen_1_uniq_rns.add(node_data[0])
        if node_data[1]['genome']==genome_id_2:
            gen_2_uniq_rns.add(node_data[0])
    return gen_1_uniq_rns,gen_2_uniq_rns

def automatic_difference_check(pairwise_merged_graph):
    latest_genome=pairwise_merged_graph.name.split(":")
    old_name=":".join(latest_genome[:-1])
    latest_genome=latest_genome[-1]
    gen_1_reactions, gen_2_reactions=_unique_compound_pairs(pairwise_merged_graph,(old_name,latest_genome))
    pairwise_paths={}
    for rcn1, rcn2 in itertools.product(gen_1_reactions,gen_2_reactions):
        pairwise_paths[rcn1,rcn2]=find_paths(pairwise_merged_graph,rcn1,rcn2)
    return pairwise_paths

def write_graph_dict(output_dir, graph_dict,pairs):
    if pairs:
        for genome_pair,genome in graph_dict.iteritems():
            save_reaction_graph(genome,os.path.join(output_dir,"{0}_graph_merged_genome_reactions.xml".format("|".join(genome_pair))))
    else:
        for genome, graph in graph_dict.iteritems():
            save_reaction_graph(graph, os.path.join("{0}_graph_single_genome_reactions.xml".format(genome)))
        
def store_local_kegg_item_keys(kegg_items,database_dir):
    kc=kegg.KeggClientRest()
    all_item_names={}
    illegal_pairs=[("compound","orthology"),("orthology","compound")]
    for kegg_item in kegg_items:
        key_name=os.path.join(database_dir,"{0}_readable_names.tsv".format(kegg_item))
        if not os.path.isfile(key_name):
            item_names=kc.get_ids_names(kegg_item)
            save_readable_key(key_name,item_names,kegg_item)
            all_item_names[kegg_item]=item_names.keys()
        else:
            all_item_names[kegg_item]=load_readable_names(database_dir,[kegg_item],False)[kegg_item].keys()
    for kegg_item_1, kegg_item_2 in itertools.permutations(all_item_names.iterkeys(),2):
        print "considering the pair: {0}, {1}".format(kegg_item_1,kegg_item_2)
        if (kegg_item_1,kegg_item_2) not in illegal_pairs:
            shared_key_name=os.path.join(database_dir,"{0}_linked_{1}_database.tsv").format(kegg_item_1, kegg_item_2)
            if not os.path.isfile(shared_key_name):
                print "The processing of pair: {0},{1} has begun.".format(kegg_item_1,kegg_item_2)
                kc=kegg.KeggClientRest()
                linked_ids=kc.link_ids(kegg_item_2,all_item_names[kegg_item_1])
                save_key_pairings(shared_key_name,linked_ids,(kegg_item_1,kegg_item_2))
        else:
            pass
    return

def load_local_kegg_database_pairings(database_dir,kegg_item_pairs, process_all):
    '''Loads the local databases of kegg_item_1, kegg_item_2 pairings and return a dictionary of
    these pairings in the form kegg_item_1:kegg_items_2 (There can be more than one linked item). This
    loading is based on the earlier use of mgkits kc.link_ids to store all of the pairings needed.
    
    Input:
        database_dir   - The directory with the databases
        kegg_item_pairs- A list of kegg item pairs to load
        process_all    - A boolean decision as whether to load all existing pairs.
        
    Output: A dictionary linking either all existing kegg item pairs or just those specified. It has the form
    dict[item_1,item_2]={kegg_item_1:kegg_2_items}'''
    linking_dictionary={}
    if process_all:
        for file_name in glob(os.path.join(database_dir,'*database.tsv')):
            db_file=os.basename(file_name)
            kegg_1=db_file.split("_linked_")[0]
            kegg_2=db_file.split("_linked_")[1].split("_database")[0]
            linking_dictionary[(kegg_1,kegg_2)]={}             
            with open(file_name) as kegg_links:
                next(kegg_links)#Skip the header
                for line in kegg_links:
                    item_1,item_2=line.strip().split("\t")
                    item_2=item_2.split(";")
                    linking_dictionary[kegg_item_pair][item_1]=item_2
        return linking_dictionary
    
    for kegg_item_pair in kegg_item_pairs:
        file_name=os.path.join(database_dir,"{0}_linked_{1}_database.tsv").format(kegg_item_pair[0], kegg_item_pair[1])
        if os.path.isfile(file_name):
            linking_dictionary[kegg_item_pair]={}
            with open(file_name) as kegg_links:
                next(kegg_links) #skip the header
                for line in kegg_links:
                    item_1,item_2=line.strip().split("\t")
                    item_2=item_2.split(";")
                    linking_dictionary[kegg_item_pair][item_1]=item_2
    return linking_dictionary
                    
def load_readable_names(database_dir,kegg_items,process_all):
    '''Loads in the readable names for a specified kegg item from a list of databases.
    Input:
        database_dir        -  The directory with the databases.
        kegg_items          -  The kegg items to get the readable mapping for.
        process_all         -  Boolean - Should the function retrieve all available databases.
    Output:
        readable_item_dict  -  A dictionary of KEGG_ID: Readable name pairs'''
    readable_item_dict={}
    if process_all:
        for file_name in glob(os.path.join(database_dir,'*_readable_names.tsv')):
            desc_file=os.basename(file_name)
            kegg_item=desv_file.split("_readable_names.tsv")[0]
            readable_item_dict[kegg_item]={}
            with open(file_name) as kegg_descriptions:
                next(kegg_descriptions)
                for line in kegg_descriptions:
                    item_1,item_2=line.strip().split("\t")
                    readable_item_dict[kegg_item][item_1]=item_2
        return readable_item_dict
    
    for kegg_item in kegg_items:
        file_name=os.path.join(database_dir,'{0}_readable_names.tsv'.format(kegg_item))
        readable_item_dict[kegg_item]={}
        with open(file_name) as kegg_descriptions:
            next(kegg_descriptions)
            for line in kegg_descriptions:
                item_1,item_2=line.strip().split("\t")
                readable_item_dict[kegg_item][item_1]=item_2
    return readable_item_dict
            
            
def save_readable_key(key_name,item_names,kegg_item):
    df=pd.DataFrame([
    [col1,col2] for col1,col2 in item_names.iteritems()
                   ])
    df.columns=[kegg_item,"Description"]
    df.to_csv(key_name,sep="\t",index=None)
    return None

def save_key_pairings(shared_key_name,item_links,kegg_item_tuple):
    df=pd.DataFrame([
            [col1,";".join(col2)] for col1, col2 in item_links.iteritems()
        ])
    df.columns=[kegg_item_tuple[0],kegg_item_tuple[1]]
    df.to_csv(shared_key_name,sep="\t",index=None)
    return None

def remove_ko_pth_hits(file_path):
    
    with open(file_path,'r') as KO_PTH_pairs:
        out_dir=os.path.dirname(file_path)
        temp_file=open(os.path.join(out_dir,"temp.tsv"),'w')
        for line in KO_PTH_pairs:
            KO,pathways=line.strip().split("\t")
            pathways=pathways.split(";")
            pathways=[pathway for pathway in pathways if not pathway.startswith("ko")]
            pathways=";".join(pathways)
            new_line="{0}\t{1}\n".format(KO,pathways)
            temp_file.write(new_line)
    temp_file.close()
    
#FROM NETWORKX DOCUMENTATION
def k_shortest_paths(graph,start,end,k,weight=None):
    '''Returns the k shorest paths going from the start node to end node base on the weight values.'''
    return list(itertools.islice(nx.shortest_simple_paths(graph,start,end,weight=weight),k))

def extract_successor_neighbours(G,node, rcn_eqn_data,eqn_pair,cpd=None):
    if isinstance(cpd,type(None)):
        neighbours=G.neighbors(node)
        for neighbour in neighbours:
            yield neighbour
    else:
        out_compounds=set([])
        current_node=rcn_eqn_data[node]
        for side_1,side_2 in eqn_pair.iteritems():
            init_cpds=cpd & current_node[side_1]
            if len(init_cpds)>0:
                out_compounds.update(current_node[side_2])                
        for edge in G.edges_iter(node,data=True):
            new_connection=edge[2]["overlap"] & out_compounds
            if len(new_connection)>0:
                #print edge
                yield edge[1],edge[2]["overlap"]
            
def extract_predecessor_neighbours(G,node,rcn_eqn_data,eqn_pair,cpd=None):
    if isinstance(cpd,type(None)):
        neighbours=G.neighbors(node)
        for neighbour in neighbours:
            yield neighbour
    else:
        out_compounds=set([])
        current_node=rcn_eqn_data[node]
        for side_1,side_2 in eqn_pair.iteritems():
            init_cpds=cpd & current_node[side_1]
            if len(init_cpds)>0:
                out_compounds.update(current_node[side_1])            
        for edge in G.edges_iter(node,data=True):
            new_connection=edge[2]["overlap"] & out_compounds
            if len(new_connection)>0:
                #print edge
                yield edge[1],edge[2]["overlap"]

def predecessor_neighbours(G,node,rcn_eqn_data,eqn_pair,cpd):
    return list(extract_predecessor_neighbours(G,node,rcn_eqn_data,eqn_pair,cpd))

def successor_neighbours(G,node,rcn_eqn_data,eqn_pair,cpd):
    return list(extract_successor_neighbours(G,node,rcn_eqn_data,eqn_pair,cpd))
        
def get_eqn_pair():
    return {'side_2_cpds':'side_1_cpds','side_1_cpds':'side_2_cpds'}

def store_unique_cpd_rn_node_graphs(Genome_ko_hits,common_cpds,reaction_dir,output_dir,database_dir,double_links):
    rn_cpd_gen_links,ko_rn_gen_links=generate_links(Genome_ko_hits,output_dir)
    rcn_eqn_links=load_local_rcn_eqn_database_set(database_dir)
    
    side_pairs={'side_2_cpds':'side_1_cpds','side_1_cpds':'side_2_cpds'}
    
    cpd_gen_graph={}
    
    against_side_spec_rn_links=make_side_specific_cpd_links(rn_cpd_gen_links,rcn_eqn_links,side_pairs,common_cpds)
    print "The graphs are being made."
    
    start_time=datetime.datetime.now()
    for genome, rn_dict in against_side_spec_rn_links.iteritems():
        current_time=datetime.datetime.now()
        print "{1} has started being processed at: {0}".format(current_time,genome)
        print "Duration since start: {0}".format(current_time-start_time)
        file_name=os.path.join(reaction_dir,"{0}_graph_side_specific_single_genome_reactions.xml".format(genome))
        if not os.path.isfile(file_name):
            current_graph=create_unique_reaction_graph(rn_dict,common_cpds,double_links)
            save_reaction_graph(current_graph,file_name)
            print "Duration of making graph {0}".format(datetime.datetime.now()-current_time)
        
    return

def make_side_specific_cpd_links(rn_cpd_links,rcn_eqn_links,opposing_sides,common_cpds):
    all_viable_rns=set(rcn_eqn_links.keys())
    side_spec_gen_rns={}

    for genome, rn_cpds in rn_cpd_links.iteritems():
        side_spec_gen_rns[genome]={}
        linkable_RNs=all_viable_rns & set(rn_cpds.keys())
        for RN in linkable_RNs:
            for side, cpds in rcn_eqn_links[RN].iteritems():
                cleaned_cpds=cpds-common_cpds
                if len(cleaned_cpds)>0:
                    for cpd in cleaned_cpds:
                        side_spec_gen_rns[genome][";".join([RN,cpd])]=rcn_eqn_links[RN][opposing_sides[side]]
    return side_spec_gen_rns


def create_unique_reaction_graph(against_rn_links,common_cpds,double_links):
    connected_edges={}
    N_double_links=0
    for reac_1, reac_2 in itertools.combinations(against_rn_links.iterkeys(),2):
        
        both=0
        node_name_1=reac_1.split(";")
        node_name_2=reac_2.split(";")
        
        if node_name_1[0]!=node_name_2[0]:
            cpds_1=against_rn_links[reac_1]
            cpd_2=set([node_name_2[1]])
            link=(cpds_1 & cpd_2)
            if len(link)>0:
                connected_edges[(reac_1,reac_2)]=link
                both+=1

            cpds_1=against_rn_links[reac_2]
            cpd_2=set([node_name_1[1]])
            link=(cpds_1 & cpd_2)-common_cpds
            if len(link)>0:
                connected_edges[(reac_2,reac_1)]=link
                both+=1

            if both==2:
                if not double_links:
                    connected_edges.pop((reac_1,reac_2))
                    connected_edges.pop((reac_2,reac_1))
                else:
                    connected_edges[(reac_1,reac_2)]=connected_edges[(reac_1,reac_2)]|connected_edges.pop((reac_2,reac_1))
                N_double_links+=1
                
    print "There were {0} doubly linked reactions in this set of comparisons.".format(N_double_links)
    
    rcn_graph=nx.DiGraph()
    for node,cpds in against_rn_links.iteritems():
        rcn_graph.add_node(node,cpds=cpds)
    
    rcn_graph.add_edges_from([edge+({"overlap":overlap,"weight":0.5},) for edge, overlap in connected_edges.iteritems()])
    rcn_graph=remove_self_loops(rcn_graph)
    return rcn_graph


def merge_directed_input_specific_rn_graphs(rcn_graph1, rcn_graph2, common_cpds):
    '''Creates a merged graph from a pair of graphs. This is done by finding the reactions
    unique to each genome and then scanning all of the reactions in the other genome to create new links.
    Input:
        rcn_graph1    - The first reaction graph.
        rcn_graph2    - The second reaction graph.
        common_cpds   - The set of uninteresting or hyperconnected compounds.
    Output:
        merged_graph  - A graph created from merging rcn_graph1 and rcn_graph2
    '''
    
    
    graph_1_node_compounds=nx.get_node_attributes(rcn_graph1,"cpds")
    graph_2_node_compounds=nx.get_node_attributes(rcn_graph2,"cpds")
    graph_1_nodes=set(graph_1_node_compounds.keys())
    graph_2_nodes=set(graph_2_node_compounds.keys())
    graph_1_unique_nodes=graph_1_nodes-graph_2_nodes
    graph_2_unique_nodes=graph_2_nodes-graph_1_nodes
    shared_nodes=graph_1_nodes & graph_2_nodes
    
    graph_1_uniq_links=connect_reactions(graph_2_node_compounds,graph_1_unique_nodes,graph_1_node_compounds,common_cpds)
    
    graph_2_uniq_links=connect_reactions(graph_1_node_compounds,graph_2_unique_nodes,graph_2_node_compounds,common_cpds)
    
    merged_graph=rcn_graph1.copy()
    merged_graph.name=rcn_graph1.name+":"+rcn_graph2.name
    merged_graph.add_nodes_from([(node,{"cpds":graph_2_node_compounds[node]})for node in graph_2_unique_nodes],genome=rcn_graph2.name)
    nx.set_node_attributes(merged_graph,"genome",{node:rcn_graph1.name+":"+rcn_graph2.name for node in shared_nodes})
    nx.set_node_attributes(merged_graph,"genome",{node:rcn_graph1.name for node in graph_1_unique_nodes})
    
    merged_graph.add_edges_from([edge+({"overlap":overlap,"weight":0.5},) for edge, overlap in graph_1_uniq_links.iteritems()])
    merged_graph.add_edges_from([edge+({"overlap":overlap,"weight":0.5},) for edge, overlap in graph_2_uniq_links.iteritems()])
    
    return merged_graph

def connect_directed_reactions(graph_1_node_compounds,unique_nodes,graph_2_node_compounds ,common_cpds):
    
    return

def unique_cpds_graph_pair(graph_1,graph_2):
    
    all_cpds_g1=set(itertools.chain(*nx.get_node_attributes(graph_1,"cpds").itervalues()))
    all_cpds_g2=set(itertools.chain(*nx.get_node_attributes(graph_2,"cpds").itervalues()))
    uniq_g1=all_cpds_g1- all_cpds_g2
    uniq_g2=all_cpds_g2- all_cpds_g1
    
    unique_cpds={graph_1.name:uniq_g1,graph_2.name:uniq_g2}
    
    return unique_cpds

def pairwise_unique_cpds(genome_graph_dict):
    unique_cpds={}
    for genome_1, genome_2 in itertools.combinations(genome_graph_dict.iterkeys(),2):
        unique_cpds[(genome_1,genome_2)]=unique_cpds_graph_pair(genome_graph_dict[genome_1],genome_graph_dict[genome_2])
    
    return unique_cpds

def write_uniq_compounds(output_file,uniq_cpds):
    return

def get_starting_nodes(rn_graph,cpd):
    '''Extract all nodes matching the specific compound in the graph. This extracts that compound as the starting
    compound for the reaction.
    '''
    node_names=np.array(rn_graph.nodes())
    matching_nodes=[]
    for element in node_names:
        if element.endswith(cpd):
            matching_nodes.append(element)
    return matching_nodes

def get_ending_nodes(rn_graph,cpd):
    '''Extract all nodes where the cpd is an end product.'''
    cpd_data=nx.get_node_attributes(rn_graph,"cpds")
    matching_nodes=[]
    for node, cpds in cpd_data.iteritems():
        if cpd in cpds:
            matching_nodes.append(node)
    return  matching_nodes

def extract_node_pairs(rn_graph,starting_cpds,ending_cpds):
    start_cpd_nodes={}
    end_cpd_nodes={}
    failures={}
    failures["starts"]=[]
    failures["ends"]=[]
    for cpd_1 in starting_cpds:
        cpd_1_nodes=get_starting_nodes(rn_graph,cpd_1)
        if len(cpd_1_nodes)>0:
            start_cpd_nodes[cpd_1]=cpd_1_nodes
        else:
            failures["starts"].append(cpd_1)
    for cpd_2 in ending_cpds:
        cpd_2_nodes=get_ending_nodes(rn_graph,cpd_2)
        if len(cpd_2_nodes)>0:
            end_cpd_nodes[cpd_2]=cpd_2_nodes
        else:
            failures["ends"].append(cpd_2)
    return start_cpd_nodes,end_cpd_nodes,failures

#FROM NETWORKX DOCUMENTATION
def k_shortest_paths(graph,start,end,k,weight=None):
    return list(itertools.islice(nx.shortest_simple_paths(graph,start,end,weight=weight),k))
##############################
def extract_pathways(graph, cpd_pairs,output_file):
    start_cpds,end_cpds=zip(*cpd_pairs)
    start_cpd_nodes,end_cpd_nodes, failures=extract_node_pairs(graph, start_cpds,end_cpds)
    path_data={}
    for (cpd_1,cpd_2) in cpd_pairs:
        path_data[(cpd_1,cpd_2)]={}
        start_nodes=start_cpd_nodes[cpd_1]
        end_nodes=end_cpd_nodes[cpd_2]
        for node_1,node_2 in itertools.product(start_nodes,end_nodes):
            path_data[(cpd_1,cpd_2)][(node_1,node_2)]=k_shortest_paths(graph,node_1,node_2,k)
    return path_data

def make_pd_dataframe(genome_path_data):
    pd.df([
            []
        ])
    return

def save_path_data(path_data,outfile):
        
    return

def networkx_k_shortest_paths(graph,cpds1,cpds2,k):
    if len(cpds1)!=len(cpds2):
        print "The cpds must came as start, end pairs. There must be the same number of starting nodes as finishing nodes"
        return None
    cpd_pair_k_shortest_paths={}
    graph_cpds=graph.nodes()
    for (start,end) in itertools.izip(cpds1,cpds2):
        if start in graph_cpds and end in graph_cpds:
            cpd_pair_k_shortest_paths[(start,end)]=k_shortest_paths(graph,start,end,k)
        else:
            cpd_pair_k_shortest_paths[(start,end)]=[[]]
    return cpd_pair_k_shortest_paths

def shortest_path_to_pd_dataframe(shortest_paths,database_dir):
    cpd_names =load_readable_names(database_dir,["compound"],False)["compound"]
    rn_names  = load_readable_names(database_dir,["reaction"],False)["reaction"]
    
    df=pd.DataFrame([
            [cpd1, cpd2, cpd_names[cpd1],cpd_names[cpd2],"||".join(clean_path(path,cpd_names,rn_names,readable=False)), "||".join(clean_path(path,cpd_names,rn_names,readable=True))] for (cpd1, cpd2), path in shortest_paths.iteritems() 
        ])
    df.columns=["Start_CPD_ID","End_CPD_ID",'Start_CPD',"End_compound","CPD_PATH","READABLE_CPD_PATH"]
    return df

def clean_path(path,readable=False,cpd_names=None,rn_names=None):
    #cpd_names =load_readable_names(database_dir,["compound"],False)["compound"]
    #rn_names  = load_readable_names(database_dir,["reaction"],False)["reaction"]
    cleaned_path=[]
    for node in path:
        rn,cpd=node.split(";")
        if readable:
            rn=rn_names[rn]
            cpd=cpd_names[cpd]
        cleaned_node="{0}({1})".format(rn,cpd)
        cleaned_path.append(cleaned_node)
    return cleaned_path

def save_output(pd_path_dataframe,output_file):
    pd_path_dataframe.to_csv(output_file,sep="\t",header=True,index=False)
    return None

def shortest_paths_pd_dataframe(k_shortest_paths):
    kc = kegg.KeggClientRest()
    cpd_names = kc.get_ids_names('compound')
    df=pd.DataFrame([
            [cpd1, cpd2, cpd_names[cpd1],cpd_names[cpd2],"||".join(path), "||".join([cpd_names[cpd] for cpd in path])] for (cpd1, cpd2), paths in k_shortest_paths.iteritems() for path in paths 
        ])
    df.columns=["KEGGCpd1ID","KEGGCpd2ID",'CPD1_readable',"CPD2_readable","CPD_PATH","READABLE_CPD_PATH"]
    return df

def is_biological(path,rcn_eqn_pairs):
    '''Does a simple screen for whether a path makes biological sense.
    One of this key criteria is that if a compound was used as an auxilliary input
    then it should not later be a key ongoing compound. Using the same compound to 
    produce itself later is inefficient and doesn't make sense.
    Input:
        path    -    list
            List of nodes representing a path through the graph.
        rcn_eqn_pairs    -    dict
            A dictionary of reactions paired with there equation information, i.e
            the compounds which are products and reactants separately based on 
            context.
    Output:
        isbiological    -    Boolean
            Does the path make biological sense.'''
    
    in_cpds=[]
    extra_in_cpds=[]
    for node in path:
        rn,cpd=node.split(";")
        in_cpds.append(cpd)
        extras=rcn_eqn_pairs[rn]
        in_side=which_side_rcn(cpd,extras)
        extra_in_cpds.extend(extras[in_side]-set([cpd]))
    in_cpds=set(in_cpds)
    extra_in_cpds=set(extra_in_cpds)
    shared_cpds= in_cpds & extra_in_cpds
    if len(shared_cpds)>0:
        return False
    else:
        return True

def which_side_rcn(cpd, eqn_rcn_pair):
    '''Decides which side of the reaction you are on based on the cpd.'''
    sides=""
    for side,cpds in eqn_rcn_pair.iteritems():
        if cpd in cpds:
            sides=side
    return sides

def k_shortest_biological_paths(graph,start,end,k,rcn_eqn_pairs,weight=None):
    '''Ignore paths which don't make at least a little biological sense.'''
    biological_paths=[]
    N_biol_paths=0
    for path in nx.shortest_simple_paths(graph,start,end,weight=weight):
        if is_biological(path,rcn_eqn_pairs):
            biological_paths.append(path)
            N_biol_paths+=1
            if N_biol_paths>=k:
                return biological_paths
    return biological_paths


def network_graph_prep_wd():

    
    return

def generate_biological_links(Genome_ko_hits,output_dir):
    KO_RN_dict = load_local_kegg_database_pairings(os.path.join(output_dir,"Databases"),[("orthology","reaction")], False)[("orthology","reaction")]
    all_linkable_KOs = set(itertools.chain(*Genome_ko_hits.itervalues())) & set(KO_RN_dict.keys())
    cpd_links = {KO:set(KO_RN_dict[KO]) for KO in all_linkable_KOs}
    return

#def network_analysis_wf(,):
    
#    return

In [12]:
rcn_eqn_pairs=load_local_rcn_eqn_database_set(database_dir)

G=reaction_graphs["U_52098"]


NameError: name 'reaction_graphs' is not defined

In [8]:
G.nodes()[0:20]

NameError: name 'G' is not defined

In [59]:
m=k_shortest_paths(G,'R00475;C00160','R04199;C03972',10,weight=None)

NetworkXNoPath: No path between R00475;C00160 and R04199;C03972.

In [22]:
timeit.repeat("nx.has_path(G,'R00475;C00160','R04199;C03972')",'from __main__ import nx, G',number=1000)

[0.24488496780395508, 0.2435619831085205, 0.17296099662780762]

In [23]:
m

[['R00475;C00160',
  'R00372;C00048',
  'R00258;C00026',
  'R00259;C00025',
  'R04364;C00010',
  'R04199;C03972'],
 ['R00475;C00160',
  'R00372;C00048',
  'R00093;C00026',
  'R00259;C00025',
  'R04364;C00010',
  'R04199;C03972'],
 ['R00475;C00160',
  'R00372;C00048',
  'R04475;C00026',
  'R00259;C00025',
  'R04364;C00010',
  'R04199;C03972'],
 ['R00475;C00160',
  'R00372;C00048',
  'R01214;C00026',
  'R01213;C00141',
  'R04364;C00010',
  'R04199;C03972'],
 ['R00475;C00160',
  'R00372;C00048',
  'R00248;C00026',
  'R00259;C00025',
  'R04364;C00010',
  'R04199;C03972'],
 ['R00475;C00160',
  'R00372;C00048',
  'R01214;C00026',
  'R00259;C00025',
  'R04364;C00010',
  'R04199;C03972'],
 ['R00475;C00160',
  'R00372;C00048',
  'R05085;C00026',
  'R00259;C00025',
  'R04364;C00010',
  'R04199;C03972'],
 ['R00475;C00160',
  'R00372;C00048',
  'R02199;C00026',
  'R00259;C00025',
  'R04364;C00010',
  'R04199;C03972'],
 ['R00475;C00160',
  'R00372;C00048',
  'R03243;C00026',
  'R00259;C00025',
  'R

In [40]:
timeit.repeat("k_shortest_biological_paths(G,'R00475;C00160','R04199;C03972',10,rcn_eqn_pairs)",'from __main__ import G, rcn_eqn_pairs,k_shortest_biological_paths',number=1)

[0.14593982696533203, 0.10733604431152344, 0.11613011360168457]

In [56]:
l=k_shortest_biological_paths(G,'R00475;C00160','R04199;C03972',10,rcn_eqn_pairs)

In [57]:
l

[['R00475;C00160',
  'R00372;C00048',
  'R01214;C00026',
  'R01213;C00141',
  'R04364;C00010',
  'R04199;C03972'],
 ['R00475;C00160',
  'R00369;C00048',
  'R00945;C00037',
  'R02287;C00101',
  'R00259;C00025',
  'R04364;C00010',
  'R04199;C03972'],
 ['R00475;C00160',
  'R00588;C00048',
  'R00945;C00037',
  'R02287;C00101',
  'R00259;C00025',
  'R04364;C00010',
  'R04199;C03972'],
 ['R00475;C00160',
  'R00369;C00048',
  'R00945;C00037',
  'R03189;C00101',
  'R00259;C00025',
  'R04364;C00010',
  'R04199;C03972'],
 ['R00475;C00160',
  'R00588;C00048',
  'R00945;C00037',
  'R03189;C00101',
  'R00259;C00025',
  'R04364;C00010',
  'R04199;C03972'],
 ['R00475;C00160',
  'R00369;C00048',
  'R08701;C00037',
  'R02287;C00101',
  'R00259;C00025',
  'R04364;C00010',
  'R04199;C03972'],
 ['R00475;C00160',
  'R00372;C00048',
  'R09099;C00037',
  'R00585;C00065',
  'R03210;C00041',
  'R04364;C00010',
  'R04199;C03972'],
 ['R00475;C00160',
  'R00588;C00048',
  'R08701;C00037',
  'R02287;C00101',
  'R0

In [16]:
timeit.repeat("k_shortest_paths(G,'R00475;C00160','R09773;C17556',10,weight=None)",'from __main__ import G,k_shortest_paths',number=100)

[2.270888090133667, 2.2284109592437744, 2.1088790893554688]

In [None]:
uniq_cpds=pairwise_unique_cpds(reaction_graphs)

In [53]:
store_unique_cpd_rn_node_graphs(KO_genome_hits,common_cpds,os.path.join(output_dir,"graphs/Reaction_links","Side_spec"),output_dir,database_dir,double_links=True)

The graphs are being made.
U_52270 has started being processed at: 2016-07-28 13:55:56.648621
Duration since start: 0:00:00.000074
There were 1644 doubly linked reactions in this set of comparisons.
Duration of making graph 0:00:25.297256
U_52098 has started being processed at: 2016-07-28 13:56:21.947452
Duration since start: 0:00:25.298905
There were 1209 doubly linked reactions in this set of comparisons.
Duration of making graph 0:00:19.205901
SymbC15 has started being processed at: 2016-07-28 13:56:41.153752
Duration since start: 0:00:44.505205
There were 14138 doubly linked reactions in this set of comparisons.
Duration of making graph 0:02:22.333221
U_52448 has started being processed at: 2016-07-28 13:59:03.487425
Duration since start: 0:03:06.838878
There were 1619 doubly linked reactions in this set of comparisons.
Duration of making graph 0:00:22.513843
U_52425 has started being processed at: 2016-07-28 13:59:26.001678
Duration since start: 0:03:29.353131
There were 1954 doub

In [54]:
store_unique_cpd_rn_node_graphs(KO_genome_hits,common_cpds,os.path.join(output_dir,"graphs/Reaction_links","Side_spec_no_double_links"),output_dir,database_dir,double_links=False)

The graphs are being made.
U_52270 has started being processed at: 2016-07-28 14:16:08.833647
Duration since start: 0:00:00.000046
There were 1644 doubly linked reactions in this set of comparisons.
Duration of making graph 0:00:19.718796
U_52098 has started being processed at: 2016-07-28 14:16:28.553064
Duration since start: 0:00:19.719463
There were 1209 doubly linked reactions in this set of comparisons.
Duration of making graph 0:00:18.304732
SymbC15 has started being processed at: 2016-07-28 14:16:46.858502
Duration since start: 0:00:38.024901
There were 14138 doubly linked reactions in this set of comparisons.
Duration of making graph 0:01:58.074109
U_52448 has started being processed at: 2016-07-28 14:18:44.933713
Duration since start: 0:02:36.100112
There were 1619 doubly linked reactions in this set of comparisons.
Duration of making graph 0:00:20.929159
U_52425 has started being processed at: 2016-07-28 14:19:05.863472
Duration since start: 0:02:57.029871
There were 1954 doub

In [55]:
reaction_graphs=load_graphs(os.path.join(output_dir,"graphs","Reaction_links","Side_spec_no_double_links"),'*.xml')

# Recreating Heatmaps

In [22]:
def plot_heatmap_proportion(path_id, path_name, KO_genome_files, cmap,output_dir,database_dir,bin_names):

    PTH_MO_pairs = load_local_kegg_database_pairings(database_dir,[("pathway","module")], False)[("pathway","module")]
    linkable_base_paths=set(PTH_MO_pairs.iterkeys()) & set(path_id)
    base_mod= set(itertools.chain(*[PTH_MO_pairs[pathway] for pathway in linkable_base_paths]))
    
    MO_KO_pairs  = load_local_kegg_database_pairings(database_dir,[("module","orthology")], False)[("module","orthology")]
    all_modules = {MO:MO_KO_pairs[MO] for MO in base_mod} #kc.link_ids('ko', base_mod)
    module_totals = pd.Series(
        {
            mod_id: len(ko_ids)
            for mod_id, ko_ids in all_modules.iteritems()
        }
    )
    rev_modules = dictionary.reverse_mapping(all_modules)
    
    module_names=load_readable_names(database_dir,["module"],False)["module"]
    
    mod_prop = {}
    for genome_id, KOs in KO_genome_files.iteritems():
        mod_prop[genome_id] = {}
        genome_ko_ids = KOs
        for mod_id, ko_ids in all_modules.iteritems():
            mod_prop[genome_id][mod_id] = len(set(ko_ids) & genome_ko_ids)
    
    mod_prop = pd.DataFrame(mod_prop).fillna(0)
    mod_prop = mod_prop[mod_prop.sum(axis=1) > 0].divide(module_totals, axis='index').dropna()
    mod_prop = mod_prop.rename(index=module_names, columns=bin_names)
    mod_prop=mod_prop.sort_index(axis='columns')
    
    h2 = sns.clustermap(mod_prop, col_cluster=False, method='complete', cmap=cmap)
    h2.ax_heatmap.set_title(path_name)
    for text in h2.ax_heatmap.get_yticklabels():
        text.set_rotation('horizontal')
    for text in h2.ax_heatmap.get_xticklabels():
        text.set_rotation('vertical')
    h2.savefig(os.path.join(output_dir,'{}-modules_proportion.pdf'.format(path_name)))

colors = ['Blues', 'Greens', 'Oranges', 'Purples', 'Reds', 'Greys','Blues','Blues']
for (path_name, path_id), palette in zip(pathways.iteritems(), colors):
    print path_name, path_id, palette
    plot_heatmap_proportion(path_id, path_name, KO_genome_hits, palette,output_dir,os.path.join(output_dir,"Databases"),bin_names)



def group_micro_plot_heatmap_proprtion(path_id, path_name, KO_genome_files, cmap,output_dir,database_dir,bin_names):
    PTH_MO_pairs = load_local_kegg_database_pairings(database_dir,[("pathway","module")], False)[("pathway","module")]
    linkable_base_paths=set(PTH_MO_pairs.iterkeys()) & set(path_id)
    base_mod= set(itertools.chain(*[PTH_MO_pairs[pathway] for pathway in linkable_base_paths]))
    
    MO_KO_pairs  = load_local_kegg_database_pairings(database_dir,[("module","orthology")], False)[("module","orthology")]
    all_modules = {MO:MO_KO_pairs[MO] for MO in base_mod} #kc.link_ids('ko', base_mod)
    module_totals = pd.Series(
        {
            mod_id: len(ko_ids)
            for mod_id, ko_ids in all_modules.iteritems()
        }
    )
    rev_modules = dictionary.reverse_mapping(all_modules)
    
    module_names=load_readable_names(database_dir,["module"],False)["module"]
    
    mod_prop = {}
    for genome_id, KOs in KO_genome_files.iteritems():
        mod_prop[genome_id] = {}
        genome_ko_ids = KOs
        for mod_id, ko_ids in all_modules.iteritems():
            mod_prop[genome_id][mod_id] = len(set(ko_ids) & genome_ko_ids)
            
    euk_names=["coral","SymbC15"]
    
    bacteria=set(mod_prop.keys()) - set(euk_names)
    
    new_mod_prop={}
    new_mod_prop["microorganisms"]={}
    for module in mod_prop["coral"].iterkeys():
        new_mod_prop["microorganisms"][module]=0
    for euk in euk_names:
        new_mod_prop[euk]=mod_prop[euk]
    for genome, module_counts in mod_prop.iteritems():
        if genome not in euk_names:
            for module, count in module_counts.iteritems():
                new_mod_prop["microorganisms"][module]=max(new_mod_prop["microorganisms"][module],count)
                
    new_bin_names={gen_id:gen_name for gen_id,gen_name in bin_names.iteritems() if gen_id in euk_names}
    new_bin_names["microorganisms"]="microorganisms"
    bin_names=new_bin_names
    
    mod_prop = pd.DataFrame(new_mod_prop).fillna(0)
    mod_prop = mod_prop[mod_prop.sum(axis=1) > 0].divide(module_totals, axis='index').dropna()
    mod_prop = mod_prop.rename(index=module_names, columns=bin_names)
    mod_prop=mod_prop.sort_index(axis='columns')
    
        
    h2 = sns.clustermap(mod_prop, col_cluster=False, method='complete', cmap=cmap)
    h2.ax_heatmap.set_title(path_name)
    for text in h2.ax_heatmap.get_yticklabels():
        text.set_rotation('horizontal')
    for text in h2.ax_heatmap.get_xticklabels():
        text.set_rotation('vertical')
    h2.savefig(os.path.join(output_dir,'{}_grouped_microbes_modules_proportion.pdf'.format(path_name)))
    
    
colors = ['Blues', 'Greens', 'Oranges', 'Purples', 'Reds', 'Greys','Blues','Blues']
for (path_name, path_id), palette in zip(pathways.iteritems(), colors):
    print path_name, path_id, palette
    group_micro_plot_heatmap_proprtion(path_id, path_name, KO_genome_hits, palette,output_dir,os.path.join(output_dir,"Databases"),bin_names)

oxidative_phosphorylation ['map00190'] Blues


  if self._edgecolors == 'face':


two-component ['map02020'] Greens
vitamins&cofactors ['map00730', 'map00740', 'map00750', 'map00760', 'map00770', 'map00780', 'map00785', 'map00790', 'map00670', 'map00830', 'map00860', 'map00130'] Oranges
AminoAcidMetabolism ['map00250', 'map00270', 'map00260', 'map00280', 'map00290', 'map00300', 'map00310', 'map00220', 'map00330', 'map00340', 'map00350', 'map00360', 'map00380', 'map00400'] Purples
nitrogen-sulfur-fatty_acid-photosynthesis ['map00910', 'map00920', 'map01212', 'map00195'] Reds
carbon ['map01200'] Greys
amino-acids ['map01230'] Blues
oxidative_phosphorylation ['map00190'] Blues
two-component ['map02020'] Greens
vitamins&cofactors ['map00730', 'map00740', 'map00750', 'map00760', 'map00770', 'map00780', 'map00785', 'map00790', 'map00670', 'map00830', 'map00860', 'map00130'] Oranges
AminoAcidMetabolism ['map00250', 'map00270', 'map00260', 'map00280', 'map00290', 'map00300', 'map00310', 'map00220', 'map00330', 'map00340', 'map00350', 'map00360', 'map00380', 'map00400'] Purp

## Path searching algorithms

In [96]:
def networkx_shortest_path(graph,cpds1,cpds2):
    if len(cpds1)!=len(cpds2):
        print "The cpds must came as start, end pairs. There must be the same number of starting nodes as finishing nodes"
        return None
    shortest_paths={}
    graph_cpds=graph.nodes()
    for (start,end) in itertools.izip(cpds1,cpds2):
        if start in graph_cpds and end in graph_cpds:
            shortest_paths[(start,end)]=nx.astar_path(graph,start,end)
        else:
            shortest_paths[(start,end)]=[]
    return shortest_paths

#FROM NETWORKX DOCUMENTATION
def k_shortest_paths(graph,start,end,k,weight=None):
    return list(itertools.islice(nx.shortest_simple_paths(graph,start,end,weight=weight),k))
##############################


def networkx_k_shortest_paths(graph,cpds1,cpds2,k):
    if len(cpds1)!=len(cpds2):
        print "The cpds must came as start, end pairs. There must be the same number of starting nodes as finishing nodes"
        return None
    cpd_pair_k_shortest_paths={}
    graph_cpds=graph.nodes()
    for (start,end) in itertools.izip(cpds1,cpds2):
        if start in graph_cpds and end in graph_cpds:
            cpd_pair_k_shortest_paths[(start,end)]=k_shortest_paths(graph,start,end,k)
        else:
            cpd_pair_k_shortest_paths[(start,end)]=[[]]
    return cpd_pair_k_shortest_paths

def shortest_path_to_pd_dataframe(shortest_paths):
    kc = kegg.KeggClientRest()
    cpd_names = kc.get_ids_names('compound')
    df=pd.DataFrame([
            [cpd1, cpd2, cpd_names[cpd1],cpd_names[cpd2],"||".join(path), "||".join([cpd_names[cpd] for cpd in path])] for (cpd1, cpd2), path in shortest_paths.iteritems() 
        ])
    df.columns=["KEGGCpd1ID","KEGGCpd2ID",'CPD1_readable',"CPD2_readable","CPD_PATH","READABLE_CPD_PATH"]
    return df

def save_output(pd_path_dataframe,output_file):
    pd_path_dataframe.to_csv(output_file,sep="\t",header=True,index=False)
    return None

def shortest_paths_pd_dataframe(k_shortest_paths):
    kc = kegg.KeggClientRest()
    cpd_names = kc.get_ids_names('compound')
    df=pd.DataFrame([
            [cpd1, cpd2, cpd_names[cpd1],cpd_names[cpd2],"||".join(path), "||".join([cpd_names[cpd] for cpd in path])] for (cpd1, cpd2), paths in k_shortest_paths.iteritems() for path in paths 
        ])
    df.columns=["KEGGCpd1ID","KEGGCpd2ID",'CPD1_readable',"CPD2_readable","CPD_PATH","READABLE_CPD_PATH"]
    return df

def genome_shortest_paths_pd_dataframe(genome_cpd_pairs):
    kc = kegg.KeggClientRest()
    cpd_names = kc.get_ids_names('compound')
    df=pd.DataFrame([
            [genome, cpd1, cpd2, cpd_names[cpd1],cpd_names[cpd2],"||".join(path), "||".join([cpd_names[cpd] for genome, k_shortest_paths in genome_cpd_pairs.iteritems() for cpd in path])] \
            for (cpd1, cpd2), paths in k_shortest_paths.iteritems() for path in paths 
        ])
    df.columns=["Genome","KEGGCpd1ID","KEGGCpd2ID",'CPD1_readable',"CPD2_readable","CPD_PATH","READABLE_CPD_PATH"]
    return df

def investigate_within_genome_paths(genome_graph_dict,output_dir,n_shortest_paths,n_pruned_nodes,cpds1,cpds2):
    pruned_graphs={}
    genome_cpd_paths={}
    for genome, graph in genome_graph_dict.iteritems():
        pruned_graphs[genome]=prune_graph_abundant_nodes(graph,n_pruned_nodes)
    for genome, pruned_graph in pruned_graphs.iteritems():
        genome_cpd_paths[genome]=networkx_k_shortest_paths(pruned_graph,cpds1,cpds2,n_shortest_paths)
    df=genome_shortest_paths_pd_dataframe(genome_cpd_paths)
    save_output(df,os.path.join(output_dir,"{0}paths_{1}pruned_nodes_genome_cpd_pair_data.tsv"))
    return df

def investigate_between_genome_paths(genome_graph_dict,output_dir,n_shortest_paths,n_pruned_nodes,cpds1,cpds2):
    merged_genomes={}
    genome_differences={}
    for genome_1,genome_2 in itertools.combination(genome_graph_dict.iterkeys(),2):
        genome_differences=analyse_genome_differences(genome_1,genome_2)
        merged_genomes[genome_1,genome_2]=merge_genomes(genome_graph_dict[genome_1],genome_graph_dict[genome_2])
    for (genome_1,genome_2),merged_graph in merged_genomes:
        pass
        
        
        
    return

def pairwise_merged_graphs_individually(genome_reaction_graphs,output_dir):
    for genome_1,genome_2 in itertools.combinations(genome_reaction_graphs.iterkeys(),2):
        merged_genome=merge_graphs(genome_reaction_graphs[genome_1],genome_reaction_graphs[genome_2],common_cpds)
        save_reaction_graph(merged_genome,os.path.join(output_dir,"{0}_paired_{1}_graph_merged_genome_reactions.xml".format(genome_1,genome_2)))

## Single degree comparisons

In [76]:
def get_one_degree_nodes(graph,thresh):
    one_degree={}
    for cpd, deg in graph.degree().iteritems():
        if deg<=thresh:
            one_degree[cpd]=deg
    return one_degree

def all_one_degree(graph_dict):
    one_degree_genome={}
    for genome, graph in graph_dict.iteritems():
        one_degree_genome[genome]=get_one_degree_nodes(graph,1)
    return one_degree_genome

def set_conversion(graph_dict):
    cpd_sets={}
    for genome, cpd_dict in graph_dict.iteritems():
        cpd_sets[genome]=set(cpd_dict.iterkeys())
    return cpd_sets        

def get_one_degree_overlap(one_degree_genomes):
    paired_compounds={}
    for genome_1, genome_2 in itertools.combinations(one_degree_genomes.iterkeys(),2):
        paired_compounds_init=one_degree_genomes[genome_1] & one_degree_genomes[genome_2]
        if len(paired_compounds_init)>0:
            paired_compounds[(genome_1,genome_2)]=paired_compounds_init
    return paired_compounds

def get_one_degree_joins(graph_dict,item_name,common_cpds):
    one_degree_nodes=all_one_degree(graph_dict)
    one_degree_nodes=set_conversion(one_degree_nodes)
    linked_cpds={}
    for genome_1,genome_2 in itertools.combinations(one_degree_nodes.iterkeys(),2):
        genome_1_iso_nodes=one_degree_nodes[genome_1]
        genome_2_iso_nodes=one_degree_nodes[genome_2]
        gen_1_n_attr=nx.get_node_attributes(graph_dict[genome_1],item_name)
        gen_2_n_attr=nx.get_node_attributes(graph_dict[genome_2],item_name)
        genome_1_connections=connect_reactions(gen_2_n_attr,genome_1_iso_nodes,gen_1_n_attr,common_cpds)
        genome_2_connections=connect_reactions(gen_1_n_attr,genome_2_iso_nodes,gen_2_n_attr,common_cpds)
        linked_cpds[genome_1,genome_2]=genome_1_connections
        linked_cpds[genome_2,genome_1]=genome_2_connections
    return linked_cpds
        
def make_one_node_rcn_pd_df(linked_cpds,database_dir):
    readable_names=load_readable_names(database_dir,["reaction","compound"],False)
    rn_pth_pairs=load_local_kegg_database_pairings(database_dir,[("reaction","compound")], False)
    df=pd.DataFrame([
            [col1,col2,col3,col4,readable_names["reaction"][col3],\
             readable_names["reaction"][col4],";".join(overlap),";".join([readable_names["compound"][cpd] for cpd in overlap]),\
             ";".join(rn_pth_pairs[col3]),";".join(rn_pth_pairs[col4])] \
             for (col1,col2),d in linked_cpds.iteritems() for (col3, col4),overlap in d.iteritems()
        ])
    df.columns=["Genome_1","Genome_2","KEGG_Reaction1_ID","KEGG_Reaction2_ID","KEGG_Reaction1_readable","KEGG_Reaction2_readable","CPD_Overlap","Readable_Overlap","PathwaysR1","PathwaysR2"]
    return df

def save_output(pd_df,output_file):
    pd_df.to_csv(output_file,sep="\t",index=False)
    return None

def one_degree_rcn_wf(genome_graphs,item_name,common_cpds,output_dir):
    shared_cpd_genome=get_one_degree_joins(genome_graphs,item_name,common_cpds)
    new_df=make_one_node_rcn_pd_df(shared_cpd_genome,os.path.join(output_dir,"Databases"))
    save_output(new_df,os.path.join(output_dir,"Reaction_graphs_isolated_node_comparisons.tsb"))
    return new_df

def one_degree_wf(genome_graphs):
    shared_cpd_genome=all_one_degree(genome_graphs)
    shared_cpd_genome=set_conversion(shared_cpd_genome)
    shared_cpd_genome=get_one_degree_overlap(shared_cpd_genome)
    return shared_cpd_genome

def cpd_shared_count(shared_cpds):
    cpd=[]
    for genome, cpds in shared_cpds.iteritems():
        cpd+=list(cpds)
    return cpd

def make_one_node_pd_dataframe(overlapping_cpds):
    kc=kegg.KeggClientRest()
    cpd_path_pairings=kc.link_ids('pathway',set(itertools.chain(*overlapping_cpds.itervalues())))
    cpd_names = kc.get_ids_names('compound')
    pathway_names = kc.get_ids_names('pathway')
    
    df=pd.DataFrame([
        [col1,col2,col3,cpd_names[col3],";".join(cpd_path_pairings[col3])] for (col1,col2), d in pairwise_comparison.iteritems() for col3 in d
    ])
    
    df.columns=["Genome_1","Genome_2","KEGGCompoundID","Readable_cpd_name","KEGGPathway_ID"]#,"Readable_pathway_commound"]
    
    return df

## Within module overlap

In [23]:
def missing_KO_in_modules(genome_KO_hits,output_dir):
    '''Return a dictionary with a dictionary of modules and missing KOs for each
    genome. For simplicity all modules with no hits are ignored.'''
    missing_KOs={}
    genome_module_KO_hits={}
    kc=kegg.KeggClientRest()
    all_kos=set(itertools.chain(*genome_KO_hits.itervalues()))
    
    #All modules with at least one hit across the entire dataset
    prelinked_module=load_local_kegg_database_pairings(os.path.join(output_dir,"Databases"),[("orthology","module")], False)[("orthology","module")]
    all_kegg_kos=set(itertools.chain(prelinked_module.iterkeys()))
    observed_linkable_kos=all_kegg_kos & all_kos
    observed_modules=set(itertools.chain(*[prelinked_module[ko] for ko in observed_linkable_kos]))
    #The module pairings with their associated KO ids.
    MO_KO_pair=load_local_kegg_database_pairings(os.path.join(output_dir,"Databases"),[("module","orthology")], False)[("module","orthology")]
    
    KO_with_modules = {MO:MO_KO_pair[MO] for MO in observed_modules} 
    
    #KO misses for each genome and module.
    for genome_id, ko_hits in genome_KO_hits.iteritems():
        missing_KOs[genome_id]={}
        genome_module_KO_hits[genome_id]={}
        for mod_id, ko_ids in KO_with_modules.iteritems():
            present_KOs=(set(ko_ids) & ko_hits)
            #exclusion of intersection from module set
            missing_KOs[genome_id][mod_id]=set(ko_ids) - present_KOs
            #Intersection of sets - KOs present in module
            genome_module_KO_hits[genome_id][mod_id]=present_KOs
        
    return missing_KOs,genome_module_KO_hits,KO_with_modules

def pairwise_compare(genome_KO_hits,genome_missing_KOs,KO_with_modules,min_diff,min_completeness):
    '''Return a dictionary of modules for each genome pair if the complementation of the module passes
    some criteria.
    '''
    pairwise_complement_dict={}
    for genome_1, genome_2, in itertools.combinations(genome_KO_hits.iterkeys(),2):
        pairwise_complement_dict[(genome_1,genome_2)]={}
        pairwise_complement_dict[(genome_2,genome_1)]={}
        for mod_id, KO_set in KO_with_modules.iteritems():
            
            mod_tot=len(KO_set)
            
            genome_1_module_hits=genome_KO_hits[genome_1][mod_id]
            genome_2_module_hits=genome_KO_hits[genome_2][mod_id]
            
            com_1=genome_missing_KOs[genome_2][mod_id] & genome_1_module_hits # (A-C) cap B
            com_2=genome_missing_KOs[genome_1][mod_id] & genome_2_module_hits #(A-B) cap C
            
            N_com1=len(com_1)
            N_com2=len(com_2)
            
            if N_com1>1 and N_com1/float(mod_tot)>=min_diff:
                complemented_set=com_1.union(genome_1_module_hits)
                new_completeness=len(complemented_set)/float(mod_tot)
                if new_completeness>=min_completeness:
                    pairwise_complement_dict[(genome_1,genome_2)][mod_id]=(new_completeness,N_com1/float(mod_tot),genome_2_module_hits,com_1)
            if N_com2>1 and N_com2/float(mod_tot)>=min_diff:
                complemented_set=com_2.union(genome_2_module_hits)
                new_completeness=len(complemented_set)/float(mod_tot)
                if new_completeness>=min_completeness:
                    pairwise_complement_dict[(genome_2,genome_1)][mod_id]=(new_completeness,N_com2/float(mod_tot),genome_1_module_hits,com_2)
    return pairwise_complement_dict

def clean_pairs(dict_of_dicts):
    cleaned_dict={}
    for genome_pair, module_dict in dict_of_dicts.iteritems():
        if len(module_dict)>0:
            cleaned_dict[genome_pair]=module_dict
        else:
            pass
            #print module_dict
    return cleaned_dict
            
def make_complement_dataframe(pairwise_comparison):
    ###############
    df=pd.DataFrame([
        [col1,col2,col3,col4,col5,col6,col7] for (col1,col2), d in pairwise_comparison.iteritems() for col3,(col4,col5,col6,col7) in d.iteritems()
    ])
    df.columns=["Genome_1(Giver)","Genome_2(Receiver)","KEGGModuleID","ComplementedCompleteness","GiverContribution","Orig_KOs","Given_KOs"]
    
    return df

def threshold_reduce(df,orig_contrib):
    return df[(df["ComplementedCompleteness"]-df["GiverContribution"])>=orig_contrib]

def save_output(complement_dataframe, orig_contrib,output_dir):
    full_file=os.path.join(output_dir,"entire_overlap_file.tsv")
    complement_dataframe=readable_names(complement_dataframe,'KEGGModuleID','module','ReadableModuleNames',output_dir)
    complement_dataframe=add_mapping(complement_dataframe,"KEGGModuleID",'pathway','KEGGPathwayID',output_dir)
    complement_dataframe.to_csv(full_file,sep="\t",index=None)
    reduced_file=os.path.join(output_dir,"orig_contrib_thresh{0}_red_comp.tsv".format(orig_contrib))
    less_dataframe=threshold_reduce(complement_dataframe,orig_contrib)
    less_dataframe.to_csv(reduced_file,sep="\t",index=None)
    return less_dataframe

def readable_names(dataframe,column_name,kegg_object,new_column_name,output_dir):
    kc = kegg.KeggClientRest()
    module_names = load_readable_names(os.path.join(output_dir,"Databases"),[kegg_object],False)[kegg_object]
    idx=dataframe.columns.get_loc(column_name)+1 #To be after module IDs
    new_df=dataframe.copy()
    new_df.insert(idx,new_column_name,dataframe[column_name].map(module_names))
    return new_df

def add_mapping(dataframe,column_name,kegg_object,new_column_name,output_dir,human_readable=True):
    #print [("module",kegg_object)]
    MO_PA_pairs=load_local_kegg_database_pairings(os.path.join(output_dir,"Databases"),[("module",kegg_object)], False)[("module",kegg_object)]
    linkable_modules=set(itertools.chain(MO_PA_pairs.iterkeys())) & set(dataframe[column_name])
    
    mod_path_pairings={module:MO_PA_pairs[module] for module in linkable_modules}
    #print len(mod_path_pairings)
    comp_pair={module:";".join(pathways) for module, pathways in mod_path_pairings.iteritems()}
    pathway_names=load_readable_names(os.path.join(output_dir,"Databases"),[kegg_object],False)[kegg_object]
    new_df=dataframe.copy()
    idx=dataframe.columns.get_loc(column_name)+2
    new_df.insert(idx,new_column_name,dataframe[column_name].map(comp_pair))
    if human_readable:
        readable={module:";".join(pathway_names[pathway] for pathway in pathways) for module, pathways in mod_path_pairings.iteritems()}
        new_df.insert(idx+1,new_column_name+"HuReadable",dataframe[column_name].map(readable))
    return new_df

def module_complement_workflow(KO_genome_hits,output_dir,orig_contrib,min_completeness, min_diff,return_all=False):
    
    missing_module_info,present_KOs_module,KO_module_pairs=missing_KO_in_modules(KO_genome_hits,output_dir)
    pairwise_complement_dict=pairwise_compare(present_KOs_module,missing_module_info,KO_module_pairs,min_diff,min_completeness)
    clean_pair_comp_subset=clean_pairs(pairwise_complement_dict)
    metab_comp_df=make_complement_dataframe(clean_pair_comp_subset)
    reduced_df=save_output(metab_comp_df,orig_contrib,output_dir)
    if return_all:
        return metab_comp_df
    else:
        return reduced_df

# Within Module Comparison

In [24]:
df=module_complement_workflow(KO_genome_hits,output_dir,0.4,0.8,0.4,return_all=False)

# Single Node Comparison

In [25]:
one_degree_rcn_wf(reaction_graphs,"cpds",common_cpds,output_dir)

NameError: name 'one_degree_rcn_wf' is not defined

# Loading in previously made graphs

In [8]:
reaction_graphs=load_graphs(os.path.join(output_dir,"graphs","Reaction_links","Side_ambiguous"),'*.xml')

In [11]:
G=reaction_graphs['U_52423']

# Testing graph merging and the new cpd to reaction path tracking.

In [9]:
test_pair=(reaction_graphs['coral'],reaction_graphs['SymbC15'])

In [11]:
merged_graph=merge_graphs(test_pair[0],test_pair[1],common_cpds)
#looks like the graph merging works. Yay.

In [12]:
import timeit
timeit.timeit("merge_graphs(test_pair[0],test_pair[1],common_cpds)","from __main__ import merge_graphs, common_cpds, test_pair",number=10)

74.15104389190674

In [104]:
split=merged_graph.name.split(":")
first=split[0]
last=split[1]
gen_1_reactions, gen_2_reactions=_unique_compound_pairs(merged_graph,(first,last))

In [105]:
print len(gen_1_reactions), len(gen_2_reactions)

279 314


In [104]:
new_temp={rcn:data for rcn, data in reaction_graphs['coral'].nodes(data=True)}
new_temp['R01878']

{'cpds': {u'C00001', u'C00014', u'C00299', u'C00475'}}

In [97]:
test_start,test_end=extract_starting_ending_reactions(reaction_graphs['coral'],"C00014","C01289","cpds")
print test_start,test_end

set(['R04893', 'R04930', 'R04907', 'R04890', 'R01168', 'R03909', 'R01001', 'R08348', 'R00277', 'R04908', 'R01676', 'R01230', 'R04674', 'R00131', 'R01579', 'R02540', 'R02529', 'R00485', 'R02408', 'R04025', 'R00220', 'R05590', 'R07700', 'R00348', 'R00269', 'R00248', 'R00689', 'R05551', 'R00084', 'R00748', 'R02613', 'R00243', 'R00571', 'R03096', 'R00253', 'R00648', 'R01663', 'R04666', 'R02382', 'R01878', 'R00729', 'R02485', 'R00181', 'R01134', 'R00359', 'R04300', 'R02908', 'R01560', 'R04125', 'R02302', 'R04770', 'R00677', 'R00996', 'R10949', 'R01221', 'R03180', 'R02532', 'R02150', 'R00765', 'R00256', 'R02556', 'R00357', 'R09366', 'R00905', 'R08221', 'R01710', 'R01151', 'R02173', 'R02197', 'R08228', 'R00782']) set([])


In [106]:
iter_pairs=extract_all_starting_ending_reacitons(reaction_graphs['coral'], [("C01289","C00014"),("C01289","C01289"),("C15547","C01179")],"cpds")
for pair,second in iter_pairs:
    print pair, second

set([]) set(['R04893', 'R04930', 'R04907', 'R04890', 'R01168', 'R03909', 'R01001', 'R08348', 'R00277', 'R04908', 'R01676', 'R01230', 'R04674', 'R00131', 'R01579', 'R02540', 'R02529', 'R00485', 'R02408', 'R04025', 'R00220', 'R05590', 'R07700', 'R00348', 'R00269', 'R00248', 'R00689', 'R05551', 'R00084', 'R00748', 'R02613', 'R00243', 'R00571', 'R03096', 'R00253', 'R00648', 'R01663', 'R04666', 'R02382', 'R01878', 'R00729', 'R02485', 'R00181', 'R01134', 'R00359', 'R04300', 'R02908', 'R01560', 'R04125', 'R02302', 'R04770', 'R00677', 'R00996', 'R10949', 'R01221', 'R03180', 'R02532', 'R02150', 'R00765', 'R00256', 'R02556', 'R00357', 'R09366', 'R00905', 'R08221', 'R01710', 'R01151', 'R02173', 'R02197', 'R08228', 'R00782'])
set([]) set([])
set([]) set(['R02521', 'R03342', 'R00734', 'R00729'])


In [108]:
iter_pairs=extract_all_starting_ending_reacitons(reaction_graphs['coral'], [("C00014","C01179")],"cpds")
for pair,second in iter_pairs:
    print pair, second

set(['R04893', 'R04930', 'R04907', 'R04890', 'R01168', 'R03909', 'R01001', 'R08348', 'R00277', 'R04908', 'R01676', 'R01230', 'R04674', 'R00131', 'R01579', 'R02540', 'R02529', 'R00485', 'R02408', 'R04025', 'R00220', 'R05590', 'R07700', 'R00348', 'R00269', 'R00248', 'R00689', 'R05551', 'R00084', 'R00748', 'R02613', 'R00243', 'R00571', 'R03096', 'R00253', 'R00648', 'R01663', 'R04666', 'R02382', 'R01878', 'R00729', 'R02485', 'R00181', 'R01134', 'R00359', 'R04300', 'R02908', 'R01560', 'R04125', 'R02302', 'R04770', 'R00677', 'R00996', 'R10949', 'R01221', 'R03180', 'R02532', 'R02150', 'R00765', 'R00256', 'R02556', 'R00357', 'R09366', 'R00905', 'R08221', 'R01710', 'R01151', 'R02173', 'R02197', 'R08228', 'R00782']) set(['R02521', 'R03342', 'R00734', 'R00729'])


In [115]:
print i

0


In [None]:

i=0
for path in nx.all_simple_paths(reaction_graphs['coral'],'R01878','R00734'):
    i+=1
    if i>=10:
        break
print i

In [None]:
test_pairwise_merging=pairwise_merged_graphs(reaction_graphs,False)
write_graph_dict(os.path.join(output_dir,"graphs/Pairwise_reaction_links"),test_pairwise_merging)

In [None]:
pairwise_merged_graphs_individually(reaction_graphs,os.path.join(output_dir,"graphs/Pairwise_reaction_links"))

# Deprecated

In [41]:
#I made a new KEGG KO file from all of the microbial gff so there was no need to constantly mount the work server or mess with gffs

def load_gffs(dir_name,glob_id):
    gff_files = {}
    for fname in glob(os.path.join(dir_name,glob_id)):
        key = '_'.join(os.path.basename(fname).split('_')[1:3])
        gff_files[key] = list(gff.parse_gff(fname))
    return gff_files
    
def get_KO_from_gff(gff_files):
    KO_ids={}
    for genome_id, annotations in gff_files.iteritems():
        genome_ko_ids = set()
        for annotation in annotations:
            for ko_id in annotation.get_mapping('KO'):
                genome_ko_ids.add(ko_id)
        KO_ids[genome_id]=genome_ko_ids
    return KO_ids

def get_KO_from_gff_not_set(gff_files):
    KO_ids={}
    for genome_id, annotations in gff_files.iteritems():
        genome_ko_ids = []
        for annotation in annotations:
            for ko_id in annotation.get_mapping('KO'):
                genome_ko_ids.append(ko_id)
        KO_ids[genome_id]=genome_ko_ids
    return KO_ids

def write_genome_KO_file(output_file,KO_Genome_hits):
    #print KO_Genome_hits
    with open(output_file,'w') as microb_kos:
        for genome, KO_hits in KO_Genome_hits.iteritems():
            microb_kos.write("{0}\t{1}\n".format(genome,";".join(KO_hits)))
    return

def remake_KOs_file(gff_dir, output_file, glob_ids):
    gff_files={}
    KO_hits={}
    for glob_id in glob_ids:
        gff_files[glob_id]=load_gffs(gff_dir,glob_id)
        KO_hits[glob_id]=get_KO_from_gff_not_set(gff_files[glob_id])
    if len(glob_ids)==1:
        #print KO_hits[glob_id][0]
        microbial_KO_hits=KO_hits.values()[0]
    else:
        microbial_KO_hits=merge_KO_hits(KO_hits)
    write_genome_KO_file(output_file,microbial_KO_hits)
    return microbial_KO_hits

def merge_KO_hits(mult_KO_hits):
    all_genomes=set(itertools.chain(key for genome_KO_dicts in mult_KO_hits.itervalues() for key in genome_KO_dicts.iterkeys()))
    new_dict={genome:[] for genome in all_genomes}
    for glob_id, genome_dicts in mult_KO_hits.iteritems():
        for genome, KOs in genome_dicts.iteritems():
            new_dict[genome]=new_dict[genome]+KOs
    return new_dict

microbial_KO_both_hits=remake_KOs_file(gff_dir,microbial_kegg,["*_genomic-both-final.gff"])

microbial_KO_sprot_trembl_hits=remake_KOs_file(gff_dir,microbial_kegg,["*_genomic-sprot-final.gff","*_genomic-trembl-final.gff"])

for genome, KOs in microbial_KO_both_hits.iteritems():
    print len(KOs), len(microbial_KO_sprot_trembl_hits[genome]), len(set(KOs)), len(set(microbial_KO_sprot_trembl_hits[genome]))

2016-07-22 09:44:19,706 -    INFO - mgkit.io.gff->parse_gff: Loading GFF from file (/home/alex/Documents/Hons/Seaquence/francesco_data/gff_bins-2016-06-14b/filtered_U_41432_genomic-both-final.gff)
2016-07-22 09:44:19,706 -    INFO - mgkit.io.gff->parse_gff: Loading GFF from file (/home/alex/Documents/Hons/Seaquence/francesco_data/gff_bins-2016-06-14b/filtered_U_41432_genomic-both-final.gff)
2016-07-22 09:44:20,299 -    INFO - mgkit.io.gff->parse_gff: Read 2762 line from file (/home/alex/Documents/Hons/Seaquence/francesco_data/gff_bins-2016-06-14b/filtered_U_41432_genomic-both-final.gff)
2016-07-22 09:44:20,299 -    INFO - mgkit.io.gff->parse_gff: Read 2762 line from file (/home/alex/Documents/Hons/Seaquence/francesco_data/gff_bins-2016-06-14b/filtered_U_41432_genomic-both-final.gff)
2016-07-22 09:44:20,332 -    INFO - mgkit.io.gff->parse_gff: Loading GFF from file (/home/alex/Documents/Hons/Seaquence/francesco_data/gff_bins-2016-06-14b/filtered_U_51863_genomic-both-final.gff)
2016-07-2

361 1638 341 1292
666 1532 516 998
533 1120 458 854
658 1550 550 1106
563 1143 508 939
669 1554 577 1141
878 1878 667 1164
619 1485 520 1044
202 708 191 629
552 1226 491 951
592 1396 481 951
563 1021 530 889
627 1385 517 992
545 1139 472 901
1079 1614 830 1090
584 1380 530 1117
217 948 204 789
437 753 403 659
778 1650 648 1211
549 1292 412 818
108 867 101 795
1037 2235 763 1309
355 853 316 650
414 802 365 642
896 1545 716 1029
580 1322 457 902
716 1913 579 1303
374 961 348 863
550 1270 461 891
960 1225 794 964
448 1121 381 772
555 1271 454 925
853 1103 750 934
506 1092 418 802
648 1507 528 1016
833 1010 738 867
691 1499 596 1154
682 1538 510 994
499 1125 409 771
985 1538 703 959
485 814 405 638
650 1437 517 1008
367 777 327 631
398 964 367 757
914 1625 696 1100
609 1424 487 952
724 1694 503 1052
594 1123 499 850
577 1379 487 999
509 1198 398 779
425 1097 370 753
461 1199 400 796


# Simple, euk-like repeats table.

In [37]:
#gene_dir="G:\data\eukaryote_like_repeats\gene_hits"
#gene_dir="/home/baker/Documents/MountedDrive/seaquence/data/eukaryote_like_repeats/gene_hits"
#load_euk_like_repeat_hits(gene_dir)

In [9]:
import os,sys,re
import pandas as pd
import numpy as np
import scipy as sp
import glob as gb
from collections import defaultdict

def hmmer_domtblout_parser(file_path,header_line,comment_char):
    i=0
    header_one_space_sep=[2,20]
    line_one_space_sep=[2,-1]
    regex_clean=re.compile("\s{2,}")
    with open (file_path) as hmmer_hits:
        for line in hmmer_hits:
            if i==header_line:
                yield list(itertools.chain(*[fields.split(" ", 1) if i in header_one_space_sep else [fields] for i,fields in enumerate(regex_clean.split(line.strip())) ]))
            elif not line.startswith(comment_char):
                part_proc=regex_clean.split(line.strip())
                for i in line_one_space_sep:
                    part_proc[i]=part_proc[i].split(None,2)
                
                yield list(itertools.chain(*[ [item] if not isinstance(item,list) else item for item in part_proc]))
            i+=1
    
def create_hmmer_domtblout_df(file_path,header_line,comment_char):
    hmmer_file=hmmer_domtblout_parser(file_path,header_line,comment_char)
    header=next(hmmer_file)
    df=pd.DataFrame([
            line for line in hmmer_file 
        ])
    df.columns=header
    return df

def all_domtblout_df(file_path,header_line,comment_char,optional_reg_cut):
    repeat_dfs={}
    i=0
    for file_path in gb.glob(os.path.join(file_path, "*.tsv")):
        file_name=os.path.basename(file_path)
        #print file_name
        if isinstance(optional_reg_cut,type(None)):
            file_id=file_name.replace(".tsv","")
        else:
            file_id=file_name.replace(".tsv","")
            file_id=re.sub(optional_reg_cut,"",file_id)
        repeat_dfs[file_id]=create_hmmer_domtblout_df(file_path,header_line,comment_char)
        i+=1
        if i%10==0:
            print "{0} files have been processed. The last was {1}".format(i,file_id)
    if len(repeat_dfs)==0:
        print "No files were loaded"
    return repeat_dfs

def merge_repeat_dfs(df_dict):
    df_list=[None]*len(df_dict)
    i=0
    for genome_id,df in df_dict.iteritems():
        new_df=df
        new_df['Genome_id']=genome_id
        #new_df.rename(columns={0:'Gene_Name'},inplace=True)
        df_list[i]=new_df
        i+=1
    merged_dict=pd.concat(df_list,axis=0)
    merged_dict.index.names=["Gene_name"]
    cols = merged_dict.columns.tolist()
    cols=cols[-1:]+cols[0:-1]
    merged_dict=merged_dict[cols]
    merged_dict['Genome_id']=merged_dict['Genome_id'].str.strip("_genomic").str.strip("aa_genes_unfiltered_").str.strip("aa_genes_filtered_")
    print merged_dict
#    print merged_dict.ix[:,0]
    return merged_dict

def merge_repeat_dfs_wf(df_dict,output_file,taxonomy_file):
    gene_level_hits=merge_repeat_dfs(df_dict)
    gene_level_hits['Taxonomy']=gene_level_hits['Genome_id'].map(taxonomy_file)
    cols = gene_level_hits.columns.tolist()
    cols=[cols[0]]+cols[-1:]+cols[1:-1]
    gene_level_hits=gene_level_hits[cols]
    gene_level_hits.to_csv(output_file,sep="\t",index=True)
    
    return gene_level_hits
    

def all_hmms_hit(df):
    
    return

def construct_n_genes_hits(df_dict,all_columns):
    genome_ids=df_dict.keys()
    column_names=pd.unique(itertools.chain(*df_dict.itervalues()))
    gene_count_df=pd.DataFrame(index=genome_ids,columns=all_columns)
    gene_count_df.index.name="Genome_id"
    
    for file_id, df in df_dict.iteritems():
        for repeat_motif,count in df.iteritems():
            gene_count_df.set_value(file_id,repeat_motif,count)
            
    return gene_count_df

def process_gene_hits(complete_df_dict,genome_taxonomy,output_file,all_motifs,taxonomy_file):
    gene_hits={file_id: make_hmm_hits_to_gene_hits(df) for file_id, df in complete_df_dict.iteritems()}
    gene_df=construct_n_genes_hits(gene_hits,all_motifs)
    gene_df.reset_index(level=0,inplace=True)
    gene_df['Genome_id']=gene_df['Genome_id'].str.strip("_genomic").str.strip("aa_genes_unfiltered_").str.strip("aa_genes_filtered_")
    #print gene_df.ix[:,0]
    #print taxonomy_file
    gene_df['Taxonomy']=gene_df['Genome_id'].map(taxonomy_file)
    cols = gene_df.columns.tolist()
    cols=[cols[0]]+cols[-1:]+cols[1:-1]
    gene_df=gene_df[cols]
    #Add taxonomy information
    gene_df.to_csv(output_file, sep="\t",index=False)
    
    return gene_df

def hmm_hits_in_gene(df,all_columns):
    genes_hit=pd.unique(df['query name'])
    repeats_found=pd.unique(df['# target name'])
    hit_df=pd.DataFrame(index=genes_hit,columns=all_columns)
    hit_df=hit_df.fillna(0)
    all_counts=df.groupby(['query name','# target name']).count()['of']
    for ((contig,repeat_motif),count) in all_counts.iteritems():
        hit_df.set_value(contig,repeat_motif,count)
        
    return hit_df

def make_hmm_hits_to_gene_hits(df):
    #Check for any columns with at least one hit, add all true values
    #to get number of genes with that motif.
    return (df>0).sum(axis=0)

def get_all_target_values(df_dict):
    target_values=[]
    for genome_id, df in df_dict.iteritems():
        target_values.append(df['# target name'])
        
    return pd.unique(pd.concat(target_values,axis=0))

def hmm_hits_wf(input_dir, output_dir,taxonomy_file, header_line,comment_char,optional_reg_cut):
    all_dfs=all_domtblout_df(input_dir,header_line,comment_char,optional_reg_cut)
    all_target_ids=get_all_target_values(all_dfs)
    motif_hits_in_gene={file_id: hmm_hits_in_gene(df,all_target_ids) for file_id,df in all_dfs.iteritems()}
    genome_taxonomy=load_bin_names(taxonomy_file)
    genome_taxonomy=defaultdict(lambda: "No predefined Taxonomy",genome_taxonomy)
    N_gene_hits=process_gene_hits(motif_hits_in_gene,genome_taxonomy,os.path.join(output_dir,"N_genes_with_hits.tsv"),all_target_ids,genome_taxonomy)
    merged_hmm_hits=merge_repeat_dfs_wf(motif_hits_in_gene,os.path.join(output_dir,"hmm_hits_per_gene_per_genome.tsv"),genome_taxonomy)
    
    return N_gene_hits
    
def load_bin_names(tax_file):
    #Load bin_ids and bins_taxonomy from file.
    bin_names={}
    bin_pair=[]
    with open(tax_file,'r') as bin_tax_pair:
        bin_tax_pair.readline()
        for line in bin_tax_pair:
            bin_pair.append(tuple(line.strip().split("\t")))

    bin_names={bin_id:taxonomy for taxonomy, bin_id in bin_pair}
    return bin_names



In [145]:
gene_dir

'/home/baker/grive/Honours/eukaryote_like_repeats/gene_hits'

In [146]:
hmm_dir

'/home/baker/grive/Honours/HMM_searches/Symbioses_test/euk_repeat_results'

In [10]:
hmm_hits_wf(gene_dir, hmm_dir,tax_file,1,"#",None)

10 files have been processed. The last was aa_genes_filtered_U_52338_genomic
20 files have been processed. The last was aa_genes_unfiltered_U_52105_genomic
30 files have been processed. The last was aa_genes_unfiltered_U_52439_genomic
40 files have been processed. The last was aa_genes_unfiltered_U_52524_genomic
50 files have been processed. The last was aa_genes_unfiltered_U_52615_genomic
                 Genome_id  T2SSE  TPR_16  TPR_19  TPR_2  TPR_14  TPR_6  \
Gene_name                                                                 
contig_446576_2    U_52529      1       0       0      0       0      0   
contig_446576_3    U_52529      0       3       2      4       5      4   
contig_446576_4    U_52529      0       0       0      0       0      0   
contig_639749_1    U_52529      1       0       0      0       0      0   
contig_658327_3    U_52529      0       9      10     13      12     10   
contig_658327_8    U_52529      0       3       4      6       5      0   
contig_

Unnamed: 0,Genome_id,Taxonomy,T2SSE,TPR_16,TPR_19,TPR_2,TPR_14,TPR_6,TPR_1,TPR_12,...,WD40_alt,WD40,YscO-like,Type_III_YscX,YopE,YscW,TAL_effector,WD40_3,WD40_4,T3SS_needle_reg
0,U_52529,d__Bacteria;p__Proteobacteria;c__Gammproteobac...,16,30,28,34,34,26,27,27,...,0,0,0,0,0,0,0,0,0,0
1,U_51963,d__Bacteria;p__Chloroflexi,12,15,10,13,14,10,10,10,...,1,1,0,0,0,0,0,0,0,0
2,U_52278,d__Bacteria;p__Gemmatimonadetes;c__Gemm-2,10,13,15,11,18,9,9,11,...,0,4,0,0,0,0,0,0,0,0
3,U_52536,d__Bacteria;p__Chloroflexi,12,14,12,15,12,10,13,14,...,1,5,0,0,0,0,0,0,0,0
4,U_52439,p__Poribacteria,24,103,98,106,100,95,99,99,...,1,88,1,0,0,0,0,0,0,0
5,U_52531,d__Bacteria;p__Chloroflexi;c__SAR202,15,10,10,11,10,6,10,10,...,1,1,1,0,0,0,0,0,0,0
6,U_51962,d__Bacteria;p__Bacteroidetes;c__Rhodothermia;o...,11,31,31,31,32,30,25,29,...,0,1,0,0,0,0,0,0,0,0
7,U_52520,d__Bacteria;p__Chloroflexi;c__SAR202,10,4,4,4,6,4,4,4,...,0,0,0,0,0,0,0,0,0,0
8,U_52478,d__Bacteria;p__Actinobacteria (sister to o__ko...,11,4,2,4,3,1,2,3,...,0,2,0,0,0,0,0,0,0,0
9,U_52271,d__Bacteria;p__Acidobacteria,17,53,47,54,53,48,47,51,...,0,4,0,1,0,0,0,0,0,0


In [None]:
genes_hit=pd.unique(new_df['query name'])
repeats_found=pd.unique(new_df['# target name'])
hit_df=pd.DataFrame(index=genes_hit,columns=repeats_found)
hit_df=hit_df.fillna(0)
print hit_df

In [8]:
file_name=os.path.join(gene_dir,"aa_genes_filtered_U_41432_genomic.tsv")
import re

In [22]:
for ((contig,repeat_motif),count) in grouped_counts.iteritems():
    #print contig,repeat_motif,count
    hit_df.set_value(contig,repeat_motif,count)
hit_df

Unnamed: 0,TPR_4,TPR_11,TPR_1,TPR_2,TPR_16,TPR_8,TPR_10,TPR_12,TPR_14,TPR_17,...,TPR_21,TPR_3,WD40,WD40_alt,Ank_2,Ank_3,Ank,Ank_5,Ank_4,TPR_18
L35C15UnmappedToSymbiodinium_L35C15UnmappedToPlutea_nesoni_seqprepped_13206_GBR_UNSW_H8P31ADXX_TAAGGCGA_combinedL1L2__merged_contig_76860_1,3,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
L35C15UnmappedToSymbiodinium_L35C15UnmappedToPlutea_nesoni_seqprepped_13206_GBR_UNSW_H8P31ADXX_TAAGGCGA_combinedL1L2__merged_contig_76860_5,8,7,8,9,7,9,8,7,8,7,...,3,0,0,0,0,0,0,0,0,0
L35C15UnmappedToSymbiodinium_L35C15UnmappedToPlutea_nesoni_seqprepped_13206_GBR_UNSW_H8P31ADXX_TAAGGCGA_combinedL1L2__merged_contig_76860_6,8,8,8,9,5,7,8,5,7,8,...,4,5,0,0,0,0,0,0,0,0
L35C15UnmappedToSymbiodinium_L35C15UnmappedToPlutea_nesoni_seqprepped_13206_GBR_UNSW_H8P31ADXX_TAAGGCGA_combinedL1L2__merged_contig_112401_17,0,4,4,5,5,0,0,0,5,2,...,2,0,0,0,0,0,0,0,0,0
L35C15UnmappedToSymbiodinium_L35C15UnmappedToPlutea_nesoni_seqprepped_13206_GBR_UNSW_H8P31ADXX_TAAGGCGA_combinedL1L2__merged_contig_103936_8,11,12,13,13,8,13,13,9,12,14,...,4,3,0,0,0,0,0,0,0,0
L35C15UnmappedToSymbiodinium_L35C15UnmappedToPlutea_nesoni_seqprepped_13206_GBR_UNSW_H8P31ADXX_TAAGGCGA_combinedL1L2__merged_contig_86701_5,0,2,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
L35C15UnmappedToSymbiodinium_L35C15UnmappedToPlutea_nesoni_seqprepped_13206_GBR_UNSW_H8P31ADXX_TAAGGCGA_combinedL1L2__merged_contig_96605_12,0,0,0,0,0,0,0,0,0,0,...,0,0,15,0,0,0,0,0,0,0
L35C15UnmappedToSymbiodinium_L35C15UnmappedToPlutea_nesoni_seqprepped_13206_GBR_UNSW_H8P31ADXX_TAAGGCGA_combinedL1L2__merged_contig_96701_10,0,0,0,7,2,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
L35C15UnmappedToSymbiodinium_L35C15UnmappedToPlutea_nesoni_seqprepped_13206_GBR_UNSW_H8P31ADXX_TAAGGCGA_combinedL1L2__merged_contig_141573_8,4,3,0,3,2,3,0,2,4,2,...,2,0,0,0,0,0,0,0,0,0
L35C15UnmappedToSymbiodinium_L35C15UnmappedToPlutea_nesoni_seqprepped_13206_GBR_UNSW_H8P31ADXX_TAAGGCGA_combinedL1L2__merged_contig_141573_12,0,0,0,0,0,0,0,0,0,0,...,0,0,0,7,0,0,0,0,0,0


In [41]:
from glob import glob
import os
    
    
def write_windows_colour_file(windows_file,Donovan_colour_file,output_file):
    windows_dict=new_windows_colours_files(windows_file,Donovan_colour_file)
    output_txt="\n".join(["\t".join((key,value)) for key,value in windows_dict.iteritems()])
    
    with open(output_file,'w') as new_colour_file:
        new_colour_file.write(output_txt)
        
    return None

def convert_colour_files_for_windows(windows_size, output_directory, donovan_colour_files,sequence_file):
    
    
    return
    
def load_Francesco_colour_file(directory_path):
    colour_files={}
    for file_name in glob(os.path.join(directory_path,"*.tab")):
        with open(file_name) as current_genome:
            genome_file_name=os.path.basename(file_name).split(".fasta")[0]+".fasta"
            colour_files[genome_file_name]={}
            next(current_genome)
            for line in current_genome:
                contig, Ancester_lineage,LCA_lineage,R,G,B=line.strip().split("\t")
                colour_files[genome_file_name][contig]=(R,G,B)
    return colour_files

def write_Donovan_colour_files(output_directory,F_colour_files):
    donovan_format={}
    for file_name,contig_values in F_colour_files.iteritems():
        print file_name
        donovan_format[file_name+".don.tab"]={}
        new_colour_file=open(os.path.join(output_directory,file_name+".don.tab"),'w')
        storage_list=["na"]*len(contig_values)
        i=0
        for contig, colour_tuple in contig_values.iteritems():
            colour_code=[str(int(float(colour_prop)*255)) for colour_prop in colour_tuple]
            #print colour_code
            donovan_format[file_name+".don.tab"][contig]=",".join(colour_code)
            storage_list[i]="\t".join([contig,",".join(colour_code)])
            i+=1
        new_colour_file.write("\n".join(storage_list))         
        new_colour_file.close()  
    return
        
def new_windows_colours_files(windows_file,Donovan_colour_file):
    colour_pairings={}
    with open(Donovan_colour_file) as Don_col:
        for line in Don_col:
            contig, col_tup=line.strip().split("\t")
            #col_tup=tuple()
            colour_pairings[contig]=col_tup
    #print "The Donovan style dictionary",colour_pairings
    
    reg_pat="(?:>)(.*)(?:\n)" #Extracts contig name
    #colour_pairings=defaultdict(lambda x: x,colour_pairings)
    #with open(windows_file,'r') as windows_seq:
    #    for line in windows_seq:
    #        if line.startswith(">"):
    windows=open(windows_file,'r')
    windows_txt="".join(windows.readlines()) #reasonable since single cells should not get big
    windows.close()
    windows_dict={}
    for match in re.finditer(reg_pat,windows_txt):
        contig_name=match.groups()[0]
        #print contig_name
        contig,position=contig_name.rsplit(":",1) #sep name and postion
        if contig in colour_pairings:
            windows_dict[contig_name]=colour_pairings[contig]
    #print "The windows dictionary", windows_dict
    
    return windows_dict
-


In [42]:
windows_file="C:\\Users\\Baker\\Google Drive\\Honours\\SingleCellM\\Experimenting\\ACE-SCG_MDA003-K22.scaffoldswindows.fasta"

francesco_colour_file="C:\\Users\\Baker\\Google Drive\\Honours\\SingleCellM\\Experimenting\\ACE-SCG_MDA003-K22.scaffolds.fasta.gz-nt_lca-colors.tab"

output_file="C:\\Users\\Baker\\Google Drive\\Honours\\SingleCellM\\Experimenting\\new_ACE-SCG_MDA003-K22.window_colours.tab"

windows_colour_file_wf(windows_file,francesco_colour_file,output_file)

There were 5588 contigs in the lca colour file. 5890 windows are to be coloured.


In [11]:
colour_data=load_Francesco_colour_file("C:\\Users\\Baker\\Google Drive\\Honours\\SingleCellM\\Francesco_Data\\nt_tab")
write_Donovan_colour_files("C:\\Users\\Baker\\Google Drive\\Honours\\SingleCellM\\Francesco_Data\\Donovan_colours",colour_data)

ACE-SCG_MDA003-M17.scaffolds.fasta
ACE-SCG_MDA005-L4.scaffolds.fasta
ACE-SCG_MDA003-O19.scaffolds.fasta
ACE-SCG_MDA004-K17.scaffolds.fasta
ACE-SCG_MDA004-K10.scaffolds.fasta
ACE-SCG_MDA004-G6.scaffolds.fasta
ACE-SCG_MDA003-O15.scaffolds.fasta
ACE-SCG_MDA006-J10.scaffolds.fasta
ACE-SCG_MDA006-O3.scaffolds.fasta
ACE-SCG_MDA005-J11.scaffolds.fasta
GCA_000986845.1_ASM98684v1_genomic.fasta
ACE-SCG_MDA003-N9.scaffolds.fasta
ACE-SCG_MDA004-N3.scaffolds.fasta
ACE-SCG_MDA004-M13.scaffolds.fasta
ACE-SCG_MDA004-N11.scaffolds.fasta
GCA_001549325.1_SCGC-AAA382N08_genomic.fasta
ACE-SCG_MDA006-P3.scaffolds.fasta
ACE-SCG_MDA003-K22.scaffolds.fasta
ACE-SCG_MDA006-P5.scaffolds.fasta
ACE-SCG_MDA003-M3.scaffolds.fasta
ACE-SCG_MDA004-L10.scaffolds.fasta
ACE-SCG_MDA005-N12.scaffolds.fasta
ACE-SCG_MDA003-N16.scaffolds.fasta


In [11]:
sns.set_style({'font.family':'sans-serif', 'font.sans-serif':'Bitstream Vera Sans'})

print sns.axes_style()

def fix_eps(fpath):
    """Fix carriage returns in EPS files caused by Arial font."""
    txt = b""
    with open(fpath, "rb") as f:
        for line in f:
            if b"\r\rHebrew" in line:
                line = line.replace(b"\r\rHebrew", b"Hebrew")
            txt += line
    with open(fpath, "wb") as f:
        f.write(txt)

def plot_heatmap_proportion(path_id, path_name, KO_genome_files, cmap,output_dir,database_dir,bin_names):
    
    sns.set_style({'font.family':'sans-serif', 'font.sans-serif':'Bitstream Vera Sans'})

    PTH_MO_pairs = load_local_kegg_database_pairings(database_dir,[("pathway","module")], False)[("pathway","module")]
    linkable_base_paths=set(PTH_MO_pairs.iterkeys()) & set(path_id)
    base_mod= set(itertools.chain(*[PTH_MO_pairs[pathway] for pathway in linkable_base_paths]))
    
    MO_KO_pairs  = load_local_kegg_database_pairings(database_dir,[("module","orthology")], False)[("module","orthology")]
    all_modules = {MO:MO_KO_pairs[MO] for MO in base_mod} #kc.link_ids('ko', base_mod)
    module_totals = pd.Series(
        {
            mod_id: len(ko_ids)
            for mod_id, ko_ids in all_modules.iteritems()
        }
    )
    rev_modules = dictionary.reverse_mapping(all_modules)
    
    module_names=load_readable_names(database_dir,["module"],False)["module"]
    
    mod_prop = {}
    for genome_id, KOs in KO_genome_files.iteritems():
        mod_prop[genome_id] = {}
        genome_ko_ids = KOs
        for mod_id, ko_ids in all_modules.iteritems():
            mod_prop[genome_id][mod_id] = len(set(ko_ids) & genome_ko_ids)
    
    mod_prop = pd.DataFrame(mod_prop).fillna(0)
    mod_prop = mod_prop[mod_prop.sum(axis=1) > 0].divide(module_totals, axis='index').dropna()
    mod_prop = mod_prop.rename(index=module_names, columns=bin_names)
    mod_prop=mod_prop.sort_index(axis='columns')
    
    h2 = sns.clustermap(mod_prop, col_cluster=False, method='complete', cmap=cmap)
    
    #font_0=h2.FontProperties()
    #font_0.set_family('sans-serif')
    #font_0.set_style('Bitstream Vera Sans')
    h2.ax_heatmap.set_title(path_name)
    for text in h2.ax_heatmap.get_yticklabels():
        text.set_rotation('horizontal')
    for text in h2.ax_heatmap.get_xticklabels():
        text.set_rotation('vertical')
    #h2.savefig(os.path.join(output_dir,'{}-modules_proportion.pdf'.format(path_name)))
    h2.savefig(os.path.join(output_dir,'{}-modules_proportion.eps'.format(path_name)),format="eps")
    h2.savefig(os.path.join(output_dir,'{}-modules_proportion.svg'.format(path_name)),format="svg")
    fix_eps(os.path.join(output_dir,'{}-modules_proportion.eps'.format(path_name)))
    
colors = ['Blues', 'Greens', 'Oranges', 'Purples', 'Reds', 'Greys','Blues','Blues']
for (path_name, path_id), palette in zip(pathways.iteritems(), colors):
    print path_name, path_id, palette
    plot_heatmap_proportion(path_id, path_name, KO_genome_hits, palette,output_dir,os.path.join(output_dir,"Databases"),bin_names)



def group_micro_plot_heatmap_proprtion(path_id, path_name, KO_genome_files, cmap,output_dir,database_dir,bin_names):
    
    sns.set_style({'font.family':'sans-serif', 'font.sans-serif':'Bitstream Vera Sans'})
    
    PTH_MO_pairs = load_local_kegg_database_pairings(database_dir,[("pathway","module")], False)[("pathway","module")]
    linkable_base_paths=set(PTH_MO_pairs.iterkeys()) & set(path_id)
    base_mod= set(itertools.chain(*[PTH_MO_pairs[pathway] for pathway in linkable_base_paths]))
    
    MO_KO_pairs  = load_local_kegg_database_pairings(database_dir,[("module","orthology")], False)[("module","orthology")]
    all_modules = {MO:MO_KO_pairs[MO] for MO in base_mod} #kc.link_ids('ko', base_mod)
    module_totals = pd.Series(
        {
            mod_id: len(ko_ids)
            for mod_id, ko_ids in all_modules.iteritems()
        }
    )
    rev_modules = dictionary.reverse_mapping(all_modules)
    
    module_names=load_readable_names(database_dir,["module"],False)["module"]
    
    mod_prop = {}
    for genome_id, KOs in KO_genome_files.iteritems():
        mod_prop[genome_id] = {}
        genome_ko_ids = KOs
        for mod_id, ko_ids in all_modules.iteritems():
            mod_prop[genome_id][mod_id] = len(set(ko_ids) & genome_ko_ids)
            
    euk_names=["coral","SymbC15"]
    
    bacteria=set(mod_prop.keys()) - set(euk_names)
    
    new_mod_prop={}
    new_mod_prop["microorganisms"]={}
    for module in mod_prop["coral"].iterkeys():
        new_mod_prop["microorganisms"][module]=0
    for euk in euk_names:
        new_mod_prop[euk]=mod_prop[euk]
    for genome, module_counts in mod_prop.iteritems():
        if genome not in euk_names:
            for module, count in module_counts.iteritems():
                new_mod_prop["microorganisms"][module]=max(new_mod_prop["microorganisms"][module],count)
                
    new_bin_names={gen_id:gen_name for gen_id,gen_name in bin_names.iteritems() if gen_id in euk_names}
    new_bin_names["microorganisms"]="microorganisms"
    bin_names=new_bin_names
    
    mod_prop = pd.DataFrame(new_mod_prop).fillna(0)
    mod_prop = mod_prop[mod_prop.sum(axis=1) > 0].divide(module_totals, axis='index').dropna()
    mod_prop = mod_prop.rename(index=module_names, columns=bin_names)
    mod_prop=mod_prop.sort_index(axis='columns')
    
        
    h2 = sns.clustermap(mod_prop, col_cluster=False, method='complete', cmap=cmap)
    #font_0=h2.FontProperties()
    #h2.set_style({'font.family':'sans-serif', 'font.sans-serif':'Bitstream Vera Sans'})
    #font_0.set_family('sans-serif')
    #font_0.set_style('Bitstream Vera Sans')
    h2.ax_heatmap.set_title(path_name)
    for text in h2.ax_heatmap.get_yticklabels():
        text.set_rotation('horizontal')
    for text in h2.ax_heatmap.get_xticklabels():
        text.set_rotation('vertical')
    #h2.savefig(os.path.join(output_dir,'{}_grouped_microbes_modules_proportion.pdf'.format(path_name)))
    h2.savefig(os.path.join(output_dir,'{}_grouped_microbes_modules_proportion.eps'.format(path_name)),format="eps")
    h2.savefig(os.path.join(output_dir,'{}_grouped_microbes_modules_proportion.svg'.format(path_name)),format="svg")
    fix_eps(os.path.join(output_dir,'{}_grouped_microbes_modules_proportion.eps'.format(path_name)))
    
colors = ['Blues', 'Greens', 'Oranges', 'Purples', 'Reds', 'Greys','Blues','Blues']
for (path_name, path_id), palette in zip(pathways.iteritems(), colors):
    print path_name, path_id, palette
    group_micro_plot_heatmap_proprtion(path_id, path_name, KO_genome_hits, palette,output_dir,os.path.join(output_dir,"Databases"),bin_names)

{'legend.numpoints': 1, 'axes.axisbelow': True, 'font.sans-serif': [u'Bitstream Vera Sans'], 'axes.labelcolor': '.15', 'ytick.major.size': 0.0, 'axes.grid': True, 'ytick.minor.size': 0.0, 'legend.scatterpoints': 1, 'axes.edgecolor': 'white', 'grid.color': 'white', 'legend.frameon': False, 'ytick.color': '.15', 'xtick.major.size': 0.0, 'figure.facecolor': 'white', 'xtick.color': '.15', 'xtick.minor.size': 0.0, 'font.family': [u'sans-serif'], 'xtick.direction': u'out', 'lines.solid_capstyle': u'round', 'grid.linestyle': u'-', 'image.cmap': u'Greys', 'axes.facecolor': '#EAEAF2', 'text.color': '.15', 'ytick.direction': u'out', 'axes.linewidth': 0.0}
oxidative_phosphorylation ['map00190'] Blues
two-component ['map02020'] Greens
vitamins&cofactors ['map00730', 'map00740', 'map00750', 'map00760', 'map00770', 'map00780', 'map00785', 'map00790', 'map00670', 'map00830', 'map00860', 'map00130'] Oranges
AminoAcidMetabolism ['map00250', 'map00270', 'map00260', 'map00280', 'map00290', 'map00300', 'm

In [None]:
print output_dir