### This Notebook Evaluates Replicate Out Cross-Validation to Better Measure how the model generalises

Group KFold Cross-Validation prevents spectra from the same Surface appearing within both the training and test folds.

This gives a better indication of the model's ability to generalise, as it stops leakage between samples from the same fold.

Import Libraries

In [27]:
import pandas as pd
import seaborn as sns
import numpy as np
from scipy.signal import savgol_filter
from scipy import sparse
from scipy.sparse.linalg import spsolve
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import StratifiedKFold, KFold, GroupKFold
from sklearn.model_selection import LeavePGroupsOut

Read the spectral data

In [28]:
#df = pd.read_csv("../../data/exosomes.raw_spectrum_1.csv")

In [29]:
df = pd.read_csv("../../data/current_clean_spectrum.csv")

In [30]:
df['SpecID'].nunique()

3045

In [31]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6239205 entries, 0 to 6239204
Data columns (total 6 columns):
 #   Column      Dtype  
---  ------      -----  
 0   SpecID      object 
 1   Seq         int64  
 2   WaveNumber  float64
 3   SurID       object 
 4   Status      object 
 5   Absorbance  float64
dtypes: float64(2), int64(1), object(3)
memory usage: 285.6+ MB


#### Train an Extra Trees Classifier on the full spectrum.

In [32]:
def prepare_wavelength_df(df, absorbance_col, status_col='Status'):

    # Step 1: Group by 'SurID' and 'WaveNumber' and calculate median absorbance
    median_absorbance = df.groupby(['SurID', 'WaveNumber'])[absorbance_col].median().reset_index()

    # Step 2: Pivot the table to get 'WaveNumber' as columns, 'SurID' as index, and median absorbance as values
    wavelength_df = median_absorbance.pivot(index='SurID', columns='WaveNumber', values=absorbance_col)

    # Merge with the statuses based on SpecID
    # Include the SurID to perform GroupKFold CV
    statuses_and_surface = df[['SurID', status_col]].drop_duplicates()
    wavelength_df = pd.merge(wavelength_df, statuses_and_surface, on='SurID')

    # Set SpecID as the index
    wavelength_df = wavelength_df.set_index('SurID')

    return wavelength_df

In [33]:
wavelength_df = prepare_wavelength_df(df, 'Absorbance')
wavelength_df

Unnamed: 0_level_0,400.22778,400.91116,401.59454,402.27789,402.96127,403.64465,404.32803,405.01138,405.69476,406.37814,...,1794.3053,1794.9886,1795.672,1796.3553,1797.0387,1797.722,1798.4055,1799.0889,1799.7722,Status
SurID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
201210-1,-1.261773,-1.487437,-1.493051,-1.505201,-1.523022,-1.545650,-1.572220,-1.601867,-1.633727,-1.666934,...,-1.077548,-1.061438,-1.061623,-1.053362,-1.045736,-1.028861,-1.024108,-1.049046,-1.174049,Normal
201210-2,-0.894710,-0.921115,-1.038404,-1.084028,-1.113839,-1.164607,-1.191006,-1.231913,-1.250772,-1.241953,...,-1.883608,-1.884308,-1.885187,-1.870041,-1.856400,-1.844880,-1.846040,-1.798168,-1.732828,Normal
210114-1,0.386497,0.421640,0.386049,0.395677,0.410873,0.434603,0.437952,0.448156,0.449485,0.440940,...,-1.595796,-1.597691,-1.592841,-1.585158,-1.566326,-1.567910,-1.560230,-1.539894,-1.520302,Normal
210114-2,0.342039,0.340826,0.308961,0.252493,0.310173,0.313019,0.305763,0.298428,0.326233,0.343486,...,-1.556205,-1.554188,-1.542800,-1.528400,-1.507333,-1.491158,-1.489146,-1.434492,-1.429440,Normal
210120-1,0.746662,0.707318,0.662128,0.625805,0.597256,0.555120,0.485850,0.445487,0.423553,0.387052,...,-1.078026,-1.078357,-1.070442,-1.068339,-1.066159,-1.075257,-1.059876,-1.058883,-1.053197,Hyperglycemia
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
210519-3,0.261464,0.229795,0.205900,0.206404,0.216478,0.231820,0.269926,0.308085,0.346080,0.375113,...,-1.496688,-1.489000,-1.480562,-1.491120,-1.500803,-1.500677,-1.510344,-1.516037,-1.513067,Hyperglycemia
210524-1,-1.004154,-1.035637,-1.045777,-1.069023,-1.043572,-1.059675,-1.081360,-1.100711,-1.116490,-1.128900,...,-1.711449,-1.726570,-1.728099,-1.724006,-1.711004,-1.693228,-1.651810,-1.598485,-1.641007,Hypoglycemia
210526-1,-1.590005,-1.661008,-1.709788,-1.716623,-1.712634,-1.733227,-1.743411,-1.735913,-1.737246,-1.734658,...,-1.852672,-1.842897,-1.831211,-1.816104,-1.815327,-1.798290,-1.792396,-1.782048,-1.754277,Hyperglycemia
210526-2,-2.318242,-2.341874,-2.364037,-2.351647,-2.352406,-2.409253,-2.390095,-2.400872,-2.418729,-2.416327,...,-1.884517,-1.873499,-1.866623,-1.867632,-1.863514,-1.890710,-1.925062,-1.900814,-1.926573,Hyperglycemia


