In [4]:
import h5py
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
import cv2
import os
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.model_selection import train_test_split
import shutil
import glob
import SimpleITK as sitk

In [5]:
# read metadata
path = "/media/yesindeed/WD5T/data/ADIN/1.5T/"

demo_data = pd.read_csv(path + "ADNI1_Screening_1.5T_2_20_2024.csv")
# demo_data = pd.read_csv(path + "ADNI1_Complete_1Yr_1.5T_5_02_2024.csv")
demo_data

Unnamed: 0,Image Data ID,Subject,Group,Sex,Age,Visit,Modality,Description,Type,Acq Date,Format,Downloaded
0,I63897,941_S_1363,MCI,F,70,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,3/12/2007,NiFTI,1/11/2024
1,I97327,941_S_1311,MCI,M,69,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,3/02/2007,NiFTI,1/11/2024
2,I63888,941_S_1295,MCI,M,77,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,2/09/2007,NiFTI,1/11/2024
3,I63879,941_S_1203,CN,M,83,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,1/29/2007,NiFTI,1/11/2024
4,I63874,941_S_1202,CN,M,78,sc,MRI,MPR-R; GradWarp; B1 Correction; N3; Scaled,Processed,1/30/2007,NiFTI,1/11/2024
...,...,...,...,...,...,...,...,...,...,...,...,...
1070,I118676,002_S_0559,CN,M,79,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled_2,Processed,5/23/2006,NiFTI,2/20/2024
1071,I45117,002_S_0413,CN,F,76,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,5/02/2006,NiFTI,2/20/2024
1072,I118673,002_S_0413,CN,F,76,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled_2,Processed,5/02/2006,NiFTI,2/20/2024
1073,I118671,002_S_0295,CN,M,85,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled_2,Processed,4/18/2006,NiFTI,2/20/2024


In [6]:
# only keep screening subset
demo_data = demo_data.loc[demo_data["Visit"] == "sc"]
# remove MCI
demo_data = demo_data.loc[demo_data["Group"] != "MCI"]

demo_data = demo_data.reset_index(drop=True)
demo_data

Unnamed: 0,Image Data ID,Subject,Group,Sex,Age,Visit,Modality,Description,Type,Acq Date,Format,Downloaded
0,I63879,941_S_1203,CN,M,83,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,1/29/2007,NiFTI,1/11/2024
1,I63874,941_S_1202,CN,M,78,sc,MRI,MPR-R; GradWarp; B1 Correction; N3; Scaled,Processed,1/30/2007,NiFTI,1/11/2024
2,I66462,941_S_1197,CN,F,82,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,1/20/2007,NiFTI,1/11/2024
3,I63865,941_S_1195,CN,M,77,sc,MRI,MPR-R; GradWarp; B1 Correction; N3; Scaled,Processed,2/08/2007,NiFTI,1/11/2024
4,I63847,941_S_1194,CN,M,85,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,1/20/2007,NiFTI,1/11/2024
...,...,...,...,...,...,...,...,...,...,...,...,...
545,I118676,002_S_0559,CN,M,79,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled_2,Processed,5/23/2006,NiFTI,2/20/2024
546,I45117,002_S_0413,CN,F,76,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,5/02/2006,NiFTI,2/20/2024
547,I118673,002_S_0413,CN,F,76,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled_2,Processed,5/02/2006,NiFTI,2/20/2024
548,I118671,002_S_0295,CN,M,85,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled_2,Processed,4/18/2006,NiFTI,2/20/2024


In [7]:
demo_data["Sex"].value_counts()

Sex
M    289
F    261
Name: count, dtype: int64

In [8]:
demo_data["Age"].value_counts()

Age
71    53
78    44
73    42
72    39
77    36
76    32
80    32
75    31
74    28
85    26
70    22
83    19
79    19
81    14
82    14
86    12
84    12
66    11
69     9
64     8
65     6
87     5
88     5
68     5
63     4
62     4
60     4
57     3
89     3
56     2
90     2
59     2
55     1
91     1
Name: count, dtype: int64

