In [None]:
import re, glob, os, math, re, scipy, ecopy, random
from concurrent.futures import ProcessPoolExecutor
from collections import defaultdict
from skbio.stats.ordination import pcoa
import numpy as np
import pandas as pd
from Bio import SeqIO, SearchIO
import matplotlib.pyplot as plt
import subprocess as sp
from ete3 import Tree
import seaborn as sns
sns.set('notebook')
%matplotlib inline 
from IPython.display import Markdown, display
import warnings
warnings.filterwarnings('ignore')

In [None]:
def cmdir(path):
    if not os.path.isdir(path):
        os.mkdir(path)

In [None]:
def scaffold(gene):
    if gene != "None":
        try: return re.search("(.+?)_[0-9]+$", gene).group(1)
        except: print(gene)

In [None]:
rootdir = "TO_FILL_IN"

# gather curated genome data

### read in curated genome metadata

In [None]:
# assigning environment/study
current = sorted(glob.glob(rootdir + "metadata/filtered_genome_metadata_curated*"))[-1]
mc = pd.read_csv(current, sep=",")
mc = mc[~mc["newname"].isin(to_remove)]

In [None]:
# define color palette
env2color = {
    "freshwater": "#5B7DAD",
    "sediment": "#312C29",
    "marine": "#10b5a7",
    "soil": "#7a5d1f",
    "engineered": "#B4B5B4",
    "animal-associated": "#A84726",
    "hypersaline": "#d6960b",
    "plant-associated": "#647d37"
}

In [None]:
# explore
mcs = mc[["taxcat", "env_broad"]]
mcs["count"]=1
mcg = mcs.groupby(["taxcat", "env_broad"], as_index=False).count()

sns.set_style("ticks", {"axes.edgecolor": "0.8"})
kws = dict(linewidth=.5, edgecolor="black")
g = sns.relplot("env_broad", "taxcat", data=mcg[mcg["taxcat"]!="None"], size="count", hue="env_broad",
    palette=env2color, alpha=1, height=3, aspect=2, sizes=(50,500), **kws, legend="brief")
plt.xticks(rotation=45, ha="right")
plt.show()

# annotate the genomes

In [None]:
cmdir(rootdir + "protein")
cmdir(rootdir + "protein/prodigal")

### generate protein files

In [None]:
# do this again if tax corrected
calls = []

for key, row in mc.iterrows():
    
    genome = glob.glob(rootdir + "genomes/" + row["newname"] + ".fna")[0]
    
    # if alternatively coded, repredict proteins with code 25
    if row["taxcat"] in ["Gracilibacteria", "Absconditabacteria"]:
        code = "25"
    else:
        code = "11"

    call = "prodigal -i " + genome + " -m -g " + code + " -a " + \
        rootdir + "protein/prodigal/" + os.path.basename(genome).replace(".fna", ".fa.genes.faa") + \
        " -o " + rootdir + "protein/prodigal/" + os.path.basename(genome).replace(".fna", ".fa.genes") + \
        " -d " + rootdir + "protein/prodigal/" + os.path.basename(genome).replace(".fna", ".fa.genes.fna")
    calls.append(call)

In [None]:
def run(call):  
    sp.call(call, shell=True)
    
with ProcessPoolExecutor(20) as executor:
    executor.map(run, calls)

In [None]:
# clear protein file
if os.path.isfile(rootdir + "/protein/ALL.faa"):
    os.remove(rootdir + "/protein/ALL.faa")

# concatenate
with open(rootdir + "/protein/ALL.faa", "a") as catfile:
    for proteome in glob.glob(rootdir + "protein/prodigal/*faa"):
        for record in SeqIO.parse(open(proteome), "fasta"):
            catfile.write(">" + str(record.description) + "\n" + str(record.seq) + "\n")

### run kofam

In [None]:
# launch kofamscan
kocall = "sbatch -J kofamscan --wrap '/path/to/exec_annotation -o " + rootdir + "/protein/kofamscan.latest.txt " + \
    rootdir + "protein/ALL.faa " + "-p path/to/kofam/profiles/prokaryote.hal " + \
    "-k /path/to/kofam/metadata/ko_list --cpu 48 -f detail'"
print(kocall)

In [None]:
# read in output
buffer = []
for line in open(rootdir + "protein/kofamscan.latest.txt").readlines():
    if "#" not in line:
        # hilariously long regex
        m = re.search("[* ]*([\S]+)\s+([\S]+)\s+([0-9.-]+)\s+" + \
            "([0-9.-]+)\s+([0-9.+-e]+)\s(.+?$)", line.strip())
        try:
            buffer.append(m.groups())
        except:
            print(line)

kodf = pd.DataFrame.from_records(buffer, columns =["gene", "ko", "threshold", "score", "eval", "def"]) 
buffer=[]

In [None]:
# filter for significance
kodf["eval"] = kodf["eval"].apply(lambda x: float(x))
kodf["score"] = kodf["score"].apply(lambda x: float(x))
kodf = kodf[kodf["eval"] < 1e-6]
# get best hit per gene based on score
kfilt = kodf.sort_values('score', ascending=False).drop_duplicates("gene")

### run de-novo protein clustering

In [None]:
cmdir(rootdir + "/scripts/")

In [None]:
#scripts available from
# https://github.com/raphael-upmc/proteinClusteringPipeline

In [None]:
with open(rootdir + "/scripts/runProteinClustering.sh", "w") as wrapper:
    # start by subfamily clustering
    call1 = "subfamilies.py --output-directory " + \
        rootdir + "/protein/protein_clustering/ --cpu 48 " + rootdir + "protein/ALL.faa"
    # then do hmm-hmm comparison to generate families
    call2 = "hhblits.py --cpu 48 " + \
       rootdir + "protein/protein_clustering/config.json"
    call3 = "runningMclClustering.py --force --coverage 0.70 " + \
        "--fasta --cpu 48 " + rootdir + "protein/protein_clustering/config.json"
    wrapper.write("\n".join([call1,call2, call3]))

Chmod +x wrapper then sbatch.

### trim

