<a href="https://colab.research.google.com/github/Alpharages/CardioXNet/blob/main/cardiac_failure_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AI-Based Detection of Cardiac Failure Using Chest X-Ray Images

This notebook implements a custom Convolutional Neural Network (CNN) for detecting cardiac failure using chest X-ray images. The model is designed to provide early, accessible, and cost-effective diagnosis of cardiac failure, potentially improving patient outcomes through timely intervention.

## Table of Contents
1. [Setup](#setup)
2. [Data Preparation](#data-preparation)
3. [Data Preprocessing](#data-preprocessing)
4. [Model Building](#model-building)
5. [Training](#training)
6. [Evaluation](#evaluation)
7. [Visualization](#visualization)
8. [Inference](#inference)

## Setup <a name="setup"></a>

First, let's set up our environment by installing the required dependencies and cloning the repository.

In [None]:
# Check if running in Colab
import sys
IN_COLAB = 'google.colab' in sys.modules
print(f"Running in Colab: {IN_COLAB}")

In [None]:
# Install dependencies

                !pip install -q tensorflow
                !pip install -q numpy
                !pip install -q pandas
                !pip install -q matplotlib
                !pip install -q scikit-learn
                !pip install -q opencv-python
                !pip install -q pillow
                !pip install -q albumentations
                !pip install -q tqdm
                !pip install -q tensorboard
                !pip install -q python-dotenv
                !pip install -q seaborn
                

In [None]:
# Clone the repository if in Colab
if IN_COLAB:
    !git clone https://github.com/Alpharages/CardioXNet.git
    %cd CardioXNet
    
    # Create necessary directories
    !mkdir -p data/images
    !mkdir -p data/processed
    !mkdir -p results/logs

In [None]:
# Import necessary libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from pathlib import Path
from datetime import datetime
from dotenv import load_dotenv
from google.colab import files

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

## Data Preparation <a name="data-preparation"></a>

For this project, we'll use the ChestX-ray8 dataset from NIH. In Colab, we have two options:
1. Upload a small sample dataset directly to Colab
2. Mount Google Drive and access a larger dataset stored there

Let's implement both options:

In [None]:
# Option 1: Upload a small sample dataset directly to Colab
def upload_sample_data():
    print("Please upload your metadata file (filename_label.csv):")
    uploaded = files.upload()
    for filename in uploaded.keys():
        !mv "{filename}" data/filename_label.csv
    
    print("\nPlease upload your X-ray images (you can select multiple files):")
    uploaded = files.upload()
    for filename in uploaded.keys():
        !mv "{filename}" data/images/
    
    print(f"Uploaded {len(uploaded)} images to data/images/")
    return Path('data/filename_label.csv'), Path('data/images')

In [None]:
# Option 2: Mount Google Drive and access data stored there
def mount_drive_data():
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Specify the path to your data in Google Drive
    drive_metadata_path = input("Enter the path to your metadata file in Google Drive: ")
    drive_images_path = input("Enter the path to your images folder in Google Drive: ")
    
    # Create symbolic links to the data
    !ln -s "{drive_metadata_path}" data/filename_label.csv
    !ln -s "{drive_images_path}" data/images
    
    return Path('data/filename_label.csv'), Path('data/images')

In [None]:
# Choose your data source
data_source = input("Choose your data source (1 for upload, 2 for Google Drive): ")

if data_source == "1":
    metadata_path, image_dir = upload_sample_data()
elif data_source == "2":
    metadata_path, image_dir = mount_drive_data()
else:
    print("Invalid choice. Using upload option by default.")
    metadata_path, image_dir = upload_sample_data()

In [None]:
# Create a .env file with the paths
with open('.env', 'w') as f:
    f.write(f"DATA_DIR=data\n")
    f.write(f"PROCESSED_DATA_DIR=data/processed\n")
    f.write(f"IMAGE_DIR={image_dir}\n")
    f.write(f"METADATA_FILE={metadata_path}\n")
    f.write(f"RESULTS_DIR=results\n")

# Load environment variables
load_dotenv()

## Data Preprocessing <a name="data-preprocessing"></a>

Now, let's implement the enhanced data preprocessing pipeline using the ChestXRayPreprocessor class. This pipeline includes advanced augmentations such as vertical flips, affine transformations, CLAHE, coarse dropout, and Gaussian noise to improve model generalization.

In [None]:
# Import the preprocessor
from src.data.preprocess import ChestXRayPreprocessor

# Initialize the preprocessor
preprocessor = ChestXRayPreprocessor()

# Preprocess the dataset
print("Starting dataset preprocessing...")
processed_metadata = preprocessor.preprocess_dataset(metadata_path, image_dir)
print(f"Preprocessing complete. Processed {len(processed_metadata)} images.")

In [None]:
# Create TensorFlow datasets
from src.training.train import create_dataset

# Path to the processed metadata
processed_metadata_path = os.path.join(os.getenv('PROCESSED_DATA_DIR', 'data/processed'), 'processed_metadata.csv')

# Create datasets
print("Creating datasets...")
train_dataset = create_dataset(
    processed_metadata_path,
    batch_size=32,
    split='train'
)

val_dataset = create_dataset(
    processed_metadata_path,
    batch_size=32,
    split='val'
)

test_dataset = create_dataset(
    processed_metadata_path,
    batch_size=32,
    split='test'
)

## Model Building <a name="model-building"></a>

Now, let's build our enhanced custom CNN model for cardiac failure detection. This model features increased capacity with double convolutional layers in each block, increased filter counts, and larger dense layers for better feature extraction.

In [None]:
# Import the model
from src.models.cardiac_model import CardiacFailureModel

# Initialize the model
print("Initializing model...")
model = CardiacFailureModel()

# Print model summary
model.model.summary()

In [None]:
# Alternatively, you can use a pre-trained model
use_pretrained = input("Do you want to use a pre-trained model? (y/n): ")

if use_pretrained.lower() == 'y':
    # Choose a pre-trained model
    print("Available pre-trained models:")
    print("1. EfficientNetB0")
    print("2. ResNet50")
    print("3. DenseNet121")
    print("4. MobileNetV2")
    
    model_choice = input("Choose a model (1-4): ")
    
    if model_choice == "1":
        base_model_name = 'efficientnetb0'
    elif model_choice == "2":
        base_model_name = 'resnet50'
    elif model_choice == "3":
        base_model_name = 'densenet121'
    elif model_choice == "4":
        base_model_name = 'mobilenetv2'
    else:
        print("Invalid choice. Using EfficientNetB0 by default.")
        base_model_name = 'efficientnetb0'
    
    # Create the pre-trained model
    model = CardiacFailureModel()
    model.create_pretrained_model(base_model_name=base_model_name, fine_tune_layers=10)
    
    # Print model summary
    model.model.summary()

## Training <a name="training"></a>

Now, let's train our model with the preprocessed data using advanced training techniques. The training process includes class imbalance handling, cosine decay learning rate scheduling, enhanced callbacks for model checkpointing, and comprehensive TensorBoard logging.

In [None]:
# Set up TensorBoard
# Load the TensorBoard extension
%load_ext tensorboard

# Create TensorBoard log directory with timestamp
current_time = datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = os.path.join(os.getenv('RESULTS_DIR', 'results'), 'logs', current_time)
os.makedirs(log_dir, exist_ok=True)

# Create TensorBoard callback with enhanced configuration
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=1,  # Log weight histograms every epoch
    write_graph=True,  # Log model graph
    write_images=True,  # Log weight images
    update_freq='epoch',  # Log metrics every epoch
    profile_batch='500,520'  # Profile a few batches
)

# Launch TensorBoard
%tensorboard --logdir={log_dir}

In [None]:
# Set up callbacks
results_dir = Path(os.getenv('RESULTS_DIR', 'results'))
results_dir.mkdir(exist_ok=True)

callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath=str(results_dir / 'best_model'),
        save_best_only=True,
        monitor='val_auc',
        mode='max',
        save_format='tf'
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True
    ),
    tensorboard_callback,
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,
        patience=3,
        min_lr=1e-6
    )
]

In [None]:
# Train the model
print("Starting training...")

if use_pretrained.lower() == 'y':
    # Fine-tune the pre-trained model
    history1, history2 = model.fine_tune(
        train_dataset,
        val_dataset,
        epochs=20,  # Initial training epochs
        callbacks=callbacks,
        fine_tune_epochs=10  # Fine-tuning epochs
    )
    # Combine histories for evaluation
    history = history2
    history.history.update({k: history1.history[k] + history2.history[k] for k in history1.history.keys()})
else:
    # Train the custom model
    history = model.train(
        train_dataset,
        val_dataset,
        epochs=50,
        callbacks=callbacks
    )

# Save the final model
model.save(results_dir / 'final_model')
print("Training complete!")

## Evaluation <a name="evaluation"></a>

Now, let's evaluate our trained model on the test dataset using a comprehensive evaluation pipeline. This includes multiple performance metrics (accuracy, precision, recall, F1 score, ROC AUC, PR AUC) and detailed visualizations to thoroughly assess model performance.

In [None]:
# Import the evaluation module
from src.utils.evaluation import evaluate_model

# Evaluate the model
print("Evaluating model...")
metrics_summary = evaluate_model(
    model.model,
    test_dataset,
    history=history,
    results_dir=results_dir
)

# Print summary metrics
print("\nModel Performance Summary:")
for metric, value in metrics_summary.items():
    print(f"{metric}: {value:.4f}")

## Visualization <a name="visualization"></a>

Let's visualize the evaluation results.

In [None]:
# Display the evaluation plots
def display_evaluation_plots():
    plot_files = [
        'confusion_matrix.png',
        'roc_curve.png',
        'precision_recall_curve.png',
        'prediction_distribution.png',
        'training_loss.png',
        'training_accuracy.png',
        'training_auc.png'
    ]
    
    for plot_file in plot_files:
        plot_path = results_dir / plot_file
        if plot_path.exists():
            plt.figure(figsize=(10, 8))
            img = plt.imread(plot_path)
            plt.imshow(img)
            plt.axis('off')
            plt.title(plot_file.replace('.png', '').replace('_', ' ').title())
            plt.show()
        else:
            print(f"Plot file not found: {plot_path}")

# Display the plots
display_evaluation_plots()

## Inference <a name="inference"></a>

Finally, let's implement a function to make predictions on new X-ray images.

In [None]:
# Function to preprocess a single image for prediction
def preprocess_image_for_prediction(image_path):
    import cv2
    import albumentations as A
    
    # Define the transformation
    transform = A.Compose([
        A.Resize(224, 224),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Read and preprocess the image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    transformed = transform(image=image)
    processed_image = transformed['image']
    
    # Add batch dimension
    return np.expand_dims(processed_image, axis=0)

In [None]:
# Function to make predictions on new images
def predict_cardiac_failure(model, image_path):
    # Preprocess the image
    processed_image = preprocess_image_for_prediction(image_path)
    
    # Make prediction
    prediction = model.predict(processed_image)[0][0]
    
    # Display the image and prediction
    plt.figure(figsize=(8, 8))
    img = plt.imread(image_path)
    plt.imshow(img, cmap='gray')
    plt.title(f"Prediction: {'Cardiac Failure' if prediction > 0.5 else 'Normal'} ({prediction:.4f})")
    plt.axis('off')
    plt.show()
    
    return prediction

In [None]:
# Upload a new image for prediction
def predict_on_uploaded_image():
    print("Please upload an X-ray image for prediction:")
    uploaded = files.upload()
    
    for filename in uploaded.keys():
        print(f"\nPredicting on {filename}...")
        prediction = predict_cardiac_failure(model.model, filename)
        print(f"Prediction probability: {prediction:.4f}")
        print(f"Diagnosis: {'Cardiac Failure' if prediction > 0.5 else 'Normal'}")

# Make predictions on uploaded images
predict_on_uploaded_image()

## Conclusion

In this notebook, we've implemented a complete pipeline for cardiac failure detection using chest X-ray images:

1. Set up the environment and data
2. Preprocessed the X-ray images with enhanced augmentation techniques (vertical flips, affine transformations, CLAHE, coarse dropout, and Gaussian noise)
3. Built and trained an enhanced custom CNN model with increased capacity (double convolutional layers, increased filters, larger dense layers)
4. Implemented advanced training techniques (class imbalance handling, cosine decay learning rate scheduling, enhanced callbacks)
5. Evaluated the model using comprehensive metrics (accuracy, precision, recall, F1 score, ROC AUC, PR AUC)
6. Visualized the results with detailed plots
7. Implemented inference on new images

This enhanced model can be used as a tool to assist healthcare professionals in the early detection of cardiac failure, potentially improving patient outcomes through timely intervention. The advanced techniques implemented in this notebook help achieve better performance and generalization compared to simpler approaches.