In [1]:
# in src --> utils --> external_sources_utils.py add a function for retrieving the nonvertebrata tax ids

# initialize the kingdom plantae taxonomy id
PLANTAE_TAX_ID = "3193"

# define a function to filter plants
def get_plantae_tax_ids(tax_ids):
    plantae_tax_ids = []
    for tax_id in tax_ids:
        try:
            full_lineage_tax_ids = pytaxonkit.lineage([tax_id])["FullLineageTaxIDs"].iloc[0].split(";")
            if PLANTAE_TAX_ID in full_lineage_tax_ids:
                plantae_tax_ids.append(tax_id)
        except:
            print(f"ERROR in lineage for tax_id = {tax_id}")
    return plantae_tax_ids  

In [2]:
# in src --> data_preprocessing --> dataset_filter.py add a function to get specific sequences from the plant hosts

TAXONKIT_DB = "TAXONKIT_DB"
VIRUS_HOST_TAX_ID = "virus_host_tax_id"

N_CPU = 4

def get_sequences_from_plantae_hosts(input_file_path, taxon_metadata_dir_path, output_file_path):
    print("START: Filter records with virus hosts not belonging to 'plantae' kingdom.")

    # Set TAXONKIT_DB environment variable
    os.environ["TAXONKIT_DB"] = taxon_metadata_dir_path

    # Read input file
    df = pd.read_csv(input_file_path)

    # Get all unique host tax ids
    host_tax_ids = df[VIRUS_HOST_TAX_ID].unique()
    print(f"Number of unique host tax ids = {len(host_tax_ids)}")

    # Get taxids belonging to the kingdom of plants
    # split into sublists for parallel processing
    host_tax_ids_sublists = np.array_split(np.array(host_tax_ids), N_CPU)
    for i in range(N_CPU):
        print(f"Size of host_tax_ids_sublists[{i}] = {host_tax_ids_sublists[i].shape}")

    # multiprocessing for parallelism
    cpu_pool = Pool(N_CPU)
    plantae_tax_ids_sublists = cpu_pool.map(external_sources_utils.get_plantae_tax_ids, host_tax_ids_sublists)
    # flatten the list of sub_lists into one list
    plantae_tax_ids = list(itertools.chain.from_iterable(plantae_tax_ids_sublists))
    cpu_pool.close()
    cpu_pool.join()
    print(f"Number of unique plant tax ids = {len(plantae_tax_ids)}")
    # Filter
    print(f"Dataset size before filtering for plants: {df.shape}")
    df = df[df[VIRUS_HOST_TAX_ID].isin(plantae_tax_ids)]
    print(f"Dataset size after filtering for plants: {df.shape}")

    df.to_csv(output_file_path, index=False)
    print(f"Writing to file {output_file_path}")
    print("END: Filter records with virus hosts not belonging to plantae' kingdom.")

In [None]:
# in data_preprocessor.py add an argument within the function parse_args

parser.add_argument("--filter_plants", action="store_true", help="Filter for virus hosts belonging to Plantae kingdom using the absolute path to the NCBI taxon directory provided in --taxon_dir.")

# within function process(config) create a new if statement
if config.filter_plants:
    filtered_dataset_file_path = os.path.join(output_dir, Path(input_file_path).stem + "_plants.csv")
    dataset_filter.get_sequences_from_plantae_hosts(input_file_path=input_file_path, taxon_metadata_dir_path=config.taxon_dir, output_file_path=filtered_dataset_file_path)