In [None]:
# build scaf2bin
scaf2bin = {}

for genome in glob.glob(rootdir + "genomes/*"):
    name = os.path.basename(genome).replace(".fna", "")
    for record in SeqIO.parse(open(genome), "fasta"):
        scaf2bin[record.description.split(" ")[0]] = name

In [None]:
# apply scaf2bin then filter original df
kfilt["scaffold"] = kfilt["gene"].apply(scaffold)
kfilt["bin"] = kfilt["scaffold"].map(scaf2bin)
kfilt = kfilt[~kfilt["bin"].isin(to_remove)]

### reconcile

In [None]:
# reproduce consistency plot from Meheust et al.
orf2family = {}
count = 0

for line in open(rootdir + "/protein/protein_clustering/orf2family.tsv").readlines():
    if count > 0:
        orf2family[line.split("\t")[0]] = line.split("\t")[1].strip()
    count +=1

In [None]:
cons = []

for ko in kfilt["ko"].unique():
    
    genes = kfilt[kfilt["ko"]==ko]["gene"].to_list()
    fams = [orf2family[gene] for gene in genes if gene in orf2family]
    
    if len(genes) > 5:
        mode = scipy.stats.mode(fams).mode[0]
        p = fams.count(mode)/float(len(genes)) * 100
        cons.append(p)

In [None]:
cdf = pd.DataFrame(sorted(cons, reverse=True)).reset_index()
cdf.columns = ["rank", "consistency"]
sns.set_style("ticks")
kws = {'s':20, 'alpha':0.2}
sns.regplot("rank", "consistency", data=cdf, color="blue", scatter_kws=kws, fit_reg=False)
plt.show()

In [None]:
# reconcile with kofams? remember sub threshold hits too

fam2ko = {}
for fam in glob.glob(rootdir + "protein/protein_clustering/familiesFasta/*"):
    
    orfs = []
    for record in SeqIO.parse(open(fam), "fasta"):
        orfs.append(record.description.split(" ")[0])
    
    table = kfilt[kfilt["gene"].isin(orfs)]
    
    if len(table) > 0:
        mode = table["ko"].mode()[0]
        p = table["ko"].to_list().count(mode)/float(len(orfs))
        #r = np.median(table[table["ko"]==mode]["score_ratio"].to_list())
        fam2ko[os.path.basename(fam).replace(".fa","")] = {"ko": mode, "percent": round(p,3)} #, "med_ratio":round(r,2)}

In [None]:
kps = [fam2ko[key]["percent"] for key in fam2ko.keys()]
kpf = pd.DataFrame(sorted(kps, reverse=True)).reset_index()
kpf.columns = ["rank", "consistency"]
sns.set_style("ticks")
kws = {'s':20, 'alpha':0.2}
sns.regplot("rank", "consistency", data=kpf, color="blue", scatter_kws=kws, fit_reg=False)
plt.show()

### supp table 5

In [None]:
pfam_info = defaultdict(list)

for family in glob.glob(rootdir + "protein/protein_clustering/familiesFasta/*"):
    
    name = os.path.basename(family).replace(".fa","")
    sizes = [len(record.seq) for record in SeqIO.parse(open(family), "fasta")]
    pfam_info["family"].append(name)
    pfam_info["num_seqs"].append(len(sizes))
    pfam_info["median_protein_len"].append(np.median(sizes))
    
pfdf = pd.DataFrame(pfam_info)
# subset by size
pfsub = pfdf[pfdf["num_seqs"]>=5]

In [None]:
pfmerge = pfsub.merge(pd.DataFrame.from_dict(fam2ko, orient="index").reset_index().rename(columns={"index":"family"}), how="left")
ko_medians = kfilt.groupby("ko", as_index=False).aggregate({"score":"median", "eval":"median"}).rename(columns={"score":"median_score", "eval": "median_eval"})
pfmerge = pfmerge.merge(ko_medians, how="left").merge(kfilt[["ko", "threshold", "def"]].drop_duplicates(), how="left").fillna("None").sort_values("family")
pfmerge = pfmerge[["family", "num_seqs", "median_protein_len", "ko", "def", "percent", "threshold", "median_score", "median_eval"]].rename(columns={"ko":"kegg_orthology", "def": "kegg_definition", "percent": "fraction_seqs_annotated", "threshold":"kegg_threshold"})

In [None]:
pfmerge.to_csv(rootdir + "protein/supp_table_5.csv", index=False)

# making the 16RP tree

In [None]:
# extract 16 RP results from kofam results
# crappy filter for now - generate terms
rp16 = ["S8","L5","L18","S3","L22","S10","S19","L14","L15","L24","L16","L2","L3","S17","L6","L4"]
terms = ["subunit ribosomal protein " + term + "$" for term in rp16]
# filter
k16 = kfilt[kfilt["def"].str.contains('|'.join(terms))]
# check results
len(k16["def"].unique())

In [None]:
# selection step - one per genome, same contig?
k16["scaffold"] = k16["gene"].apply(scaffold)
k16["bin"] = k16["scaffold"].map(scaf2bin)
# get mode scaffold per bin
modes = k16.groupby("bin", as_index=False).aggregate({"scaffold": lambda x: scipy.stats.mode(x).mode[0]})
modes.columns = ["bin", "mode_scaf"]
k16 = k16.merge(modes, how="left", on="bin")
# is hmm on the mode scaf?
k16["scafscore"] = k16.apply(lambda x: x["scaffold"]==x["mode_scaf"], axis=1)
# sort and dereplicate preferencing those on mode scaf
rpfilt = k16.sort_values(["bin", "ko", "scafscore", "score"], ascending=[False,False,False,False]).drop_duplicates(["bin", "ko"])

In [None]:
scaf_counts=[]

for bin in rpfilt["bin"].unique():
    scaf_counts.append(len(set(rpfilt[rpfilt["bin"]==bin]["scaffold"])))
    
sns.distplot(scaf_counts, kde=False, bins=10)
plt.show()

In [None]:
rpiv = rpfilt.pivot("bin", "ko", "gene").fillna("None")

