In [None]:
!pip install opendatasets tensorflow matplotlib scikit-learn seaborn plotly
!pip install opencv-python pillow pandas numpy

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import DenseNet121, ResNet50, EfficientNetB0
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import classification_report, confusion_matrix
import cv2
from pathlib import Path
import shutil
import glob
import warnings
warnings.filterwarnings('ignore')

In [None]:
DISEASE_CONFIG = {
    'pneumonia': {
        'dataset_url': 'https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia',
        'classes': ['NORMAL', 'PNEUMONIA'],
        'img_size': (224, 224),
        'class_mode': 'binary',
        'loss': 'binary_crossentropy',
        'metrics': ['accuracy', tf.keras.metrics.AUC(name='auc')],
        'base_model': DenseNet121
    },
    'skin_cancer': {
        'dataset_url': 'https://www.kaggle.com/datasets/fanconic/skin-cancer-malignant-vs-benign',
        'classes': ['benign', 'malignant'],
        'img_size': (224, 224),
        'class_mode': 'binary',
        'loss': 'binary_crossentropy',
        'metrics': ['accuracy', tf.keras.metrics.AUC(name='auc')],
        'base_model': EfficientNetB0
    },
    'brain_tumor': {
        'dataset_url': 'https://www.kaggle.com/datasets/ahmedhamada0/brain-tumor-detection',
        'classes': ['no', 'yes'],
        'img_size': (224, 224),
        'class_mode': 'binary',
        'loss': 'binary_crossentropy',
        'metrics': ['accuracy', tf.keras.metrics.AUC(name='auc')],
        'base_model': ResNet50
    },
    'covid19': {
        'dataset_url': 'https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database',
        'classes': ['COVID', 'Lung_Opacity', 'Normal', 'Viral Pneumonia'],
        'img_size': (224, 224),
        'class_mode': 'categorical',
        'loss': 'categorical_crossentropy',
        'metrics': ['accuracy'],
        'base_model': DenseNet121
    }
}


