# The case of diatoms, dinoflagellates, and Prymnesiophytes

In [None]:
import pandas as pd
from sklearn.preprocessing import StandardScaler
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.model_selection import cross_val_score, cross_validate
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline
import matplotlib.colors as mcolors
import plotly.graph_objects as go
from scipy import stats
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score, mean_squared_error
import statsmodels.api as sm
from statsmodels.formula.api import ols
from statsmodels.stats.multicomp import pairwise_tukeyhsd

sns.set_style('white')

By correlating functional modules with taxonomic classes, we see that on an ecosystem level, phenology drives the metabolic functions that are found. This pattern potentially obscures the relationships between species occurrence and functional diversity. 
However, there's two groups that are particularly interesting: diatoms and dinoflagellates. Both are present year-round, however a turnover in relative abundance and succession of different assemblages typifies the diatom community.
Our metatranscriptomic data thus allows us to answer key ecological questions as a case study:
1. What is the relation between the amount of species inhabiting a same niche and functional diversity?
2. Is there a link between evenness and functional diversity?
3. Do different assemblages have different functional profiles, and if so, what are the differences?


## 1. Load data

### 1.1 Gather taxonomic data

First, we'll look at the taxonomically annotated data. In order to capture all transcripts remotely identified as diatoms or dinoflagellates, all transcripts that are aligned to either of these groups with a percent identity of 60% or higher are included. This is a very loose threshold, but it allows us to capture all transcripts that are remotely related to these groups. 

In [None]:
## Load PhyloDB alignments
data_tax = pd.read_table('../../data/annotation/taxonomy_phyloDB_extended/phylodb_extended.firsthit.90plus_alnscore.m8',  engine='pyarrow', header=None)
print(f'The mmseqs alignment has {len(data_tax)} hits with a >60 % identity')
data_tax.columns = ['query_id', 'target_id', 'p_ident', 'alnlen', 'mismatch', 'gapopen', 'qstart', 'qend', 'tstart', 'tend', 'evalue', 'bits']
print(f'The mmseqs alignment contains {len(data_tax["query_id"].unique())} unique query sequences')
# Fix transcript names in the first column so that they equal the transcript identifiers in the count files
data_tax.iloc[:, 0] = data_tax.iloc[:, 0].str.split(".", expand=True).drop(columns=1)
print(f'The mmseqs alignment contains {len(data_tax["query_id"].unique())} unique query sequences')

# Add taxonomic information
taxonomy = pd.read_table('../../data/annotation/taxonomy_phyloDB_extended/phylodb_1.076.taxonomy_extended.txt')
print(f'The taxonomy file contains {len(taxonomy)} rows')
annotation = pd.read_table('../../data/annotation/taxonomy_phyloDB_extended/phylodb_1.076.annotations_extended.txt', engine='pyarrow', header=None)
print(f'The annotation file contains {len(annotation)} rows, linking transcript IDs to their taxonomy and function')
annotation.columns = ['target_id', 'code', 'strain_name', 'function']

# Add the taxonomy information to the data_tax
# data_tax = data_tax.merge(annotation, left_on='target_id', right_on='target_id', how='left')
data_tax = data_tax.merge(annotation, left_on='target_id', right_on='target_id')
print(f'The alignment merged with annotation file contains {len(data_tax)} rows')

data_tax = data_tax.merge(taxonomy, left_on='strain_name', right_on = 'strain_name')
print(f'The data_tax merged with the taxonomy  file contains {len(data_tax)} rows')

data_tax = data_tax.drop(columns=['code', 'peptide_count', 'mismatch', 'gapopen', 'qstart', 'qend', 'tstart', 'tend', 'evalue', 'bits'])
# Expand taxonomy path in relevant columns
data_tax[['kingdom', 'superphylum', 'phylum', 'class', 'order', 'family', 'genus', 'species']] = data_tax['taxonomy'].str.split(';', expand = True)

# View data_tax
data_tax.head()

In [None]:
data_tax.columns

### 1.1.2 Alternative taxonomy: EukProt
Alternatively, do this analysis with the EukProt taxonomy.

In [None]:
## Load eukprot annotations
data_tax_eukprot = pd.read_table('../../data/annotation/taxonomy_eukprot/eukprot_DB.firsthit.60plus_alnscore.m8', header=None)
print(f'The eukprot annotation file contains {len(data_tax_eukprot)} rows')
# Fix transcript names in the first column so that they equal the transcript identifiers in the count files
data_tax_eukprot.iloc[:, 0] = data_tax_eukprot.iloc[:, 0].str.split(".", expand=True).drop(columns=1)

## In the second column, split of the EukProt ID off
eukprot_ID = data_tax_eukprot.iloc[:, 1].str.split("_", expand=True)[0]
data_tax_eukprot.iloc[:, 1] = eukprot_ID
data_tax_eukprot.columns = ['query_id', 'target_id', 'p_ident', 'alnlen', 'mismatch', 'gapopen', 'qstart', 'qend', 'tstart', 'tend', 'evalue', 'bits']

## Add taxonomic information
eukprot_taxonomy = pd.read_table('../../data/annotation/taxonomy_eukprot/EukProt_included_data_sets.v03.2021_11_22.txt')
print(f'The eukprot taxonomy file contains {len(eukprot_taxonomy)} rows')

# Drop the columns that are not needed
eukprot_taxonomy.drop(columns=['Previous_Names', 'Replaces_EukProt_ID', 'Data_Source_URL', 'Data_Source_Name', 'Paper_DOI', 'Actions_Prior_to_Use',
       'Data_Source_Type', 'Notes', 'Columns_Modified_Since_Previous_Version', 'Merged_Strains',
       'Alternative_Strain_Names', '18S_Sequence_GenBank_ID', '18S_Sequence',
       '18S_Sequence_Source', '18S_Sequence_Other_Strain_GenBank_ID',
       '18S_Sequence_Other_Strain_Name', '18S_and_Taxonomy_Notes'], inplace=True)

# Swap the _ to a space in the Name_to_Use column
eukprot_taxonomy['Name_to_Use'] = eukprot_taxonomy['Name_to_Use'].str.replace('_', ' ')

# Merge the annotation and taxonomy files
data_tax_eukprot = data_tax_eukprot.merge(eukprot_taxonomy, left_on='target_id', right_on='EukProt_ID', how='left')
print(f'The eukprot annotation merged with taxonomy file contains {len(data_tax_eukprot)} rows')

# Drop the columns that are not needed
data_tax_eukprot.drop(columns=['target_id', 'gapopen', 'qstart', 'qend', 'tstart', 'tend', 'evalue', 'bits'], inplace=True)

### 1.2 Gather functional data

We'll also look at the functional annotations, which are based on the KEGG database.

In [None]:
# Load the annotation data
data_func = pd.read_csv('../../data/annotation/functional_eggnog/functional_annotation.emapper.annotations', sep = '\t', engine = 'pyarrow')

# Fix transcript names in the first column so that they equal the transcript identifiers in the count files
# This is necessary because TransDecoder adds .p2 or .p1 to the sequence identifiers
data_func.iloc[:, 0] = data_func.iloc[:, 0].str.split(".", expand=True).drop(columns=1)

# Check the annotation data, how many more hits are there compared to unique query sequences?
print(f'number of rows in eggNOG annotation: {len(data_func)}')
print(f'number of unique query sequences in eggNOG annotation: {len(data_func.iloc[:, 0].unique())}')

# Check all columns in the annotation data
data_func.columns

In [None]:
data_func.head()

In [None]:
# Only retain relevant columns and rename them
data_func = data_func[['#query',
             'Description', 'GOs', 'KEGG_ko', 'KEGG_Pathway', 'KEGG_Module', 'KEGG_Reaction', 'PFAMs']]

### 1.3 Gather count data

In [None]:
# Load the necessary count data
counts = pd.read_csv('../../data/kallisto/tpm.csv', engine='pyarrow')
## This is the transcript read mapping, prior to protein prediction!
# Rename the first column to transcript_id
counts = counts.rename(columns={'target_id': 'transcript_id'})
print(f'{len(counts)} transcripts were quantified (on the nucleotide level) in the kallisto run')
# Transform the data to the long format
counts = counts.melt(id_vars=['transcript_id'],var_name='sample', value_name='TPM')
counts.head()

### 1.3.2 Gather TPL daata

In [None]:
# Load the necessary count data
tpl = pd.read_csv('../../data/kallisto/transcripts_per_L.csv', engine='pyarrow')
## This is the transcript read mapping, prior to protein prediction!
# Rename the first column to transcript_id
tpl = tpl.rename(columns={'target_id': 'transcript_id'})
# Transform the data to the long format
tpl = tpl.melt(id_vars=['transcript_id'],var_name='sample', value_name='TPL')
tpl.head()

### 1.4 Combine data

In [None]:
# Combine count and annotation data together
## Prefilter
# Combine relevant taxonomic and functional data
#data = data[(data['class'] == 'Bacillariophyta') | (data['class'] == 'Dinophyceae') | (data['class'] == 'Prymnesiophyceae')]
data_tax_eukprot = data_tax_eukprot[(data_tax_eukprot['Taxogroup1_UniEuk'] == 'Ochrophyta') | (data_tax_eukprot['Taxogroup1_UniEuk'] == 'Dinoflagellata') | (data_tax_eukprot['Taxogroup1_UniEuk'] == 'Prymnesiophyceae')]

# Merge with PhyloDB or EUKprot, TPM or TPL data - depending on which code is run downstream.
#data = data_tax.merge(counts, left_on='query_id', right_on='transcript_id', how='left')
#data = data_tax_eukprot.merge(counts, left_on='query_id', right_on='transcript_id', how='left')
#del data_tax_eukprot, counts
data = data_tax_eukprot.merge(tpl, left_on='query_id', right_on='transcript_id', how='left')
del tpl

data = data.drop(columns='query_id')

# Add sample metadata
meta = pd.read_csv('../../samples.csv', sep=';', index_col=0)
data = data.merge(meta, left_on= 'sample', right_on= 'sample', how = 'left')

# View data
print(f'{len(data)} rows are in the data, which should be the amount of annotated proteins times the number of samples')
data.head()

In [None]:
#data[['class', 'TPM']].groupby('class').sum().sort_values(by='TPM', ascending=False)
data[['Taxogroup1_UniEuk', 'TPM']].groupby('Taxogroup1_UniEuk').sum().sort_values(by='TPM', ascending=False)

In [None]:
#data[['class', 'TPM']].groupby('class').sum().sort_values(by='TPM', ascending=False)
data[['Taxogroup1_UniEuk', 'TPL']].groupby('Taxogroup1_UniEuk').sum().sort_values(by='TPM', ascending=False)

In [None]:
#data[['class', 'TPM']].groupby('class').sum().sort_values(by='TPM', ascending=False)
data[['Taxogroup2_UniEuk', 'TPM']].groupby('Taxogroup2_UniEuk').sum().sort_values(by='TPM', ascending=False)

In [None]:
data = data.merge(data_func, left_on='transcript_id', right_on='#query')
data = data.drop(columns=['#query'])
print(len(data))

## 2. Diatoms

In [None]:
# First, let's get a list of the unique species in the data
#data[(data['class'] == 'Bacillariophyta') & (data['p_ident'] >= 0.98)]['species'].unique()
data[(data['Taxogroup2_UniEuk'] == 'Diatomeae') & (data['p_ident'] >= 0.98)]['Name_to_Use'].unique()

### 2.1 Diatom Abundance
First we'll make the graph showing relative diatom abundances per month throughout the year. After, we'll show absolute diatom abundances per month.

In [None]:
# Filter the data
## Get the diatoms
#tax_level = 'genus'
tax_level = 'Genus_UniEuk'
aggregation_level = 'month'

## Since we'll be looking at the relative abundance of different diatom genera, we can only include reads that are annotated to a genus with a sufficient % sequence identity
#data_diatoms = data[(data['class'] == 'Bacillariophyta') & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]
data_diatoms = data[(data['Taxogroup2_UniEuk'] == 'Diatomeae') & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]
data_diatoms.head()

In [None]:
# Remove transcripts below a certain TPM threshold
data_diatoms = data_diatoms[data_diatoms['TPM'] > 1]

# Group by month and genus, sum TPM
data_diatoms = data_diatoms.groupby([aggregation_level, tax_level]).sum().reset_index()

data_diatoms['month'] = pd.Categorical(data_diatoms['month'], ["July_2020", "August_2020", "September_2020", 
                                                        "November_2020", "December_2020", "January_2021", 
                                                        "February_2021", "April_2021", "May_2021", 
                                                        "June_2021", "July_2021"])

# Normalise mean of TPM values to the total mean TPM of that month
data_diatoms["rel_expression_per_month"] = data_diatoms.TPM / data_diatoms.groupby('month').TPM.transform('sum')

# Combine low-abundant groups
rare_groups = data_diatoms['rel_expression_per_month'] <= 0.02
data_diatoms.loc[rare_groups, tax_level] = 'Rare'

# Print unique species
print(data_diatoms[tax_level].unique())
# Inspect data
data_diatoms.head()

In [None]:
# Plot
## First, set conversion factor to transform pixels to cm
pixels_per_cm = 37.79527559055118

fig = px.histogram(data_diatoms.sort_values("month", ascending=False), 
                x="rel_expression_per_month", 
                y="month", 
                color=tax_level,
                orientation='h',
                color_discrete_map=(
                    {'Rare': '#808080', 
                     'Asterionellopsis': '#ff8c00', 
                     'Astrosyne': '#87cefa',
                     'Chaetoceros': '#E69F00',
                     'Corethron': '#ffd700',
                     'Coscinodiscus': '#B5B5B5',
                     'Craspedostauros': '#2196F3',
                     'Dactyliosolen': '#ba55d3',
                     'Ditylum': '#56B4E9',
                     'Entomoneis': '#795548',
                     'Eucampia': '#89CE00',
                     'Guinardia': '#009E73', 
                     'Helicotheca': '#04686E',
                     'Leptocylindrus': '#046E0A',
                     'Odontella': '#0072B2',
                     'Pseudo-nitzschia': '#A42324',
                     'Rhizosolenia': '#CC79A7',
                     'Skeletonema': '#fa8072',
                     'Stephanopyxis': '#dda0dd',
                     'Thalassiosira': '#D55E00',
                     'Triceratium': '#CC7DFF',
                     'Amphiprora': '#c2ccc4' 
                     }
                ),
                #color_discrete_map={
                #     'Rare': '#808080',
                #     'Asterionellopsis': '#ff8c00',
                #     'Chaetoceros': '#E69F00',
                #     'Cyclotella': '#8BC34A',
                #     'Dactyliosolen': '#ba55d3',
                #     'Ditylum': '#56B4E9',
                #     'Helicotheca': '#04686E',
                #     'Pseudictyota': '#673AB7',
                #     'Pseudo-nitzschia': '#A42324',
                #     'Rhizosolenia': '#CC79A7',
                #     'Thalassiosira': '#D55E00',
                #     'Trieres': '#F8EF3A',
                #     'Astrosyne': '#87cefa',
                #     'Coscinodiscus': '#B5B5B5',
                #     'Entomoneis': '#795548',
                #     'Extubocellulus': '#607D8B',
                #     'Gedaniella': '#3F51B5',
                #     'Leptocylindrus': '#046E0A',
             #     'Phaeodactylum': '#9E9E9E',
                #     'Odontella': '#0072B2',
                #     'Corethron': '#ffd700',
                #     'Licmophora': '#98fb98',
                #     'Proboscia': '#0000ff',
                #     'Stephanopyxis': '#dda0dd',
                #     'Synedropsis': '#673AB7',
                #     'Aulacoseira': '#00BCD4',
                #     'Skeletonema': '#fa8072',
                #     'Eucampia': '#89CE00',
                #     'Fragilariopsis': '#FFEB3B',
                #     'Attheya': '#795548',
                #     'Thalassionema': '#9C27B0',
                #     'Craspedostauros': '#2196F3'
                # },
                # Specify all the months that need to be included, 
                # even if no sample has been taken
                category_orders={"genus": ['Rare', 'Ditylum', 'Odontella', 'Thalassiosira', 'Chaetoceros', 'Rhizosolenia',
                                            'Asterionellopsis', 'Coscinodiscus', 'Dactyliosolen', 'Helicotheca', 'Leptocylindrus',
                                            'Proboscia', 'Pseudo-nitzschia', 'Corethron', 'Skeletonema', 'Stephanopyxis', 'Astrosyne', 
                                            'Entomoneis', 'Eucampia', 'Craspedostauros', 'Amphiprora', 'Guinardia', 'Triceratium'],
                                 #"Genus_UniEuk": ['Rare', 'Ditylum', 'Odontella', 'Thalassiosira', 'Chaetoceros', 'Rhizosolenia',
                                 #                'Asterionellopsis', 'Coscinodiscus', 'Dactyliosolen', 'Helicotheca', 'Leptocylindrus',
                                 #                'Proboscia', 'Pseudo-nitzschia', 'Corethron', 'Skeletonema', 'Stephanopyxis', 'Astrosyne', 'Licmophora']
                    }
                # text_auto='.2f'
                )

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 8.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='% TPM of total sum',
    yaxis_title_text=None,
)

fig.show()

# Save figure as png
#fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_relative_expression_per_month_{}_phylodb.png".format(tax_level), scale=1)
#fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_relative_expression_per_month_{}_phylodb.svg".format(tax_level), scale=1)

fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_relative_expression_per_month_{}_eukprot.png".format(tax_level), scale=1)
fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_relative_expression_per_month_{}_eukprot.svg".format(tax_level), scale=1)

In [None]:
# Store the order of the genera in the legend
diatom_legend_order = []
for i in range(len(fig.data)):
    diatom_legend_order.append(fig.data[i].name)
print(diatom_legend_order)

#### Spatial distribution

In [None]:
#tax_level = 'genus'
tax_level = 'Genus_UniEuk'
aggregation_level = 'month'
aggregation_level2 = 'station'

