<a href="https://colab.research.google.com/github/TAUforPython/BioMedAI/blob/main/CNN%20DWT%20image%20diabet%20retinopathy%20classification%20.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Okay, here is the Google Colab code based on your requirements. It includes the CNN class with DWT preprocessing, image-by-image loading, visualization, training, evaluation, and confusion matrix computation.

Diabetic Retinopathy Classification using CNN and Discrete Wavelet Transform (DWT)

This notebook demonstrates a pipeline for classifying diabetic retinopathy stages
using a Convolutional Neural Network (CNN) with Discrete Wavelet Transform (DWT)
preprocessing. The model is trained on the APTOS 2019 dataset.
The code processes images one-by-one to manage memory efficiently.

https://www.kaggle.com/competitions/aptos2019-blindness-detection/data?select=train.csv

In [1]:
# Install PyWavelets for DWT
!pip install PyWavelets -q

In [2]:
!pip install kaggle -q

In [21]:
from google.colab import userdata

# Get Hugging Face and PubMed token from environment
KG_TOKEN_NAME = userdata.get("KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME")
KAGGLE_KEY_TOKEN = userdata.get("KAGGLE_KEY")

if KAGGLE_KEY_TOKEN is None:
  print("Kaggle API need to be resolved.")
else:
  print("Kaggle API configured successfully.")

Kaggle API configured successfully.


In [22]:
import numpy as np
import pandas as pd
import tensorflow as tf
import cv2
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pywt
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
import zipfile

# Check for GPU availability
print("Num GPUs Available:", len(tf.config.experimental.list_physical_devices('GPU')))
if tf.config.experimental.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(tf.config.experimental.list_physical_devices('GPU')[0], True)


Num GPUs Available: 0


In [None]:
# Define the dataset name and download path
DATASET_NAME = "aptos2019-blindness-detection"
DOWNLOAD_DIR = "/content/kaggle_dataset"

# Create download directory
os.makedirs(DOWNLOAD_DIR, exist_ok=True)

# Download the dataset using Kaggle API
!kaggle competitions download -c {DATASET_NAME} -p {DOWNLOAD_DIR}

# Extract the downloaded zip file
zip_file_path = os.path.join(DOWNLOAD_DIR, f"{DATASET_NAME}.zip")
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(DOWNLOAD_DIR)

print(f"Dataset extracted to: {DOWNLOAD_DIR}")

