In [1]:
import numpy as np
import pandas as pd
import os
from metadata import *
from utils_cnn import *
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np

# Load and visualize data

### Load the MNI

In [2]:
# load MNI image
MNI_PATH = '../data/datasets/_MNI_template/mni_icbm152_nlin_asym_09c_nifti/mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c.nii'
mni = nib.load(MNI_PATH).get_fdata()
mni_mean, mni_std = [mni.mean(), mni.std()]

center_crop = 200
data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mni_mean, mni_std, inplace=False),
])

### Load the images for the given subjects

In [16]:
# load the image paths from the metadata file
mtd_csv = '../data/metadata/metadata.csv'
mtd_df = pd.read_csv(mtd_csv, sep=';')
mtd_df.head()

Unnamed: 0,dataset,ID,score,sex,T1_b_exists,preprocessed,csf,gm,wm,background,subc,T1_b_shape,T1_exists,T1_path,T1_shape
0,ImaGenoma,001065-00,13.0,1,True,../data/datasets/ImaGenoma/T1_b/001065-00/rigi...,../data/datasets/ImaGenoma/T1_b/001065-00/csf....,../data/datasets/ImaGenoma/T1_b/001065-00/gm.n...,../data/datasets/ImaGenoma/T1_b/001065-00/wm.n...,../data/datasets/ImaGenoma/T1_b/001065-00/back...,../data/datasets/ImaGenoma/T1_b/001065-00/subc...,"(193, 229, 193)",True,../data/datasets/ImaGenoma/T1/001065-00/3D_T1W...,"(112, 176, 180)"
1,ImaGenoma,000295-01,7.0,0,True,../data/datasets/ImaGenoma/T1_b/000295-01/rigi...,../data/datasets/ImaGenoma/T1_b/000295-01/csf....,../data/datasets/ImaGenoma/T1_b/000295-01/gm.n...,../data/datasets/ImaGenoma/T1_b/000295-01/wm.n...,../data/datasets/ImaGenoma/T1_b/000295-01/back...,../data/datasets/ImaGenoma/T1_b/000295-01/subc...,"(193, 229, 193)",True,../data/datasets/ImaGenoma/T1/000295-01/3D_T1W...,"(112, 176, 180)"
2,ImaGenoma,801081-02,9.0,0,True,../data/datasets/ImaGenoma/T1_b/801081-02/rigi...,../data/datasets/ImaGenoma/T1_b/801081-02/csf....,../data/datasets/ImaGenoma/T1_b/801081-02/gm.n...,../data/datasets/ImaGenoma/T1_b/801081-02/wm.n...,../data/datasets/ImaGenoma/T1_b/801081-02/back...,../data/datasets/ImaGenoma/T1_b/801081-02/subc...,"(193, 229, 193)",True,../data/datasets/ImaGenoma/T1/801081-02/3D_T1W...,"(112, 176, 180)"
3,ImaGenoma,800905-02,8.0,0,True,../data/datasets/ImaGenoma/T1_b/800905-02/rigi...,../data/datasets/ImaGenoma/T1_b/800905-02/csf....,../data/datasets/ImaGenoma/T1_b/800905-02/gm.n...,../data/datasets/ImaGenoma/T1_b/800905-02/wm.n...,../data/datasets/ImaGenoma/T1_b/800905-02/back...,../data/datasets/ImaGenoma/T1_b/800905-02/subc...,"(193, 229, 193)",True,../data/datasets/ImaGenoma/T1/800905-02/3D_T1W...,"(112, 176, 180)"
4,ImaGenoma,000501-01,7.0,0,True,../data/datasets/ImaGenoma/T1_b/000501-01/rigi...,../data/datasets/ImaGenoma/T1_b/000501-01/csf....,../data/datasets/ImaGenoma/T1_b/000501-01/gm.n...,../data/datasets/ImaGenoma/T1_b/000501-01/wm.n...,../data/datasets/ImaGenoma/T1_b/000501-01/back...,../data/datasets/ImaGenoma/T1_b/000501-01/subc...,"(193, 229, 193)",True,../data/datasets/ImaGenoma/T1/000501-01/3D_T1W...,"(112, 176, 180)"


In [17]:
IDs = mtd_df['ID'].values
print('Number of subjects with images: ',len(IDs))

Number of subjects with images:  1021


### Load and clean up the labels

In [3]:
# load the labels file
data_sav = '/home/alex/data/datasets/ImaGenoma/IMAGENOMA_AGING_PHENOTYPES_16022021.sav'
data = pd.read_spss(data_sav)
print('Columns: ', data.columns)

data.head()

Columns:  Index(['ID_IMAGENOMA', 'ID_ORIGEN', 'AGE', 'AGE_GROUPS_5G', 'AGE_GROUPS_4G',
       'AGE_GROUPS_3G', 'AGE_GROUPS_2G', 'GENDER', 'EDUCATION_LEVEL_5G',
       'EDUCATION_LEVEL_4G', 'EDUCATION_LEVEL_3G', 'EDUCATION_LEVEL_2G',
       'DIET_1', 'DIET_2', 'DIET_3', 'DIET_4', 'DIET_5', 'DIET_6', 'DIET_7',
       'DIET_8', 'DIET_9', 'DIET_10', 'DIET_11', 'DIET_12', 'DIET_13',
       'DIET_14', 'DIET_SCORE', 'DIET_RK', 'TPR', 'TFR', 'TDFR', 'TDPR', 'FVF',
       'FVF_NORM', 'FVS', 'FVS_NORM', 'DST_F', 'DST_F_NORM', 'DST_B',
       'DST_B_NORM', 'SDT', 'SDT_NORM', 'SCWT_WC', 'SCWT_WC_NORM', 'COGNITION',
       'PHQ_1', 'PHQ_2', 'PHQ_3', 'PHQ_4', 'PHQ_5', 'PHQ_6', 'PHQ_7', 'PHQ_8',
       'PHQ_9', 'PHQ_SCORE', 'PHQ_SCORE_RK', 'BFI_EXTR', 'BFI_AGRE',
       'BFI_CONS', 'BFI_NEUR', 'BFI_OPEN', 'VIG_METS_WEEK', 'MOD_METS_WEEK',
       'WALK_METS_WEEK', 'METS_MIN_WEEK', 'PHYSICAL_ACT'],
      dtype='object')


