In [1]:
import os
import pandas as pd
import numpy as np
import pydicom

def read_image_df(data_dir):
    """
    Recursively scan the data_dir and create a DataFrame that stores the image file paths
    and the label, where the label is the folder name.
    """
    data = []
    for root, dirs, files in os.walk(data_dir):
        # The folder name (last part of root) is used as the label
        label = os.path.basename(root)
        # Skip folders that are not labels (optional: you can add a check if labels are only '0' and '1')
        if label not in ['0', '1']:
            continue
        for file in files:
            if file.lower().endswith('.dcm'):
                file_path = os.path.join(root, file)
                data.append({
                    "file_path": file_path,
                    "label": int(label)  # convert label to integer if needed
                })
    return pd.DataFrame(data)

# Define your data directory where the labeled folders ('0' and '1') are located.
data_dir = '/Users/yavuzalpdemirci/Desktop/data_for_testing'

# 1. Create the Image DataFrame: file_path and label (extracted from folder names)
image_df = read_image_df(data_dir)
print("Image DataFrame shape:", image_df.shape)

Image DataFrame shape: (84, 2)


In [2]:
image_df.head()

Unnamed: 0,file_path,label
0,/Users/yavuzalpdemirci/Desktop/data_for_testin...,0
1,/Users/yavuzalpdemirci/Desktop/data_for_testin...,0
2,/Users/yavuzalpdemirci/Desktop/data_for_testin...,0
3,/Users/yavuzalpdemirci/Desktop/data_for_testin...,0
4,/Users/yavuzalpdemirci/Desktop/data_for_testin...,0