#data_diatoms = data[(data['class'] == 'Bacillariophyta') & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPM']]
data_diatoms = data[(data['Taxogroup2_UniEuk'] == 'Diatomeae') & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPM']]

# Remove transcripts below a certain TPM threshold
data_diatoms = data_diatoms[data_diatoms['TPM'] > 1]

# Group by month, station and species, sum TPM
data_diatoms = data_diatoms.groupby([aggregation_level, aggregation_level2, tax_level]).sum().reset_index()

data_diatoms['month'] = pd.Categorical(data_diatoms['month'], ["July_2020", "August_2020", "September_2020", 
                                                        "November_2020", "December_2020", "January_2021", 
                                                        "February_2021", "April_2021", "May_2021", 
                                                        "June_2021", "July_2021"])

# Create list of genera for color assignment
diatom_genera = []
# Plot per station
for station in data_diatoms[aggregation_level2].unique():
    plot_data = data_diatoms[data_diatoms[aggregation_level2] == station]
    # Normalise mean of TPM values to the total mean TPM of that month
    plot_data["rel_expression_per_month"] = plot_data.TPM / plot_data.groupby('month').TPM.transform('sum')

    # Combine low-abundant groups
    rare_groups = plot_data['rel_expression_per_month'] <= 0.02
    plot_data.loc[rare_groups, tax_level] = 'Rare'
    # Add all unique genera to dictionary if not already present
    for genus in plot_data[tax_level].unique():
        diatom_genera.append(genus) if genus not in diatom_genera else None
    # Plot
    fig = px.histogram(plot_data.sort_values("month", ascending=False), 
                    x="rel_expression_per_month", 
                    y="month", 
                    color=tax_level,
                    orientation='h',
                    # text_auto='.2f',
                    color_discrete_map=(
                        {'Rare': '#808080', 
                         'Asterionellopsis': '#ff8c00', 
                         'Astrosyne': '#87cefa',
                         'Chaetoceros': '#E69F00',
                         'Corethron': '#ffd700',
                         'Coscinodiscus': '#B5B5B5',
                         'Craspedostauros': '#2196F3',
                         'Dactyliosolen': '#ba55d3',
                         'Ditylum': '#56B4E9',
                         'Entomoneis': '#795548',
                         'Eucampia': '#89CE00',
                         'Guinardia': '#009E73', 
                         'Helicotheca': '#04686E',
                         'Leptocylindrus': '#046E0A',
                         'Odontella': '#0072B2',
                         'Pseudo-nitzschia': '#A42324',
                         'Rhizosolenia': '#CC79A7',
                         'Skeletonema': '#fa8072',
                         'Stephanopyxis': '#dda0dd',
                         'Thalassiosira': '#D55E00',
                         'Triceratium': '#CC7DFF',
                         'Amphiprora': '#c2ccc4' 
                         }
                    ),
                    #color_discrete_map={
                    #    'Rare': '#808080',
                    #    'Asterionellopsis': '#ff8c00',
                    #    'Chaetoceros': '#E69F00',
                    #    'Cyclotella': '#8BC34A',
                    #    'Dactyliosolen': '#ba55d3',
                    #    'Ditylum': '#56B4E9',
                    #    'Helicotheca': '#04686E',
                    #    'Pseudictyota': '#673AB7',
                    #    'Pseudo-nitzschia': '#A42324',
                    #    'Rhizosolenia': '#CC79A7',
                    #    'Thalassiosira': '#D55E00',
                    #    'Trieres': '#F8EF3A',
                    #    'Astrosyne': '#87cefa',
                    #    'Coscinodiscus': '#B5B5B5',
                    #    'Entomoneis': '#795548',
                    #    'Extubocellulus': '#607D8B',
                    #    'Gedaniella': '#3F51B5',
                    #    'Leptocylindrus': '#046E0A',
                    #    'Phaeodactylum': '#9E9E9E',
                    #    'Odontella': '#0072B2',
                    #    'Corethron': '#ffd700',
                    #    'Licmophora': '#98fb98',
                    #    'Proboscia': '#0000ff',
                    #    'Stephanopyxis': '#dda0dd',
                    #    'Synedropsis': '#673AB7',
                    #    'Aulacoseira': '#00BCD4',
                    #    'Skeletonema': '#fa8072',
                    #    'Eucampia': '#89CE00',
                    #    'Fragilariopsis': '#FFEB3B',
                    #    'Attheya': '#795548',
                    #    'Thalassionema': '#9C27B0',
                    #    'Craspedostauros': '#2196F3'
                    #},

                    # Specify all the months that need to be included, 
                    # even if no sample has been taken
                    category_orders={"month": ["July_2020", "August_2020", "September_2020",
                                        "November_2020", "December_2020", "January_2021",
                                        "February_2021", "April_2021", "May_2021",
                                        "June_2021", "July_2021"],
                                     "genus": ['Rare', 'Ditylum', 'Odontella', 'Thalassiosira', 'Chaetoceros', 'Rhizosolenia',
                                            'Asterionellopsis', 'Coscinodiscus', 'Dactyliosolen', 'Helicotheca', 'Leptocylindrus',
                                            'Proboscia', 'Pseudo-nitzschia', 'Corethron', 'Skeletonema', 'Stephanopyxis', 'Astrosyne', 
                                            'Entomoneis', 'Eucampia', 'Craspedostauros', 'Amphiprora', 'Guinardia', 'Triceratium'],
                                      #"Genus_UniEuk": ['Rare', 'Ditylum', 'Odontella', 'Thalassiosira', 'Chaetoceros', 'Rhizosolenia',
                                      #               'Asterionellopsis', 'Coscinodiscus', 'Dactyliosolen', 'Helicotheca', 'Leptocylindrus',
                                      #               'Proboscia', 'Pseudo-nitzschia', 'Corethron', 'Skeletonema', 'Stephanopyxis', 'Astrosyne', 'Licmophora']

                        }
                    )

    fig.update_layout(
        font=dict(
            family="Times New Roman, serif",  # Set the font family to Times New Roman
            size=8,  # Set the font size
            color="#000000"  # Set the font color
        ),
        autosize=False,
        width= 8.5 * pixels_per_cm,
        height= 7.5 * pixels_per_cm,
        margin=dict( # Set the margins
            l=0,  # Left margin
            r=25,  # Right margin
            b=25,  # Bottom margin
            t=25  # Top margin
        ),
        xaxis_title_text='% TPM of total sum',
        yaxis_title_text=None,
        ## Add station name to title
        title_text=station
    )

    fig.show()

    # Save figure as png
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_relative_expression_per_month_{}_at_{}_phylodb.png".format(tax_level, station), scale=1)
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_relative_expression_per_month_{}_at_{}_phylodb.svg".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_relative_expression_per_month_{}_at_{}_eukprot_phylodb.png".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_relative_expression_per_month_{}_at_{}_eukprot.svg".format(tax_level, station), scale=1)

In [None]:
diatom_genera
# Now create colours for all of these in the above code, set the order, and rerun the above code

#### Total TPM of diatoms per month

In [None]:
# Plot the total TPM per month per genus
# Normalise mean of TPM values to the total mean TPM of that month
data_diatoms["rel_expression_per_month"] = data_diatoms.TPM / data_diatoms.groupby('month').TPM.transform('sum')

# Combine low-abundant groups
rare_groups = data_diatoms['rel_expression_per_month'] <= 0.02
data_diatoms.loc[rare_groups, tax_level] = 'Rare'

fig = px.histogram(data_diatoms.sort_values("month", ascending=False),
             x = "TPM",
             y = "month",
             color = tax_level,
             # Pick the colors for the stations
            color_discrete_map=(
                {'Rare': '#808080', 
                 'Asterionellopsis': '#ff8c00', 
                 'Astrosyne': '#87cefa',
                 'Chaetoceros': '#E69F00',
                 'Corethron': '#ffd700',
                 'Coscinodiscus': '#B5B5B5',
                 'Craspedostauros': '#2196F3',
                 'Dactyliosolen': '#ba55d3',
                 'Ditylum': '#56B4E9',
                 'Entomoneis': '#795548',
                 'Eucampia': '#89CE00',
                 'Guinardia': '#009E73', 
                 'Helicotheca': '#04686E',
                 'Leptocylindrus': '#046E0A',
                 'Odontella': '#0072B2',
                 'Pseudo-nitzschia': '#A42324',
                 'Rhizosolenia': '#CC79A7',
                 'Skeletonema': '#fa8072',
                 'Stephanopyxis': '#dda0dd',
                 'Thalassiosira': '#D55E00',
                 'Triceratium': '#CC7DFF',
                 'Amphiprora': '#c2ccc4' 
                 }
            ),
             #color_discrete_map={
             #        'Rare': '#808080',
             #        'Asterionellopsis': '#ff8c00',
             #        'Chaetoceros': '#E69F00',
             #        'Cyclotella': '#8BC34A',
             #        'Dactyliosolen': '#ba55d3',
             #        'Ditylum': '#56B4E9',
             #        'Helicotheca': '#04686E',
             #        'Pseudictyota': '#673AB7',
             #        'Pseudo-nitzschia': '#A42324',
             #        'Rhizosolenia': '#CC79A7',
             #        'Thalassiosira': '#D55E00',
             #        'Trieres': '#F8EF3A',
             #        'Astrosyne': '#87cefa',
             #        'Coscinodiscus': '#B5B5B5',
             #        'Entomoneis': '#795548',
             #        'Extubocellulus': '#607D8B',
             #        'Gedaniella': '#3F51B5',
             #        'Leptocylindrus': '#046E0A',
             #        'Phaeodactylum': '#9E9E9E',
             #        'Odontella': '#0072B2',
             #        'Corethron': '#ffd700',
             #        'Licmophora': '#98fb98',
             #        'Proboscia': '#0000ff',
             #        'Stephanopyxis': '#dda0dd',
             #        'Synedropsis': '#673AB7',
             #        'Aulacoseira': '#00BCD4',
             #        'Skeletonema': '#fa8072',
             #        'Eucampia': '#89CE00',
             #        'Fragilariopsis': '#FFEB3B',
             #        'Attheya': '#795548',
             #        'Thalassionema': '#9C27B0',
             #        'Craspedostauros': '#2196F3'
             #    },
             orientation='h',
            category_orders={"month": ["July_2020", "August_2020", "September_2020",
                "November_2020", "December_2020", "January_2021",
                "February_2021", "April_2021", "May_2021",
                "June_2021", "July_2021"],
                "genus": ['Rare', 'Ditylum', 'Odontella', 'Thalassiosira', 'Chaetoceros', 'Rhizosolenia',
                                            'Asterionellopsis', 'Coscinodiscus', 'Dactyliosolen', 'Helicotheca', 'Leptocylindrus',
                                            'Proboscia', 'Pseudo-nitzschia', 'Corethron', 'Skeletonema', 'Stephanopyxis', 'Astrosyne', 
                                            'Entomoneis', 'Eucampia', 'Craspedostauros', 'Amphiprora', 'Guinardia', 'Triceratium'],
                'Genus_UniEuk': ['Rare', 'Helicotheca', 'Rhizosolenia', 'Leptocylindrus', 'Odontella', 'Pseudictyota',
                                 'Trieres', 'Pseudo-nitzschia', 'Thalassiosira', 'Licmophora', 'Asterionellopsis', 'Astrosyne', 
                                 'Coscinodiscus', 'Chaetoceros', 'Entomoneis', 'Extubocellulus', 'Dactyliosolen', 'Corethron', 
                                 'Stephanopyxis', 'Proboscia', 'Ditylum', 'Skeletonema', 'Eucampia']
                })

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 8.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='Total diatom TPM per month',
    yaxis_title_text=None,
    # Set range of x-axis
    xaxis_range=[0, 650000],
)
fig.show()

# save figure as svg
#fig.write_image(f"../../figures/diatoms_vs_dinoflagellates/diatoms_total_TPM_per_month_{tax_level}_phylodb.svg", scale=1)
fig.write_image(f"../../figures/diatoms_vs_dinoflagellates/diatoms_total_TPM_per_month_{tax_level}_eukprot.svg", scale=1)

In [None]:
# Plot the total TPM per month per station
fig = px.histogram(data_diatoms.sort_values("month", ascending=False),
             x = "TPM",
             y = "month",
             color = 'station',
             # Pick the colors for the stations
            color_discrete_map={
                    "ZG02": "#8c613c",
                    "120": "#956cb4",
                    "330": "#ee854a",
                    "130": "#4878d0",
                    "780": "#d65f5f",
                    "700": "#6acc64"},
             orientation='h',
            category_orders={"month": ["July_2020", "August_2020", "September_2020",
                "November_2020", "December_2020", "January_2021",
                "February_2021", "April_2021", "May_2021",
                "June_2021", "July_2021"]})


fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 8.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='Total diatom TPM per month',
    yaxis_title_text=None,
)
fig.show()

# save figure as svg
#fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_total_TPM_per_month_station_phylodb.svg", scale=1)
fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_total_TPM_per_month_station_eukprot.svg", scale=1)

#### Spatial distribution

In [None]:
#tax_level = 'genus'
tax_level = 'Genus_UniEuk'
aggregation_level = 'month'
aggregation_level2 = 'station'

