In [1]:
import torch
import torch.nn as nn
import torchio as tio 
import numpy as np
import pandas as pd
import nibabel as nib
import json

import os

# global settings
base_path = '/mnt/data_lab513/vqtran_data'
root_data = os.path.join(base_path, "data", "raw_data", "ADNI_NIfTI")
root_bias_correction = os.path.join(base_path, "data", "clean_data", "mri_bias_correction")
root_bet = os.path.join(base_path, "data", "clean_data", "mri_brain_extraction")
root_reg = os.path.join(base_path, "data", "clean_data", "mri_registration")
root_meta = os.path.join(base_path, "data", "meta_data")#, "Pre-Thesis_metadata", "ADNI") 
root_train = os.path.join(base_path, "data", "train_data")
root_train_dec = os.path.join(base_path, "data", "data_train_dec", "origin")
root_train_unique = os.path.join(base_path, "data", "data_train_dec", "unique")
root_train_unique_tensor = os.path.join(base_path, "data", "data_train_dec", "tensor")



# Work

0. Read data in (pandas/ json + nib filename)
1. Compute mean
2. Compute root mean square distance
3. Save data as a dataframe in unique_subject_with_prediction_mean_filter.csv
4. Cut off image have high root mean square distance
4. Create torch tensor for dataset


# 1. Read data in & compute_mean

In [2]:
subject_dict = json.load(open('../investigate/unique_dataset_dict.json', 'r'))

mean_image_CN = torch.zeros(110, 110, 110)
count_image_CN = 0
# print(mean_image_CN.dtype) #torch.float32
mean_image_EMCI = torch.zeros(110, 110, 110)
count_image_EMCI = 0

for key in subject_dict.keys():
    filename = subject_dict[key][2]
    image_absolute_path = os.path.join(root_train_unique, filename)
    # print(filename)
    label = subject_dict[key][1]
    # print(label)
    
    if label == "CN" or label == "EMCI":
        image_sample = nib.load(image_absolute_path)
        image_array = image_sample.get_fdata()
        
        image_tensor = torch.Tensor(image_array)
        image_tensor = torch.unsqueeze(image_tensor,0)
        
        image_transformation_tio = tio.transforms.Compose(
                [
                    tio.transforms.Resize((110,110,110)),
                    tio.ZNormalization(),
                    tio.RescaleIntensity(out_min_max=(0, 1)) #), in_min_max=(0., 8957.8574))
                ]
        )

        image_tensor = image_transformation_tio(image_tensor)
        
        # print(image_tensor.shape)
        
        if label == "CN":
            mean_image_CN = torch.add(mean_image_CN, image_tensor)
            count_image_CN += 1
        elif label == "EMCI":
            mean_image_EMCI = torch.add(mean_image_EMCI, image_tensor)
            count_image_EMCI += 1
        else:
            raise ValueError("label must be CN or EMCI")  
    else:
        continue
        
    # print(label)
    # print(mean_image_CN.mean())  
    # print(mean_image_EMCI.mean())  
    # print(image_tensor.mean()) 
    # print(count_image_CN)
    # print(count_image_EMCI)
    # break
    
    
print(mean_image_CN.mean())  
print(mean_image_EMCI.mean())  
print(mean_image_CN.max())  
print(mean_image_EMCI.max()) 
print(mean_image_CN.min())  
print(mean_image_EMCI.min()) 
print(count_image_CN)
print(count_image_EMCI)

mean_image_CN /= count_image_CN
mean_image_EMCI /= count_image_EMCI

print(mean_image_CN.mean())  
print(mean_image_EMCI.mean())  
print(mean_image_CN.max())  
print(mean_image_EMCI.max()) 
print(mean_image_CN.min())  
print(mean_image_EMCI.min()) 


tensor(45.3884)
tensor(27.1659)
tensor(274.1535)
tensor(170.8169)
tensor(0.)
tensor(0.)
349
238
tensor(0.1301)
tensor(0.1141)
tensor(0.7855)
tensor(0.7177)
tensor(0.)
tensor(0.)


In [3]:
torch.save(mean_image_CN, "mean_image_CN.pt")
torch.save(mean_image_EMCI, "mean_image_EMCI.pt")

# 2. Calcualte Mean Square Distance between each image and data

In [4]:
subject_dict_with_filter_distance = {}
loss = nn.MSELoss()
    
