In [11]:
import os
import math
import seaborn as sns
import numpy as np
import pandas as pd
from einops import rearrange
from tqdm import tqdm
import SimpleITK as sitk
import ipywidgets as widgets
import matplotlib.pyplot as plt
from torchsummary import summary
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import Accuracy, AUROC, Precision, Recall

In [12]:
import torch
import numpy as np
import nibabel as nib
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, dataframe, img_shape=(128, 128, 128)):
        self.df = dataframe
        self.image_paths = self.df['ADNI_path'].values
        self.labels = self.df['Group'].values

        # Binary classification mapping
        self.label_names = {'CN': 0, 'AD': 1, "MCI": 2, "EMCI": 3, "LMCI": 4}
        self.num_classes = len(self.label_names)
        self.labels_binary = self.df['Group'].map(self.label_names).values
        
        print(f"{len(self.image_paths)} Images found with classification.")
        print(f"Samples per class: {dict(zip(*np.unique(self.labels, return_counts=True)))}")
        
        self.img_shape = img_shape

    def __len__(self):
        return len(self.image_paths)

    def resize_image(self, image_np):
        """ Resize the numpy image using scipy.ndimage.zoom """
        from scipy.ndimage import zoom
        scale_factors = [n/o for n, o in zip(self.img_shape, image_np.shape)]
        resized_image_np = zoom(image_np, scale_factors, order=1)  # Linear interpolation
        return resized_image_np
    
    def crop_image_to_mask(self, image_volume):

        # Find non-zero indices in the mask
        non_zero_indices = np.argwhere(image_volume > 0)
        if non_zero_indices.size == 0:
            raise ValueError("Mask volume contains no non-zero values.")

        # Calculate bounding box
        min_indices = np.min(non_zero_indices, axis=0)
        max_indices = np.max(non_zero_indices, axis=0)

        # Crop image
        cropped_image = image_volume[
            min_indices[0]:max_indices[0] + 1,
            min_indices[1]:max_indices[1] + 1,
            min_indices[2]:max_indices[2] + 1
        ]

        return cropped_image
    
    def process_image_only(self, image_path):
        try:
            
            image_nib = nib.load(image_path)
            image_np = image_nib.get_fdata()

            image_np = self.crop_image_to_mask(image_np)

            # Resize image
            resized_image_np = self.resize_image(image_np)

            # Normalize image
            mean = np.mean(resized_image_np)
            std = np.std(resized_image_np)
            std = std if std != 0 else 1e-6
            resized_image_np = (resized_image_np - mean) / std

            # Convert to PyTorch tensor
            image_tensor = torch.tensor(resized_image_np, dtype=torch.float32).unsqueeze(0)

            return image_tensor
        
        except Exception as e:
                return None

    def process_image(self, idx):
        try:
            image_path = self.image_paths[idx]
            label = self.labels_binary[idx]

            # Load image using nibabel
            image_nib = nib.load(image_path)
            image_np = image_nib.get_fdata()

            image_np = self.crop_image_to_mask(image_np)

            # Resize image
            resized_image_np = self.resize_image(image_np)

            # Normalize image
            mean = np.mean(resized_image_np)
            std = np.std(resized_image_np)
            std = std if std != 0 else 1e-6
            resized_image_np = (resized_image_np - mean) / std

            # Convert to PyTorch tensor
            image_tensor = torch.tensor(resized_image_np, dtype=torch.float32).unsqueeze(0)

            return image_tensor, label, image_path
        
        except Exception as e:
            print(f"Exception in processing image {image_path}: {e}")
            return None, None, None

    def __getitem__(self, idx):
        img, lbl, image_path = self.process_image(idx)
        return img, lbl, image_path
    

def custom_collate(batch):
    # Remove samples that are None
    batch = [item for item in batch if item[0] is not None]
    if len(batch) == 0:
        return None  # or raise an error
    return torch.utils.data.dataloader.default_collate(batch)


In [13]:
directory_path = 'C:/ADNI_2/other_dataset/ADNI_processed/'
label_names = {'CN': 0, 'AD': 1, "MCI": 2, "EMCI": 3, "LMCI": 4}

# fix this to be specific about data
IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH = 128, 128, 128

################################################################################################################
def load_nii_paths(directory):
    image_files = []
    IDs = []

    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith('.nii.gz'):
                image_files.append(os.path.join(root, file))
                IDs.append(os.path.join(root, file).split("\\")[-2])
                
    return np.array(image_files), np.array(IDs)

img_file_ori_2, IDs_2 = load_nii_paths(directory_path)

################################################################################################################
#  Load the CSV file
df_2 = pd.read_csv("C:/ADNI_2/other_dataset/ADNI-2025.csv")

print(df_2["Group"].value_counts())

################################################################################################################
# Add the paths to the csv file and create filtered dataframe
main_imgs_2 = []
valid_indices_2 = []

for i in range(len(df_2)):
    img_id = df_2["Image Data ID"][i]
    
    try:
        # Check if img_id exists in IDs array
        matches = np.where(IDs_2 == img_id)[0]
        if len(matches) == 0:
            main_imgs_2.append(None)
            continue
            
        index = matches[0]
        main_imgs_2.append(img_file_ori_2[index])
        valid_indices_2.append(i)
        
    except Exception as e:
        print(f"Error processing image ID {img_id}: {str(e)}")
        main_imgs_2.append(None)

# Create new dataframe with only valid image data
df_2["ADNI_path"] = main_imgs_2
df_valid = df_2.iloc[valid_indices_2].copy()

