In [28]:
import transformers
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import os
from nilearn.image import load_img
import pandas as pd
import gc

data_dir = 'data'
pet_dir = "data/AD_images"

In [29]:
files = {}
for file in os.listdir(pet_dir):
    if file.endswith(".nii"):
        img = load_img(os.path.join(pet_dir, file))
        patient_id = file.split(".")[0].removeprefix('AD_normalised_')
        print(f'Processing patient: {patient_id}')
        # Convert the image to a PyTorch tensor
        torch_img = torch.tensor(img.get_fdata(), dtype=torch.float32)
        files[patient_id] =torch_img

Processing patient: 002_S_5018
Processing patient: 003_S_4136
Processing patient: 003_S_4152
Processing patient: 003_S_4373
Processing patient: 003_S_4892
Processing patient: 003_S_5165
Processing patient: 003_S_5187
Processing patient: 005_S_4707
Processing patient: 005_S_4910
Processing patient: 005_S_5038
Processing patient: 005_S_5119
Processing patient: 006_S_4153
Processing patient: 006_S_4192
Processing patient: 006_S_4546
Processing patient: 006_S_4867
Processing patient: 007_S_4568
Processing patient: 007_S_4637
Processing patient: 007_S_4911
Processing patient: 007_S_5196
Processing patient: 009_S_5027
Processing patient: 009_S_5037
Processing patient: 009_S_5224
Processing patient: 009_S_5252
Processing patient: 011_S_4827
Processing patient: 011_S_4845
Processing patient: 011_S_4906
Processing patient: 011_S_4912
Processing patient: 011_S_4949
Processing patient: 013_S_5071
Processing patient: 014_S_4039
Processing patient: 014_S_4615
Processing patient: 016_S_4009
Processi

In [30]:
df = pd.read_csv(os.path.join(data_dir, 'ADNIMERGE_19Jun2025.csv'))
sex_df = df.filter(['PTID', 'PTGENDER'])
sex_map = {'Male': 0, 'Female': 1}
sex_labels = {0: 'Male', 1: 'Female'}
sex_df['PTGENDER'] = sex_df['PTGENDER'].map(sex_map)

  df = pd.read_csv(os.path.join(data_dir, 'ADNIMERGE_19Jun2025.csv'))


In [31]:
del df
gc.collect()

30

In [32]:
# Compute the number of common patients between the PET files and the sex_df
missing_patients = sex_df[~sex_df['PTID'].isin(files.keys())]
print(f'Missing patients: {len(missing_patients)}')

common_patients = sex_df['PTID'].isin(files.keys())
print(f'Common patients: {common_patients.sum()}')

print(f'Total patients {len(files)}')

Missing patients: 15684
Common patients: 737
Total patients 149


In [33]:
# Update the sex DataFrame to include a new column for the PET image data matched on PTID
for patient in files:
    img = files.get(patient)
    sex_df['PET_IMAGE'] = sex_df['PTID'].map(files) # Insert the img data ino the 'PET_IMAGE' column in sex_df for the corresponding PTID field

In [34]:
# Print the PTID for the columns for which PET_IMAGE is not None
print(f'Number of patients: {len(sex_df)}')
sex_df.dropna(subset=['PET_IMAGE'], inplace=True)
sex_df.drop_duplicates(subset=['PTID'], inplace=True)
print(f'Number of patients with PET images: {len(sex_df)}')

Number of patients: 16421
Number of patients with PET images: 149


In [35]:
# Print the shape of the PET image for the first patient
first_patient = sex_df.iloc[0]
print(f'First patient PTID: {first_patient["PTID"]}')
print(f'PET image shape: {first_patient["PET_IMAGE"].shape} with data type {type(first_patient["PET_IMAGE"])}')
print(f'Sex: {first_patient["PTGENDER"]}')

First patient PTID: 135_S_5275
PET image shape: torch.Size([101, 116, 96]) with data type <class 'torch.Tensor'>
Sex: 1


In [36]:
# Select PET_IMAGE and PTGENDER from sex_df
X = sex_df['PET_IMAGE'].tolist()  # This will be a list of torch.Tensor objects
y = sex_df['PTGENDER'].values     # This will be a numpy array of labels

print(f'X length: {len(X)}, PET image shape: {X[0].shape}, y shape: {y.shape}')

X length: 149, PET image shape: torch.Size([101, 116, 96]), y shape: (149,)


In [37]:
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, random_state=42)
print(f'Train X length: {len(X_train)}, Test X length: {len(X_test)} with shapes {X_train[0].shape}, {X_test[0].shape}')
print(f'Train y shape: {y_train.shape}, Test y shape: {y_test.shape}')

Train X length: 119, Test X length: 30 with shapes torch.Size([101, 116, 96]), torch.Size([101, 116, 96])
Train y shape: (119,), Test y shape: (30,)


In [38]:
# Define a 3D CNN to deal with images of shape (101, 116, 96)
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = torch.nn.Conv3d(1, 16, kernel_size=5, padding='valid')
        self.conv2 = torch.nn.Conv3d(16, 32, kernel_size=5, padding=1)
        self.fc1 = torch.nn.Linear(32 * 101 * 116 * 96 // 4 // 4 // 4, 128)
        # self.fc2 = torch.nn.Linear(128, 2)  # Binary classification
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.max_pool3d(x, kernel_size=2, stride=2)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        # x = self.fc2(x)
        x = self.sigmoid(x)
        return x

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = CNN().to(device)
criterion = torch.nn.BCELoss()
inputs = torch.stack([torch.unsqueeze(img, 0) for img in X_train], dim=0)  # Add channel dimension
labels = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)  # Convert labels to float and add a channel dimension
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 100
train_losses = []

for epoch in range(1, epochs + 1):
    model.train()
    optimizer.zero_grad()

    # Convert the training data to PyTorch tensors
    # inputs = torch.tensor(X_train, dtype=torch.float32).unsqueeze(1).to(device)  # Add channel dimension
    # labels = torch.tensor(y_train.values, dtype=torch.float32).to(device)
    inputs, labels = inputs.to(device), labels.to(device)

    outputs = model(inputs)
    loss = criterion(outputs, labels)

    loss.backward()
    optimizer.step()

    train_losses.append(loss.item())

    if epoch % 10 == 0:
        print(f'Epoch [{epoch}/{epochs}], Loss: {loss.item():.4f}')