<a href="https://colab.research.google.com/github/MELAI-1/WORKSHOPS-AND-SCIENTIFIC-OUTREACH/blob/main/I-X%20AI%20in%20Science-Imperial/Tutorial_Phylodynamics_ModelSelection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Tutorial for phylodynamics model selection**
Based on the method developed in Perez M.F. and Gascuel O.PhyloCNN: Improving tree representation and neural network architecture for deep learning from trees in phylodynamics and diversification studies. https://www.biorxiv.org/content/10.1101/2024.12.13.628187v1

## **1. Introduction**
This tutorial shows how to train a CNN model that classify phylogentic trees of viruses according to three competing epidemiological (phylodynamics) models.

Phylodynamics relies on phylogenetic trees, which are build based on aligned genetic sequences of the pathogen taken from infecte individuals. The trees are calibrated (dated) to reflect transmission times.
<img src="https://drive.google.com/uc?export=view&id=1sQ4hClFJs9xZoSo_ScmtuxYMoa05RSXh" width="600" height="300">

Figure 1. from [Guinat et al., 2021](https://www.sciencedirect.com/science/article/pii/S0169534721001300).

We will compare three competing epidemiological (phylodynamics) models - Birth-Death (BD), Birth-Death Exposed Infectious (BDEI) and Birth-Death with Superspreaders (BDSS).

<img src="https://drive.google.com/uc?export=view&id=1FxkO0Qisu6m1_Znc76MMbd6ZVUSPWuAl" width="500" height="300">

The simulated trees were encoded by describing the neighborhood (e.g., length of outgoing branches) and main measurements (e.g., date, number of descendants) of all nodes and leaves of the phylogeny.

<img src="https://drive.google.com/uc?export=view&id=1FysAnN2H8C312yQFtAWeFW7OSRAAMdrv" width="750" height="750">

## **2. Libraries and Data Loading**
We import the required python libraries and then we load phylogenetic trees simulated under each of the 3 models (BD, BDEI, BDSS).


In [1]:
#First you need to download the data.
!gdown --id 1GHLYw3EezrtrMkJDBXY8FNZ4FjyV3Vnn

Downloading...
From (original): https://drive.google.com/uc?id=1GHLYw3EezrtrMkJDBXY8FNZ4FjyV3Vnn
From (redirected): https://drive.google.com/uc?id=1GHLYw3EezrtrMkJDBXY8FNZ4FjyV3Vnn&confirm=t&uuid=f17779ba-4cfa-4be3-8390-997cfe194625
To: /content/PhyloDyn.zip
100% 70.1M/70.1M [00:00<00:00, 128MB/s]


In [2]:
#Unzip simulations
!unzip "/content/PhyloDyn.zip"

Archive:  /content/PhyloDyn.zip
   creating: PhyloDyn/
  inflating: __MACOSX/._PhyloDyn     
  inflating: PhyloDyn/.DS_Store 2    
  inflating: __MACOSX/PhyloDyn/._.DS_Store 2  
  inflating: PhyloDyn/.DS_Store      
  inflating: __MACOSX/PhyloDyn/._.DS_Store  
  inflating: PhyloDyn/Encoded_Zurich.csv  
  inflating: __MACOSX/PhyloDyn/._Encoded_Zurich.csv  
  inflating: PhyloDyn/BDSS_large_100K.csv  
  inflating: __MACOSX/PhyloDyn/._BDSS_large_100K.csv  
  inflating: PhyloDyn/Encoded_Zurich.npy  
  inflating: __MACOSX/PhyloDyn/._Encoded_Zurich.npy  
   creating: PhyloDyn/testset/
  inflating: __MACOSX/PhyloDyn/._testset  
  inflating: PhyloDyn/Encoded_trees_BDSS.npy  
  inflating: __MACOSX/PhyloDyn/._Encoded_trees_BDSS.npy  
  inflating: PhyloDyn/Encoded_trees_BDEI.npy  
  inflating: __MACOSX/PhyloDyn/._Encoded_trees_BDEI.npy  
  inflating: PhyloDyn/Encoded_trees_BD.npy  
  inflating: __MACOSX/PhyloDyn/._Encoded_trees_BD.npy  
  inflating: PhyloDyn/testset/.DS_Store  
  inflating: __MACO

In [13]:
import pandas as pd
import tensorflow as tf
import keras
import numpy as np

from keras.models import Sequential, Model
from keras.layers import Activation, Dense
from keras.layers import Conv2D, GlobalAveragePooling2D, BatchNormalization
from keras.layers import Dense, Dropout, Activation, Flatten

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

# 1) Load tree encodings for BD, BDEI, BDSS models. For each model, we will load tree
# encodings for training (1,000 trees per model) and for testing (testset - 100 trees per model).
# The encodings are separated in two channels (one for internal nodes and another
# for the leaves of the tree). The trees have a maximum of 500 tips (leaves).
encoding_BD = np.load('/content/PhyloDyn/Encoded_trees_BD.npy')
encoding_test_BD = np.load('/content/PhyloDyn/testset/Encoded_trees_BD.npy')
encoding_BDEI  = np.load('/content/PhyloDyn/Encoded_trees_BDEI.npy')
encoding_test_BDEI = np.load('/content/PhyloDyn/testset/Encoded_trees_BDEI.npy')
encoding_BDSS  = np.load('/content/PhyloDyn/Encoded_trees_BDSS.npy')
encoding_test_BDSS = np.load('/content/PhyloDyn/testset/Encoded_trees_BDSS.npy')

#### Exercise 1: Data Visualization

**Question 1a (add code below and copy it on your assessment form at the end):** Fill the code cell below to check the shape of the loaded inputs for training trees and test trees.

In [16]:
# Add code below to recover the shape from encodings of training and test trees.
print(f"the shape of the encodings of Birth-Death (BD) train tree is",encoding_BD.shape)
print(f"the shape of the encodings of Birth-Death (BD) test tree is",encoding_test_BD.shape)
print(f"the shape of the encodings of Birth-Death Exposed Infectious (BDEI) train tree is",encoding_BDEI.shape)
print(f"the shape of the encodings of Birth-Death Exposed Infectious (BDEI) test tree is",encoding_test_BDEI.shape)
print(f"the shape of the encodings of Birth-Death with Superspreaders (BDSS)train tree is",encoding_BDSS.shape)
print(f"the shape of the encodings of Birth-Death with Superspreaders (BDSS) test tree is",encoding_test_BDSS.shape)



the shape of the encodings of Birth-Death (BD) train tree is (1000, 500, 19, 2)
the shape of the encodings of Birth-Death (BD) test tree is (100, 500, 19, 2)
the shape of the encodings of Birth-Death Exposed Infectious (BDEI) train tree is (1000, 500, 19, 2)
the shape of the encodings of Birth-Death Exposed Infectious (BDEI) test tree is (100, 500, 19, 2)
the shape of the encodings of Birth-Death with Superspreaders (BDSS)train tree is (1000, 500, 19, 2)
the shape of the encodings of Birth-Death with Superspreaders (BDSS) test tree is (100, 500, 19, 2)


**Question 1b (add the answer to your assessment form):** What does the second dimension (value 500) in the arrays represent?
- A. The number of simulations  
- B. The number of features per node  
- C. The number of mavimum leaves per tree  
- D. The number of epidemiological models  

answer:
c-The number of maximum leaves per tree

## 3. Data Preprocessing
We will process the input to be properly formatted before feeding it to the neural network. This will involve the following steps:

### Label Assignment
We create a label array **Y** for the training and test set, with:
- `0` for BD
- `1` for BDEI
- `2` for BDSS

In [17]:

#Add labels for each simulation (a different label for each model)
Y = [0 for i in range(len(encoding_BD))]
Y.extend([1 for i in range(len(encoding_BDEI))])
Y.extend([2 for i in range(len(encoding_BDSS))])
Y = np.array(Y)

Y_test = [0 for i in range(len(encoding_test_BD))]
Y_test.extend([1 for i in range(len(encoding_test_BDEI))])
Y_test.extend([2 for i in range(len(encoding_test_BDSS))])
Y_test = np.array(Y_test)

#We **one-hot encode** `Y` (since it’s a 3-class classification)
Y = np.eye(3)[Y]

In [18]:
#Combine encodings from the 3 models
encoding = np.concatenate((encoding_BD,encoding_BDEI,encoding_BDSS),axis=0)
encoding_test = np.concatenate((encoding_test_BD,encoding_test_BDEI,encoding_test_BDSS),axis=0)

In [19]:
### Splitting Data into Training & Validation
# 30% for validation
Y, Y_valid, encoding, encoding_valid = train_test_split(Y,encoding,test_size=0.3, shuffle=True,stratify=Y)

#### Exercise 2: Data split and stratification

**Question 2 (add explanation to the assessment form):** a) What is the validation set and why is it useful? b) Why do we need to shuffle the order of labels and trees? c) What is the advantage to using 'stratify=Y' in our example?