### build outgroup

In [None]:
outdir= rootdir + "/protein/rp16/"
cmdir(outdir)

In [None]:
# read in tables
supp = pd.read_csv(rootdir + "protein/bmc_supp.tsv", sep="\t")
tree = pd.read_csv(rootdir + "metadata/crossenv_phy.txt", 
    header=None, names=["phylum", "group"]).fillna("None")

# subset
references = []
# immediate context
references += supp[supp["revised_tax"].isin(["Kazanbacteria", "Peregrinibacteria", "Peribacteria", 
    "Berkelbacteria", "Howlettbacteria", "Abawacabacteria"])]["name"].to_list()
#superphyla
for supergroup in ["Microgenomates", "Parcubacteria"]:
    subtree = tree[tree["group"].str.contains(supergroup)]
    subtable = supp[supp["revised_tax"].isin(subtree["phylum"].to_list())]
    references += random.sample(subtable["name"].to_list(), 25)
    
print(len(references))

In [None]:
#scaf2bin
scaf2cpr = {}
for genome in glob.glob(rootdir + "genomes/*"):
    for record in SeqIO.parse(open(genome), "fasta"):
        scaf2cpr[record.description.split(" ")[0]] = os.path.basename(genome).split(".")[0]

In [None]:
# generate outgroup files
mapping = {"PF00410": "rpS8","PF00281":"rpL5","TIGR00060":"rpL18","TIGR01009":"rpS3",
            "TIGR01044":"rpL22","TIGR01049":"rpS10","TIGR01050":"rpS19","TIGR01067":"rpL14",
            "TIGR01071":"rpL15","TIGR01079":"rpL24","TIGR01164":"rpL16","TIGR01171":"rpL2",
            "TIGR03625":"rpL3","TIGR03635":"rpS17","TIGR03654":"rpL6","TIGR03953":"rpL4"}

for reffile in glob.glob(rootdir + "reference_genomes/bac175/*.BAC175.concat.faa"):
    if ("TIGR02013" not in reffile) and ("TIGR02386" not in reffile):
        with open(rootdir + "protein/rp16/" + mapping[os.path.basename(reffile).split(".")[0]] + ".CPR.faa", "w") as out:
            for record in SeqIO.parse(open(reffile), "fasta"):
                if "@" not in record.description:
                    try: bin = scaf2cpr[scaffold(record.description.split(" ")[0])]
                    except: bin = "None"
                    if bin in references:
                        out.write(">" + record.description + "\n" + str(record.seq) + "\n")

In [None]:
forward_kegg_mapping = {}
reverse_kegg_mapping = {}

for key, row in k16.drop_duplicates(["ko", "def"]).iterrows():
    forward_kegg_mapping[row["ko"]] = "rp" + row["def"].split(" ")[-1]
    reverse_kegg_mapping["rp" + row["def"].split(" ")[-1]] = row["ko"] 

In [None]:
# generate corresponding rpiv
outgroups = {}

for reffile in glob.glob(rootdir + "protein/rp16/rp*CPR*"):
    
    col = reverse_kegg_mapping[os.path.basename(reffile).split(".")[0]]
    
    for record in SeqIO.parse(open(reffile), "fasta"):
        
        try: taxon = scaf2cpr[scaffold(record.description.split(" ")[0])]
        except: taxon="None"
        
        if taxon != "None":
            if taxon not in outgroups:
                outgroups[taxon] = {col: record.description.split(" ")[0]}
            else:
                outgroups[taxon][col] = record.description.split(" ")[0]

outdf = pd.DataFrame.from_dict(outgroups, orient="index")

In [None]:
len(outdf)

### pull + align individual genes

In [None]:
with open(outdir + "wrapper.sh", "w") as wrapper:
    
    for col in rpiv.columns:
        if "K0" in col:# write names
            file = outdir + col
            with open(file + ".names.txt", "w") as names:
                for key, row in rpiv.iterrows():
                    if row[col] != "None":
                        names.write(row[col] + "\n")
            # pullseq
            call1 = "pullseq -n " + file + ".names.txt " + \
                "-i " + rootdir + "protein/ALL.faa > " + file + ".faa"
            wrapper.write(call1 + "\n")
            
            # merge with CPR references for now
            reffile = glob.glob(rootdir + "protein/rp16/" + forward_kegg_mapping[col] + ".CPR*")[0]
            call2 = "cat " + file + ".faa " + reffile + \
                " > " + file + ".concat.faa"
            wrapper.write(call2 + "\n")
            
            # mafft
            call3 = "mafft --thread 16 --retree 2 --reorder " + file + \
                ".concat.faa > " + file + ".mafft"
            wrapper.write(call3 + "\n")
            
            # bmge
            call4 = "java -jar BMGE.jar -i " + file + ".mafft" + \
                " -t AA -m BLOSUM30 -of " + file + ".bmge.mafft"
            wrapper.write(call4 + "\n")       

In [None]:
# merge in outgroups
rpiv = pd.concat([rpiv, outdf]).fillna("None")

In [None]:
merged_seq = rpiv
aln_lens = {}

def get_sequence(gene, seq_dict):
    if gene=="None": 
        return "None"
    else:
        try: return seq_dict[gene]
        except: 
            print("%s not found!" %(gene))
            return "None"
            
# add sequences to merged df
for trimmed_alignment in glob.glob(outdir + "*bmge*"):
    # first read in trimmed sequences
    temp_dict = {}
    for record in SeqIO.parse(open(trimmed_alignment, "r"), "fasta"):
        # pull clean headers
        m = re.search("(\S+).*", record.description)
        temp_dict[m.group(1)] = str(record.seq)
    # now add to the dataframe using apply
    hmm = os.path.basename(trimmed_alignment).split(".")[0]
    col_name = hmm + "_seq"
    merged_seq[col_name] = merged_seq[hmm].apply(lambda x: get_sequence(x, temp_dict))
    # get aln len to use later
    aln_lens[hmm] = len(record.seq)

In [None]:
# add in counts to use below
kos = list(rpfilt["ko"].unique())

