### This Notebook Evaluates Replicate Out Cross-Validation to Better Measure how the model generalises

Group KFold Cross-Validation prevents spectra from the same Surface appearing within both the training and test folds.

This gives a better indication of the model's ability to generalise, as it stops leakage between samples from the same fold.

Import Libraries

In [1]:
import pandas as pd
import seaborn as sns
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import StratifiedKFold, KFold, GroupKFold
from sklearn.model_selection import LeavePGroupsOut

Read the spectral data

In [2]:
#df = pd.read_csv("../../data/exosomes.raw_spectrum_1.csv")

In [3]:
df = pd.read_csv("../../optuna_cleaning_spectra.csv")

In [4]:
df['SpecID'].nunique()

3045

In [5]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6239205 entries, 0 to 6239204
Data columns (total 6 columns):
 #   Column      Dtype  
---  ------      -----  
 0   SpecID      object 
 1   Seq         int64  
 2   WaveNumber  float64
 3   Absorbance  float64
 4   SurID       object 
 5   Status      object 
dtypes: float64(2), int64(1), object(3)
memory usage: 285.6+ MB


#### Train an Extra Trees Classifier on the full spectrum.

In [6]:
def prepare_wavelength_df(df, absorbance_col, status_col='Status'):

    # Step 1: Group by 'SurID' and 'WaveNumber' and calculate median absorbance
    median_absorbance = df.groupby(['SurID', 'WaveNumber'])[absorbance_col].median().reset_index()

    # Step 2: Pivot the table to get 'WaveNumber' as columns, 'SurID' as index, and median absorbance as values
    wavelength_df = median_absorbance.pivot(index='SurID', columns='WaveNumber', values=absorbance_col)

    # Merge with the statuses based on SpecID
    # Include the SurID to perform GroupKFold CV
    statuses_and_surface = df[['SurID', status_col]].drop_duplicates()
    wavelength_df = pd.merge(wavelength_df, statuses_and_surface, on='SurID')

    # Set SpecID as the index
    wavelength_df = wavelength_df.set_index('SurID')

    return wavelength_df

In [7]:
wavelength_df = prepare_wavelength_df(df, 'Absorbance')
wavelength_df

Unnamed: 0_level_0,400.22778,400.91116,401.59454,402.27789,402.96127,403.64465,404.32803,405.01138,405.69476,406.37814,...,1794.3053,1794.9886,1795.672,1796.3553,1797.0387,1797.722,1798.4055,1799.0889,1799.7722,Status
SurID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
201210-1,-0.006066,-0.009978,-0.013899,-0.008371,-0.012366,-0.010708,-0.013499,-0.011808,-0.007892,-0.008800,...,-0.005555,-0.005570,-0.006277,-0.004788,-0.005204,-0.004975,-0.004912,-0.006252,-0.006024,Normal
201210-2,-0.006985,-0.007944,-0.007097,-0.007175,-0.007719,-0.008383,-0.008279,-0.009176,-0.008554,-0.008365,...,-0.013399,-0.012174,-0.012807,-0.013541,-0.013215,-0.012425,-0.012813,-0.013469,-0.013100,Normal
210114-1,0.021078,0.019592,0.019564,0.018430,0.017379,0.017885,0.018595,0.019918,0.018479,0.019082,...,-0.012637,-0.012500,-0.013439,-0.012607,-0.011671,-0.011619,-0.010893,-0.012163,-0.011828,Normal
210114-2,0.018762,0.019651,0.019739,0.020253,0.018586,0.017376,0.018982,0.018869,0.019467,0.019746,...,-0.010143,-0.010286,-0.010107,-0.010266,-0.009509,-0.008970,-0.011015,-0.010777,-0.011668,Normal
210120-1,0.021287,0.022054,0.022066,0.021095,0.020544,0.019706,0.018084,0.018457,0.018406,0.016659,...,-0.008766,-0.008179,-0.008886,-0.009008,-0.008671,-0.008036,-0.008497,-0.009162,-0.008883,Hyperglycemia
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
210519-3,0.011480,0.014745,0.013208,0.011891,0.012752,0.012015,0.014077,0.013743,0.014106,0.013077,...,-0.010961,-0.010564,-0.011066,-0.011739,-0.011967,-0.011466,-0.011792,-0.012558,-0.011886,Hyperglycemia
210524-1,-0.007456,-0.008781,-0.009012,-0.008729,-0.009619,-0.008884,-0.008973,-0.006905,-0.007498,-0.009671,...,-0.014520,-0.015094,-0.015846,-0.015719,-0.015626,-0.015337,-0.014385,-0.014483,-0.015197,Hypoglycemia
210526-1,-0.010439,-0.011936,-0.012041,-0.012316,-0.012245,-0.012185,-0.012437,-0.012184,-0.012844,-0.012551,...,-0.012116,-0.012192,-0.013408,-0.012921,-0.012134,-0.012407,-0.013021,-0.013686,-0.013345,Hyperglycemia
210526-2,-0.015722,-0.016298,-0.015422,-0.016109,-0.016126,-0.014102,-0.017215,-0.015002,-0.015937,-0.015372,...,-0.010926,-0.013288,-0.012359,-0.012021,-0.011010,-0.011754,-0.012337,-0.010246,-0.012557,Hyperglycemia


