In [1]:
import torch
import torch.nn as nn
import torchio as tio 
import numpy as np
import pandas as pd
import nibabel as nib
import json

import os

# global settings
base_path = '/mnt/data_lab513/vqtran_data'
root_data = os.path.join(base_path, "data", "raw_data", "ADNI_NIfTI")
root_bias_correction = os.path.join(base_path, "data", "clean_data", "mri_bias_correction")
root_bet = os.path.join(base_path, "data", "clean_data", "mri_brain_extraction")
root_reg = os.path.join(base_path, "data", "clean_data", "mri_registration")
root_meta = os.path.join(base_path, "data", "meta_data")#, "Pre-Thesis_metadata", "ADNI") 
root_train = os.path.join(base_path, "data", "train_data")
root_train_dec = os.path.join(base_path, "data", "data_train_dec", "origin")
root_train_unique = os.path.join(base_path, "data", "data_train_dec", "unique")
root_train_unique_tensor = os.path.join(base_path, "data", "data_train_dec", "tensor")



# Work

0. Read data in (pandas/ json + nib filename)
1. Compute mean
2. Compute root mean square distance
3. Save data as a dataframe in unique_subject_with_prediction_mean_filter.csv
4. Cut off image have high root mean square distance
4. Create torch tensor for dataset


# 1. Read data in & compute_mean

In [2]:
subject_dict = json.load(open('../investigate/unique_dataset_dict.json', 'r'))

mean_image_CN = torch.zeros(110, 110, 110)
count_image_CN = 0
# print(mean_image_CN.dtype) #torch.float32
mean_image_AD = torch.zeros(110, 110, 110)
count_image_AD = 0

for key in subject_dict.keys():
    filename = subject_dict[key][2]
    image_absolute_path = os.path.join(root_train_unique, filename)
    # print(filename)
    label = subject_dict[key][1]
    # print(label)
    
    if label == "CN" or label == "AD":
        image_sample = nib.load(image_absolute_path)
        image_array = image_sample.get_fdata()
        
        image_tensor = torch.Tensor(image_array)
        image_tensor = torch.unsqueeze(image_tensor,0)
        
        image_transformation_tio = tio.transforms.Compose(
                [
                    tio.transforms.Resize((110,110,110)),
                    tio.ZNormalization(),
                    tio.RescaleIntensity(out_min_max=(0, 1)) #), in_min_max=(0., 8957.8574))
                ]
        )

        image_tensor = image_transformation_tio(image_tensor)
        
        # print(image_tensor.shape)
        
        if label == "CN":
            mean_image_CN = torch.add(mean_image_CN, image_tensor)
            count_image_CN += 1
        elif label == "AD":
            mean_image_AD = torch.add(mean_image_AD, image_tensor)
            count_image_AD += 1
        else:
            raise ValueError("label must be CN or AD")  
    else:
        continue
        
    # print(label)
    # print(mean_image_CN.mean())  
    # print(mean_image_AD.mean())  
    # print(image_tensor.mean()) 
    # print(count_image_CN)
    # print(count_image_AD)
    # break
    
    
print(mean_image_CN.mean())  
print(mean_image_AD.mean())  
print(mean_image_CN.max())  
print(mean_image_AD.max()) 
print(mean_image_CN.min())  
print(mean_image_AD.min()) 
print(count_image_CN)
print(count_image_AD)

mean_image_CN /= count_image_CN
mean_image_AD /= count_image_AD

print(mean_image_CN.mean())  
print(mean_image_AD.mean())  
print(mean_image_CN.max())  
print(mean_image_AD.max()) 
print(mean_image_CN.min())  
print(mean_image_AD.min()) 


tensor(45.3884)
tensor(35.3058)
tensor(274.1535)
tensor(216.0457)
tensor(0.)
tensor(0.)
349
278
tensor(0.1301)
tensor(0.1270)
tensor(0.7855)
tensor(0.7771)
tensor(0.)
tensor(0.)


In [3]:
torch.save(mean_image_CN, "mean_image_CN.pt")
torch.save(mean_image_AD, "mean_image_AD.pt")

# 2. Calcualte Mean Square Distance between each image and data

In [4]:
subject_dict_with_filter_distance = {}
loss = nn.MSELoss()
    
