# Working Models

We can attempt, purely for fun, to use the newly developed models to predict the diagnoses for certain baseline visits where the diagnosis is actually missing.

## Patients with a null diagnosis at baseline visit

In [1]:
import os
import joblib
import numpy as np
import pandas as pd 

from CogniPredictAD.preprocessing import ADNICleaner, ADNITransformator, knn_impute_group
from sklearn.impute import KNNImputer

pd.set_option("display.max_rows", 116)
pd.set_option("display.max_columns", 116)
pd.set_option("display.max_info_columns", 116) 

# Open the evaluation_dataset with pandas
evaluation_dataset = pd.read_csv("../data/ADNIMERGE.csv")
# Filter baseline ("bl") visits with missing diagnoses
evaluation_dataset = evaluation_dataset[(evaluation_dataset['VISCODE'] == 'bl') & (evaluation_dataset['DX_bl'].isna())]

evaluation_dataset.reset_index(drop=True, inplace=True)

cleaner = ADNICleaner(evaluation_dataset)
transformer = ADNITransformator(evaluation_dataset)

display(evaluation_dataset)

  evaluation_dataset = pd.read_csv("../data/ADNIMERGE.csv")


Unnamed: 0,RID,COLPROT,ORIGPROT,PTID,SITE,VISCODE,EXAMDATE,DX_bl,AGE,PTGENDER,PTEDUCAT,PTETHCAT,PTRACCAT,PTMARRY,APOE4,FDG,PIB,AV45,FBB,ABETA,TAU,PTAU,CDRSB,ADAS11,ADAS13,ADASQ4,MMSE,RAVLT_immediate,RAVLT_learning,RAVLT_forgetting,RAVLT_perc_forgetting,LDELTOTAL,DIGITSCOR,TRABSCOR,FAQ,MOCA,EcogPtMem,EcogPtLang,EcogPtVisspat,EcogPtPlan,EcogPtOrgan,EcogPtDivatt,EcogPtTotal,EcogSPMem,EcogSPLang,EcogSPVisspat,EcogSPPlan,EcogSPOrgan,EcogSPDivatt,EcogSPTotal,FLDSTRENG,FSVERSION,IMAGEUID,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp,ICV,DX,mPACCdigit,mPACCtrailsB,EXAMDATE_bl,CDRSB_bl,ADAS11_bl,ADAS13_bl,ADASQ4_bl,MMSE_bl,RAVLT_immediate_bl,RAVLT_learning_bl,RAVLT_forgetting_bl,RAVLT_perc_forgetting_bl,LDELTOTAL_BL,DIGITSCOR_bl,TRABSCOR_bl,FAQ_bl,mPACCdigit_bl,mPACCtrailsB_bl,FLDSTRENG_bl,FSVERSION_bl,IMAGEUID_bl,Ventricles_bl,Hippocampus_bl,WholeBrain_bl,Entorhinal_bl,Fusiform_bl,MidTemp_bl,ICV_bl,MOCA_bl,EcogPtMem_bl,EcogPtLang_bl,EcogPtVisspat_bl,EcogPtPlan_bl,EcogPtOrgan_bl,EcogPtDivatt_bl,EcogPtTotal_bl,EcogSPMem_bl,EcogSPLang_bl,EcogSPVisspat_bl,EcogSPPlan_bl,EcogSPOrgan_bl,EcogSPDivatt_bl,EcogSPTotal_bl,ABETA_bl,TAU_bl,PTAU_bl,FDG_bl,PIB_bl,AV45_bl,FBB_bl,Years_bl,Month_bl,Month,M,update_stamp
0,6712,ADNI3,ADNI3,019_S_6712,19,bl,2019-04-24,,90.6,Male,17,Not Hisp/Latino,White,Married,,0.735309,,,1.5838,,,,7.0,29.33,41.33,8.0,20.0,8.0,0.0,2.0,100.0,0.0,,300.0,16.0,17.0,,,,,,,,2.375,1.66667,1.5,2.0,2.5,2.5,2.05263,,,,,,,,,,,,-19.7479,-18.1154,2019-04-24,7.0,29.33,41.33,8.0,20.0,8.0,0.0,2.0,100.0,0.0,,300.0,16.0,-19.7479,-18.1154,,,,,,,,,,,17.0,,,,,,,,2.375,1.66667,1.5,2.0,2.5,2.5,2.05263,,,,0.735309,,,1.5838,0.0,0.0,0,0,2023-07-07 05:00:04.0
1,6880,ADNI3,ADNI3,137_S_6880,137,bl,2020-12-09,,66.3,Male,16,Not Hisp/Latino,White,Married,0.0,,,,1.0646,,,,0.5,3.67,7.67,4.0,28.0,46.0,6.0,4.0,33.3333,8.0,,61.0,0.0,28.0,2.0,1.22222,1.0,1.4,1.16667,1.75,1.42105,1.0,1.0,1.0,1.0,1.0,1.0,1.0,,Cross-Sectional FreeSurfer (6.0),1360078.0,17827.4,7590.5,1000750.0,4368.0,16228.0,21865.0,1407420.0,,-4.26884,-2.71821,2020-12-09,0.5,3.67,7.67,4.0,28.0,46.0,6.0,4.0,33.3333,8.0,,61.0,0.0,-4.26884,-2.71821,,Cross-Sectional FreeSurfer (6.0),1360078.0,17827.4,7590.5,1000750.0,4368.0,16228.0,21865.0,1407420.0,28.0,2.0,1.22222,1.0,1.4,1.16667,1.75,1.42105,1.0,1.0,1.0,1.0,1.0,1.0,1.0,,,,,,,1.0646,0.0,0.0,0,0,2023-07-09 05:25:37.0
2,6883,ADNI3,ADNI3,137_S_6883,137,bl,2020-12-21,,64.5,Female,18,Not Hisp/Latino,White,Divorced,0.0,,,,,,,,0.5,3.33,5.33,2.0,29.0,53.0,2.0,3.0,25.0,11.0,,68.0,0.0,,1.375,2.11111,1.33333,1.6,1.4,1.25,1.56757,1.25,1.11111,1.0,1.0,1.16667,1.0,1.10526,,,,,,,,,,,,-0.42049,-0.088403,2020-12-21,0.5,3.33,5.33,2.0,29.0,53.0,2.0,3.0,25.0,11.0,,68.0,0.0,-0.42049,-0.088403,,,,,,,,,,,,1.375,2.11111,1.33333,1.6,1.4,1.25,1.56757,1.25,1.11111,1.0,1.0,1.16667,1.0,1.10526,,,,,,,,0.0,0.0,0,0,2023-07-07 05:00:05.0
3,6912,ADNI3,ADNI3,067_S_6912,67,bl,2021-07-26,,60.3,Female,16,Not Hisp/Latino,More than one,Divorced,,,,,,,,,2.0,8.33,15.33,7.0,28.0,35.0,3.0,7.0,87.5,8.0,,119.0,,24.0,3.25,2.22222,1.16667,2.0,3.33333,3.5,2.55263,,,,,,,,,Cross-Sectional FreeSurfer (6.0),1433667.0,29464.5,7214.5,1153330.0,3661.0,20394.0,21953.0,1645580.0,,-6.49645,-5.97208,2021-07-26,2.0,8.33,15.33,7.0,28.0,35.0,3.0,7.0,87.5,8.0,,119.0,,-6.49645,-5.97208,,Cross-Sectional FreeSurfer (6.0),1433667.0,29464.5,7214.5,1153330.0,3661.0,20394.0,21953.0,1645580.0,24.0,3.25,2.22222,1.16667,2.0,3.33333,3.5,2.55263,,,,,,,,,,,,,,,0.0,0.0,0,0,2023-07-07 05:00:05.0
4,6906,ADNI3,ADNI3,137_S_6906,137,bl,2021-07-21,,55.5,Female,16,Not Hisp/Latino,White,Never married,,,,,,,,,1.0,5.33,6.33,1.0,27.0,62.0,8.0,1.0,6.66667,15.0,,71.0,5.0,26.0,2.625,2.33333,1.5,1.8,1.83333,3.25,2.21053,1.875,1.33333,1.66667,1.4,1.5,2.0,1.60526,,Cross-Sectional FreeSurfer (6.0),1454373.0,18665.5,7222.5,954688.0,4131.0,15253.0,19195.0,1283140.0,,-0.40621,-0.179726,2021-07-21,1.0,5.33,6.33,1.0,27.0,62.0,8.0,1.0,6.66667,15.0,,71.0,5.0,-0.40621,-0.179726,,Cross-Sectional FreeSurfer (6.0),1454373.0,18665.5,7222.5,954688.0,4131.0,15253.0,19195.0,1283140.0,26.0,2.625,2.33333,1.5,1.8,1.83333,3.25,2.21053,1.875,1.33333,1.66667,1.4,1.5,2.0,1.60526,,,,,,,,0.0,0.0,0,0,2023-07-07 05:00:05.0
5,6390,ADNI3,ADNI3,130_S_6390,130,bl,2018-08-06,,68.9,Female,15,Not Hisp/Latino,White,Married,,,,,,,,,0.0,1.67,5.67,4.0,30.0,46.0,4.0,8.0,72.7273,12.0,,60.0,0.0,24.0,1.375,1.33333,1.33333,1.4,1.0,2.5,1.42105,1.75,1.22222,1.33333,1.0,1.0,1.75,1.34211,,Cross-Sectional FreeSurfer (6.0),1010018.0,41685.6,6796.2,996598.0,3984.0,15119.0,20880.0,1455600.0,,-0.341944,0.265948,2018-08-06,0.0,1.67,5.67,4.0,30.0,46.0,4.0,8.0,72.7273,12.0,,60.0,0.0,-0.341944,0.265948,,Cross-Sectional FreeSurfer (6.0),1010018.0,41685.6,6796.2,996598.0,3984.0,15119.0,20880.0,1455600.0,24.0,1.375,1.33333,1.33333,1.4,1.0,2.5,1.42105,1.75,1.22222,1.33333,1.0,1.0,1.75,1.34211,,,,,,,,0.0,0.0,0,0,2023-07-07 05:00:03.0
6,6627,ADNI3,ADNI3,037_S_6627,37,bl,2019-02-27,,67.7,Female,14,Not Hisp/Latino,White,Divorced,2.0,0.948225,,,,,,,2.5,,,,25.0,,,,,1.0,,,15.0,,,,,,,,,3.375,3.11111,3.0,4.0,3.66667,4.0,3.44444,,Cross-Sectional FreeSurfer (6.0),1080124.0,79008.3,6951.7,1137970.0,4197.0,22067.0,19711.0,1706490.0,,-14.4432,-14.4432,2019-02-27,2.5,,,,25.0,,,,,1.0,,,15.0,-14.4432,-14.4432,,Cross-Sectional FreeSurfer (6.0),1080124.0,79008.3,6951.7,1137970.0,4197.0,22067.0,19711.0,1706490.0,,,,,,,,,3.375,3.11111,3.0,4.0,3.66667,4.0,3.44444,,,,0.948225,,,,0.0,0.0,0,0,2023-07-07 05:00:04.0
7,6701,ADNI3,ADNI3,032_S_6701,32,bl,2019-08-22,,71.1,Female,14,Not Hisp/Latino,Black,Married,,,,1.4936,,,,,0.0,,,,27.0,,,,,8.0,,,,,,,,,,,,,,,,,,,,Cross-Sectional FreeSurfer (6.0),1154769.0,29883.2,6816.7,999323.0,3874.0,16455.0,19162.0,1416240.0,,-6.75361,-6.75361,2019-08-22,0.0,,,,27.0,,,,,8.0,,,,-6.75361,-6.75361,,Cross-Sectional FreeSurfer (6.0),1154769.0,29883.2,6816.7,999323.0,3874.0,16455.0,19162.0,1416240.0,,,,,,,,,,,,,,,,,,,,,1.4936,,0.0,0.0,0,0,2023-07-07 05:00:04.0
8,6794,ADNI3,ADNI3,137_S_6794,137,bl,2020-02-28,,67.0,Female,16,Not Hisp/Latino,Black,Married,,,,,,,,,0.0,4.33,6.33,2.0,27.0,40.0,9.0,8.0,61.5385,10.0,,75.0,,22.0,1.625,1.44444,1.0,1.6,1.66667,2.0,1.52632,,,,,,,,,Cross-Sectional FreeSurfer (6.0),1222159.0,60062.5,7596.6,1064780.0,4805.0,20048.0,20206.0,1573900.0,,-3.1479,-2.36562,2020-02-28,0.0,4.33,6.33,2.0,27.0,40.0,9.0,8.0,61.5385,10.0,,75.0,,-3.1479,-2.36562,,Cross-Sectional FreeSurfer (6.0),1222159.0,60062.5,7596.6,1064780.0,4805.0,20048.0,20206.0,1573900.0,22.0,1.625,1.44444,1.0,1.6,1.66667,2.0,1.52632,,,,,,,,,,,,,,,0.0,0.0,0,0,2023-07-07 05:00:05.0
9,6919,ADNI3,ADNI3,137_S_6919,137,bl,2021-07-23,,78.3,Male,18,Not Hisp/Latino,White,Married,,,,,,,,,0.5,6.67,11.67,5.0,29.0,33.0,3.0,3.0,37.5,3.0,,116.0,0.0,24.0,2.0,2.0,1.16667,1.2,1.33333,1.75,1.63158,1.25,1.0,1.0,1.0,1.16667,1.5,1.13158,,,,,,,,,,,,-5.84674,-5.4241,2021-07-23,0.5,6.67,11.67,5.0,29.0,33.0,3.0,3.0,37.5,3.0,,116.0,0.0,-5.84674,-5.4241,,,,,,,,,,,24.0,2.0,2.0,1.16667,1.2,1.33333,1.75,1.63158,1.25,1.0,1.0,1.0,1.16667,1.5,1.13158,,,,,,,,0.0,0.0,0,0,2023-07-07 05:00:05.0


