Permalink
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
executable file 335 lines (282 sloc) 10.7 KB
#!/usr/bin/env python3
import argparse
import copy
import logging
from multiprocessing import Pool
import os
import pandas as pd
import shutil
import sys
import time
import uuid
from ann_linkage_clustering.lib import make_cags_with_ann
from ann_linkage_clustering.lib import iteratively_refine_cags
from ann_linkage_clustering.lib import make_nmslib_index
from ann_linkage_clustering.helpers import return_results
from ann_linkage_clustering.helpers import make_summary_abund_df
from ann_linkage_clustering.helpers import make_abundance_dataframe
from ann_linkage_clustering.helpers import normalize_abundance_dataframe
from ann_linkage_clustering.helpers import exit_and_clean_up
from ann_linkage_clustering.helpers import read_json
def find_cags(
sample_sheet=None,
output_prefix=None,
output_folder=None,
normalization=None,
max_dist=0.1,
temp_folder="/scratch",
results_key="results",
abundance_key="depth",
gene_id_key="id",
threads=1,
min_samples=1,
test=False,
clr_floor=None,
distance_metric="cosine",
linkage_type="average"
):
"""
Wrapper function to be called directly by the user.
Reads in a set of samples with gene abundance information in JSON format,
and returns the final set of CAGs to a set of files as specified by the user.
"""
# Make sure the temporary folder exists
assert os.path.exists(temp_folder)
# Make a new temp folder
temp_folder = os.path.join(temp_folder, str(uuid.uuid4())[:8])
os.mkdir(temp_folder)
# Set up logging
log_fp = os.path.join(temp_folder, "log.txt")
logFormatter = logging.Formatter(
'%(asctime)s %(levelname)-8s [find-cags] %(message)s'
)
rootLogger = logging.getLogger()
rootLogger.setLevel(logging.INFO)
# Write to file
fileHandler = logging.FileHandler(log_fp)
fileHandler.setFormatter(logFormatter)
rootLogger.addHandler(fileHandler)
# Also write to STDOUT
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(logFormatter)
rootLogger.addHandler(consoleHandler)
# Set up the multiprocessing pool
pool = Pool(threads)
# READING IN DATA
# Read in the sample_sheet
logging.info("Reading in the sample sheet from " + sample_sheet)
try:
sample_sheet = read_json(sample_sheet)
except:
exit_and_clean_up(temp_folder)
# Make the abundance DataFrame
logging.info("Making the abundance DataFrame")
try:
df = make_abundance_dataframe(
sample_sheet,
results_key,
abundance_key,
gene_id_key
)
except:
exit_and_clean_up(temp_folder)
# NORMALIZING RAW ABUNDANCES
# Normalize the raw depth abundances
try:
df = normalize_abundance_dataframe(df, normalization)
except:
exit_and_clean_up(temp_folder)
# Make a copy of this abundance table, to be saved at the end
unfiltered_abundance_df = copy.deepcopy(df)
# Apply the clr_floor parameter, if applicable
if clr_floor is not None and normalization == "clr":
if clr_floor == "auto":
clr_floor = df.min().max()
logging.info("Automatically set the minimum CLR as {:,}".format(clr_floor))
else:
logging.info("User passed in {} as the minimum CLR value".format(clr_floor))
try:
clr_floor = float(clr_floor)
except:
logging.info("{:,} could not be evaluated as a float".format(clr_floor))
exit_and_clean_up(temp_folder)
logging.info("Applying the CLR floor: {}".format(clr_floor))
df = df.applymap(lambda v: v if v > clr_floor else clr_floor)
# If min_samples > 1, subset the genes
logging.info(
"Subsetting to genes found in at least {} samples".format(min_samples))
# Keep track of the number of genes filtered, and the time elapsed
n_before_filtering = df.shape[0]
start_time = time.time()
# Filter
df = df.loc[(df > df.min().min()).sum(axis=1) >= min_samples]
logging.info("{:,} / {:,} genes found in >= {:,} samples ({:,} seconds elapsed)".format(
df.shape[0],
n_before_filtering,
min_samples,
round(time.time() - start_time, 2)
))
# If this is being run in testing mode, subset to 2,000 genes
if test:
logging.info("Running in testing mode, subset to 2,000 genes")
df = df.head(2000)
# Make sure that the lowest abundance is 0 (for clustering)
logging.info("Shifting the lowest abundance to 0 (for the purpose of calculating distance metrics)")
df = df - df.min().min()
# CLUSTERING
# Make the nmslib index
logging.info("Making the nmslib index")
index = make_nmslib_index(df)
logging.info("Finding CAGs")
# Make CAGs using the approximate nearest neighbor
try:
cags = make_cags_with_ann(
index,
max_dist,
df.copy(),
pool,
threads=threads,
distance_metric=distance_metric,
linkage_type=linkage_type
)
except:
exit_and_clean_up(temp_folder)
logging.info("Closing the process pool")
pool.close()
# Refine the CAGs by merging overlapping groups
try:
iteratively_refine_cags(
cags,
df.copy(),
max_dist,
distance_metric=distance_metric,
linkage_type=linkage_type,
threads=threads
)
except:
exit_and_clean_up(temp_folder)
# Rename the CAGs
cags = {
ix: list_of_genes
for ix, list_of_genes in enumerate(sorted(list(cags.values()), key=len, reverse=True))
}
# Print the number of total CAGs, number of singletons, etc.
logging.info("Number of CAGs = {:,}".format(
len(cags)
))
# Print the number of total CAGs, number of singletons, etc.
logging.info("Number of singletons = {:,} / {:,}".format(
sum([len(v) == 1 for v in cags.values()]),
sum(map(len, cags.values()))
))
# Print the number of total CAGs, number of singletons, etc.
logging.info("Largest CAG = {:,}".format(
max(map(len, cags.values()))
))
logging.info("Size distribution of CAGs:")
logging.info(pd.Series(list(map(len, cags.values()))).describe())
logging.info("Largest CAGs:")
largest_cags = pd.Series(dict(zip(
cags.keys(),
map(len, cags.values())
)))
largest_cags.sort_values(ascending=False, inplace=True)
for cag_id, cag_size in largest_cags.head(10).items():
logging.info("{}: {:,}".format(cag_id, cag_size))
# RETURN RESULTS
# Now make a summary DF with the mean value for each combined CAG
summary_df = make_summary_abund_df(unfiltered_abundance_df, cags)
# Read in the logs
logs = "\n".join(open(log_fp, "rt").readlines())
# Return the results
logging.info("Returning results to " + output_folder)
try:
return_results(
unfiltered_abundance_df,
summary_df,
cags,
logs,
output_prefix,
output_folder,
temp_folder
)
except:
exit_and_clean_up(temp_folder)
# Delete any files that were created for this sample
logging.info("Removing temporary folder: " + temp_folder)
shutil.rmtree(temp_folder)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
Find a set of co-abundant genes
Read in a set of gene abundance information in JSON format,
find CAGs using ANN to rapidly identify gene neighborhoods,
and then return the final CAGs to a JSON file, with logs.
"""
)
parser.add_argument("--sample-sheet",
type=str,
required=True,
help="""Location for sample sheet (.json[.gz]).""")
parser.add_argument("--output-prefix",
type=str,
required=True,
help="""Prefix for output files.""")
parser.add_argument("--output-folder",
type=str,
required=True,
help="""Folder to place results.
(Supported: s3://, or local path).""")
parser.add_argument("--normalization",
type=str,
default=None,
help="Normalization factor per-sample (median, sum, or clr).")
parser.add_argument("--max-dist",
type=float,
default=0.01,
help="Maximum cosine distance for clustering.")
parser.add_argument("--temp-folder",
type=str,
default="/scratch",
help="Folder for temporary files.")
parser.add_argument("--results-key",
type=str,
default="results",
help="Key identifying the list of gene abundances for each sample JSON.")
parser.add_argument("--abundance-key",
type=str,
default="depth",
help="Key identifying the abundance value for each element in the results list.")
parser.add_argument("--gene-id-key",
type=str,
default="id",
help="Key identifying the gene ID for each element in the results list.")
parser.add_argument("--threads",
type=int,
default=1,
help="Number of threads to use.")
parser.add_argument("--min-samples",
type=int,
default=1,
help="Filter genes by the number of samples they are found in.")
parser.add_argument("--clr-floor",
type=str,
default=None,
help="Set a minimum CLR value, 'auto' will use the largest minimum value.")
parser.add_argument("--test",
action="store_true",
help="Run in testing mode and only process a subset of 2,000 genes.")
args = parser.parse_args(sys.argv[1:])
# Sample sheet is in JSON format
assert args.sample_sheet.endswith((".json", ".json.gz"))
# Normalization factor is absent, 'median', or 'sum'
assert args.normalization in [None, "median", "sum", "clr"]
# max-dist is >=0
assert args.max_dist >= 0
assert args.threads >= 1
assert args.min_samples >= 1
# Make sure the temporary folder exists
assert os.path.exists(args.temp_folder), args.temp_folder
find_cags(
**args.__dict__
)