In [11]:
#dependencies

import torch
import torch.nn as nn
from monai.networks.nets import DenseNet121
from monai.transforms import Compose, LoadImage, ScaleIntensity, EnsureChannelFirst, Resize 
import pandas as pd
import pydicom
import numpy as np
from monai.transforms import Compose, EnsureChannelFirst, ScaleIntensity, Resize
from torch.utils.data import Dataset, DataLoader
import os
import matplotlib.pyplot as plt
import keras
from torch.utils.data import DataLoader, random_split
from keras import layers


In [None]:
# load the dataset

class CTDataset(Dataset):

    def __init__(self, csv_file, transform=None):
        self.dataframe = pd.read_csv(csv_file)
        self.transform = transform
        
    def __len__(self):
        return len(self.dataframe)

    
    def _load_dicom_series(self, directory_path):
        """Load all DICOM files from a directory and sort them by slice position"""
        dicom_files = []
        
        # get all .dcm files in the directory
        for filename in os.listdir(directory_path):
            if filename.endswith('.dcm'):
                filepath = os.path.join(directory_path, filename)
                try:
                    dicom = pydicom.dcmread(filepath, force=True)
                    dicom_files.append(dicom)
                except Exception as e:
                    print(f"Error reading {filepath}: {e}")
                    continue
        
        # Sort slices by instance number or z-position
        if dicom_files:
            try:
                print("success from 1")
                dicom_files.sort(key=lambda x: int(x.InstanceNumber))
            except:
                try:
                    print("success from 1")
                    dicom_files.sort(key=lambda x: float(x.ImagePositionPatient[2]))
                except:   
                    print("success from 1") 
                    dicom_files.sort(key=lambda x: x.filename)
            
            # Extract pixel data and stack into 3D volume
            slices = []
            
            for dicom in dicom_files:
                # Get pixel array
                pixel_array = dicom.pixel_array
                
                # Apply rescale slope and intercept if available
                if hasattr(dicom, 'RescaleSlope') and hasattr(dicom, 'RescaleIntercept'):
                    pixel_array = pixel_array * dicom.RescaleSlope + dicom.RescaleIntercept
                
                slices.append(pixel_array)
            
            # Stack slices to create 3D volume
            volume = np.stack(slices, axis=0)  # Shape: (depth, height, width)
            return volume
        else:
            raise ValueError(f"No DICOM files found in {directory_path}")
    
    def __getitem__(self, idx):
        
        row = self.dataframe.iloc[idx]
        dicom_dir = row['ct_folder_path']
        original_label = row['stage']
        
        if original_label == 2:
            label = 0
        elif original_label == 4:
            label = 1
        
        # Load the DICOM series as 3D volume
        volume = self._load_dicom_series(dicom_dir)
        
        # Apply transforms
        if self.transform:
            volume = self.transform(volume)
        
        return volume, torch.tensor(label, dtype=torch.long)

csv_file_path = "colorectal_ct_patients.csv"
dataset = CTDataset(csv_file=csv_file_path)

visualize = False

if visualize:
    for i in range(1, 350):
        volume, label = dataset[i]
        print(f"Patient {i}: Volume shape {volume.shape}, Label: {label}")

    
# # Define transforms for the 3D volume
# transforms = Compose([
#     EnsureChannelFirst(channel_dim='no_channel'),  # Adds channel dimension: (C, D, H, W)
#     ScaleIntensity(minv=0.0, maxv=1.0),
#     Resize(spatial_size=(64, 224, 224))  # Resize to consistent spatial size
# ])

In [13]:
# display images and show information

