# Machine learning prediction of trophic modes
In this notebook, we'll predict whether phaeocystisn is phototrophic, heterotrophic or mixotrophic in a given sample.

### Setup
First, we'll import the necessary scripts and load the training data from [github](https://github.com/armbrustlab/trophic-mode-ml) and [zenodo](https://zenodo.org/record/4425690#.ZHWzKC9ByLf).
These are stored in the `data` folder.
Then, we'll run the model.

In [81]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import requests
import numpy as np
import re
import sys
import warnings
from pandas.errors import PerformanceWarning
warnings.filterwarnings('ignore', category=PerformanceWarning)

In [57]:
# Read transcripts annotated to Phaeocystis (from the output of the mtx_taxonomy.ipynb script)
phaeocystis_transcripts = pd.read_csv("../data/annotation/taxonomy_eukprot/130/genus_bins/Phaeocystis_transcriptome_bin.csv")
print(len(phaeocystis_transcripts))
# Filter out transcripts annotated to Phaeocystis globosa with a percent identity above 80
phaeocystis_transcripts = phaeocystis_transcripts[(phaeocystis_transcripts["p_ident"] > 0.8) & (phaeocystis_transcripts["Name_to_Use"] == "Phaeocystis globosa")]
print(len(phaeocystis_transcripts))

# Keep only the transcript IDs
phaeocystis_transcripts = phaeocystis_transcripts["query_id"].tolist()

18720
7612


In [70]:
# Read in the functional annotation data
functional_annotation = pd.read_table('../data/annotation/functional/130/functional_annotation.emapper.annotations', )
# Cut off weird characters from the transcript names
functional_annotation['#query'] = functional_annotation['#query'].str.split(".", n=1, expand=True)[0]
# Rename the query_id column
functional_annotation.rename(columns={'#query': 'transcript_id'}, inplace=True)
print(len(functional_annotation))
# Only keep the rows that are in the Phaeocystis transcriptome
functional_annotation = functional_annotation[functional_annotation['transcript_id'].isin(phaeocystis_transcripts)]
print(len(functional_annotation))

165198
5771


In [71]:
# Read in the quantification data
tpm = pd.read_csv('../data/quantification/130_tpm.csv')

# Rename the target_id column
tpm.rename(columns={'target_id': 'transcript_id'}, inplace=True)

# Change into long format
tpm = tpm.melt(id_vars=['transcript_id'], var_name='sample', value_name='TPM')

# Only keep the rows that are in the Phaeocystis transcriptome
tpm = tpm[tpm['transcript_id'].isin(phaeocystis_transcripts)]

## Generate Phaeocystis tpm sums per sample
phaeo_tpm_sums = tpm.groupby('sample')['TPM'].sum().reset_index()

## Rename columns
phaeo_tpm_sums.columns = ['sample', 'Phaeocystis_total_tpm']

In [72]:
# Read in the environmental data
env_data = pd.read_csv('../data/samples_env.csv', index_col=0)
# Read in primary production data
PP_data = pd.read_csv('../data/raw/LabSTAF/labstaf_combined_data.csv')
## Combine Station and Sample columns into one station_sample column
PP_data['Station'] = PP_data['Station'].astype(str) + '_' + PP_data['Sample'].astype(str)


# Merge PP_data with env and transcript data
env_data = env_data.merge(PP_data[['Station', 'PP']], on=['Station'], how='inner')

# Rename station column
env_data.rename(columns={'Station': 'sample'}, inplace=True)

print(env_data.columns)

Index(['sample', 'StationPrefix', 'StationSuffix', 'Latitude', 'Longitude',
       'Date', 'day_moment', 'day_length', 'Temperature', 'Salinity',
       'Conductivity', 'Depth', 'Oxygen', 'Fluorescence', 'NH4', 'NO2', 'NO3',
       'NOX', 'PO4', 'Si', 'TEP', 'sea_surface_height_above_sea_level',
       'surface_baroclinic_sea_water_velocity', 'PP'],
      dtype='object')


In [73]:
# Create a dataframe containing the sample names, functional annotation categories of interest and TPM
functional_columns = ['transcript_id', 'PFAMs']
data = functional_annotation[functional_columns].merge(tpm, on='transcript_id', how='inner')
data = data.merge(env_data, on='sample', how='inner')

# Standardize the TPM values per sample by dividing by the total TPM of Phaeocystis in that sample
data = data.merge(phaeo_tpm_sums, on='sample', how='inner')
data['TPM_standardized'] = data['TPM'] / data['Phaeocystis_total_tpm']