#data_diatoms = data[(data['class'] == 'Bacillariophyta') & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPM']]
data_diatoms = data[(data['Taxogroup2_UniEuk'] == 'Diatomeae') & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPM']]

# Remove transcripts below a certain TPM threshold
data_diatoms = data_diatoms[data_diatoms['TPM'] > 1]

# Group by month, station and species, sum TPM
data_diatoms = data_diatoms.groupby([aggregation_level, aggregation_level2, tax_level]).sum().reset_index()

# Specify the desired order of months
month_order = ["July_2020", "August_2020", "September_2020", "November_2020", "December_2020", "January_2021",
               "February_2021", "April_2021", "May_2021", "June_2021", "July_2021"]

# Plot per station
for station in data_diatoms[aggregation_level2].unique():
    plot_data = data_diatoms[data_diatoms[aggregation_level2] == station]
    # Normalise mean of TPM values to the total mean TPM of that month
    plot_data["rel_expression_per_month"] = plot_data.TPM / plot_data.groupby('month').TPM.transform('sum')

    # Combine low-abundant groups
    rare_groups = plot_data['rel_expression_per_month'] <= 0.02
    plot_data.loc[rare_groups, tax_level] = 'Rare'

    fig = px.histogram(plot_data, x = 'TPM', y=aggregation_level, color = tax_level,
                 # Pick the colors for the stations
                color_discrete_map=(
                    {'Rare': '#808080', 
                     'Asterionellopsis': '#ff8c00', 
                     'Astrosyne': '#87cefa',
                     'Chaetoceros': '#E69F00',
                     'Corethron': '#ffd700',
                     'Coscinodiscus': '#B5B5B5',
                     'Craspedostauros': '#2196F3',
                     'Dactyliosolen': '#ba55d3',
                     'Ditylum': '#56B4E9',
                     'Entomoneis': '#795548',
                     'Eucampia': '#89CE00',
                     'Guinardia': '#009E73', 
                     'Helicotheca': '#04686E',
                     'Leptocylindrus': '#046E0A',
                     'Odontella': '#0072B2',
                     'Pseudo-nitzschia': '#A42324',
                     'Rhizosolenia': '#CC79A7',
                     'Skeletonema': '#fa8072',
                     'Stephanopyxis': '#dda0dd',
                     'Thalassiosira': '#D55E00',
                     'Triceratium': '#CC7DFF',
                     'Amphiprora': '#c2ccc4' 
                     }
                ),
                #color_discrete_map={
                #        'Rare': '#808080',
                #        'Asterionellopsis': '#ff8c00',
                #        'Chaetoceros': '#E69F00',
                #        'Cyclotella': '#8BC34A',
                #        'Dactyliosolen': '#ba55d3',
                #        'Ditylum': '#56B4E9',
                #        'Helicotheca': '#04686E',
                #        'Pseudictyota': '#673AB7',
                #        'Pseudo-nitzschia': '#A42324',
                #        'Rhizosolenia': '#CC79A7',
                #        'Thalassiosira': '#D55E00',
                #        'Trieres': '#F8EF3A',
                #        'Astrosyne': '#87cefa',
                #        'Coscinodiscus': '#B5B5B5',
                #        'Entomoneis': '#795548',
                #        'Extubocellulus': '#607D8B',
                #        'Gedaniella': '#3F51B5',
                #        'Leptocylindrus': '#046E0A',
                #        'Phaeodactylum': '#9E9E9E',
                #        'Odontella': '#0072B2',
                #        'Corethron': '#ffd700',
                #        'Licmophora': '#98fb98',
                #        'Proboscia': '#0000ff',
                #        'Stephanopyxis': '#dda0dd',
                #        'Synedropsis': '#673AB7',
                #        'Aulacoseira': '#00BCD4',
                #        'Skeletonema': '#fa8072',
                #        'Eucampia': '#89CE00',
                #        'Fragilariopsis': '#FFEB3B',
                #        'Attheya': '#795548',
                #        'Thalassionema': '#9C27B0',
                #        'Craspedostauros': '#2196F3'
                #    },
                category_orders={aggregation_level: month_order,
                                 "genus": ['Rare', 'Ditylum', 'Odontella', 'Thalassiosira', 'Chaetoceros', 'Rhizosolenia',
                                            'Asterionellopsis', 'Coscinodiscus', 'Dactyliosolen', 'Helicotheca', 'Leptocylindrus',
                                            'Proboscia', 'Pseudo-nitzschia', 'Corethron', 'Skeletonema', 'Stephanopyxis', 'Astrosyne', 
                                            'Entomoneis', 'Eucampia', 'Craspedostauros', 'Amphiprora', 'Guinardia', 'Triceratium']},
                title=station)

    fig.update_layout(
        font=dict(
            family="Times New Roman, serif",
            size=8,
            color="#000000"
        ),
        autosize=False,
        width=8.5 * pixels_per_cm,
        height=7.5 * pixels_per_cm,
        margin=dict(
            l=0,
            r=25,
            b=25,
            t=25
        ),
        xaxis_title_text='Month',
        yaxis_title_text='Total diatom TPM',
        legend_title_text=tax_level,
        # Set range of x-axis
        xaxis_range=[0, 300000]
    )
    
    fig.show()

    # Save figure as png
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_total_TPM_per_month_{}_at_{}_phylodb.png".format(tax_level, station), scale=1)
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_total_TPM_per_month_{}_at_{}_phylodb.svg".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_total_TPM_per_month_{}_at_{}_eukprot.png".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_total_TPM_per_month_{}_at_{}_eukprot.svg".format(tax_level, station), scale=1)

### TPL

In [None]:
#tax_level = 'genus'
tax_level = 'Genus_UniEuk'
aggregation_level = 'month'
aggregation_level2 = 'station'

#data_diatoms = data[(data['class'] == 'Bacillariophyta') & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPL']]
data_diatoms = data[(data['Taxogroup2_UniEuk'] == 'Diatomeae') & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPL']]

# Remove transcripts below a certain TPL threshold
data_diatoms = data_diatoms[data_diatoms['TPL'] > 1]

# Group by month, station and species, sum TPL
data_diatoms = data_diatoms.groupby([aggregation_level, aggregation_level2, tax_level]).sum().reset_index()

# Specify the desired order of months
month_order = ["July_2020", "August_2020", "September_2020", "November_2020", "December_2020", "January_2021",
               "February_2021", "April_2021", "May_2021", "June_2021", "July_2021"]

## First, set conversion factor to transform pixels to cm
pixels_per_cm = 37.79527559055118

# Plot per station
for station in data_diatoms[aggregation_level2].unique():
    plot_data = data_diatoms[data_diatoms[aggregation_level2] == station]
    # Normalise mean of TPL values to the total mean TPL of that month
    plot_data["rel_expression_per_month"] = plot_data.TPL / plot_data.groupby('month').TPL.transform('sum')

    # Combine low-abundant groups
    rare_groups = plot_data['rel_expression_per_month'] <= 0.02
    plot_data.loc[rare_groups, tax_level] = 'Rare'

    fig = px.histogram(plot_data, x = 'rel_expression_per_month', y=aggregation_level, color = tax_level,
                 # Pick the colors for the stations
                #color_discrete_map=(
                #    {'Rare': '#808080', 
                #     'Asterionellopsis': '#ff8c00', 
                #     'Astrosyne': '#87cefa',
                #     'Chaetoceros': '#E69F00',
                #     'Corethron': '#ffd700',
                #     'Coscinodiscus': '#B5B5B5',
                #     'Craspedostauros': '#2196F3',
                #     'Dactyliosolen': '#ba55d3',
                #     'Ditylum': '#56B4E9',
                #     'Entomoneis': '#795548',
                #     'Eucampia': '#89CE00',
                #     'Guinardia': '#009E73', 
                #     'Helicotheca': '#04686E',
                #     'Leptocylindrus': '#046E0A',
                #     'Odontella': '#0072B2',
                #     'Pseudo-nitzschia': '#A42324',
                #     'Rhizosolenia': '#CC79A7',
                #     'Skeletonema': '#fa8072',
                #     'Stephanopyxis': '#dda0dd',
                #     'Thalassiosira': '#D55E00',
                #     'Triceratium': '#CC7DFF',
                #     'Amphiprora': '#c2ccc4' 
                #     }
                #),
                color_discrete_map={
                        'Rare': '#808080',
                        'Asterionellopsis': '#ff8c00',
                        'Chaetoceros': '#E69F00',
                        'Cyclotella': '#8BC34A',
                        'Dactyliosolen': '#ba55d3',
                        'Ditylum': '#56B4E9',
                        'Helicotheca': '#04686E',
                        'Pseudictyota': '#673AB7',
                        'Pseudo-nitzschia': '#A42324',
                        'Rhizosolenia': '#CC79A7',
                        'Thalassiosira': '#D55E00',
                        'Trieres': '#F8EF3A',
                        'Astrosyne': '#87cefa',
                        'Coscinodiscus': '#B5B5B5',
                        'Entomoneis': '#795548',
                        'Extubocellulus': '#607D8B',
                        'Gedaniella': '#3F51B5',
                        'Leptocylindrus': '#046E0A',
                        'Phaeodactylum': '#9E9E9E',
                        'Odontella': '#0072B2',
                        'Corethron': '#ffd700',
                        'Licmophora': '#98fb98',
                        'Proboscia': '#0000ff',
                        'Stephanopyxis': '#dda0dd',
                        'Synedropsis': '#673AB7',
                        'Aulacoseira': '#00BCD4',
                        'Skeletonema': '#fa8072',
                        'Eucampia': '#89CE00',
                        'Fragilariopsis': '#FFEB3B',
                        'Attheya': '#795548',
                        'Thalassionema': '#9C27B0',
                        'Craspedostauros': '#2196F3'
                    },
                category_orders={aggregation_level: month_order,
                                 "genus": ['Rare', 'Ditylum', 'Odontella', 'Thalassiosira', 'Chaetoceros', 'Rhizosolenia',
                                            'Asterionellopsis', 'Coscinodiscus', 'Dactyliosolen', 'Helicotheca', 'Leptocylindrus',
                                            'Proboscia', 'Pseudo-nitzschia', 'Corethron', 'Skeletonema', 'Stephanopyxis', 'Astrosyne', 
                                            'Entomoneis', 'Eucampia', 'Craspedostauros', 'Amphiprora', 'Guinardia', 'Triceratium']},
                title=station)

    fig.update_layout(
        font=dict(
            family="Times New Roman, serif",
            size=8,
            color="#000000"
        ),
        autosize=False,
        width=8.5 * pixels_per_cm,
        height=7.5 * pixels_per_cm,
        margin=dict(
            l=0,
            r=25,
            b=25,
            t=25
        ),
        xaxis_title_text='Month',
        yaxis_title_text='Total diatom TPL',
        legend_title_text=tax_level,
    )
    
    fig.show()

    # Save figure as png
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_relative_TPL_per_month_{}_at_{}_phylodb.png".format(tax_level, station), scale=1)
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_relative_TPL_per_month_{}_at_{}_phylodb.svg".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_relative_TPL_per_month_{}_at_{}_eukprot.png".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_relative_TPL_per_month_{}_at_{}_eukprot.svg".format(tax_level, station), scale=1)

In [None]:
#tax_level = 'genus'
tax_level = 'Genus_UniEuk'
aggregation_level = 'month'
aggregation_level2 = 'station'

#data_diatoms = data[(data['class'] == 'Bacillariophyta') & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPL']]
data_diatoms = data[(data['Taxogroup2_UniEuk'] == 'Diatomeae') & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPL']]

# Remove transcripts below a certain TPL threshold
data_diatoms = data_diatoms[data_diatoms['TPL'] > 1]

# Group by month, station and species, sum TPL
data_diatoms = data_diatoms.groupby([aggregation_level, aggregation_level2, tax_level]).sum().reset_index()

# Specify the desired order of months
month_order = ["July_2020", "August_2020", "September_2020", "November_2020", "December_2020", "January_2021",
               "February_2021", "April_2021", "May_2021", "June_2021", "July_2021"]

## First, set conversion factor to transform pixels to cm
pixels_per_cm = 37.79527559055118

# Plot per station
for station in data_diatoms[aggregation_level2].unique():
    plot_data = data_diatoms[data_diatoms[aggregation_level2] == station]
    # Normalise mean of TPL values to the total mean TPL of that month
    plot_data["rel_expression_per_month"] = plot_data.TPL / plot_data.groupby('month').TPL.transform('sum')

    # Combine low-abundant groups
    rare_groups = plot_data['rel_expression_per_month'] <= 0.02
    plot_data.loc[rare_groups, tax_level] = 'Rare'

    fig = px.histogram(plot_data, x = 'TPL', y=aggregation_level, color = tax_level,
                 # Pick the colors for the stations
                #color_discrete_map=(
                #    {'Rare': '#808080', 
                #     'Asterionellopsis': '#ff8c00', 
                #     'Astrosyne': '#87cefa',
                #     'Chaetoceros': '#E69F00',
                #     'Corethron': '#ffd700',
                #     'Coscinodiscus': '#B5B5B5',
                #     'Craspedostauros': '#2196F3',
                #     'Dactyliosolen': '#ba55d3',
                #     'Ditylum': '#56B4E9',
                #     'Entomoneis': '#795548',
                #     'Eucampia': '#89CE00',
                #     'Guinardia': '#009E73', 
                #     'Helicotheca': '#04686E',
                #     'Leptocylindrus': '#046E0A',
                #     'Odontella': '#0072B2',
                #     'Pseudo-nitzschia': '#A42324',
                #     'Rhizosolenia': '#CC79A7',
                #     'Skeletonema': '#fa8072',
                #     'Stephanopyxis': '#dda0dd',
                #     'Thalassiosira': '#D55E00',
                #     'Triceratium': '#CC7DFF',
                #     'Amphiprora': '#c2ccc4' 
                #     }
                #),
                color_discrete_map={
                        'Rare': '#808080',
                        'Asterionellopsis': '#ff8c00',
                        'Chaetoceros': '#E69F00',
                        'Cyclotella': '#8BC34A',
                        'Dactyliosolen': '#ba55d3',
                        'Ditylum': '#56B4E9',
                        'Helicotheca': '#04686E',
                        'Pseudictyota': '#673AB7',
                        'Pseudo-nitzschia': '#A42324',
                        'Rhizosolenia': '#CC79A7',
                        'Thalassiosira': '#D55E00',
                        'Trieres': '#F8EF3A',
                        'Astrosyne': '#87cefa',
                        'Coscinodiscus': '#B5B5B5',
                        'Entomoneis': '#795548',
                        'Extubocellulus': '#607D8B',
                        'Gedaniella': '#3F51B5',
                        'Leptocylindrus': '#046E0A',
                        'Phaeodactylum': '#9E9E9E',
                        'Odontella': '#0072B2',
                        'Corethron': '#ffd700',
                        'Licmophora': '#98fb98',
                        'Proboscia': '#0000ff',
                        'Stephanopyxis': '#dda0dd',
                        'Synedropsis': '#673AB7',
                        'Aulacoseira': '#00BCD4',
                        'Skeletonema': '#fa8072',
                        'Eucampia': '#89CE00',
                        'Fragilariopsis': '#FFEB3B',
                        'Attheya': '#795548',
                        'Thalassionema': '#9C27B0',
                        'Craspedostauros': '#2196F3'
                    },
                category_orders={aggregation_level: month_order,
                                 "genus": ['Rare', 'Ditylum', 'Odontella', 'Thalassiosira', 'Chaetoceros', 'Rhizosolenia',
                                            'Asterionellopsis', 'Coscinodiscus', 'Dactyliosolen', 'Helicotheca', 'Leptocylindrus',
                                            'Proboscia', 'Pseudo-nitzschia', 'Corethron', 'Skeletonema', 'Stephanopyxis', 'Astrosyne', 
                                            'Entomoneis', 'Eucampia', 'Craspedostauros', 'Amphiprora', 'Guinardia', 'Triceratium']},
                title=station)

    fig.update_layout(
        font=dict(
            family="Times New Roman, serif",
            size=8,
            color="#000000"
        ),
        autosize=False,
        width=8.5 * pixels_per_cm,
        height=7.5 * pixels_per_cm,
        margin=dict(
            l=0,
            r=25,
            b=25,
            t=25
        ),
        xaxis_title_text='Month',
        yaxis_title_text='Total diatom TPL',
        legend_title_text=tax_level,
        # Set range of x-axis
        xaxis_range=[0, 650000000]
    )
    
    fig.show()

    # Save figure as png
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_total_TPL_per_month_{}_at_{}_phylodb.png".format(tax_level, station), scale=1)
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_total_TPL_per_month_{}_at_{}_phylodb.svg".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_total_TPL_per_month_{}_at_{}_eukprot.png".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatoms_total_TPL_per_month_{}_at_{}_eukprot.svg".format(tax_level, station), scale=1)

### 2.2 Biodiversity indices
Now, we'll generate samples x species tables and use those to generate biodiversity indices per sample. Then, we'll visualize patterns in said indices over time (in months).

In [None]:
# Filter the data
## Get the diatoms
# tax_level = 'genus'
tax_level = 'Name_to_Use'
aggregation_level = 'sample'

## Since we'll be looking at the relative abundance of different diatom genera, we can only include reads that are annotated to a genus with a sufficient % sequence identity
data_diatoms = data[(data['Taxogroup2_UniEuk'] == 'Diatomeae') & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]
# Remove transcripts below a certain TPM threshold
data_diatoms = data_diatoms[data_diatoms['TPM'] > 1]

# First, I only want to include transcriptome bins that have at least 50 non-zero expressed transcripts in one of the samples
# Group by month and genus, sum TPM
data_diatoms = data_diatoms.groupby([aggregation_level, tax_level]).count().reset_index()
# Now transform the data to the wide format
data_diatoms = data_diatoms.pivot(index=tax_level, columns=aggregation_level, values='TPM')
# Set NaN values to 0
data_diatoms = data_diatoms.fillna(0)

# Visualize the data
data_diatoms.head()

In [None]:
data_diatoms[data_diatoms > 100].count()

In [None]:
# Check how many diatom genera exist per month that have > 100 non-zero TPM transcripts
## Per sample, count the number of genera that have > 100 TPM
taxonomic_bin_abundance = data_diatoms[data_diatoms > 100].count(axis=0).reset_index()
taxonomic_bin_abundance.columns = ['sample', 'num_species']

## Add metadata
taxonomic_bin_abundance = taxonomic_bin_abundance.merge(meta[['month', 'station']], left_on='sample', right_index=True, how='left')

# Plot the number of diatom genera per month that have > 100 TPM
fig = px.box(taxonomic_bin_abundance, x='month', y='num_species',
             category_orders={'month': ['July_2020', "August_2020", "September_2020", "November_2020", 
               "December_2020", "January_2021", "February_2021", "April_2021", 
               "May_2021", "June_2021", "July_2021"]})

fig.show()

# Save the figure
fig.write_image("../../figures/diatoms_vs_dinoflagellates/num_eukprot_diatom_species_per_month.svg", width = 3.5, height = 2, scale=1, format='svg')

In [None]:
# ANOVA for months
model_month = ols('num_species ~ month', data=taxonomic_bin_abundance).fit()
anova_table_month = sm.stats.anova_lm(model_month, typ=2)

# ANOVA for stations
model_stations = ols('num_species ~ station', data=taxonomic_bin_abundance).fit()
anova_table_stations = sm.stats.anova_lm(model_stations, typ=2)

print("ANOVA Table for the difference in active species between months:\n", anova_table_month)
print("\nANOVA Table for the difference in active species between stations:\n", anova_table_stations)

In [None]:
# Check normality of residuals for the months
print("\nChecking Normality of Residuals for Months...")
print(stats.shapiro(model_month.resid))
sm.qqplot(model_month.resid, line='s')
plt.title("Q-Q Plot of Residuals for Months")
plt.show()

# Check homogeneity of variances for the months
print("\nChecking Homogeneity of Variances for the Months...")
print(stats.levene(*[taxonomic_bin_abundance['num_species'][taxonomic_bin_abundance['month'] == s] for s in taxonomic_bin_abundance['month'].unique()]))

In [None]:
# Kruskal-Wallis test for the months as ANOVA assumptions were violated
kw_table_month = stats.kruskal(*[taxonomic_bin_abundance['num_species'][taxonomic_bin_abundance['month'] == s] for s in taxonomic_bin_abundance['month'].unique()])

print("\nKruskal-Wallis Table for the difference in active species between months:\n", kw_table_month)

In [None]:
# Extract the genera that have > 100 TPM in at least one sample
taxonomic_bin_abundance = data_diatoms[data_diatoms > 100].count(axis=1)
taxonomic_bin_abundance.head()

In [None]:
# Plot distribution of most abundant diatom genera
taxonomic_bin_abundance.sort_values(ascending=False).plot(kind='bar', figsize=(10, 5));

In [None]:
# Extract the genera that have > 100 TPM in at least one sample
taxonomic_bin_abundance = data_diatoms[data_diatoms > 100].count(axis=1).reset_index()
taxonomic_bin_abundance.columns = ['taxonomic_bin', 'abundant_in_samples']
taxonomic_bin_abundance.head()

In [None]:
## Since we'll be looking at the relative abundance of different diatom genera, we can only include reads that are annotated to a genus with a sufficient % sequence identity
data_diatoms = data[(data['Taxogroup2_UniEuk'] == 'Diatomeae') & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]
# Remove transcripts below a certain TPM threshold
data_diatoms = data_diatoms[data_diatoms['TPM'] > 1]

# Group by month and genus, sum TPM
data_diatoms = data_diatoms.groupby([aggregation_level, tax_level]).count().reset_index()

# Now transform the data to the wide format
data_diatoms = data_diatoms.pivot(index=tax_level, columns=aggregation_level, values='TPM')
# Set NaN values to 0
data_diatoms = data_diatoms.fillna(0)

# Filter the dataframe for these genera
abundant_taxa = taxonomic_bin_abundance[taxonomic_bin_abundance['abundant_in_samples'] >= 1]['taxonomic_bin'].tolist()
data_diatoms = data_diatoms[data_diatoms.index.isin(abundant_taxa)]

# Save to csv
data_diatoms.to_csv(f'../../data/analysis/diatom_{tax_level}_{aggregation_level}.csv')

In [None]:
# Now save the count of these genera per sample
taxonomic_bin_abundance = data_diatoms[data_diatoms > 100].count(axis=0).reset_index()
taxonomic_bin_abundance.columns = ['sample', 'num_species']

In [None]:
taxonomic_bin_abundance.head()

Now run the above through the R [diversity](diversity.r) script. Then, we'll visualize the results.

In [None]:
# Load biodiversity estimates
diatom_biodiversity = pd.read_csv('../../data/analysis/diatom_biodiversity_estimates.csv', index_col=0)

# Add metadata to the biodiversity estimates
meta = pd.read_csv('../../samples.csv', sep=';', index_col=0)
diatom_biodiversity = diatom_biodiversity.merge(meta, left_index=True, right_index=True)

# Add the number of diatom species per sample
diatom_biodiversity = diatom_biodiversity.merge(taxonomic_bin_abundance, left_index=True, right_on='sample')

In [None]:
# Plot number of diatom genera per month
# Set figure size
cm = 1/2.54
plt.figure(figsize=(5*cm, 5*cm))

month_order = ["July_2020", "August_2020", "September_2020", "November_2020", 
                "December_2020", "January_2021", "February_2021", "April_2021", 
                "May_2021", "June_2021", "July_2021"]

# Create a plot of the functional annotation per month
sns.boxplot(x="month", y="num_species",
              order=month_order,
              data=diatom_biodiversity,
              color='#63C5DA')

# Change x axis labels rotation
plt.xticks(rotation=90, ha='right')

# Add grid
plt.grid(axis='both')

# save figure as svg
plt.savefig(f'../../figures/diatoms_vs_dinoflagellates/diatom_species_per_month.svg', dpi=600, bbox_inches='tight')

In [None]:
# Plot evenness per month
# Set figure size
cm = 1/2.54
plt.figure(figsize=(5*cm, 5*cm))

month_order = ["July_2020", "August_2020", "September_2020", "November_2020", 
                "December_2020", "January_2021", "February_2021", "April_2021", 
                "May_2021", "June_2021", "July_2021"]

# Create a plot of the functional annotation per month
sns.boxenplot(x="month", y="evenness",
              order=month_order,
              scale="linear", data=diatom_biodiversity,
              color='#63C5DA')

# Change x axis labels rotation
plt.xticks(rotation=90, ha='right')

# Set y axis limits
plt.ylim(-0.05, 1.05)

# Add grid
plt.grid(axis='both')

# save figure as svg
plt.savefig(f'../../figures/diatoms_vs_dinoflagellates/diatom_evenness_per_month.svg', dpi=600, bbox_inches='tight')

### 2.3 Calculate Functional Richness

#### Data preparation

In [None]:
# To the above dataframe, add the amount of unique KEGG KO ids per sample
# Set aggregation levels of interest
functional_info = 'KEGG_ko'
tax_level = 'Name_to_Use'
# tax_level = 'genus'
aggregation_level = 'sample'
# Only retain abundant species determined above
data_diatoms = data[data[tax_level].isin(abundant_taxa)][[aggregation_level, functional_info, 'TPM']]

# if processing kegg data, extra preprocessing is required: 
# Comment or uncomment the following line if multiple values are assigned to a transcript!
## split them up
data_diatoms = data_diatoms.assign(**{functional_info:data_diatoms[functional_info].str.split(',')})
# Now we can explode the functional column
data_diatoms = data_diatoms.explode(functional_info)
# Cut off that weird part of the identifier, if necessary!
data_diatoms[functional_info] = data_diatoms[functional_info].str.split(":", expand=True).drop(columns=0)

# Remove the rows with no functional annotation
data_diatoms = data_diatoms[data_diatoms[functional_info].notna()]
data_diatoms = data_diatoms[data_diatoms[functional_info] != '-']

# Remove transcripts below a certain TPM threshold
data_diatoms = data_diatoms[data_diatoms['TPM'] > 1]

# Group by functional information and sample values, sum TPM
functional_diversity = data_diatoms.groupby([functional_info, aggregation_level]).sum().reset_index()

# Reshape the dataframe to wide format
functional_diversity = functional_diversity.pivot(index=functional_info, columns=aggregation_level, values='TPM')

# Set NaN values to 0
functional_diversity = functional_diversity.fillna(0)

# Calculate functional richness
functional_diversity[functional_diversity < 1] = 0
functional_diversity[functional_diversity >= 1] = 1

functional_richness =  functional_diversity.sum(axis=0).reset_index()

functional_richness.columns = ['sample', functional_info]
functional_richness.set_index('sample', inplace=True)


In [None]:
# Add the functional richness to the dataframe
diatom_biodiversity = diatom_biodiversity.merge(functional_richness, left_on='sample', right_index=True)

In [None]:
diatom_biodiversity.head()

In [None]:
## This piece of code is now deprecated, but might still be useful in the future
## Swap identifier for KO name
## Load the KO names
#KO_names = pd.read_csv('../../data/analysis/kegg_info.txt', sep='\t', engine='pyarrow', header=None)
#KO_names.columns = ['KO', 'Name']
#
## Remove the part in the name before the ;
#KO_names['Name'] = KO_names['Name'].str.split(';', expand=True).drop(columns=[0,2])
## Remove the part in the name between the []
#KO_names['Name'] = KO_names['Name'].str.split('[', expand=True).drop(columns=[1,2,3])
## Add the KO identifier to the name
#KO_names['Name'] = KO_names['Name'] + ' (' + KO_names['KO'] + ')'
#
## Merge the KO names with the data
#data_diatoms = data_diatoms.merge(KO_names, left_on=functional_info, right_on='KO', how='left')
#
## Remove the KO identifier
#data_diatoms = data_diatoms.drop(columns=['KEGG_ko', 'KO'])
#
## Regroup to sum similar KO names
#data_diatoms = data_diatoms.groupby(['Name', aggregation_level]).sum().reset_index()



#### Abundance of functions

In [None]:
# Plot the amount of KO per month in a boxplot
# Set figure size
cm = 1/2.54
plt.figure(figsize=(5*cm, 5*cm))

# Create a plot of the functional annotation per month
sns.boxenplot(x="month", y="KEGG_ko",
              order=month_order,
              scale="linear", data=diatom_biodiversity,
              color='#63C5DA')

# Change x axis labels rotation
plt.xticks(rotation=90, ha='right')

# Set y axis limits
plt.ylim(-50, 5500)

# Add grid
plt.grid(axis='both')

# save figure as svg
plt.savefig(f'../../figures/diatoms_vs_dinoflagellates/diatom_KO_abundance_per_month.svg', dpi=600, bbox_inches='tight')

#### Active species richnes vs abundance of functional modules

In [None]:
# There's fewer active species in dinoflagellates, so let's just create a boxplot of the number of active species vs functional richness
# Plot the amount of KO per month in a boxplot
# Set figure size
cm = 1/2.54
plt.figure(figsize=(10*cm, 5*cm))

# Create a plot of the functional annotation per month
sns.boxenplot(x="num_species", y="KEGG_ko",
              scale="linear", data=diatom_biodiversity,
              color='#63C5DA')

# Set y axis limits
plt.ylim(-50, 5500)

# Add grid
plt.grid(axis='both')

# save figure as svg
plt.savefig(f'../../figures/diatoms_vs_dinoflagellates/diatoms_KO_abundance_num_species.svg', dpi=600, bbox_inches='tight')

In [None]:
# How is functional richness correlated to the amount of diatom reads?
## Calculate the number of reads that are assigned to diatoms in a given sample
number_of_reads = data[data['Taxogroup2_UniEuk'] == 'Diatomeae'][[aggregation_level, 'TPM']]
number_of_reads = number_of_reads.groupby(aggregation_level).sum().reset_index()

# Add it to the diatom_biodiversity data
diatom_biodiversity = diatom_biodiversity.merge(number_of_reads, on='sample')

In [None]:
diatom_biodiversity

In [None]:
# Plot
fig = px.scatter(diatom_biodiversity, x="TPM", y="KEGG_ko")

fig.update_traces(marker=dict(size=5,
                              color='#63C5DA',
                              line=dict(width=1,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 5.5 * pixels_per_cm,
    height= 5.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    plot_bgcolor='white',
    showlegend=False
)

fig.update_xaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    linewidth=0.5,
    gridcolor='lightgrey'
)
fig.update_yaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    linewidth=0.5,
    gridcolor='lightgrey',
    range = [-50, 3000]
)

fig.show()


In [None]:
plot_data = diatom_biodiversity.copy()

# Define the models
linear_model = LinearRegression()
poly_model = make_pipeline(PolynomialFeatures(degree=2), LinearRegression())

# Metrics to be calculated: 'neg_mean_squared_error' and 'r2'
scoring_metrics = ['neg_mean_squared_error', 'r2']

# Calculate the cross-validated scores for all models using mean squared error and r2 as the scoring metrics
linear_scores = cross_validate(linear_model, plot_data[['num_species']], plot_data['KEGG_ko'], cv=5, scoring=scoring_metrics)
poly_scores = cross_validate(poly_model, plot_data[['num_species']], plot_data['KEGG_ko'], cv=5, scoring=scoring_metrics)

# Printing the results
print("Linear model:")
print(f"  Negative MSE: {linear_scores['test_neg_mean_squared_error'].mean():.2f} with a standard deviation of {linear_scores['test_neg_mean_squared_error'].std():.2f}")
print(f"  R2: {linear_scores['test_r2'].mean():.4f} with a standard deviation of {linear_scores['test_r2'].std():.4f}")

print("\nPolynomial model:")
print(f"  Negative MSE: {poly_scores['test_neg_mean_squared_error'].mean():.2f} with a standard deviation of {poly_scores['test_neg_mean_squared_error'].std():.2f}")
print(f"  R2: {poly_scores['test_r2'].mean():.4f} with a standard deviation of {poly_scores['test_r2'].std():.4f}")

# Performing a paired t-test to compare the performance of the two models on MSE
t_stat, p_val = stats.ttest_rel(linear_scores['test_neg_mean_squared_error'], poly_scores['test_neg_mean_squared_error'])

print(f"\nT-statistic for MSE: {t_stat}")
print(f"P-value for MSE: {p_val}")
if p_val < 0.05:
    print("The difference in MSE between the two models is statistically significant.")
else:
    print("There is no statistically significant difference in MSE between the two models.")

In [None]:
# Linear model is the better fit
## Fit the polynomial model
linear_model.fit(plot_data[['num_species']], plot_data['KEGG_ko'])
## Extract the predictions of the polynomial model
# First generate x-values for the polynomial fit
x_values = np.linspace(plot_data['num_species'].min(), plot_data['num_species'].max(), 100)

# Predict y-values using the polynomial fit
model_predictions = linear_model.predict(x_values.reshape(-1, 1))

# Set conversion factor to transform pixels to cm
pixels_per_cm = 37.79527559055118

fig = px.scatter(plot_data, x="num_species", y="KEGG_ko")

# Add the polynomial fit
fig.add_trace(go.Scatter(
    x=x_values,
    y=model_predictions,
    mode='lines',
    line=dict(color='black', width=1),
    name='Polynomial Fit'
))

fig.update_traces(marker=dict(size=5,
                              color='#63C5DA',
                              line=dict(width=1,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 5.5 * pixels_per_cm,
    height= 5.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    plot_bgcolor='white',
    showlegend=False
)

fig.update_xaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    linewidth=0.5,
    gridcolor='lightgrey'
)
fig.update_yaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    linewidth=0.5,
    gridcolor='lightgrey',
    range = [-50, 3000]
)

fig.show()

In [None]:
# Fit the data with scipy
# Set the model, a negative exponential
model = lambda x, a, b, c: b - (b-a) * np.exp(-c * x)
popt, pcov = curve_fit(model, plot_data['num_species'], plot_data['KEGG_ko'])

# Generate x-values for the exponential fit
x_values = np.linspace(plot_data['num_species'].min(), plot_data['num_species'].max(), 100)
# Predict y-values using the exponential fit
model_predictions = model(x_values, *popt)

fig = px.scatter(plot_data, x="num_species", y="KEGG_ko")

# Add the polynomial fit
fig.add_trace(go.Scatter(
    x=x_values,
    y=model_predictions,
    mode='lines',
    line=dict(color='black', width=1),
    name='Polynomial Fit'
))

fig.update_traces(marker=dict(size=5,
                              color='#63C5DA',
                              line=dict(width=1,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 5.5 * pixels_per_cm,
    height= 5.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    plot_bgcolor='white',
    showlegend=False
)

fig.update_xaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    linewidth=0.5,
    gridcolor='lightgrey'
)
fig.update_yaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    linewidth=0.5,
    gridcolor='lightgrey',
    range = [-50, 3000]
)

fig.show()

# Save figure as png
fig.write_image(f'../../figures/diatoms_vs_dinoflagellates/diatom_KO_abundance_vs_species.svg')

In [None]:
# Seems a better fit, let's compare both
# Predict y-values using the linear model
linear_predictions = linear_model.predict(plot_data[['num_species']])

# Calculate MSE
linear_mse = mean_squared_error(plot_data['KEGG_ko'], linear_predictions)

# Calculate R-squared score
linear_r2 = r2_score(plot_data['KEGG_ko'], linear_predictions)

# Predict y-values using the exponential model
exp_predictions = model(plot_data['num_species'], *popt)

# Calculate MSE
exp_mse = mean_squared_error(plot_data['KEGG_ko'], exp_predictions)

# Calculate R-squared score
exp_r2 = r2_score(plot_data['KEGG_ko'], exp_predictions)

# Compare the two models
print(f'Linear model: MSE = {linear_mse}, R2 = {linear_r2}')
print(f'Exponential model: MSE = {exp_mse}, R2 = {exp_r2}')

In [None]:
# Colour according to month
fig = px.scatter(plot_data, x="num_species", y="KEGG_ko", trendline="ols", color='month', color_discrete_sequence=px.colors.qualitative.Set1)

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=14,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 14 * pixels_per_cm,
    height= 10 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    plot_bgcolor='white',
)

fig.update_xaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    linewidth=0.5,
    gridcolor='lightgrey'
)
fig.update_yaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    linewidth=0.5,
    gridcolor='lightgrey',
    range = [-50, 3000]
)

fig.show()

# Save figure as png
fig.write_image(f'../../figures/diatoms_vs_dinoflagellates/diatom_KO_abundance_vs_species_per_month.svg')

#### Trophic feeding mode PFAMs
In this part we'll use the list of PFAMs from the study by [Lambert & Groussman, 2022](https://www.pnas.org/doi/suppl/10.1073/pnas.2100916119#data-availability). In dataset S6, they provide a list of commmon PFAMs that capture functional diversity between different feeding modes. 
We'll mine our dataset for the expression of these PFAMs, and see if we can find any patterns. From the above we can hypothesize that the diatom community is photoautotrophic, and the dinoflagellate community is mixotrophic. We'll see if this is reflected in the expression of feeding-mode related PFAMS.

