# PLS-DA analysis tutorial

### Import the required packages

In [None]:
import pandas as pds
import numpy as np
from sklearn.cross_decomposition import PLSRegression
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, f1_score, roc_curve, RocCurveDisplay, r2_score
from sklearn.pipeline import Pipeline
import plotly.express as px

### Load the example dataset using pandas

In [None]:
lcMSData = pds.read_csv('./Data/Dementia_RPOS_XCMS.csv')

In [None]:
# Read the retention time and m/z value from feature names
featuresData = pds.DataFrame([(float(x.split('_')[0]), float(x.split('_')[1][:-3])) for x in lcMSData.columns[11:]], columns=['Rt', 'mz'])
featuresData['Rt'] = featuresData['Rt']/60
medianSpectrum = np.median(lcMSData.iloc[:, 11:].values, axis=0)

# Use log of median spectra as intensity value for the scatterplot
featuresData['Median'] = np.log(medianSpectrum + 1)
#featuresData['Median'] = medianSpectrum 

## PLS-DA model

The first step in a PLS-DA model is to fit a PLS regression model with a dummy vector/matrix encoding class membership as Y.

In [None]:
# Select only the MS features
XDataMatrix = lcMSData.iloc[:, 11:].values

# Log transform the data matrix 
logXDataMatrix = np.log(XDataMatrix + 1)

YGender = pds.Categorical(lcMSData['Gender'].values)
YGenderDummy = pds.Categorical(YGender).codes

In [None]:
# pandas Categorical object associates the "categories" text name with a numerical code
# The numerical code follows the order in the .categories index
YGender.categories

The YGenderDummy vector is now a vector of 0s and 1s, where 0=Female and 1=Male

In [None]:
YGenderDummy

In [None]:
# Regular apply PLS with dummy vector as Y
plsModel = Pipeline(steps=[('uv_scale', StandardScaler()), ('PLS', PLSRegression(n_components=2, scale=False))])

plsModel.fit(logXDataMatrix, YGenderDummy)

### PLS-DA model prediction

The predictions of a PLS regression model is a continuous value. To convert this number into a class prediction, we need an extra classification rule or algorithm. The simplest procedure is to assign the class membership which is closest to the predicted value. For example, the class will be 0 (Female) if prediction < 0.5, or Male if > 0.5.

In [None]:
predictFrame = pds.DataFrame(np.c_[plsModel.predict(logXDataMatrix), YGenderDummy], columns=['Predicted', 'Gender'])

fig = px.scatter(predictFrame, x="Predicted", y="Gender", render_mode='webgl', 
                labels={"Predicted": "PLS predicted Gender",
                        "Gender": "Gender"}, 
                template='plotly_white')

fig.add_vline(x=0.5, line_dash="dash")

fig.show()

We will instead convert the PLS outputs into a class prediction using a logistic regression model. The class will be predicted with the logistic regression model using the PLS T-scores (class ~ PLS scores). 

In [None]:
daModel = LogisticRegression()

plsScores = plsModel.transform(X=logXDataMatrix)

# Fit the logistic regression model with the scores
daModel.fit(plsScores, YGenderDummy)

# Obtain the test set scores and the prediction
plsDaClassification = daModel.predict(plsScores)

# ROC curve 
RocCurveDisplay.from_estimator(daModel, X=plsScores, y=YGenderDummy)

# Score ROC AUC
"ROC AUC: {0}".format(roc_auc_score(plsDaClassification, YGenderDummy))

_The model seems to perform very well..._

### PLS-DA scores plot
Lets's now examine the PLS scores plot...

In [None]:
T_scores = plsModel.transform(logXDataMatrix)

# Assemble a pandas data frame with the scores for each component and then combine with study variables
plsResultsDFrame = pds.DataFrame(T_scores, columns=['PLS T' + str(x+1) for x in range(T_scores.shape[1])])
plsResultsDFrame = pds.concat([lcMSData.loc[:, ['Subject ID', 'Sample ID', 'Age', 'Gender', 'Run Order', 'Acquisition batch']], plsResultsDFrame], axis=1)

In [None]:
fig = px.scatter(plsResultsDFrame, x="PLS T1", y="PLS T2", color="Gender",
                 render_mode='webgl', 
                template='plotly_white')

fig.show()

The PLS parameters are exactly the same as those in the PLS regression - see the PLS tutorial for more information.