In [None]:
class MultiDiseaseDatasetManager:
    def __init__(self):
        self.setup_kaggle()

    def setup_kaggle(self):
        """Setup Kaggle API credentials"""
        if not Path('kaggle.json').exists():
            print("Please upload your kaggle.json file.")
            from google.colab import files
            files.upload()
        os.environ['KAGGLE_CONFIG_DIR'] = '/content'
        !chmod 600 /content/kaggle.json

    def download_dataset(self, disease_name, dataset_url):
        """Download and organize dataset for specific disease"""
        print(f"Downloading {disease_name} dataset...")
        import opendatasets as od

        data_dir = f"data/{disease_name}"
        od.download(dataset_url, data_dir=data_dir)

        # Organize dataset structure based on disease type
        self.organize_dataset(disease_name, data_dir)

    def organize_dataset(self, disease_name, data_dir):
        """Organize dataset structure"""
        if disease_name == 'pneumonia':
            # Handle chest X-ray pneumonia dataset
            if Path(f"{data_dir}/chest-xray-pneumonia").exists():
                shutil.move(f"{data_dir}/chest-xray-pneumonia/chest_xray", f"{data_dir}/chest_xray")
                shutil.rmtree(f"{data_dir}/chest-xray-pneumonia")

        elif disease_name == 'skin_cancer':
            # Handle skin cancer dataset - usually has different structure
            self._organize_skin_cancer_dataset(data_dir)

        elif disease_name == 'brain_tumor':
            # Handle brain tumor dataset
            self._organize_brain_tumor_dataset(data_dir)

        elif disease_name == 'covid19':
            # Handle COVID-19 dataset
            self._organize_covid19_dataset(data_dir)

        print(f"Dataset {disease_name} organized successfully!")

    def _organize_skin_cancer_dataset(self, data_dir):
        """Organize skin cancer dataset structure"""
        # Find the actual dataset folder
        possible_paths = [
            f"{data_dir}/skin-cancer-malignant-vs-benign",
            f"{data_dir}/data",
            f"{data_dir}/train",
            f"{data_dir}"
        ]

        actual_data_path = None
        for path in possible_paths:
            if Path(path).exists():
                # Check if it contains class folders
                subdirs = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
                if any(class_name in subdirs for class_name in ['benign', 'malignant', 'train', 'test']):
                    actual_data_path = path
                    break

        if actual_data_path:
            # If data is directly in class folders, create train/val/test split
            if 'benign' in os.listdir(actual_data_path) and 'malignant' in os.listdir(actual_data_path):
                self._create_train_val_test_split(actual_data_path, ['benign', 'malignant'])

    def _organize_brain_tumor_dataset(self, data_dir):
        """Organize brain tumor dataset structure"""
        possible_paths = [
            f"{data_dir}/brain-tumor-detection",
            f"{data_dir}/data",
            f"{data_dir}"
        ]

        actual_data_path = None
        for path in possible_paths:
            if Path(path).exists():
                subdirs = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
                if any(class_name in subdirs for class_name in ['no', 'yes', 'train', 'test']):
                    actual_data_path = path
                    break

        if actual_data_path and 'no' in os.listdir(actual_data_path):
            self._create_train_val_test_split(actual_data_path, ['no', 'yes'])

    def _organize_covid19_dataset(self, data_dir):
        """Organize COVID-19 dataset structure"""
        possible_paths = [
            f"{data_dir}/covid19-radiography-database",
            f"{data_dir}/COVID-19_Radiography_Dataset",
            f"{data_dir}"
        ]

        actual_data_path = None
        for path in possible_paths:
            if Path(path).exists():
                subdirs = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
                covid_classes = ['COVID', 'Lung_Opacity', 'Normal', 'Viral Pneumonia']
                if any(class_name in subdirs for class_name in covid_classes):
                    actual_data_path = path
                    break

        if actual_data_path:
            self._create_train_val_test_split(actual_data_path, ['COVID', 'Lung_Opacity', 'Normal', 'Viral Pneumonia'])

    def _create_train_val_test_split(self, data_path, class_names, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
        """Create train/validation/test split from class folders"""
        import random

        # Set random seed for reproducibility
        random.seed(42)

        base_dir = os.path.dirname(data_path)
        train_dir = os.path.join(base_dir, 'train')
        val_dir = os.path.join(base_dir, 'val')
        test_dir = os.path.join(base_dir, 'test')

        # Create train, val, and test directories
        os.makedirs(train_dir, exist_ok=True)
        os.makedirs(val_dir, exist_ok=True)
        os.makedirs(test_dir, exist_ok=True)

        print(f"Creating train/val/test split with ratios: {train_ratio:.1%}/{val_ratio:.1%}/{test_ratio:.1%}")

        total_files = 0
        for class_name in class_names:
            class_path = os.path.join(data_path, class_name)
            if not os.path.exists(class_path):
                print(f"Warning: Class folder {class_name} not found, skipping...")
                continue

            # Create class subdirectories
            os.makedirs(os.path.join(train_dir, class_name), exist_ok=True)
            os.makedirs(os.path.join(val_dir, class_name), exist_ok=True)
            os.makedirs(os.path.join(test_dir, class_name), exist_ok=True)

            # Get all image files with various extensions
            image_files = []
            image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff', '*.tif', '*.gif', '*.webp']

            for ext in image_extensions:
                image_files.extend(glob.glob(os.path.join(class_path, ext)))
                image_files.extend(glob.glob(os.path.join(class_path, ext.upper())))

            if len(image_files) == 0:
                print(f"Warning: No image files found in {class_path}")
                continue

            # Shuffle files for random distribution
            random.shuffle(image_files)

            # Calculate split indices
            total_images = len(image_files)
            train_end = int(total_images * train_ratio)
            val_end = train_end + int(total_images * val_ratio)

            # Split files
            train_files = image_files[:train_end]
            val_files = image_files[train_end:val_end]
            test_files = image_files[val_end:]

            # Function to copy files safely
            def copy_files_safely(file_list, dest_dir, split_name):
                copied_count = 0
                for file_path in file_list:
                    try:
                        filename = os.path.basename(file_path)
                        dest_path = os.path.join(dest_dir, class_name, filename)
                        if not os.path.exists(dest_path):
                            shutil.copy2(file_path, dest_path)
                            copied_count += 1
                    except Exception as e:
                        print(f"Error copying {file_path}: {str(e)}")
                return copied_count

            # Copy files to respective directories
            train_copied = copy_files_safely(train_files, train_dir, 'train')
            val_copied = copy_files_safely(val_files, val_dir, 'val')
            test_copied = copy_files_safely(test_files, test_dir, 'test')

            total_files += total_images
            print(f"  {class_name}: {train_copied} train, {val_copied} val, {test_copied} test images")

        print(f"Total images processed: {total_files}")
        return train_dir, val_dir, test_dir


In [None]:
# ==============================
# 5. Data Generators
# ==============================
class DataGeneratorFactory:
    @staticmethod
    def create_generators(disease_name, base_dir, config):
        """Create data generators for specific disease"""
        img_size = config['img_size']
        batch_size = 32

        # Data augmentation for training
        train_datagen = ImageDataGenerator(
            rescale=1./255,
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            fill_mode='nearest',
            validation_split=0.2  # Use 20% for validation
        )

        # No augmentation for validation/test
        val_test_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

        # Create generators based on dataset structure
        if disease_name == 'pneumonia':
            train_dir = os.path.join(base_dir, "chest_xray/train")
            val_dir = os.path.join(base_dir, "chest_xray/val")
            test_dir = os.path.join(base_dir, "chest_xray/test")

            if all(os.path.exists(d) for d in [train_dir, val_dir, test_dir]):
                train_gen = train_datagen.flow_from_directory(
                    train_dir, target_size=img_size, batch_size=batch_size,
                    class_mode=config['class_mode']
                )

                val_gen = ImageDataGenerator(rescale=1./255).flow_from_directory(
                    val_dir, target_size=img_size, batch_size=batch_size,
                    class_mode=config['class_mode']
                )

                test_gen = ImageDataGenerator(rescale=1./255).flow_from_directory(
                    test_dir, target_size=img_size, batch_size=batch_size,
                    class_mode=config['class_mode'], shuffle=False
                )
            else:
                raise FileNotFoundError(f"Pneumonia dataset structure not found in {base_dir}")

        else:
            # For other datasets with created train/val/test split
            train_dir = os.path.join(base_dir, "train")
            val_dir = os.path.join(base_dir, "val")
            test_dir = os.path.join(base_dir, "test")

            # Check if train/val/test directories exist
            if all(os.path.exists(d) for d in [train_dir, val_dir, test_dir]):
                train_gen = ImageDataGenerator(rescale=1./255).flow_from_directory(
                    train_dir, target_size=img_size, batch_size=batch_size,
                    class_mode=config['class_mode']
                )

                val_gen = ImageDataGenerator(rescale=1./255).flow_from_directory(
                    val_dir, target_size=img_size, batch_size=batch_size,
                    class_mode=config['class_mode']
                )

                test_gen = ImageDataGenerator(rescale=1./255).flow_from_directory(
                    test_dir, target_size=img_size, batch_size=batch_size,
                    class_mode=config['class_mode'], shuffle=False
                )

            else:
                # Fallback: use validation_split on single directory if splits don't exist
                possible_dirs = [
                    os.path.join(base_dir, "train"),
                    base_dir
                ]

                data_directory = None
                for dir_path in possible_dirs:
                    if os.path.exists(dir_path):
                        try:
                            subdirs = [d for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))]
                            if len(subdirs) >= 2:  # Has class folders
                                data_directory = dir_path
                                break
                        except:
                            continue

                if data_directory is None:
                    raise FileNotFoundError(f"Could not find organized data directory for {disease_name}")

                print(f"Using fallback validation_split for {disease_name}")
                train_gen = train_datagen.flow_from_directory(
                    data_directory, target_size=img_size, batch_size=batch_size,
                    class_mode=config['class_mode'], subset='training'
                )

                val_gen = val_test_datagen.flow_from_directory(
                    data_directory, target_size=img_size, batch_size=batch_size,
                    class_mode=config['class_mode'], subset='validation'
                )

                # Use validation set as test set
                test_gen = val_gen

        return train_gen, val_gen, test_gen


