# Modules to download

In [None]:
# !pip install matplotlib nibabel numpy opencv-python scipy scikit-image antsx SimpleITK



In [1]:
# Importation for modules.
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import cv2
from scipy.ndimage import binary_fill_holes
from skimage.morphology import remove_small_objects, convex_hull_image
from skimage.segmentation import active_contour
from skimage.filters import gaussian
import ants
import SimpleITK as sitk
from helpers import *
from antspynet.utilities import brain_extraction
import os
import csv
import shutil

### **Cleaning Demographic Data**
This process cleans the demographic data from the OASIS-2 dataset and stores the cleaned data as a CSV file in the current working directory (CWD).

**Note:** Replace the Dataset_folder variable (line 1) with the path to your demographic data excel file or use the cleaned_demographic_data.csv file aldredy present in this folder

In [None]:
Dataset_folder = "C:/Users/moksh/OneDrive/Desktop/Alzeimers/oasis_db/OAS2_RAW_PART2/OAS2_RAW_PART2" #replace with path to your dataset folder
if not os.path.exists(Dataset_folder):
        print([])  

nifti_files = []
for item in os.listdir(Dataset_folder):
    item_path = os.path.join(Dataset_folder, item)
    if os.path.isdir(item_path):  # Check if it's a directory
        nifti_files.append(item)
Dataset2_folder = "C:/Users/moksh/OneDrive/Desktop/Alzeimers/oasis_db/OAS2_RAW_PART1"
for item in os.listdir(Dataset2_folder):
    item_path = os.path.join(Dataset2_folder, item)
    if os.path.isdir(item_path):  # Check if it's a directory
        nifti_files.append(item)



header = None #storing header to write to new cleaned csv file
demographic_mri_ids = [] #stores mri mri_ids in demographic data
demographic_data = "oasis_longitudinal_demographics-8d83e569fa2e2d30.csv" #file path of uncleaned data
with open(demographic_data, 'r', newline='', encoding='utf-8') as csvfile:
    reader = csv.reader(csvfile)
    header = next(reader)  # Read header
    column_index = 1  # index of "mri id" column

    for row in reader:
        try:
            demographic_mri_ids.append(row[column_index])
        except IndexError:
            print(f"Warning: Row has fewer columns than expected.")  # Handle missing data
print(nifti_files)
print(len(demographic_mri_ids))
print(len(nifti_files))



# finding intersection
demographic_mri_ids_set = set(demographic_mri_ids)
nifti_files_set = set(nifti_files)
print(nifti_files_set-demographic_mri_ids_set)
final_data = demographic_mri_ids_set.intersection(nifti_files_set)
print(len(final_data) == len(nifti_files)) # len same as nifti files therfore all demographics data for nifti files available


#creating new csv with cleaned data
filtered_rows = []
with open(demographic_data, 'r', newline='', encoding='utf-8') as csvfile:
    reader = csv.reader(csvfile)
    header = next(reader)  # Read header (if present)
    column_index = 1  # index of "mri id" column

    for row in reader:
        try:
            if row[column_index] in final_data:
                filtered_rows.append(row)
        except IndexError:
            print(f"Warning: Row has fewer columns than expected.")  # Handle missing data

output_filepath = "cleaned_demographic_data"
with open(output_filepath, 'w', newline='', encoding='utf-8') as outfile:
    writer = csv.writer(outfile)
    writer.writerow(header)
    writer.writerows(filtered_rows)
    print(f"Filtered data written to: {output_filepath}")


