In [1]:
import tensorflow as tf
from tensorflow.keras.applications import DenseNet121, ResNet50, Xception
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import logging
import pickle

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("cnn_models.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

def create_model(base_model_name, input_shape=(256, 256, 3), classes=2):
    """
    Create a CNN model using a pre-trained base model
    
    Parameters:
    -----------
    base_model_name : str
        Name of the base model to use ('densenet121', 'resnet50', or 'xception')
    input_shape : tuple
        Input shape for the model
    classes : int
        Number of output classes
        
    Returns:
    --------
    Model
        The compiled Keras model
    """
    if base_model_name.lower() == 'densenet121':
        base_model = DenseNet121(weights='imagenet', include_top=False, input_shape=input_shape)
    elif base_model_name.lower() == 'resnet50':
        base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
    elif base_model_name.lower() == 'xception':
        base_model = Xception(weights='imagenet', include_top=False, input_shape=input_shape)
    else:
        raise ValueError(f"Unknown base model: {base_model_name}")
    
    # Add custom layers on top of the base model
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024, activation='relu')(x)
    x = Dropout(0.5)(x)
    predictions = Dense(classes, activation='softmax')(x)
    
    # Create the model
    model = Model(inputs=base_model.input, outputs=predictions)
    
    # Compile the model
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

def train_model(model, dataset_dir, input_shape=(256, 256), batch_size=32, epochs=20):
    """
    Train a CNN model on the dataset
    
    Parameters:
    -----------
    model : Model
        The Keras model to train
    dataset_dir : str
        Path to the dataset directory
    input_shape : tuple
        Input shape for images (without channels)
    batch_size : int
        Batch size for training
    epochs : int
        Number of epochs to train
        
    Returns:
    --------
    Model
        The trained Keras model
    History
        Training history
    """
    # Create data generators with augmentation
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        validation_split=0.2
    )
    
    training_generator = train_datagen.flow_from_directory(
        dataset_dir,
        target_size=input_shape,
        batch_size=batch_size,
        class_mode='categorical',
        subset='training'
    )
    
    validation_generator = train_datagen.flow_from_directory(
        dataset_dir,
        target_size=input_shape,
        batch_size=batch_size,
        class_mode='categorical',
        subset='validation'
    )
    
    # Train the model
    history = model.fit(
        training_generator,
        steps_per_epoch=training_generator.samples // batch_size,
        validation_data=validation_generator,
        validation_steps=validation_generator.samples // batch_size,
        epochs=epochs
    )
    
    return model, history

def evaluate_model(model, dataset_dir, input_shape=(256, 256), batch_size=32):
    """
    Evaluate a CNN model on the dataset
    
    Parameters:
    -----------
    model : Model
        The Keras model to evaluate
    dataset_dir : str
        Path to the dataset directory
    input_shape : tuple
        Input shape for images (without channels)
    batch_size : int
        Batch size for evaluation
        
    Returns:
    --------
    dict
        Evaluation metrics
    """
    # Create data generator
    test_datagen = ImageDataGenerator(rescale=1./255)
    
    test_generator = test_datagen.flow_from_directory(
        dataset_dir,
        target_size=input_shape,
        batch_size=batch_size,
        class_mode='categorical',
        shuffle=False
    )
    
    # Evaluate the model
    metrics = model.evaluate(
        test_generator,
        steps=test_generator.samples // batch_size + 1
    )
    
    result = dict(zip(model.metrics_names, metrics))
    
    # Get predictions
    y_pred = model.predict(
        test_generator,
        steps=test_generator.samples // batch_size + 1
    )
    y_pred_classes = np.argmax(y_pred, axis=1)
    
    # Get true labels
    y_true = test_generator.classes
    
    # Calculate confusion matrix
    from sklearn.metrics import confusion_matrix, classification_report
    cm = confusion_matrix(y_true, y_pred_classes)
    report = classification_report(y_true, y_pred_classes, target_names=test_generator.class_indices.keys())
    
    result['confusion_matrix'] = cm
    result['classification_report'] = report
    
    return result