def count_markers(row):
    count = 0
    for hmm in kos:
        if row[hmm] != "None":
            count += 1
    return count

merged_seq["rp16_count"] = merged_seq.apply(count_markers, axis=1)

In [None]:
count_mins = {"rp16": 8}

# finally write out the concatenated alignment, pruning previously identified taxa as before
for dataset in ["rp16"]:
    
    filename = outdir + dataset + "_crossenv.mafft"
    with open(filename, "w") as outfile:

        # for each genome meeting criteria
        for key, row in merged_seq.reset_index().iterrows():
            # if genome meets min gene count threshes
            if (row[dataset + "_count"] >= count_mins[dataset]) and (row["index"] not in to_remove):
                outfile.write(">" + row["index"] + "\n")
                # now write out sequences
                for hmm in kos:
                    col_name = hmm + "_seq"
                    # if missing gene, just add gaps
                    if row[col_name] == "None":
                        outfile.write("-"*aln_lens[hmm])
                    # if gene present
                    else:
                        outfile.write(row[col_name])
                outfile.write("\n")

In [None]:
# and run the trees
for concat_align in glob.glob(outdir + "rp16*mafft*"):
    
    basename = os.path.basename(concat_align).split(".")[0]
    call = "sbatch -J iqtree --wrap 'iqtree -s " + concat_align + " -m MFP -st AA -bb 1000 -nt AUTO -pre " + concat_align.split(".")[0] + "'"
    #sp.call(call, shell=True)
    print(call)

In [None]:
# decorate the tree
info_dict = {row["newname"]:{"name": row["name"], "env_broad": row["env_broad"], "env_narrow": row["env_narrow"]} for key, row in mc.iterrows()}
t = Tree(rootdir + "protein/rp16/rp16_crossenv.treefile")
itol = open(rootdir + "/protein/rp16/rp16_crossenv.envbroad.txt", "w")
itol.write("TREE_COLORS\nSEPARATOR TAB\nDATA\n")

for leaf in t:
    oleaf = leaf.name
    try:
        env_narrow = info_dict[leaf.name]["env_narrow"]
        color = env2color[info_dict[oleaf]["env_broad"]]
        label = info_dict[oleaf]["env_broad"]
    except:
        color = "white"
        label = "None"
        env_narrow = "None"
    leaf.name = leaf.name + "_" + env_narrow
    itol.write(leaf.name + "\trange\t" + color + "\t" + label + "\n")

itol.close()
t.write(outfile=rootdir + "protein/rp16/rp16_crossenv.renamed.treefile", format=2)

In [None]:
# parse phy groups + tree order
phy = pd.read_csv(rootdir + "metadata/crossenv_phy.csv", header=None)
phy.columns = ["leaf", "phy"]

def scrub(leaf):
    for cat in list(mc["env_narrow"].unique()) + ["soda"]:
        leaf = leaf.rsplit(str(cat), 1)[0].strip()
    return leaf.replace(" ", "_")

phy["newname"] = phy["leaf"].apply(scrub)
mc = mc.merge(phy, on="newname", how="left").fillna("None")

# gene content analyses

### overall similarity

In [None]:
# reformat protein dict
pdf = pd.DataFrame.from_dict(orf2family, orient="index").reset_index()
pdf.columns = ["gene", "ko"]
pdf["bin"] = pdf["gene"].apply(lambda x: scaf2bin[scaffold(x)])
pdf = pdf[~pdf["bin"].isin(to_remove)].drop("bin", axis=1)
pdf.head()

In [None]:
def run_pcoa(df, fam_size, phy):
    
    # reconfigure df
    df["bin"] = df["gene"].apply(lambda x: scaf2bin[scaffold(x)])
    # add tax names
    bin2tax = {row["newname"]: row["taxcat"] for key, row in mc.iterrows()}
    df["taxcat"] = df["bin"].map(bin2tax)
    
    # define dataframe
    if phy != "all":
        sub = df[df["taxcat"]==phy][["gene", "ko", "bin"]]
    else: sub = df[["gene", "ko", "bin"]]
        
    gb = sub.groupby(["bin", "ko"], as_index=False).count()
    piv = gb.pivot("bin", "ko", "gene").fillna(0)
    # filter out low count annotations
    piv = piv[piv.columns[piv.sum(axis=0) >= fam_size]]
    pivb = piv > 0
    #calculate distance matrix
    jac = ecopy.distance(pivb, method='jaccard', transform='1')
    # then use skbio to do pcoA
    results = pcoa(jac)
    pcresults = pd.DataFrame(results.samples)
    pcresults["newname"] = piv.index
    
    return results, pcresults

In [None]:
buffer = []

for dataset in ["pclust"]: #"kofam"
    
    df = kfilt if dataset == "kofam" else pdf
    full, pcresult = run_pcoa(df, 5, "Saccharibacteria")
    
    for contrast in ["env_broad", "phy"]:
        
        rm = pcresult.merge(mc[["newname", contrast]], 
            on="newname", how="left")[["PC1", "PC2", contrast]]
        rm.columns = ["PC1", "PC2", "contrast"]
        rm["dataset"] = dataset
        rm["contrast_type"] = contrast
        buffer.append(rm)
        break

In [None]:
# configure palette
phy2color = {phy: sns.color_palette("Set2").as_hex()[i] for i, phy in enumerate(mc["phy"].unique()) if phy!= "None"}
phy2color["None"] = "lightgrey"
merged_palette = {**env2color, **phy2color}
# add taxcats
merged_palette["Saccharibacteria"] = "#4c72b0"
merged_palette["Absconditabacteria"] = "#dd8452"
merged_palette["Gracilibacteria"] = "#55a868"

In [None]:
# plot
sns.set_style("ticks")
kws = {'s':50, 'alpha':1, "edgecolor":"black", "linewidth":0.25}
#p = {True: "blue", False: "lightgrey"}
g = sns.FacetGrid(pd.concat(buffer).fillna("None"), hue="contrast", palette=merged_palette,
    row="contrast_type",height=4, sharex=False, sharey=False, aspect=1.5)
