In [None]:
# Mount Google Drive (optional, to save model)
from google.colab import drive
drive.mount('/content/drive')

# Install kagglehub (if not already installed)
!pip install kagglehub gradio

# Authenticate for Kaggle
import os
from google.colab import files

# Upload your kaggle.json if you have it
try:
    uploaded = files.upload()  # Upload your kaggle.json
    !mkdir -p ~/.kaggle
    !cp kaggle.json ~/.kaggle/ #{"username":"","key":""}
    !chmod 600 ~/.kaggle/kaggle.json
except:
    # Or generate from credentials
    os.environ['KAGGLE_USERNAME'] = "ithree"
    os.environ['KAGGLE_KEY'] = "329d53d4f60bc3dd2f9de1fec6d5b363"

In [None]:
from tensorflow.keras.layers import Lambda, MaxPooling2D, Concatenate
import tensorflow.keras.backend as K

In [None]:
#HEATMAP !!!!!!
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import kagglehub
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Conv2D, GlobalAveragePooling2D, Dense, Dropout, BatchNormalization
from tensorflow.keras.layers import Lambda, MaxPooling2D, Concatenate
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
import tensorflow.keras.backend as K
import seaborn as sns
import gradio as gr
import tensorflow as tf
import matplotlib.cm as cm
import json

# Parameters
INPUT_SHAPE = (224, 224, 4)  # RGB + edge detection channel
NUM_CLASSES = 2  # Just PCOS and Normal (from the Kaggle dataset)
LEARNING_RATE = 0.0001
BATCH_SIZE = 32
EPOCHS = 50

# Define the conditions we're working with (only PCOS and normal based on available Kaggle data)
CONDITIONS = ["normal", "pcos"]

# Module 1: Dataset Acquisition
def download_datasets():
    """Download datasets from Kaggle."""
    datasets = {}

    # Download PCOS dataset
    print("Downloading PCOS dataset...")
    pcos_path = kagglehub.dataset_download("shnotweta/2000-images-of-ultrasound-for-pcos")
    datasets["pcos"] = {
        "type": "image",
        "path": os.path.join(pcos_path, "dataset", "pcos")
    }

    # Add normal samples path
    datasets["normal"] = {
        "type": "image",
        "path": os.path.join(pcos_path, "dataset", "normal")
    }

    print(f"PCOS dataset downloaded to: {pcos_path}")
    return datasets