Traceback (most recent call last):
  File "/usr/local/bin/kaggle", line 4, in <module>
    from kaggle.cli import main
  File "/usr/local/lib/python3.12/dist-packages/kaggle/__init__.py", line 6, in <module>
    api.authenticate()
  File "/usr/local/lib/python3.12/dist-packages/kaggle/api/kaggle_api_extended.py", line 434, in authenticate
    raise IOError('Could not find {}. Make sure it\'s located in'
OSError: Could not find kaggle.json. Make sure it's located in /root/.config/kaggle. Or use the environment method. See setup instructions at https://github.com/Kaggle/kaggle-api/


FileNotFoundError: [Errno 2] No such file or directory: '/content/kaggle_dataset/aptos2019-blindness-detection.zip'

In [None]:
# ## 2. Data Preparation and Loading
# This section handles downloading and preparing the dataset paths.
# We assume you have uploaded the train.csv file to your Colab environment.

# %%
# Define paths
# You need to upload your train.csv file to the Colab environment first
# You will need to download the train.zip from Kaggle and upload it here too
# For demonstration, we assume the train images are in a folder named 'train_images'
# which you need to upload or mount from your Drive containing the Kaggle data.
TRAIN_CSV_PATH = "train.csv" # Update if necessary
TRAIN_IMAGES_DIR = "train_images" # Update this path to where your images are located
# Example: TRAIN_IMAGES_DIR = "/content/drive/MyDrive/kaggle_data/aptos2019-blindness-detection/train_images"

# Load the training labels
df_train = pd.read_csv(TRAIN_CSV_PATH)
print(f"Loaded training CSV with {len(df_train)} samples.")
print(df_train.head())

# Ensure the image paths are correctly formed
df_train["image_path"] = df_train["id_code"].apply(lambda x: os.path.join(TRAIN_IMAGES_DIR, x + ".png"))

# Check if all images exist
missing_images = df_train[~df_train["image_path"].apply(os.path.exists)]
if not missing_images.empty:
    print(f"Warning: {len(missing_images)} images listed in CSV were not found in {TRAIN_IMAGES_DIR}.")
    print("Example missing images:")
    print(missing_images.head())
    # Remove missing images
    df_train = df_train[df_train["image_path"].apply(os.path.exists)]

print(f"Final training DataFrame shape after checking for files: {df_train.shape}")

In [None]:
# ## 3. DWT Preprocessing and Visualization
# This section defines functions for DWT preprocessing and visualizes the process.

# %%
def apply_dwt_and_visualize(image_path, wavelet='db1', levels=2, img_size=(224, 224)):
    """
    Applies Discrete Wavelet Transform (DWT) to an image and visualizes the decomposition.

    Args:
        image_path (str): Path to the input image.
        wavelet (str): Wavelet type to use (e.g., 'db1', 'haar').
        levels (int): Number of DWT levels to perform.
        img_size (tuple): Target size to resize the image.

    Returns:
        np.ndarray: The reconstructed image from the final level approximation and details.
    """
    # Read and resize the image
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    if image is None:
        raise ValueError(f"Unable to read image at path: {image_path}")
    image = cv2.resize(image, img_size)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

    # Convert to grayscale for DWT visualization (optional, can apply to each channel)
    gray_image = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0

    # Perform DWT
    coeffs = pywt.wavedec2(gray_image, wavelet, level=levels)
    cA_n, (cH_n, cV_n, cD_n) = coeffs[0], coeffs[-1] # Approximation and last level details

    # Visualize the original image and the DWT components
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    fig.suptitle(f'DWT Decomposition (Levels={levels}, Wavelet={wavelet}) - {os.path.basename(image_path)}', fontsize=14)

    axes[0, 0].imshow(gray_image, cmap='gray')
    axes[0, 0].set_title('Original Grayscale Image')
    axes[0, 0].axis('off')

    axes[0, 1].imshow(cA_n, cmap='gray')
    axes[0, 1].set_title(f'Approximation Coefficients (Level {levels})')
    axes[0, 1].axis('off')

    axes[1, 0].imshow(cH_n, cmap='gray')
    axes[1, 0].set_title(f'Horizontal Details (Level {levels})')
    axes[1, 0].axis('off')

    axes[1, 1].imshow(cV_n, cmap='gray')
    axes[1, 1].set_title(f'Vertical Details (Level {levels})')
    axes[1, 1].axis('off')

    # Optionally, show diagonal details
    # axes[1, 2].imshow(cD_n, cmap='gray')
    # axes[1, 2].set_title(f'Diagonal Details (Level {levels})')
    # axes[1, 2].axis('off')

    plt.tight_layout()
    plt.show()

    # For the actual preprocessing, we might just use the approximation coefficients
    # or reconstruct a lower-resolution version. Here, we reconstruct the full image
    # from the final level coefficients for demonstration.
    reconstructed_gray = pywt.waverec2(coeffs, wavelet)
    # Ensure the reconstructed image matches the original size
    reconstructed_gray = cv2.resize(reconstructed_gray, img_size)

    return reconstructed_gray

def apply_dwt_and_save(image_path, save_dir, wavelet='db1', levels=2, img_size=(224, 224)):
    """
    Applies DWT to an image and saves the processed version.

    Args:
        image_path (str): Path to the input image.
        save_dir (str): Directory to save the processed image.
        wavelet (str): Wavelet type to use.
        levels (int): Number of DWT levels.
        img_size (tuple): Target size for the image.

    Returns:
        str: Path to the saved processed image.
    """
    os.makedirs(save_dir, exist_ok=True)

    # Read and resize the image
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    if image is None:
        raise ValueError(f"Unable to read image at path: {image_path}")
    image = cv2.resize(image, img_size)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

    # Apply DWT to each channel (RGB)
    processed_channels = []
    for i in range(3): # Iterate over R, G, B channels
        channel = image[:, :, i]
        coeffs = pywt.wavedec2(channel, wavelet, level=levels)
        # Reconstruct from coefficients (can modify coefficients here if needed)
        reconstructed_channel = pywt.waverec2(coeffs, wavelet)
        # Ensure size matches
        reconstructed_channel = cv2.resize(reconstructed_channel, img_size)
        processed_channels.append(reconstructed_channel)

    # Stack the processed channels back
    processed_image = np.stack(processed_channels, axis=-1)

    # Clip values to [0, 1] range after reconstruction
    processed_image = np.clip(processed_image, 0.0, 1.0)

    # Save the processed image
    base_name = os.path.basename(image_path)
    save_path = os.path.join(save_dir, base_name)
    # Convert back to BGR for OpenCV saving
    processed_image_bgr = cv2.cvtColor((processed_image * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
    cv2.imwrite(save_path, processed_image_bgr)

    return save_path


# %%
# Example visualization of DWT on a sample image
sample_image_path = df_train.iloc[0]["image_path"]
print(f"Sample image path: {sample_image_path}")
if os.path.exists(sample_image_path):
    dwt_result = apply_dwt_and_visualize(sample_image_path, wavelet='db1', levels=2)
    print("DWT visualization complete.")
else:
    print(f"Sample image does not exist: {sample_image_path}")


# %% [markdown]
# ## 4. Custom Dataset Generator (Image-by-Image)
# This generator loads and preprocesses images on-the-fly during training.

# %%
class DWTImageGenerator(tf.keras.utils.Sequence):
    """
    Custom Keras Sequence for loading and preprocessing images using DWT on-the-fly.
    """
    def __init__(self, image_paths, labels, batch_size=32, img_size=(224, 224), wavelet='db1', levels=2, shuffle=True):
        self.image_paths = image_paths
        self.labels = labels
        self.batch_size = batch_size
        self.img_size = img_size
        self.wavelet = wavelet
        self.levels = levels
        self.shuffle = shuffle
        self.indices = np.arange(len(self.image_paths))
        self.on_epoch_end()

    def __len__(self):
        """Denotes the number of batches per epoch."""
        return int(np.ceil(len(self.image_paths) / self.batch_size))

    def __getitem__(self, index):
        """Generate one batch of data."""
        # Generate indices for the batch
        start_idx = index * self.batch_size
        end_idx = min((index + 1) * self.batch_size, len(self.indices))
        batch_indices = self.indices[start_idx:end_idx]

        # Generate data
        batch_images = np.empty((len(batch_indices), *self.img_size, 3))
        batch_labels = np.empty((len(batch_indices)), dtype=int)

        for i, idx in enumerate(batch_indices):
            img_path = self.image_paths.iloc[idx]
            label = self.labels.iloc[idx]

            # Load and preprocess image using DWT
            image = cv2.imread(img_path, cv2.IMREAD_COLOR)
            if image is None:
                 # Handle potential loading errors, maybe load a default or skip
                 # For simplicity, we'll load a blank image here, but robust error handling is better
                 image = np.zeros((*self.img_size, 3), dtype=np.uint8)
                 print(f"Warning: Could not load image {img_path}. Using blank image.")
            else:
                image = cv2.resize(image, self.img_size)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

                # Apply DWT to each channel
                processed_channels = []
                for ch in range(3):
                    channel = image[:, :, ch]
                    coeffs = pywt.wavedec2(channel, self.wavelet, level=self.levels)
                    reconstructed_channel = pywt.waverec2(coeffs, self.wavelet)
                    reconstructed_channel = cv2.resize(reconstructed_channel, self.img_size)
                    processed_channels.append(reconstructed_channel)
                image = np.stack(processed_channels, axis=-1)
                image = np.clip(image, 0.0, 1.0) # Clip values after reconstruction

            batch_images[i] = image
            batch_labels[i] = label

        return batch_images, batch_labels

    def on_epoch_end(self):
        """Updates indices after each epoch."""
        if self.shuffle:
            np.random.shuffle(self.indices)


# %% [markdown]
# ## 5. Define the CNN Model

# %%
def build_cnn_model(input_shape=(224, 224, 3), num_classes=5):
    """
    Builds a simple CNN model for classification.
    """
    inputs = tf.keras.Input(shape=input_shape)

    # Convolutional Block 1
    x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='conv1_1')(inputs)
    x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='conv1_2')(x)
    x = tf.keras.layers.MaxPooling2D((2, 2), name='pool1')(x)

    # Convolutional Block 2
    x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='conv2_1')(x)
    x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='conv2_2')(x)
    x = tf.keras.layers.MaxPooling2D((2, 2), name='pool2')(x)

    # Convolutional Block 3
    x = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same', name='conv3_1')(x)
    x = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same', name='conv3_2')(x)
    x = tf.keras.layers.MaxPooling2D((2, 2), name='pool3')(x)

    # Global Average Pooling instead of Flatten + Dense for reduction
    x = tf.keras.layers.GlobalAveragePooling2D(name='global_avg_pool')(x)
    x = tf.keras.layers.Dropout(0.5, name='dropout1')(x)

    # Dense Layer
    x = tf.keras.layers.Dense(256, activation='relu', name='dense1')(x)
    x = tf.keras.layers.Dropout(0.5, name='dropout2')(x)

    # Output Layer
    outputs = tf.keras.layers.Dense(num_classes, activation='softmax', name='predictions')(x)

    model = tf.keras.Model(inputs, outputs, name='DWT_CNN_Model')
    return model