g = g.map(sns.scatterplot, "PC1", "PC2", x_jitter=.01, y_jitter=.01, **kws)
g.set_titles('{row_name}').add_legend()
plt.savefig(rootdir + "figures/sac_pcoas.svg", format="svg", bbox_inches="tight")
plt.show()

### high-level

In [None]:
# sac proteome size by environment/quality
quality = pd.read_csv(rootdir + "metadata/genomeInformation.csv")
quality["newname"] = quality["genome"].apply(lambda x: x.replace(".fna", ""))
m = m.merge(quality[["newname", "completeness", "contamination"]], on="newname", how="left")

In [None]:
comp = defaultdict(list)

for i in range(70, 100, 5):
    
    table = m[m["completeness"]>=i]
    for key, row in table.iterrows():
        comp["newname"].append(row["newname"])
        comp["taxcat"].append(row["taxcat"])
        comp["env_broad"].append(row["env_broad"])
        comp["genome_size"].append(row["genome_size"])
        comp["orf#"].append(row["orf#"])
        comp["threshold"].append(str(i)+"%")

compdf = pd.DataFrame(comp)

In [None]:
taxcat = "Saccharibacteria"
order = compdf[(compdf["taxcat"]==taxcat)].query("threshold=='95%'").groupby("env_broad", as_index=False).aggregate({"orf#":"median"}).sort_values("orf#", ascending=False)["env_broad"].to_list()

sns.set_style("ticks")
plt.figure(figsize=(4,7))
sns.boxplot("orf#", "env_broad", hue="threshold", order=order, palette="Blues", linewidth=0.5, fliersize=0, 
    data=compdf[compdf["taxcat"]==taxcat])
sns.stripplot("orf#", "env_broad", hue="threshold", order=order, color="grey",data=compdf[compdf["taxcat"]==taxcat], dodge=True, size=3)
plt.xlabel("proteome size (orfs)")
plt.ylabel("")
plt.legend(bbox_to_anchor=(1, 1), loc='upper left')
plt.grid('on', which='major', axis='x')
plt.savefig(rootdir + "figures/" + taxcat.lower() + "_orfcount.svg", format="svg", bbox_inches="tight")
plt.show()

### heatmap

In [None]:
# generate color mappings - bin 2 cols for each contrast
tax2phy = {"Saccharibacteria": "#4c72b0", "Absconditabacteria": "#dd8452", "Gracilibacteria": "#55a868", "None":"white"}
bin2eb = {i["newname"]: merged_palette[i["env_broad"]] for k,i in mc.iterrows()}
bin2group = {i["newname"]: merged_palette[i["phy"]] for k,i in mc.iterrows()}
# read in tree order
tree_order = phy["newname"].to_list()
tdf = pd.DataFrame(tree_order).reset_index()
tdf.columns = ["position", "bin"]
bin2tax = {row["newname"]: row["taxcat"] for key, row in mc.iterrows()}

In [None]:
def clustermap(df, min_fam_size, binary):
    
    # reconfigure df
    sub = df[["gene", "ko", "bin"]]
    gb = sub.groupby(["bin", "ko"], as_index=False).count()
    piv = gb.pivot("bin", "ko", "gene").fillna(0)
    # filter out low count annotations
    piv = piv[piv.columns[piv.sum(axis=0) >= min_fam_size]]
    if binary == True:
        piv = piv > 0
        piv = piv.replace(True, 1).reset_index()
    else:
        piv = piv.replace(0, -99).apply(np.log10).fillna(0)

    # reorder
    piv = piv.merge(tdf, on="bin", how="right").sort_values("position")
    # set row colors
    #tax_colors = piv["bin"].map(bin2tax).map(tax2phy)
    env_colors = piv["bin"].map(bin2eb)
    group_colors = piv["bin"].map(bin2group)
    fpiv = piv.drop(["bin", "position"], axis=1).dropna()
    
    #plot clustergram
    g = sns.clustermap(fpiv, figsize=(15,7), row_cluster=False, method='average', 
        metric='jaccard', row_colors=[group_colors, env_colors], cmap="Blues",
                   cbar_pos = None, dendrogram_ratio=0.07)
    for a in g.ax_col_dendrogram.collections:
        a.set_linewidth(0.25)
    plt.axis("off")
    plt.savefig(rootdir + "figures/heatmap.png", format="png", dpi=300)
    
    # return original/clustered column names
    return fpiv, g.dendrogram_col.reordered_ind

In [None]:
matrix, column_indices = clustermap(pdf, 5, True)

### find differentially distributed families

In [None]:
import scipy.stats as spstats
from statsmodels.stats.multitest import multipletests

In [None]:
diffs = defaultdict(list)

for phylum in pdf["taxcat"].unique():
    
    pvals = []
    if phylum !="None":
        
        table = pdf[pdf["taxcat"]==phylum]
        meta = mc[mc["taxcat"]==phylum]
        # merge in env data
        table = table.merge(mc[["newname", "env_broad"]], 
            left_on="bin", right_on="newname", how="left")

        for i, fam in enumerate(table["ko"].unique()):

            subtable = table[table["ko"]==fam]

            # match with clustergram output
            if fam in matrix.columns:

                # presence/absence
                subtable = subtable.drop_duplicates(["newname", "env_broad"])

                for env in table["env_broad"].unique():
                    diffs["taxcat"].append(phylum)
                    diffs["fam"].append(fam)
                    diffs["env"].append(env)

                    in_num = len(subtable[subtable["env_broad"]==env])
                    in_total = len(meta[meta["env_broad"]==env])
                    out_num = len(subtable[subtable["env_broad"]!=env])
                    out_total = len(meta[meta["env_broad"]!=env])
                    diffs["in_num"].append(in_num)
                    diffs["in_perc"].append(in_num/in_total)
                    diffs["out_perc"].append(out_num/out_total)
                    
                    if out_num ==0: # if exclusive
                        diffs["ratio"].append("None")
                        diffs["exclusive"].append(True)
                    else: # if not exclusive
                        diffs["ratio"].append((in_num/in_total)/(out_num/out_total))
                        diffs["exclusive"].append(False)
                        
                    # compute fisher's exact statistic
                    contable = [[in_num, out_num], [in_total-in_num, out_total-out_num]]
                    oddsratio, pvalue = spstats.fisher_exact(contable, alternative='two-sided')
                    diffs["fisher_exact"].append(pvalue)
                    pvals.append(pvalue)

            print('processed %d of %d for %s\r'%(i, len(table["ko"].unique()), phylum), end="")
            
    #fdr correction
    diffs["fisher_fdr"] += list(multipletests(pvals, method="fdr_bh")[1])