def plot_ct_slices_with_windowing(volume, num_rows=4, num_cols=10, window_center = 40, window_width = 400):
    
    """
    Keras-style plotting with CT windowing applied
    """

    def apply_ct_window(image, window_center, window_width):

        window_min = window_center - window_width // 2
        window_max = window_center + window_width // 2
        windowed = np.clip(image, window_min, window_max)
        normalized = (windowed - window_min) / (window_max - window_min)
        return normalized
    
    # Handle different input shapes
    if len(volume.shape) == 4:
        
        volume = volume.squeeze(0)
    
    depth, _, _ = volume.shape
    
    total_slices = num_rows * num_cols
    start_slice = max(0, depth // 2 - total_slices // 2)
    end_slice = min(depth, start_slice + total_slices)
    
    _, axes = plt.subplots(num_rows, num_cols, figsize=(15, 6))
    
    for i in range(num_rows):

        for j in range(num_cols):

            slice_idx = start_slice + i * num_cols + j
            
            if slice_idx < end_slice:

                slice_data = volume[slice_idx]
                windowed_slice = apply_ct_window(slice_data, window_center, window_width)
                
                axes[i, j].imshow(windowed_slice, cmap='gray')
                axes[i, j].set_title(f'S{slice_idx}', fontsize=8, pad=2)
                axes[i, j].axis('off')
                
            else:

                axes[i, j].axis('off')
    
    plt.suptitle(f'CT Slices with Window [C:{window_center}, W:{window_width}]', y=0.95)
    plt.tight_layout()
    plt.show()

def quick_view(dataset, patient_idx, rows=4, cols=8):

    """One-liner to view any patient"""
    volume, label = dataset[patient_idx]
    
    if hasattr(volume, 'numpy'):
        volume = volume.numpy()
    
    if len(volume.shape) == 4:
        volume = volume.squeeze(0)
    
    print(f"Patient {patient_idx}, Label: {label}, Shape: {volume.shape}")
    plot_ct_slices_with_windowing(volume, num_rows=rows, num_cols=cols)


def check_slice_details(dataset, patient_idx):
    volume, _ = dataset[patient_idx]
    
    if hasattr(volume, 'numpy'):
        volume = volume.numpy()
    if len(volume.shape) == 4:
        volume = volume.squeeze(0)
    
    depth, height, width = volume.shape
    print(f"Patient {patient_idx}:")
    print(f" - Total slices: {depth}")
    print(f" - Slice resolution: {height} × {width} pixels")
    print(f" - HU value range: [{volume.min():.0f}, {volume.max():.0f}]")
    

show_info = False

if show_info:
    check_slice_details(dataset, 111)
    quick_view(dataset, 111)
    check_slice_details(dataset, 400)
    quick_view(dataset, 400)

In [None]:
def get_model(width = 224, height = 224, depth = 64):
    """Build a 3D convolutional neural network model."""

    inputs = keras.Input((depth, height, width, 1))

    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.GlobalAveragePooling3D()(x)
    x = layers.Dense(units=512, activation="relu")(x)
    x = layers.Dropout(0.3)(x)

    outputs = layers.Dense(units=1, activation="sigmoid")(x)

    # Define the model.
    model = keras.Model(inputs, outputs, name="3dcnn")
    return model


# Build model.
model = get_model(width = 224, height = 224, depth = 64)
model.summary()


In [None]:
# Compile model
model.compile(
    loss="sparse_categorical_crossentropy",  # Use sparse since labels are 0,1
    optimizer=keras.optimizers.Adam(learning_rate=0.0001),
    metrics=["accuracy"]
)

# Define callbacks
checkpoint_cb = keras.callbacks.ModelCheckpoint(
    "best_3d_ct_model.keras", 
    save_best_only=True,
    monitor='val_accuracy'
)

early_stopping_cb = keras.callbacks.EarlyStopping(
    monitor="val_accuracy", 
    patience=15,
    restore_best_weights=True
)


In [None]:
# Convert PyTorch DataLoader to Keras-compatible format
def pytorch_to_keras_generator(pytorch_loader):
    """Convert PyTorch DataLoader to Keras generator"""
    while True:
        for images, labels in pytorch_loader:
            # Convert PyTorch tensors to numpy arrays
            images_np = images.numpy()
            labels_np = labels.numpy()
            
            # Keras expects channels_last: (batch, depth, height, width, channels)
            images_np = np.transpose(images_np, (0, 2, 3, 4, 1))
            
            yield images_np, labels_np

# Create Keras-compatible generators
train_gen = pytorch_to_keras_generator(train_loader)
val_gen = pytorch_to_keras_generator(val_loader)

# Calculate steps per epoch
train_steps = len(train_loader)
val_steps = len(val_loader)

# Now train with generators
history = model.fit(
    train_gen,
    steps_per_epoch=train_steps,
    validation_data=val_gen,
    validation_steps=val_steps,
    epochs=100,
    verbose=1,
    callbacks=[checkpoint_cb, early_stopping_cb]
)