# %% [markdown]
# ## 6. Split Data and Create Generators

# %%
# Split the data
X_paths = df_train['image_path']
y_labels = df_train['diagnosis']

# Use stratify to maintain class distribution
X_train_paths, X_temp_paths, y_train_labels, y_temp_labels = train_test_split(
    X_paths, y_labels, test_size=0.3, random_state=42, stratify=y_labels
)
X_val_paths, X_test_paths, y_val_labels, y_test_labels = train_test_split(
    X_temp_paths, y_temp_labels, test_size=0.5, random_state=42, stratify=y_temp_labels # 0.5 x 0.3 = 0.15 for val, 0.15 for test
)

print(f"Training samples: {len(X_train_paths)}")
print(f"Validation samples: {len(X_val_paths)}")
print(f"Test samples: {len(X_test_paths)}")

# Compute class weights to handle imbalance
class_weights_raw = compute_class_weight(class_weight='balanced', classes=np.unique(y_train_labels), y=y_train_labels)
class_weights_dict = dict(enumerate(class_weights_raw))
print(f"Computed class weights: {class_weights_dict}")

# Create generators
BATCH_SIZE = 16 # Reduced batch size due to potential memory constraints in Colab
IMG_SIZE = (224, 224)
WAVELET_TYPE = 'db1'
DWT_LEVELS = 2