for key in subject_dict.keys():
    global mean_square_distance
    mean_square_distance = {}
    
    # for 2 class NC-EMCI problem # need to change for generalization
    if (subject_dict[key][1] == "CN") or (subject_dict[key][1] == "EMCI"):
        # print(key)
        image_absolute_path = os.path.join(root_train_unique, subject_dict[key][2])
        # print(image_absolute_path)
        
        label = subject_dict[key][1]
        # print(label)
        
        
        image_sample = nib.load(image_absolute_path)
        image_array = image_sample.get_fdata()
            
        image_tensor = torch.Tensor(image_array)
        image_tensor = torch.unsqueeze(image_tensor,0)
        
        image_transformation_tio = tio.transforms.Compose(
                [
                    tio.transforms.Resize((110,110,110)),
                    tio.ZNormalization(),
                    tio.RescaleIntensity(out_min_max=(0, 1)) #), in_min_max=(0., 8957.8574))
                ]
        )

        image_tensor = image_transformation_tio(image_tensor)
    
        distance_CN = loss(image_tensor, mean_image_CN)
        distance_EMCI = loss(image_tensor, mean_image_EMCI)
        
        # print(distance_CN)
        # print(distance_EMCI)
        # print(distance_CN.shape)
        # print(distance_EMCI.shape)
        # print(distance_CN.dtype)
        # print(distance_EMCI.dtype)
        
        mean_square_distance = {"DISTANCE CN": distance_CN.item(), "DISTANCE EMCI": distance_EMCI.item()}

        subject_dict_with_filter_distance[key] = {"Subject ID": key,
                                                "Image ID": subject_dict[key][0],
                                                "Image Path": image_absolute_path, 
                                                "Image Target": subject_dict[key][1], 
                                            **mean_square_distance}
        
        
        
    # print(subject_dict_with_filter_distance)
    # print(mean_square_distance)
    
print(len(subject_dict_with_filter_distance.keys()))
# print(mean_square_distance)

587


In [6]:
filter_distance_dataframe = pd.DataFrame()
for key in subject_dict_with_filter_distance.keys():
    # print(subject_dict_with_filter_distance[key])     
    filter_distance_dataframe = filter_distance_dataframe.append(subject_dict_with_filter_distance[key], ignore_index = True)

        
filter_distance_dataframe.to_csv("../investigate/unique_subject_with_filter_distance_CN_EMCI.csv", index=False)

In [7]:
filter_distance_dataframe = pd.read_csv("../investigate/unique_subject_with_filter_distance_CN_EMCI.csv")
filter_distance_dataframe.head(10)

Unnamed: 0,Subject ID,Image ID,Image Path,Image Target,DISTANCE CN,DISTANCE EMCI
0,002_S_0295,I13722,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.003989,0.008199
1,002_S_0413,I14437,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.007575,0.013994
2,002_S_0559,I15948,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.005903,0.002665
3,002_S_0685,I18211,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.006689,0.012382
4,006_S_0484,I17377,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.007942,0.014387
5,006_S_0498,I17505,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002531,0.004697
6,006_S_0681,I23677,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.006057,0.011387
7,006_S_0731,I23468,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.00378,0.007872
8,009_S_4958,I338115,/mnt/data_lab513/vqtran_data/data/data_train_d...,EMCI,0.007011,0.003206
9,009_S_5000,I342850,/mnt/data_lab513/vqtran_data/data/data_train_d...,EMCI,0.00496,0.003604


In [8]:
intra_inter_distance_dataframe = 0
intra_class_distance = []
inter_class_distance = []


for ii in range(len(filter_distance_dataframe)):
    subject_id = filter_distance_dataframe.loc[ii, "Subject ID"]
    label = filter_distance_dataframe.loc[ii, "Image Target"]
    
    distance_CN = filter_distance_dataframe.loc[ii, "DISTANCE CN"] 
    distance_EMCI = filter_distance_dataframe.loc[ii, "DISTANCE EMCI"]
    
    if label == "EMCI" or label == "CN":
        if label == "CN":
            intra_class_distance.append(distance_CN)
            inter_class_distance.append(distance_EMCI - distance_CN)

        elif label == "EMCI":
            intra_class_distance.append(distance_EMCI)
            inter_class_distance.append(distance_CN - distance_EMCI )
        
        else:
            raise ValueError("This dataframe only allow 2 labels: EMCI and CN")
    else:
        raise ValueError("This dataframe only allow 2 labels: EMCI and CN")
    # print(filter_distance_dataframe.loc[0,:].to_frame().T)
    
    