# Phaeocystis is absent from samples 130_15, 130_16, 130_20, 130_24, 130_25. Set these standardized TPM values to 0 otherwise they won't make sense
data.loc[data['sample'].isin(['130_15', '130_16', '130_20', '130_24', '130_25']), 'TPM_standardized'] = 0

# Inspect the data
data.head()

Unnamed: 0,transcript_id,PFAMs,sample,TPM,StationPrefix,StationSuffix,Latitude,Longitude,Date,day_moment,...,NO3,NOX,PO4,Si,TEP,sea_surface_height_above_sea_level,surface_baroclinic_sea_water_velocity,PP,Phaeocystis_total_tpm,TPM_standardized
0,c_000002143804,adh_short,51_21,0.0,51,21,51.531851,3.182475,2023-04-19 09:03:00,Day,...,9.13,9.35,0.17,6.98,639.653333,-0.751302,0.899872,48.3,274.394819,0.0
1,c_000002143804,adh_short,51_7,0.0,51,7,51.532258,3.182319,2023-04-18 19:03:00,Civil twilight,...,7.89,8.07,0.16,7.4,201.302667,-1.046195,1.005741,0.0,1612.400269,0.0
2,c_000002143804,adh_short,51_11,3.78494,51,11,51.531646,3.182801,2023-04-18 23:07:00,Night,...,7.22,7.4,0.13,7.4,144.862667,1.264855,0.485325,34.152,1789.126177,0.002116
3,c_000002143804,adh_short,51_8,0.0,51,8,51.531966,3.182484,2023-04-18 20:04:00,Nautical twilight,...,7.89,8.08,0.15,7.34,161.794667,-0.643491,0.946217,51.972,5576.978714,0.0
4,c_000002143804,adh_short,51_10,0.0,51,10,51.531665,3.182699,2023-04-18 22:05:00,Night,...,7.27,7.44,0.13,7.21,212.590667,0.313989,0.422929,31.584,2078.523171,0.0


I've downloaded the Pfam annotation data from [interprot](https://www.ebi.ac.uk/interpro/download/Pfam/) on 30/05/2023 and stored it in `data/annotation/functional/`.
We'll load that dataframe and parse it, to swap the PFAM short names for their descriptions.

In [74]:
# Dictionary to store PFAM ID and description
pfam_mappping = {}
pfam_data = {}

# Read the PFAM data file
with open("../data/annotation/functional/Pfam-A.hmm.dat", "r") as file:
    file_content = file.read()

# Split the file content into individual PFAM entries
pfam_entries = re.split(r"//\n", file_content.strip())

# Process each PFAM entry
for entry in pfam_entries:
    lines = entry.strip().split("\n")
    pfam_short_name_line = lines[1]
    pfam_id_line = lines[2]
    pfam_description_line = lines[3]
    
    # Extract the PFAM short name
    pfam_short_name = pfam_short_name_line.split("ID")[1].strip()
    
    # Extract the PFAM ID
    pfam_id = pfam_id_line.split("AC")[1].strip()

    # Extract the PFAM description
    pfam_description = pfam_description_line.split("DE")[1].strip()

    # Store the PFAM ID and short name in the dictionary
    pfam_mappping[pfam_short_name] = pfam_id
    
    # Store the PFAM ID and description in the dictionary
    pfam_data[pfam_id] = pfam_description

In [75]:
# Add description and ID columns to data
data['PFAMs_ID'] = data['PFAMs'].apply(lambda x: ",".join([pfam_mappping.get(pfam.strip(), "-") for pfam in x.split(",")]))
data['PFAMs_description'] = data['PFAMs_ID'].apply(lambda x: ",".join([pfam_data.get(pfam.strip(), "-") for pfam in x.split(",")]))

In [76]:
data.head()