In [None]:
diffdf = pd.DataFrame(diffs)
diffdf = diffdf.merge(pd.DataFrame.from_dict(fam2ko, orient="index").reset_index().rename(columns={"index":"fam"}), how="left")
diffdf = diffdf.merge(kfilt[["ko", "def"]].drop_duplicates(), how="left").fillna("None")

In [None]:
families_tokeep=[]
# exclusive pfams in at least x% of ingroup genomes
families_tokeep += diffdf[(diffdf["exclusive"]==True) & (diffdf["fisher_fdr"]<=0.05)]["fam"].to_list()
# enriched pfams - x% ingroup but at least yfold enrichment
both = diffdf[(diffdf["ratio"]!='None') & (diffdf["ratio"]!=0)]
families_tokeep += both[(both["ratio"]>=5) & (both["fisher_fdr"]<=0.05)]["fam"].to_list()
# finally, depleted - majority outgroup + <x% ingroup
families_tokeep += diffdf[(diffdf["in_perc"]<=0.10) & (diffdf["out_perc"]>=0.5) & (diffdf["fisher_fdr"]<=0.05)]["fam"].to_list()
print(len(set(families_tokeep)))

### supp table 6

In [None]:
exclusive = diffdf[(diffdf["exclusive"]==True) & (diffdf["fisher_fdr"]<=0.05)]
exclusive["type"] = "enriched"
both = diffdf[(diffdf["ratio"]!='None') & (diffdf["ratio"]!=0)]
enriched = both[(both["ratio"]>=5) & (both["fisher_fdr"]<=0.05)]
enriched["type"] = "enriched"
depleted = diffdf[(diffdf["in_perc"]<=0.10) & (diffdf["out_perc"]>=0.5) & (diffdf["fisher_fdr"]<=0.05)]
depleted["type"] = "depleted"
s6 = pd.concat([exclusive, enriched, depleted])
s6["in_perc"] = s6["in_perc"].apply(lambda x: round(x,4))
s6["out_perc"] = s6["out_perc"].apply(lambda x: round(x,4))
s6["ratio"] = s6["ratio"].apply(lambda x: round(x,4) if x != "None" else "None")
s6["fisher_exact"] = s6["fisher_exact"].apply(lambda x: round(x,4))
s6["fisher_fdr"] = s6["fisher_fdr"].apply(lambda x: round(x,4))
s6["percent"] = s6["percent"].apply(lambda x: round(x,4) if x!= "None" else "None")
s6.head()

In [None]:
s6 = s6[["taxcat", "fam", "env", "type", "in_num", "in_perc", "out_perc", "ratio", "exclusive", "fisher_exact", "fisher_fdr", "ko", "percent", "def"]]
s6.columns = ["lineage", "protein_family", "habitat_broad", "distribution_type", "ingroup_num_encoding", "ingroup_percent_encoding", "outgroup_percent_encoding",
             "ratio", "exclusive", "fisher_exact", "fisher_fdr", "kegg_orthology", "fraction_seqs_annotated", "kegg_definition"]
s6.sort_values(["lineage", "habitat_broad", "distribution_type", "ingroup_num_encoding"], ascending=[True,True, False, False]).to_csv(rootdir + "protein/supp_table_6.csv", index=False)

### plot ticker

In [None]:
#process re-ordering info
orig = pd.DataFrame(matrix.columns).reset_index()
orig.columns = ["index", "fam"]
new = pd.DataFrame(column_indices).reset_index()
new.columns = ["new_index", "index"]
inds = orig.merge(new, on="index", how="left")
inds.head()

In [None]:
# prep df
table = diffdf.merge(inds[["fam", "new_index"]], how="left", on="fam")
totaldf = table.groupby("fam", as_index=False).aggregate({"in_num":"sum"}).rename(columns={"in_num":"fam_total"})
table = table.merge(totaldf, how="left", on="fam")
table["total_perc"] = table.apply(lambda x: x["in_num"]/x["fam_total"], axis=1)

for contrast in ["taxcat", "env"]:
    
    if contrast == "taxcat":
        color_map = ["lightgrey", "grey", "darkgrey"]
    else: color_map = [merged_palette[item] for item in table["env"].unique()]
    
    grouped = table.groupby(["fam",contrast], as_index=False).aggregate({"total_perc":"sum"})
    
    if contrast=="env":
        order = table["env"].unique()
    else: order = ["Absconditabacteria", "Gracilibacteria", "Saccharibacteria"]
        
    grouped.pivot("fam", contrast, "total_perc").fillna(0).loc[table.drop_duplicates("new_index").sort_values("new_index")["fam"].to_list(),order].plot.area(color=color_map, figsize=(15,1), legend=False)
    sns.despine(left=False, bottom=True)
    plt.xlabel("")
    #plt.tick_params(axis='x', bottom=False, top=False, labelbottom=False)
    plt.axis('off')
    plt.savefig(rootdir + "figures/" + contrast + "_ticker.png", format="png", bbox_inches="tight", dpi=300)
    plt.show()

# ale

In [None]:
cmdir(rootdir + "ale")
cmdir(rootdir + "ale/alignments")
cmdir(rootdir + "ale/gene_trees")
cmdir(rootdir + "ale/results")

### alignment and gene trees

In [None]:
filtfracs = []