## Answers:
a) The validation set is a portion of the dataset set aside to evaluate a model's performance during training, and it is useful because it enables hyperparameter tuning, early stopping to prevent overfitting, provides realistic estimates of model generalization to unseen data, and facilitates comparisons between models based on performance metrics.

b) We need to shuffle the order of labels and trees to eliminate any order bias that may exist, ensure that both training and validation sets are representative of the entire dataset, prevent overfitting to sequence patterns in ensemble methods, and avoid biased distributions in cross-validation folds.

c) The advantage of using stratify=Y is that it ensures that the proportions of different classes in the dataset are preserved in both the training and validation sets, leading to more reliable and consistent model evaluation.

## 4. Building & Training the CNN (2-Generation Context)

### Model Definition
We define a CNN that processes input of shape `(500, 19, 2)`:
- 500 = number of leaves or nodes
- 19 = number of features
- 2 = channels (leaves, nodes)

This architecture was inspired by the fact that internal nodes and leaves contribute differently to the tree likelihood calculation for multi-type birth-death models (MTBD, which includes BD, BDEI and BDSS; see Equation 8 in [Zhukova et al., 2023](https://academic.oup.com/sysbio/article/72/6/1387/7273092))

<img src="https://drive.google.com/uc?export=view&id=1FvkaeBLF42DuYYgePIj3NhKetzK3Abj6" width="1000" height="500">

<img src="https://drive.google.com/uc?export=view&id=1Fzol42i8u8hvSC6DEMDM3ScsoyeW4TTx" width="500" height="340">



## 4 Build the Neural Network Model <p id="build"> </p>

In [20]:
# Creation of the Network Model: model definition
def build_model():
    # Initialize the Sequential model
    model = Sequential()

    # First convolutional layer:
    # - Filters: 32
    # - Kernel size: (1, 19), sliding across the second dimension of the input
    # - Input shape: (500, 19, 2) where 500 is the number of tree leaves/nodes, 19 is the feature size, and 2 is the number of channels (leaves and nodes)
    # - Activation function: ELU (Exponential Linear Unit)
    # - Groups: 2 to apply separate convolutions for the two channels (leaves and nodes)
    model.add(Conv2D(filters=32, use_bias=False, kernel_size=(1, 19), input_shape=(500, 19, 2), activation='relu', groups=2))

    # Apply batch normalization to stabilize and speed up the training process
    model.add(BatchNormalization())

    # Second convolutional layer:
    # - Filters: 32
    # - Kernel size: (1, 1) to process each feature independently
    # - Activation function: ELU
    model.add(Conv2D(filters=32, use_bias=False, kernel_size=(1, 1), activation='relu'))

    # Apply batch normalization again
    model.add(BatchNormalization())

    # Third convolutional layer:
    # - Filters: 32
    # - Kernel size: (1, 1) for further feature processing
    # - Activation function: ELU
    model.add(Conv2D(filters=32, use_bias=False, kernel_size=(1, 1), activation='relu'))

    # Apply batch normalization for the final time before flattening
    model.add(BatchNormalization())

    # Flatten the 2D feature maps from the convolutional layers into a 1D vector,
    # which will be passed to the fully connected (dense) layers
    model.add(GlobalAveragePooling2D())

    # Fully connected (FFNN) part:
    # Dense layers with decreasing number of units, all using ELU activation:
    model.add(Dense(64, activation='relu'))   # First dense layer with 64 units
    model.add(Dense(32, activation='relu'))   # Second dense layer with 32 units
    model.add(Dense(16, activation='relu'))   # Third dense layer with 16 units
    model.add(Dense(8, activation='relu'))    # Fourth dense layer with 8 units

    # Output layer:
    # - 3 output neurons, corresponding to the 3 models
    # - Activation function: softmax
    model.add(Dense(3, activation='softmax'))

    # Show the summary of the model structure (number of layers, shapes of outputs, etc.)
    model.summary()

    # Return the constructed model
    return model

Now we compile and fit the model.

In [21]:
from keras import losses

# Initialize the model using the build_model function that was previously defined
estimator = build_model()

# Compile the model:
# - Loss function: categorical_crossentropy is used to measure the error between the predicted probability distribution and the true distribution for multi-class classification tasks.
# - Optimizer: 'Adam' is used to minimize the loss function efficiently
# - Metrics: Accuracy is used to track the model's performance during training
estimator.compile(loss=keras.losses.categorical_crossentropy, optimizer = 'Adam', metrics=['accuracy'])

# Early stopping callback to prevent overfitting:
# - monitor: monitor the validation accuracy during training
# - patience: stop training if the validation accuracy doesn't improve for 100 consecutive epochs
# - mode: 'max' indicates that training will stop when the validation accuracy reaches its maximum
# - restore_best_weights: restore the weights from the best epoch (the one with the highest validation accuracy)
early_stop = keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=100, mode='max', restore_best_weights=True)

# Custom callback to display training progress:
# - Print a dot for every epoch (or newline every 100 epochs) to indicate progress in training
class PrintD(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs):
        if epoch % 100 == 0:  # Print a newline every 100 epochs
            print('')
        print('.', end='')  # Print a dot to indicate progress during each epoch

# Set the maximum number of epochs (iterations over the entire dataset)
EPOCHS = 1000

# Train the model using the `fit` method:
# - encoding_pad: The padded training data (inputs)
# - Y: The target values (outputs)
# - verbose: set to 1 to print progress during training
# - epochs: The number of times to iterate over the entire dataset
# - validation_split: the fraction of data to use for validation (used to monitor validation loss)
# - batch_size: the number of samples per gradient update
# - callbacks: list of callbacks to be used during training (early stopping and progress display)
history = estimator.fit(encoding, Y, verbose=1, epochs=EPOCHS, validation_data=(encoding_valid, Y_valid), batch_size=32, callbacks=[early_stop, PrintD()])

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/1000
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.3457 - loss: 1.0920
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 61ms/step - accuracy: 0.3464 - loss: 1.0916 - val_accuracy: 0.4444 - val_loss: 1.0723
Epoch 2/1000
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 11ms/step - accuracy: 0.5962 - loss: 0.8341 - val_accuracy: 0.5989 - val_loss: 0.7598
Epoch 3/1000
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 9ms/step - accuracy: 0.6309 - loss: 0.5898 - val_accuracy: 0.6511 - val_loss: 0.7612
Epoch 4/1000
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 9ms/step - accuracy: 0.8128 - loss: 0.4774 - val_accuracy: 0.6500 - val_loss: 0.6574
Epoch 5/1000
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 13ms/step - accuracy: 0.8636 - loss: 0.3515 - val_accuracy: 0.5578 - val_loss: 0.8360
Epoch 6/1000
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1

### Evaluate the trained model
We evaluate our classifier by using the test set, which was not seen by the network during training. We plot the results as a confusion matrix.

In [22]:
# Evaluate on test set
predicted_test = np.array(estimator.predict(encoding_test))
pred_cat = [i.argmax() for i in predicted_test]

# Confusion matrix
print(confusion_matrix(Y_test, pred_cat))

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 80ms/step
[[93  6  1]
 [15 81  4]
 [ 5  1 94]]


#### Exercise 3: Confusion matrix

**Question 3 (write the answer at the assessment form):** Examine the confusion matrix produced after evaluating the model on the test set. a) What does the confusion matrix reveal about the model’s performance? b) What indications in the matrix would suggest that the model is biased toward one particular class?

