<a href="https://colab.research.google.com/github/ALLEE16481/COVID-19-X-ray-Forgery-Detection-Model/blob/main/COVID_19_X_ray_Forgery_Detection_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# Basic scientific stack (already installed usually, but just in case)
!pip install numpy pandas matplotlib seaborn scikit-learn

# TensorFlow (pre-installed in Colab, but if you want a specific version)
!pip install tensorflow

# OpenCV
!pip install opencv-python-headless

# Requests (usually installed, but just in case)
!pip install requests

# pathlib is part of Python standard library — no install needed.




In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.preprocessing import LabelEncoder
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
import cv2
import os
import random
from pathlib import Path
import zipfile
import requests
from io import BytesIO

In [2]:
# Set seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

print("📦 All libraries imported successfully!")

📦 All libraries imported successfully!


In [19]:
# Set seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

# Check if running in Google Colab
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

print("📦 All libraries imported successfully!")

📦 All libraries imported successfully!


In [52]:
class ColabForgeryDetector:
    def __init__(self, img_size=(224, 224), batch_size=16):
        self.img_size = img_size
        self.batch_size = batch_size
        self.label_encoder = LabelEncoder()
        self.X = None
        self.y_raw = None
        self.y_encoded = None
        self.y_categorical = None
        self.X_train = None
        self.X_val = None
        self.X_test = None
        self.y_train = None
        self.y_val = None
        self.y_test = None
        self.model = None
        self.history = None

    def create_synthetic_dataset(self, samples_per_class=200):
        """
        Create synthetic dataset for demonstration (replace with real data loading)
        """
        print("🎲 Creating synthetic dataset for demonstration...")

        images = []
        labels = []

        class_names = ['COVID-19', 'Viral_Pneumonia', 'Normal']

        for class_name in class_names:
            print(f"   Generating {samples_per_class} samples for {class_name}")

            for i in range(samples_per_class):
                # Generate realistic-looking X-ray synthetic data
                base_img = np.random.normal(0.3, 0.1, (*self.img_size, 3))

                # Add class-specific patterns
                if class_name == 'COVID-19':
                    # Add some ground-glass opacity patterns
                    for _ in range(5):
                        x, y = np.random.randint(50, self.img_size[0]-50, 2)
                        cv2.circle(base_img, (x, y), np.random.randint(10, 30),
                                 (0.6, 0.6, 0.6), -1)

                elif class_name == 'Viral_Pneumonia':
                    # Add consolidation patterns
                    for _ in range(3):
                        x, y = np.random.randint(30, self.img_size[0]-30, 2)
                        cv2.rectangle(base_img, (x, y),
                                    (x+40, y+40), (0.7, 0.7, 0.7), -1)

                # Normalize and clip
                base_img = np.clip(base_img, 0, 1)
                images.append(base_img)
                labels.append(class_name)

        self.X = np.array(images, dtype=np.float32)
        self.y_raw = np.array(labels)

        # Encode labels
        self.y_encoded = self.label_encoder.fit_transform(self.y_raw)
        self.y_categorical = to_categorical(self.y_encoded, num_classes=3)

        print(f"✅ Dataset created: {len(self.X)} images")
        print(f"   Classes: {list(self.label_encoder.classes_)}")

        # Return the data as well, although it's now stored in self.X and self.y_categorical
        return self.X, self.y_categorical