with open(rootdir + "scripts/alignTrimGeneTrees.sh", "w") as wrapper:
    
    for family in glob.glob(rootdir + "protein/protein_clustering/familiesFasta/*"):
        
        # get fam size
        fam_size = len([record for record in SeqIO.parse(open(family), "fasta")])
        
        # subset to diff distributed families
        basename = os.path.basename(family).split(".")[0]
        
        if basename in families_tokeep:
            
            aln_path = rootdir + "ale/alignments/" + basename + ".mafft"
            
            # get sequence sizes
            seq_lens = [len(record) for record in SeqIO.parse(open(family), "fasta")]
            #for now, filter at > 2 STD below the mean len
            thresh = np.mean(seq_lens) - 2*np.std(seq_lens)
            # capture how much being lost
            filtfracs.append(len([s for s in seq_lens if s<thresh])/float(len(seq_lens)))
            
            # pullseq call
            call = "pullseq -m " + str(math.floor(thresh)) + " -i " + family + " > " + \
                rootdir + "ale/alignments/" + basename + ".filtered.faa"
            # mafft call
            call1 = "mafft --thread 16 --retree 2 --reorder " + \
                rootdir + "ale/alignments/" + basename + ".filtered.faa > " + aln_path
            # trimal call
            call2 = "trimal -in " + aln_path + " -out " + aln_path.replace(".mafft",".trimal.mafft") + \
                " -gt 0.1"
            wrapper.write(call + "\n" + call1 + "\n" + call2 + "\n")

In [None]:
# how many sequences are filtered per fam?
sns.distplot(filtfracs, kde=False)
plt.xlabel("filtered fraction")
plt.show()

In [None]:
# generate calls
calls = []

for trimal in glob.glob(rootdir + "ale/alignments/*trimal*"):
    
    seqs = [str(record.seq) for record in SeqIO.parse(open(trimal), "fasta")]
    basename = os.path.basename(trimal).split(".")[0]
    
    if len(set(seqs)) >= 4:
        call = "iqtree -s " + trimal + " -bnni -m TEST -st AA -bb 1000 -nt AUTO -pre " + \
            rootdir + "/ale/gene_trees/" + basename
    else: # get around iqtree bootstrap limitation
        call = "iqtree -s " + trimal + " -bnni -m TEST -st AA -nt AUTO -pre " + \
            rootdir + "/ale/gene_trees/" + basename
        
    #already done?
    if glob.glob(rootdir + "ale/gene_trees/" + basename + ".treefile") == []:
        calls.append(call)

In [None]:
# write to multiple wrappers
n = math.ceil(len(calls)/30)
for i in range(0, len(calls),n):
    with open(rootdir + "ale/gene_trees/wrapper" + \
        str(int(i/n)+1) + ".sh", "w") as wrapper:
        for call in calls[i:i + n]:
            wrapper.write(call + "\n")

Then chmod + sbatch in terminal : ` for item in $(ls | grep wrapper); do sbatch -J $item --wrap "$(pwd)/$item"; done`

### clean up species tree

In [None]:
# remove outgroups
t = Tree(rootdir + "protein/rp16/rp16_crossenv.treefile")
to_include = [genome for genome in mc["newname"] if genome not in to_remove]
t.prune(to_include, preserve_branch_length=True)
# reroot using arbitrary abs + gra
ancestor = t.get_common_ancestor("Shaiber2020_ORALPCFBin00011_Absconditabacteria_36","AR_2015_2-01_BD1-5_23_23_curated")
t.set_outgroup(ancestor)
#print(t)

In [None]:
# export
t.write(outfile=rootdir + "protein/rp16/rp16_crossenv.cpronly.treefile", format=2)

### run ale

In [None]:
cmdir(rootdir + "ale/scripts")
import random
import shutil

In [None]:
trdict = {item:True for item in to_remove}

In [None]:
for i, tree_set in enumerate(glob.glob(rootdir + "ale/gene_trees/*[0-9].ufboot")):
    
    count = 1
    cmdir(rootdir + "ale/temp")
    
    # pre-sample 100 trees
    all_trees = open(tree_set).readlines()
    sample = random.sample(all_trees, 100)
    
    # remove contaminant genes, modify leaf names
    for tree in sample:
        
        to_keep = []
        t = Tree(tree)

        for leaf in t:
            remove = False
            bin = scaf2bin[scaffold(leaf.name)]
            try:
                trdict[bin]
                print(bin)
            except KeyError:
                to_keep.append(leaf.name)

        # write out pruned treefile
        t.prune(to_keep,preserve_branch_length=True)
        
        # then modify leaf names to include species for ALE
        for leaf in t:
            new = scaf2bin[scaffold(leaf.name)] + "$" + leaf.name
            leaf.name = new
        t.write(outfile=rootdir + "ale/temp/temp" + str(count) + ".tre", format=2)
        count+=1
    
    # concatenate bootstraps
    with open(tree_set.replace("ufboot", "pruned.ufboot"), "w") as out:
        for tree in glob.glob(rootdir + "ale/temp/*"):
            for tree in open(tree).readlines():
                out.write(tree + "\n")
    
    shutil.rmtree(rootdir + "ale/temp/")
    print('%d of %d trees processed.\r'%(i, len(glob.glob(rootdir + "ale/gene_trees/*[0-9].ufboot"))), end="")

In [None]:
def wrapperize(calls, parts, out):
    n = math.ceil(len(calls)/parts)
    for i in range(0, len(calls),n):
        with open(out + \
            str(int(i/n)+1) + ".txt", "w") as wrapper:
            for call in calls[i:i + n]:
                wrapper.write(call + "\n")

In [None]:
bootcalls = []

for fam in glob.glob(rootdir + "ale/gene_trees/*pruned.ufboot"):
    if glob.glob(fam.replace("gene_trees", "results") + "*uml_rec")==[]:
        bootcalls.append(fam)

wrapperize(bootcalls, 25, rootdir + "ale/results/famlist")

In [None]:
for callist in glob.glob(rootdir + "ale/results/famlist*"):
    call = "sbatch -J " + os.path.basename(callist).replace("fam","").split(".")[0] + \
        " --wrap 'python " + rootdir + "ale/scripts/runAle.py " + callist + "'"
    #print(call)