# Module 2: Image Processing
def apply_sobel_filter(image):
    """Apply Sobel filter for edge detection."""
    # Convert to grayscale if the image is in color
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    else:
        gray = image

    # Ensure grayscale image is in uint8 format
    gray = np.clip(gray, 0, 255).astype(np.uint8)

    # Apply Sobel filter in X and Y directions
    sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
    sobel_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)

    # Compute gradient magnitude
    sobel_magnitude = np.sqrt(sobel_x**2 + sobel_y**2)

    # Normalize to 8-bit image
    sobel_magnitude = cv2.normalize(sobel_magnitude, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

    return sobel_magnitude

def load_images_from_directory(directory, target_size=(224, 224), max_images=None):
    """Load images from a directory with optional limit."""
    images = []
    file_paths = [os.path.join(directory, filename) for filename in os.listdir(directory)
                 if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff'))]

    if max_images:
        file_paths = file_paths[:max_images]

    print(f"Loading {len(file_paths)} images from {directory}")

    for filepath in file_paths:
        try:
            img = cv2.imread(filepath)
            if img is not None:
                img = cv2.resize(img, target_size)
                images.append(img)
        except Exception as e:
            print(f"Error loading {filepath}: {e}")

    return np.array(images)

def preprocess_images(images):
    """Apply preprocessing to images: convert to RGB and add edge detection channel."""
    processed_images = []
    for img in images:
        high_pass = apply_sobel_filter(img)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_4ch = np.dstack((img_rgb, high_pass))
        processed_images.append(img_4ch)
    return np.array(processed_images)

def preprocess_single_image(image_path, target_size=(224, 224)):
    """Preprocess a single image for prediction."""
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"Could not load image from {image_path}")

    img = cv2.resize(img, target_size)
    high_pass = apply_sobel_filter(img)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_4ch = np.dstack((img_rgb, high_pass))
    return np.expand_dims(img_4ch, axis=0)  # Add batch dimension

def preprocess_direct_image(img, target_size=(224, 224)):
    """Preprocess an image directly from memory."""
    if img is None:
        raise ValueError("No image provided")

    img = cv2.resize(img, target_size)
    high_pass = apply_sobel_filter(img)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if len(img.shape) == 3 else np.stack([img]*3, axis=-1)
    img_4ch = np.dstack((img_rgb, high_pass))
    return np.expand_dims(img_4ch, axis=0)  # Add batch dimension

# Module 3: Data Preparation
def prepare_data(datasets):
    """Prepare data for model training."""
    all_images = []
    all_labels = []

    # Process each condition
    for condition, dataset_info in datasets.items():
        if os.path.exists(dataset_info["path"]):
            print(f"Processing {condition} images...")
            images = load_images_from_directory(dataset_info["path"])
            if len(images) > 0:
                processed_images = preprocess_images(images)
                all_images.append(processed_images)
                all_labels.append(np.full(len(processed_images), condition))
            else:
                print(f"No images found for {condition}")
        else:
            print(f"Path not found: {dataset_info['path']}")

    # Combine data from different conditions
    X_images = np.concatenate(all_images) if all_images else np.array([])
    y_text = np.concatenate(all_labels) if all_labels else np.array([])

    # Convert text labels to indices
    label_to_index = {condition: idx for idx, condition in enumerate(CONDITIONS)}
    y = np.array([label_to_index[label] for label in y_text])
    y_categorical = to_categorical(y, num_classes=len(CONDITIONS))

    print(f"Loaded {len(X_images)} images across {len(label_to_index)} classes")
    return X_images, y_categorical, label_to_index

# Module 4: Model Building
def build_model(input_shape, num_classes):
    """Build the CNN model with EfficientNetB3 backbone."""
    # Image input
    image_input = Input(shape=input_shape, name='image_input')

    # Extract RGB and edge channels
    rgb_channels = Lambda(lambda x: x[:, :, :, :3])(image_input)
    edge_channel = Lambda(lambda x: x[:, :, :, 3:4])(image_input)

    # Use EfficientNet with RGB channels
    base_model = EfficientNetB3(weights='imagenet', include_top=False, input_tensor=rgb_channels)

    # Process edge channel separately
    edge_x = Conv2D(16, (3, 3), padding='same', activation='relu')(edge_channel)
    edge_x = Conv2D(16, (3, 3), padding='same', activation='relu')(edge_x)
    edge_x = MaxPooling2D(pool_size=(2, 2))(edge_x)

    # Get features from base model
    x = base_model.output

    # Resize edge features to match base model output
    target_shape = K.int_shape(x)[1:3]
    edge_x = Conv2D(16, (3, 3), padding='same', activation='relu')(edge_x)
    edge_x = GlobalAveragePooling2D()(edge_x)
    edge_features = Dense(64, activation='relu')(edge_x)

    # Main feature extraction from RGB path
    x = GlobalAveragePooling2D()(x)
    x = BatchNormalization()(x)
    rgb_features = Dense(512, activation='relu')(x)

    # Combine RGB and edge features
    combined = Concatenate()([rgb_features, edge_features])

    # Classification layers
    x = Dropout(0.5)(combined)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.3)(x)
    outputs = Dense(num_classes, activation='softmax', name='predictions')(x)

    model = Model(inputs=image_input, outputs=outputs)

    # Freeze early layers for transfer learning
    for layer in base_model.layers[:-20]:
        layer.trainable = False

    return model