In [53]:
class ColabForgeryDetector:
    def __init__(self, img_size=(224, 224), batch_size=16):
        self.img_size = img_size
        self.batch_size = batch_size
        self.label_encoder = LabelEncoder()
        self.X = None
        self.y_raw = None
        self.y_encoded = None
        self.y_categorical = None
        self.X_train = None
        self.X_val = None
        self.X_test = None
        self.y_train = None
        self.y_val = None
        self.y_test = None
        self.model = None
        self.history = None

    def download_kaggle_dataset(self):
        """
        Download the actual Kaggle dataset (requires Kaggle API setup)
        """
        if not IN_COLAB:
            print("⚠️  This function is designed for Google Colab")
            return

        try:
            # Mount Google Drive to access Kaggle credentials
            from google.colab import drive
            drive.mount('/content/drive')

            # Install Kaggle
            os.system('pip install kaggle')

            # Setup Kaggle credentials
            os.environ['KAGGLE_CONFIG_DIR'] = '/content/drive/MyDrive/kaggle'

            # Download dataset
            os.system('kaggle datasets download -d nourmahmoud/covid19-digital-xrays-forgery-dataset')

            # Extract dataset
            with zipfile.ZipFile('covid19-digital-xrays-forgery-dataset.zip', 'r') as zip_ref:
                zip_ref.extractall('/content/dataset')

            print("✅ Kaggle dataset downloaded and extracted!")
            return '/content/dataset'

        except Exception as e:
            print(f"❌ Error downloading dataset: {e}")
            print("💡 Using synthetic data instead...")
            return None

In [54]:
class ColabForgeryDetector:
    def __init__(self, img_size=(224, 224), batch_size=16):
        self.img_size = img_size
        self.batch_size = batch_size
        self.label_encoder = LabelEncoder()
        self.X = None
        self.y_raw = None
        self.y_encoded = None
        self.y_categorical = None
        self.X_train = None
        self.X_val = None
        self.X_test = None
        self.y_train = None
        self.y_val = None
        self.y_test = None
        self.model = None
        self.history = None

    def visualize_samples(self, num_samples=8):
        """
        Visualize sample images
        """
        plt.figure(figsize=(15, 8))
        random_indices = random.sample(range(len(self.X)), min(num_samples, len(self.X)))

        for i, idx in enumerate(random_indices):
            plt.subplot(2, 4, i + 1)
            plt.imshow(self.X[idx])
            plt.title(f'{self.y_raw[idx]}')
            plt.axis('off')

        plt.suptitle('Sample X-ray Images', fontsize=16)
        plt.tight_layout()
        plt.show()

In [55]:
class ColabForgeryDetector:
    def __init__(self, img_size=(224, 224), batch_size=16):
        self.img_size = img_size
        self.batch_size = batch_size
        self.label_encoder = LabelEncoder()
        self.X = None
        self.y_raw = None
        self.y_encoded = None
        self.y_categorical = None
        self.X_train = None
        self.X_val = None
        self.X_test = None
        self.y_train = None
        self.y_val = None
        self.y_test = None
        self.model = None
        self.history = None

    def split_data(self, test_size=0.2, val_size=0.1):
        """
        Split data with stratification
        """
        print("📊 Splitting data...")

        # First split
        X_temp, self.X_test, y_temp, self.y_test = train_test_split(
            self.X, self.y_categorical, test_size=test_size,
            random_state=42, stratify=self.y_categorical
        )

        # Second split
        val_size_adj = val_size / (1 - test_size)
        self.X_train, self.X_val, self.y_train, self.y_val = train_test_split(
            X_temp, y_temp, test_size=val_size_adj,
            random_state=42, stratify=y_temp
        )

        print(f"   Training: {self.X_train.shape[0]} samples")
        print(f"   Validation: {self.X_val.shape[0]} samples")
        print(f"   Test: {self.X_test.shape[0]} samples")

In [56]:
class ColabForgeryDetector:
    def __init__(self, img_size=(224, 224), batch_size=16):
        self.img_size = img_size
        self.batch_size = batch_size
        self.label_encoder = LabelEncoder()
        self.X = None
        self.y_raw = None
        self.y_encoded = None
        self.y_categorical = None
        self.X_train = None
        self.X_val = None
        self.X_test = None
        self.y_train = None
        self.y_val = None
        self.y_test = None
        self.model = None
        self.history = None

    def build_optimized_model(self):
        """
        Build Colab-optimized CNN model
        """
        print("🏗️  Building optimized CNN model...")

        model = Sequential([
            # Efficient convolutional base
            Conv2D(32, (3, 3), activation='relu', input_shape=(*self.img_size, 3)),
            BatchNormalization(),
            MaxPooling2D((2, 2)),

            Conv2D(64, (3, 3), activation='relu'),
            BatchNormalization(),
            MaxPooling2D((2, 2)),

            Conv2D(128, (3, 3), activation='relu'),
            BatchNormalization(),
            MaxPooling2D((2, 2)),

            # Global average pooling instead of flatten (more efficient)
            tf.keras.layers.GlobalAveragePooling2D(),

            # Dense layers as specified in your requirements
            Dense(300, activation='relu'),
            Dropout(0.5),

            Dense(150, activation='relu'),
            Dropout(0.4),

            Dense(75, activation='relu'),
            Dropout(0.3),

            Dense(50, activation='relu'),
            Dropout(0.2),

            Dense(25, activation='relu'),
            Dropout(0.1),

            # Output layer with float32 for stability
            Dense(3, activation='softmax', dtype='float32')
        ])

        # Use Adam with learning rate scheduling
        optimizer = Adam(learning_rate=0.001)

        model.compile(
            optimizer=optimizer,
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )

        self.model = model
        print("✅ Model built successfully!")

        # Display model summary
        model.summary()
        return model