def create_ensemble_model(models):
    """
    Create an ensemble model from multiple base models
    
    Parameters:
    -----------
    models : list
        List of base models to include in the ensemble
        
    Returns:
    --------
    function
        A function that takes an image and returns ensemble predictions
    """
    def ensemble_predict(img):
        """
        Make an ensemble prediction on a single image
        
        Parameters:
        -----------
        img : array
            Input image
            
        Returns:
        --------
        array
            Ensemble prediction
        """
        # Preprocess image
        img = cv2.resize(img, (256, 256))
        img = img / 255.0
        img = np.expand_dims(img, axis=0)
        
        # Get predictions from all models
        predictions = []
        for model in models:
            pred = model.predict(img, verbose=0)
            predictions.append(pred)
        
        # Average predictions
        ensemble_pred = np.mean(predictions, axis=0)
        
        return ensemble_pred
    
    return ensemble_predict

def generate_gradcam(img, model, layer_name=None):
    """
    Generate a GradCAM heatmap for the image
    
    Parameters:
    -----------
    img : array
        Input image (normalized)
    model : Model
        The Keras model to use
    layer_name : str or None
        Name of the layer to use for GradCAM
        
    Returns:
    --------
    array
        GradCAM heatmap
    """
    # Expand dimensions to match model input shape
    img_tensor = np.expand_dims(img, axis=0)
    
    # If layer_name is not provided, try to find the last convolutional layer
    if layer_name is None:
        for layer in reversed(model.layers):
            if 'conv' in layer.name:
                layer_name = layer.name
                break
    
    # If we still don't have a layer_name, raise an error
    if layer_name is None:
        raise ValueError("Could not find a convolutional layer in the model")
    
    # Get the gradient of the top predicted class with respect to the output of the layer
    grad_model = tf.keras.models.Model(
        inputs=[model.inputs],
        outputs=[model.get_layer(layer_name).output, model.output]
    )
    
    with tf.GradientTape() as tape:
        # Cast the image tensor to a float-32 tensor
        img_tensor = tf.cast(img_tensor, tf.float32)
        conv_outputs, predictions = grad_model(img_tensor)
        class_idx = tf.argmax(predictions[0])
        top_class = predictions[:, class_idx]
    
    # Compute the gradient of the top predicted class with respect to the activations of the last conv layer
    grads = tape.gradient(top_class, conv_outputs)
    
    # This is a vector where each entry is the mean intensity of the gradient over a specific feature map channel
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    
    # Multiply each channel in the feature map array by the "importance" of this channel
    # FIX: Use TensorFlow operations instead of direct assignment
    conv_outputs_tensor = conv_outputs[0].numpy()  # Convert to numpy for manipulation
    
    # Weight feature maps by importance using NumPy
    for i in range(pooled_grads.shape[0]):
        conv_outputs_tensor[:, :, i] *= pooled_grads[i].numpy()
    
    # The channel-wise mean of the weighted feature map
    heatmap = np.mean(conv_outputs_tensor, axis=-1)
    
    # For visualization purpose, normalize the heatmap between 0 & 1
    heatmap = np.maximum(heatmap, 0) / (np.max(heatmap) or 1e-10)
    
    # Resize to original image size
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    
    return heatmap

def superimpose_heatmap(img, heatmap, alpha=0.6):
    """
    Superimpose a heatmap on the original image
    
    Parameters:
    -----------
    img : array
        Original image
    heatmap : array
        GradCAM heatmap
    alpha : float
        Transparency factor
        
    Returns:
    --------
    array
        Superimposed image
    """
    # Convert heatmap to RGB
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    
    # Convert original image to uint8 if it's not already
    if img.dtype != np.uint8:
        img = np.uint8(255 * img)
    
    # Superimpose the heatmap on original image
    superimposed_img = cv2.addWeighted(img, 1.0 - alpha, heatmap, alpha, 0)
    
    return superimposed_img