In [None]:
# Read in the data from the paper, but skip first 3 rows
relevant_pfams = pd.read_excel('../../data/external/pnas.2100916119.sd06.xlsx', skiprows=3)

# Remove all NaN rows
relevant_pfams = relevant_pfams.dropna()

In [None]:
relevant_pfams.head()

In [None]:
print(len(relevant_pfams['PFAM_ID'].unique()))

In [None]:
# Now, create a new data_diato dataframe with PFAM expression data
functional_info = 'PFAMs'
aggregation_level = 'month'
data_diatoms_PFAM = data[data['class'] == 'Bacillariophyta'][[aggregation_level, functional_info, 'TPM']]

In [None]:
# Check how many of the relevant PFAMs are present in the data
data_diatoms_PFAM[functional_info].isin(relevant_pfams['short_name']).value_counts(normalize=True)

In [None]:
# Check to which clade these hits belong
relevant_pfams[relevant_pfams['short_name'].isin(data_diatoms_PFAM[functional_info])]['clade'].value_counts()

Looks promising, most PFAMs occur in the class that is linked to photoautotrophs. Let's check which of the relevant PFAMs is most expressed.

In [None]:
# Only retain the relevant PFAMs
data_diatoms_PFAM = data_diatoms_PFAM[data_diatoms_PFAM[functional_info].isin(relevant_pfams['short_name'])]

In [None]:
# What are the most abundant PFAMs?
data_diatoms_PFAM.groupby(functional_info)['TPM'].sum().sort_values(ascending=False).head(10)

In [None]:
data_diatoms_PFAM.head()

In [None]:

# Determine the genes with highest variance across samples
number_of_PFAMs = 15
# Log2 transform the data
plot_data = data_diatoms_PFAM.groupby([aggregation_level, functional_info])['TPM'].sum().reset_index()

# Add number of stations visited per month
samples_per_month = {
    'July_2020': 6,
    'August_2020': 6,
    'September_2020': 6,
    'November_2020': 6,
    'December_2020': 6,
    'January_2021': 5,
    'February_2021': 5,
    'April_2021': 4,
    'May_2021': 6,
    'June_2021': 6,
    'July_2021': 6
}

# Divide the TPM sum by number of samples per month to obtain the average TPM per month
plot_data['TPM'] = plot_data['TPM'] / plot_data[aggregation_level].map(samples_per_month)

# Divide the TPM sum by the total TPM sum per month dedicated to diatoms
plot_data['TPM'] = plot_data['TPM'] / plot_data.groupby(aggregation_level)['TPM'].transform('sum')

# Swap the short PFAM names for their functional descriptions

# Transform the data to the wide format
plot_data_wide = plot_data.pivot(index=aggregation_level, columns=functional_info, values='TPM')
# log2 transform the data
plot_data_log = np.log2(plot_data_wide + 1)

# Using non-log2 transformed data gives more weight to more expressed genes; these have higher variances
# Using log2 transformed data gives more weight to less expressed genes; and counterselects for highly expressed genes

# Transpose the data
print(plot_data_log.shape)
plot_data_log = plot_data_log.transpose()

# Re-order the columns according to month
month_order = ["July_2020", "August_2020", "September_2020", "November_2020", 
                "December_2020", "January_2021", "February_2021", "April_2021", 
                "May_2021", "June_2021", "July_2021"]

plot_data_log = plot_data_log.reindex(month_order, axis=1)

# Scale the data
plot_data_scaled = plot_data_log.apply(lambda x: (x - x.mean()) / x.std(), axis=1)

# Cluster similar rows together
plot_data_scaled = plot_data_scaled.reindex(plot_data_scaled.mean(axis=1).sort_values(ascending=False).index, axis=0)

# In the index, include the PFAM clade and functional descriptions to the short names between brackets
plot_data_scaled.index = plot_data_scaled.index.map(lambda x: "{} ({} - {})".format(x, relevant_pfams[relevant_pfams['short_name'] == x]['clade'].values[0], relevant_pfams[relevant_pfams['short_name'] == x]['function'].values[0]))

# Only retain the 15 most variable PFAMs
most_variable_PFAMs = plot_data_scaled.var(axis=1).sort_values(ascending=False).head(number_of_PFAMs).index
plot_data_scaled = plot_data_scaled.loc[most_variable_PFAMs]

In [None]:
# 
g = sns.clustermap(
    plot_data_scaled,
    figsize=(8, 6),
    cmap='mako',
    col_cluster=False,
    # Legend bar should be on the right
    cbar_kws={'orientation': 'vertical'},
    )

g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize=8)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), fontsize=8)
g.ax_heatmap.set_facecolor('#f7f7f7')
g.ax_heatmap.set_xlabel(None)
g.ax_heatmap.set_ylabel(None)

plt.show()

# Save figure as svg
g.savefig("../../figures/diatoms_vs_dinoflagellates/diatoms_heatmap_{}_most_variable_trophic_{}_per_{}.svg".format(number_of_genes, functional_info, aggregation_level), format='svg', dpi=600)

Now we'll do the same but with the 15 most expressed PFAMs.

In [None]:

# Determine the genes with highest variance across samples
number_of_PFAMs = 15
# Log2 transform the data
plot_data = data_diatoms_PFAM.groupby([aggregation_level, functional_info])['TPM'].sum().reset_index()

# Add number of stations visited per month
samples_per_month = {
    'July_2020': 6,
    'August_2020': 6,
    'September_2020': 6,
    'November_2020': 6,
    'December_2020': 6,
    'January_2021': 5,
    'February_2021': 5,
    'April_2021': 4,
    'May_2021': 6,
    'June_2021': 6,
    'July_2021': 6
}

# Divide the TPM sum by number of samples per month to obtain the average TPM per month
plot_data['TPM'] = plot_data['TPM'] / plot_data[aggregation_level].map(samples_per_month)

# Divide the TPM sum by the total TPM sum per month dedicated to diatoms
plot_data['TPM'] = plot_data['TPM'] / plot_data.groupby(aggregation_level)['TPM'].transform('sum')

# Swap the short PFAM names for their functional descriptions

# Transform the data to the wide format
plot_data_wide = plot_data.pivot(index=aggregation_level, columns=functional_info, values='TPM')
# log2 transform the data
plot_data_log = np.log2(plot_data_wide + 1)

# Using non-log2 transformed data gives more weight to more expressed genes; these have higher variances
# Using log2 transformed data gives more weight to less expressed genes; and counterselects for highly expressed genes

# Transpose the data
print(plot_data_log.shape)
plot_data_log = plot_data_log.transpose()

# Re-order the columns according to month
month_order = ["July_2020", "August_2020", "September_2020", "November_2020", 
                "December_2020", "January_2021", "February_2021", "April_2021", 
                "May_2021", "June_2021", "July_2021"]

plot_data_log = plot_data_log.reindex(month_order, axis=1)

# Scale the data
plot_data_scaled = plot_data_log.apply(lambda x: (x - x.mean()) / x.std(), axis=1)

# Cluster similar rows together
plot_data_scaled = plot_data_scaled.reindex(plot_data_scaled.mean(axis=1).sort_values(ascending=False).index, axis=0)

# In the index, include the PFAM clade and functional descriptions to the short names between brackets
plot_data_scaled.index = plot_data_scaled.index.map(lambda x: "{} ({} - {})".format(x, relevant_pfams[relevant_pfams['short_name'] == x]['clade'].values[0], relevant_pfams[relevant_pfams['short_name'] == x]['function'].values[0]))

# Only retain the 15 most expressed PFAMs
most_variable_PFAMs = plot_data_scaled.sum(axis=1).sort_values(ascending=False).head(number_of_PFAMs).index
plot_data_scaled = plot_data_scaled.loc[most_variable_PFAMs]

# 
g = sns.clustermap(
    plot_data_scaled,
    figsize=(8, 6),
    cmap='mako',
    col_cluster=False,
    # Legend bar should be on the right
    cbar_kws={'orientation': 'vertical'},
    )

g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize=8)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), fontsize=8)
g.ax_heatmap.set_facecolor('#f7f7f7')
g.ax_heatmap.set_xlabel(None)
g.ax_heatmap.set_ylabel(None)

plt.show()

# Save figure as svg
g.savefig("../../figures/diatoms_vs_dinoflagellates/diatoms_heatmap_{}_most_expressed_trophic_{}_per_{}.svg".format(number_of_genes, functional_info, aggregation_level), format='svg', dpi=600)

Now, we'll do the same with the full feature set (all 1046 gene families used in the model described in the paper).

In [None]:
PFAM_features = pd.read_excel('../../data/external/pnas.2100916119.sd03.xlsx', skiprows=1, header=1)

PFAM_features = PFAM_features.dropna()

PFAM_features.head()

In [None]:
# Of the PFAM features, which are most expressed?
## Subset the data to only include the PFAMs
functional_info = 'PFAMs'
aggregation_level = 'month'
data_diatoms_PFAM = data[data['class'] == 'Bacillariophyta'][[aggregation_level, functional_info, 'TPM']]
diatom_features = data_diatoms_PFAM[data_diatoms_PFAM[functional_info].isin(PFAM_features['Name'])]

# Group by month, PFAM and sum the TPM values
diatom_features = diatom_features.groupby([aggregation_level, functional_info])['TPM'].sum().reset_index()

# What are the most abundant PFAMs?
diatom_features.groupby(functional_info)['TPM'].sum().sort_values(ascending=False).head(10)

In [None]:
# Visualize the most abundant PFAMs
number_of_PFAMs = 15
# Log2 transform the data
plot_data = diatom_features.groupby([aggregation_level, functional_info])['TPM'].sum().reset_index()

# Divide the TPM sum by number of samples per month to obtain the average TPM per month
plot_data['TPM'] = plot_data['TPM'] / plot_data[aggregation_level].map(samples_per_month)

# Divide the TPM sum by the total TPM sum per month dedicated to diatoms
plot_data['TPM'] = plot_data['TPM'] / plot_data.groupby(aggregation_level)['TPM'].transform('sum')

# Swap the short PFAM names for their functional descriptions

# Transform the data to the wide format
plot_data_wide = plot_data.pivot(index=aggregation_level, columns=functional_info, values='TPM')
# log2 transform the data
plot_data_log = np.log2(plot_data_wide + 1)

# Using non-log2 transformed data gives more weight to more expressed genes; these have higher variances
# Using log2 transformed data gives more weight to less expressed genes; and counterselects for highly expressed genes

# Transpose the data
print(plot_data_log.shape)
plot_data_log = plot_data_log.transpose()

plot_data_log = plot_data_log.reindex(month_order, axis=1)

# Scale the data
plot_data_scaled = plot_data_log.apply(lambda x: (x - x.mean()) / x.std(), axis=1)

# Cluster similar rows together
plot_data_scaled = plot_data_scaled.reindex(plot_data_scaled.mean(axis=1).sort_values(ascending=False).index, axis=0)

# Only retain the 15 most variable PFAMs
most_variable_PFAMs = plot_data_scaled.var(axis=1).sort_values(ascending=False).head(number_of_PFAMs).index
plot_data_scaled = plot_data_scaled.loc[most_variable_PFAMs]

In [None]:
# 
g = sns.clustermap(
    plot_data_scaled,
    figsize=(8, 6),
    cmap='mako',
    col_cluster=False,
    # Legend bar should be on the right
    cbar_kws={'orientation': 'vertical'},
    )

g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize=8)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), fontsize=8)
g.ax_heatmap.set_facecolor('#f7f7f7')
g.ax_heatmap.set_xlabel(None)
g.ax_heatmap.set_ylabel(None)

plt.show()

# Save figure as svg
# g.savefig("../../figures/diatoms_vs_dinoflagellates/diatoms_heatmap_{}_most_variable_trophic_{}_per_{}.svg".format(number_of_genes, functional_info, aggregation_level), format='svg', dpi=600)

In [None]:
# Now we'll do the same but with the 15 most highly expressed PFAMs
# Visualize the most abundant PFAMs
number_of_PFAMs = 15
# Log2 transform the data
plot_data = diatom_features.groupby([aggregation_level, functional_info])['TPM'].sum().reset_index()

# Divide the TPM sum by number of samples per month to obtain the average TPM per month
plot_data['TPM'] = plot_data['TPM'] / plot_data[aggregation_level].map(samples_per_month)

# Divide the TPM sum by the total TPM sum per month dedicated to diatoms
plot_data['TPM'] = plot_data['TPM'] / plot_data.groupby(aggregation_level)['TPM'].transform('sum')

# Swap the short PFAM names for their functional descriptions

# Transform the data to the wide format
plot_data_wide = plot_data.pivot(index=aggregation_level, columns=functional_info, values='TPM')
# log2 transform the data
plot_data_log = np.log2(plot_data_wide + 1)

# Using non-log2 transformed data gives more weight to more expressed genes; these have higher variances
# Using log2 transformed data gives more weight to less expressed genes; and counterselects for highly expressed genes

# Transpose the data
print(plot_data_log.shape)
plot_data_log = plot_data_log.transpose()

plot_data_log = plot_data_log.reindex(month_order, axis=1)

# Scale the data
plot_data_scaled = plot_data_log.apply(lambda x: (x - x.mean()) / x.std(), axis=1)

# Cluster similar rows together
plot_data_scaled = plot_data_scaled.reindex(plot_data_scaled.mean(axis=1).sort_values(ascending=False).index, axis=0)

# Only retain the 15 most expressed PFAMs
most_variable_PFAMs = plot_data_scaled.sum(axis=1).sort_values(ascending=False).head(number_of_PFAMs).index
plot_data_scaled = plot_data_scaled.loc[most_variable_PFAMs]

# 
g = sns.clustermap(
    plot_data_scaled,
    figsize=(8, 6),
    cmap='mako',
    col_cluster=False,
    # Legend bar should be on the right
    cbar_kws={'orientation': 'vertical'},
    )

g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize=8)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), fontsize=8)
g.ax_heatmap.set_facecolor('#f7f7f7')
g.ax_heatmap.set_xlabel(None)
g.ax_heatmap.set_ylabel(None)

plt.show()

# Save figure as svg
# g.savefig("../../figures/diatoms_vs_dinoflagellates/diatoms_heatmap_{}_most_variable_trophic_{}_per_{}.svg".format(number_of_genes, functional_info, aggregation_level), format='svg', dpi=600)

#### Trophic feeding mode of diatoms
I've predicted the trophic feeding mode of several species using the machine learning model of the paper by [Lambert & Groussman, 2022](https://doi.org/10.1073/pnas.2100916119), see [here](./trophic_mode_prediction.ipynb). We'll use the data to see how the trophic feeding mode of diatoms changes over time.

In [None]:
# Load the data
#trophic_predictions = pd.read_csv('../../data/analysis/phylodb_trophic_mode_predictions.csv')
trophic_predictions = pd.read_csv('../../data/analysis/eukprot_trophic_mode_predictions.csv')
# Only retain the diatoms
#trophic_predictions = trophic_predictions[trophic_predictions['class'] == 'Bacillariophyta']
trophic_predictions = trophic_predictions[trophic_predictions['Taxogroup2_UniEuk'] == 'Diatomeae']
print(len(trophic_predictions))
trophic_predictions.head()

In [None]:
# Plot the relative abundance of the three trophic modes per month
# Calculate relative abundance of each prediction per month
rel_abundance = (trophic_predictions[['month', 'station', 'prediction']].groupby(['month', 'prediction'])
                       .size()
                       .groupby(level=0).apply(lambda x: 100 * x / x.sum())
                       .reset_index(name='relative abundance'))
# Set the color palette
color_map = {
    'Phot': 'green',
    'Mix': 'black',
    'Het': 'red'
}
fig = px.histogram(rel_abundance.sort_values("month", ascending=False),
             x = "relative abundance",
             y = "month",
             color = "prediction",
             orientation='h',
             color_discrete_map=color_map,
             category_orders={"month": ["July_2020", "August_2020", "September_2020",
                "November_2020", "December_2020", "January_2021",
                "February_2021", "April_2021", "May_2021",
                "June_2021", "July_2021"]})

pixels_per_cm = 37.7952755906

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8  # Set the font size
    ),
    autosize=False,
    width= 8.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='Ratio of predicted trophic modes (%)',
    yaxis_title_text=None,
)

fig.show()

# save figure as svg
#fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatom_trophic_mode_relative_abundance_phylodb.svg")
fig.write_image("../../figures/diatoms_vs_dinoflagellates/diatom_trophic_mode_relative_abundance_eukprot.svg")

In [None]:
# Extract the relevant columns
# tax_level='genus'
tax_level='Genus_UniEuk'

df = trophic_predictions[[tax_level, 'month', 'prediction']]
# Convert 'prediction' into numerical categories
df['prediction'] = df['prediction'].map({'Phot':1, 'Mix':2, 'Het':3})

# Pivot the DataFrame to use for heatmap
df_pivot = df.pivot_table(index=tax_level, columns='month', values='prediction', 
                          aggfunc=lambda x: int(stats.mode(x[x!=0])[0][0] if any(x!=0) else 0))

# Fill NA with a specific category (e.g., 0 for 'no prediction')
df_pivot = df_pivot.fillna(0)

# Ensure the values are integer
df_pivot = df_pivot.astype(int)

# Set the order of the months
month_order = ["July_2020", "August_2020", "September_2020", "November_2020", 
                "December_2020", "January_2021", "February_2021", "April_2021", 
                "May_2021", "June_2021", "July_2021"]
df_pivot = df_pivot.reindex(month_order, axis=1)

# Set the order of the genera to be the same as in the above relative abundance plot
df_pivot = df_pivot.reindex(diatom_legend_order, axis=0)

# Create a color map
cmap = mcolors.ListedColormap(['lightgrey', 'green', 'yellow', 'red'])

# Create the heatmap with the color map
plt.figure(figsize=(4.5, 4))
sns.heatmap(df_pivot, cmap=cmap, annot=False, cbar=False, vmin=0, vmax=3)

# Change layout to match the relative abundance plot
plt.tight_layout()

# Remove the x and  y-axis label
plt.xlabel(None)
plt.ylabel(None)

# Font size should be 8, and font family Times New Roman
plt.rcParams.update({'font.size': 8, 'font.family': 'Times New Roman'})

# Save figure as svg
#plt.savefig("../../figures/diatoms_vs_dinoflagellates/diatom_trophic_mode_genus_consensus_heatmap_phylodb.svg", format='svg', dpi=600)
plt.savefig("../../figures/diatoms_vs_dinoflagellates/diatom_trophic_mode_genus_consensus_heatmap_eukprot.svg", format='svg', dpi=600)

plt.show()


#### PCA

For a PCA, we need a new matrix samples x genes. We'll also need a new metadata dataframe with the samples, months, stations, and environmental variables. 

In [None]:
# Specify how the matrix will look like (in the end its aggregation_level x functional_info)
functional_info = 'KEGG_ko'
aggregation_level = 'sample'

# Prepare gene expression data and load environmental variables
#data_diatoms = data[data['class'] == 'Bacillariophyta'][[aggregation_level, functional_info, 'TPM']]
data_diatoms = data[data['Taxogroup2_UniEuk'] == 'Diatomeae'][[aggregation_level, functional_info, 'TPM']]

# if processing kegg data, extra preprocessing is required: 
# Comment or uncomment the following line if multiple values are assigned to a transcript!
## split KEGG identifiers up
data_diatoms = data_diatoms.assign(**{functional_info:data_diatoms[functional_info].str.split(',')})
# Now we can explode the functional column
data_diatoms = data_diatoms.explode(functional_info)
# Cut off that weird part of the identifier, if necessary!
data_diatoms[functional_info] = data_diatoms[functional_info].str.split(":", expand=True).drop(columns=0)

# Group by functional information and sample values, sum TPM
data_diatoms = data_diatoms.groupby([functional_info, aggregation_level]).sum().reset_index()

# Remove the rows with no functional annotation
data_diatoms = data_diatoms[data_diatoms[functional_info].notna()]
data_diatoms = data_diatoms[data_diatoms[functional_info] != '-']

# Transform the data to the wide format
data_diatoms = data_diatoms.pivot(index=aggregation_level, columns=functional_info, values='TPM')

# Log2 transform the data
data_diatoms_log = np.log2(data_diatoms + 1)

# Scale the features (columns, TPM values of every prediction)
data_diatoms_scaled = pd.DataFrame(StandardScaler().fit_transform(data_diatoms_log),
                                   index=data_diatoms.index,
                                   columns=data_diatoms.columns)
## Scaling removes the mean and scales to unit variance, the resulting values are z-scores

# Load the environmental variables
env_variables = pd.read_csv("../../data/environmental/samples_env.csv", sep=";")
env_variables = env_variables.set_index('sample')

# Merge environmental variables with gene expression data
pca_data = data_diatoms_log.merge(env_variables, left_index=True, right_on='sample', how='left')

# Perform PCA on gene expression data, ignore columns that are also in the environmental variables
pca = PCA(n_components=2)
principalComponents = pca.fit_transform(pca_data.drop(columns=env_variables.columns))

# Add principal components to the data
pca_data['PC1'] = principalComponents[:, 0]
pca_data['PC2'] = principalComponents[:, 1]

# Create a color dictionary for the months
month_color_dict = dict(zip(data['month'].unique(), sns.color_palette('tab10', n_colors=len(data['month'].unique()))))
month_order = ["July_2020", "August_2020", "September_2020", "November_2020", 
                "December_2020", "January_2021", "February_2021", "April_2021", 
                "May_2021", "June_2021", "July_2021"]

# Set figure size and font scale
cm = 1/2.54
plt.figure(figsize=(12*cm, 14*cm))
sns.set(style='white', font_scale=1)

# Plot PCA biplot with the color of the month corrresponding to the month of the sample
# and the shape of the point corresponding to the station of the sample
sns.scatterplot(data=pca_data,
                x='PC1',
                y='PC2',
                hue='month',
                hue_order=month_order,
                style='station',
                palette=month_color_dict,
                s=40, edgecolor='black', linewidth=0.5)

plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)
# Plot labels with the variance explained by each principal component
plt.xlabel('Principal Component 1 ({}%)'.format(round(pca.explained_variance_ratio_[0] * 100, 2)))
plt.ylabel('Principal Component 2 ({}%)'.format(round(pca.explained_variance_ratio_[1] * 100, 2)))

# Plot arrows indicating the correlation between the principal components and the environmental variables
## Linear regression to determine the correlation between the principal components and the environmental variables
columns_of_interest = ['NO3', 'PO4', 'Si', 'SPM', 'salinity', 'Temperature']