In [57]:
class ColabForgeryDetector:
    def __init__(self, img_size=(224, 224), batch_size=16):
        self.img_size = img_size
        self.batch_size = batch_size
        self.label_encoder = LabelEncoder()
        self.X = None
        self.y_raw = None
        self.y_encoded = None
        self.y_categorical = None
        self.X_train = None
        self.X_val = None
        self.X_test = None
        self.y_train = None
        self.y_val = None
        self.y_test = None
        self.model = None
        self.history = None

    def create_data_generators(self):
        """
        Create memory-efficient data generators
        """
        # Conservative augmentation for Colab
        train_datagen = ImageDataGenerator(
            rotation_range=10,
            width_shift_range=0.1,
            height_shift_range=0.1,
            horizontal_flip=True,
            zoom_range=0.1,
            fill_mode='nearest'
        )

        val_datagen = ImageDataGenerator()

        self.train_gen = train_datagen.flow(
            self.X_train, self.y_train,
            batch_size=self.batch_size,
            shuffle=True
        )

        self.val_gen = val_datagen.flow(
            self.X_val, self.y_val,
            batch_size=self.batch_size,
            shuffle=False
        )

        return self.train_gen, self.val_gen

In [58]:
class ColabForgeryDetector:
    def __init__(self, img_size=(224, 224), batch_size=16):
        self.img_size = img_size
        self.batch_size = batch_size
        self.label_encoder = LabelEncoder()
        self.X = None
        self.y_raw = None
        self.y_encoded = None
        self.y_categorical = None
        self.X_train = None
        self.X_val = None
        self.X_test = None
        self.y_train = None
        self.y_val = None
        self.y_test = None
        self.model = None
        self.history = None

    def train_with_callbacks(self, epochs=30):
        """
        Train with Colab-optimized callbacks
        """
        print(f"🚂 Training model for {epochs} epochs...")

        callbacks = [
            tf.keras.callbacks.EarlyStopping(
                monitor='val_accuracy',
                patience=8,
                restore_best_weights=True,
                verbose=1
            ),
            tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.3,
                patience=4,
                min_lr=1e-7,
                verbose=1
            ),
            tf.keras.callbacks.ModelCheckpoint(
                'best_model.h5',
                monitor='val_accuracy',
                save_best_only=True,
                verbose=1
            )
        ]

        # Calculate steps
        steps_per_epoch = len(self.X_train) // self.batch_size
        validation_steps = len(self.X_val) // self.batch_size

        self.history = self.model.fit(
            self.train_gen,
            steps_per_epoch=steps_per_epoch,
            epochs=epochs,
            validation_data=self.val_gen,
            validation_steps=validation_steps,
            callbacks=callbacks,
            verbose=1
        )

        print("✅ Training completed!")
        return self.history