for key in subject_dict.keys():
    global mean_square_distance
    mean_square_distance = {}
    
    # for 2 class NC-AD problem # need to change for generalization
    if (subject_dict[key][1] == "CN") or (subject_dict[key][1] == "AD"):
        # print(key)
        image_absolute_path = os.path.join(root_train_unique, subject_dict[key][2])
        # print(image_absolute_path)
        
        label = subject_dict[key][1]
        # print(label)
        
        
        image_sample = nib.load(image_absolute_path)
        image_array = image_sample.get_fdata()
            
        image_tensor = torch.Tensor(image_array)
        image_tensor = torch.unsqueeze(image_tensor,0)
        
        image_transformation_tio = tio.transforms.Compose(
                [
                    tio.transforms.Resize((110,110,110)),
                    tio.ZNormalization(),
                    tio.RescaleIntensity(out_min_max=(0, 1)) #), in_min_max=(0., 8957.8574))
                ]
        )

        image_tensor = image_transformation_tio(image_tensor)
    
        distance_CN = loss(image_tensor, mean_image_CN)
        distance_AD = loss(image_tensor, mean_image_AD)
        
        # print(distance_CN)
        # print(distance_AD)
        # print(distance_CN.shape)
        # print(distance_AD.shape)
        # print(distance_CN.dtype)
        # print(distance_AD.dtype)
        
        mean_square_distance = {"DISTANCE CN": distance_CN.item(), "DISTANCE AD": distance_AD.item()}

        subject_dict_with_filter_distance[key] = {"Subject ID": key,
                                                "Image ID": subject_dict[key][0],
                                                "Image Path": image_absolute_path, 
                                                "Image Target": subject_dict[key][1], 
                                            **mean_square_distance}
        
        
        
    # print(subject_dict_with_filter_distance)
    # print(mean_square_distance)
    
print(len(subject_dict_with_filter_distance.keys()))
# print(mean_square_distance)

627


In [5]:
filter_distance_dataframe = pd.DataFrame()
for key in subject_dict_with_filter_distance.keys():
    # print(subject_dict_with_filter_distance[key])     
    filter_distance_dataframe = filter_distance_dataframe.append(subject_dict_with_filter_distance[key], ignore_index = True)

        
filter_distance_dataframe.to_csv("../investigate/unique_subject_with_filter_distance_CN_AD.csv", index=False)

In [6]:
filter_distance_dataframe = pd.read_csv("../investigate/unique_subject_with_filter_distance_CN_AD.csv")
filter_distance_dataframe.head(10)

Unnamed: 0,Subject ID,Image ID,Image Path,Image Target,DISTANCE CN,DISTANCE AD
0,002_S_0295,I13722,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.003989,0.004507
1,002_S_0413,I14437,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.007575,0.008511
2,002_S_0559,I15948,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.005903,0.005204
3,002_S_0619,I16392,/mnt/data_lab513/vqtran_data/data/data_train_d...,AD,0.012839,0.011432
4,002_S_0685,I18211,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.006689,0.007408
5,005_S_0814,I23573,/mnt/data_lab513/vqtran_data/data/data_train_d...,AD,0.004964,0.005108
6,005_S_0929,I25645,/mnt/data_lab513/vqtran_data/data/data_train_d...,AD,0.006088,0.006548
7,005_S_1341,I43188,/mnt/data_lab513/vqtran_data/data/data_train_d...,AD,0.004963,0.005419
8,006_S_0484,I17377,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.007942,0.008971
9,006_S_0498,I17505,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002531,0.002767


In [7]:
intra_inter_distance_dataframe = 0
intra_class_distance = []
inter_class_distance = []


for ii in range(len(filter_distance_dataframe)):
    subject_id = filter_distance_dataframe.loc[ii, "Subject ID"]
    label = filter_distance_dataframe.loc[ii, "Image Target"]
    
    distance_CN = filter_distance_dataframe.loc[ii, "DISTANCE CN"] 
    distance_AD = filter_distance_dataframe.loc[ii, "DISTANCE AD"]
    
    if label == "AD" or label == "CN":
        if label == "CN":
            intra_class_distance.append(distance_CN)
            inter_class_distance.append(distance_AD - distance_CN) # tính norm 
            # inter_class_distance.append(distance_EMCI - distance_CN)
            # combine for comparison?
            
        elif label == "AD":
            intra_class_distance.append(distance_AD)
            inter_class_distance.append(distance_CN - distance_AD )
        
        else:
            raise ValueError("This dataframe only allow 2 labels: AD and CN")
    else:
        raise ValueError("This dataframe only allow 2 labels: AD and CN")
    # print(filter_distance_dataframe.loc[0,:].to_frame().T)
    
    