# Print summary statistics
print(f"\nTotal records processed: {len(df_2)}")
print(f"Number of matched images: {len(df_valid)}")
print(f"Number of missing images: {len(df_2) - len(df_valid)}")

print("\nGroup distribution in filtered dataset:")
print(df_valid["Group"].value_counts())

df_valid

Group
CN      1765
EMCI    1394
LMCI     857
AD       776
MCI      717
Name: count, dtype: int64

Total records processed: 5509
Number of matched images: 5495
Number of missing images: 14

Group distribution in filtered dataset:
Group
CN      1761
EMCI    1390
LMCI     851
AD       776
MCI      717
Name: count, dtype: int64


Unnamed: 0,Image Data ID,Subject,Group,Sex,Age,Visit,Modality,Description,Type,Acq Date,Format,Downloaded,ADNI_path
0,I1293299,130_S_4294,LMCI,F,81,init,MRI,MT1; N3m,Processed,10/31/2017,NiFTI,3/04/2025,C:/ADNI_2/other_dataset/ADNI_processed/130_S_4...
1,I1293312,130_S_4294,LMCI,F,81,init,MRI,MT1; N3m,Processed,10/31/2017,NiFTI,3/04/2025,C:/ADNI_2/other_dataset/ADNI_processed/130_S_4...
2,I966298,014_S_2308,EMCI,M,81,init,MRI,MT1; GradWarp; N3m,Processed,9/01/2017,NiFTI,3/04/2025,C:/ADNI_2/other_dataset/ADNI_processed/014_S_2...
3,I966296,014_S_2308,EMCI,M,81,init,MRI,MT1; N3m,Processed,9/01/2017,NiFTI,3/04/2025,C:/ADNI_2/other_dataset/ADNI_processed/014_S_2...
4,I1293301,018_S_4399,CN,F,84,init,MRI,MT1; N3m,Processed,8/08/2017,NiFTI,3/04/2025,C:/ADNI_2/other_dataset/ADNI_processed/018_S_4...
...,...,...,...,...,...,...,...,...,...,...,...,...,...
5504,I387751,011_S_0016,CN,M,66,sc,MRI,HHP 6 DOF AC-PC registered MPRAGE,Processed,9/27/2005,NiFTI,3/04/2025,C:/ADNI_2/other_dataset/ADNI_processed/011_S_0...
5505,I412371,011_S_0005,CN,M,74,sc,MRI,HHP 6 DOF AC-PC registered MPRAGE,Processed,9/02/2005,NiFTI,3/04/2025,C:/ADNI_2/other_dataset/ADNI_processed/011_S_0...
5506,I474708,011_S_0005,CN,M,74,sc,MRI,HarP 135 final release 2015,Processed,9/02/2005,NiFTI,3/04/2025,C:/ADNI_2/other_dataset/ADNI_processed/011_S_0...
5507,I474758,011_S_0002,CN,M,74,sc,MRI,HarP 135 final release 2015,Processed,8/26/2005,NiFTI,3/04/2025,C:/ADNI_2/other_dataset/ADNI_processed/011_S_0...


In [14]:
img_size = (128, 128, 128, 1)

train_dataset = CustomDataset(df_valid, img_shape=img_size)

5495 Images found with classification.
Samples per class: {'AD': 776, 'CN': 1761, 'EMCI': 1390, 'LMCI': 851, 'MCI': 717}


In [16]:
img_path

'C:/ADNI_2/other_dataset/ADNI_processed/941_S_6052\\MT1__N3m\\2017-07-20_11_10_10.0\\I882756\\ADNI_941_S_6052_MR_MT1__N3m_Br_20170804183926749_S585807_I882756.nii.gz'

In [18]:
for i in tqdm(range(len(df_valid))):
    # Save the processed image tensor to a file
    try:
        img_path  = df_valid['ADNI_path'].iloc[i]
        img = train_dataset.process_image_only(img_path)
        if img is not None:
            output_path = img_path.replace("\\", "/").replace("ADNI_processed", "ADNI_tensors").replace('.nii.gz', '.pt')
            output_dir = os.path.dirname(output_path)
            os.makedirs(output_dir, exist_ok=True)

            torch.save(img.squeeze(0), output_path)
        else:
            print(f"Skipping image at {img_path}")
    except Exception as e:
        print(f"Error saving image at {output_path}: {str(e)}")

  3%|▎         | 171/5495 [00:54<23:48,  3.73it/s]

Skipping image at C


 13%|█▎        | 727/5495 [03:52<22:20,  3.56it/s]

Skipping image at C


 16%|█▌        | 857/5495 [04:36<26:14,  2.95it/s]

Skipping image at C


 16%|█▌        | 865/5495 [04:39<25:31,  3.02it/s]

Skipping image at C


 27%|██▋       | 1471/5495 [07:55<20:59,  3.20it/s]

Skipping image at C


 37%|███▋      | 2033/5495 [10:48<18:06,  3.19it/s]

Skipping image at C


 51%|█████▏    | 2825/5495 [15:11<15:27,  2.88it/s]

Skipping image at C


 56%|█████▋    | 3102/5495 [16:46<13:46,  2.90it/s]

Skipping image at C


 64%|██████▎   | 3503/5495 [18:54<10:54,  3.04it/s]

Skipping image at C


 68%|██████▊   | 3757/5495 [20:17<10:17,  2.81it/s]

Skipping image at C


 86%|████████▌ | 4711/5495 [25:23<04:07,  3.16it/s]

Skipping image at C


100%|██████████| 5495/5495 [29:27<00:00,  3.11it/s]