train_gen = DWTImageGenerator(
    X_train_paths, y_train_labels,
    batch_size=BATCH_SIZE,
    img_size=IMG_SIZE,
    wavelet=WAVELET_TYPE,
    levels=DWT_LEVELS,
    shuffle=True
)
val_gen = DWTImageGenerator(
    X_val_paths, y_val_labels,
    batch_size=BATCH_SIZE,
    img_size=IMG_SIZE,
    wavelet=WAVELET_TYPE,
    levels=DWT_LEVELS,
    shuffle=False # Don't shuffle validation data
)
test_gen = DWTImageGenerator(
    X_test_paths, y_test_labels,
    batch_size=BATCH_SIZE,
    img_size=IMG_SIZE,
    wavelet=WAVELET_TYPE,
    levels=DWT_LEVELS,
    shuffle=False # Don't shuffle test data
)


# %% [markdown]
# ## 7. Train the Model

# %%
# Build the model
model = build_cnn_model(input_shape=(*IMG_SIZE, 3), num_classes=5)
model.summary()

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss='sparse_categorical_crossentropy', # Use sparse for integer labels
    metrics=['accuracy']
)

# Define callbacks
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True,
    verbose=1
)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.2,
    patience=3,
    min_lr=1e-7,
    verbose=1
)