def identify_affected_areas(heatmap, threshold=0.7):
    """
    Identify affected brain areas based on the heatmap
    
    Parameters:
    -----------
    heatmap : array
        GradCAM heatmap
    threshold : float
        Threshold for considering a region as affected
        
    Returns:
    --------
    list
        Names of potentially affected brain areas
    """
    # Binarize the heatmap using the threshold
    binary_heatmap = heatmap > threshold
    
    # Calculate the centroid of the affected area
    y_indices, x_indices = np.where(binary_heatmap)
    if len(y_indices) == 0 or len(x_indices) == 0:
        return ["No specific affected areas detected"]
    
    centroid_y = np.mean(y_indices)
    centroid_x = np.mean(x_indices)
    
    # Brain region mapping based on image quadrants
    h, w = heatmap.shape
    regions = []
    
    # Left vs Right
    if centroid_x < w/2:
        regions.append("Left hemisphere")
    else:
        regions.append("Right hemisphere")
    
    # Anterior vs Posterior
    if centroid_y < h/2:
        regions.append("Anterior region")
    else:
        regions.append("Posterior region")
    
    # Check specific quadrants for more detailed localization
    if centroid_x < w/3:
        if centroid_y < h/3:
            regions.append("Possibly frontal lobe")
        elif centroid_y > 2*h/3:
            regions.append("Possibly occipital lobe")
        else:
            regions.append("Possibly temporal lobe")
    elif centroid_x > 2*w/3:
        if centroid_y < h/3:
            regions.append("Possibly frontal lobe")
        elif centroid_y > 2*h/3:
            regions.append("Possibly occipital lobe")
        else:
            regions.append("Possibly temporal lobe")
    else:
        if centroid_y < h/2:
            regions.append("Possibly parietal lobe")
        else:
            regions.append("Possibly cerebellum or brain stem")
    
    # Add a disclaimer
    regions.append("Note: This is a preliminary estimate and should be confirmed by a medical professional")
    
    return regions

def predict_stroke(image_path, models=None, models_dir='models'):
    """
    Predict stroke from a brain CT scan image
    
    Parameters:
    -----------
    image_path : str
        Path to the input image
    models : dict or None
        Dictionary of pre-loaded models
    models_dir : str
        Directory containing saved models
        
    Returns:
    --------
    dict
        Prediction results including class, probability, and heatmap
    """
    try:
        # Load image
        img = cv2.imread(image_path)
        if img is None:
            raise ValueError(f"Failed to load image: {image_path}")
        
        # Convert BGR to RGB
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Resize image
        img_resized = cv2.resize(img, (256, 256))
        
        # Normalize image
        img_normalized = img_resized / 255.0
        
        # Expand dimensions to match model input shape
        img_batch = np.expand_dims(img_normalized, axis=0)
        
        # Load models if not provided
        if models is None:
            models = {}
            model_names = ['densenet121', 'resnet50', 'xception']
            
            for model_name in model_names:
                # Try loading from pickle first
                pickle_path = os.path.join(models_dir, f"stroke_{model_name}.pkl")
                keras_path = os.path.join(models_dir, f"stroke_{model_name}.keras")
                
                if os.path.exists(pickle_path):
                    with open(pickle_path, 'rb') as f:
                        models[model_name] = pickle.load(f)
                elif os.path.exists(keras_path):
                    models[model_name] = tf.keras.models.load_model(keras_path)
        
        if not models:
            raise ValueError("No models available for prediction")
        
        # Make predictions with each model
        predictions = {}
        for model_name, model in models.items():
            pred = model.predict(img_batch, verbose=0)[0]
            predictions[model_name] = {
                'normal_prob': float(pred[0]),
                'stroke_prob': float(pred[1])
            }
        
        # Calculate ensemble prediction (average of all models)
        ensemble_normal_prob = np.mean([pred['normal_prob'] for pred in predictions.values()])
        ensemble_stroke_prob = np.mean([pred['stroke_prob'] for pred in predictions.values()])
        
        # Determine the predicted class
        predicted_class = 'Stroke' if ensemble_stroke_prob > 0.5 else 'Normal'
        
        # Generate heatmap for the most confident model
        most_confident_model_name = max(
            predictions.keys(), 
            key=lambda k: predictions[k]['stroke_prob'] if predicted_class == 'Stroke' else predictions[k]['normal_prob']
        )
        most_confident_model = models[most_confident_model_name]
        
        # Create GradCAM for visualization
        heatmap = generate_gradcam(img_normalized, most_confident_model)
        
        # Create superimposed image
        superimposed_img = superimpose_heatmap(img_resized, heatmap)
        
        # Identify affected areas
        affected_areas = identify_affected_areas(heatmap)
        
        result = {
            'predicted_class': predicted_class,
            'confidence': float(ensemble_stroke_prob if predicted_class == 'Stroke' else ensemble_normal_prob),
            'model_predictions': predictions,
            'heatmap_image': superimposed_img,
            'heatmap': heatmap,
            'affected_areas': affected_areas
        }
        
        return result
    
    except Exception as e:
        logger.error(f"Error in predict_stroke: {str(e)}")
        raise