# Calculate the correlation matrix between the principal components and the environmental variables
corr_matrix = np.corrcoef(pca_data[['PC1', 'PC2'] + columns_of_interest].dropna().T)

# Get the correlation between the first two principal components and each environmental parameter
corr_PC1_env = corr_matrix[0, 2:]
corr_PC2_env = corr_matrix[1, 2:]

# Plot the arrows representing the correlations
for i, env_param in enumerate(columns_of_interest):
    plt.arrow(0, 0, corr_PC1_env[i]*max(pca_data['PC1']), corr_PC2_env[i]*max(pca_data['PC2']), head_width=0.05, color='gray')
    # corr_PC1_env[i] specifies the end coordinate of the arrow in the x direction
    # This coordinate is multiplied with max(pca_data['PC1']) to scale the arrow to the length of the principal component for visualization
    plt.text(corr_PC1_env[i]*max(pca_data['PC1'])*1.05, corr_PC2_env[i]*max(pca_data['PC2'])*1.05, env_param, fontsize=12, color='gray')
    # Same as above, but the text is placed 5% further away from the arrow

# Save figure as svg
# plt.savefig("../../figures/diatoms_vs_dinoflagellates/diatom_PCA_{}_per_{}_phylodb.svg".format(functional_info, aggregation_level), format='svg', dpi=600)
plt.savefig("../../figures/diatoms_vs_dinoflagellates/diatom_PCA_{}_per_{}_eukprot.svg".format(functional_info, aggregation_level), format='svg', dpi=600)

plt.show()

In [None]:
# Save diatom expression data to a csv file
#data_diatoms.to_csv('../../data/analysis/kegg_expression_diatoms_phylodb.csv')
data_diatoms.to_csv('../../data/analysis/kegg_expression_diatoms_eukprot.csv')

## 3. Dinoflagellates

In [None]:
# First, let's get a list of all the unique dinoflagellate species
#data[(data['class'] == 'Dinophyceae')  & (data['p_ident'] >= 0.98)]['species'].unique()
data[(data['Taxogroup2_UniEuk'].isin(['Dinophyceae', 'core-Noctilucales']))  & (data['p_ident'] >= 0.98)]['Name_to_Use'].unique()

### 3.1 Dinoflagellate abundance
For dinoflagellates, we'll do the same analysis as for diatoms.

In [None]:
#tax_level = 'genus'
tax_level = 'Genus_UniEuk'
aggregation_level = 'month'

#data_dinoflagellates = data[(data['class'] == 'Dinophyceae') & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]
data_dinoflagellates = data[(data['Taxogroup2_UniEuk'].isin(['Dinophyceae', 'core-Noctilucales'])) & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]

# Remove transcripts below a certain TPM threshold
data_dinoflagellates = data_dinoflagellates[data_dinoflagellates['TPM'] > 1]

# Group by month and species, sum TPM
data_dinoflagellates = data_dinoflagellates.groupby([aggregation_level, tax_level]).sum().reset_index()

data_dinoflagellates['month'] = pd.Categorical(data_dinoflagellates['month'], ["July_2020", "August_2020", "September_2020", 
                                                        "November_2020", "December_2020", "January_2021", 
                                                        "February_2021", "April_2021", "May_2021", 
                                                        "June_2021", "July_2021"])

# Normalise mean of TPM values to the total mean TPM of that month
data_dinoflagellates["rel_expression_per_month"] = data_dinoflagellates.TPM / data_dinoflagellates.groupby('month').TPM.transform('sum')

# Combine low-abundant groups
rare_groups = data_dinoflagellates['rel_expression_per_month'] <= 0.02
data_dinoflagellates.loc[rare_groups, tax_level] = 'Rare'

# Print unique species
print(data_dinoflagellates[tax_level].unique())
# Inspect data
data_dinoflagellates.head()

In [None]:
# Add conversion variable
pixels_per_cm = 37.79527559055118

# Plot
fig = px.histogram(data_dinoflagellates.sort_values("month", ascending=False), 
                x="rel_expression_per_month", 
                y="month", 
                color=tax_level,
                orientation='h',
                color_discrete_map={
                    "Rare": "#545454",
                    "Alexandrium": "#ebac23",
                    "Amphidinium": "#b80058",
                    "Noctiluca": "#008cf9",
                    "Tripos": "#006e00",
                    "Kryptoperidinium": "#00bbad",
                    "Oxyrrhis": "#d163e6",
                    "Karenia": "#b24502",
                    "Symbiodinium": "#ff9287",
                    "Ceratium": "#5954d6",
                    "Durinskia": "#00c6f8",
                    "Heterocapsa": "#878500",
                    "Scripsiella": "#00a76c",
                    'Polarella': "#FC8D62",
                    'Gonyaulax': "#8DA0CB",
                    'Azadinium': "#E78AC3",
                    'Lingulodinium': "#A6D854",
                    'Karlodinium': "#B3B3CC"
                },
                category_orders={"month": ["July_2020", "August_2020", "September_2020",
                                            "November_2020", "December_2020", "January_2021",
                                            "February_2021", "April_2021", "May_2021",
                                            "June_2021", "July_2021"],
                "Genus_UniEuk": ["Rare", "Tripos", "Noctiluca"]},
                # text_auto='.2f'
                )

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 8.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='% TPM of total sum',
    yaxis_title_text=None,
)

fig.show()

# Save figure as png
#fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_relative_expression_per_month_{}_phylodb.png".format(tax_level), scale=1)
#fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_relative_expression_per_month_{}_phylodb.svg".format(tax_level), scale=1)
fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_relative_expression_per_month_{}_eukprot.png".format(tax_level), scale=1)
fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_relative_expression_per_month_{}_eukprot.svg".format(tax_level), scale=1)

In [None]:
# Store the order of the genera in the legend
dinoflagellate_legend_order = []
for i in range(len(fig.data)):
    dinoflagellate_legend_order.append(fig.data[i].name)
print(dinoflagellate_legend_order)

#### Spatial distribution

In [None]:
#tax_level = 'genus'
tax_level = 'Genus_UniEuk'
aggregation_level = 'month'
aggregation_level2 = 'station'

#data_dinoflagellates = data[(data['class'] == 'Dinophyceae') & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPM']]
data_dinoflagellates = data[(data['Taxogroup2_UniEuk'].isin(['Dinophyceae', 'core-Noctilucales'])) & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPM']]

# Remove transcripts below a certain TPM threshold
data_dinoflagellates = data_dinoflagellates[data_dinoflagellates['TPM'] > 1]

# Group by month, station and species, sum TPM
data_dinoflagellates = data_dinoflagellates.groupby([aggregation_level, aggregation_level2, tax_level]).sum().reset_index()

data_dinoflagellates['month'] = pd.Categorical(data_dinoflagellates['month'], ["July_2020", "August_2020", "September_2020", 
                                                        "November_2020", "December_2020", "January_2021", 
                                                        "February_2021", "April_2021", "May_2021", 
                                                        "June_2021", "July_2021"])

dinoflagellate_genera = []
# Plot per station
for station in data_dinoflagellates[aggregation_level2].unique():
    plot_data = data_dinoflagellates[data_dinoflagellates[aggregation_level2] == station]
    # Normalise mean of TPM values to the total mean TPM of that month
    plot_data["rel_expression_per_month"] = plot_data.TPM / plot_data.groupby('month').TPM.transform('sum')

    # Combine low-abundant groups
    rare_groups = plot_data['rel_expression_per_month'] <= 0.02
    plot_data.loc[rare_groups, tax_level] = 'Rare'
    # Add all unique genera to dictionary if not already present
    for genus in plot_data[tax_level].unique():
        dinoflagellate_genera.append(genus) if genus not in dinoflagellate_genera else None
    fig = px.histogram(plot_data.sort_values("month", ascending=False), 
                    x="rel_expression_per_month", 
                    y="month", 
                    color=tax_level,
                    orientation='h',
                    # text_auto='.2f',
                    color_discrete_map={
                        "Rare": "#545454",
                        "Alexandrium": "#ebac23",
                        "Amphidinium": "#b80058",
                        "Noctiluca": "#008cf9",
                        "Tripos": "#006e00",
                        "Kryptoperidinium": "#00bbad",
                        "Oxyrrhis": "#d163e6",
                        "Karenia": "#b24502",
                        "Symbiodinium": "#ff9287",
                        "Ceratium": "#5954d6",
                        "Durinskia": "#00c6f8",
                        "Heterocapsa": "#878500",
                        "Scripsiella": "#00a76c",
                        'Polarella': "#FC8D62",
                        'Gonyaulax': "#8DA0CB",
                        'Azadinium': "#E78AC3",
                        'Lingulodinium': "#A6D854",
                        'Karlodinium': "#B3B3CC"
                    },
                    # Specify all the months that need to be included, 
                    # even if no sample has been taken
                    category_orders={"month": ["July_2020", "August_2020", "September_2020",
                                        "November_2020", "December_2020", "January_2021",
                                        "February_2021", "April_2021", "May_2021",
                                        "June_2021", "July_2021"],
                                     "genus": ["Rare", "Alexandrium", "Amphidinium", "Noctiluca", 
                                               "Symbiodinium", "Ceratium", "Tripos", "Kryptoperidinium", 
                                               "Oxyrrhis", "Karenia", "Durinskia", "Heterocapsa", "Scripsiella"],
                                     #"Genus_UniEuk": ["Rare", "Noctiluca", "Tripos", "Paragymnodinium", "Pelagodinium", "Polykrikos"]
                                     }
                    )

    fig.update_layout(
        font=dict(
            family="Times New Roman, serif",  # Set the font family to Times New Roman
            size=8,  # Set the font size
            color="#000000"  # Set the font color
        ),
        autosize=False,
        width= 8.5 * pixels_per_cm,
        height= 7.5 * pixels_per_cm,
        margin=dict( # Set the margins
            l=0,  # Left margin
            r=25,  # Right margin
            b=25,  # Bottom margin
            t=25  # Top margin
        ),
        xaxis_title_text='% TPM of total sum',
        yaxis_title_text=None,
        title_text=station
    )

    fig.show()

    # Save figure as png
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_relative_expression_per_month_{}_at_{}_phylodb.png".format(tax_level, station), scale=1)
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_relative_expression_per_month_{}_at_{}_phylodb.svg".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_relative_expression_per_month_{}_at_{}_eukprot.png".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_relative_expression_per_month_{}_at_{}_eukprot.svg".format(tax_level, station), scale=1)

In [None]:
dinoflagellate_genera
# Now create colours for all of these genera and rerun the above code

#### Total TPM of dinoflagellates per month

In [None]:
#tax_level = 'genus'
tax_level = 'Genus_UniEuk'
aggregation_level = 'month'

#data_dinoflagellates = data[(data['class'] == 'Dinophyceae') & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]
data_dinoflagellates = data[(data['Taxogroup2_UniEuk'].isin(['Dinophyceae', 'core-Noctilucales'])) & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]

# Remove transcripts below a certain TPM threshold
data_dinoflagellates = data_dinoflagellates[data_dinoflagellates['TPM'] > 1]

# Group by month and species, sum TPM
data_dinoflagellates = data_dinoflagellates.groupby([aggregation_level, tax_level]).sum().reset_index()

data_dinoflagellates['month'] = pd.Categorical(data_dinoflagellates['month'], ["July_2020", "August_2020", "September_2020", 
                                                        "November_2020", "December_2020", "January_2021", 
                                                        "February_2021", "April_2021", "May_2021", 
                                                        "June_2021", "July_2021"])

# Normalise mean of TPM values to the total mean TPM of that month
data_dinoflagellates["rel_expression_per_month"] = data_dinoflagellates.TPM / data_dinoflagellates.groupby('month').TPM.transform('sum')

# Combine low-abundant groups
rare_groups = data_dinoflagellates['rel_expression_per_month'] <= 0.02
data_dinoflagellates.loc[rare_groups, tax_level] = 'Rare'

# Plot the total TPM per month
pixels_per_cm = 37.79527559055118
fig = px.histogram(data_dinoflagellates.sort_values("month", ascending=False),
            x = "TPM",
            y = "month",
            color = tax_level,
            color_discrete_map={
                "Rare": "#545454",
                "Alexandrium": "#ebac23",
                "Amphidinium": "#b80058",
                "Noctiluca": "#008cf9",
                "Tripos": "#006e00",
                "Kryptoperidinium": "#00bbad",
                "Oxyrrhis": "#d163e6",
                "Karenia": "#b24502",
                "Symbiodinium": "#ff9287",
                "Ceratium": "#5954d6",
                "Durinskia": "#00c6f8",
                "Heterocapsa": "#878500",
                "Scripsiella": "#00a76c",
                'Polarella': "#FC8D62",
                'Gonyaulax': "#8DA0CB",
                'Azadinium': "#E78AC3",
                'Lingulodinium': "#A6D854",
                'Karlodinium': "#B3B3CC"
            },
            #color_discrete_map={"Rare": "#545454", "Noctiluca": "#008cf9", "Tripos": "#006e00",
            #                         "Paragymnodinium": "#9932CC", "Pelagodinium": "#40E0D0",
            #                         "Polykrikos": "#FFA500"},
            orientation='h',
            category_orders={"month": ["July_2020", "August_2020", "September_2020",
                "November_2020", "December_2020", "January_2021",
                "February_2021", "April_2021", "May_2021",
                "June_2021", "July_2021"],
                "genus": dinoflagellate_legend_order}
                #"Genus_UniEuk": ["Rare", "Tripos", "Noctiluca"]},
            )

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 8.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='Total dinoflagellate TPM per month',
    yaxis_title_text=None,
    # Set range of x-axis
    xaxis_range=[0, 550000],
)

fig.show()

# save figure as svg
#fig.write_image(f"../../figures/diatoms_vs_dinoflagellates/dinoflagellates_total_TPM_per_month_{tax_level}_phylodb.svg", scale=1)
fig.write_image(f"../../figures/diatoms_vs_dinoflagellates/dinoflagellates_total_TPM_per_month_{tax_level}_eukprot.svg", scale=1)

In [None]:
# Plot the total TPM per month per station
#tax_level = 'genus'
tax_level = 'Genus_UniEuk'
aggregation_level = 'month'
aggregation_level2 = 'station'

#data_dinoflagellates = data[(data['class'] == 'Dinophyceae') & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPM']]
data_dinoflagellates = data[(data['Taxogroup2_UniEuk'].isin(['Dinophyceae', 'core-Noctilucales'])) & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPM']]

# Remove transcripts below a certain TPM threshold
data_dinoflagellates = data_dinoflagellates[data_dinoflagellates['TPM'] > 1]

# Group by month and species, sum TPM
data_dinoflagellates = data_dinoflagellates.groupby([aggregation_level, aggregation_level2, tax_level]).sum().reset_index()

data_dinoflagellates['month'] = pd.Categorical(data_dinoflagellates['month'], ["July_2020", "August_2020", "September_2020", 
                                                        "November_2020", "December_2020", "January_2021", 
                                                        "February_2021", "April_2021", "May_2021", 
                                                        "June_2021", "July_2021"])

# Normalise mean of TPM values to the total mean TPM of that month
data_dinoflagellates["rel_expression_per_month"] = data_dinoflagellates.TPM / data_dinoflagellates.groupby('month').TPM.transform('sum')

# Combine low-abundant groups
rare_groups = data_dinoflagellates['rel_expression_per_month'] <= 0.02
data_dinoflagellates.loc[rare_groups, tax_level] = 'Rare'

fig = px.histogram(data_dinoflagellates.sort_values("month", ascending=False),
             x = "TPM",
             y = "month",
             color = 'station',
            color_discrete_map={
                "ZG02": "#8c613c",
                "120": "#956cb4",
                "330": "#ee854a",
                "130": "#4878d0",
                "780": "#d65f5f",
                "700": "#6acc64"},
             orientation='h',
            category_orders={"month": ["July_2020", "August_2020", "September_2020",
                "November_2020", "December_2020", "January_2021",
                "February_2021", "April_2021", "May_2021",
                "June_2021", "July_2021"]},)

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 8.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='Total dinoflagellate TPM per month',
    yaxis_title_text=None,
)

fig.show()

# save figure as svg
#fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_total_TPM_per_month_station_phylodb.svg", scale=1)
fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_total_TPM_per_month_station_eukprot.svg", scale=1)

#### Spatial distribution

In [None]:
# Define the taxonomic level and aggregation levels
#tax_level = 'genus'
tax_level = 'Genus_UniEuk'
aggregation_level = 'month'
aggregation_level2 = 'station'

# Filter the data for dinoflagellates
#data_dinoflagellates = data[(data['class'] == 'Dinophyceae') & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPM']]
#data_dinoflagellates = data[(data['Taxogroup2_UniEuk'].isin(['Dinophyceae', 'core-Noctilucales'])) & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPM']]

# Remove transcripts below a certain TPM threshold
data_dinoflagellates = data_dinoflagellates[data_dinoflagellates['TPM'] > 1]

# Group by month, station, and species, sum TPM
data_dinoflagellates = data_dinoflagellates.groupby([aggregation_level, aggregation_level2, tax_level]).sum().reset_index()

# Specify the desired order of months
month_order = ["July_2020", "August_2020", "September_2020", "November_2020", "December_2020", "January_2021",
               "February_2021", "April_2021", "May_2021", "June_2021", "July_2021"]

# Plot total TPM per station
pixels_per_cm = 37.79527559055118
for station in data_dinoflagellates[aggregation_level2].unique():
    plot_data = data_dinoflagellates[data_dinoflagellates[aggregation_level2] == station]
    # Normalise mean of TPM values to the total mean TPM of that month
    plot_data["rel_expression_per_month"] = plot_data.TPM / plot_data.groupby('month').TPM.transform('sum')

    # Combine low-abundant groups
    rare_groups = plot_data['rel_expression_per_month'] <= 0.02
    plot_data.loc[rare_groups, tax_level] = 'Rare'
    
    fig = px.histogram(plot_data, x='TPM', y=aggregation_level, color=tax_level,
            color_discrete_map={
                "Rare": "#545454",
                "Alexandrium": "#ebac23",
                "Amphidinium": "#b80058",
                "Noctiluca": "#008cf9",
                "Tripos": "#006e00",
                "Kryptoperidinium": "#00bbad",
                "Oxyrrhis": "#d163e6",
                "Karenia": "#b24502",
                "Symbiodinium": "#ff9287",
                "Ceratium": "#5954d6",
                "Durinskia": "#00c6f8",
                "Heterocapsa": "#878500",
                "Scripsiella": "#00a76c",
                'Polarella': "#FC8D62",
                'Gonyaulax': "#8DA0CB",
                'Azadinium': "#E78AC3",
                'Lingulodinium': "#A6D854",
                'Karlodinium': "#B3B3CC"
            },
            category_orders={aggregation_level: month_order},
            title=station)
    
    fig.update_layout(
        font=dict(
            family="Times New Roman, serif",
            size=8,
            color="#000000"
        ),
        autosize=False,
        width=8.5 * pixels_per_cm,
        height=7.5 * pixels_per_cm,
        margin=dict(
            l=0,
            r=25,
            b=25,
            t=25
        ),
        xaxis_title_text='Month',
        yaxis_title_text='Total dinoflagellate TPM',
        legend_title_text=tax_level,
        # Set range of x-axis
        xaxis_range=[0, 250000]
    )
    
    fig.show()

    # Save figure as PNG and SVG
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_total_TPM_{}_at_{}_phylodb.png".format(tax_level, station), scale=1)
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_total_TPM_{}_at_{}_phylodb.svg".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_total_TPM_{}_at_{}_eukprot.png".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_total_TPM_{}_at_{}_eukprot.svg".format(tax_level, station), scale=1)

### TPL

In [None]:
# Define the taxonomic level and aggregation levels
#tax_level = 'genus'
tax_level = 'Genus_UniEuk'
aggregation_level = 'month'
aggregation_level2 = 'station'