# Train the model
EPOCHS = 20 # Adjust based on your resources and time
history = model.fit(
    train_gen,
    epochs=EPOCHS,
    validation_data=val_gen,
    class_weight=class_weights_dict,
    callbacks=[early_stopping, reduce_lr],
    verbose=1
)

# %% [markdown]
# ## 8. Evaluate the Model

# %%
# Evaluate on the test set
test_loss, test_acc = model.evaluate(test_gen, verbose=1)
print(f"Test Accuracy: {test_acc * 100:.2f}%")
print(f"Test Loss: {test_loss:.4f}")

# %% [markdown]
# ## 9. Plot Training History

# %%
# Plot training history
def plot_training_history(history):
    """Plots the training and validation accuracy and loss."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    # Plot accuracy
    ax1.plot(history.history['accuracy'], label='Training Accuracy')
    ax1.plot(history.history['val_accuracy'], label='Validation Accuracy')
    ax1.set_title('Model Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
    ax1.grid(True)

    # Plot loss
    ax2.plot(history.history['loss'], label='Training Loss')
    ax2.plot(history.history['val_loss'], label='Validation Loss')
    ax2.set_title('Model Loss')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.show()

plot_training_history(history)

# %% [markdown]
# ## 10. Compute and Plot Confusion Matrix

# %%
# Get predictions for the test set
print("Generating predictions for confusion matrix...")
test_predictions = model.predict(test_gen, verbose=1)
predicted_classes = np.argmax(test_predictions, axis=1)
true_classes = y_test_labels.values # Get the actual labels from the series

# Compute confusion matrix
cm = confusion_matrix(true_classes, predicted_classes)

# Plot confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR'],
            yticklabels=['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR'])
plt.title('Confusion Matrix - Test Set')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()

# %% [markdown]
# ## 11. Example Prediction on a Single Image

# %%
def predict_single_image(model, image_path, wavelet='db1', levels=2, img_size=(224, 224)):
    """
    Predicts the class for a single image.
    """
    # Load and preprocess
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    if image is None:
        raise ValueError(f"Could not load image: {image_path}")
    image = cv2.resize(image, img_size)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

    # Apply DWT
    processed_channels = []
    for ch in range(3):
        channel = image[:, :, ch]
        coeffs = pywt.wavedec2(channel, wavelet, level=levels)
        reconstructed_channel = pywt.waverec2(coeffs, wavelet)
        reconstructed_channel = cv2.resize(reconstructed_channel, img_size)
        processed_channels.append(reconstructed_channel)
    image = np.stack(processed_channels, axis=-1)
    image = np.clip(image, 0.0, 1.0)

    # Expand dimensions for batch prediction
    image_batch = np.expand_dims(image, axis=0)

    # Predict
    prediction_probs = model.predict(image_batch, verbose=0)[0]
    predicted_class = np.argmax(prediction_probs)
    confidence = prediction_probs[predicted_class]

    class_names = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR']
    print(f"Predicted Class: {class_names[predicted_class]} (Index: {predicted_class})")
    print(f"Confidence: {confidence*100:.2f}%")
    print(f"All Probabilities: {prediction_probs}")

    # Visualize the original and processed image
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    original_img = cv2.imread(image_path, cv2.IMREAD_COLOR)
    original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
    ax1.imshow(original_img)
    ax1.set_title('Original Image')
    ax1.axis('off')

    ax2.imshow(image)
    ax2.set_title('DWT Preprocessed Image')
    ax2.axis('off')
    plt.show()

    return predicted_class, confidence, prediction_probs

# Example prediction (use a path from your test set or another image you have)
example_image_path = X_test_paths.iloc[0] # Use the first image from the test set
print(f"Predicting for image: {example_image_path}")
predict_single_image(model, example_image_path, wavelet=WAVELET_TYPE, levels=DWT_LEVELS, img_size=IMG_SIZE)