Anwers:
a) The confusion matrix reveals how well the model is performing by showing the number of correct and incorrect predictions for each class, allowing us to identify areas of strength and weakness in the model's predictions.  We observe that the model predicts more correct values (93 percent accuracy for BD,81 percent accuracy for BDEI, 94 percent accuracy for BDSS)

b) Indications of bias toward one particular class include significantly higher values along the diagonal for that class compared to others (indicating more correct predictions) and lower values in the corresponding rows or columns for other classes, which may show a higher number of false positives or false negatives. for the case of our  model we can not say that our model is bias because the difference is not soo big.

Now you can compare the obtained accuracy with other State of the Art approaches. :

<img src="https://drive.google.com/uc?export=view&id=1mamPD_VCI74Y8LzhnHHNyLzfIFqZA8cO" width="300" height="500">

Note that we are using 1,000 trees per model for training, compared to 4 million trees of each model to train the FFNN-SS and CNN-CBLV of [Voznica et al. (2022)](https://www.nature.com/articles/s41467-022-31511-0#Sec29).



## 5. Predicting empirical (real) data.
Our trained network can now be used to predict the most likely epidemiological model on real datasets.
We will use the the phylogenetic tree from [Rasmusen et al. (2017)](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1005448) with 200 HIV-1 sequences collected as part of the [Swiss Cohort Study (2010)](https://academic.oup.com/ije/article/39/5/1179/799735).

<img src="https://drive.google.com/uc?export=view&id=1Fzc9naQ8ACbL9i6_ZDWh1o8GhIvJ1ql4" width="500" height="340">

In [23]:
# Load the data
encoding_Zurich = np.load('/content/PhyloDyn/Encoded_Zurich.npy')


# predict values for the empirical dataset
predicted_emp = np.array(estimator.predict(encoding_Zurich))

# Print the results
print("  BD            BDEI          BDSS")
print(predicted_emp)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 640ms/step
  BD            BDEI          BDSS
[[1.3270288e-25 1.0000000e+00 1.9609776e-22]]


#### Exercise 4: Analysis of Empirical Data Predictions

**Question 4 (write answer at the assessment form):** The trained model predicts epidemiological models for the HIV data (Zurich dataset). a) Which model was selected, and how does this compare to the results reported in the paper (BDSS with superspreaders)? b) If your prediction differs from BDSS, what factors might explain the discrepancy?