# Module 5: Training and Evaluation
def train_model(model, X_train, y_train, X_val, y_val):
    """Train the model with early stopping and learning rate reduction."""
    model.compile(
        optimizer=Adam(learning_rate=LEARNING_RATE),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )

    # Callbacks
    callbacks = [
        EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6),
        ModelCheckpoint('best_pcos_model.h5', monitor='val_accuracy', save_best_only=True, mode='max')
    ]

    # Train
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        callbacks=callbacks
    )

    return model, history

def evaluate_model(model, X_test, y_test, label_to_index):
    """Evaluate the model and display metrics."""
    # Evaluate on test set
    print("Evaluating model...")
    test_loss, test_accuracy = model.evaluate(X_test, y_test)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")

    # Predictions
    y_pred_prob = model.predict(X_test)
    y_pred_classes = np.argmax(y_pred_prob, axis=1)
    y_true_classes = np.argmax(y_test, axis=1)

    # Classification report
    index_to_label = {v: k for k, v in label_to_index.items()}
    target_names = [index_to_label[i] for i in range(len(CONDITIONS))]

    print("\nClassification Report:")
    print(classification_report(y_true_classes, y_pred_classes, target_names=target_names))

    # Confusion matrix
    conf_matrix = confusion_matrix(y_true_classes, y_pred_classes)
    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues",
                xticklabels=target_names, yticklabels=target_names)
    plt.xlabel("Predicted Labels")
    plt.ylabel("True Labels")
    plt.title("Confusion Matrix")
    plt.tight_layout()
    plt.savefig("confusion_matrix.png")
    plt.close()

    # ROC curves
    plt.figure(figsize=(8, 6))
    for i in range(len(CONDITIONS)):
        fpr, tpr, _ = roc_curve(y_test[:, i], y_pred_prob[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f'{target_names[i]} (AUC = {roc_auc:.2f})')

    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves')
    plt.legend(loc="lower right")
    plt.savefig("roc_curves.png")
    plt.close()

    return test_accuracy

# Module 6: Prediction Functions
def predict_image(model, image_path, label_to_index):
    """Make a prediction for a single image."""
    # Preprocess the image
    processed_image = preprocess_single_image(image_path)

    # Make prediction
    prediction = model.predict(processed_image)[0]

    # Convert indices to labels
    index_to_label = {v: k for k, v in label_to_index.items()}

    # Create result dictionary
    results = {index_to_label[i]: float(prediction[i]) for i in range(len(prediction))}

    # Get the predicted class
    predicted_class = index_to_label[np.argmax(prediction)]
    confidence = float(np.max(prediction))

    return results, predicted_class, confidence

# NEW MODULE: Grad-CAM for Heatmap Generation
def make_gradcam_heatmap(model, img_array, pred_index=None):
    """
    Generate a Grad-CAM heatmap for the given image and model.

    Args:
        model: The model to use for prediction
        img_array: The preprocessed image as a numpy array
        pred_index: The index of the class to visualize (None means the predicted class)

    Returns:
        The heatmap as a numpy array
    """
    # Create a model that maps the input image to the activations
    # of the last conv layer and the output predictions
    last_conv_layer = None

    # Find the last convolutional layer
    for layer in reversed(model.layers):
        if isinstance(layer, tf.keras.layers.Conv2D):
            last_conv_layer = layer
            break

    if last_conv_layer is None:
        # If we can't find a conv layer, try to use the last layer of the EfficientNet backbone
        for layer in model.layers:
            if isinstance(layer, tf.keras.models.Model):  # EfficientNet is a Model
                base_model = layer
                # Get the last conv layer from the base model
                for l in reversed(base_model.layers):
                    if isinstance(l, tf.keras.layers.Conv2D):
                        last_conv_layer = l
                        break
                break

    if last_conv_layer is None:
        print("Could not find a convolutional layer for Grad-CAM")
        return None

    # Get the gradients of the predicted class (or specified class) with respect to the last conv layer
    with tf.GradientTape() as tape:
        # Create a model that outputs both the final prediction and the activations of the last conv layer
        grad_model = tf.keras.models.Model(
            inputs=[model.inputs],
            outputs=[model.get_layer(last_conv_layer.name).output, model.output]
        )

        # Get the activations of the last conv layer and predictions
        last_conv_layer_output, preds = grad_model(img_array)

        if pred_index is None:
            pred_index = tf.argmax(preds[0])

        # Get the score for the predicted class
        class_channel = preds[:, pred_index]

    # This is the gradient of the predicted class with respect to
    # the output feature map of the last conv layer
    grads = tape.gradient(class_channel, last_conv_layer_output)

    # Vector of mean intensity of gradient over feature map
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    # Multiply each channel in the feature map with its importance
    last_conv_layer_output = last_conv_layer_output[0]
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    # Normalize the heatmap
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    heatmap = heatmap.numpy()

    return heatmap

def generate_heatmap_overlay(image_path, model, label_to_index):
    """
    Generate a heatmap overlay for the given image.

    Args:
        image_path: Path to the image file
        model: The model to use for prediction
        label_to_index: Dictionary mapping label names to indices

    Returns:
        Tuple of (original image, heatmap overlay, predicted class, confidence)
    """
    # Load the image
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"Could not load image from {image_path}")

    # Convert to RGB for display
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Preprocess the image
    img_processed = preprocess_single_image(image_path)

    # Make a prediction
    preds = model.predict(img_processed)
    pred_class_idx = np.argmax(preds[0])
    confidence = np.max(preds[0])

    # Convert indices to labels
    index_to_label = {v: k for k, v in label_to_index.items()}
    predicted_class = index_to_label[pred_class_idx]

    # Generate the heatmap
    heatmap = make_gradcam_heatmap(model, img_processed, pred_class_idx)

    if heatmap is None:
        return img_rgb, img_rgb, predicted_class, confidence

    # Resize the heatmap to match the image size
    heatmap = cv2.resize(heatmap, (img_rgb.shape[1], img_rgb.shape[0]))

    # Create a colored heatmap
    heatmap_colored = cm.jet(heatmap)[:, :, :3]  # Remove alpha channel
    heatmap_colored = (heatmap_colored * 255).astype(np.uint8)

    # Create an overlay image
    overlay = img_rgb.copy()
    alpha = 0.5  # Transparency factor

    # Overlay the heatmap on the image
    cv2.addWeighted(heatmap_colored, alpha, img_rgb, 1 - alpha, 0, overlay)

    return img_rgb, overlay, predicted_class, confidence