In [None]:
class MultiDiseaseModelBuilder:
    @staticmethod
    def build_model(disease_name, config):
        """Build model for specific disease"""
        input_shape = config['img_size'] + (3,)
        base_model_class = config['base_model']

        # Load pre-trained base model
        base_model = base_model_class(
            weights='imagenet',
            include_top=False,
            input_shape=input_shape
        )
        base_model.trainable = False

        # Determine output neurons based on problem type
        if config['class_mode'] == 'binary':
            output_neurons = 1
            activation = 'sigmoid'
        else:
            output_neurons = len(config['classes'])
            activation = 'softmax'

        # Build model
        model = models.Sequential([
            base_model,
            layers.GlobalAveragePooling2D(),
            layers.Dropout(0.5),
            layers.Dense(512, activation='relu'),
            layers.BatchNormalization(),
            layers.Dropout(0.3),
            layers.Dense(256, activation='relu'),
            layers.BatchNormalization(),
            layers.Dropout(0.2),
            layers.Dense(output_neurons, activation=activation)
        ])

        # Compile model
        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
            loss=config['loss'],
            metrics=config['metrics']
        )

        return model

# ==============================
# 7. Training Pipeline
# ==============================
class DiseaseTrainer:
    def __init__(self):
        self.models = {}
        self.histories = {}

    def train_disease_model(self, disease_name, train_gen, val_gen, config, epochs=20):
        """Train model for specific disease"""
        print(f"\n{'='*50}")
        print(f"Training {disease_name.upper()} Detection Model")
        print(f"{'='*50}")

        # Build model
        model = MultiDiseaseModelBuilder.build_model(disease_name, config)
        print(f"Model architecture for {disease_name}:")
        model.summary()

        # Callbacks
        callbacks = [
            ModelCheckpoint(
                f"best_{disease_name}_model.h5",
                monitor='val_loss',
                mode='min',
                save_best_only=True,
                verbose=1
            ),
            EarlyStopping(
                monitor='val_loss',
                patience=7,
                mode='min',
                restore_best_weights=True
            ),
            ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=3,
                min_lr=1e-7,
                verbose=1
            )
        ]

        # Train model
        history = model.fit(
            train_gen,
            steps_per_epoch=max(1, train_gen.samples // train_gen.batch_size),
            validation_data=val_gen,
            validation_steps=max(1, val_gen.samples // val_gen.batch_size),
            epochs=epochs,
            callbacks=callbacks,
            verbose=1
        )

        # Load best model
        model = tf.keras.models.load_model(f"best_{disease_name}_model.h5")

        # Store model and history
        self.models[disease_name] = model
        self.histories[disease_name] = history

        return model, history

    def train_all_diseases(self, diseases_to_train, epochs=20):
        """Train models for all specified diseases with error handling"""
        dataset_manager = MultiDiseaseDatasetManager()

        successful_trainings = []
        failed_trainings = []

        for disease_name in diseases_to_train:
            print(f"\n{'='*60}")
            print(f"Processing {disease_name.upper()}")
            print(f"{'='*60}")

            if disease_name not in DISEASE_CONFIG:
                print(f"Warning: {disease_name} not found in config. Skipping...")
                failed_trainings.append((disease_name, "Not in config"))
                continue

            # Check if model already exists
            model_path = f"best_{disease_name}_model.h5"
            if os.path.exists(model_path):
                print(f"Model for {disease_name} already exists! Loading existing model...")
                try:
                    model = tf.keras.models.load_model(model_path)
                    self.models[disease_name] = model
                    successful_trainings.append(disease_name)
                    print(f"✅ Successfully loaded existing {disease_name} model!")
                    continue
                except Exception as e:
                    print(f"Failed to load existing model: {e}")
                    print("Will retrain the model...")

            try:
                config = DISEASE_CONFIG[disease_name]

                # Download dataset
                print(f"📥 Downloading {disease_name} dataset...")
                dataset_manager.download_dataset(disease_name, config['dataset_url'])

                # Create data generators
                print(f"🔄 Creating data generators for {disease_name}...")
                base_dir = f"data/{disease_name}"
                train_gen, val_gen, test_gen = DataGeneratorFactory.create_generators(
                    disease_name, base_dir, config
                )

                # Train model
                print(f"🚀 Training {disease_name} model...")
                model, history = self.train_disease_model(
                    disease_name, train_gen, val_gen, config, epochs
                )

                # Evaluate model
                print(f"📊 Evaluating {disease_name} model...")
                self.evaluate_model(disease_name, model, test_gen, config)

                successful_trainings.append(disease_name)
                print(f"✅ Successfully completed training for {disease_name}!")

            except Exception as e:
                print(f"❌ Error training {disease_name}: {str(e)}")
                failed_trainings.append((disease_name, str(e)))
                print(f"Skipping {disease_name} and continuing with next disease...")
                continue

        # Summary
        print(f"\n{'='*60}")
        print("TRAINING SUMMARY")
        print(f"{'='*60}")
        print(f"✅ Successfully trained: {len(successful_trainings)} models")
        for disease in successful_trainings:
            print(f"   - {disease}")

        if failed_trainings:
            print(f"\n❌ Failed to train: {len(failed_trainings)} models")
            for disease, error in failed_trainings:
                print(f"   - {disease}: {error}")

        if successful_trainings:
            print(f"\n🎉 Training completed with {len(successful_trainings)} successful models!")
        else:
            print(f"\n⚠️  No models were successfully trained.")

        return successful_trainings, failed_trainings

    def evaluate_model(self, disease_name, model, test_gen, config):
        """Evaluate trained model with better error handling"""
        print(f"\n{'='*30}")
        print(f"Evaluating {disease_name.upper()} Model")
        print(f"{'='*30}")

        try:
            # Get predictions
            print("Generating predictions...")
            predictions = model.predict(test_gen, verbose=1)

            # Handle different prediction formats
            if config['class_mode'] == 'binary':
                if len(predictions.shape) == 2 and predictions.shape[1] == 1:
                    # Binary classification with shape (n, 1)
                    y_pred = (predictions[:, 0] > 0.5).astype(int)
                elif len(predictions.shape) == 1:
                    # Binary classification with shape (n,)
                    y_pred = (predictions > 0.5).astype(int)
                else:
                    print(f"Unexpected prediction shape: {predictions.shape}")
                    y_pred = (predictions.flatten() > 0.5).astype(int)
            else:
                # Multi-class classification
                y_pred = np.argmax(predictions, axis=1)

            # Get true labels
            y_true = test_gen.classes

            # Ensure arrays have same length
            min_length = min(len(y_true), len(y_pred))
            y_true = y_true[:min_length]
            y_pred = y_pred[:min_length]

            print(f"Evaluation samples: {min_length}")
            print(f"True labels shape: {y_true.shape}")
            print(f"Predicted labels shape: {y_pred.shape}")

            # Calculate metrics
            test_results = model.evaluate(test_gen, verbose=0)
            print(f"Test Loss: {test_results[0]:.4f}")
            print(f"Test Accuracy: {test_results[1]:.4f}")

            # Classification report
            class_names = list(test_gen.class_indices.keys())
            print("\nClassification Report:")
            print(classification_report(y_true, y_pred, target_names=class_names))

            # Confusion Matrix
            cm = confusion_matrix(y_true, y_pred)
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                       xticklabels=class_names, yticklabels=class_names)
            plt.title(f'{disease_name.title()} - Confusion Matrix')
            plt.ylabel('True Label')
            plt.xlabel('Predicted Label')
            plt.show()

        except Exception as e:
            print(f"Error during evaluation: {e}")
            print("Attempting basic evaluation...")
            try:
                test_results = model.evaluate(test_gen, verbose=0)
                print(f"Test Loss: {test_results[0]:.4f}")
                print(f"Test Accuracy: {test_results[1]:.4f}")
            except Exception as e2:
                print(f"Basic evaluation also failed: {e2}")


In [None]:
class VisualizationTools:
    @staticmethod
    def plot_training_history(histories):
        """Plot training history for all diseases"""
        fig, axes = plt.subplots(2, len(histories), figsize=(20, 10))

        for idx, (disease_name, history) in enumerate(histories.items()):
            # Plot accuracy
            axes[0, idx].plot(history.history['accuracy'], label='Training Accuracy')
            axes[0, idx].plot(history.history['val_accuracy'], label='Validation Accuracy')
            axes[0, idx].set_title(f'{disease_name.title()} - Accuracy')
            axes[0, idx].set_xlabel('Epoch')
            axes[0, idx].set_ylabel('Accuracy')
            axes[0, idx].legend()

            # Plot loss
            axes[1, idx].plot(history.history['loss'], label='Training Loss')
            axes[1, idx].plot(history.history['val_loss'], label='Validation Loss')
            axes[1, idx].set_title(f'{disease_name.title()} - Loss')
            axes[1, idx].set_xlabel('Epoch')
            axes[1, idx].set_ylabel('Loss')
            axes[1, idx].legend()

        plt.tight_layout()
        plt.show()

    @staticmethod
    def plot_sample_images(generators_dict):
        """Plot sample images from all datasets"""
        fig, axes = plt.subplots(len(generators_dict), 3, figsize=(15, 5*len(generators_dict)))

        for idx, (disease_name, (train_gen, _, _)) in enumerate(generators_dict.items()):
            images, labels = next(train_gen)

            for i in range(3):
                ax = axes[idx, i] if len(generators_dict) > 1 else axes[i]
                ax.imshow(images[i])

                # Get class name
                class_names = list(train_gen.class_indices.keys())
                if hasattr(labels[i], '__len__'):  # categorical
                    label_idx = np.argmax(labels[i])
                else:  # binary
                    label_idx = int(labels[i])

                ax.set_title(f'{disease_name.title()}: {class_names[label_idx]}')
                ax.axis('off')

        plt.tight_layout()
        plt.show()


In [None]:
class GradCAMVisualizer:
    @staticmethod
    def make_gradcam_heatmap(img_array, model, last_conv_layer_name):
        """Generate Grad-CAM heatmap"""
        # Find the last convolutional layer
        for layer in reversed(model.layers):
            if 'conv' in layer.name.lower():
                last_conv_layer_name = layer.name
                break

        grad_model = tf.keras.models.Model(
            [model.inputs],
            [model.get_layer(last_conv_layer_name).output, model.output]
        )

        with tf.GradientTape() as tape:
            conv_outputs, predictions = grad_model(img_array)
            pred_index = tf.argmax(predictions[0])
            class_channel = predictions[:, pred_index]

        grads = tape.gradient(class_channel, conv_outputs)
        pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

        conv_outputs = conv_outputs[0]
        heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
        heatmap = tf.squeeze(heatmap)
        heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)

        return heatmap.numpy()

    @staticmethod
    def display_gradcam(img_path, model, img_size):
        """Display Grad-CAM visualization"""
        # Load and preprocess image
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_resized = cv2.resize(img, img_size)
        img_array = np.expand_dims(img_resized / 255.0, axis=0)

        # Generate heatmap
        heatmap = GradCAMVisualizer.make_gradcam_heatmap(img_array, model, None)

        # Resize heatmap to original image size
        heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
        heatmap = np.uint8(255 * heatmap)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

        # Create overlay
        superimposed_img = cv2.addWeighted(img, 0.6, heatmap, 0.4, 0)

        # Display results
        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.imshow(img)
        plt.title("Original Image")
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(heatmap)
        plt.title("Grad-CAM Heatmap")
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(superimposed_img)
        plt.title("Overlay")
        plt.axis('off')

        plt.tight_layout()
        plt.show()