>**The Count of Surface IDs and the number of associated samples**

In [34]:
len(wavelength_df.groupby(['SurID']))

63

It looks like each Surface is associated with 1 Status

>**The Count of Spectra with each Status**

In [35]:
df.groupby('Status')['SpecID'].nunique().reset_index(name='Spectra Count')

Unnamed: 0,Status,Spectra Count
0,Hyperglycemia,915
1,Hypoglycemia,1065
2,Normal,1065


>**The Count of Surfaces with each Status**

In [36]:
df.groupby('Status')['SurID'].nunique().reset_index(name='Surface Count')

Unnamed: 0,Status,Surface Count
0,Hyperglycemia,19
1,Hypoglycemia,22
2,Normal,22


>#### **Train an Extra Trees Classifier on the Raw Spectrum and evaluate it with GroupKFold cross-validation.**

In [37]:
wavelength_df

Unnamed: 0_level_0,400.22778,400.91116,401.59454,402.27789,402.96127,403.64465,404.32803,405.01138,405.69476,406.37814,...,1794.3053,1794.9886,1795.672,1796.3553,1797.0387,1797.722,1798.4055,1799.0889,1799.7722,Status
SurID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
201210-1,-1.261773,-1.487437,-1.493051,-1.505201,-1.523022,-1.545650,-1.572220,-1.601867,-1.633727,-1.666934,...,-1.077548,-1.061438,-1.061623,-1.053362,-1.045736,-1.028861,-1.024108,-1.049046,-1.174049,Normal
201210-2,-0.894710,-0.921115,-1.038404,-1.084028,-1.113839,-1.164607,-1.191006,-1.231913,-1.250772,-1.241953,...,-1.883608,-1.884308,-1.885187,-1.870041,-1.856400,-1.844880,-1.846040,-1.798168,-1.732828,Normal
210114-1,0.386497,0.421640,0.386049,0.395677,0.410873,0.434603,0.437952,0.448156,0.449485,0.440940,...,-1.595796,-1.597691,-1.592841,-1.585158,-1.566326,-1.567910,-1.560230,-1.539894,-1.520302,Normal
210114-2,0.342039,0.340826,0.308961,0.252493,0.310173,0.313019,0.305763,0.298428,0.326233,0.343486,...,-1.556205,-1.554188,-1.542800,-1.528400,-1.507333,-1.491158,-1.489146,-1.434492,-1.429440,Normal
210120-1,0.746662,0.707318,0.662128,0.625805,0.597256,0.555120,0.485850,0.445487,0.423553,0.387052,...,-1.078026,-1.078357,-1.070442,-1.068339,-1.066159,-1.075257,-1.059876,-1.058883,-1.053197,Hyperglycemia
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
210519-3,0.261464,0.229795,0.205900,0.206404,0.216478,0.231820,0.269926,0.308085,0.346080,0.375113,...,-1.496688,-1.489000,-1.480562,-1.491120,-1.500803,-1.500677,-1.510344,-1.516037,-1.513067,Hyperglycemia
210524-1,-1.004154,-1.035637,-1.045777,-1.069023,-1.043572,-1.059675,-1.081360,-1.100711,-1.116490,-1.128900,...,-1.711449,-1.726570,-1.728099,-1.724006,-1.711004,-1.693228,-1.651810,-1.598485,-1.641007,Hypoglycemia
210526-1,-1.590005,-1.661008,-1.709788,-1.716623,-1.712634,-1.733227,-1.743411,-1.735913,-1.737246,-1.734658,...,-1.852672,-1.842897,-1.831211,-1.816104,-1.815327,-1.798290,-1.792396,-1.782048,-1.754277,Hyperglycemia
210526-2,-2.318242,-2.341874,-2.364037,-2.351647,-2.352406,-2.409253,-2.390095,-2.400872,-2.418729,-2.416327,...,-1.884517,-1.873499,-1.866623,-1.867632,-1.863514,-1.890710,-1.925062,-1.900814,-1.926573,Hyperglycemia


Leaves out certain SurIDs out to evaluate the models ability to generalise.

In [38]:
def evaluate_extra_trees(df):

    # Set the Surfaces as groups
    X = df.drop(['Status'], axis=1)
    y = df['Status']
    
    # Creating the Extra Trees classifier
    et = ExtraTreesClassifier(random_state=1234)
    
    # Using StratifiedKFold for classification tasks
    #cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=1234)
    cv = KFold(n_splits=10, shuffle=True, random_state=1234)

    # Getting cross-validation scores
    scores = cross_val_score(et, X, y, cv=cv, scoring='accuracy', n_jobs=-1)
    
    # Displaying the results
    print(f'{et.__class__.__name__} Cross-Validation Accuracy: {np.mean(scores):.4f} +/- {np.std(scores):.4f}')

In [39]:
evaluate_extra_trees(wavelength_df)

ExtraTreesClassifier Cross-Validation Accuracy: 0.5643 +/- 0.1934


Result on not scaled data:
ExtraTreesClassifier Cross-Validation Accuracy: 0.5857 +/- 0.1102

Results on scaled data:
ExtraTreesClassifier Cross-Validation Accuracy: 0.5643 +/- 0.1934