In [3]:
image_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 84 entries, 0 to 83
Data columns (total 2 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   file_path  84 non-null     object
 1   label      84 non-null     int64 
dtypes: int64(1), object(1)
memory usage: 1.4+ KB


In [4]:
image_df.file_path[1]

'/Users/yavuzalpdemirci/Desktop/data_for_testing/0/10023.dcm'

In [5]:
def process_window_value(window_value):
    if isinstance(window_value, pydicom.multival.MultiValue):
        return list(window_value) 
    else:
        return [float(window_value)] 

def read_metadata_df(data_dir):
    """
    Recursively scan the data_dir and create a DataFrame that stores the file path and extracted metadata.
    """
    metadata_list = []
    for root, dirs, files in os.walk(data_dir):
        for file in files:
            if file.lower().endswith('.dcm'):
                file_path = os.path.join(root, file)
                try:
                    ds = pydicom.dcmread(file_path)
                    # Extract all the relevant metadata fields
                    metadata = {
                        "file_path": file_path,
                        "SliceThickness": float(ds.SliceThickness),
                        "SamplesPerPixel": float(ds.SamplesPerPixel),
                        "Photometric Interpretation": ds.PhotometricInterpretation,
                        "PixelSpacing": process_window_value(ds.PixelSpacing),
                        "BitsAllocated": float(ds.BitsAllocated),
                        "BitsStored": float(ds.BitsStored),
                        "HighBit": float(ds.HighBit),
                        "PixelRepresentation": float(ds.PixelRepresentation),
                        "WindowCenter":  process_window_value(ds.WindowCenter),
                        "WindowWidth": process_window_value(ds.WindowWidth),
                        "RescaleIntercept": float(ds.RescaleIntercept),
                        "RescaleSlope": float(ds.RescaleSlope),
                        "RescaleType": ds.RescaleType
                    }
                    metadata_list.append(metadata)
                except Exception as e:
                    print(f"Error reading metadata from {file_path}: {e}")
    return pd.DataFrame(metadata_list)


metadata_df = read_metadata_df(data_dir)
print("Metadata DataFrame shape:", metadata_df.shape)

Metadata DataFrame shape: (84, 14)


In [6]:
metadata_df.head()

Unnamed: 0,file_path,SliceThickness,SamplesPerPixel,Photometric Interpretation,PixelSpacing,BitsAllocated,BitsStored,HighBit,PixelRepresentation,WindowCenter,WindowWidth,RescaleIntercept,RescaleSlope,RescaleType
0,/Users/yavuzalpdemirci/Desktop/data_for_testin...,4.0,1.0,MONOCHROME2,"[0.42485546875, 0.42485546875]",16.0,16.0,15.0,0.0,[35.0],[80.0],-8192.0,1.0,HU
1,/Users/yavuzalpdemirci/Desktop/data_for_testin...,5.0,1.0,MONOCHROME2,"[0.44140625, 0.44140625]",16.0,12.0,11.0,0.0,"[35, 700]","[80, 3200]",-1024.0,1.0,US
2,/Users/yavuzalpdemirci/Desktop/data_for_testin...,2.5,1.0,MONOCHROME2,"[0.488281, 0.488281]",16.0,16.0,15.0,1.0,[40.0],[100.0],-1024.0,1.0,HU
3,/Users/yavuzalpdemirci/Desktop/data_for_testin...,5.0,1.0,MONOCHROME2,"[0.449, 0.449]",16.0,16.0,15.0,1.0,[40.0],[120.0],0.0,1.0,US
4,/Users/yavuzalpdemirci/Desktop/data_for_testin...,5.0,1.0,MONOCHROME2,"[0.455078125, 0.455078125]",16.0,12.0,11.0,0.0,"[35, 700]","[80, 3200]",-1024.0,1.0,US


In [7]:
from sklearn.model_selection import train_test_split, StratifiedKFold

def split_dataset(df):
    """
    Split DataFrame into training (70%), validation (20%), and test (10%) sets.
    Stratification is applied based on the 'label' column.
    """
    # First, extract the test set (10%)
    train_val, test = train_test_split(df, test_size=0.10, stratify=df['label'], random_state=42)
    # Then split train_val into train (70%) and validation (20%)
    # Since train_val is 90%, a 22.22% split of it yields approximately 20% of the full dataset.
    train, val = train_test_split(train_val, test_size=0.2222, stratify=train_val['label'], random_state=42)
    return train, val, test

x, y, z=split_dataset(image_df)

In [8]:
print(f'Train Dataset Size: {x.shape}')
print(f'Val Dataset Size: {y.shape}')
print(f'Test Dataset Size: {z.shape}')

Train Dataset Size: (58, 2)
Val Dataset Size: (17, 2)
Test Dataset Size: (9, 2)


In [9]:
def get_stratified_kfold(df, n_splits=5):
    """
    Return stratified k-fold indices for cross validation.
    """
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    X = df.index.values
    y = df['label'].values
    return list(skf.split(X, y))

cross = get_stratified_kfold(x)

In [11]:
metadata_df['file_path'][0]

'/Users/yavuzalpdemirci/Desktop/data_for_testing/0/10022.dcm'

In [14]:
metadata_row = metadata_df[metadata_df['file_path'] == '/Users/yavuzalpdemirci/Desktop/data_for_testing/0/10023.dcm']

In [15]:
metadata_row

Unnamed: 0,file_path,SliceThickness,SamplesPerPixel,Photometric Interpretation,PixelSpacing,BitsAllocated,BitsStored,HighBit,PixelRepresentation,WindowCenter,WindowWidth,RescaleIntercept,RescaleSlope,RescaleType
1,/Users/yavuzalpdemirci/Desktop/data_for_testin...,5.0,1.0,MONOCHROME2,"[0.44140625, 0.44140625]",16.0,12.0,11.0,0.0,"[35, 700]","[80, 3200]",-1024.0,1.0,US


In [16]:
type(metadata_row['RescaleSlope'].values[0])

numpy.float64

In [17]:
metadata_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 84 entries, 0 to 83
Data columns (total 14 columns):
 #   Column                      Non-Null Count  Dtype  
---  ------                      --------------  -----  
 0   file_path                   84 non-null     object 
 1   SliceThickness              84 non-null     float64
 2   SamplesPerPixel             84 non-null     float64
 3   Photometric Interpretation  84 non-null     object 
 4   PixelSpacing                84 non-null     object 
 5   BitsAllocated               84 non-null     float64
 6   BitsStored                  84 non-null     float64
 7   HighBit                     84 non-null     float64
 8   PixelRepresentation         84 non-null     float64
 9   WindowCenter                84 non-null     object 
 10  WindowWidth                 84 non-null     object 
 11  RescaleIntercept            84 non-null     float64
 12  RescaleSlope                84 non-null     float64
 13  RescaleType                 84 non-nu

In [18]:
metadata_df.select_dtypes(include=['int', 'float']).columns

Index(['SliceThickness', 'SamplesPerPixel', 'BitsAllocated', 'BitsStored',
       'HighBit', 'PixelRepresentation', 'RescaleIntercept', 'RescaleSlope'],
      dtype='object')

In [19]:
for col in metadata_df.select_dtypes(include=['int', 'float']).columns:
    metadata_df[col] = metadata_df[col].astype('float32')

In [20]:
import numpy as np
import pandas as pd
import pydicom
import cv2 

def apply_window(hu_image, center, width):
    min_value = center - width // 2
    max_value = center + width // 2
    windowed = np.clip(hu_image, min_value, max_value)
    # Normalize to 0-1 range
    normalized = (windowed - min_value) / width
    return normalized

def preprocess_with_metadata_from_df(dicom_path, metadata_df, target_shape=(512, 512)):
    """
    Reads a DICOM file, applies windowing transformations, and resizes to a fixed shape.
    
    Args:
        dicom_path (str): Path to the DICOM file.
        metadata_df (pd.DataFrame): DataFrame containing metadata for the images.
        target_shape (tuple): Desired (height, width) for the output image.

    Returns:
        np.ndarray: Processed image with shape (512, 512, 4).
    """
    # Find the metadata row for this specific file
    metadata_row = metadata_df[metadata_df['file_path'] == dicom_path]

    if metadata_row.empty:
        raise ValueError(f"No metadata found for {dicom_path} in the DataFrame")

    # Extract metadata values from the DataFrame
    rescale_slope = metadata_row['RescaleSlope'].values[0]
    rescale_intercept = metadata_row['RescaleIntercept'].values[0]
    slice_thickness = metadata_row['SliceThickness'].values[0]
    pixel_spacing = metadata_row['PixelSpacing'].values[0][0]  # Extract first value if it's a list

    # Load DICOM and get pixel array
    ds = pydicom.dcmread(dicom_path)
    pixel_array = ds.pixel_array

    # Convert to Hounsfield Units using rescale values
    hu_image = pixel_array * float(rescale_slope) + float(rescale_intercept)

    # Apply windowing for different medical image views
    brain_window = apply_window(hu_image, center=40, width=80)
    subdural_window = apply_window(hu_image, center=80, width=200)
    stroke_window = apply_window(hu_image, center=50, width=50)

    # Create a resolution map
    resolution_map = np.ones_like(brain_window) * (slice_thickness / pixel_spacing)

    # Resize each channel to (512, 512) before stacking
    brain_window_resized = cv2.resize(brain_window, target_shape, interpolation=cv2.INTER_AREA)
    subdural_window_resized = cv2.resize(subdural_window, target_shape, interpolation=cv2.INTER_AREA)
    stroke_window_resized = cv2.resize(stroke_window, target_shape, interpolation=cv2.INTER_AREA)
    resolution_map_resized = cv2.resize(resolution_map, target_shape, interpolation=cv2.INTER_AREA)

    # Stack all channels together to maintain (512, 512, 4) shape
    multichannel_input = np.stack([
        brain_window_resized,
        subdural_window_resized,
        stroke_window_resized,
        resolution_map_resized
    ], axis=-1)

    return multichannel_input


def load_and_process_images(df, process_fn, metadata_df, target_shape=(512, 512)):
    """
    Loads and processes images from file paths using a custom processing function.
    Ensures all images have a consistent shape.
    
    Args:
        df (pd.DataFrame): DataFrame containing image file paths.
        process_fn (function): Function to preprocess images.
        metadata_df (pd.DataFrame): DataFrame containing metadata.
        target_shape (tuple): Desired image shape (height, width).

    Returns:
        np.ndarray: Numpy array of processed images with shape (N, 512, 512, 4).
    """
    processed_images = []
    
    for idx, row in df.iterrows():
        try:
            processed_image = process_fn(row['file_path'], metadata_df, target_shape)
            processed_images.append(processed_image)
        except Exception as e:
            print(f"Error processing {row['file_path']}: {e}")

    return np.array(processed_images)  # Now guaranteed to work


In [21]:
processed_images = load_and_process_images(image_df, preprocess_with_metadata_from_df, metadata_df)

In [22]:
processed_images.shape

(84, 512, 512, 4)

In [23]:
processed_images[0].shape

(512, 512, 4)