In [None]:
class MultiDiseasePredictor:
    def __init__(self):
        self.models = {}
        self.configs = {}

    def load_models(self, disease_names):
        """Load trained models"""
        for disease_name in disease_names:
            try:
                model_path = f"best_{disease_name}_model.h5"
                self.models[disease_name] = tf.keras.models.load_model(model_path)
                self.configs[disease_name] = DISEASE_CONFIG[disease_name]
                print(f"Loaded {disease_name} model successfully!")
            except:
                print(f"Could not load {disease_name} model. Make sure it's trained first.")

    def predict_image(self, img_path, disease_name):
        """Predict disease for single image"""
        if disease_name not in self.models:
            print(f"Model for {disease_name} not loaded!")
            return None

        model = self.models[disease_name]
        config = self.configs[disease_name]
        img_size = config['img_size']

        # Load and preprocess image
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, img_size)
        img_array = np.expand_dims(img / 255.0, axis=0)

        # Make prediction
        prediction = model.predict(img_array)

        # Interpret prediction
        if config['class_mode'] == 'binary':
            confidence = float(prediction[0][0])
            predicted_class = config['classes'][1] if confidence > 0.5 else config['classes'][0]
            confidence = confidence if confidence > 0.5 else 1 - confidence
        else:
            class_idx = np.argmax(prediction[0])
            predicted_class = config['classes'][class_idx]
            confidence = float(prediction[0][class_idx])

        return {
            'disease': disease_name,
            'predicted_class': predicted_class,
            'confidence': confidence,
            'all_probabilities': prediction[0].tolist()
        }

    def predict_multiple_diseases(self, img_path):
        """Predict multiple diseases for single image"""
        results = {}

        for disease_name in self.models.keys():
            result = self.predict_image(img_path, disease_name)
            if result:
                results[disease_name] = result

        return results