In [2]:
evaluation_dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 11 entries, 0 to 10
Data columns (total 116 columns):
 #    Column                    Non-Null Count  Dtype  
---   ------                    --------------  -----  
 0    RID                       11 non-null     int64  
 1    COLPROT                   11 non-null     object 
 2    ORIGPROT                  11 non-null     object 
 3    PTID                      11 non-null     object 
 4    SITE                      11 non-null     int64  
 5    VISCODE                   11 non-null     object 
 6    EXAMDATE                  11 non-null     object 
 7    DX_bl                     0 non-null      object 
 8    AGE                       11 non-null     float64
 9    PTGENDER                  11 non-null     object 
 10   PTEDUCAT                  11 non-null     int64  
 11   PTETHCAT                  11 non-null     object 
 12   PTRACCAT                  11 non-null     object 
 13   PTMARRY                   11 non-null     object 


In [3]:
train = pd.read_csv("../data/train.csv")
print(train.columns)

Index(['DX', 'AGE', 'PTGENDER', 'PTEDUCAT', 'APOE4', 'MMSE', 'CDRSB', 'ADAS13',
       'LDELTOTAL', 'FAQ', 'MOCA', 'TRABSCOR', 'RAVLT_immediate',
       'RAVLT_learning', 'RAVLT_perc_forgetting', 'mPACCdigit', 'EcogPtMem',
       'EcogPtLang', 'EcogPtVisspat', 'EcogPtPlan', 'EcogPtOrgan',
       'EcogPtDivatt', 'EcogSPMem', 'EcogSPLang', 'EcogSPVisspat',
       'EcogSPPlan', 'EcogSPOrgan', 'EcogSPDivatt', 'FDG', 'PTAU/ABETA',
       'Hippocampus/ICV', 'Entorhinal/ICV', 'Fusiform/ICV', 'MidTemp/ICV',
       'Ventricles/ICV', 'WholeBrain/ICV'],
      dtype='object')