# Filter the data for dinoflagellates
#data_dinoflagellates = data[(data['class'] == 'Dinophyceae') & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPL']]
data_dinoflagellates = data[(data['Taxogroup2_UniEuk'].isin(['Dinophyceae', 'core-Noctilucales'])) & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPL']]

# Remove transcripts below a certain TPL threshold
data_dinoflagellates = data_dinoflagellates[data_dinoflagellates['TPL'] > 1]

# Group by month, station, and species, sum TPL
data_dinoflagellates = data_dinoflagellates.groupby([aggregation_level, aggregation_level2, tax_level]).sum().reset_index()

# Specify the desired order of months
month_order = ["July_2020", "August_2020", "September_2020", "November_2020", "December_2020", "January_2021",
               "February_2021", "April_2021", "May_2021", "June_2021", "July_2021"]

# Plot total TPL per station
pixels_per_cm = 37.79527559055118
for station in data_dinoflagellates[aggregation_level2].unique():
    plot_data = data_dinoflagellates[data_dinoflagellates[aggregation_level2] == station]
    # Normalise mean of TPL values to the total mean TPL of that month
    plot_data["rel_expression_per_month"] = plot_data.TPL / plot_data.groupby('month').TPL.transform('sum')

    # Combine low-abundant groups
    rare_groups = plot_data['rel_expression_per_month'] <= 0.02
    plot_data.loc[rare_groups, tax_level] = 'Rare'
    
    fig = px.histogram(plot_data, x='rel_expression_per_month', y=aggregation_level, color=tax_level,
            color_discrete_map={
                "Rare": "#545454",
                "Alexandrium": "#ebac23",
                "Amphidinium": "#b80058",
                "Noctiluca": "#008cf9",
                "Tripos": "#006e00",
                "Kryptoperidinium": "#00bbad",
                "Oxyrrhis": "#d163e6",
                "Karenia": "#b24502",
                "Symbiodinium": "#ff9287",
                "Ceratium": "#5954d6",
                "Durinskia": "#00c6f8",
                "Heterocapsa": "#878500",
                "Scripsiella": "#00a76c",
                'Polarella': "#FC8D62",
                'Gonyaulax': "#8DA0CB",
                'Azadinium': "#E78AC3",
                'Lingulodinium': "#A6D854",
                'Karlodinium': "#B3B3CC"
            },
            category_orders={aggregation_level: month_order,
                             tax_level: ['Rare', 'Noctiluca', 'Alexandrium', 'Durinskia', 'Kryptoperidinium', 'Gymnodinium', 'Pelagodinium', 'Polykrikos', 'Tripos']},
            title=station)
    
    fig.update_layout(
        font=dict(
            family="Times New Roman, serif",
            size=8,
            color="#000000"
        ),
        autosize=False,
        width=8.5 * pixels_per_cm,
        height=7.5 * pixels_per_cm,
        margin=dict(
            l=0,
            r=25,
            b=25,
            t=25
        ),
        xaxis_title_text='Month',
        yaxis_title_text='Relative dinoflagellate TPL',
        legend_title_text=tax_level
    )
    
    fig.show()

    # Save figure as PNG and SVG
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_relative_TPL_{}_at_{}_phylodb.png".format(tax_level, station), scale=1)
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_relative_TPL_{}_at_{}_phylodb.svg".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_relative_TPL_{}_at_{}_eukprot.png".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_relative_TPL_{}_at_{}_eukprot.svg".format(tax_level, station), scale=1)

In [None]:
# Define the taxonomic level and aggregation levels
#tax_level = 'genus'
tax_level = 'Genus_UniEuk'
aggregation_level = 'month'
aggregation_level2 = 'station'

# Filter the data for dinoflagellates
#data_dinoflagellates = data[(data['class'] == 'Dinophyceae') & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPL']]
data_dinoflagellates = data[(data['Taxogroup2_UniEuk'].isin(['Dinophyceae', 'core-Noctilucales'])) & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPL']]

# Remove transcripts below a certain TPL threshold
data_dinoflagellates = data_dinoflagellates[data_dinoflagellates['TPL'] > 1]

# Group by month, station, and species, sum TPL
data_dinoflagellates = data_dinoflagellates.groupby([aggregation_level, aggregation_level2, tax_level]).sum().reset_index()

# Specify the desired order of months
month_order = ["July_2020", "August_2020", "September_2020", "November_2020", "December_2020", "January_2021",
               "February_2021", "April_2021", "May_2021", "June_2021", "July_2021"]

# Plot total TPL per station
pixels_per_cm = 37.79527559055118
for station in data_dinoflagellates[aggregation_level2].unique():
    plot_data = data_dinoflagellates[data_dinoflagellates[aggregation_level2] == station]
    # Normalise mean of TPL values to the total mean TPL of that month
    plot_data["rel_expression_per_month"] = plot_data.TPL / plot_data.groupby('month').TPL.transform('sum')

    # Combine low-abundant groups
    rare_groups = plot_data['rel_expression_per_month'] <= 0.02
    plot_data.loc[rare_groups, tax_level] = 'Rare'
    
    fig = px.histogram(plot_data, x='TPL', y=aggregation_level, color=tax_level,
            color_discrete_map={
                "Rare": "#545454",
                "Alexandrium": "#ebac23",
                "Amphidinium": "#b80058",
                "Noctiluca": "#008cf9",
                "Tripos": "#006e00",
                "Kryptoperidinium": "#00bbad",
                "Oxyrrhis": "#d163e6",
                "Karenia": "#b24502",
                "Symbiodinium": "#ff9287",
                "Ceratium": "#5954d6",
                "Durinskia": "#00c6f8",
                "Heterocapsa": "#878500",
                "Scripsiella": "#00a76c",
                'Polarella': "#FC8D62",
                'Gonyaulax': "#8DA0CB",
                'Azadinium': "#E78AC3",
                'Lingulodinium': "#A6D854",
                'Karlodinium': "#B3B3CC"
            },
            category_orders={aggregation_level: month_order,
                            tax_level: ['Rare', 'Noctiluca', 'Alexandrium', 'Durinskia', 'Kryptoperidinium', 'Gymnodinium', 'Pelagodinium', 'Polykrikos', 'Tripos']},
            title=station)
    
    fig.update_layout(
        font=dict(
            family="Times New Roman, serif",
            size=8,
            color="#000000"
        ),
        autosize=False,
        width=8.5 * pixels_per_cm,
        height=7.5 * pixels_per_cm,
        margin=dict(
            l=0,
            r=25,
            b=25,
            t=25
        ),
        xaxis_title_text='Month',
        yaxis_title_text='Total dinoflagellate TPL',
        legend_title_text=tax_level,
        # Set range of x-axis
        xaxis_range=[0, 700000000]
    )
    
    fig.show()

    # Save figure as PNG and SVG
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_total_TPL_{}_at_{}_phylodb.png".format(tax_level, station), scale=1)
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_total_TPL_{}_at_{}_phylodb.svg".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_total_TPL_{}_at_{}_eukprot.png".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_total_TPL_{}_at_{}_eukprot.svg".format(tax_level, station), scale=1)

### 3.2 Biodiversity indices
Now, we'll generate samples x species tables and use those to generate biodiversity indices per sample. Then, we'll visualize patterns in said indices over time (in months).

In [None]:
# Filter the data
## Get the diatoms
# tax_level = 'genus'
tax_level = 'Name_to_Use'
aggregation_level = 'sample'

## Since we'll be looking at the relative abundance of different diatom genera, we can only include reads that are annotated to a genus with a sufficient % sequence identity
data_dinoflagellates = data[(data['Taxogroup2_UniEuk'].isin(['Dinophyceae', 'core-Noctilucales'])) & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]
# Remove transcripts below a certain TPM threshold
data_dinoflagellates = data_dinoflagellates[data_dinoflagellates['TPM'] > 1]

# First, I only want to include transcriptome bins that have at least 100 non-zero expressed transcripts in one of the samples
# Group by month and genus, sum TPM
data_dinoflagellates = data_dinoflagellates.groupby([aggregation_level, tax_level]).count().reset_index()

# Now transform the data to the wide format
data_dinoflagellates = data_dinoflagellates.pivot(index=tax_level, columns=aggregation_level, values='TPM')
# Set NaN values to 0
data_dinoflagellates = data_dinoflagellates.fillna(0)

# Visualize the data
data_dinoflagellates.head()

In [None]:
data_dinoflagellates[data_dinoflagellates > 100].count()

In [None]:
# Check how many diatom genera exist per month that have > 100 non-zero TPM transcripts
## Per sample, count the number of genera that have > 100 TPM
taxonomic_bin_abundance = data_dinoflagellates[data_dinoflagellates > 100].count(axis=0).reset_index()
taxonomic_bin_abundance.columns = ['sample', 'num_species']

## Add metadata
taxonomic_bin_abundance = taxonomic_bin_abundance.merge(meta[['month']], left_on='sample', right_index=True, how='left')

# Plot the number of diatom genera per month that have > 100 TPM

fig = px.box(taxonomic_bin_abundance, x='month', y='num_species',
             category_orders={'month': ['July_2020', "August_2020", "September_2020", "November_2020", 
               "December_2020", "January_2021", "February_2021", "April_2021", 
               "May_2021", "June_2021", "July_2021"]})

fig.show()

# Save the figure
fig.write_image("../../figures/diatoms_vs_dinoflagellates/num_dinoflagellate_species_per_month.png", width = 3.5, height = 2, scale=1, format='svg')

In [None]:
# Extract the genera that have > 100 TPM in at least one sample
taxonomic_bin_abundance = data_dinoflagellates[data_dinoflagellates > 100].count(axis=1)
taxonomic_bin_abundance.head()

In [None]:
# Plot distribution of most abundant dinoflagellate genera
taxonomic_bin_abundance.sort_values(ascending=False).plot(kind='bar', figsize=(10, 5))

In [None]:
# Extract the genera that have > 100 TPM in at least one sample
taxonomic_bin_abundance = data_dinoflagellates[data_dinoflagellates > 100].count(axis=1).reset_index()
taxonomic_bin_abundance.columns = ['taxonomic_bin', 'abundant_in_samples']
taxonomic_bin_abundance.head()

In [None]:
## Since we'll be looking at the relative abundance of different diatom genera, we can only include reads that are annotated to a genus with a sufficient % sequence identity
data_dinoflagellates = data[(data['Taxogroup2_UniEuk'].isin(['Dinophyceae', 'core-Noctilucales'])) & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]
# Remove transcripts below a certain TPM threshold
data_dinoflagellates = data_dinoflagellates[data_dinoflagellates['TPM'] > 1]

# Group by month and genus, sum TPM
data_dinoflagellates = data_dinoflagellates.groupby([aggregation_level, tax_level]).count().reset_index()

# Now transform the data to the wide format
data_dinoflagellates = data_dinoflagellates.pivot(index=tax_level, columns=aggregation_level, values='TPM')
# Set NaN values to 0
data_dinoflagellates = data_dinoflagellates.fillna(0)

# Filter the dataframe for these genera
abundant_taxa = taxonomic_bin_abundance[taxonomic_bin_abundance['abundant_in_samples'] >= 1]['taxonomic_bin'].tolist()
data_dinoflagellates = data_dinoflagellates[data_dinoflagellates.index.isin(abundant_taxa)]

# Save to csv
data_dinoflagellates.to_csv(f'../../data/analysis/dinoflagellate_{tax_level}_{aggregation_level}.csv')

In [None]:
# Now save the count of these genera per sample
taxonomic_bin_abundance = data_dinoflagellates[data_dinoflagellates > 100].count(axis=0).reset_index()
taxonomic_bin_abundance.columns = ['sample', 'num_species']

In [None]:
taxonomic_bin_abundance.head()

Now run the above through the R [diversity](diversity.r) script. Then, we'll visualize the results.

In [None]:
# Load biodiversity estimates
dinoflagellate_biodiversity = pd.read_csv('../../data/analysis/dinoflagellates_biodiversity_estimates.csv', index_col=0)

# Add metadata to the biodiversity estimates
meta = pd.read_csv('../../samples.csv', sep=';', index_col=0)
dinoflagellate_biodiversity = dinoflagellate_biodiversity.merge(meta, left_index=True, right_index=True)

# Add the number of dinoflagellate species per sample
dinoflagellate_biodiversity = dinoflagellate_biodiversity.merge(taxonomic_bin_abundance, left_index=True, right_on='sample')

In [None]:
# Plot number of dinoflagellate genera per month
# Set figure size
cm = 1/2.54
plt.figure(figsize=(5*cm, 5*cm))

month_order = ["July_2020", "August_2020", "September_2020", "November_2020", 
                "December_2020", "January_2021", "February_2021", "April_2021", 
                "May_2021", "June_2021", "July_2021"]

# Create a plot of the functional annotation per month
sns.boxplot(x="month", y="num_species",
              order=month_order,
              data=dinoflagellate_biodiversity,
              color='#63C5DA')

# Change x axis labels rotation
plt.xticks(rotation=90, ha='right')

# Add grid
plt.grid(axis='both')

# save figure as svg
plt.savefig(f'../../figures/diatoms_vs_dinoflagellates/dinoflagellate_genera_per_month.svg', dpi=600, bbox_inches='tight')

In [None]:
# Plot evenness per month
# Set figure size
cm = 1/2.54
plt.figure(figsize=(5*cm, 5*cm))

month_order = ["July_2020", "August_2020", "September_2020", "November_2020", 
                "December_2020", "January_2021", "February_2021", "April_2021", 
                "May_2021", "June_2021", "July_2021"]

# Create a plot of the functional annotation per month
sns.boxenplot(x="month", y="evenness",
              order=month_order,
              scale="linear", data=dinoflagellate_biodiversity,
              color='#63C5DA')

# Change x axis labels rotation
plt.xticks(rotation=90, ha='right')

# Set y axis limits
plt.ylim(-0.05, 1.05)

# Add grid
plt.grid(axis='both')

# save figure as svg
plt.savefig(f'../../figures/diatoms_vs_dinoflagellates/dinoflagellate_evenness_per_month.svg', dpi=600, bbox_inches='tight')

### 3.3 Calculate Functional Richness

#### Data preparation

In [None]:
# Set aggregation levels of interest
functional_info = 'KEGG_ko'
tax_level = 'Name_to_Use'
# tax_level = 'genus'
aggregation_level = 'sample'
# Only retain abundant species determined above
data_dinoflagellates = data[data[tax_level].isin(abundant_taxa)][[aggregation_level, functional_info, 'TPM']]

# if processing kegg data, extra preprocessing is required: 
# Comment or uncomment the following line if multiple values are assigned to a transcript!
## split them up
data_dinoflagellates = data_dinoflagellates.assign(**{functional_info:data_dinoflagellates[functional_info].str.split(',')})
# Now we can explode the functional column
data_dinoflagellates = data_dinoflagellates.explode(functional_info)
# Cut off that weird part of the identifier, if necessary!
data_dinoflagellates[functional_info] = data_dinoflagellates[functional_info].str.split(":", expand=True).drop(columns=0)

# Remove the rows with no functional annotation
data_dinoflagellates = data_dinoflagellates[data_dinoflagellates[functional_info].notna()]
data_dinoflagellates = data_dinoflagellates[data_dinoflagellates[functional_info] != '-']

# Remove transcripts below a certain TPM threshold
data_dinoflagellates = data_dinoflagellates[data_dinoflagellates['TPM'] > 1]

# Group by functional information and sample values, sum TPM
functional_diversity = data_dinoflagellates.groupby([functional_info, aggregation_level]).sum().reset_index()

# Reshape the dataframe to wide format
functional_diversity = functional_diversity.pivot(index=functional_info, columns=aggregation_level, values='TPM')

# Set NaN values to 0
functional_diversity = functional_diversity.fillna(0)

# Calculate functional richness
functional_diversity[functional_diversity < 1] = 0
functional_diversity[functional_diversity >= 1] = 1

functional_richness =  functional_diversity.sum(axis=0).reset_index()

functional_richness.columns = ['sample', functional_info]
functional_richness.set_index('sample', inplace=True)

In [None]:
# Add the functional richness to the biodiversity estimates
dinoflagellate_biodiversity = dinoflagellate_biodiversity.merge(functional_richness, left_on='sample', right_index=True)

In [None]:
dinoflagellate_biodiversity.head()

#### Abundance of functions

In [None]:
# Plot the amount of KO per month in a boxplot
# Set figure size
cm = 1/2.54
plt.figure(figsize=(5*cm, 5*cm))

# Create a plot of the functional annotation per month
sns.boxenplot(x="month", y="KEGG_ko",
              order=month_order,
              scale="linear", data=dinoflagellate_biodiversity,
              color='#63C5DA')

# Change x axis labels rotation
plt.xticks(rotation=90, ha='right')

# Set y axis limits
plt.ylim(-50, 5500)

# Add grid
plt.grid(axis='both')

# save figure as svg
plt.savefig(f'../../figures/diatoms_vs_dinoflagellates/dinoflagellates_KO_abundance_per_month.svg', dpi=600, bbox_inches='tight')

In [None]:
# How is functional richness correlated to the amount of diatom reads?
## Calculate the number of reads that are assigned to diatoms in a given sample
number_of_reads = data[data['Taxogroup2_UniEuk'] == 'Dinophyceae'][[aggregation_level, 'TPM']]
number_of_reads = number_of_reads.groupby(aggregation_level).sum().reset_index()

# Add it to the diatom_biodiversity data
dinoflagellate_biodiversity = dinoflagellate_biodiversity.merge(number_of_reads, on='sample')

In [None]:
# Plot
fig = px.scatter(dinoflagellate_biodiversity, x="TPM", y="KEGG_ko")

fig.update_traces(marker=dict(size=5,
                              color='#63C5DA',
                              line=dict(width=1,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 5.5 * pixels_per_cm,
    height= 5.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    plot_bgcolor='white',
    showlegend=False
)

fig.update_xaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    linewidth=0.5,
    gridcolor='lightgrey'
)
fig.update_yaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    linewidth=0.5,
    gridcolor='lightgrey'
)

fig.show()

#### Active species richness vs abundance of functional modules

In [None]:
plot_data = dinoflagellate_biodiversity.copy()

# Create a linear model
linear_model = LinearRegression()
# Create a polynomial model
poly_model = make_pipeline(PolynomialFeatures(degree=2), LinearRegression())

# Calculate the cross-validated scores for all models using mean squared error as the scoring metric
linear_scores = cross_val_score(linear_model, plot_data[['num_species']], plot_data['KEGG_ko'], cv=5)
poly_scores = cross_val_score(poly_model, plot_data[['num_species']], plot_data['KEGG_ko'], cv=5)

# Compare the average scores of both models
print("Linear model: %0.2f accuracy with a standard deviation of %0.2f" % (linear_scores.mean(), linear_scores.std()))
print("Polynomial model: %0.2f accuracy with a standard deviation of %0.2f" % (poly_scores.mean(), poly_scores.std()))

In [None]:
# Polynomial model is the better fit
## Fit the polynomial model
poly_model.fit(plot_data[['num_species']], plot_data['KEGG_ko'])
## Extract the predictions of the polynomial model
# First generate x-values for the polynomial fit
x_values = np.linspace(plot_data['num_species'].min(), plot_data['num_species'].max(), 100)

# Predict y-values using the polynomial fit
model_predictions = poly_model.predict(x_values.reshape(-1, 1))

# Set conversion factor to transform pixels to cm
pixels_per_cm = 37.79527559055118

fig = px.scatter(plot_data, x="num_species", y="KEGG_ko")

# Add the polynomial fit
fig.add_trace(go.Scatter(
    x=x_values,
    y=model_predictions,
    mode='lines',
    line=dict(color='black', width=1),
    name='Polynomial Fit'
))

fig.update_traces(marker=dict(size=5,
                              color='#63C5DA',
                              line=dict(width=1,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 5.5 * pixels_per_cm,
    height= 5.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    plot_bgcolor='white',
    showlegend=False
)

fig.update_xaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    linewidth=0.5,
    gridcolor='lightgrey'
)
fig.update_yaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    linewidth=0.5,
    gridcolor='lightgrey',
    range = [-50, 5500]
)

fig.show()

In [None]:
# There's fewer active species in dinoflagellates, so let's just create a boxplot of the number of active species vs functional richness
# Plot the amount of KO per month in a boxplot
# Set figure size
cm = 1/2.54
plt.figure(figsize=(5*cm, 5*cm))

# Create a plot of the functional annotation per month
sns.boxenplot(x="num_species", y="KEGG_ko",
              scale="linear", data=dinoflagellate_biodiversity,
              color='#63C5DA')

# Set y axis limits
plt.ylim(-50, 5500)

# Add grid
plt.grid(axis='both')

# save figure as svg
plt.savefig(f'../../figures/diatoms_vs_dinoflagellates/dinoflagellates_KO_abundance_num_species.svg', dpi=600, bbox_inches='tight')

In [None]:
# Colour according to month
fig = px.scatter(plot_data, x="num_species", y="KEGG_ko", trendline="ols", color='month', color_discrete_sequence=px.colors.qualitative.Set1)

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=14,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 14 * pixels_per_cm,
    height= 10 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    plot_bgcolor='white',
)

fig.update_xaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    gridcolor='lightgrey'
)
fig.update_yaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    gridcolor='lightgrey',
    range = [-50, 5500]
)

fig.show()

# Save figure as png
fig.write_image(f'../../figures/diatoms_vs_dinoflagellates/dinoflagellate_KO_abundance_vs_species_per_month.svg')

#### Trophic feeding mode PFAMs
From the above we can hypothesize that the dinoflagellate community is mixotrophic. We'll see if this is reflected in the expression of feeding-mode related PFAMS.

In [None]:
# Now, create a new data_diato dataframe with PFAM expression data
functional_info = 'PFAMs'
aggregation_level = 'month'
data_dinoflagellates_PFAM = data[data['class'] == 'Dinophyceae'][[aggregation_level, functional_info, 'TPM']]

In [None]:
# Check how many of the relevant PFAMs are present in the data
data_dinoflagellates_PFAM[functional_info].isin(relevant_pfams['short_name']).value_counts(normalize=True)

In [None]:
# Check to which clade these hits belong
relevant_pfams[relevant_pfams['short_name'].isin(data_dinoflagellates_PFAM[functional_info])]['clade'].value_counts()

In [None]:
# Only retain the relevant PFAMs
data_dinoflagellates_PFAM = data_dinoflagellates_PFAM[data_dinoflagellates_PFAM[functional_info].isin(relevant_pfams['short_name'])]

In [None]:
# What are the most abundant PFAMs?
data_dinoflagellates_PFAM.groupby(functional_info)['TPM'].sum().sort_values(ascending=False).head(10)

In [None]:
data_dinoflagellates_PFAM.head()

In [None]:
# Determine the genes with highest variance across samples
number_of_PFAMs = 15
# Log2 transform the data
plot_data = data_dinoflagellates_PFAM.groupby([aggregation_level, functional_info])['TPM'].sum().reset_index()

# Add number of stations visited per month
samples_per_month = {
    'July_2020': 6,
    'August_2020': 6,
    'September_2020': 6,
    'November_2020': 6,
    'December_2020': 6,
    'January_2021': 5,
    'February_2021': 5,
    'April_2021': 4,
    'May_2021': 6,
    'June_2021': 6,
    'July_2021': 6
}

# Divide the TPM sum by number of samples per month to obtain the average TPM per month
plot_data['TPM'] = plot_data['TPM'] / plot_data[aggregation_level].map(samples_per_month)

# Divide the TPM sum by the total TPM sum per month dedicated to dinoflagellates
plot_data['TPM'] = plot_data['TPM'] / plot_data.groupby(aggregation_level)['TPM'].transform('sum')

# Swap the short PFAM names for their functional descriptions

# Transform the data to the wide format
plot_data_wide = plot_data.pivot(index=aggregation_level, columns=functional_info, values='TPM')
# log2 transform the data
plot_data_log = np.log2(plot_data_wide + 1)
# Using non-log2 transformed data gives more weight to more expressed genes; these have higher variances
# Using log2 transformed data gives more weight to less expressed genes; and counterselects for highly expressed genes

# Transpose the data
print(plot_data_log.shape)
plot_data_log = plot_data_log.transpose()

# Re-order the columns according to month
month_order = ["July_2020", "August_2020", "September_2020", "November_2020", 
                "December_2020", "January_2021", "February_2021", "April_2021", 
                "May_2021", "June_2021", "July_2021"]

plot_data_log = plot_data_log.reindex(month_order, axis=1)

# Scale the data
plot_data_scaled = plot_data_log.apply(lambda x: (x - x.mean()) / x.std(), axis=1)

# Cluster similar rows together
plot_data_scaled = plot_data_scaled.reindex(plot_data_scaled.mean(axis=1).sort_values(ascending=False).index, axis=0)

# In the index, include the PFAM clade and functional descriptions to the short names between brackets
plot_data_scaled.index = plot_data_scaled.index.map(lambda x: "{} ({} - {})".format(x, relevant_pfams[relevant_pfams['short_name'] == x]['clade'].values[0], relevant_pfams[relevant_pfams['short_name'] == x]['function'].values[0]))

# Only retain the 15 most variable PFAMs
most_variable_PFAMs = plot_data_scaled.var(axis=1).sort_values(ascending=False).head(number_of_PFAMs).index
plot_data_scaled = plot_data_scaled.loc[most_variable_PFAMs]

In [None]:
plot_data_scaled.head()