Unnamed: 0,ID_IMAGENOMA,ID_ORIGEN,AGE,AGE_GROUPS_5G,AGE_GROUPS_4G,AGE_GROUPS_3G,AGE_GROUPS_2G,GENDER,EDUCATION_LEVEL_5G,EDUCATION_LEVEL_4G,...,BFI_EXTR,BFI_AGRE,BFI_CONS,BFI_NEUR,BFI_OPEN,VIG_METS_WEEK,MOD_METS_WEEK,WALK_METS_WEEK,METS_MIN_WEEK,PHYSICAL_ACT
0,800006-01,SHARE,77.670089,70-79,70-79,75+,65+,Male,Primary,Primary,...,7.0,8.0,8.0,4.0,7.0,0.0,0.0,2772.0,2772.0,Moderate
1,800011-02,SHARE,60.87885,60-69,60-69,50-64,50-64,Male,Secondary,Secondary,...,,,,,,,,,,
2,800019-01,SHARE,66.231348,60-69,60-69,65-74,65+,Male,University,University,...,5.0,6.0,10.0,9.0,9.0,3840.0,0.0,6930.0,10770.0,High
3,800027-01,SHARE,67.214237,60-69,60-69,65-74,65+,Female,Primary,Primary,...,5.0,6.0,6.0,6.0,5.0,0.0,0.0,1386.0,1386.0,Moderate
4,800027-02,SHARE,69.322382,60-69,60-69,65-74,65+,Male,Primary,Primary,...,5.0,5.0,5.0,6.0,4.0,0.0,1680.0,11088.0,12768.0,High


In [4]:
# We will only use the following columns: 
our_columns = [c for c in data.columns if c in ['ID_IMAGENOMA', 'AGE', 'GENDER'] or c.startswith('DIET')]
our_data = data.filter(our_columns)

our_data.head()

Unnamed: 0,ID_IMAGENOMA,AGE,GENDER,DIET_1,DIET_2,DIET_3,DIET_4,DIET_5,DIET_6,DIET_7,DIET_8,DIET_9,DIET_10,DIET_11,DIET_12,DIET_13,DIET_14,DIET_SCORE,DIET_RK
0,800006-01,77.670089,Male,yes,no,no,no,no,yes,yes,no,yes,yes,yes,no,no,yes,7.0,Moderate
1,800011-02,60.87885,Male,,,,,,,,,,,,,,,,
2,800019-01,66.231348,Male,yes,yes,no,no,no,yes,no,yes,yes,yes,no,yes,no,no,7.0,Moderate
3,800027-01,67.214237,Female,yes,yes,yes,yes,no,yes,yes,yes,yes,no,yes,yes,yes,yes,12.0,High
4,800027-02,69.322382,Male,yes,yes,yes,no,yes,yes,yes,yes,yes,yes,yes,yes,yes,no,12.0,High


In [9]:
# check Dtype and remove all rows with missing values
print('Final number of subjects: ', len(our_data))
our_data.info()
our_data = our_data.dropna()


Final number of subjects:  434
<class 'pandas.core.frame.DataFrame'>
Int64Index: 434 entries, 0 to 463
Data columns (total 19 columns):
 #   Column        Non-Null Count  Dtype   
---  ------        --------------  -----   
 0   ID_IMAGENOMA  434 non-null    object  
 1   AGE           434 non-null    float64 
 2   GENDER        434 non-null    category
 3   DIET_1        434 non-null    category
 4   DIET_2        434 non-null    category
 5   DIET_3        434 non-null    category
 6   DIET_4        434 non-null    category
 7   DIET_5        434 non-null    category
 8   DIET_6        434 non-null    category
 9   DIET_7        434 non-null    category
 10  DIET_8        434 non-null    category
 11  DIET_9        434 non-null    category
 12  DIET_10       434 non-null    category
 13  DIET_11       434 non-null    category
 14  DIET_12       434 non-null    category
 15  DIET_13       434 non-null    category
 16  DIET_14       434 non-null    category
 17  DIET_SCORE    434 non-n

    ID_IMAGENOMA        AGE  GENDER DIET_1 DIET_2 DIET_3 DIET_4 DIET_5 DIET_6  \
0      800006-01  77.670089    Male    yes     no     no     no     no    yes   
2      800019-01  66.231348    Male    yes    yes     no     no     no    yes   
3      800027-01  67.214237  Female    yes    yes    yes    yes     no    yes   
4      800027-02  69.322382    Male    yes    yes    yes     no    yes    yes   
5      800037-01  67.126626  Female    yes    yes     no     no    yes    yes   
..           ...        ...     ...    ...    ...    ...    ...    ...    ...   
459    805688-02  70.395619  Female    yes     no     no    yes    yes    yes   
460    805690-01  65.886379  Female    yes     no     no     no    yes    yes   
461    805884-01  72.032854    Male    yes     no     no     no     no    yes   
462    805899-01  72.714579    Male    yes    yes     no    yes    yes    yes   
463    805899-02  69.223819  Female    yes     no     no     no    yes    yes   

    DIET_7 DIET_8 DIET_9 DI