## Same Preprocessing

In [4]:
evaluation_dataset = cleaner.clean_limit_values(columns=["TAU", "PTAU", "ABETA"], dataset=evaluation_dataset)

# Ensure evaluation_dataset has all columns in feature_cols.
# If a column is missing:
#  - if it is a ratio like "A/B" -> try to create it via create_ratio_column("A","B", new_col_name="A/B")
#  - otherwise create the column filled with NaN
for col in train.columns:
    if col in evaluation_dataset.columns:
        continue  # already present
    # If column looks like a ratio "X/Y", attempt to create it using the transformator
    if "/" in col:
        num, den = col.split("/", 1)
        # create_ratio_column will print a warning and skip if numerator/denominator don't exist
        evaluation_dataset = transformer.create_ratio_column(numerator_col=num, denominator_col=den, new_col_name=col)
    else:
        # create empty column with NaNs so column set matches train
        evaluation_dataset.loc[:, col] = pd.NA

evaluation_dataset["PTGENDER"] = evaluation_dataset["PTGENDER"].map({"Male": 1, "Female": 0})
evaluation_dataset = evaluation_dataset[train.columns]

Column 'TAU' cleaned: '<' values scaled by 0.99, '>' values by 1.01.
Column 'PTAU' cleaned: '<' values scaled by 0.99, '>' values by 1.01.
Column 'ABETA' cleaned: '<' values scaled by 0.99, '>' values by 1.01.
Ratio column 'PTAU/ABETA' created from 'PTAU' / 'ABETA'.
NaNs -> PTAU: 11, ABETA: 11, PTAU/ABETA: 11
'PTAU/ABETA' has 0 more NaN(s) than 'PTAU' and 0 more NaN(s) than 'ABETA'.
Ratio column 'Hippocampus/ICV' created from 'Hippocampus' / 'ICV'.
NaNs -> Hippocampus: 4, ICV: 4, Hippocampus/ICV: 4
'Hippocampus/ICV' has 0 more NaN(s) than 'Hippocampus' and 0 more NaN(s) than 'ICV'.
Ratio column 'Entorhinal/ICV' created from 'Entorhinal' / 'ICV'.
NaNs -> Entorhinal: 4, ICV: 4, Entorhinal/ICV: 4
'Entorhinal/ICV' has 0 more NaN(s) than 'Entorhinal' and 0 more NaN(s) than 'ICV'.
Ratio column 'Fusiform/ICV' created from 'Fusiform' / 'ICV'.
NaNs -> Fusiform: 4, ICV: 4, Fusiform/ICV: 4
'Fusiform/ICV' has 0 more NaN(s) than 'Fusiform' and 0 more NaN(s) than 'ICV'.
Ratio column 'MidTemp/ICV' cr