In [59]:
class ColabForgeryDetector:
    def __init__(self, img_size=(224, 224), batch_size=16):
        self.img_size = img_size
        self.batch_size = batch_size
        self.label_encoder = LabelEncoder()
        self.X = None
        self.y_raw = None
        self.y_encoded = None
        self.y_categorical = None
        self.X_train = None
        self.X_val = None
        self.X_test = None
        self.y_train = None
        self.y_val = None
        self.y_test = None
        self.model = None
        self.history = None


    def evaluate_and_visualize(self):
        """
        Comprehensive evaluation with visualizations
        """
        print("📈 Evaluating model performance...")

        # Predictions
        test_pred = self.model.predict(self.X_test, batch_size=self.batch_size)
        test_pred_classes = np.argmax(test_pred, axis=1)
        test_true_classes = np.argmax(self.y_test, axis=1)

        # Accuracy
        accuracy = accuracy_score(test_true_classes, test_pred_classes)
        print(f"🎯 Test Accuracy: {accuracy:.4f}")

        # Create comprehensive visualization
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))

        # 1. Training History
        axes[0, 0].plot(self.history.history['accuracy'], label='Training')
        axes[0, 0].plot(self.history.history['val_accuracy'], label='Validation')
        axes[0, 0].set_title('Model Accuracy')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Accuracy')
        axes[0, 0].legend()
        axes[0, 0].grid(True)

        # 2. Loss History
        axes[0, 1].plot(self.history.history['loss'], label='Training')
        axes[0, 1].plot(self.history.history['val_loss'], label='Validation')
        axes[0, 1].set_title('Model Loss')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True)

        # 3. Confusion Matrix
        cm = confusion_matrix(test_true_classes, test_pred_classes)
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[1, 0],
                   xticklabels=self.label_encoder.classes_,
                   yticklabels=self.label_encoder.classes_)
        axes[1, 0].set_title('Confusion Matrix')
        axes[1, 0].set_ylabel('True Label')
        axes[1, 0].set_xlabel('Predicted Label')

        # 4. Class Distribution
        unique, counts = np.unique(test_pred_classes, return_counts=True)
        axes[1, 1].bar([self.label_encoder.classes_[i] for i in unique], counts)
        axes[1, 1].set_title('Prediction Distribution')
        axes[1, 1].set_ylabel('Count')
        axes[1, 1].tick_params(axis='x', rotation=45)

        plt.tight_layout()
        plt.show()

        # Classification Report
        print("\n📊 Detailed Classification Report:")
        print(classification_report(test_true_classes, test_pred_classes,
                                  target_names=self.label_encoder.classes_))

        return accuracy

In [60]:
class ColabForgeryDetector:
    def __init__(self, img_size=(224, 224), batch_size=16):
        self.img_size = img_size
        self.batch_size = batch_size
        self.label_encoder = LabelEncoder()
        self.X = None
        self.y_raw = None
        self.y_encoded = None
        self.y_categorical = None
        self.X_train = None
        self.X_val = None
        self.X_test = None
        self.y_train = None
        self.y_val = None
        self.y_test = None
        self.model = None
        self.history = None

    def predict_sample(self, image_idx=None):
        """
        Predict and visualize a sample
        """
        if image_idx is None:
            image_idx = random.randint(0, len(self.X_test) - 1)

        img = self.X_test[image_idx:image_idx+1]
        true_label = self.label_encoder.classes_[np.argmax(self.y_test[image_idx])]

        pred = self.model.predict(img, verbose=0)
        pred_label = self.label_encoder.classes_[np.argmax(pred[0])]
        confidence = np.max(pred[0])

        plt.figure(figsize=(12, 4))

        # Show image
        plt.subplot(1, 2, 1)
        plt.imshow(self.X_test[image_idx])
        plt.title(f'True: {true_label}\nPredicted: {pred_label}\nConfidence: {confidence:.2f}')
        plt.axis('off')

        # Show prediction probabilities
        plt.subplot(1, 2, 2)
        plt.bar(self.label_encoder.classes_, pred[0])
        plt.title('Prediction Probabilities')
        plt.ylabel('Probability')
        plt.xticks(rotation=45)

        plt.tight_layout()
        plt.show()

        return pred_label, confidence

