In [2]:
import pandas as pd
from ete3 import NCBITaxa
from sklearn.model_selection import train_test_split
import os

In [6]:
dataframe = './test_data/ab-taxid-targetacc.csv'

In [7]:
#Import dataframe
df = pd.read_csv(dataframe)
# setting taxid to string
df['taxid'] = df['taxid'].astype(str)

In [8]:
# Checking out dataframe
df.describe()

Unnamed: 0,target_acc,taxid
count,2057543,2057543
unique,2057543,18070
top,PDT000113968.2,28901
freq,1,401268


In [23]:
# Checking number of unique taxids - There are way too many
len(df['taxid'].unique())

18070

In [17]:
# Initialize the NCBITaxa object
ncbi = NCBITaxa()

# List of TaxIDs (example list, replace with your own)
taxid_list = [1284325]  # Replace with your own TaxIDs

# Function to get species-level TaxID for each input TaxID
def get_species_taxid(taxid):
    # Get the lineage for the TaxID
    lineage = ncbi.get_lineage(taxid)
    
    # Get the names and ranks of each level in the lineage
    names = ncbi.get_taxid_translator(lineage)
    ranks = ncbi.get_rank(lineage)
    
    # Find the species-level TaxID
    for tid in lineage:
        if ranks[tid] == 'species':
            return tid, names[tid]
    
    # If species not found, return None
    return None, None

Species TaxID: 28901, Species Name: Salmonella enterica


In [21]:
# Apply the function to the 'TaxID' column and expand into two new columns
df[['Base_Species_TaxID', 'Species_Name']] = df['taxid'].apply(lambda x: pd.Series(get_species_taxid(x)))

In [24]:
# Setting base taxid to a string
df['Base_Species_TaxID'] = df['Base_Species_TaxID'].astype(str)

In [26]:
# Checking out new DF
df.describe()
# Looks good now we can see everything has been charcterized into 754 species

Unnamed: 0,target_acc,taxid,Base_Species_TaxID,Species_Name
count,2057543,2057543,2057543,2057543
unique,2057543,18070,754,754
top,PDT000113968.2,28901,28901,Salmonella enterica
freq,1,401268,671635,671635


In [33]:
# Target Species

target_species = {}

#Enterococcus faecium
ef_taxid = '1352'
target_species['Enterococcus faecium'] = ef_taxid

#Staphylococcus aureus
sa_taxid = '1280'
target_species['Staphylococcus aureus'] = sa_taxid

#Klebsiella pneumoniae
kp_taxid = '573'
target_species['Klebsiella pneumoniae'] = kp_taxid

#Acinetobacter baumannii
ab_taxid = '470'
target_species['Acinetobacter baumannii'] = ab_taxid

#Pseudomonas aeruginosa
pa_taxid = '287'
target_species['Pseudomonas aeruginosa'] = pa_taxid

In [34]:
for species,taxid in target_species.items():
    
    print('There are ' + str(len(df.loc[df['Base_Species_TaxID'] == taxid])) + ' assemblies from ' + species) 

There are 40280 assemblies from Enterococcus faecium
There are 120954 assemblies from Staphylococcus aureus
There are 99500 assemblies from Klebsiella pneumoniae
There are 34872 assemblies from Acinetobacter baumannii
There are 38843 assemblies from Pseudomonas aeruginosa


In [41]:
for species,taxid in target_species.items():
    
    subset_df = df.loc[df['Base_Species_TaxID'] == taxid]
    
    train_df, test_df = train_test_split(subset_df, test_size=0.2, random_state=42)
    
    file_name_test = species + '_test.csv'
    file_name_train = species + '_train.csv'
    
    train_df.to_csv(os.path.join('./datasets', file_name_train), index=False)
    test_df.to_csv(os.path.join('./datasets', file_name_test), index=False)