def load_models(models_dir='models'):
    """
    Load trained CNN models
    
    Parameters:
    -----------
    models_dir : str
        Directory containing the saved models
        
    Returns:
    --------
    dict
        Dictionary of loaded models
    """
    models = {}
    model_names = ['densenet121', 'resnet50', 'xception']
    
    for model_name in model_names:
        # Try loading from pickle first
        pickle_path = os.path.join(models_dir, f"stroke_{model_name}.pkl")
        keras_path = os.path.join(models_dir, f"stroke_{model_name}.keras")
        
        if os.path.exists(pickle_path):
            logger.info(f"Loading {model_name} model from pickle: {pickle_path}")
            try:
                with open(pickle_path, 'rb') as f:
                    models[model_name] = pickle.load(f)
                logger.info(f"Successfully loaded {model_name} model from pickle")
            except Exception as e:
                logger.error(f"Error loading {model_name} model from pickle: {str(e)}")
        
        elif os.path.exists(keras_path):
            logger.info(f"Loading {model_name} model from keras: {keras_path}")
            try:
                models[model_name] = tf.keras.models.load_model(keras_path)
                logger.info(f"Successfully loaded {model_name} model from keras")
            except Exception as e:
                logger.error(f"Error loading {model_name} model: {str(e)}")
    
    if not models:
        logger.warning("No models were loaded. Make sure the model files exist.")
    
    return models

def save_models_as_pickle(models, models_dir='models'):
    """
    Save models in pickle format for use in the web app
    
    Parameters:
    -----------
    models : dict
        Dictionary of models to save
    models_dir : str
        Directory to save the models
    """
    os.makedirs(models_dir, exist_ok=True)
    
    for model_name, model in models.items():
        pickle_path = os.path.join(models_dir, f"stroke_{model_name}.pkl")
        logger.info(f"Saving {model_name} model to pickle: {pickle_path}")
        
        try:
            with open(pickle_path, 'wb') as f:
                pickle.dump(model, f)
            logger.info(f"Successfully saved {model_name} model to pickle")
        except Exception as e:
            logger.error(f"Error saving {model_name} model to pickle: {str(e)}")