In [5]:
# Quick diagnostics: how many missing values per column before imputation
print("Missing values BEFORE imputation (evaluation):")
print(evaluation_dataset.isna().sum())

Missing values BEFORE imputation (evaluation):
DX                       11
AGE                       0
PTGENDER                  0
PTEDUCAT                  0
APOE4                     8
MMSE                      0
CDRSB                     0
ADAS13                    3
LDELTOTAL                 1
FAQ                       4
MOCA                      4
TRABSCOR                  3
RAVLT_immediate           3
RAVLT_learning            3
RAVLT_perc_forgetting     3
mPACCdigit                1
EcogPtMem                 4
EcogPtLang                4
EcogPtVisspat             4
EcogPtPlan                4
EcogPtOrgan               4
EcogPtDivatt              4
EcogSPMem                 4
EcogSPLang                4
EcogSPVisspat             4
EcogSPPlan                4
EcogSPOrgan               4
EcogSPDivatt              4
FDG                       9
PTAU/ABETA               11
Hippocampus/ICV           4
Entorhinal/ICV            4
Fusiform/ICV              4
MidTemp/ICV               4
V

We apply **KNN** (the same one used in the Data Preprocessing Notebook), using the one to be evaluated by the models as a test.

In [6]:
train = train.astype(float)
evaluation_dataset = evaluation_dataset.astype(float)
cols_to_impute = [c for c in train.columns if c != "DX"]
knn_obj, means_used, stds_used = knn_impute_group(train, evaluation_dataset, cols_to_impute)

K-Nearest Neighbors imputation applied...


In [7]:
# Diagnostics after imputation
print("Missing values AFTER imputation (evaluation_imputed):")
print(evaluation_dataset.isna().sum().sum(), "total NaNs (should be 0)")

Missing values AFTER imputation (evaluation_imputed):
11 total NaNs (should be 0)


In [8]:
round_columns = ['APOE4', 'MMSE', 'RAVLT_immediate', 'RAVLT_learning', 'LDELTOTAL', 'TRABSCOR', 'FAQ', 'MOCA', 'PTGENDER', 'PTEDUCAT']
for col in round_columns:
    evaluation_dataset = cleaner.convert_float_to_int(column=col, method="round", dataset=evaluation_dataset)

evaluation_dataset = cleaner.convert_float_to_int("AGE", method="floor", dataset=evaluation_dataset)
evaluation_dataset = cleaner.convert_float_to_int("ADAS13", method="floor", dataset=evaluation_dataset)