In [None]:
def main():
    """Main execution pipeline"""
    print("Multi-Disease Prediction System")
    print("="*50)

    # Initialize trainer
    trainer = DiseaseTrainer()

    # Select diseases to train (modify as needed)
    diseases_to_train = ['pneumonia', 'skin_cancer', 'brain_tumor']

    try:
        # Train all models with robust error handling
        successful_trainings, failed_trainings = trainer.train_all_diseases(diseases_to_train, epochs=15)

        if successful_trainings:
            # Visualize training results for successful trainings only
            if trainer.histories:
                vis_tools = VisualizationTools()
                vis_tools.plot_training_history(trainer.histories)

            # Initialize predictor and load models
            predictor = MultiDiseasePredictor()
            predictor.load_models(successful_trainings)

            print(f"\n🎉 System ready! Successfully trained {len(successful_trainings)} models.")

            # Provide usage instructions
            print("\n" + "="*60)
            print("USAGE INSTRUCTIONS")
            print("="*60)
            print("To make predictions on new images:")
            print("1. Single disease prediction:")
            print("   result = predictor.predict_image('image_path.jpg', 'pneumonia')")
            print("2. Multi-disease prediction:")
            print("   results = predictor.predict_multiple_diseases('image_path.jpg')")

        else:
            print("\n❌ No models were successfully trained. Please check the errors above.")
            return None, None

    except Exception as e:
        print(f"Critical error during training pipeline: {str(e)}")
        return None, None

    return trainer, predictor

In [None]:
"""
# Example usage:

# 1. Train models for specific diseases
trainer = DiseaseTrainer()
trainer.train_all_diseases(['pneumonia', 'skin_cancer'], epochs=10)

# 2. Load trained models and make predictions
predictor = MultiDiseasePredictor()
predictor.load_models(['pneumonia', 'skin_cancer'])

# 3. Predict single disease
result = predictor.predict_image('test_image.jpg', 'pneumonia')
print(f"Prediction: {result['predicted_class']} (Confidence: {result['confidence']:.2f})")

# 4. Predict multiple diseases
results = predictor.predict_multiple_diseases('test_image.jpg')
for disease, result in results.items():
    print(f"{disease}: {result['predicted_class']} ({result['confidence']:.2f})")

# 5. Visualize Grad-CAM
gradcam = GradCAMVisualizer()
gradcam.display_gradcam('test_image.jpg', predictor.models['pneumonia'], (224, 224))
"""

if __name__ == "__main__":
    trainer, predictor = main()