# Modules to download

In [None]:
# !pip install matplotlib nibabel numpy opencv-python scipy scikit-image antsx SimpleITK



In [1]:
# Importation for modules.
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import cv2
from scipy.ndimage import binary_fill_holes
from skimage.morphology import remove_small_objects, convex_hull_image
from skimage.segmentation import active_contour
from skimage.filters import gaussian
import ants
import SimpleITK as sitk
from helpers import *
from antspynet.utilities import brain_extraction
import os
import csv
import shutil

### **Cleaning Demographic Data**
This process cleans the demographic data from the OASIS-2 dataset and stores the cleaned data as a CSV file in the current working directory (CWD).

**Note:** Replace the Dataset_folder variable (line 1) with the path to your demographic data excel file or use the cleaned_demographic_data.csv file aldredy present in this folder

In [None]:
Dataset_folder = "OAS2_RAW_PART2" #replace with path to your dataset folder
if not os.path.exists(Dataset_folder):
        print([])  

nifti_files = []
for item in os.listdir(Dataset_folder):
    item_path = os.path.join(Dataset_folder, item)
    if os.path.isdir(item_path):  # Check if it's a directory
        nifti_files.append(item)



header = None #storing header to write to new cleaned csv file
demographic_mri_ids = [] #stores mri mri_ids in demographic data
demographic_data = "oasis_longitudinal_demographics-8d83e569fa2e2d30.csv" #file path of uncleaned data
with open(demographic_data, 'r', newline='', encoding='utf-8') as csvfile:
    reader = csv.reader(csvfile)
    header = next(reader)  # Read header
    column_index = 1  # index of "mri id" column

    for row in reader:
        try:
            demographic_mri_ids.append(row[column_index])
        except IndexError:
            print(f"Warning: Row has fewer columns than expected.")  # Handle missing data
print()
print(len(demographic_mri_ids))
print(len(nifti_files))

#finding intersection
demographic_mri_ids_set = set(demographic_mri_ids)
nifti_files_set = set(nifti_files)
final_data = demographic_mri_ids_set.intersection(nifti_files_set)
print(len(final_data) == len(nifti_files)) # len same as nifti files therfore all demographics data for nifti files available


#creating new csv with cleaned data
filtered_rows = []
with open(demographic_data, 'r', newline='', encoding='utf-8') as csvfile:
    reader = csv.reader(csvfile)
    header = next(reader)  # Read header (if present)
    column_index = 1  # index of "mri id" column

    for row in reader:
        try:
            if row[column_index] in final_data:
                filtered_rows.append(row)
        except IndexError:
            print(f"Warning: Row has fewer columns than expected.")  # Handle missing data

output_filepath = "cleaned_demographic_data"
with open(output_filepath, 'w', newline='', encoding='utf-8') as outfile:
    writer = csv.writer(outfile)
    writer.writerow(header)
    writer.writerows(filtered_rows)
    print(f"Filtered data written to: {output_filepath}")



373
164
True
Filtered data written to: cleaned_demographic_data


### **Functions for NIfTI Display and Processing**

This section includes functions to interactively explore 3D arrays, rescale array values, add suffixes to filenames, and overlay mask contours on 3D images for visualization.


In [5]:
# Function to display nifti files
from ipywidgets import interact
def explore_3D_array(arr: np.ndarray, cmap: str='gray'):
    def fn(SLICE):
        plt.figure(figsize=(7,7))
        plt.axis('off')
        plt.imshow(arr[SLICE,:,:], cmap=cmap)

    interact(fn, SLICE=(0, arr.shape[0]-1))

def add_suffix_to_filename(filename: str, suffix: str) ->str:
    if filename.endswith('.nifti.hdr'):
        result = filename.replace('.nifti.hdr', f'_{suffix}.nifti.hdr')
        return result
    else:
        raise RuntimeError('filename with unknown ext')

def rescale_linear(array: np.ndarray, new_min: int, new_max: int):
  minimum, maximum = np.min(array), np.max(array)
  m = (new_max - new_min) / (maximum - minimum)
  b = new_min - m * minimum
  return m * array + b