In [None]:
# try running with fraction missing
gqual = pd.read_csv(rootdir + "metadata/genomeInformation.csv")
gqual["newname"] = gqual["genome"].apply(lambda x: x.replace(".fna", ""))
# get species in tree
cpr_intree = [leaf.name for leaf in Tree(rootdir + "protein/rp16/rp16_crossenv.cpronly.treefile")]
gsubset = gqual[gqual["newname"].isin(cpr_intree)]

with open(rootdir + "ale/fraction_missing.txt", "w") as outfile:
    for key, row in gsubset.iterrows():
        outfile.write(row["newname"] + ":" + str((1-float(row["completeness"])/100)) + "\n")

### parse output

In [None]:
ale_results = {}

for result in glob.glob(rootdir + "ale/results/*uml_rec"):
    
    buffer = []
    name = os.path.basename(result).split(".")[0]
    
    # find tabular portion
    for line in open(result).readlines():
        elements = line.strip().split("\t")
        if (elements[0] == "S_terminal_branch") or (elements[0] == "S_internal_branch"):
            buffer.append(elements)
    ale_results[name] = pd.DataFrame(buffer, columns=["branch_type", "branch", 
        "duplications", "transfers", "losses", "originations","copies"], dtype=float)

In [None]:
len(ale_results.keys())

In [None]:
# sum over all values
cat_results = pd.concat(ale_results.values())
sum_events = cat_results.groupby(["branch", "branch_type"], as_index=False).sum()
# exclude terminal branches
sum_events = sum_events[sum_events["branch_type"]!="S_terminal_branch"]

### itol

In [None]:
cmdir(rootdir + "ale/itol/")
import matplotlib as mpl

In [None]:
# make cladogram
with open(rootdir + "ale/itol/cladogram.tre", "w") as out:
    tree_list = [line for line in open(glob.glob(rootdir + "ale/results/*uml_rec")[0]).readlines() if "S:\t" in line]
    # reformat internal node names for iTOL
    tree_string = tree_list[0].strip().replace("S:\t", "")
    tree_mod = re.sub("\)([0-9]+):1", r")INT\1:1", tree_string)
    out.write(tree_mod.replace("'", ""))

In [None]:
from matplotlib import cm

def rgba2hex(col):
    
    def rft(raw):
        return int(round(raw*255))
    
    return '#{:02x}{:02x}{:02x}'.format(rft(col[0]), rft(col[1]), rft(col[2]))

In [None]:
# make datasets
for var in sum_events.columns:
    
    if "branch" not in var:
        
        #define color scale
        cmap = sns.light_palette("orange", as_cmap=True)
        
        # write out itol file
        itol = open(rootdir + "ale/itol/" + var + ".itol.txt", "w")
        # change SYMBOL TO STYLE for branch colors
        itol.write("DATASET_SYMBOL\nSEPARATOR COMMA\nDATASET_LABEL," + \
            var + "\nCOLOR,#d3d3d3\nDATA\n")
        for key, row in sum_events.iterrows():
            scalar = row[var]/max(sum_events[var])
            color = rgba2hex(cmap(scalar))
            node = "INT" + row["branch"] if "internal" in row["branch_type"] else row["branch"]
            if row[var] >= 1:
                itol.write("%s,2,%d,%s,1,0.5\n" %(node,scalar*10,color))
                #itol.write("%s,branch,node,%s,1,normal\n" %(node,color))
        itol.close()
        
        # generate standalone color bars
        if var not in ["duplications", "copies"]:
            print(var)
            title = "total " + var if var != "transfers" else "total within-CPR transfers"
            sns.set_style({"axes.linewidth":0.25, "axes.edgecolor":"black"})
            fig, ax = plt.subplots(figsize=(0.5,3))
            #fig.subplots_adjust(bottom=0.5)
            cmapp = cmap
            norm = mpl.colors.Normalize(vmin=min(sum_events[sum_events[var]>=1][var]), 
                vmax=max(sum_events[var]))
            fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
                cax=ax, orientation='vertical')
            plt.savefig(rootdir + "ale/itol/" + var + ".bar.svg", format='svg')

In [None]:
# decorate the tree
itol = open(rootdir + "ale/itol/cladogram_envbroad.itol.txt", "w")
itol.write("TREE_COLORS\nSEPARATOR TAB\nDATA\n")

for key, row in mc.iterrows():
    color = merged_palette[row["env_broad"]]
    label = row["env_broad"]
    itol.write(row["newname"] + "\trange\t" + color + "\t" + label + "\n")
itol.close()

# rhodopsin

In [None]:
cmdir(rootdir + "protein/rhod")

In [None]:
# from https://www.nature.com/articles/s41586-018-0225-9#Sec1 supp
with open(rootdir + "protein/rhod/push_refs_clean.faa", "w") as outfile:
    for record in SeqIO.parse(open(rootdir + "protein/rhod/Supp_Data2_AlignmentFileType1plusHelioRs.txt"), "fasta"):
        outfile.write(">" + record.description + "\n" + str(record.seq).replace("-", "") + "\n")

In [None]:
with open(rootdir + "protein/rhod/rhod_aln.itol.txt", "w") as outfile:
    outfile.write("DATASET_ALIGNMENT\nSEPARATOR COMMA\nDATASET_LABEL,rhod\nCOLOR,#ff0000\nCUSTOM_COLOR_SCHEME,MY_SCHEME_1,A=#d2d0c9,M=#d2d0c9,I=#d2d0c9,L=#d2d0c9,V=#d2d0c9,P=#746f69,G=#746f69,C=#746f69,F=#d0ad16,Y=#d0ad16,W=#d0ad16,S=#34acfb,T=#34acfb,N=#34acfb,-=#ffffff,Q=#34acfb,R=#34fb54,K=#34fb54,H=#34fb54,D=#fb4034,E=#fb4034\nDATA\n")
    for record in SeqIO.parse(open(rootdir + "/protein/rhod/all_push_rhodopsin.stripped.mafft"), "fasta"):
        outfile.write(">" + record.description.split(" ")[0] + "\n" + str(record.seq) + "\n")