Column 'APOE4' converted using round-half-up rounding.
Column 'MMSE' converted using round-half-up rounding.
Column 'RAVLT_immediate' converted using round-half-up rounding.
Column 'RAVLT_learning' converted using round-half-up rounding.
Column 'LDELTOTAL' converted using round-half-up rounding.
Column 'TRABSCOR' converted using round-half-up rounding.
Column 'FAQ' converted using round-half-up rounding.
Column 'MOCA' converted using round-half-up rounding.
Column 'PTGENDER' converted using round-half-up rounding.
Column 'PTEDUCAT' converted using round-half-up rounding.
Column 'AGE' converted using floor rounding.
Column 'ADAS13' converted using floor rounding.


In [9]:
evaluation_dataset.reindex(columns=["DX", "AGE", "PTGENDER", "PTEDUCAT", "APOE4", "MMSE", "CDRSB", "ADAS13", "LDELTOTAL", "FAQ", "MOCA", "TRABSCOR", "RAVLT_immediate", "RAVLT_learning", "RAVLT_perc_forgetting", "mPACCdigit", "EcogPtMem", "EcogPtLang", "EcogPtVisspat", "EcogPtPlan", "EcogPtOrgan", "EcogPtDivatt", "EcogSPMem", "EcogSPLang", "EcogSPVisspat", "EcogSPPlan", "EcogSPOrgan", "EcogSPDivatt", "FDG", "PTAU/ABETA", "Hippocampus/ICV", "Entorhinal/ICV", "Fusiform/ICV", "MidTemp/ICV", "Ventricles/ICV", "WholeBrain/ICV"])
display(evaluation_dataset)

Unnamed: 0,DX,AGE,PTGENDER,PTEDUCAT,APOE4,MMSE,CDRSB,ADAS13,LDELTOTAL,FAQ,MOCA,TRABSCOR,RAVLT_immediate,RAVLT_learning,RAVLT_perc_forgetting,mPACCdigit,EcogPtMem,EcogPtLang,EcogPtVisspat,EcogPtPlan,EcogPtOrgan,EcogPtDivatt,EcogSPMem,EcogSPLang,EcogSPVisspat,EcogSPPlan,EcogSPOrgan,EcogSPDivatt,FDG,PTAU/ABETA,Hippocampus/ICV,Entorhinal/ICV,Fusiform/ICV,MidTemp/ICV,Ventricles/ICV,WholeBrain/ICV
0,,90,1,17,1,20,7.0,41,0,16,17,300,8,0,100.0,-19.7479,2.451428,2.074443,1.661904,1.776,1.891999,2.16,2.375,1.66667,1.5,2.0,2.5,2.5,0.735309,0.054991,0.003296,0.00159,0.008719,0.009541,0.038619,0.597339
1,,66,1,16,0,28,0.5,7,8,0,28,61,46,6,33.3333,-4.26884,2.0,1.22222,1.0,1.4,1.16667,1.75,1.0,1.0,1.0,1.0,1.0,1.0,1.302911,0.0188,0.005393,0.003104,0.01153,0.015536,0.012667,0.711053
2,,64,0,18,0,29,0.5,5,11,0,26,68,53,2,25.0,-0.42049,1.375,2.11111,1.33333,1.6,1.4,1.25,1.25,1.11111,1.0,1.0,1.16667,1.0,1.342462,0.015544,0.005541,0.003074,0.012929,0.014192,0.01865,0.716139
3,,60,0,16,1,28,2.0,15,8,1,24,119,35,3,87.5,-6.49645,3.25,2.22222,1.16667,2.0,3.33333,3.5,2.255,1.797779,1.529525,1.578667,1.913332,2.123334,1.305944,0.024087,0.004384,0.002225,0.012393,0.013341,0.017905,0.700865
4,,55,0,16,1,27,1.0,6,15,5,26,71,62,8,6.66667,-0.40621,2.625,2.33333,1.5,1.8,1.83333,3.25,1.875,1.33333,1.66667,1.4,1.5,2.0,1.347033,0.019281,0.005629,0.003219,0.011887,0.014959,0.014547,0.744025
5,,68,0,15,1,30,0.0,5,12,0,24,60,46,4,72.7273,-0.341944,1.375,1.33333,1.33333,1.4,1.0,2.5,1.75,1.22222,1.33333,1.0,1.0,1.75,1.257902,0.03249,0.004669,0.002737,0.010387,0.014345,0.028638,0.684665
6,,67,0,14,2,25,2.5,32,1,15,17,215,23,2,100.0,-14.4432,2.205,1.866665,1.744762,1.632,1.766667,1.8,3.375,3.11111,3.0,4.0,3.66667,4.0,0.948225,0.068712,0.004074,0.002459,0.012931,0.011551,0.046299,0.666848
7,,71,0,14,0,27,0.0,11,8,1,23,113,39,6,42.15384,-6.75361,2.0,1.63111,1.429524,1.208,1.166667,1.51,1.527858,1.19611,1.13,1.06,1.238667,1.37,1.304248,0.018641,0.004813,0.002735,0.011619,0.01353,0.0211,0.705617
8,,67,0,16,0,27,0.0,6,10,0,22,75,40,9,61.5385,-3.1479,1.625,1.44444,1.0,1.6,1.66667,2.0,1.45,1.337776,1.32,1.25,1.36,1.45,1.264108,0.023994,0.004827,0.003053,0.012738,0.012838,0.038162,0.676523
9,,78,1,18,1,29,0.5,11,3,0,24,116,33,3,37.5,-5.84674,2.0,2.0,1.16667,1.2,1.33333,1.75,1.25,1.0,1.0,1.0,1.16667,1.5,1.171227,0.043181,0.004804,0.002727,0.011828,0.013114,0.020212,0.664319