# Module 7: Visualization Functions
def plot_training_history(history):
    """Plot the training history."""
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.tight_layout()
    plt.savefig("training_history.png")
    plt.close()

# Module 8: Gradio Frontend with Heatmap Visualization
def create_frontend(model_path, label_to_index):
    """Create a Gradio frontend for model inference with heatmap visualization."""
    def predict_fn(image):
        # Save the uploaded image temporarily
        temp_path = "temp_upload.jpg"
        cv2.imwrite(temp_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

        # Load model if needed
        if isinstance(model_path, str):
            model = load_model(model_path)
        else:
            model = model_path

        # Make prediction
        results, predicted_class, confidence = predict_image(model, temp_path, label_to_index)

        # Generate heatmap
        original_img, heatmap_overlay, _, _ = generate_heatmap_overlay(temp_path, model, label_to_index)

        # Save heatmap for download
        heatmap_path = "heatmap_explanation.png"
        plt.figure(figsize=(12, 6))

        plt.subplot(1, 2, 1)
        plt.imshow(original_img)
        plt.title("Original Image")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(heatmap_overlay)
        plt.title("Activation Heatmap")
        plt.axis('off')

        plt.suptitle(f"Prediction: {predicted_class} (Confidence: {confidence:.2f})")
        plt.tight_layout()
        plt.savefig(heatmap_path)
        plt.close()

        # Create probability bar chart
        fig = plt.figure(figsize=(8, 4))
        plt.bar(list(results.keys()), list(results.values()), color=['blue' if v != np.max(list(results.values())) else 'red' for v in results.values()])
        plt.ylim(0, 1)
        plt.ylabel('Probability')
        plt.title(f'Prediction: {predicted_class} (Confidence: {confidence:.2f})')
        plt.tight_layout()

        # Create explanation of the heatmap
        explanation = f"""
        ## Detected condition: {predicted_class}
        Confidence: {confidence:.2%}

        ### Explanation:
        The heatmap shows which areas of the image influenced the model's prediction.
        Warmer colors (red/yellow) indicate areas that strongly contributed to the prediction,
        while cooler colors (blue) had less influence.

        For PCOS detection, the model typically focuses on:
        - Follicle distribution and appearance
        - Ovarian size and texture
        - Stromal echogenicity (brightness)

        ### Download Results:
        You can download the heatmap visualization for further analysis.
        """

        # Generate heatmap metadata
        heatmap_metadata = {
            "prediction": predicted_class,
            "confidence": float(confidence),
            "class_probabilities": {k: float(v) for k, v in results.items()},
            "timestamp": pd.Timestamp.now().isoformat()
        }

        # Save metadata
        with open("heatmap_metadata.json", "w") as f:
            json.dump(heatmap_metadata, f, indent=2)

        os.remove(temp_path)  # Clean up

        return fig, heatmap_overlay, explanation, heatmap_path

    # Create Gradio interface
    demo = gr.Interface(
        fn=predict_fn,
        inputs=gr.Image(),
        outputs=[
            gr.Plot(label="Probability Distribution"),
            gr.Image(label="Heatmap Visualization"),
            gr.Markdown(label="Explanation"),
            gr.File(label="Download Heatmap")
        ],
        title="PCOS Detection with Explainable AI",
        description="Upload an ultrasound scan to detect if it shows PCOS or normal condition. The heatmap shows which image areas influenced the model's decision.",
        examples=["example1.jpg", "example2.jpg"]  # You can add example images here
    )

    return demo

# Main function
def main():
    """Main function to orchestrate the entire workflow."""
    # 1. Download datasets
    print("Step 1: Downloading datasets...")
    datasets = download_datasets()

    # 2. Prepare data
    print("Step 2: Preparing data...")
    X, y, label_to_index = prepare_data(datasets)

    # 3. Split the data
    print("Step 3: Splitting data...")
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42, stratify=y_train)
    print(f"Training set: {X_train.shape}, Validation set: {X_val.shape}, Test set: {X_test.shape}")

    # 4. Build model
    print("Step 4: Building model...")
    model = build_model(INPUT_SHAPE, NUM_CLASSES)
    model.summary()

    # 5. Train model
    print("Step 5: Training model...")
    model, history = train_model(model, X_train, y_train, X_val, y_val)

    # 6. Visualize training
    print("Step 6: Visualizing training history...")
    plot_training_history(history)

    # 7. Evaluate model
    print("Step 7: Evaluating model...")
    evaluate_model(model, X_test, y_test, label_to_index)

    # 8. Save model
    model_path = "pcos_detection_model.h5"
    model.save(model_path)
    print(f"Model saved to {model_path}")

    # 9. Create and launch frontend
    print("Step 8: Creating demo frontend...")
    demo = create_frontend(model, label_to_index)
    demo.launch()

# Entry point
if __name__ == "__main__":
    main()

# Simplified demo for direct use
def run_demo(model_path="pcos_detection_model.h5"):
    """Run just the demo frontend with a pre-trained model."""
    # Define the conditions we're using
    conditions = ["normal", "pcos"]
    label_to_index = {condition: idx for idx, condition in enumerate(conditions)}

    # Check if model exists
    if os.path.exists(model_path):
        # Create and launch frontend
        demo = create_frontend(model_path, label_to_index)
        demo.launch()
    else:
        print(f"Model file {model_path} not found. Please run the main function first to train the model.")