In [65]:
class ColabForgeryDetector:
    def __init__(self, img_size=(224, 224), batch_size=16):
        self.img_size = img_size
        self.batch_size = batch_size
        self.label_encoder = LabelEncoder()
        self.X = None
        self.y_raw = None
        self.y_encoded = None
        self.y_categorical = None
        self.X_train = None
        self.X_val = None
        self.X_test = None
        self.y_train = None
        self.y_val = None
        self.y_test = None
        self.model = None
        self.history = None
        self.train_gen = None
        self.val_gen = None

    def create_synthetic_dataset(self, samples_per_class=200):
        """
        Create synthetic dataset for demonstration (replace with real data loading)
        """
        print("🎲 Creating synthetic dataset for demonstration...")

        images = []
        labels = []

        class_names = ['COVID-19', 'Viral_Pneumonia', 'Normal']

        for class_name in class_names:
            print(f"   Generating {samples_per_class} samples for {class_name}")

            for i in range(samples_per_class):
                # Generate realistic-looking X-ray synthetic data
                base_img = np.random.normal(0.3, 0.1, (*self.img_size, 3))

                # Add class-specific patterns
                if class_name == 'COVID-19':
                    # Add some ground-glass opacity patterns
                    for _ in range(5):
                        x, y = np.random.randint(50, self.img_size[0]-50, 2)
                        cv2.circle(base_img, (x, y), np.random.randint(10, 30),
                                 (0.6, 0.6, 0.6), -1)

                elif class_name == 'Viral_Pneumonia':
                    # Add consolidation patterns
                    for _ in range(3):
                        x, y = np.random.randint(30, self.img_size[0]-30, 2)
                        cv2.rectangle(base_img, (x, y),
                                    (x+40, y+40), (0.7, 0.7, 0.7), -1)

                # Normalize and clip
                base_img = np.clip(base_img, 0, 1)
                images.append(base_img)
                labels.append(class_name)

        self.X = np.array(images, dtype=np.float32)
        self.y_raw = np.array(labels)

        # Encode labels
        self.y_encoded = self.label_encoder.fit_transform(self.y_raw)
        self.y_categorical = to_categorical(self.y_encoded, num_classes=3)

        print(f"✅ Dataset created: {len(self.X)} images")
        print(f"   Classes: {list(self.label_encoder.classes_)}")

        # Return the data as well, although it's now stored in self.X and self.y_categorical
        return self.X, self.y_categorical

    def download_kaggle_dataset(self):
        """
        Download the actual Kaggle dataset (requires Kaggle API setup)
        """
        if not IN_COLAB:
            print("⚠️  This function is designed for Google Colab")
            return

        try:
            # Mount Google Drive to access Kaggle credentials
            from google.colab import drive
            drive.mount('/content/drive')

            # Install Kaggle
            os.system('pip install kaggle')

            # Setup Kaggle credentials
            os.environ['KAGGLE_CONFIG_DIR'] = '/content/drive/MyDrive/kaggle'

            # Download dataset
            os.system('kaggle datasets download -d nourmahmoud/covid19-digital-xrays-forgery-dataset')

            # Extract dataset
            with zipfile.ZipFile('covid19-digital-xrays-forgery-dataset.zip', 'r') as zip_ref:
                zip_ref.extractall('/content/dataset')

            print("✅ Kaggle dataset downloaded and extracted!")
            return '/content/dataset'

        except Exception as e:
            print(f"❌ Error downloading dataset: {e}")
            print("💡 Using synthetic data instead...")
            return None

    def visualize_samples(self, num_samples=8):
        """
        Visualize sample images
        """
        if self.X is None or self.y_raw is None:
             print("⚠️ Dataset not loaded. Please create or download the dataset first.")
             return

        plt.figure(figsize=(15, 8))
        random_indices = random.sample(range(len(self.X)), min(num_samples, len(self.X)))

        for i, idx in enumerate(random_indices):
            plt.subplot(2, 4, i + 1)
            plt.imshow(self.X[idx])
            plt.title(f'{self.y_raw[idx]}')
            plt.axis('off')

        plt.suptitle('Sample X-ray Images', fontsize=16)
        plt.tight_layout()
        plt.show()

    def split_data(self, test_size=0.2, val_size=0.1):
        """
        Split data with stratification
        """
        if self.X is None or self.y_categorical is None:
            print("⚠️ Dataset not loaded. Please create or download the dataset first.")
            return

        print("📊 Splitting data...")

        # First split
        X_temp, self.X_test, y_temp, self.y_test = train_test_split(
            self.X, self.y_categorical, test_size=test_size,
            random_state=42, stratify=self.y_categorical
        )

        # Second split
        val_size_adj = val_size / (1 - test_size)
        self.X_train, self.X_val, self.y_train, self.y_val = train_test_split(
            X_temp, y_temp, test_size=val_size_adj,
            random_state=42, stratify=y_temp
        )

        print(f"   Training: {self.X_train.shape[0]} samples")
        print(f"   Validation: {self.X_val.shape[0]} samples")
        print(f"   Test: {self.X_test.shape[0]} samples")

    def build_optimized_model(self):
        """
        Build Colab-optimized CNN model
        """
        print("🏗️  Building optimized CNN model...")

        model = Sequential([
            # Efficient convolutional base
            Conv2D(32, (3, 3), activation='relu', input_shape=(*self.img_size, 3)),
            BatchNormalization(),
            MaxPooling2D((2, 2)),

            Conv2D(64, (3, 3), activation='relu'),
            BatchNormalization(),
            MaxPooling2D((2, 2)),

            Conv2D(128, (3, 3), activation='relu'),
            BatchNormalization(),
            MaxPooling2D((2, 2)),

            # Global average pooling instead of flatten (more efficient)
            tf.keras.layers.GlobalAveragePooling2D(),

            # Dense layers as specified in your requirements
            Dense(300, activation='relu'),
            Dropout(0.5),

            Dense(150, activation='relu'),
            Dropout(0.4),

            Dense(75, activation='relu'),
            Dropout(0.3),

            Dense(50, activation='relu'),
            Dropout(0.2),

            Dense(25, activation='relu'),
            Dropout(0.1),

            # Output layer with float32 for stability
            Dense(3, activation='softmax', dtype='float32')
        ])

        # Use Adam with learning rate scheduling
        optimizer = Adam(learning_rate=0.001)

        model.compile(
            optimizer=optimizer,
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )

        self.model = model
        print("✅ Model built successfully!")

        # Display model summary
        model.summary()
        return model

    def create_data_generators(self):
        """
        Create memory-efficient data generators
        """
        if self.X_train is None or self.y_train is None or self.X_val is None or self.y_val is None:
             print("⚠️ Data not split. Please split the data first.")
             return None, None

        # Conservative augmentation for Colab
        train_datagen = ImageDataGenerator(
            rotation_range=10,
            width_shift_range=0.1,
            height_shift_range=0.1,
            horizontal_flip=True,
            zoom_range=0.1,
            fill_mode='nearest'
        )

        val_datagen = ImageDataGenerator()

        self.train_gen = train_datagen.flow(
            self.X_train, self.y_train,
            batch_size=self.batch_size,
            shuffle=True
        )

        self.val_gen = val_datagen.flow(
            self.X_val, self.y_val,
            batch_size=self.batch_size,
            shuffle=False
        )

        return self.train_gen, self.val_gen

    def train_with_callbacks(self, epochs=30):
        """
        Train with Colab-optimized callbacks
        """
        if self.model is None:
            print("⚠️ Model not built. Please build the model first.")
            return None
        if self.train_gen is None or self.val_gen is None:
             print("⚠️ Data generators not created. Please create data generators first.")
             return None


        print(f"🚂 Training model for {epochs} epochs...")

        callbacks = [
            tf.keras.callbacks.EarlyStopping(
                monitor='val_accuracy',
                patience=8,
                restore_best_weights=True,
                verbose=1
            ),
            tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.3,
                patience=4,
                min_lr=1e-7,
                verbose=1
            ),
            tf.keras.callbacks.ModelCheckpoint(
                'best_model.h5',
                monitor='val_accuracy',
                save_best_only=True,
                verbose=1
            )
        ]

        # Calculate steps
        steps_per_epoch = len(self.X_train) // self.batch_size
        validation_steps = len(self.X_val) // self.batch_size

        self.history = self.model.fit(
            self.train_gen,
            steps_per_epoch=steps_per_epoch,
            epochs=epochs,
            validation_data=self.val_gen,
            validation_steps=validation_steps,
            callbacks=callbacks,
            verbose=1
        )

        print("✅ Training completed!")
        return self.history

    def evaluate_and_visualize(self):
        """
        Comprehensive evaluation with visualizations
        """
        if self.model is None:
            print("⚠️ Model not built. Please build the model first.")
            return None
        if self.X_test is None or self.y_test is None:
             print("⚠️ Test data not available. Please split the data and train the model first.")
             return None
        if self.history is None:
            print("⚠️ Model not trained. Please train the model first.")
            return None


        print("📈 Evaluating model performance...")

        # Predictions
        test_pred = self.model.predict(self.X_test, batch_size=self.batch_size)
        test_pred_classes = np.argmax(test_pred, axis=1)
        test_true_classes = np.argmax(self.y_test, axis=1)

        # Accuracy
        accuracy = accuracy_score(test_true_classes, test_pred_classes)
        print(f"🎯 Test Accuracy: {accuracy:.4f}")

        # Create comprehensive visualization
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))

        # 1. Training History
        axes[0, 0].plot(self.history.history['accuracy'], label='Training')
        axes[0, 0].plot(self.history.history['val_accuracy'], label='Validation')
        axes[0, 0].set_title('Model Accuracy')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Accuracy')
        axes[0, 0].legend()
        axes[0, 0].grid(True)

        # 2. Loss History
        axes[0, 1].plot(self.history.history['loss'], label='Training')
        axes[0, 1].plot(self.history.history['val_loss'], label='Validation')
        axes[0, 1].set_title('Model Loss')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True)

        # 3. Confusion Matrix
        cm = confusion_matrix(test_true_classes, test_pred_classes)
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[1, 0],
                   xticklabels=self.label_encoder.classes_,
                   yticklabels=self.label_encoder.classes_)
        axes[1, 0].set_title('Confusion Matrix')
        axes[1, 0].set_ylabel('True Label')
        axes[1, 0].set_xlabel('Predicted Label')

        # 4. Class Distribution
        unique, counts = np.unique(test_pred_classes, return_counts=True)
        axes[1, 1].bar([self.label_encoder.classes_[i] for i in unique], counts)
        axes[1, 1].set_title('Prediction Distribution')
        axes[1, 1].set_ylabel('Count')
        axes[1, 1].tick_params(axis='x', rotation=45)

        plt.tight_layout()
        plt.show()

        # Classification Report
        print("\n📊 Detailed Classification Report:")
        print(classification_report(test_true_classes, test_pred_classes,
                                  target_names=self.label_encoder.classes_))

        return accuracy

    def predict_sample(self, image_idx=None):
        """
        Predict and visualize a sample
        """
        if self.model is None:
            print("⚠️ Model not built. Please build and train the model first.")
            return None, None
        if self.X_test is None or self.y_test is None:
             print("⚠️ Test data not available. Please split the data and train the model first.")
             return None, None

        if image_idx is None:
            image_idx = random.randint(0, len(self.X_test) - 1)

        img = self.X_test[image_idx:image_idx+1]
        true_label = self.label_encoder.classes_[np.argmax(self.y_test[image_idx])]

        pred = self.model.predict(img, verbose=0)
        pred_label = self.label_encoder.classes_[np.argmax(pred[0])]
        confidence = np.max(pred[0])

        plt.figure(figsize=(12, 4))

        # Show image
        plt.subplot(1, 2, 1)
        plt.imshow(self.X_test[image_idx])
        plt.title(f'True: {true_label}\nPredicted: {pred_label}\nConfidence: {confidence:.2f}')
        plt.axis('off')

        # Show prediction probabilities
        plt.subplot(1, 2, 2)
        plt.bar(self.label_encoder.classes_, pred[0])
        plt.title('Prediction Probabilities')
        plt.ylabel('Probability')
        plt.xticks(rotation=45)

        plt.tight_layout()
        plt.show()

        return pred_label, confidence