Unnamed: 0,transcript_id,PFAMs,sample,TPM,StationPrefix,StationSuffix,Latitude,Longitude,Date,day_moment,...,PO4,Si,TEP,sea_surface_height_above_sea_level,surface_baroclinic_sea_water_velocity,PP,Phaeocystis_total_tpm,TPM_standardized,PFAMs_ID,PFAMs_description
0,c_000002143804,adh_short,51_21,0.0,51,21,51.531851,3.182475,2023-04-19 09:03:00,Day,...,0.17,6.98,639.653333,-0.751302,0.899872,48.3,274.394819,0.0,PF00106.28,short chain dehydrogenase
1,c_000002143804,adh_short,51_7,0.0,51,7,51.532258,3.182319,2023-04-18 19:03:00,Civil twilight,...,0.16,7.4,201.302667,-1.046195,1.005741,0.0,1612.400269,0.0,PF00106.28,short chain dehydrogenase
2,c_000002143804,adh_short,51_11,3.78494,51,11,51.531646,3.182801,2023-04-18 23:07:00,Night,...,0.13,7.4,144.862667,1.264855,0.485325,34.152,1789.126177,0.002116,PF00106.28,short chain dehydrogenase
3,c_000002143804,adh_short,51_8,0.0,51,8,51.531966,3.182484,2023-04-18 20:04:00,Nautical twilight,...,0.15,7.34,161.794667,-0.643491,0.946217,51.972,5576.978714,0.0,PF00106.28,short chain dehydrogenase
4,c_000002143804,adh_short,51_10,0.0,51,10,51.531665,3.182699,2023-04-18 22:05:00,Night,...,0.13,7.21,212.590667,0.313989,0.422929,31.584,2078.523171,0.0,PF00106.28,short chain dehydrogenase


## Predict trophic mode

In [77]:
# Indicate where to look for the machine learning modules so we can import them into the notebook
sys.path.append('../data/trophic-mode-ml/model/')

In [78]:
# Create the data for the machine learning model
df = data[['sample', 'PFAMs_ID', 'TPM_standardized']]
df = df.explode('PFAMs_ID')
# Group by sample and PFAM ID and sum the TPM values
df = df.groupby(['sample', 'PFAMs_ID']).sum().reset_index()

# Remove all rows with no PFAM ID
df = df[df['PFAMs_ID'] != "-"]

# Remove the part of the PFAM ID that indicates the domain
df['PFAMs_ID'] = df['PFAMs_ID'].apply(lambda x: x.split(".")[0])

In [79]:
from predict_tm_lib import run_prediction

def run_prediction_per_sample(df, train_data_path, feature_path, labels_path):
    """ This will run the trophic mode prediction for all species in the dataset (month, station, sample, species list).
        The results_df used for this was generated above, and only includes species that meet the models' criteria.

    Args:
        df (pandas dataframe): Should contain a column sample and species for which the prediction should be run
        profiles (pandas dataframe): Should contain the profiles for all species in the df. Should have columns sample, species, PFAMs_ID, and TPM (normalized for total transcriptome bin size)
        train_data_path (csv): csv data file with the training data as used in the model by the authors
        feature_path (csv): csv data file with the features as used in the model by the authors
        labels_path (csv): csv data file with the labels as used in the model by the authors

    Returns:
        dataframe: dataframe with a predicted trophic mode for all species. Columns are sample, species, and predicted trophic mode.
    """
    # This will store all results
    all_results = pd.DataFrame()
    # Get unique samples
    samples = df['sample'].unique()
    for sample in samples:
        # Get profile for this specie in this sample
        profile = df[(df['sample'] == sample)][['PFAMs_ID', 'TPM_standardized']]
        # Reshape profile
        profile = profile.pivot_table(columns='PFAMs_ID', values='TPM_standardized', aggfunc='sum').fillna(0)
        # Save into a csv file
        if not os.path.exists('../data/analysis/phaeocystis_trophic_modes/'):
            os.makedirs('../data/analysis/phaeocystis_trophic_modes/')
        profile_path = f'../data/analysis/phaeocystis_trophic_modes/{sample}.csv'
        profile.to_csv(profile_path)
        # Run the prediction model
        predictions = run_prediction(data=profile_path, train=train_data_path, feats=feature_path, labels=labels_path, out=f'../data/analysis/phaeocystis_trophic_modes/{sample}_preds.csv', rf=False)
        # Collect predictions into a dataframe
        p_df = pd.DataFrame(data={'sample': sample, 'prediction': predictions})
        # Append to all_results dataframe
        all_results = pd.concat([all_results, p_df], ignore_index=True)
    # Return the final results
    return all_results

In [80]:
# Run the model for station 130, otherwise the kernel will crash
results = run_prediction_per_sample(
    df=df,
    train_data_path='../data/trophic-mode-ml/data/Field_training_data.csv', 
    feature_path='../data/trophic-mode-ml/data/Extracted_Pfams.csv', 
    labels_path='../data/trophic-mode-ml/data/Field_training_labels.csv'
    );
results.to_csv(f'../data/analysis/phaeocystis_trophic_modes/130_preds.csv', index=False)



All predictions indicate Phaeocystis globosa was phototrophic.