## Load the Models

In [10]:
models_dir = "../results"
model_files = {
    "Model1": "Model1.pkl",
    "XAIModel1": "XAIModel1.pkl",
    "Model2": "Model2.pkl",
    "XAIModel2": "XAIModel2.pkl"
}

loaded_models = {}
for name, fname in model_files.items():
    path = os.path.join(models_dir, fname)
    mdl = joblib.load(path)
    if mdl is not None:
        loaded_models[name] = mdl


## Predict Diagnosis

In [11]:
X_eval_1 = evaluation_dataset.drop(['DX'], axis=1)
X_eval_2 = X_eval_1.drop(["CDRSB", "LDELTOTAL", "mPACCdigit"], axis=1)

# Prepare containers for predictions and probabilities (keep them OUTSIDE evaluation_dataset)
model_names = list(loaded_models.keys())
preds_df = pd.DataFrame(index=evaluation_dataset.index, columns=model_names)
probs_df = pd.DataFrame(index=evaluation_dataset.index, columns=model_names, dtype=float)

# Predict with each model, but DO NOT write Pred_* columns into evaluation_dataset ---
for name, mdl in loaded_models.items():
    # choose the correct input matrix for this model
    if name in ["Model1", "XAIModel1"]:
        X_eval = X_eval_1.copy()
    elif name in ["Model2", "XAIModel2"]:
        X_eval = X_eval_2.copy()
    else:
        X_eval = X_eval_1.copy()

    # Predict (DataFrame first, fallback to numpy array)
    try:
        preds = mdl.predict(X_eval)
    except Exception:
        preds = mdl.predict(X_eval.values)

    # Store predictions in external DataFrame (no changes to evaluation_dataset)
    preds_df[name] = pd.Series(preds, index=evaluation_dataset.index)

    # Try to get a top-probability per sample, store in probs_df (NaN if unavailable)
    top_probs = None
    if hasattr(mdl, "predict_proba"):
        try:
            probs = mdl.predict_proba(X_eval)
        except Exception:
            probs = mdl.predict_proba(X_eval.values)
        top_probs = np.max(probs, axis=1)
    elif hasattr(mdl, "decision_function"):
        try:
            dec = mdl.decision_function(X_eval)
        except Exception:
            dec = mdl.decision_function(X_eval.values)
        # convert decision scores to softmax as proxy
        try:
            exp = np.exp(dec - np.max(dec, axis=1, keepdims=True))
            softmax = exp / np.sum(exp, axis=1, keepdims=True)
            top_probs = np.max(softmax, axis=1)
        except Exception:
            top_probs = None

    if top_probs is not None:
        probs_df[name] = pd.Series(top_probs, index=evaluation_dataset.index)
    else:
        probs_df[name] = np.nan  # explicit float NaN

    # Print summary counts (optional)
    print(f"Model '{name}': predictions computed (not stored in evaluation_dataset).")
    try:
        vc = preds_df[name].value_counts(dropna=False)
        print(f" -> Prediction counts for {name}:\n{vc}\n")
    except Exception as e:
        print(f" -> Could not compute value_counts for {name}: {e}\n")

# Majority vote and apply to DX (no new Pred_* columns in evaluation_dataset) ---
vote_order = ["Model1", "Model2", "XAIModel1", "XAIModel2"]  # tie-breaker order

def majority_from_preds(row):
    # row is a Series containing predictions from each model (indexed by model name)
    preds = [row[m] for m in vote_order]
    # count frequencies, ignore NaN
    counts = pd.Series([p for p in preds if pd.notna(p)]).value_counts()
    if len(counts) == 0:
        return pd.NA  # no valid prediction
    if counts.iloc[0] > 1:
        return counts.idxmax()  # majority exists
    else:
        # no majority -> return the first model's prediction (if present), else first non-NaN
        first = preds[0]
        if pd.isna(first):
            # fallback to first non-NaN prediction
            for p in preds:
                if pd.notna(p):
                    return p
            return pd.NA
        return first

# Compute majority series from preds_df rows
majority_series = preds_df.apply(majority_from_preds, axis=1)

# Replace only existing NaNs in evaluation_dataset["DX"] (so we don't duplicate or erase non-NaN values)
if "DX" in evaluation_dataset.columns:
    evaluation_dataset.loc[:, "DX"] = evaluation_dataset["DX"].fillna(majority_series)
else:
    # If somehow DX did not exist, create it from majority
    evaluation_dataset.loc[:, "DX"] = majority_series

# Print detailed per-index results (predictions from models and final DX) ---
for idx in evaluation_dataset.index:
    # gather preds for this index
    preds_for_idx = preds_df.loc[idx, vote_order].tolist()
    probs_for_idx = probs_df.loc[idx, vote_order].tolist()
    print(f"Index {idx}:")
    for m, p, pr in zip(vote_order, preds_for_idx, probs_for_idx):
        # show probability only if not NaN
        if pd.notna(pr):
            print(f"  {m}: {p}  (p={pr:.3f})")
        else:
            print(f"  {m}: {p}")
    print(f"  => Final DX: {evaluation_dataset.loc[idx, 'DX']}\n")