['OAS2_0100_MR1', 'OAS2_0100_MR2', 'OAS2_0100_MR3', 'OAS2_0101_MR1', 'OAS2_0101_MR2', 'OAS2_0101_MR3', 'OAS2_0102_MR1', 'OAS2_0102_MR2', 'OAS2_0102_MR3', 'OAS2_0103_MR1', 'OAS2_0103_MR2', 'OAS2_0103_MR3', 'OAS2_0104_MR1', 'OAS2_0104_MR2', 'OAS2_0105_MR1', 'OAS2_0105_MR2', 'OAS2_0106_MR1', 'OAS2_0106_MR2', 'OAS2_0108_MR1', 'OAS2_0108_MR2', 'OAS2_0109_MR1', 'OAS2_0109_MR2', 'OAS2_0111_MR1', 'OAS2_0111_MR2', 'OAS2_0112_MR1', 'OAS2_0112_MR2', 'OAS2_0113_MR1', 'OAS2_0113_MR2', 'OAS2_0114_MR1', 'OAS2_0114_MR2', 'OAS2_0116_MR1', 'OAS2_0116_MR2', 'OAS2_0117_MR1', 'OAS2_0117_MR2', 'OAS2_0117_MR3', 'OAS2_0117_MR4', 'OAS2_0118_MR1', 'OAS2_0118_MR2', 'OAS2_0119_MR1', 'OAS2_0119_MR2', 'OAS2_0119_MR3', 'OAS2_0120_MR1', 'OAS2_0120_MR2', 'OAS2_0121_MR1', 'OAS2_0121_MR2', 'OAS2_0122_MR1', 'OAS2_0122_MR2', 'OAS2_0124_MR1', 'OAS2_0124_MR2', 'OAS2_0126_MR1', 'OAS2_0126_MR2', 'OAS2_0126_MR3', 'OAS2_0127_MR1', 'OAS2_0127_MR2', 'OAS2_0127_MR3', 'OAS2_0127_MR4', 'OAS2_0127_MR5', 'OAS2_0128_MR1', 'OAS2_0128_MR

### **Functions for NIfTI Display and Processing**

This section includes functions to interactively explore 3D arrays, rescale array values, add suffixes to filenames, and overlay mask contours on 3D images for visualization.


In [19]:
# Function to display nifti files
from ipywidgets import interact
def explore_3D_array(arr: np.ndarray, cmap: str='gray'):
    def fn(SLICE):
        plt.figure(figsize=(7,7))
        plt.axis('off')
        plt.imshow(arr[SLICE,:,:], cmap=cmap)

    interact(fn, SLICE=(0, arr.shape[0]-1))

def add_suffix_to_filename(filename: str, suffix: str) ->str:
    if filename.endswith('.nifti.hdr'):
        result = filename.replace('.nifti.hdr', f'_{suffix}.nifti.hdr')
        return result
    else:
        raise RuntimeError('filename with unknown ext')

def rescale_linear(array: np.ndarray, new_min: int, new_max: int):
  minimum, maximum = np.min(array), np.max(array)
  m = (new_max - new_min) / (maximum - minimum)
  b = new_min - m * minimum
  return m * array + b

def explore_3D_array_with_mask_contour(arr: np.ndarray, mask: np.ndarray, thickness: int = 1):
  _arr = rescale_linear(arr,0,1)
  _mask = rescale_linear(mask,0,1)
  _mask = _mask.astype(np.uint8)

  def fn(SLICE):
    arr_rgb = cv2.cvtColor(_arr[SLICE, :, :], cv2.COLOR_GRAY2RGB)
    contours, _ = cv2.findContours(_mask[SLICE, :, :], cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    arr_with_contours = cv2.drawContours(arr_rgb, contours, -1, (0,1,0), thickness)

    plt.figure(figsize=(7,7))
    plt.imshow(arr_with_contours)

  interact(fn, SLICE=(0, arr.shape[0]-1))

### **Brain Extraction Function**

This function reads a NIfTI image, performs brain extraction using ANTs, applies the brain mask, and returns the masked brain image as a NumPy array, preserving metadata.


In [5]:
def extracting_brain(img_path):
    ant_img = ants.image_read(img_path, reorient='RAS')
    # explore_3D_array(arr=ant_img.numpy(), cmap='nipy_spectral')
    prob_brain_mask = brain_extraction(ant_img, modality="t1")
    brain_mask = ants.get_mask(prob_brain_mask, low_thresh=0.5)
    # explore_3D_array_with_mask_contour(ant_img.numpy(), brain_mask.numpy())
    # Apply the mask
    masked = ants.mask_image(ant_img, brain_mask)

    # Convert to NumPy array
    final_mask = masked.numpy()


    # Convert back to ANTs image (to preserve metadata)
    rotated_masked_ant = ants.from_numpy(final_mask, origin=masked.origin, spacing=masked.spacing, direction=masked.direction)

    # Visualize the rotated image
    # explore_3D_array(final_mask, cmap='gray')
    final_mask.shape
    return final_mask


In [6]:
extracting_brain("C:/Users/moksh/OneDrive/Desktop/Alzeimers/oasis_db/disc1/OAS1_0001_MR1/RAW/OAS1_0001_MR1_mpr-1_anon.hdr")

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0.

### **MRI Image Preprocessing**

This function processes MRI images through a series of steps, including:
1. Grayscale conversion (if needed)
2. CLAHE (Contrast Limited Adaptive Histogram Equalization) for contrast enhancement
3. Gaussian and Median blurring for noise reduction
4. Sharpening using a custom kernel
5. Non-Local Means Denoising
6. Adding salt-and-pepper noise
7. Applying PCA to reduce dimensionality of the image slices

It processes both 2D slices and 3D volumes of MRI data.


In [6]:
from sklearn.decomposition import PCA
def preprocess_mri_image(img_array):
    """
    Preprocess an MRI image with the following steps:
    1. Convert to grayscale (if needed)
    2. Apply CLAHE for contrast enhancement
    3. Apply Gaussian blur
    4. Apply Median blur for noise reduction
    5. Apply sharpening
    6. Apply Non-Local Means Denoising
    7. Add salt-and-pepper noise at the end

    :param img_array: NumPy array of the MRI image
    :return: Processed NumPy array
    """
    img_array
    if len(img_array.shape) == 3:  # Check if it's a 3D array (e.g., 128, 256, 256)
        processed_slices = []
        for i in range((img_array.shape[0]//2)-10,(img_array.shape[0]//2)+10):  # Iterate through each slice
            slice_img = img_array[i]
            processed_slices.append(process_single_slice(slice_img))
        return np.array(processed_slices)

    else:
        return process_single_slice(img_array)

def process_single_slice(slice_img):
    """
    Process a single 2D slice of the MRI image.
    """
    slice_img = np.rot90(slice_img, k=3)
    # Ensure image is single-channel grayscale
    if len(slice_img.shape) == 3 and slice_img.shape[-1] == 3:  # Check if RGB
        slice_img = cv2.cvtColor(slice_img, cv2.COLOR_BGR2GRAY)

    # Normalize and convert to uint8
    slice_img = cv2.normalize(slice_img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

    # Apply CLAHE (Adaptive Histogram Equalization)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    slice_img = clahe.apply(slice_img)

    # Apply Gaussian Blur
    slice_img = cv2.GaussianBlur(slice_img, (5, 5), 0)

    # Apply sharpening filter
    sharpen_kernel = np.array([[0, -1, 0],
                               [-1, 5, -1],
                               [0, -1, 0]])
    slice_img = cv2.filter2D(slice_img, -1, sharpen_kernel)

    # Apply Non-Local Means Denoising
    slice_img = cv2.fastNlMeansDenoising(slice_img, None, h=10, templateWindowSize=7, searchWindowSize=21)

    # Add Salt-and-Pepper Noise
    slice_img = add_salt_and_pepper(slice_img)

    # Apply Median Blur
    slice_img = cv2.medianBlur(slice_img, 5)

    return slice_img

def add_salt_and_pepper(img, salt_prob=0.02, pepper_prob=0.02):
    """
    Function to add salt-and-pepper noise to an image.
    """
    noisy_img = img.copy()
    num_salt = int(salt_prob * img.size)
    num_pepper = int(pepper_prob * img.size)

    # Add salt (white pixels)
    coords = [np.random.randint(0, i - 1, num_salt) for i in img.shape]
    noisy_img[tuple(coords)] = 255

    # Add pepper (black pixels)
    coords = [np.random.randint(0, i - 1, num_pepper) for i in img.shape]
    noisy_img[tuple(coords)] = 0

    return noisy_img


# def apply_pca_to_slice(slice_img, n_components=115):
#     """
#     Apply PCA to a single 2D slice and return the transformed image (without reconstruction).

#     :param slice_img: 2D slice of the MRI image
#     :param n_components: Number of principal components to keep
#     :return: PCA transformed slice (without reconstruction)
#     """
#     # Step 1: Reshape the slice to a 2D array (flatten it)
#     # h, w = slice_img.shape
#     # reshaped_slice = slice_img.reshape(h, w)

#     # Step 2: Apply PCA to the reshaped slice
#     pca = PCA(n_components=n_components)
#     transformed_image = pca.fit_transform(slice_img)

#     # reconstructed_image = pca.inverse_transform(transformed_image)
#     # reconstructed_image = reconstructed_image.reshape(h, w)

#     return transformed_image

In [6]:
final=extracting_brain("OAS2_RAW_PART2/OAS2_0101_MR1/RAW/mpr-1.nifti.hdr")
final = preprocess_mri_image(final)
explore_3D_array(final)

interactive(children=(IntSlider(value=9, description='SLICE', max=19), Output()), _dom_classes=('widget-intera…

### **Preprocess and Store MRI Data**

This function walks through the dataset directory, processes MRI images by extracting the brain, and performs various preprocessing steps (such as denoising, contrast enhancement, and PCA). The processed images are saved as `.npy` files in corresponding directories which will then be fed into the CNN model.

This step takes approximately 30 minutes to run on a Ryzen 5 processor. Please run it only once to obtain the preprocessed data. Do not run it again after the initial execution.

**Note:** Replace the root_dir variable with the path to your oasis 2 dataset


In [28]:
def preprocess_and_store():
    root_dir = 'C:/Users/moksh/OneDrive/Desktop/Alzeimers/oasis_db/OAS2_RAW_PART1' #replace with path to your dataset folder
    for root, dirs, files in os.walk(root_dir):
        # print(f"Current Directory: {root}")
        
        # print(f"Subdirectories: {dirs}")
        
        # print(f"Files: {files}")
        
        if "OLD" not in root:
            for file in files:
                
                file_path = os.path.join(root, file)
                if file_path.endswith("hdr"): 
                    new = file_path.split(os.sep)
                    folder = new[1]                
                    print("Processing " + file + " in folder " + folder)                  
                    final_mask = extracting_brain(file_path)
                    processed_img = preprocess_mri_image(final_mask)                   
                    print("Done processing " + file + " in folder " + folder)
                    print(file_path)
                    processed_file_name = file_path.split(os.sep)[-3]
                    np.save("preProcessed"+"/"+processed_file_name,processed_img)
                    break

preprocess_and_store()

# reconstructed_image = pca.inverse_transform(transformed_image)
# reconstructed_image = reconstructed_image.reshape(h, w)
# explore_3D_array(np.load("C:/Users/moksh/OneDrive/Desktop/Alzeimers/Alzeimers-detection/OAS2_0100_MR1/mpr-1.np.npy"))
    

Processing mpr-1.nifti.hdr in folder OAS2_0001_MR1
Done processing mpr-1.nifti.hdr in folder OAS2_0001_MR1
C:/Users/moksh/OneDrive/Desktop/Alzeimers/oasis_db/OAS2_RAW_PART1\OAS2_0001_MR1\RAW\mpr-1.nifti.hdr
Processing mpr-1.nifti.hdr in folder OAS2_0001_MR2
Done processing mpr-1.nifti.hdr in folder OAS2_0001_MR2
C:/Users/moksh/OneDrive/Desktop/Alzeimers/oasis_db/OAS2_RAW_PART1\OAS2_0001_MR2\RAW\mpr-1.nifti.hdr
Processing mpr-1.nifti.hdr in folder OAS2_0002_MR1
Done processing mpr-1.nifti.hdr in folder OAS2_0002_MR1
C:/Users/moksh/OneDrive/Desktop/Alzeimers/oasis_db/OAS2_RAW_PART1\OAS2_0002_MR1\RAW\mpr-1.nifti.hdr
Processing mpr-1.nifti.hdr in folder OAS2_0002_MR2
Done processing mpr-1.nifti.hdr in folder OAS2_0002_MR2
C:/Users/moksh/OneDrive/Desktop/Alzeimers/oasis_db/OAS2_RAW_PART1\OAS2_0002_MR2\RAW\mpr-1.nifti.hdr
Processing mpr-1.nifti.hdr in folder OAS2_0002_MR3
Done processing mpr-1.nifti.hdr in folder OAS2_0002_MR3
C:/Users/moksh/OneDrive/Desktop/Alzeimers/oasis_db/OAS2_RAW_PART1

# mapping demographic data to the preprocessed MRI data

In [None]:
import os
import pandas as pd

def mapDemographic():
    root_dir = os.getcwd() + "/preProcessed"
    demographic_file = "C:/Users/moksh/OneDrive/Desktop/Alzeimers/Alzeimers-detection/oasis_longitudinal_demographics-8d83e569fa2e2d30.csv"

    os.makedirs("demented", exist_ok=True)
    os.makedirs("non-demented", exist_ok=True)

    df = pd.read_csv(demographic_file)

    # ses_mean = df["SES"].mean()  
    # mmse_mean = df["MMSE"].mean()  
    
    # print(int(ses_mean), int(mmse_mean))

    # df["SES"]= df["SES"].fillna(ses_mean)
    # df["MMSE"] = df["MMSE"].fillna(mmse_mean)

    df["Group"] = df["Group"].apply(lambda x: 0 if x == "Nondemented" else 1)
    # df["M/F"] = df["M/F"].apply(lambda x: 0 if x == "F" else 1)

    # df = df.drop(columns=["Subject ID", "Hand"])
    print(df.head())

    mri_data = {}
    for _, row in df.iterrows():
        mri_id = row["MRI ID"]
        data = row.values.tolist() 
        mri_data[mri_id] = data
    print(mri_data) 
    for root, dirs, files in os.walk(root_dir):        
        for file in files:
            file_path = os.path.join(root, file)
            mri_id = os.path.basename(root)
            if file_path.endswith(".npy"):
                destination="demented" if mri_data[file.strip(".npy")][2] == 1 else "non-demented"
                destination_folder = os.path.join("C:/Users/moksh/OneDrive/Desktop/Alzeimers/Alzeimers-detection", destination,file)
                try:
                    print(file_path)
                    print(destination_folder)
                    shutil.copy2(file_path, destination_folder)
                    print(f"Copied {file} to {'demented' if mri_data[file.strip(".npy")][2] == 1 else 'non-demented'}")

                except FileExistsError:
                    print(f"Folder {mri_id} already exists in destination")

# Run the function 
mapDemographic()




  Subject ID         MRI ID  Group  Visit  MR Delay M/F Hand  Age  EDUC  SES  \
0  OAS2_0001  OAS2_0001_MR1      0      1         0   M    R   87    14  2.0   
1  OAS2_0001  OAS2_0001_MR2      0      2       457   M    R   88    14  2.0   
2  OAS2_0002  OAS2_0002_MR1      1      1         0   M    R   75    12  NaN   
3  OAS2_0002  OAS2_0002_MR2      1      2       560   M    R   76    12  NaN   
4  OAS2_0002  OAS2_0002_MR3      1      3      1895   M    R   80    12  NaN   

   MMSE  CDR  eTIV   nWBV    ASF  
0  27.0  0.0  1987  0.696  0.883  
1  30.0  0.0  2004  0.681  0.876  
2  23.0  0.5  1678  0.736  1.046  
3  28.0  0.5  1738  0.713  1.010  
4  22.0  0.5  1698  0.701  1.034  
{'OAS2_0001_MR1': ['OAS2_0001', 'OAS2_0001_MR1', 0, 1, 0, 'M', 'R', 87, 14, 2.0, 27.0, 0.0, 1987, 0.696, 0.883], 'OAS2_0001_MR2': ['OAS2_0001', 'OAS2_0001_MR2', 0, 2, 457, 'M', 'R', 88, 14, 2.0, 30.0, 0.0, 2004, 0.681, 0.876], 'OAS2_0002_MR1': ['OAS2_0002', 'OAS2_0002_MR1', 1, 1, 0, 'M', 'R', 75, 12, nan, 23

In [55]:
import os
import shutil
import random

def split_data(class_name, split_ratio=0.8):
    src_dir = f'./{class_name}'
    if not os.path.exists(src_dir):
        print(f"Source folder {src_dir} does not exist!")
        return

    # Include all files (ignore extensions)
    files = [f for f in os.listdir(src_dir)]
    print(f"Found {len(files)} files in {class_name}")
    random.shuffle(files)
    split_idx = int(len(files) * split_ratio)
    train_files = files[:split_idx]
    val_files = files[split_idx:]

    for split, split_files in [('train', train_files), ('val', val_files)]:
        dst_dir = os.path.join('dataset', split, class_name)
        os.makedirs(dst_dir, exist_ok=True)
        for file in split_files:
            src_file = os.path.join(src_dir, file)
            dst_file = os.path.join(dst_dir, file)
            print(f"Copying {src_file} to {dst_file}")
            shutil.copy(src_file, dst_file)

split_data('demented')
split_data('non-demented')


Found 183 files in demented
Copying ./demented\OAS2_0071_MR2.npy to dataset\train\demented\OAS2_0071_MR2.npy
Copying ./demented\OAS2_0058_MR1.npy to dataset\train\demented\OAS2_0058_MR1.npy
Copying ./demented\OAS2_0080_MR2.npy to dataset\train\demented\OAS2_0080_MR2.npy
Copying ./demented\OAS2_0103_MR3.npy to dataset\train\demented\OAS2_0103_MR3.npy
Copying ./demented\OAS2_0176_MR3.npy to dataset\train\demented\OAS2_0176_MR3.npy
Copying ./demented\OAS2_0176_MR2.npy to dataset\train\demented\OAS2_0176_MR2.npy
Copying ./demented\OAS2_0098_MR2.npy to dataset\train\demented\OAS2_0098_MR2.npy
Copying ./demented\OAS2_0134_MR1.npy to dataset\train\demented\OAS2_0134_MR1.npy
Copying ./demented\OAS2_0031_MR2.npy to dataset\train\demented\OAS2_0031_MR2.npy
Copying ./demented\OAS2_0089_MR1.npy to dataset\train\demented\OAS2_0089_MR1.npy
Copying ./demented\OAS2_0079_MR2.npy to dataset\train\demented\OAS2_0079_MR2.npy
Copying ./demented\OAS2_0002_MR1.npy to dataset\train\demented\OAS2_0002_MR1.npy


In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from efficientnet_pytorch import EfficientNet
import numpy as np
import time
from tqdm import tqdm

class NpyDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.transform = transform
        self.data = []
        
        print(f"Loading data from {directory}...")
        start_time = time.time()
        
        for label in os.listdir(directory):
            label_path = os.path.join(directory, label)
            if os.path.isdir(label_path):
                for npy_file in os.listdir(label_path):
                    npy_path = os.path.join(label_path, npy_file)                   
                    volume = np.load(npy_path)  # Shape (20, 264, 115)
                    self.data.append((volume, label))

        print(f"Data loaded in {time.time() - start_time:.2f} seconds.")
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        volume, label = self.data[idx]
        
        label_idx = 0 if label == "demented" else 1
        
        processed_slices = []
        for i in range(volume.shape[0]):
            slice_2d = volume[i, :, :]
            # Convert to RGB for compatibility with standard image transforms
            slice_rgb = np.repeat(slice_2d[:, :, np.newaxis], 3, axis=2).astype(np.float32)
            
            if self.transform:
                slice_rgb = self.transform(slice_rgb)
                
            processed_slices.append(slice_rgb)
        
        if not isinstance(processed_slices[0], torch.Tensor):
            processed_slices = [torch.tensor(slice_arr) for slice_arr in processed_slices]
            
        volume_tensor = torch.stack(processed_slices)
        
        return volume_tensor, label_idx


class SliceBasedClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super(SliceBasedClassifier, self).__init__()
        
        self.efficientnet = EfficientNet.from_pretrained('efficientnet-b2')
        
        self.features = nn.Sequential(*list(self.efficientnet.children())[:-1])
        

        feature_dim = self.efficientnet._fc.in_features
        
        self.aggregator = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), 
            nn.Flatten()
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        batch_size, num_slices = x.shape[0], x.shape[1]
        
        slice_features = []
        for i in range(num_slices):
            slice_batch = x[:, i] 
            
            features = self.efficientnet.extract_features(slice_batch)
            pooled = self.aggregator(features)  
            slice_features.append(pooled)
        
        combined_features = torch.stack(slice_features, dim=1)  
        aggregated_features = torch.mean(combined_features, dim=1)  
        
        output = self.classifier(aggregated_features)
        return output


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((260, 260)), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet stats
])

train_dir = 'dataset/train'
val_dir = 'dataset/val'  

train_dataset = NpyDataset(train_dir, transform=transform)
val_dataset = NpyDataset(val_dir, transform=transform)

batch_size = 2  
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

num_classes = 2  
model = SliceBasedClassifier(num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        # Training loop
        start_epoch_time = time.time()
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()

            outputs = model(inputs)
            
            loss = criterion(outputs, labels)
            loss.backward()

            optimizer.step()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = 100 * correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")
        print(f"Epoch {epoch+1} took {time.time() - start_epoch_time:.2f} seconds.")

        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                outputs = model(inputs)
                
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = 100 * val_correct / val_total
        print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")


train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)

Using device: cuda
Loading data from dataset/train...
Data loaded in 34.16 seconds.
Loading data from dataset/val...
Data loaded in 4.39 seconds.
Loaded pretrained weights for efficientnet-b2


Epoch 1/10:  20%|██        | 65/325 [24:07<1:36:28, 22.26s/batch]


KeyboardInterrupt: 

: 

MAKING TEST SET FROM OASIS 1

In [81]:
root_dir="C:/Users/moksh/OneDrive/Desktop/Alzeimers/oasis_db/disc2"
final=[]
for root, dirs, files in os.walk(root_dir):        
    for file in files:
        if file.endswith("anon.hdr"):
            path=root+"/"+file
            img = nib.load(path)
            arr = img.get_fdata()
            arr=arr.squeeze()
            arr.shape
            reoriented = np.transpose(arr, (2, 0, 1))  # Now shape is [Z, X, Y]
            # Rotate the reoriented array by 90 degrees along the Z-axis (axis=0)
            rotated = np.rot90(reoriented, k=1, axes=(1, 2))  # Rotate on axes (1, 2) to rotate the image in the Z-Y plane
            np.save("C:/Users/moksh/OneDrive/Desktop/Alzeimers/Alzeimers-detection/testRaw/"+file,rotated)
            break
print(len(final))
print(final)


0
[]


In [None]:
for root,_,files in os.walk("C:/Users/moksh/OneDrive/Desktop/Alzeimers/Alzeimers-detection/testRaw"):
    for file in files:
        ls=file.split("_")
        name=ls[0]+"_"+ls[1]+"_"+ls[2]+".npy"
        print(name)
        
        os.rename(root+"/"+file,root+"/"+name)


OAS1_0004_MR1.npy.npy.npy
OAS1_0005_MR1.npy.npy.npy
OAS1_0006_MR1.npy.npy.npy
OAS1_0007_MR1.npy.npy.npy
OAS1_0009_MR1.npy.npy.npy
OAS1_0012_MR1.npy.npy.npy
OAS1_0014_MR1.npy.npy.npy
OAS1_0017_MR1.npy.npy.npy
OAS1_0025_MR1.npy.npy.npy
OAS1_0027_MR1.npy.npy.npy
OAS1_0029_MR1.npy.npy.npy
OAS1_0037_MR1.npy.npy.npy
OAS1_0038_MR1.npy.npy.npy
OAS1_0040_MR1.npy.npy.npy
OAS1_0043_MR1.npy.npy
OAS1_0044_MR1.npy.npy
OAS1_0045_MR1.npy.npy
OAS1_0046_MR1.npy.npy
OAS1_0047_MR1.npy.npy
OAS1_0049_MR1.npy.npy
OAS1_0050_MR1.npy.npy
OAS1_0051_MR1.npy.npy
OAS1_0052_MR1.npy.npy
OAS1_0053_MR1.npy.npy
OAS1_0054_MR1.npy.npy
OAS1_0055_MR1.npy.npy
OAS1_0056_MR1.npy.npy
OAS1_0057_MR1.npy.npy
OAS1_0058_MR1.npy.npy
OAS1_0059_MR1.npy.npy
OAS1_0060_MR1.npy.npy
OAS1_0061_MR1.npy.npy
OAS1_0061_MR2.npy.npy
OAS1_0062_MR1.npy.npy
OAS1_0063_MR1.npy.npy
OAS1_0064_MR1.npy.npy
OAS1_0065_MR1.npy.npy
OAS1_0066_MR1.npy.npy
OAS1_0067_MR1.npy.npy
OAS1_0068_MR1.npy.npy
OAS1_0069_MR1.npy.npy
OAS1_0070_MR1.npy.npy
OAS1_0071_MR1.npy.np

In [91]:
import pandas as pd
df=pd.read_csv("C:/Users/moksh/OneDrive/Desktop/Alzeimers/Alzeimers-detection/oasis1demographic.csv")
df.head()
df=df.drop(["M/F","Hand","Age","Educ","SES","MMSE","eTIV","nWBV","ASF","Delay"],axis=1)
df.head()
df=df.dropna()
df.head()
dict_df = df.set_index('ID')['CDR'].to_dict()
dict_df
for root,dirs,files in os.walk("C:/Users/moksh/OneDrive/Desktop/Alzeimers/Alzeimers-detection/testRaw"):
    for file in files:
        if file.strip(".npy") in dict_df.keys():
            if dict_df[file.strip(".npy")] == 0.0:
                new="non-demented"
            else:
                new="demented"
            os.rename(root+"/"+file,root+"/"+new+"/"+file)
    break

In [None]:
explore_3D_array(np.load("C:/Users/moksh/OneDrive/Desktop/Alzeimers/Alzeimers-detection/trainRaw/non-demented/OAS1_0030_MR1.npy"))

interactive(children=(IntSlider(value=63, description='SLICE', max=127), Output()), _dom_classes=('widget-inte…