def train_and_save_all_models(dataset_dir, models_dir='models'):
    """
    Train and save all three models (DenseNet121, ResNet50, Xception)
    
    Parameters:
    -----------
    dataset_dir : str
        Path to the dataset directory
    models_dir : str
        Directory to save the models
        
    Returns:
    --------
    dict
        Dictionary containing trained models
    """
    # Create models directory if it doesn't exist
    os.makedirs(models_dir, exist_ok=True)
    
    # Models to train
    model_names = ['densenet121', 'resnet50', 'xception']
    trained_models = {}
    
    for model_name in model_names:
        logger.info(f"Training {model_name.upper()} model...")
        
        # Create and train the model
        model = create_model(model_name)
        model, history = train_model(model, dataset_dir)
        
        # Save the model in keras format
        keras_path = os.path.join(models_dir, f"stroke_{model_name}.keras")
        model.save(keras_path)
        logger.info(f"Model saved to {keras_path}")
        
        # Also save as pickle for web app
        pickle_path = os.path.join(models_dir, f"stroke_{model_name}.pkl")
        with open(pickle_path, 'wb') as f:
            pickle.dump(model, f)
        logger.info(f"Model saved as pickle to {pickle_path}")
        
        # Plot and save training history
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(history.history['accuracy'], label='Train')
        plt.plot(history.history['val_accuracy'], label='Validation')
        plt.title(f'{model_name.upper()} - Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(history.history['loss'], label='Train')
        plt.plot(history.history['val_loss'], label='Validation')
        plt.title(f'{model_name.upper()} - Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig(os.path.join(models_dir, f"{model_name}_history.png"))
        plt.close()
        
        # Store the trained model
        trained_models[model_name] = model
        
        # Evaluate the model
        logger.info(f"Evaluating {model_name.upper()} model...")
        result = evaluate_model(model, dataset_dir)
        
        logger.info(f"Accuracy: {result['accuracy']:.4f}")
        logger.info(f"Loss: {result['loss']:.4f}")
        logger.info(f"Confusion Matrix:\n{result['confusion_matrix']}")
        logger.info(f"Classification Report:\n{result['classification_report']}")
    
    return trained_models

def test_prediction_on_samples(trained_models, dataset_dir, models_dir='models'):
    """
    Test the prediction on sample images
    
    Parameters:
    -----------
    trained_models : dict
        Dictionary of trained models
    dataset_dir : str
        Path to the dataset directory
    models_dir : str
        Directory containing saved models
    """
    # Test prediction on sample images
    sample_paths = [
        os.path.join(dataset_dir, 'Normal', '1 (1).jpg'),
        os.path.join(dataset_dir, 'Stroke', '2 (1).jpg')
    ]
    
    for sample_path in sample_paths:
        if os.path.exists(sample_path):
            logger.info(f"Predicting on {sample_path}...")
            try:
                result = predict_stroke(sample_path, trained_models)
                
                logger.info(f"Predicted class: {result['predicted_class']}")
                logger.info(f"Confidence: {result['confidence']:.4f}")
                
                for model_name, pred in result['model_predictions'].items():
                    logger.info(f"{model_name}: Normal={pred['normal_prob']:.4f}, Stroke={pred['stroke_prob']:.4f}")
                
                # Log affected areas
                logger.info("Potentially affected brain areas:")
                for area in result['affected_areas']:
                    logger.info(f"- {area}")
                
                # Save visualization
                output_dir = os.path.join(models_dir, 'visualizations')
                os.makedirs(output_dir, exist_ok=True)
                
                # Get filename without path and extension
                filename = os.path.splitext(os.path.basename(sample_path))[0]
                
                # Save heatmap image
                heatmap_path = os.path.join(output_dir, f"{filename}_heatmap.png")
                cv2.imwrite(heatmap_path, cv2.cvtColor(result['heatmap_image'], cv2.COLOR_RGB2BGR))
                logger.info(f"Heatmap saved to {heatmap_path}")
                
            except Exception as e:
                logger.error(f"Error predicting on {sample_path}: {str(e)}")
        else:
            logger.warning(f"Sample image not found: {sample_path}")

def main():
    """
    Main function to train and evaluate models
    """
    dataset_dir = 'Dataset'
    models_dir = 'models'
    
    # Check if models directory exists
    os.makedirs(models_dir, exist_ok=True)
    
    # Check if models already exist in pickle format
    pickle_files = [f"stroke_{name}.pkl" for name in ['densenet121', 'resnet50', 'xception']]
    pickles_exist = all(os.path.exists(os.path.join(models_dir, f)) for f in pickle_files)
    
    # Check if models exist in keras format
    keras_files = [f"stroke_{name}.keras" for name in ['densenet121', 'resnet50', 'xception']]
    keras_exist = all(os.path.exists(os.path.join(models_dir, f)) for f in keras_files)
    
    trained_models = {}
    
    if pickles_exist:
        logger.info("Models already exist in pickle format, loading them...")
        trained_models = load_models(models_dir)
    elif keras_exist:
        logger.info("Models exist in keras format, loading and converting to pickle...")
        trained_models = load_models(models_dir)
        save_models_as_pickle(trained_models, models_dir)
    else:
        logger.info("Training new models...")
        trained_models = train_and_save_all_models(dataset_dir, models_dir)
    
    # Test prediction on sample images
    test_prediction_on_samples(trained_models, dataset_dir, models_dir)
    
    logger.info("Process completed successfully!")

if __name__ == "__main__":
    main()




2025-04-11 19:06:35,876 - __main__ - INFO - Models exist in keras format, loading and converting to pickle...
2025-04-11 19:06:35,877 - __main__ - INFO - Loading densenet121 model from keras: models\stroke_densenet121.keras












2025-04-11 19:07:56,094 - __main__ - INFO - Successfully loaded densenet121 model from keras
2025-04-11 19:07:56,094 - __main__ - INFO - Loading resnet50 model from keras: models\stroke_resnet50.keras
2025-04-11 19:09:52,685 - __main__ - INFO - Successfully loaded resnet50 model from keras
2025-04-11 19:09:52,685 - __main__ - INFO - Loading xception model from keras: models\stroke_xception.keras
2025-04-11 19:11:10,257 - __main__ - INFO - Successfully loaded xception model from keras
2025-04-11 19:11:10,257 - __main__ - INFO - Saving densenet121 model to pickle: models\stroke_densenet121.pkl
2025-04-11 19:11:13,862 - __main__ - INFO - Successfully saved densenet121 model to pickle
2025-04-11 19:11:13,862 - __main__ - INFO - Saving resnet50 model to pickle: models\stroke_resnet50.pkl
2025-04-11 19:11:19,593 - __main__ - INFO - Successfully saved resnet50 model to pickle
2025-04-11 19:11:19,593 - __main__ - INFO - Saving xception model to pickle: models\stroke_xception.pkl
2025-04-11 19