intra_inter_distance_dataframe = filter_distance_dataframe.assign(INTRA_CLASS_DISTANCE=pd.Series(np.array(intra_class_distance)).values)
intra_inter_distance_dataframe = intra_inter_distance_dataframe.assign(INTER_CLASS_DISTANCE=pd.Series(np.array(inter_class_distance)).values)
intra_inter_distance_dataframe.head(10) 

Unnamed: 0,Subject ID,Image ID,Image Path,Image Target,DISTANCE CN,DISTANCE EMCI,INTRA_CLASS_DISTANCE,INTER_CLASS_DISTANCE
0,002_S_0295,I13722,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.003989,0.008199,0.003989,0.00421
1,002_S_0413,I14437,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.007575,0.013994,0.007575,0.006419
2,002_S_0559,I15948,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.005903,0.002665,0.005903,-0.003238
3,002_S_0685,I18211,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.006689,0.012382,0.006689,0.005692
4,006_S_0484,I17377,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.007942,0.014387,0.007942,0.006444
5,006_S_0498,I17505,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002531,0.004697,0.002531,0.002166
6,006_S_0681,I23677,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.006057,0.011387,0.006057,0.005329
7,006_S_0731,I23468,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.00378,0.007872,0.00378,0.004092
8,009_S_4958,I338115,/mnt/data_lab513/vqtran_data/data/data_train_d...,EMCI,0.007011,0.003206,0.003206,0.003806
9,009_S_5000,I342850,/mnt/data_lab513/vqtran_data/data/data_train_d...,EMCI,0.00496,0.003604,0.003604,0.001357


# 4. Cut off image have high root mean square distance

In [9]:
positive_inter_class_distance_filter = intra_inter_distance_dataframe["INTER_CLASS_DISTANCE"] > 0 

In [10]:
positive_inter_class_distance_dataframe = intra_inter_distance_dataframe[positive_inter_class_distance_filter]
print(len(positive_inter_class_distance_dataframe))
positive_inter_class_distance_dataframe.head(10)

375


Unnamed: 0,Subject ID,Image ID,Image Path,Image Target,DISTANCE CN,DISTANCE EMCI,INTRA_CLASS_DISTANCE,INTER_CLASS_DISTANCE
0,002_S_0295,I13722,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.003989,0.008199,0.003989,0.00421
1,002_S_0413,I14437,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.007575,0.013994,0.007575,0.006419
3,002_S_0685,I18211,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.006689,0.012382,0.006689,0.005692
4,006_S_0484,I17377,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.007942,0.014387,0.007942,0.006444
5,006_S_0498,I17505,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002531,0.004697,0.002531,0.002166
6,006_S_0681,I23677,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.006057,0.011387,0.006057,0.005329
7,006_S_0731,I23468,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.00378,0.007872,0.00378,0.004092
8,009_S_4958,I338115,/mnt/data_lab513/vqtran_data/data/data_train_d...,EMCI,0.007011,0.003206,0.003206,0.003806
9,009_S_5000,I342850,/mnt/data_lab513/vqtran_data/data/data_train_d...,EMCI,0.00496,0.003604,0.003604,0.001357
11,073_S_4986,I338196,/mnt/data_lab513/vqtran_data/data/data_train_d...,EMCI,0.024165,0.015493,0.015493,0.008672


In [11]:
# Class distribution in positive_inter_class_distance_dataframe
print(positive_inter_class_distance_dataframe["Image Target"].value_counts()) #CN 225, EMCI 150

CN      225
EMCI    150
Name: Image Target, dtype: int64


In [12]:
CN_filter = positive_inter_class_distance_dataframe["Image Target"] == "CN"
EMCI_filter = positive_inter_class_distance_dataframe["Image Target"] == "EMCI"

positive_inter_class_distance_CN_dataframe = positive_inter_class_distance_dataframe[CN_filter]
positive_inter_class_distance_EMCI_dataframe = positive_inter_class_distance_dataframe[EMCI_filter]

# positive_inter_class_distance_CN_dataframe.head()
positive_inter_class_distance_EMCI_dataframe.head()