In [None]:
fig = px.scatter(featuresData, x="Rt", y="mz", color=plsModel['PLS'].x_weights_[:, 0], render_mode='webgl', 
                color_continuous_scale='RdBu', color_continuous_midpoint=0,
                labels={"Rt": "Retention time (min)",
                        "mz": "m/z"}, 
                template='plotly_white')

fig.show()

### Model validation and overfitting

The model ROC curve and ROC AUC values we obtained were very good (ROC AUC > 0.95)!!
But can we trust the discrimination results we obtained? Is PLS that prone to overfitting and over-optimism?

Lets do a simple test: refit a model with a random Y vector...

In [None]:
# Random resampling of the original Y vector
YGenderFake = np.random.choice(YGenderDummy, size=len(YGenderDummy))

plsModel = Pipeline(steps=[('uv_scale', StandardScaler()), ('PLS', PLSRegression(n_components=2, scale=False))])

plsModel.fit(logXDataMatrix, YGenderFake)

In [None]:
predictFrame = pds.DataFrame(np.c_[plsModel.predict(logXDataMatrix), YGenderFake], columns=['Predicted', 'Gender'])

fig = px.scatter(predictFrame, x="Predicted", y="Gender", render_mode='webgl', 
                labels={"Predicted": "PLS predicted Gender",
                        "Gender": "Gender"}, 
                template='plotly_white')

fig.add_vline(x=0.5, line_dash="dash")

fig.show()

In [None]:
T_scores = plsModel.transform(logXDataMatrix)

GenderFakeColumn = pds.Series(YGenderFake).map({0:'Female', 1:'Male'})
# Assemble a pandas data frame with the scores for each component and then combine with study variables
plsResultsDFrame = pds.DataFrame(T_scores, columns=['PLS' + str(x+1) for x in range(T_scores.shape[1])])
plsResultsDFrame = pds.concat([lcMSData.loc[:, ['Subject ID', 'Sample ID', 'Age', 'Gender', 'Run Order', 'Acquisition batch']], pds.Series(GenderFakeColumn, name='GenderFake'), plsResultsDFrame], axis=1)

In [None]:
fig = px.scatter(plsResultsDFrame, x="PLS1", y="PLS2", color="GenderFake",
                 render_mode='webgl', 
                template='plotly_white')

fig.show()

... and this is why PLS scores plots cannot be trusted to check a model quality. Separation in a PLS score plot is **NOT** a good measure of model quality.

### Model cross-validation

Instead, we will use cross-validation to obtain reliable model performance estimates. 

The following code uses a stratified (preserving the % prevalence of each class in the test set) K-Fold cross-validation routine to obtain ROC AUC, f1-score, and r-squared values which were calculated on external test set data (data not used to fit the model).

In [None]:
# Define a function to fit and cross-validate a PLS-DA model
def crossValidate_PLSDA(x, y, n_components=2, scale=True, cv=StratifiedKFold(7)):
    
    if scale is True:
        plsModel = Pipeline(steps=[('uv_scale', StandardScaler()), 
                                   ('PLS', PLSRegression(n_components=n_components, scale=False))])
    else:
        plsModel = Pipeline(steps=[('PLS', PLSRegression(n_components=n_components, scale=False))])
        
    daModel = LogisticRegression()

    cvResults = {'roc_auc':[], 'f1':[], 'r2':[]}

    # Iterate through CV rounds
    for trainIdx, testIdx in cv.split(x, y):
        
        # Fit the PLS model on training set
        plsModel.fit(x[trainIdx, :], y[trainIdx])
        
        cvResults['r2'].append(r2_score(y[testIdx], plsModel.predict(x[testIdx, :])))
        # Obtain the scores from the training set
        plsTrainScores = plsModel.transform(X=x[trainIdx, :])
        # Fit the QDA model with the train set scores and Y train set
        daModel.fit(plsTrainScores, y[trainIdx])
        
        # Obtain the test set scores and the prediction
        plsTestScores = plsModel.transform(X=x[testIdx, :])
        testPredicted = daModel.predict(plsTestScores)
        
        # Score ROC AUC
        cvResults['roc_auc'].append(roc_auc_score(testPredicted, y[testIdx]))
        cvResults['f1'].append(f1_score(testPredicted, y[testIdx]))
        # cvResults['roc'].append(roc_curve(YGender[testIdx], qdaModel.predict_proba(plsTestScores)[:,0]))
        
    cvResults = {key: np.array(value) for key, value in cvResults.items()}
    return pds.DataFrame(cvResults)