In [None]:
# Plot the heatmap
g = sns.clustermap(
    plot_data_scaled,
    figsize=(8, 6),
    cmap='mako',
    col_cluster=False,
    # Legend bar should be on the right
    cbar_kws={'orientation': 'vertical'},
    )

g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize=8)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), fontsize=8)
g.ax_heatmap.set_facecolor('#f7f7f7')
g.ax_heatmap.set_xlabel(None)
g.ax_heatmap.set_ylabel(None)

plt.show()

# Save figure as svg
g.savefig("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_heatmap_{}_most_variable_trophic_{}_per_{}.svg".format(number_of_genes, functional_info, aggregation_level), format='svg', dpi=600)

Now we'll do the same but with the 15 most expressed PFAMs.

In [None]:

# Determine the genes with highest variance across samples
number_of_PFAMs = 15
# Log2 transform the data
plot_data = data_dinoflagellates_PFAM.groupby([aggregation_level, functional_info])['TPM'].sum().reset_index()

# Add number of stations visited per month
samples_per_month = {
    'July_2020': 6,
    'August_2020': 6,
    'September_2020': 6,
    'November_2020': 6,
    'December_2020': 6,
    'January_2021': 5,
    'February_2021': 5,
    'April_2021': 4,
    'May_2021': 6,
    'June_2021': 6,
    'July_2021': 6
}

# Divide the TPM sum by number of samples per month to obtain the average TPM per month
plot_data['TPM'] = plot_data['TPM'] / plot_data[aggregation_level].map(samples_per_month)

# Divide the TPM sum by the total TPM sum per month dedicated to dinoflagellates
plot_data['TPM'] = plot_data['TPM'] / plot_data.groupby(aggregation_level)['TPM'].transform('sum')

# Swap the short PFAM names for their functional descriptions

# Transform the data to the wide format
plot_data_wide = plot_data.pivot(index=aggregation_level, columns=functional_info, values='TPM')
# log2 transform the data
plot_data_log = np.log2(plot_data_wide + 1)

# Using non-log2 transformed data gives more weight to more expressed genes; these have higher variances
# Using log2 transformed data gives more weight to less expressed genes; and counterselects for highly expressed genes

# Transpose the data
print(plot_data_log.shape)
plot_data_log = plot_data_log.transpose()

# Re-order the columns according to month
month_order = ["July_2020", "August_2020", "September_2020", "November_2020", 
                "December_2020", "January_2021", "February_2021", "April_2021", 
                "May_2021", "June_2021", "July_2021"]

plot_data_log = plot_data_log.reindex(month_order, axis=1)

# Scale the data
plot_data_scaled = plot_data_log.apply(lambda x: (x - x.mean()) / x.std(), axis=1)

# Cluster similar rows together
plot_data_scaled = plot_data_scaled.reindex(plot_data_scaled.mean(axis=1).sort_values(ascending=False).index, axis=0)

# In the index, include the PFAM clade and functional descriptions to the short names between brackets
plot_data_scaled.index = plot_data_scaled.index.map(lambda x: "{} ({} - {})".format(x, relevant_pfams[relevant_pfams['short_name'] == x]['clade'].values[0], relevant_pfams[relevant_pfams['short_name'] == x]['function'].values[0]))

# Only retain the 15 most expressed PFAMs
most_variable_PFAMs = plot_data_scaled.sum(axis=1).sort_values(ascending=False).head(number_of_PFAMs).index
plot_data_scaled = plot_data_scaled.loc[most_variable_PFAMs]

# 
g = sns.clustermap(
    plot_data_scaled,
    figsize=(8, 6),
    cmap='mako',
    col_cluster=False,
    # Legend bar should be on the right
    cbar_kws={'orientation': 'vertical'},
    )

g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize=8)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), fontsize=8)
g.ax_heatmap.set_facecolor('#f7f7f7')
g.ax_heatmap.set_xlabel(None)
g.ax_heatmap.set_ylabel(None)

plt.show()

# Save figure as svg
g.savefig("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_heatmap_{}_most_expressed_trophic_{}_per_{}.svg".format(number_of_genes, functional_info, aggregation_level), format='svg', dpi=600)

Now we'll do the same with the full feature set (all 1046 gene families used in the model described in the paper).

In [None]:
PFAM_features = pd.read_excel('../../data/external/pnas.2100916119.sd03.xlsx', skiprows=1, header=1)

PFAM_features = PFAM_features.dropna()

PFAM_features.head()

In [None]:
# Of the PFAM features, which are most expressed?
## Subset the data to only include the PFAMs
functional_info = 'PFAMs'
aggregation_level = 'month'
data_dinoflagellates_PFAM = data[data['class'] == 'Dinophyceae'][[aggregation_level, functional_info, 'TPM']]
dinoflagellate_features = data_dinoflagellates_PFAM[data_dinoflagellates_PFAM[functional_info].isin(PFAM_features['Name'])]

# Group by month, PFAM and sum the TPMs
dinoflagellate_features = dinoflagellate_features.groupby([aggregation_level, functional_info])['TPM'].sum().reset_index()

# Print the top 10 most expressed PFAMs 
dinoflagellate_features.groupby(functional_info)['TPM'].sum().sort_values(ascending=False).head(10)

In [None]:
# Visualize the most abundant PFAMs
number_of_PFAMs = 15
# Log2 transform the data
plot_data = dinoflagellate_features.groupby([aggregation_level, functional_info])['TPM'].sum().reset_index()

# Divide the TPM sum by number of samples per month to obtain the average TPM per month
plot_data['TPM'] = plot_data['TPM'] / plot_data[aggregation_level].map(samples_per_month)

# Divide the TPM sum by the total TPM sum per month dedicated to dinoflagellates
plot_data['TPM'] = plot_data['TPM'] / plot_data.groupby(aggregation_level)['TPM'].transform('sum')

# Swap the short PFAM names for their functional descriptions

# Transform the data to the wide format
plot_data_wide = plot_data.pivot(index=aggregation_level, columns=functional_info, values='TPM')
# log2 transform the data
plot_data_log = np.log2(plot_data_wide + 1)

# Using non-log2 transformed data gives more weight to more expressed genes; these have higher variances
# Using log2 transformed data gives more weight to less expressed genes; and counterselects for highly expressed genes

# Transpose the data
print(plot_data_log.shape)
plot_data_log = plot_data_log.transpose()

plot_data_log = plot_data_log.reindex(month_order, axis=1)

# Scale the data
plot_data_scaled = plot_data_log.apply(lambda x: (x - x.mean()) / x.std(), axis=1)

# Cluster similar rows together
plot_data_scaled = plot_data_scaled.reindex(plot_data_scaled.mean(axis=1).sort_values(ascending=False).index, axis=0)

# Only retain the 15 most variable PFAMs
most_variable_PFAMs = plot_data_scaled.median(axis=1).sort_values(ascending=False).head(number_of_PFAMs).index
plot_data_scaled = plot_data_scaled.loc[most_variable_PFAMs]

In [None]:
# Plot the heatmap
sns.set()

g = sns.clustermap(
    plot_data_scaled,
    figsize=(8, 6),
    cmap='mako',
    col_cluster=False
    
    )

g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize=8)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), fontsize=8)
g.ax_heatmap.set_facecolor('#f7f7f7')
g.ax_heatmap.set_xlabel(None)
g.ax_heatmap.set_ylabel(None)

plt.show()

In [None]:
# Now we'll do the same but with the 15 most highly expressed PFAMs
# Visualize the most abundant PFAMs
number_of_PFAMs = 15
# Log2 transform the data
plot_data = dinoflagellate_features.groupby([aggregation_level, functional_info])['TPM'].sum().reset_index()

# Divide the TPM sum by number of samples per month to obtain the average TPM per month
plot_data['TPM'] = plot_data['TPM'] / plot_data[aggregation_level].map(samples_per_month)

# Divide the TPM sum by the total TPM sum per month dedicated to dinoflagellates
plot_data['TPM'] = plot_data['TPM'] / plot_data.groupby(aggregation_level)['TPM'].transform('sum')

# Swap the short PFAM names for their functional descriptions

# Transform the data to the wide format
plot_data_wide = plot_data.pivot(index=aggregation_level, columns=functional_info, values='TPM')
# log2 transform the data
plot_data_log = np.log2(plot_data_wide + 1)

# Using non-log2 transformed data gives more weight to more expressed genes; these have higher variances
# Using log2 transformed data gives more weight to less expressed genes; and counterselects for highly expressed genes

# Transpose the data
print(plot_data_log.shape)
plot_data_log = plot_data_log.transpose()

plot_data_log = plot_data_log.reindex(month_order, axis=1)

# Scale the data
plot_data_scaled = plot_data_log.apply(lambda x: (x - x.mean()) / x.std(), axis=1)

# Cluster similar rows together
plot_data_scaled = plot_data_scaled.reindex(plot_data_scaled.mean(axis=1).sort_values(ascending=False).index, axis=0)

# Only retain the 15 most expressed PFAMs
most_variable_PFAMs = plot_data_scaled.sum(axis=1).sort_values(ascending=False).head(number_of_PFAMs).index
plot_data_scaled = plot_data_scaled.loc[most_variable_PFAMs]

# 
g = sns.clustermap(
    plot_data_scaled,
    figsize=(8, 6),
    cmap='mako',
    col_cluster=False,
    # Legend bar should be on the right
    cbar_kws={'orientation': 'vertical'},
    )

g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize=8)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), fontsize=8)
g.ax_heatmap.set_facecolor('#f7f7f7')
g.ax_heatmap.set_xlabel(None)
g.ax_heatmap.set_ylabel(None)

plt.show()

# Save figure as svg
# g.savefig("../../figures/diatoms_vs_dinoflagellates/dinoflagellates_heatmap_{}_most_variable_trophic_{}_per_{}.svg".format(number_of_genes, functional_info, aggregation_level), format='svg', dpi=600)

#### Trophic feeding mode of dinoflagellates
I've predicted the trophic feeding mode of several species using the machine learning model of the paper by [Lambert & Groussman, 2022](https://doi.org/10.1073/pnas.2100916119), see [here](./trophic_mode_prediction.ipynb). We'll use the data to see how the trophic feeding mode of dinoflagellates changes over time, same as for the diatoms.

In [None]:
# Load the data
#trophic_predictions = pd.read_csv('../../data/analysis/phylodb_trophic_mode_predictions.csv')
trophic_predictions = pd.read_csv('../../data/analysis/eukprot_trophic_mode_predictions.csv')
# Only retain the diatoms
#trophic_predictions = trophic_predictions[trophic_predictions['class'] == 'Dinophyceae']
trophic_predictions = trophic_predictions[trophic_predictions['Taxogroup2_UniEuk'].isin(['Dinophyceae', 'core-Noctilucales'])]
print(len(trophic_predictions))
trophic_predictions.head()

In [None]:
# Plot the relative abundance of the three trophic modes per month
# Calculate relative abundance of each prediction per month
rel_abundance = (trophic_predictions[['month', 'station', 'prediction']].groupby(['month', 'prediction'])
                       .size()
                       .groupby(level=0).apply(lambda x: 100 * x / x.sum())
                       .reset_index(name='relative abundance'))
# Set the color palette
color_map = {
    'Phot': 'green',
    'Mix': 'black',
    'Het': 'red'
}
fig = px.histogram(rel_abundance.sort_values("month", ascending=False),
             x = "relative abundance",
             y = "month",
             color = "prediction",
             orientation='h',
             color_discrete_map=color_map,
             category_orders={"month": ["July_2020", "August_2020", "September_2020",
                "November_2020", "December_2020", "January_2021",
                "February_2021", "April_2021", "May_2021",
                "June_2021", "July_2021"]})

pixels_per_cm = 37.7952755906

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8  # Set the font size
    ),
    autosize=False,
    width= 8.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='Ratio of predicted trophic modes (%)',
    yaxis_title_text=None,
)

fig.show()

# save figure as svg
#fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellate_trophic_mode_relative_abundance.svg")
fig.write_image("../../figures/diatoms_vs_dinoflagellates/dinoflagellate_trophic_mode_relative_abundance_eukprot.svg")

In [None]:
# Extract the relevant columns
#tax_level = 'genus'
tax_level = 'Genus_UniEuk'
df = trophic_predictions[[tax_level, 'month', 'prediction']]
# Convert 'prediction' into numerical categories
df['prediction'] = df['prediction'].map({'Phot':1, 'Mix':2, 'Het':3})

# Pivot the DataFrame to use for heatmap
df_pivot = df.pivot_table(index=tax_level, columns='month', values='prediction', 
                          aggfunc=lambda x: int(stats.mode(x[x!=0])[0][0] if any(x!=0) else 0))

# Fill NA with a specific category (e.g., 0 for 'no prediction')
df_pivot = df_pivot.fillna(0)

# Ensure the values are integer
df_pivot = df_pivot.astype(int)

# Set the order of the months
df_pivot = df_pivot.reindex(month_order, axis=1)

# Set the order of the genera to be the same as in the above relative abundance plot
df_pivot = df_pivot.reindex(dinoflagellate_legend_order, axis=0)

# Create a color map
cmap = mcolors.ListedColormap(['lightgrey', 'green', 'yellow', 'red'])

# Create the heatmap with the color map
plt.figure(figsize=(4.5, 4))
sns.heatmap(df_pivot, cmap=cmap, annot=False, cbar=False, vmin=0, vmax=3)

# Change layout to match the relative abundance plot
plt.tight_layout()

# Remove the x and  y-axis label
plt.xlabel(None)
plt.ylabel(None)

# Rotate the -axis labels
plt.yticks(rotation=0)

# Font size should be 8, and font family Times New Roman
plt.rcParams.update({'font.size': 8, 'font.family': 'Times New Roman'})

# Save figure as svg
#plt.savefig("../../figures/diatoms_vs_dinoflagellates/dinoflagellate_trophic_mode_genus_consensus_heatmap.svg", format='svg', dpi=600)
plt.savefig("../../figures/diatoms_vs_dinoflagellates/dinoflagellate_trophic_mode_genus_consensus_heatmap_eukprot.svg", format='svg', dpi=600)


plt.show()


In [None]:
df[df[tax_level] == 'Noctiluca']

#### PCA

In [None]:
# Specify how the matrix will look like (in the end its aggregation_level x functional_info)
functional_info = 'KEGG_ko'
aggregation_level = 'sample'

# Prepare gene expression data and load environmental variables
#data_dinoflagellates = data[data['class'] == 'Dinophyceae'][[aggregation_level, functional_info, 'TPM']]
data_dinoflagellates = data[data['Taxogroup2_UniEuk'].isin(['Dinophyceae', 'core-Noctilucales'])][[aggregation_level, functional_info, 'TPM']]

# if processing kegg data, extra preprocessing is required: 
# Comment or uncomment the following line if multiple values are assigned to a transcript!
## split KEGG identifiers up
data_dinoflagellates = data_dinoflagellates.assign(**{functional_info:data_dinoflagellates[functional_info].str.split(',')})
# Now we can explode the functional column
data_dinoflagellates = data_dinoflagellates.explode(functional_info)
# Cut off that weird part of the identifier, if necessary!
data_dinoflagellates[functional_info] = data_dinoflagellates[functional_info].str.split(":", expand=True).drop(columns=0)

# Group by functional information and sample values, sum TPM
data_dinoflagellates = data_dinoflagellates.groupby([functional_info, aggregation_level]).sum().reset_index()

# Remove the rows with no functional annotation
data_dinoflagellates = data_dinoflagellates[data_dinoflagellates[functional_info].notna()]
data_dinoflagellates = data_dinoflagellates[data_dinoflagellates[functional_info] != '-']

# Transform the data to the wide format
data_dinoflagellates = data_dinoflagellates.pivot(index=aggregation_level, columns=functional_info, values='TPM')

# Log2 transform the data
data_dinoflagellates_log = np.log2(data_dinoflagellates + 1)

# Scale the features (columns, TPM values of every prediction)
data_dinoflagellates_log = pd.DataFrame(StandardScaler().fit_transform(data_dinoflagellates_log),
                                   index=data_dinoflagellates.index,
                                   columns=data_dinoflagellates.columns)
## Scaling removes the mean and scales to unit variance, the resulting values are z-scores

# Load the environmental variables
env_variables = pd.read_csv("../../data/environmental/samples_env.csv", sep=";")
env_variables = env_variables.set_index('sample')

# Merge environmental variables with gene expression data
pca_data = data_dinoflagellates_scaled.merge(env_variables, left_index=True, right_on='sample', how='left')

# Perform PCA on gene expression data, ignore columns that are also in the environmental variables
pca = PCA(n_components=2)
principalComponents = pca.fit_transform(pca_data.drop(columns=env_variables.columns))

# Add principal components to the data
pca_data['PC1'] = principalComponents[:, 0]
pca_data['PC2'] = principalComponents[:, 1]

# Create a color dictionary for the months
month_color_dict = dict(zip(data['month'].unique(), sns.color_palette('tab10', n_colors=len(data['month'].unique()))))

# Set figure size and font scale
cm = 1/2.54
plt.figure(figsize=(12*cm, 14*cm))
sns.set(style='white', font_scale=1)

# Plot PCA biplot with the color of the month corrresponding to the month of the sample
# and the shape of the point corresponding to the station of the sample
sns.scatterplot(data=pca_data,
                x='PC1',
                y='PC2',
                hue='month',
                hue_order=month_order,
                style='station',
                palette=month_color_dict,
                s=40, edgecolor='black', linewidth=0.5)

plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)
# Plot labels with the variance explained by each principal component
plt.xlabel('Principal Component 1 ({}%)'.format(round(pca.explained_variance_ratio_[0] * 100, 2)))
plt.ylabel('Principal Component 2 ({}%)'.format(round(pca.explained_variance_ratio_[1] * 100, 2)))

# Plot arrows indicating the correlation between the principal components and the environmental variables
## Linear regression to determine the correlation between the principal components and the environmental variables
columns_of_interest = ['NO3', 'PO4', 'Si', 'SPM', 'salinity', 'Temperature']

# Calculate the correlation matrix between the principal components and the environmental variables
corr_matrix = np.corrcoef(pca_data[['PC1', 'PC2'] + columns_of_interest].dropna().T)

# Get the correlation between the first two principal components and each environmental parameter
corr_PC1_env = corr_matrix[0, 2:]
corr_PC2_env = corr_matrix[1, 2:]

# Plot the arrows representing the correlations
for i, env_param in enumerate(columns_of_interest):
    plt.arrow(0, 0, corr_PC1_env[i]*max(pca_data['PC1']), corr_PC2_env[i]*max(pca_data['PC2']), head_width=0.05, color='gray')
    # corr_PC1_env[i] specifies the end coordinate of the arrow in the x direction
    # This coordinate is multiplied with max(pca_data['PC1']) to scale the arrow to the length of the principal component for visualization
    plt.text(corr_PC1_env[i]*max(pca_data['PC1'])*1.05, corr_PC2_env[i]*max(pca_data['PC2'])*1.05, env_param, fontsize=12, color='gray')
    # Same as above, but the text is placed 5% further away from the arrow

# Save figure as svg
#plt.savefig("../../figures/diatoms_vs_dinoflagellates/dinoflagellate_PCA_{}_per_{}.svg".format(functional_info, aggregation_level), format='svg', dpi=600)
plt.savefig("../../figures/diatoms_vs_dinoflagellates/dinoflagellate_PCA_{}_per_{}_eukprot.svg".format(functional_info, aggregation_level), format='svg', dpi=600)

plt.show()

### 4. Prymnesiophytes

In [None]:
# First, let's get a list of all the unique dinoflagellate species
#prymnesiophytes = data[(data['class'] == 'Prymnesiophyceae')  & (data['p_ident'] >= 0.98)]
data[(data['Taxogroup2_UniEuk'] == 'Prymnesiophyceae')  & (data['p_ident'] >= 0.98)]['Name_to_Use'].unique()

#### 4.1 Prymnesiophyte abundance

In [None]:
#tax_level = 'species'
tax_level = 'Genus_UniEuk'
aggregation_level = 'month'

#data_prymnesiophytes = data[(data['class'] == 'Prymnesiophyceae') & (data['genus'] == 'Phaeocystis') & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]
data_prymnesiophytes = data[(data['Taxogroup2_UniEuk'] == 'Prymnesiophyceae') & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]

# Remove transcripts below a certain TPM threshold
data_prymnesiophytes = data_prymnesiophytes[data_prymnesiophytes['TPM'] > 1]

# Group by month and species, sum TPM
data_prymnesiophytes = data_prymnesiophytes.groupby([aggregation_level, tax_level]).sum().reset_index()

data_prymnesiophytes['month'] = pd.Categorical(data_prymnesiophytes['month'], ["July_2020", "August_2020", "September_2020", 
                                                        "November_2020", "December_2020", "January_2021", 
                                                        "February_2021", "April_2021", "May_2021", 
                                                        "June_2021", "July_2021"])

# Normalise mean of TPM values to the total mean TPM of that month
data_prymnesiophytes["rel_expression_per_month"] = data_prymnesiophytes.TPM / data_prymnesiophytes.groupby('month').TPM.transform('sum')

# Combine low-abundant groups
rare_groups = data_prymnesiophytes['rel_expression_per_month'] <= 0.02
data_prymnesiophytes.loc[rare_groups, tax_level] = 'Rare'

# Print unique species
print(data_prymnesiophytes[tax_level].unique())
# Inspect data
data_prymnesiophytes.head()

In [None]:
# Add conversion variable
pixels_per_cm = 37.79527559055118

# Plot
fig = px.histogram(data_prymnesiophytes.sort_values("month", ascending=False), 
                x="rel_expression_per_month", 
                y="month", 
                color=tax_level,
                orientation='h',
                color_discrete_map={
                    "Rare": "#545454",
                    "Phaeocystis": "#F26B38",
                 },
                # Always use the same order for the species
                category_orders={tax_level: ["Rare", "Phaeocystis"],
                                'month': ["July_2020", "August_2020", "September_2020",
                                            "November_2020", "December_2020", "January_2021",
                                            "February_2021", "April_2021", "May_2021",
                                            "June_2021", "July_2021"]}
                # text_auto='.2f'
                )


fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000",  # Set the font color
    ),
    autosize=False,
    width= 8.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='% TPM of total sum',
    yaxis_title_text=None,
)

fig.show()

# Save figure as png
#fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_relative_expression_per_month_{}.png".format(tax_level), scale=1)
#fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_relative_expression_per_month_{}.svg".format(tax_level), scale=1)
fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_relative_expression_per_month_{}_eukprot.png".format(tax_level), scale=1)
fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_relative_expression_per_month_{}_eukprot.svg".format(tax_level), scale=1)