Unnamed: 0,Subject ID,Image ID,Image Path,Image Target,DISTANCE CN,DISTANCE EMCI,INTRA_CLASS_DISTANCE,INTER_CLASS_DISTANCE
8,009_S_4958,I338115,/mnt/data_lab513/vqtran_data/data/data_train_d...,EMCI,0.007011,0.003206,0.003206,0.003806
9,009_S_5000,I342850,/mnt/data_lab513/vqtran_data/data/data_train_d...,EMCI,0.00496,0.003604,0.003604,0.001357
11,073_S_4986,I338196,/mnt/data_lab513/vqtran_data/data/data_train_d...,EMCI,0.024165,0.015493,0.015493,0.008672
13,002_S_4447,I278815,/mnt/data_lab513/vqtran_data/data/data_train_d...,EMCI,0.003232,0.001451,0.001451,0.001781
28,100_S_4512,I298265,/mnt/data_lab513/vqtran_data/data/data_train_d...,EMCI,0.002631,0.001854,0.001854,0.000777


In [13]:
positive_inter_class_distance_CN_sorted_dataframe = positive_inter_class_distance_CN_dataframe.sort_values(by=['INTRA_CLASS_DISTANCE'], ascending = True)
positive_inter_class_distance_EMCI_sorted_dataframe = positive_inter_class_distance_EMCI_dataframe.sort_values(by=['INTRA_CLASS_DISTANCE'], ascending = True)

In [14]:
positive_inter_class_distance_CN_sorted_dataframe.head(10)
# positive_inter_class_distance_AD_sorted_dataframe.head(10)

Unnamed: 0,Subject ID,Image ID,Image Path,Image Target,DISTANCE CN,DISTANCE EMCI,INTRA_CLASS_DISTANCE,INTER_CLASS_DISTANCE
336,036_S_0672,I19462,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001755,0.002241,0.001755,0.000485
501,052_S_1251,I38955,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001854,0.0033,0.001854,0.001446
182,012_S_4642,I296878,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001871,0.003392,0.001871,0.001521
73,036_S_4878,I321504,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002011,0.00265,0.002011,0.000639
70,023_S_0058,I9329,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002038,0.00366,0.002038,0.001623
333,109_S_0967,I27640,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002069,0.002817,0.002069,0.000748
260,153_S_4139,I250181,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002098,0.002423,0.002098,0.000325
205,036_S_0576,I16408,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002113,0.004042,0.002113,0.001929
39,116_S_1232,I37848,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002113,0.002406,0.002113,0.000293
261,153_S_4151,I251754,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002174,0.003996,0.002174,0.001822


positive distance inter class filter -> devide filtered dataframe to CN dataframe and AD dataframe -> sort intra class distance of 2 dataframe CN and AD (ascending order = True ) -> choose the first 100 subject in each dataframe CN and AD -> combine 2 dataframe CN and AD -> 

In [15]:
frames_to_concate = [positive_inter_class_distance_CN_sorted_dataframe.head(100), positive_inter_class_distance_EMCI_sorted_dataframe.head(100)]  # Or perform operations on the DFs
positive_inter_class_distance_sorted_dataframe = pd.concat(frames_to_concate)

positive_inter_class_distance_sorted_dataframe.head()

Unnamed: 0,Subject ID,Image ID,Image Path,Image Target,DISTANCE CN,DISTANCE EMCI,INTRA_CLASS_DISTANCE,INTER_CLASS_DISTANCE
336,036_S_0672,I19462,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001755,0.002241,0.001755,0.000485
501,052_S_1251,I38955,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001854,0.0033,0.001854,0.001446
182,012_S_4642,I296878,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001871,0.003392,0.001871,0.001521
73,036_S_4878,I321504,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002011,0.00265,0.002011,0.000639
70,023_S_0058,I9329,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002038,0.00366,0.002038,0.001623


In [16]:
print(len(positive_inter_class_distance_sorted_dataframe))


200


In [17]:
positive_inter_class_distance_sorted_dataframe = positive_inter_class_distance_sorted_dataframe.reset_index(drop=True)
positive_inter_class_distance_sorted_dataframe.head()