def explore_3D_array_with_mask_contour(arr: np.ndarray, mask: np.ndarray, thickness: int = 1):
  _arr = rescale_linear(arr,0,1)
  _mask = rescale_linear(mask,0,1)
  _mask = _mask.astype(np.uint8)

  def fn(SLICE):
    arr_rgb = cv2.cvtColor(_arr[SLICE, :, :], cv2.COLOR_GRAY2RGB)
    contours, _ = cv2.findContours(_mask[SLICE, :, :], cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    arr_with_contours = cv2.drawContours(arr_rgb, contours, -1, (0,1,0), thickness)

    plt.figure(figsize=(7,7))
    plt.imshow(arr_with_contours)

  interact(fn, SLICE=(0, arr.shape[0]-1))

### **Brain Extraction Function**

This function reads a NIfTI image, performs brain extraction using ANTs, applies the brain mask, and returns the masked brain image as a NumPy array, preserving metadata.


In [2]:
def extracting_brain(img_path):
    ant_img = ants.image_read(img_path, reorient='RAS')
    # explore_3D_array(arr=ant_img.numpy(), cmap='nipy_spectral')
    prob_brain_mask = brain_extraction(ant_img, modality="t1")
    brain_mask = ants.get_mask(prob_brain_mask, low_thresh=0.5)
    # explore_3D_array_with_mask_contour(ant_img.numpy(), brain_mask.numpy())
    # Apply the mask
    masked = ants.mask_image(ant_img, brain_mask)

    # Convert to NumPy array
    final_mask = masked.numpy()


    # Convert back to ANTs image (to preserve metadata)
    rotated_masked_ant = ants.from_numpy(final_mask, origin=masked.origin, spacing=masked.spacing, direction=masked.direction)

    # Visualize the rotated image
    # explore_3D_array(final_mask, cmap='gray')
    final_mask.shape
    return final_mask


### **MRI Image Preprocessing**

This function processes MRI images through a series of steps, including:
1. Grayscale conversion (if needed)
2. CLAHE (Contrast Limited Adaptive Histogram Equalization) for contrast enhancement
3. Gaussian and Median blurring for noise reduction
4. Sharpening using a custom kernel
5. Non-Local Means Denoising
6. Adding salt-and-pepper noise
7. Applying PCA to reduce dimensionality of the image slices

It processes both 2D slices and 3D volumes of MRI data.


In [3]:
from sklearn.decomposition import PCA
def preprocess_mri_image(img_array):
    """
    Preprocess an MRI image with the following steps:
    1. Convert to grayscale (if needed)
    2. Apply CLAHE for contrast enhancement
    3. Apply Gaussian blur
    4. Apply Median blur for noise reduction
    5. Apply sharpening
    6. Apply Non-Local Means Denoising
    7. Add salt-and-pepper noise at the end

    :param img_array: NumPy array of the MRI image
    :return: Processed NumPy array
    """
    img_array
    if len(img_array.shape) == 3:  # Check if it's a 3D array (e.g., 128, 256, 256)
        processed_slices = []
        for i in range((img_array.shape[0]//2)-10,(img_array.shape[0]//2)+10):  # Iterate through each slice
            slice_img = img_array[i]
            processed_slices.append(process_single_slice(slice_img))
        return np.array(processed_slices)

    else:
        return process_single_slice(img_array)

def process_single_slice(slice_img):
    """
    Process a single 2D slice of the MRI image.
    """
    slice_img = np.rot90(slice_img, k=3)
    # Ensure image is single-channel grayscale
    if len(slice_img.shape) == 3 and slice_img.shape[-1] == 3:  # Check if RGB
        slice_img = cv2.cvtColor(slice_img, cv2.COLOR_BGR2GRAY)

    # Normalize and convert to uint8
    slice_img = cv2.normalize(slice_img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

    # Apply CLAHE (Adaptive Histogram Equalization)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    slice_img = clahe.apply(slice_img)

    # Apply Gaussian Blur
    slice_img = cv2.GaussianBlur(slice_img, (5, 5), 0)

    # Apply sharpening filter
    sharpen_kernel = np.array([[0, -1, 0],
                               [-1, 5, -1],
                               [0, -1, 0]])
    slice_img = cv2.filter2D(slice_img, -1, sharpen_kernel)

    # Apply Non-Local Means Denoising
    slice_img = cv2.fastNlMeansDenoising(slice_img, None, h=10, templateWindowSize=7, searchWindowSize=21)

    # Add Salt-and-Pepper Noise
    slice_img = add_salt_and_pepper(slice_img)

    # Apply Median Blur
    slice_img = cv2.medianBlur(slice_img, 5)

    return slice_img

def add_salt_and_pepper(img, salt_prob=0.02, pepper_prob=0.02):
    """
    Function to add salt-and-pepper noise to an image.
    """
    noisy_img = img.copy()
    num_salt = int(salt_prob * img.size)
    num_pepper = int(pepper_prob * img.size)

    # Add salt (white pixels)
    coords = [np.random.randint(0, i - 1, num_salt) for i in img.shape]
    noisy_img[tuple(coords)] = 255

    # Add pepper (black pixels)
    coords = [np.random.randint(0, i - 1, num_pepper) for i in img.shape]
    noisy_img[tuple(coords)] = 0

    return noisy_img


# def apply_pca_to_slice(slice_img, n_components=115):
#     """
#     Apply PCA to a single 2D slice and return the transformed image (without reconstruction).

#     :param slice_img: 2D slice of the MRI image
#     :param n_components: Number of principal components to keep
#     :return: PCA transformed slice (without reconstruction)
#     """
#     # Step 1: Reshape the slice to a 2D array (flatten it)
#     # h, w = slice_img.shape
#     # reshaped_slice = slice_img.reshape(h, w)

#     # Step 2: Apply PCA to the reshaped slice
#     pca = PCA(n_components=n_components)
#     transformed_image = pca.fit_transform(slice_img)

#     # reconstructed_image = pca.inverse_transform(transformed_image)
#     # reconstructed_image = reconstructed_image.reshape(h, w)

#     return transformed_image

In [6]:
final=extracting_brain("OAS2_RAW_PART2/OAS2_0101_MR1/RAW/mpr-1.nifti.hdr")
final = preprocess_mri_image(final)
explore_3D_array(final)

interactive(children=(IntSlider(value=9, description='SLICE', max=19), Output()), _dom_classes=('widget-intera…

### **Preprocess and Store MRI Data**

This function walks through the dataset directory, processes MRI images by extracting the brain, and performs various preprocessing steps (such as denoising, contrast enhancement, and PCA). The processed images are saved as `.npy` files in corresponding directories which will then be fed into the CNN model.

This step takes approximately 30 minutes to run on a Ryzen 5 processor. Please run it only once to obtain the preprocessed data. Do not run it again after the initial execution.

**Note:** Replace the root_dir variable with the path to your oasis 2 dataset


In [7]:
def preprocess_and_store():
    root_dir = 'OAS2_RAW_PART2' #replace with path to your dataset folder
    for root, dirs, files in os.walk(root_dir):
        # print(f"Current Directory: {root}")
        
        # print(f"Subdirectories: {dirs}")
        
        # print(f"Files: {files}")
        
        if "OLD" not in root:
            for file in files:
                
                file_path = os.path.join(root, file)
                if file_path.endswith("hdr"): 
                    new = file_path.split(os.sep)
                    folder = new[1]                
                    os.makedirs(folder,exist_ok=True)
                    print("Processing " + file + " in folder " + folder)                  
                    final_mask = extracting_brain(file_path)
                    processed_img = preprocess_mri_image(final_mask)                   
                    print("Done processing " + file + " in folder " + folder)
                    processed_file_name = file_path.split(os.sep)[-1]
                    np.save(folder+"/"+processed_file_name,processed_img)
                    break

preprocess_and_store()
# reconstructed_image = pca.inverse_transform(transformed_image)
# reconstructed_image = reconstructed_image.reshape(h, w)
# explore_3D_array(np.load("C:/Users/moksh/OneDrive/Desktop/Alzeimers/Alzeimers-detection/OAS2_0100_MR1/mpr-1.np.npy"))
    

Processing mpr-1.nifti.hdr in folder OAS2_0100_MR1
Done processing mpr-1.nifti.hdr in folder OAS2_0100_MR1
Processing mpr-1.nifti.hdr in folder OAS2_0100_MR2
Done processing mpr-1.nifti.hdr in folder OAS2_0100_MR2
Processing mpr-1.nifti.hdr in folder OAS2_0100_MR3
Done processing mpr-1.nifti.hdr in folder OAS2_0100_MR3
Processing mpr-1.nifti.hdr in folder OAS2_0101_MR1
Done processing mpr-1.nifti.hdr in folder OAS2_0101_MR1
Processing mpr-1.nifti.hdr in folder OAS2_0101_MR2
Done processing mpr-1.nifti.hdr in folder OAS2_0101_MR2
Processing mpr-1.nifti.hdr in folder OAS2_0101_MR3
Done processing mpr-1.nifti.hdr in folder OAS2_0101_MR3
Processing mpr-1.nifti.hdr in folder OAS2_0102_MR1
Done processing mpr-1.nifti.hdr in folder OAS2_0102_MR1
Processing mpr-1.nifti.hdr in folder OAS2_0102_MR2
Done processing mpr-1.nifti.hdr in folder OAS2_0102_MR2
Processing mpr-1.nifti.hdr in folder OAS2_0102_MR3
Done processing mpr-1.nifti.hdr in folder OAS2_0102_MR3
Processing mpr-1.nifti.hdr in folder 

# mapping demographic data to the preprocessed MRI data

In [None]:
import os
import pandas as pd

def mapDemographic():
    root_dir = os.getcwd() + "/processed"
    demographic_file = "cleaned_demographic_data"

    df = pd.read_csv(demographic_file)

    # ses_mean = df["SES"].mean()  
    # mmse_mean = df["MMSE"].mean()  
    
    # print(int(ses_mean), int(mmse_mean))

    # df["SES"]= df["SES"].fillna(ses_mean)
    # df["MMSE"] = df["MMSE"].fillna(mmse_mean)

    df["Group"] = df["Group"].apply(lambda x: 0 if x == "Nondemented" else 1)
    # df["M/F"] = df["M/F"].apply(lambda x: 0 if x == "F" else 1)

    # df = df.drop(columns=["Subject ID", "Hand"])
    # print(df.head())

    mri_data = {}
    for _, row in df.iterrows():
        mri_id = row["MRI ID"]
        data = row.values.tolist() 
        mri_data[mri_id] = data
    print(mri_data)

    for root, dirs, files in os.walk(root_dir):
        mri_id = os.path.basename(root)
        
        for file in files:
            file_path = os.path.join(root, file)
            mri_id = os.path.basename(root)
            if file_path.endswith(".npy"):
                destination="demented" if mri_data[mri_id][2] == 1 else "non-demented"
                destination_folder = os.path.join("C:/Users/moksh/OneDrive/Desktop/Alzeimers/Alzeimers-detection", destination,mri_id)
                try:
                    print(file_path)
                    print(destination_folder)
                    shutil.copy2(file_path, destination_folder)
                    print(f"Copied {mri_id} to {'demented' if mri_data[mri_id][2] == 1 else 'non-demented'}")

                except FileExistsError:
                    print(f"Folder {mri_id} already exists in destination")

    
        

# Run the function 
mapDemographic()


{'OAS2_0100_MR1': ['OAS2_0100', 'OAS2_0100_MR1', 0, 1, 0, 'F', 'R', 77, 11, 4.0, 29.0, 0.0, 1583, 0.777, 1.108], 'OAS2_0100_MR2': ['OAS2_0100', 'OAS2_0100_MR2', 0, 2, 1218, 'F', 'R', 80, 11, 4.0, 30.0, 0.0, 1586, 0.757, 1.107], 'OAS2_0100_MR3': ['OAS2_0100', 'OAS2_0100_MR3', 0, 3, 1752, 'F', 'R', 82, 11, 4.0, 30.0, 0.0, 1590, 0.76, 1.104], 'OAS2_0101_MR1': ['OAS2_0101', 'OAS2_0101_MR1', 0, 1, 0, 'F', 'R', 71, 18, 2.0, 30.0, 0.0, 1371, 0.769, 1.28], 'OAS2_0101_MR2': ['OAS2_0101', 'OAS2_0101_MR2', 0, 2, 952, 'F', 'R', 74, 18, 2.0, 30.0, 0.0, 1400, 0.752, 1.254], 'OAS2_0101_MR3': ['OAS2_0101', 'OAS2_0101_MR3', 0, 3, 1631, 'F', 'R', 76, 18, 2.0, 30.0, 0.0, 1379, 0.757, 1.273], 'OAS2_0102_MR1': ['OAS2_0102', 'OAS2_0102_MR1', 1, 1, 0, 'M', 'R', 82, 15, 3.0, 29.0, 0.5, 1499, 0.689, 1.171], 'OAS2_0102_MR2': ['OAS2_0102', 'OAS2_0102_MR2', 1, 2, 610, 'M', 'R', 84, 15, 3.0, 29.0, 0.5, 1497, 0.686, 1.172], 'OAS2_0102_MR3': ['OAS2_0102', 'OAS2_0102_MR3', 1, 3, 1387, 'M', 'R', 86, 15, 3.0, 30.0, 0.5

FileNotFoundError: [WinError 3] The system cannot find the path specified

In [22]:
import os
import shutil

def flatten_to_new_dirs(src_root, dest_root):
    for group in ["demented", "non-demented"]:
        group_src = os.path.join(src_root, group)
        group_dest = os.path.join(dest_root, group)
        os.makedirs(group_dest, exist_ok=True)

        for folder_name in os.listdir(group_src):
            folder_path = os.path.join(group_src, folder_name)
            if not os.path.isdir(folder_path):
                continue

            # Find .npy file inside
            npy_files = [f for f in os.listdir(folder_path) if f.endswith(".npy")]
            if not npy_files:
                print(f"❌ No .npy in {folder_path}")
                continue

            npy_file_path = os.path.join(folder_path, npy_files[0])
            new_path = os.path.join(group_dest, folder_name)  # no extension

            # Copy with new name (no extension)
            shutil.copy2(npy_file_path, new_path)
            print(f"✅ Copied to: {new_path}")

# 🔁 Call with actual paths
src = r"D:\Alzeimers-detection"
dest = r"D:\Alzeimers-detection_flattened"
flatten_to_new_dirs(src, dest)


✅ Copied to: D:\Alzeimers-detection_flattened\demented\OAS2_0103_MR2
✅ Copied to: D:\Alzeimers-detection_flattened\demented\OAS2_0103_MR3
✅ Copied to: D:\Alzeimers-detection_flattened\demented\OAS2_0104_MR1
✅ Copied to: D:\Alzeimers-detection_flattened\demented\OAS2_0104_MR2
✅ Copied to: D:\Alzeimers-detection_flattened\demented\OAS2_0106_MR1
✅ Copied to: D:\Alzeimers-detection_flattened\demented\OAS2_0106_MR2
✅ Copied to: D:\Alzeimers-detection_flattened\demented\OAS2_0108_MR1
✅ Copied to: D:\Alzeimers-detection_flattened\demented\OAS2_0108_MR2
✅ Copied to: D:\Alzeimers-detection_flattened\demented\OAS2_0111_MR1
✅ Copied to: D:\Alzeimers-detection_flattened\demented\OAS2_0111_MR2
✅ Copied to: D:\Alzeimers-detection_flattened\demented\OAS2_0112_MR1
✅ Copied to: D:\Alzeimers-detection_flattened\demented\OAS2_0112_MR2
✅ Copied to: D:\Alzeimers-detection_flattened\demented\OAS2_0113_MR1
✅ Copied to: D:\Alzeimers-detection_flattened\demented\OAS2_0113_MR2
✅ Copied to: D:\Alzeimers-detectio

In [2]:
import os
import shutil
import random

def split_data(class_name, split_ratio=0.8):
    src_dir = f'./{class_name}'
    if not os.path.exists(src_dir):
        print(f"Source folder {src_dir} does not exist!")
        return

    # Include all files (ignore extensions)
    files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))]
    print(f"Found {len(files)} files in {class_name}")

    if len(files) == 0:
        print(f"No files found in {src_dir}")
        return

    random.shuffle(files)
    split_idx = int(len(files) * split_ratio)
    train_files = files[:split_idx]
    val_files = files[split_idx:]

    for split, split_files in [('train', train_files), ('val', val_files)]:
        dst_dir = os.path.join('dataset', split, class_name)
        os.makedirs(dst_dir, exist_ok=True)
        for file in split_files:
            src_file = os.path.join(src_dir, file)
            dst_file = os.path.join(dst_dir, file)
            print(f"Copying {src_file} to {dst_file}")
            shutil.copy(src_file, dst_file)

split_data('demented')
split_data('non-demented')


Found 86 files in demented
Copying ./demented\OAS2_0145_MR2 to dataset\train\demented\OAS2_0145_MR2
Copying ./demented\OAS2_0185_MR1 to dataset\train\demented\OAS2_0185_MR1
Copying ./demented\OAS2_0140_MR2 to dataset\train\demented\OAS2_0140_MR2
Copying ./demented\OAS2_0165_MR1 to dataset\train\demented\OAS2_0165_MR1
Copying ./demented\OAS2_0120_MR2 to dataset\train\demented\OAS2_0120_MR2
Copying ./demented\OAS2_0182_MR2 to dataset\train\demented\OAS2_0182_MR2
Copying ./demented\OAS2_0139_MR1 to dataset\train\demented\OAS2_0139_MR1
Copying ./demented\OAS2_0176_MR2 to dataset\train\demented\OAS2_0176_MR2
Copying ./demented\OAS2_0140_MR3 to dataset\train\demented\OAS2_0140_MR3
Copying ./demented\OAS2_0127_MR1 to dataset\train\demented\OAS2_0127_MR1
Copying ./demented\OAS2_0103_MR1 to dataset\train\demented\OAS2_0103_MR1
Copying ./demented\OAS2_0124_MR1 to dataset\train\demented\OAS2_0124_MR1
Copying ./demented\OAS2_0137_MR1 to dataset\train\demented\OAS2_0137_MR1
Copying ./demented\OAS2_

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from efficientnet_pytorch import EfficientNet
import numpy as np
import time
from tqdm import tqdm

class NpyDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.transform = transform
        self.data = []
        
        print(f"Loading data from {directory}...")
        start_time = time.time()
        
        for label in os.listdir(directory):
            label_path = os.path.join(directory, label)
            if os.path.isdir(label_path):
                for npy_file in os.listdir(label_path):
                    npy_path = os.path.join(label_path, npy_file)                   
                    volume = np.load(npy_path)  # Shape (20, 264, 115)
                    self.data.append((volume, label))

        print(f"Data loaded in {time.time() - start_time:.2f} seconds.")
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        volume, label = self.data[idx]
        
        label_idx = 0 if label == "demented" else 1
        
        processed_slices = []
        for i in range(volume.shape[0]):
            slice_2d = volume[i, :, :]
            # Convert to RGB for compatibility with standard image transforms
            slice_rgb = np.repeat(slice_2d[:, :, np.newaxis], 3, axis=2).astype(np.float32)
            
            if self.transform:
                slice_rgb = self.transform(slice_rgb)
                
            processed_slices.append(slice_rgb)
        
        if not isinstance(processed_slices[0], torch.Tensor):
            processed_slices = [torch.tensor(slice_arr) for slice_arr in processed_slices]
            
        volume_tensor = torch.stack(processed_slices)
        
        return volume_tensor, label_idx


class SliceBasedClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super(SliceBasedClassifier, self).__init__()
        
        self.efficientnet = EfficientNet.from_pretrained('efficientnet-b2')
        
        self.features = nn.Sequential(*list(self.efficientnet.children())[:-1])
        

        feature_dim = self.efficientnet._fc.in_features
        
        self.aggregator = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), 
            nn.Flatten()
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        batch_size, num_slices = x.shape[0], x.shape[1]
        
        slice_features = []
        for i in range(num_slices):
            slice_batch = x[:, i] 
            
            features = self.efficientnet.extract_features(slice_batch)
            pooled = self.aggregator(features)  
            slice_features.append(pooled)
        
        combined_features = torch.stack(slice_features, dim=1)  
        aggregated_features = torch.mean(combined_features, dim=1)  
        
        output = self.classifier(aggregated_features)
        return output


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((260, 260)), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet stats
])

train_dir = 'dataset/train'
val_dir = 'dataset/val'  

train_dataset = NpyDataset(train_dir, transform=transform)
val_dataset = NpyDataset(val_dir, transform=transform)

batch_size = 2  
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

num_classes = 2  
model = SliceBasedClassifier(num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        # Training loop
        start_epoch_time = time.time()
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()

            outputs = model(inputs)
            
            loss = criterion(outputs, labels)
            loss.backward()

            optimizer.step()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = 100 * correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")
        print(f"Epoch {epoch+1} took {time.time() - start_epoch_time:.2f} seconds.")

        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                outputs = model(inputs)
                
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = 100 * val_correct / val_total
        print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")


train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)

Using device: cuda
Loading data from dataset/train...
Data loaded in 34.16 seconds.
Loading data from dataset/val...
Data loaded in 4.39 seconds.
Loaded pretrained weights for efficientnet-b2


Epoch 1/10:  20%|██        | 65/325 [24:07<1:36:28, 22.26s/batch]


KeyboardInterrupt: 

: 