In [None]:
# Store the order of the genera in the legend
prymnesiophyte_legend_order = []
for i in range(len(fig.data)):
    prymnesiophyte_legend_order.append(fig.data[i].name)
print(prymnesiophyte_legend_order)

Cool! Now let's only look at the Phaeocystis species.

In [None]:
#tax_level = 'species'
tax_level = 'Name_to_Use'
aggregation_level = 'month'

#data_prymnesiophytes = data[(data['class'] == 'Prymnesiophyceae') & (data['genus'] == 'Phaeocystis') & (data['p_ident'] >= 0.98)][[aggregation_level, tax_level, 'TPM']]
data_prymnesiophytes = data[(data['Taxogroup2_UniEuk'] == 'Prymnesiophyceae') & (data['Genus_UniEuk'] == 'Phaeocystis') & (data['p_ident'] >= 0.98)][[aggregation_level, tax_level, 'TPM']]

# Remove transcripts below a certain TPM threshold
data_prymnesiophytes = data_prymnesiophytes[data_prymnesiophytes['TPM'] > 1]

# Group by month and species, sum TPM
data_prymnesiophytes = data_prymnesiophytes.groupby([aggregation_level, tax_level]).sum().reset_index()

data_prymnesiophytes['month'] = pd.Categorical(data_prymnesiophytes['month'], ["July_2020", "August_2020", "September_2020", 
                                                        "November_2020", "December_2020", "January_2021", 
                                                        "February_2021", "April_2021", "May_2021", 
                                                        "June_2021", "July_2021"])

# Normalise mean of TPM values to the total mean TPM of that month
data_prymnesiophytes["rel_expression_per_month"] = data_prymnesiophytes.TPM / data_prymnesiophytes.groupby('month').TPM.transform('sum')

# Combine low-abundant groups
rare_groups = data_prymnesiophytes['rel_expression_per_month'] <= 0.02
data_prymnesiophytes.loc[rare_groups, tax_level] = 'Rare'

# Print unique species
print(data_prymnesiophytes[tax_level].unique())
# Inspect data
data_prymnesiophytes.head()

In [None]:
# Add conversion variable
pixels_per_cm = 37.79527559055118

# Plot
fig = px.histogram(data_prymnesiophytes.sort_values("month", ascending=False), 
                x="rel_expression_per_month", 
                y="month", 
                color=tax_level,
                orientation='h',
                color_discrete_map={
                    "Rare": "#545454",
                    "Phaeocystis": "#F26B38",
                 },
                # Always use the same order for the species
                category_orders={tax_level: ["Rare", "Phaeocystis"],
                                'month': ["July_2020", "August_2020", "September_2020",
                                            "November_2020", "December_2020", "January_2021",
                                            "February_2021", "April_2021", "May_2021",
                                            "June_2021", "July_2021"]}
                # text_auto='.2f'
                )


fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000",  # Set the font color
    ),
    autosize=False,
    width= 8.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='% TPM of total sum',
    yaxis_title_text=None,
)

fig.show()

# Save figure as png
#fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_phaeocystis_relative_expression_per_month_{}.png".format(tax_level), scale=1)
#fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_phaeocystis_relative_expression_per_month_{}.svg".format(tax_level), scale=1)
fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_phaeocystis_relative_expression_per_month_{}_eukprot.png".format(tax_level), scale=1)
fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_phaeocystis_relative_expression_per_month_{}_eukprot.svg".format(tax_level), scale=1)

#### Spatial distribution

In [None]:
#tax_level = 'genus'
tax_level = 'Genus_UniEuk'
aggregation_level = 'month'
aggregation_level2 = 'station'

#data_prymnesiophytes = data[(data['class'] == 'Prymnesiophyceae') & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPM']]
data_prymnesiophytes = data[(data['Taxogroup2_UniEuk'] == 'Prymnesiophyceae') & (data['p_ident'] >= 0.9)][[aggregation_level, aggregation_level2, tax_level, 'TPM']]

# Remove transcripts below a certain TPM threshold
data_prymnesiophytes = data_prymnesiophytes[data_prymnesiophytes['TPM'] > 1]

# Group by month, station and species, sum TPM
data_prymnesiophytes = data_prymnesiophytes.groupby([aggregation_level, aggregation_level2, tax_level]).sum().reset_index()

data_prymnesiophytes['month'] = pd.Categorical(data_prymnesiophytes['month'], ["July_2020", "August_2020", "September_2020", 
                                                        "November_2020", "December_2020", "January_2021", 
                                                        "February_2021", "April_2021", "May_2021", 
                                                        "June_2021", "July_2021"])

# Plot per station
for station in data_prymnesiophytes[aggregation_level2].unique():
    plot_data = data_prymnesiophytes[data_prymnesiophytes[aggregation_level2] == station]
    # Normalise mean of TPM values to the total mean TPM of that month
    plot_data["rel_expression_per_month"] = plot_data.TPM / plot_data.groupby('month').TPM.transform('sum')

    # Combine low-abundant groups
    rare_groups = plot_data['rel_expression_per_month'] <= 0.02
    plot_data.loc[rare_groups, tax_level] = 'Rare'
    fig = px.histogram(plot_data.sort_values("month", ascending=False), 
                    x="rel_expression_per_month", 
                    y="month", 
                    color=tax_level,
                    orientation='h',
                    # text_auto='.2f',
                    color_discrete_map={
                        "Rare": "#545454",
                        "Phaeocystis": "#F26B38",
                    },
                    # Specify all the months that need to be included, 
                    # even if no sample has been taken
                    category_orders={"month": ["July_2020", "August_2020", "September_2020",
                                        "November_2020", "December_2020", "January_2021",
                                        "February_2021", "April_2021", "May_2021",
                                        "June_2021", "July_2021"],
                                    #"genus": ["Rare", "Alexandrium", "Amphidinium", "Noctiluca", 
                                    #          "Symbiodinium", "Ceratium", "Tripos", "Kryptoperidinium", 
                                    #          "Oxyrrhis", "Karenia", "Durinskia", "Heterocapsa", "Scripsiella"],
                                    "Genus_UniEuk": ['Rare', 'Phaeocystis', 'Calcidiscus', 'Chrysochromulina', 
                                                     'Prymnesium', 'Scyphosphaera', 'Isochrysis', 'Emiliania', 
                                                     'Dicrateria', 'Haptolina', 'Coccolithus', 'Gephyrocapsa']}
                    )

    fig.update_layout(
        font=dict(
            family="Times New Roman, serif",  # Set the font family to Times New Roman
            size=8,  # Set the font size
            color="#000000"  # Set the font color
        ),
        autosize=False,
        width= 8.5 * pixels_per_cm,
        height= 7.5 * pixels_per_cm,
        margin=dict( # Set the margins
            l=0,  # Left margin
            r=25,  # Right margin
            b=25,  # Bottom margin
            t=25  # Top margin
        ),
        xaxis_title_text='% TPM of total sum',
        yaxis_title_text=None,
        title_text=station
    )

    fig.show()

    # Save figure as png
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophyte_relative_expression_per_month_{}_at_{}.png".format(tax_level, station), scale=1)
    #fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophyte_relative_expression_per_month_{}_at_{}.svg".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophyte_relative_expression_per_month_{}_at_{}_eukprot.png".format(tax_level, station), scale=1)
    fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophyte_relative_expression_per_month_{}_at_{}_eukprot.svg".format(tax_level, station), scale=1)

#### Total TPM of prymnesiophytes per month

In [None]:
#tax_level = 'species'
tax_level = 'Genus_UniEuk'
aggregation_level = 'month'

#data_prymnesiophytes = data[(data['class'] == 'Prymnesiophyceae') & (data['genus'] == 'Phaeocystis') & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]
data_prymnesiophytes = data[(data['Taxogroup2_UniEuk'] == 'Prymnesiophyceae') & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]
# Remove data without a genus or with NaN values
data_prymnesiophytes[tax_level].replace('', np.nan, inplace=True)
data_prymnesiophytes.dropna(subset=tax_level, inplace=True)

# Plot the total TPM per month
fig = px.histogram(data_prymnesiophytes.sort_values("month", ascending=False),
             x = "TPM",
             y = "month",
             color = tax_level,
            color_discrete_map={
                    "Rare": "#545454",
                    "Phaeocystis": "#F26B38",
                },
             orientation='h',
            category_orders={"month": ["July_2020", "August_2020", "September_2020",
                "November_2020", "December_2020", "January_2021",
                "February_2021", "April_2021", "May_2021",
                "June_2021", "July_2021"],
                "Genus_UniEuk": ["Rare", "Phaeocystis"]},)

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 11.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='Total prymnesiophyte TPM per month',
    yaxis_title_text=None,
    # Set range of x-axis
    xaxis_range=[0, 550000],
)

fig.show()

# save figure as svg
#fig.write_image(f"../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_total_TPM_per_month_{tax_level}.svg", scale=1)
fig.write_image(f"../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_total_TPM_per_month_{tax_level}_eukprot.svg", scale=1)

In [None]:
# Log version for clarity
data_prymnesiophytes_log = data_prymnesiophytes.copy()
data_prymnesiophytes_log["TPM"] = np.log10(data_prymnesiophytes_log["TPM"] + 1)

# Plot the total TPM per month
fig = px.histogram(data_prymnesiophytes_log.sort_values("month", ascending=False),
             x = "TPM",
             y = "month",
             color = tax_level,
            color_discrete_map={
                    "Rare": "#545454",
                    "Phaeocystis": "#F26B38",
                },
             orientation='h',
            category_orders={"month": ["July_2020", "August_2020", "September_2020",
                "November_2020", "December_2020", "January_2021",
                "February_2021", "April_2021", "May_2021",
                "June_2021", "July_2021"],
                "Genus_UniEuk": ["Rare", "Phaeocystis"]},)

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 8.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='log(Total prymnesiophyte TPM per month)',
    yaxis_title_text=None,
    # Set range of x-axis
#    xaxis_range=[0, 550000],
)

fig.show()

# save figure as svg
#fig.write_image(f"../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_total_TPM_per_month_{tax_level}_log.svg", scale=1)
fig.write_image(f"../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_total_TPM_per_month_{tax_level}_eukprot_log.svg", scale=1)

In [None]:
# Plot the total TPM per month per station
fig = px.histogram(data_prymnesiophytes.sort_values("month", ascending=False),
             x = "TPM",
             y = "month",
             color = 'station',
            color_discrete_map={
                "ZG02": "#8c613c",
                "120": "#956cb4",
                "330": "#ee854a",
                "130": "#4878d0",
                "780": "#d65f5f",
                "700": "#6acc64"},
             orientation='h',
            category_orders={"month": ["July_2020", "August_2020", "September_2020",
                "November_2020", "December_2020", "January_2021",
                "February_2021", "April_2021", "May_2021",
                "June_2021", "July_2021"]},)

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 8.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='Total prymnesiophyte TPM per month',
    yaxis_title_text=None,
)

fig.show()

# save figure as svg
fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_total_TPM_per_month_station.svg", scale=1)

In [None]:
# Log version for clarity
data_prymnesiophytes_log = data_prymnesiophytes.copy()
data_prymnesiophytes_log['TPM'] = np.log10(data_prymnesiophytes_log['TPM'])

# Plot the total TPM per month
fig = px.histogram(data_prymnesiophytes_log.sort_values("month", ascending=False),
             x = "TPM",
             y = "month",
             color = 'station',
            color_discrete_map={
                "ZG02": "#8c613c",
                "120": "#956cb4",
                "330": "#ee854a",
                "130": "#4878d0",
                "780": "#d65f5f",
                "700": "#6acc64"},
             orientation='h',
            category_orders={"month": ["July_2020", "August_2020", "September_2020",
                "November_2020", "December_2020", "January_2021",
                "February_2021", "April_2021", "May_2021",
                "June_2021", "July_2021"]},)

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 8.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='Total prymnesiophyte TPM per month',
    yaxis_title_text=None,
)

fig.show()

# save figure as svg
fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_total_TPM_per_month_station_log.svg", scale=1)

Now let's repeat these last plots but for Phaeocystis species

In [None]:
#tax_level = 'species'
tax_level = 'Name_to_Use'
aggregation_level = 'month'

#data_prymnesiophytes = data[(data['class'] == 'Prymnesiophyceae') & (data['genus'] == 'Phaeocystis') & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]
data_prymnesiophytes = data[(data['Taxogroup2_UniEuk'] == 'Prymnesiophyceae') & (data['Genus_UniEuk'] == 'Phaeocystis') & (data['p_ident'] >= 0.95)][[aggregation_level, tax_level, 'TPM']]

# Plot the total TPM per month
fig = px.histogram(data_prymnesiophytes.sort_values("month", ascending=False),
             x = "TPM",
             y = "month",
             color = tax_level,
            color_discrete_map={
                    "Rare": "#545454",
                    "Phaeocystis": "#F26B38",
                },
             orientation='h',
            category_orders={"month": ["July_2020", "August_2020", "September_2020",
                "November_2020", "December_2020", "January_2021",
                "February_2021", "April_2021", "May_2021",
                "June_2021", "July_2021"],
                "Genus_UniEuk": ["Rare", "Phaeocystis"]},)

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8,  # Set the font size
        color="#000000"  # Set the font color
    ),
    autosize=False,
    width= 11.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='Total phaeocystis TPM per month',
    yaxis_title_text=None,
    # Set range of x-axis
    xaxis_range=[0, 550000],
)

fig.show()

# save figure as svg
#fig.write_image(f"../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_phaeocystis_total_TPM_per_month_{tax_level}.svg", scale=1)
fig.write_image(f"../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_phaeocystis_total_TPM_per_month_{tax_level}_eukprot.svg", scale=1)

### Biodiversity

In [None]:
# Filter the data
## Get the diatoms
# tax_level = 'genus'
tax_level = 'Name_to_Use'
aggregation_level = 'sample'

## Since we'll be looking at the relative abundance of different diatom genera, we can only include reads that are annotated to a genus with a sufficient % sequence identity
data_prymnesiophytes = data[(data['Taxogroup2_UniEuk'] == 'Prymnesiophyceae') & (data['p_ident'] >= 0.9)][[aggregation_level, tax_level, 'TPM']]
# Remove transcripts below a certain TPM threshold
data_prymnesiophytes = data_prymnesiophytes[data_prymnesiophytes['TPM'] > 1]

# First, I only want to include transcriptome bins that have at least 100 non-zero expressed transcripts in one of the samples
# Group by month and genus, sum TPM
data_prymnesiophytes = data_prymnesiophytes.groupby([aggregation_level, tax_level]).count().reset_index()
# Now transform the data to the wide format
data_prymnesiophytes = data_prymnesiophytes.pivot(index=tax_level, columns=aggregation_level, values='TPM')
# Set NaN values to 0
data_prymnesiophytes = data_prymnesiophytes.fillna(0)

# Visualize the data
data_prymnesiophytes.head()

In [None]:
data_prymnesiophytes[data_prymnesiophytes > 100].count()

In [None]:
# Check how many diatom genera exist per month that have > 100 non-zero TPM transcripts
## Per sample, count the number of genera that have > 100 TPM
taxonomic_bin_abundance = data_prymnesiophytes[data_prymnesiophytes > 100].count(axis=0).reset_index()
taxonomic_bin_abundance.columns = ['sample', 'num_species']

## Add metadata
taxonomic_bin_abundance = taxonomic_bin_abundance.merge(meta[['month']], left_on='sample', right_index=True, how='left')

# Plot the number of diatom genera per month that have > 100 TPM

fig = px.box(taxonomic_bin_abundance, x='month', y='num_species',
             category_orders={'month': ['July_2020', "August_2020", "September_2020", "November_2020", 
               "December_2020", "January_2021", "February_2021", "April_2021", 
               "May_2021", "June_2021", "July_2021"]})

fig.show()

# Save the figure
fig.write_image("../../figures/diatoms_vs_dinoflagellates/num_prymnesiophyte_species_per_month.png", width = 3.5, height = 2, scale=1, format='svg')

In [None]:
# Extract the genera that have > 100 TPM in at least one sample
taxonomic_bin_abundance = data_prymnesiophytes[data_prymnesiophytes > 100].count(axis=1)
taxonomic_bin_abundance.head()

In [None]:
# Plot distribution of most abundant diatom genera
taxonomic_bin_abundance.sort_values(ascending=False).plot(kind='bar', figsize=(10, 5))

#### Trophic feeding mode of prymnesiophytes
I've predicted the trophic feeding mode of several species using the machine learning model of the paper by [Lambert & Groussman, 2022](https://doi.org/10.1073/pnas.2100916119), see [here](./trophic_mode_prediction.ipynb). We'll use the data to see how the trophic feeding mode of prymnesiophytes changes over time, same as for the diatoms and dinoflagellates.

In [None]:
# Load the data
#trophic_predictions = pd.read_csv('../../data/analysis/phylodb_trophic_mode_predictions.csv')
trophic_predictions = pd.read_csv('../../data/analysis/eukprot_trophic_mode_predictions.csv')

# Only retain the diatoms
#trophic_predictions = trophic_predictions[trophic_predictions['class'] == 'Prymnesiophyceae']
trophic_predictions = trophic_predictions[trophic_predictions['Taxogroup2_UniEuk'] == 'Prymnesiophyceae']
print(len(trophic_predictions))
trophic_predictions.head()

In [None]:
# Plot the relative abundance of the three trophic modes per month
# Calculate relative abundance of each prediction per month
rel_abundance = (trophic_predictions[['month', 'station', 'prediction']].groupby(['month', 'prediction'])
                       .size()
                       .groupby(level=0).apply(lambda x: 100 * x / x.sum())
                       .reset_index(name='relative abundance'))
# Set the color palette
color_map = {
    'Phot': 'green',
    'Mix': 'black',
    'Het': 'red'
}
fig = px.histogram(rel_abundance.sort_values("month", ascending=False),
             x = "relative abundance",
             y = "month",
             color = "prediction",
             orientation='h',
             color_discrete_map=color_map,
             category_orders={"month": ["July_2020", "August_2020", "September_2020",
                "November_2020", "December_2020", "January_2021",
                "February_2021", "April_2021", "May_2021",
                "June_2021", "July_2021"]})

pixels_per_cm = 37.7952755906

fig.update_layout(
    font=dict(
        family="Times New Roman, serif",  # Set the font family to Times New Roman
        size=8  # Set the font size
    ),
    autosize=False,
    width= 8.5 * pixels_per_cm,
    height= 7.5 * pixels_per_cm,
    margin=dict( # Set the margins
        l=0,  # Left margin
        r=25,  # Right margin
        b=25,  # Bottom margin
        t=25  # Top margin
    ),
    xaxis_title_text='Ratio of predicted trophic modes (%)',
    yaxis_title_text=None,
)

fig.show()

# save figure as svg
#fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophyte_trophic_mode_relative_abundance_phylodb.svg")
fig.write_image("../../figures/diatoms_vs_dinoflagellates/prymnesiophyte_trophic_mode_relative_abundance_eukprot.svg")

In [None]:
#tax_level = 'genus'
tax_level = 'Genus_UniEuk'

# Extract the relevant columns
df = trophic_predictions[[tax_level, 'month', 'prediction']]
# Convert 'prediction' into numerical categories
df['prediction'] = df['prediction'].map({'Phot':1, 'Mix':2, 'Het':3})

# Pivot the DataFrame to use for heatmap
df_pivot = df.pivot_table(index=tax_level, columns='month', values='prediction', 
                          aggfunc=lambda x: int(stats.mode(x[x!=0])[0][0] if any(x!=0) else 0))

# Fill NA with a specific category (e.g., 0 for 'no prediction')
df_pivot = df_pivot.fillna(0)

# Ensure the values are integer
df_pivot = df_pivot.astype(int)

# Set the order of the months
month_order = ["July_2020", "August_2020", "September_2020", "November_2020", 
                "December_2020", "January_2021", "February_2021", "April_2021", 
                "May_2021", "June_2021", "July_2021"]
df_pivot = df_pivot.reindex(month_order, axis=1)

# Set the order of the months
df_pivot = df_pivot.reindex(month_order, axis=1)

# Set the order of the genera to be the same as in the above relative abundance plot
df_pivot = df_pivot.reindex(prymnesiophyte_legend_order, axis=0)

# Create a color map
cmap = mcolors.ListedColormap(['lightgrey', 'green', 'yellow', 'red'])

# Create the heatmap with the color map
plt.figure(figsize=(4.5, 4))
sns.heatmap(df_pivot, cmap=cmap, annot=False, cbar=False, vmin=0, vmax=3)

# Change layout to match the relative abundance plot
plt.tight_layout()

# Remove the x and  y-axis label
plt.xlabel(None)
plt.ylabel(None)

# Font size should be 8, and font family Times New Roman
plt.rcParams.update({'font.size': 8, 'font.family': 'Times New Roman'})

# Save figure as svg
#plt.savefig("../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_trophic_mode_genus_consensus_heatmap.svg", format='svg', dpi=600)
plt.savefig("../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_trophic_mode_genus_consensus_heatmap_eukprot.svg", format='svg', dpi=600)

plt.show()

In [None]:
# Repeat for species level
# Extract the relevant columns
df = trophic_predictions[['species', 'month', 'prediction']]
# Convert 'prediction' into numerical categories
df['prediction'] = df['prediction'].map({'Phot':1, 'Mix':2, 'Het':3})

# Pivot the DataFrame to use for heatmap
df_pivot = df.pivot_table(index='species', columns='month', values='prediction', 
                          aggfunc=lambda x: int(stats.mode(x[x!=0])[0][0] if any(x!=0) else 0))

# Fill NA with a specific category (e.g., 0 for 'no prediction')
df_pivot = df_pivot.fillna(0)

# Ensure the values are integer
df_pivot = df_pivot.astype(int)

# Set the order of the months
df_pivot = df_pivot.reindex(month_order, axis=1)

# Create a color map
cmap = mcolors.ListedColormap(['lightgrey', 'green', 'yellow', 'red'])

# Create the heatmap with the color map
plt.figure(figsize=(4.5, 4))
sns.heatmap(df_pivot, cmap=cmap, annot=False, cbar=False, vmin=0, vmax=3)

# Change layout to match the relative abundance plot
plt.tight_layout()

# Remove the x and  y-axis label
plt.xlabel(None)
plt.ylabel(None)

# Font size should be 8, and font family Times New Roman
plt.rcParams.update({'font.size': 8, 'font.family': 'Times New Roman'})

# Save figure as svg
#plt.savefig("../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_trophic_mode_species_consensus_heatmap.svg", format='svg', dpi=600)
plt.savefig("../../figures/diatoms_vs_dinoflagellates/prymnesiophytes_trophic_mode_species_consensus_heatmap_eukprot.svg", format='svg', dpi=600)


plt.show()