evaluation_dataset = cleaner.convert_float_to_int("DX", method="floor", dataset=evaluation_dataset)

Model 'Model1': predictions computed (not stored in evaluation_dataset).
 -> Prediction counts for Model1:
Model1
0    4
2    3
3    2
1    2
Name: count, dtype: int64

Model 'XAIModel1': predictions computed (not stored in evaluation_dataset).
 -> Prediction counts for XAIModel1:
XAIModel1
0    4
2    3
3    2
1    2
Name: count, dtype: int64

Model 'Model2': predictions computed (not stored in evaluation_dataset).
 -> Prediction counts for Model2:
Model2
0    6
1    3
3    2
Name: count, dtype: int64

Model 'XAIModel2': predictions computed (not stored in evaluation_dataset).
 -> Prediction counts for XAIModel2:
XAIModel2
0    6
1    3
3    2
Name: count, dtype: int64

Index 0:
  Model1: 3  (p=1.000)
  Model2: 3  (p=0.990)
  XAIModel1: 3  (p=1.000)
  XAIModel2: 3  (p=0.989)
  => Final DX: 3.0

Index 1:
  Model1: 2  (p=0.825)
  Model2: 0  (p=0.978)
  XAIModel1: 2  (p=0.646)
  XAIModel2: 0  (p=0.755)
  => Final DX: 2.0

Index 2:
  Model1: 1  (p=0.919)
  Model2: 0  (p=0.874)
  XAIModel1

In [12]:
display(evaluation_dataset)

Unnamed: 0,DX,AGE,PTGENDER,PTEDUCAT,APOE4,MMSE,CDRSB,ADAS13,LDELTOTAL,FAQ,MOCA,TRABSCOR,RAVLT_immediate,RAVLT_learning,RAVLT_perc_forgetting,mPACCdigit,EcogPtMem,EcogPtLang,EcogPtVisspat,EcogPtPlan,EcogPtOrgan,EcogPtDivatt,EcogSPMem,EcogSPLang,EcogSPVisspat,EcogSPPlan,EcogSPOrgan,EcogSPDivatt,FDG,PTAU/ABETA,Hippocampus/ICV,Entorhinal/ICV,Fusiform/ICV,MidTemp/ICV,Ventricles/ICV,WholeBrain/ICV
0,3,90,1,17,1,20,7.0,41,0,16,17,300,8,0,100.0,-19.7479,2.451428,2.074443,1.661904,1.776,1.891999,2.16,2.375,1.66667,1.5,2.0,2.5,2.5,0.735309,0.054991,0.003296,0.00159,0.008719,0.009541,0.038619,0.597339
1,2,66,1,16,0,28,0.5,7,8,0,28,61,46,6,33.3333,-4.26884,2.0,1.22222,1.0,1.4,1.16667,1.75,1.0,1.0,1.0,1.0,1.0,1.0,1.302911,0.0188,0.005393,0.003104,0.01153,0.015536,0.012667,0.711053
2,1,64,0,18,0,29,0.5,5,11,0,26,68,53,2,25.0,-0.42049,1.375,2.11111,1.33333,1.6,1.4,1.25,1.25,1.11111,1.0,1.0,1.16667,1.0,1.342462,0.015544,0.005541,0.003074,0.012929,0.014192,0.01865,0.716139
3,2,60,0,16,1,28,2.0,15,8,1,24,119,35,3,87.5,-6.49645,3.25,2.22222,1.16667,2.0,3.33333,3.5,2.255,1.797779,1.529525,1.578667,1.913332,2.123334,1.305944,0.024087,0.004384,0.002225,0.012393,0.013341,0.017905,0.700865
4,1,55,0,16,1,27,1.0,6,15,5,26,71,62,8,6.66667,-0.40621,2.625,2.33333,1.5,1.8,1.83333,3.25,1.875,1.33333,1.66667,1.4,1.5,2.0,1.347033,0.019281,0.005629,0.003219,0.011887,0.014959,0.014547,0.744025
5,0,68,0,15,1,30,0.0,5,12,0,24,60,46,4,72.7273,-0.341944,1.375,1.33333,1.33333,1.4,1.0,2.5,1.75,1.22222,1.33333,1.0,1.0,1.75,1.257902,0.03249,0.004669,0.002737,0.010387,0.014345,0.028638,0.684665
6,3,67,0,14,2,25,2.5,32,1,15,17,215,23,2,100.0,-14.4432,2.205,1.866665,1.744762,1.632,1.766667,1.8,3.375,3.11111,3.0,4.0,3.66667,4.0,0.948225,0.068712,0.004074,0.002459,0.012931,0.011551,0.046299,0.666848
7,0,71,0,14,0,27,0.0,11,8,1,23,113,39,6,42.15384,-6.75361,2.0,1.63111,1.429524,1.208,1.166667,1.51,1.527858,1.19611,1.13,1.06,1.238667,1.37,1.304248,0.018641,0.004813,0.002735,0.011619,0.01353,0.0211,0.705617
8,0,67,0,16,0,27,0.0,6,10,0,22,75,40,9,61.5385,-3.1479,1.625,1.44444,1.0,1.6,1.66667,2.0,1.45,1.337776,1.32,1.25,1.36,1.45,1.264108,0.023994,0.004827,0.003053,0.012738,0.012838,0.038162,0.676523
9,2,78,1,18,1,29,0.5,11,3,0,24,116,33,3,37.5,-5.84674,2.0,2.0,1.16667,1.2,1.33333,1.75,1.25,1.0,1.0,1.0,1.16667,1.5,1.171227,0.043181,0.004804,0.002727,0.011828,0.013114,0.020212,0.664319