intra_inter_distance_dataframe = filter_distance_dataframe.assign(INTRA_CLASS_DISTANCE=pd.Series(np.array(intra_class_distance)).values)
intra_inter_distance_dataframe = intra_inter_distance_dataframe.assign(INTER_CLASS_DISTANCE=pd.Series(np.array(inter_class_distance)).values)
intra_inter_distance_dataframe.head(10) 

Unnamed: 0,Subject ID,Image ID,Image Path,Image Target,DISTANCE CN,DISTANCE AD,INTRA_CLASS_DISTANCE,INTER_CLASS_DISTANCE
0,002_S_0295,I13722,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.003989,0.004507,0.003989,0.000518
1,002_S_0413,I14437,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.007575,0.008511,0.007575,0.000935
2,002_S_0559,I15948,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.005903,0.005204,0.005903,-0.0007
3,002_S_0619,I16392,/mnt/data_lab513/vqtran_data/data/data_train_d...,AD,0.012839,0.011432,0.011432,0.001407
4,002_S_0685,I18211,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.006689,0.007408,0.006689,0.000718
5,005_S_0814,I23573,/mnt/data_lab513/vqtran_data/data/data_train_d...,AD,0.004964,0.005108,0.005108,-0.000144
6,005_S_0929,I25645,/mnt/data_lab513/vqtran_data/data/data_train_d...,AD,0.006088,0.006548,0.006548,-0.00046
7,005_S_1341,I43188,/mnt/data_lab513/vqtran_data/data/data_train_d...,AD,0.004963,0.005419,0.005419,-0.000456
8,006_S_0484,I17377,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.007942,0.008971,0.007942,0.001029
9,006_S_0498,I17505,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002531,0.002767,0.002531,0.000236


# 4. Cut off image have high root mean square distance

In [8]:
positive_inter_class_distance_filter = intra_inter_distance_dataframe["INTER_CLASS_DISTANCE"] > 0 

In [9]:
positive_inter_class_distance_dataframe = intra_inter_distance_dataframe[positive_inter_class_distance_filter]
print(len(positive_inter_class_distance_dataframe))
positive_inter_class_distance_dataframe.head(10)

359


Unnamed: 0,Subject ID,Image ID,Image Path,Image Target,DISTANCE CN,DISTANCE AD,INTRA_CLASS_DISTANCE,INTER_CLASS_DISTANCE
0,002_S_0295,I13722,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.003989,0.004507,0.003989,0.000518
1,002_S_0413,I14437,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.007575,0.008511,0.007575,0.000935
3,002_S_0619,I16392,/mnt/data_lab513/vqtran_data/data/data_train_d...,AD,0.012839,0.011432,0.011432,0.001407
4,002_S_0685,I18211,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.006689,0.007408,0.006689,0.000718
8,006_S_0484,I17377,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.007942,0.008971,0.007942,0.001029
9,006_S_0498,I17505,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002531,0.002767,0.002531,0.000236
12,006_S_0681,I23677,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.006057,0.006762,0.006057,0.000704
13,006_S_0731,I23468,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.00378,0.004405,0.00378,0.000626
14,009_S_5027,I351495,/mnt/data_lab513/vqtran_data/data/data_train_d...,AD,0.010501,0.009351,0.009351,0.001151
20,002_S_1261,I286516,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002441,0.002689,0.002441,0.000248


In [10]:
# Class distribution in positive_inter_class_distance_dataframe
print(positive_inter_class_distance_dataframe["Image Target"].value_counts()) #CN 211, AD 148

CN    211
AD    148
Name: Image Target, dtype: int64


In [11]:
CN_filter = positive_inter_class_distance_dataframe["Image Target"] == "CN"
AD_filter = positive_inter_class_distance_dataframe["Image Target"] == "AD"

positive_inter_class_distance_CN_dataframe = positive_inter_class_distance_dataframe[CN_filter]
positive_inter_class_distance_AD_dataframe = positive_inter_class_distance_dataframe[AD_filter]

# positive_inter_class_distance_CN_dataframe.head()
positive_inter_class_distance_AD_dataframe.head()