In [64]:
len(demo_data["Subject"].unique())

417

In [66]:
import tqdm
import SimpleITK as sitk

# put all MRI into a single folder
for i in tqdm.tqdm(range(len(demo_data))):
    item = demo_data.iloc[i]

    id = item["Image Data ID"]
    subject = item["Subject"]

    org_path = glob.glob(os.path.join(
        path, "ADNI", subject, "*/*", id, "*/*.nii"))
    # print(os.path.join(subject, "*/*", id, "*.nii"))
    assert len(org_path) == 1, f"{id} of {subject}"

    array = sitk.GetArrayFromImage(sitk.ReadImage(org_path))
    dtype = array.dtype
    array = torch.from_numpy(array).float()

    array = array.permute(0, 3, 1, 2)

    if array.shape[2] != 224 or array.shape[3] != 224:
        array = F.interpolate(array, size=(224, 224), mode="bicubic")

    # resize to 34 since we need to stack nearest two slices for fake RGB input
    # the actual input to NN is 32 x 512 x 512
    array = F.interpolate(array.unsqueeze(0), size=(
        24, 224, 224), mode="trilinear").squeeze()
    # array = array.permute(1, 2, 0)
    array = array.numpy().astype(dtype)
    # print(array.shape)
    image = sitk.GetImageFromArray(array)

    sitk.WriteImage(image, os.path.join(path, "preprocessed", f"{id}.nii.gz"))

    # break

100%|██████████| 550/550 [13:23<00:00,  1.46s/it]  


In [36]:
def split_82(all_meta, max_test_size_per_group):
    males = all_meta[all_meta["Sex"] == "M"]
    females = all_meta[all_meta["Sex"] == "F"]

    males_train, males_test = train_test_split(
        np.unique(males["Image Data ID"]), test_size=max_test_size_per_group, random_state=0
    )
    females_train, females_test = train_test_split(
        np.unique(females["Image Data ID"]), test_size=max_test_size_per_group, random_state=0
    )

    sub_train = np.concatenate([males_train, females_train])
    sub_test = np.concatenate([males_test, females_test])

    train_meta = all_meta[all_meta["Image Data ID"].isin(sub_train)]
    test_meta = all_meta[all_meta["Image Data ID"].isin(sub_test)]

    return train_meta, test_meta


unique_males = demo_data[demo_data["Sex"] == "M"]["Image Data ID"].nunique()
unique_females = demo_data[demo_data["Sex"] == "F"]["Image Data ID"].nunique()
max_test_size_per_group = min(
    int(0.2 * unique_females), int(0.2 * unique_males))

# sub_train_meta, sub_test_meta = split_82(demo_data, max_test_size_per_group)


sub_train, sub_test = split_82(demo_data, max_test_size_per_group)

sub_train.to_csv(os.path.join(path, "train.csv"), index=False)
sub_test.to_csv(os.path.join(path, "test.csv"), index=False)

In [3]:
# age

df_test = pd.read_csv(os.path.join(path, "test.csv"))

df_test = df_test[~df_test["Age"].isnull()]

df_test["age_binary"] = df_test["Age"].values.astype("float")
df_test["age_binary"] = np.where(df_test["age_binary"].between(-1, 60), 0, df_test["age_binary"])
df_test["age_binary"] = np.where(df_test["age_binary"] >= 60, 1, df_test["age_binary"])

class_counts = df_test["age_binary"].value_counts()
print(class_counts)
min_count = class_counts.min()
balanced_test_meta = df_test.groupby("age_binary").apply(lambda x: x.sample(min_count)).reset_index(drop=True)


balanced_test_meta.to_csv(os.path.join(path, "test_age.csv"), index=False)

balanced_test_meta["age_binary"].value_counts()

age_binary
1.0    104
Name: count, dtype: int64


  balanced_test_meta = df_test.groupby("age_binary").apply(lambda x: x.sample(min_count)).reset_index(drop=True)


age_binary
1.0    104
Name: count, dtype: int64