In [None]:
cvResults = crossValidate_PLSDA(logXDataMatrix, YGenderDummy, n_components=2, scale=True, cv=StratifiedKFold(7))

The result of the 7-Fold CV process is 7 instances of the classifier performance metrics chosen (roc_auc, f1, r2). 

In [None]:
cvResults

### Selecting the optimal number of components with cross-validation

Cross-validation should also be used to select the optimal number of components. The CV procedure should be applied to models with a varying number of components, to generate a "scree plot" with cross-validated measures.

In [None]:
maxNComponents = 10

screePLSDA = [crossValidate_PLSDA(logXDataMatrix, YGenderDummy, n_components=x, scale=True, cv=StratifiedKFold(7)) for x in range(1, maxNComponents + 1)]

In [None]:
cvPLSDA_DFrame = list()

for ncomp, cv in enumerate(screePLSDA):
    currentNComp = pds.DataFrame(cv)
    currentNComp['Ncomp'] = ncomp + 1
    cvPLSDA_DFrame.append(currentNComp)
    
cvPLSDA_DFrame = pds.concat(cvPLSDA_DFrame, axis=0)

In [None]:
fig = px.box(cvPLSDA_DFrame, x='Ncomp', y='roc_auc', # points="all",
             labels={"Ncomp": "Number of components",
                        "auc": "ROC AUC"}, template='plotly_white')

fig.show()

The gains in model performance after 4 components become marginal, and therefore we will select 4 as the optimal number of components.

### Fit model with optimal number of PLS components

In [None]:
plsModel = Pipeline(steps=[('uv_scale', StandardScaler()), 
                                   ('PLS', PLSRegression(n_components=4, scale=False))])

daModel = LogisticRegression()

# Fit the PLS-DA model to the full dataset
plsModel.fit(logXDataMatrix, YGenderDummy)
daModel.fit(plsModel.transform(logXDataMatrix), YGenderDummy)

In [None]:
cvResults = crossValidate_PLSDA(logXDataMatrix, YGenderDummy, n_components=4, scale=True, cv=StratifiedKFold(7))

These cross-validated metrics are better estimates of the expected model performance.

In [None]:
pds.DataFrame(np.c_[cvResults.mean(), cvResults.std()], columns=['Mean', 'Stdev'], index=cvResults.columns)

### Permutation randomisation test

A final and very important method for model validation is the permutation randomization test. In a permutation randomisation test, the model will be refitted and assessed multiple times, but each time with the Y randomly permuted to destroy any relationship between X & Y. This allows us to assess what sort of model we can get when there really is no relationship between the two data matrices, and calculate the likelihood of obtaining a model with predictive performance as good as the non-permuted model by chance alone.

During this test, the number of components, scaling, type of cross-validation employed, and any other modeling choice is kept constant. In each randomization, the model is refitted, and the performance and model validation metric recorded. This enables the generation of permuted null distributions for these metrics, which can be used to obtain an empirical p-value for their significance.

**Note**: Running the permutation test with a large number of permutation randomizations (for example, 1000) is expected to take a considerable ammount of time (> 30 mins on a laptop).

In [None]:
nPermutations = 50

permResults = []

for permutation in range(nPermutations):
        # permute the Y vector
        permutedY = np.random.permutation(YGenderDummy)
        # Select the same number of components, and apply cross-validation in the same manner
        permcvResults = crossValidate_PLSDA(logXDataMatrix, permutedY, n_components=4, scale=True, cv=StratifiedKFold(7))
        permResults.append(permcvResults.mean())

permResults = pds.DataFrame(permResults)
    

Histogram of results from permuted (null) models. The vertical line represents the ROC AUC value obtained in the "real" model. **Note**: The numerical precision of the p-value estimate is dependent on the number of permutations used.

In [None]:
fig = px.histogram(permResults, x='roc_auc', nbins=20)

fig.add_vline(x=cvResults['roc_auc'].mean(), line_dash="dash")
fig.show()
"Permutation p-value ~ {0}".format(np.sum(permResults['roc_auc'] >= cvResults['roc_auc'].mean())/(nPermutations + 1))