Unnamed: 0,Subject ID,Image ID,Image Path,Image Target,DISTANCE CN,DISTANCE AD,INTRA_CLASS_DISTANCE,INTER_CLASS_DISTANCE
3,002_S_0619,I16392,/mnt/data_lab513/vqtran_data/data/data_train_d...,AD,0.012839,0.011432,0.011432,0.001407
14,009_S_5027,I351495,/mnt/data_lab513/vqtran_data/data/data_train_d...,AD,0.010501,0.009351,0.009351,0.001151
29,098_S_0149,I10146,/mnt/data_lab513/vqtran_data/data/data_train_d...,AD,0.00544,0.004803,0.004803,0.000637
36,099_S_4994,I348826,/mnt/data_lab513/vqtran_data/data/data_train_d...,AD,0.005427,0.005424,0.005424,3e-06
40,029_S_0999,I31239,/mnt/data_lab513/vqtran_data/data/data_train_d...,AD,0.003201,0.00287,0.00287,0.000331


In [12]:
positive_inter_class_distance_CN_sorted_dataframe = positive_inter_class_distance_CN_dataframe.sort_values(by=['INTRA_CLASS_DISTANCE'], ascending = True)
positive_inter_class_distance_AD_sorted_dataframe = positive_inter_class_distance_AD_dataframe.sort_values(by=['INTRA_CLASS_DISTANCE'], ascending = True)

In [13]:
positive_inter_class_distance_CN_sorted_dataframe.head(10)
# positive_inter_class_distance_AD_sorted_dataframe.head(10)

Unnamed: 0,Subject ID,Image ID,Image Path,Image Target,DISTANCE CN,DISTANCE AD,INTRA_CLASS_DISTANCE,INTER_CLASS_DISTANCE
369,036_S_0672,I19462,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001755,0.00182,0.001755,6.460061e-05
539,052_S_1251,I38955,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001854,0.002007,0.001854,0.0001539403
208,012_S_4642,I296878,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001871,0.002161,0.001871,0.0002896349
104,036_S_4878,I321504,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002011,0.002112,0.002011,0.0001005826
96,023_S_0058,I9329,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002038,0.002206,0.002038,0.0001681906
296,153_S_4139,I250181,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002098,0.00211,0.002098,1.213956e-05
230,036_S_0576,I16408,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002113,0.002295,0.002113,0.0001820349
297,153_S_4151,I251754,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002174,0.002453,0.002174,0.0002793837
153,012_S_4545,I290413,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002213,0.002667,0.002213,0.0004544151
295,941_S_1203,I37688,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002314,0.002314,0.002314,5.634502e-08


positive distance inter class filter -> devide filtered dataframe to CN dataframe and AD dataframe -> sort intra class distance of 2 dataframe CN and AD (ascending order = True ) -> choose the first 100 subject in each dataframe CN and AD -> combine 2 dataframe CN and AD -> pytorch tensor

In [14]:
frames_to_concate = [positive_inter_class_distance_CN_sorted_dataframe.head(100), positive_inter_class_distance_AD_sorted_dataframe.head(100)]  # Or perform operations on the DFs
positive_inter_class_distance_sorted_dataframe = pd.concat(frames_to_concate)

positive_inter_class_distance_sorted_dataframe.head()

Unnamed: 0,Subject ID,Image ID,Image Path,Image Target,DISTANCE CN,DISTANCE AD,INTRA_CLASS_DISTANCE,INTER_CLASS_DISTANCE
369,036_S_0672,I19462,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001755,0.00182,0.001755,6.5e-05
539,052_S_1251,I38955,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001854,0.002007,0.001854,0.000154
208,012_S_4642,I296878,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001871,0.002161,0.001871,0.00029
104,036_S_4878,I321504,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002011,0.002112,0.002011,0.000101
96,023_S_0058,I9329,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002038,0.002206,0.002038,0.000168


In [15]:
print(len(positive_inter_class_distance_sorted_dataframe))


200


In [16]:
positive_inter_class_distance_sorted_dataframe = positive_inter_class_distance_sorted_dataframe.reset_index(drop=True)
positive_inter_class_distance_sorted_dataframe.head()