## Creation of NEWADNIMERGE.csv to accumulate data

We then create a dataset called **NEWADNIMERGE.csv**, which contains all the data from *train*, *test*, and *evaluation_dataset*. This dataset will then store the values ​​from the main app, ***main.py***.

In [13]:
train = pd.read_csv("../data/train.csv")
test = pd.read_csv("../data/test.csv")
evaluation_dataset
new_dataset = pd.concat([train, test, evaluation_dataset], ignore_index=True)

# Salvataggio in nuovo CSV
new_dataset.to_csv("../data/NEWADNIMERGE.csv", index=False)

print("File saved as NEWADNIMERGE.csv")

File saved as NEWADNIMERGE.csv


In [14]:
display(new_dataset)

Unnamed: 0,DX,AGE,PTGENDER,PTEDUCAT,APOE4,MMSE,CDRSB,ADAS13,LDELTOTAL,FAQ,MOCA,TRABSCOR,RAVLT_immediate,RAVLT_learning,RAVLT_perc_forgetting,mPACCdigit,EcogPtMem,EcogPtLang,EcogPtVisspat,EcogPtPlan,EcogPtOrgan,EcogPtDivatt,EcogSPMem,EcogSPLang,EcogSPVisspat,EcogSPPlan,EcogSPOrgan,EcogSPDivatt,FDG,PTAU/ABETA,Hippocampus/ICV,Entorhinal/ICV,Fusiform/ICV,MidTemp/ICV,Ventricles/ICV,WholeBrain/ICV
0,2,77.0,0,16.0,1.0,28.0,2.5,5.0,1.0,0.0,24.0,108.0,47.0,5.0,63.636400,-4.840050,2.250000,2.111110,1.000000,1.000,1.333330,1.00,2.375000,2.111110,2.428570,2.60,2.833330,2.75,1.222830,0.040838,0.004524,0.001882,0.012107,0.011311,0.016977,0.706210
1,0,59.0,1,16.0,1.0,30.0,0.0,0.0,19.0,0.0,30.0,47.0,71.0,2.0,0.000000,5.427020,1.000000,1.000000,1.000000,1.000,1.000000,1.00,1.000000,1.000000,1.000000,1.00,1.000000,1.00,1.161970,0.020445,0.004452,0.002756,0.012935,0.014299,0.025614,0.752850
2,3,77.0,1,12.0,2.0,22.0,8.0,30.0,0.0,25.0,17.0,300.0,19.0,1.0,100.000000,-18.905400,2.300000,1.844446,1.248572,1.580,1.366668,1.75,3.841666,2.847620,3.033334,2.97,3.166668,3.80,0.924559,0.047131,0.002825,0.001348,0.010049,0.009701,0.053417,0.522572
3,2,82.0,1,20.0,0.0,26.0,1.5,21.0,4.0,0.0,24.0,63.0,35.0,1.0,85.714300,-7.957490,1.925000,1.269446,1.166668,1.200,1.466668,1.60,1.891666,1.272222,1.066668,1.16,1.733332,2.10,1.119130,0.020198,0.003736,0.002083,0.013038,0.013942,0.024176,0.637729
4,0,83.0,0,17.0,0.0,27.0,0.0,5.0,13.0,3.0,25.0,98.0,57.0,7.0,7.142860,-1.948410,1.250000,1.333330,1.000000,1.000,1.333330,1.00,1.375000,1.111110,1.666670,1.00,1.833330,1.25,1.279034,0.026879,0.004611,0.002170,0.011387,0.012975,0.052196,0.635279
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2424,3,67.0,0,14.0,2.0,25.0,2.5,32.0,1.0,15.0,17.0,215.0,23.0,2.0,100.000000,-14.443200,2.205000,1.866665,1.744762,1.632,1.766667,1.80,3.375000,3.111110,3.000000,4.00,3.666670,4.00,0.948225,0.068712,0.004074,0.002459,0.012931,0.011551,0.046299,0.666848
2425,0,71.0,0,14.0,0.0,27.0,0.0,11.0,8.0,1.0,23.0,113.0,39.0,6.0,42.153840,-6.753610,2.000000,1.631110,1.429524,1.208,1.166667,1.51,1.527858,1.196110,1.130000,1.06,1.238667,1.37,1.304248,0.018641,0.004813,0.002735,0.011619,0.013530,0.021100,0.705617
2426,0,67.0,0,16.0,0.0,27.0,0.0,6.0,10.0,0.0,22.0,75.0,40.0,9.0,61.538500,-3.147900,1.625000,1.444440,1.000000,1.600,1.666670,2.00,1.450000,1.337776,1.320000,1.25,1.360000,1.45,1.264108,0.023994,0.004827,0.003053,0.012738,0.012838,0.038162,0.676523
2427,2,78.0,1,18.0,1.0,29.0,0.5,11.0,3.0,0.0,24.0,116.0,33.0,3.0,37.500000,-5.846740,2.000000,2.000000,1.166670,1.200,1.333330,1.75,1.250000,1.000000,1.000000,1.00,1.166670,1.50,1.171227,0.043181,0.004804,0.002727,0.011828,0.013114,0.020212,0.664319