>**The Count of Surface IDs and the number of associated samples**

In [8]:
len(wavelength_df.groupby(['SurID']))

63

It looks like each Surface is associated with 1 Status

>**The Count of Spectra with each Status**

In [9]:
df.groupby('Status')['SpecID'].nunique().reset_index(name='Spectra Count')

Unnamed: 0,Status,Spectra Count
0,Hyperglycemia,915
1,Hypoglycemia,1065
2,Normal,1065


>**The Count of Surfaces with each Status**

In [10]:
df.groupby('Status')['SurID'].nunique().reset_index(name='Surface Count')

Unnamed: 0,Status,Surface Count
0,Hyperglycemia,19
1,Hypoglycemia,22
2,Normal,22


>#### **Train an Extra Trees Classifier on the Raw Spectrum and evaluate it with GroupKFold cross-validation.**

In [11]:
wavelength_df

Unnamed: 0_level_0,400.22778,400.91116,401.59454,402.27789,402.96127,403.64465,404.32803,405.01138,405.69476,406.37814,...,1794.3053,1794.9886,1795.672,1796.3553,1797.0387,1797.722,1798.4055,1799.0889,1799.7722,Status
SurID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
201210-1,-0.006066,-0.009978,-0.013899,-0.008371,-0.012366,-0.010708,-0.013499,-0.011808,-0.007892,-0.008800,...,-0.005555,-0.005570,-0.006277,-0.004788,-0.005204,-0.004975,-0.004912,-0.006252,-0.006024,Normal
201210-2,-0.006985,-0.007944,-0.007097,-0.007175,-0.007719,-0.008383,-0.008279,-0.009176,-0.008554,-0.008365,...,-0.013399,-0.012174,-0.012807,-0.013541,-0.013215,-0.012425,-0.012813,-0.013469,-0.013100,Normal
210114-1,0.021078,0.019592,0.019564,0.018430,0.017379,0.017885,0.018595,0.019918,0.018479,0.019082,...,-0.012637,-0.012500,-0.013439,-0.012607,-0.011671,-0.011619,-0.010893,-0.012163,-0.011828,Normal
210114-2,0.018762,0.019651,0.019739,0.020253,0.018586,0.017376,0.018982,0.018869,0.019467,0.019746,...,-0.010143,-0.010286,-0.010107,-0.010266,-0.009509,-0.008970,-0.011015,-0.010777,-0.011668,Normal
210120-1,0.021287,0.022054,0.022066,0.021095,0.020544,0.019706,0.018084,0.018457,0.018406,0.016659,...,-0.008766,-0.008179,-0.008886,-0.009008,-0.008671,-0.008036,-0.008497,-0.009162,-0.008883,Hyperglycemia
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
210519-3,0.011480,0.014745,0.013208,0.011891,0.012752,0.012015,0.014077,0.013743,0.014106,0.013077,...,-0.010961,-0.010564,-0.011066,-0.011739,-0.011967,-0.011466,-0.011792,-0.012558,-0.011886,Hyperglycemia
210524-1,-0.007456,-0.008781,-0.009012,-0.008729,-0.009619,-0.008884,-0.008973,-0.006905,-0.007498,-0.009671,...,-0.014520,-0.015094,-0.015846,-0.015719,-0.015626,-0.015337,-0.014385,-0.014483,-0.015197,Hypoglycemia
210526-1,-0.010439,-0.011936,-0.012041,-0.012316,-0.012245,-0.012185,-0.012437,-0.012184,-0.012844,-0.012551,...,-0.012116,-0.012192,-0.013408,-0.012921,-0.012134,-0.012407,-0.013021,-0.013686,-0.013345,Hyperglycemia
210526-2,-0.015722,-0.016298,-0.015422,-0.016109,-0.016126,-0.014102,-0.017215,-0.015002,-0.015937,-0.015372,...,-0.010926,-0.013288,-0.012359,-0.012021,-0.011010,-0.011754,-0.012337,-0.010246,-0.012557,Hyperglycemia


Leaves out certain SurIDs out to evaluate the models ability to generalise.

In [12]:
def evaluate_extra_trees(df):

    # Set the Surfaces as groups
    X = df.drop(['Status'], axis=1)
    y = df['Status']
    
    # Creating the Extra Trees classifier
    et = ExtraTreesClassifier(random_state=1234)
    
    # Using StratifiedKFold for classification tasks
    #cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=1234)
    cv = KFold(n_splits=10, shuffle=True, random_state=1234)

    # Getting cross-validation scores
    scores = cross_val_score(et, X, y, cv=cv, scoring='accuracy', n_jobs=-1)
    
    # Displaying the results
    print(f'{et.__class__.__name__} Cross-Validation Accuracy: {np.mean(scores):.4f} +/- {np.std(scores):.4f}')

In [13]:
evaluate_extra_trees(wavelength_df)

ExtraTreesClassifier Cross-Validation Accuracy: 0.4476 +/- 0.1659


Result on not scaled data:
ExtraTreesClassifier Cross-Validation Accuracy: 0.5857 +/- 0.1102

Results on scaled data:
ExtraTreesClassifier Cross-Validation Accuracy: 0.5643 +/- 0.1934