Unnamed: 0,Subject ID,Image ID,Image Path,Image Target,DISTANCE CN,DISTANCE AD,INTRA_CLASS_DISTANCE,INTER_CLASS_DISTANCE
0,036_S_0672,I19462,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001755,0.00182,0.001755,6.5e-05
1,052_S_1251,I38955,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001854,0.002007,0.001854,0.000154
2,012_S_4642,I296878,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001871,0.002161,0.001871,0.00029
3,036_S_4878,I321504,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002011,0.002112,0.002011,0.000101
4,023_S_0058,I9329,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002038,0.002206,0.002038,0.000168


# 4. Get tensor from dataframe

In [17]:
print(positive_inter_class_distance_sorted_dataframe.columns)

Index(['Subject ID', 'Image ID', 'Image Path', 'Image Target', 'DISTANCE CN',
       'DISTANCE AD', 'INTRA_CLASS_DISTANCE', 'INTER_CLASS_DISTANCE'],
      dtype='object')


In [18]:
print(positive_inter_class_distance_sorted_dataframe.loc[0, 'Image Path'])

/mnt/data_lab513/vqtran_data/data/data_train_dec/unique/ADNI_036_S_0672_MR_MPRAGE_br_raw_20060724092333840_1_S17131_I19462.nii.gz


In [19]:
# print(subject_dict)
X_tensor_cross_val = []
Y_tensor_cross_val = []


for ii in range(len(positive_inter_class_distance_sorted_dataframe)):
    subject_id = positive_inter_class_distance_sorted_dataframe.loc[ii, "Subject ID"]
    label = positive_inter_class_distance_sorted_dataframe.loc[ii, "Image Target"]
    
    # print(subject_id, label)
    image_absolute_path = positive_inter_class_distance_sorted_dataframe.loc[ii, 'Image Path']
    
    np_label = -1
    if label == "CN" or label == "AD":
        if label == "CN":
            np_label = np.array([0])
        elif label == "AD":
            np_label = np.array([1])
        else:
            raise ValueError("label must be CN or AD")  
    else:
        continue
        # raise ValueError("label can not be EMCI or LMCI")
        
    image_sample = nib.load(image_absolute_path)
    image_array = image_sample.get_fdata()
    
    image_tensor = torch.Tensor(image_array)
    image_tensor = torch.unsqueeze(image_tensor,0)
    
    image_transformation_tio = tio.transforms.Compose(
            [
                tio.transforms.Resize((110,110,110)),
                tio.ZNormalization(),
                tio.RescaleIntensity(out_min_max=(0, 1)) #), in_min_max=(0., 8957.8574))
            ]
    )

    image_tensor = image_transformation_tio(image_tensor)

    # print(image_tensor.shape)    
    label_tensor = torch.Tensor(np_label)

    X_tensor_cross_val.append(image_tensor)
    Y_tensor_cross_val.append(label_tensor)
    
X_tensor_cross_val = torch.stack(X_tensor_cross_val)
Y_tensor_cross_val = torch.stack(Y_tensor_cross_val)
    
Y_tensor_cross_val = Y_tensor_cross_val.ravel()
print(X_tensor_cross_val.shape)
print(Y_tensor_cross_val.shape)

torch.Size([200, 1, 110, 110, 110])
torch.Size([200])


In [20]:
torch.save(X_tensor_cross_val, os.path.join(root_train_unique_tensor, "x_tensor_NC_AD_cv_data_filter.pt"))
torch.save(Y_tensor_cross_val, os.path.join(root_train_unique_tensor, "y_tensor_NC_AD_cv_data_filter.pt"))

In [21]:
tensor_mean = torch.mean(X_tensor_cross_val) #tensor(86.5368) #tensor(0.1273)
tensor_std = torch.std(X_tensor_cross_val) #tensor(258.8274) #tensor(0.2380)
tensor_max = torch.max(X_tensor_cross_val) #tensor(8957.8574) #tensor(1.)
tensor_min = torch.min(X_tensor_cross_val) #tensor(0.)  #tensor(0.)
tensor_unique, tensor_count = torch.unique(Y_tensor_cross_val, return_counts=True)

# tensor_unique # tensor([0., 1.])
# tensor_count # tensor([100, 100])

In [22]:
print(tensor_mean)
print(tensor_std)
print(tensor_max)
print(tensor_min)
print(tensor_unique)
print(tensor_count)

tensor(0.1273)
tensor(0.2380)
tensor(1.)
tensor(0.)
tensor([0., 1.])
tensor([100, 100])