Unnamed: 0,Subject ID,Image ID,Image Path,Image Target,DISTANCE CN,DISTANCE EMCI,INTRA_CLASS_DISTANCE,INTER_CLASS_DISTANCE
0,036_S_0672,I19462,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001755,0.002241,0.001755,0.000485
1,052_S_1251,I38955,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001854,0.0033,0.001854,0.001446
2,012_S_4642,I296878,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.001871,0.003392,0.001871,0.001521
3,036_S_4878,I321504,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002011,0.00265,0.002011,0.000639
4,023_S_0058,I9329,/mnt/data_lab513/vqtran_data/data/data_train_d...,CN,0.002038,0.00366,0.002038,0.001623


# 4. Get tensor from dataframe

In [18]:
print(positive_inter_class_distance_sorted_dataframe.columns)

Index(['Subject ID', 'Image ID', 'Image Path', 'Image Target', 'DISTANCE CN',
       'DISTANCE EMCI', 'INTRA_CLASS_DISTANCE', 'INTER_CLASS_DISTANCE'],
      dtype='object')


In [19]:
print(positive_inter_class_distance_sorted_dataframe.loc[0, 'Image Path'])

/mnt/data_lab513/vqtran_data/data/data_train_dec/unique/ADNI_036_S_0672_MR_MPRAGE_br_raw_20060724092333840_1_S17131_I19462.nii.gz


In [20]:
# print(subject_dict)
X_tensor_cross_val = []
Y_tensor_cross_val = []


for ii in range(len(positive_inter_class_distance_sorted_dataframe)):
    subject_id = positive_inter_class_distance_sorted_dataframe.loc[ii, "Subject ID"]
    label = positive_inter_class_distance_sorted_dataframe.loc[ii, "Image Target"]
    
    # print(subject_id, label)
    image_absolute_path = positive_inter_class_distance_sorted_dataframe.loc[ii, 'Image Path']
    
    np_label = -1
    if label == "CN" or label == "EMCI":
        if label == "CN":
            np_label = np.array([0])
        elif label == "EMCI":
            np_label = np.array([1])
        else:
            raise ValueError("label must be CN or EMCI")  
    else:
        continue
        # raise ValueError("label can not be EMCI or AD")
        
    image_sample = nib.load(image_absolute_path)
    image_array = image_sample.get_fdata()
    
    image_tensor = torch.Tensor(image_array)
    image_tensor = torch.unsqueeze(image_tensor,0)
    
    image_transformation_tio = tio.transforms.Compose(
            [
                tio.transforms.Resize((110,110,110)),
                tio.ZNormalization(),
                tio.RescaleIntensity(out_min_max=(0, 1)) #), in_min_max=(0., 8957.8574))
            ]
    )

    image_tensor = image_transformation_tio(image_tensor)

    # print(image_tensor.shape)    
    label_tensor = torch.Tensor(np_label)

    X_tensor_cross_val.append(image_tensor)
    Y_tensor_cross_val.append(label_tensor)
    
X_tensor_cross_val = torch.stack(X_tensor_cross_val)
Y_tensor_cross_val = torch.stack(Y_tensor_cross_val)
    
Y_tensor_cross_val = Y_tensor_cross_val.ravel()
print(X_tensor_cross_val.shape)
print(Y_tensor_cross_val.shape)

torch.Size([200, 1, 110, 110, 110])
torch.Size([200])


In [21]:
torch.save(X_tensor_cross_val, os.path.join(root_train_unique_tensor, "x_tensor_NC_EMCI_cv_data_filter.pt"))
torch.save(Y_tensor_cross_val, os.path.join(root_train_unique_tensor, "y_tensor_NC_EMCI_cv_data_filter.pt"))

In [22]:
tensor_mean = torch.mean(X_tensor_cross_val) #tensor(0.1213)
tensor_std = torch.std(X_tensor_cross_val) #tensor(0.2260)
tensor_max = torch.max(X_tensor_cross_val) #tensor(1.)
tensor_min = torch.min(X_tensor_cross_val) #tensor(0.)
tensor_unique, tensor_count = torch.unique(Y_tensor_cross_val, return_counts=True)

# tensor_unique # tensor([0., 1.])
# tensor_count # tensor([100, 100])

In [23]:
print(tensor_mean)
print(tensor_std)
print(tensor_max)
print(tensor_min)
print(tensor_unique)
print(tensor_count)

tensor(0.1213)
tensor(0.2260)
tensor(1.)
tensor(0.)
tensor([0., 1.])
tensor([100, 100])
