In [1]:
import os
import sys
import argparse
import numpy as np
import matplotlib.pyplot as plt
import itertools
import cv2
from PIL import Image
import io
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, roc_curve, auc
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.models import load_model
from flask import Flask, request, jsonify, send_file

In [2]:
IMG_SIZE = 224
BATCH_SIZE = 16
SEED = 42
NUM_EPOCHS = 12
LEARNING_RATE = 1e-4
MODEL_PATH = "covid_effnetb0.h5"
CLASS_MODE = "binary"  # binary classification: COVID (1) vs Normal (0)
DATA_DIR = "dataset"  # change if your dataset folder is different

# -------------

In [5]:
# Data generators
# -------------------------
def get_generators(train_dir=os.path.join(DATA_DIR,"train"),
                   val_dir=os.path.join(DATA_DIR,"val"),
                   test_dir=os.path.join(DATA_DIR,"test"),
                   img_size=IMG_SIZE, batch_size=BATCH_SIZE):
    # augmentation for training
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=12,
        width_shift_range=0.08,
        height_shift_range=0.08,
        shear_range=0.08,
        zoom_range=0.08,
        horizontal_flip=True,
        fill_mode="nearest"
    )

    # for validation and test: only rescale
    val_datagen = ImageDataGenerator(rescale=1./255)
    test_datagen = ImageDataGenerator(rescale=1./255)

    train_gen = train_datagen.flow_from_directory(
        train_dir,
        target_size=(img_size, img_size),
        batch_size=batch_size,
        class_mode=CLASS_MODE,
        seed=SEED
    )

    val_gen = val_datagen.flow_from_directory(
        val_dir,
        target_size=(img_size, img_size),
        batch_size=batch_size,
        class_mode=CLASS_MODE,
        shuffle=False
    )

    test_gen = test_datagen.flow_from_directory(
        test_dir,
        target_size=(img_size, img_size),
        batch_size=batch_size,
        class_mode=CLASS_MODE,
        shuffle=False
    )

    return train_gen, val_gen, test_gen

In [38]:
# Build model (transfer learning)
# -------------------------
def build_model(img_size=IMG_SIZE, lr=LEARNING_RATE, fine_tune_at=None):
    base = EfficientNetB0(weights="imagenet", include_top=False, input_shape=(img_size, img_size, 3))
    base.trainable = False

    x = base.output
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.4)(x)
    x = layers.Dense(128, activation="relu")(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(1, activation="sigmoid")(x)

    model = models.Model(inputs=base.input, outputs=outputs)

    model.compile(optimizer=optimizers.Adam(lr), loss="binary_crossentropy",
                  metrics=["accuracy", tf.keras.metrics.AUC(name="auc")])

    # Optionally unfreeze a portion for fine-tuning
    if fine_tune_at is not None:
        base.trainable = True
        for layer in base.layers[:fine_tune_at]:
            layer.trainable = False
        # recompile with a lower LR
        model.compile(optimizer=optimizers.Adam(lr/10),
                      loss="binary_crossentropy",
                      metrics=["accuracy", tf.keras.metrics.AUC(name="auc")])

    return model

In [34]:
# -------------------------
def train(save_path=MODEL_PATH, epochs=NUM_EPOCHS):
    train_gen, val_gen, _ = get_generators()
    model = build_model()

    callbacks = [
        ModelCheckpoint(save_path, monitor="val_accuracy", save_best_only=True, verbose=1),
        EarlyStopping(monitor="val_accuracy", patience=5, restore_best_weights=True),
        ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3, verbose=1)
    ]

    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=epochs,
        callbacks=callbacks
    )

    # Optionally fine-tune: unfreeze last blocks and train for a few more epochs
    # Example: unfreeze top 20 layers
    print("\n--- Starting fine-tuning ---\n")
    model = build_model(fine_tune_at=-20)  # unfreeze last 20 layers
    # load weights from prior best model if saved
    if os.path.exists(save_path):
        model.load_weights(save_path)
    # re-train for small number of epochs
    ft_history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=5,
        callbacks=callbacks
    )

    # save final
    model.save(save_path)
    print(f"Model saved to {save_path}")
    return history

In [14]:
# Plot training curves
# -------------------------
def plot_history(history):
    plt.figure(figsize=(12,4))

    # accuracy
    plt.subplot(1,2,1)
    plt.plot(history.history.get("accuracy", []), label="train_acc")
    plt.plot(history.history.get("val_accuracy", []), label="val_acc")
    plt.title("Accuracy")
    plt.legend()

    # loss
    plt.subplot(1,2,2)
    plt.plot(history.history.get("loss", []), label="train_loss")
    plt.plot(history.history.get("val_loss", []), label="val_loss")
    plt.title("Loss")
    plt.legend()

    plt.show()


In [36]:
# Evaluation utilities
# -------------------------
def evaluate_model(model_path=MODEL_PATH):
    _, val_gen, test_gen = get_generators()
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model weights not found at {model_path}")

    model = load_model(model_path, compile=False)
    model.compile(optimizer=optimizers.Adam(LEARNING_RATE), loss="binary_crossentropy",
                  metrics=["accuracy", tf.keras.metrics.AUC(name="auc")])

    # Evaluate on validation and test
    print("Evaluating on validation set:")
    val_results = model.evaluate(val_gen, verbose=1)
    print("Evaluating on test set:")
    test_results = model.evaluate(test_gen, verbose=1)

    # Predictions and metrics (test set)
    y_true = test_gen.classes
    y_pred_probs = model.predict(test_gen, verbose=1).ravel()
    y_pred = (y_pred_probs >= 0.5).astype(int)

    print("\nClassification Report (test set):")
    print(classification_report(y_true, y_pred, target_names=list(test_gen.class_indices.keys())))

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    print("Confusion Matrix:\n", cm)
    plot_confusion_matrix(cm, classes=list(test_gen.class_indices.keys()), normalize=True,
                          title="Normalized Confusion Matrix (Test)")

    # ROC-AUC
    try:
        roc_auc = roc_auc_score(y_true, y_pred_probs)
        print(f"Test ROC-AUC: {roc_auc:.4f}")
        plot_roc(y_true, y_pred_probs)
    except Exception as e:
        print("ROC-AUC couldn't be computed:", e)

    return {"val_results": val_results, "test_results": test_results, "cm": cm}

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=None):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-9)
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    plt.figure(figsize=(6,5))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues if cmap is None else cmap)
    plt.title(title)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()

def plot_roc(y_true, y_score):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    roc_auc = auc(fpr, tpr)
    plt.figure()
    plt.plot(fpr, tpr, lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
    plt.plot([0,1],[0,1], linestyle='--', lw=2)
    plt.xlim([0.0,1.0]); plt.ylim([0.0,1.05])
    plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate'); plt.title('ROC')
    plt.legend(loc="lower right")
    plt.show()


In [24]:
# -------------------------
# Grad-CAM
# -------------------------
def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    """Returns heatmap for Grad-CAM."""
    grad_model = tf.keras.models.Model(
        [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]
    )
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(predictions[0])
        class_channel = predictions[:, 0]  # binary classification (sigmoid)
    grads = tape.gradient(class_channel, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    conv_outputs = conv_outputs[0]
    heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-9)
    return heatmap.numpy()

def grad_cam_and_overlay(img_path, model_path=MODEL_PATH, last_conv_layer_name="top_conv"):
    """
    Produces Grad-CAM overlay and displays it.
    last_conv_layer_name depends on base model (for EfficientNetB0, 'top_conv' typically exists).
    """
    model = load_model(model_path, compile=False)
    # preprocess image
    orig = Image.open(img_path).convert("RGB").resize((IMG_SIZE, IMG_SIZE))
    img = np.array(orig) / 255.0
    input_arr = np.expand_dims(img, axis=0)

    preds = model.predict(input_arr)
    prob = preds[0][0]
    print(f"Predicted probability (COVID=1): {prob:.4f}")

    heatmap = make_gradcam_heatmap(input_arr, model, last_conv_layer_name)
    heatmap = cv2.resize(heatmap, (IMG_SIZE, IMG_SIZE))
    heatmap = np.uint8(255 * heatmap)
    heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(np.array(orig), 0.6, heatmap_color, 0.4, 0)
    plt.figure(figsize=(8,4))
    plt.subplot(1,2,1); plt.title("Original"); plt.axis('off'); plt.imshow(orig)
    plt.subplot(1,2,2); plt.title("Grad-CAM overlay"); plt.axis('off'); plt.imshow(overlay)
    plt.show()


In [26]:
# -------------------------
# Simple Flask app for inference
# -------------------------
def create_app(model_path=MODEL_PATH):
    app = Flask(__name__)
    model = None

    @app.before_first_request
    def load():
        nonlocal model
        if not os.path.exists(model_path):
            raise RuntimeError(f"Model file not found at {model_path}. Train the model first.")
        model = load_model(model_path, compile=False)
        print("Model loaded for inference.")

    @app.route("/predict", methods=["POST"])
    def predict_route():
        """
        Accepts multipart/form-data with 'file' = image
        Returns JSON: {'pred_prob': float, 'pred_class': 'COVID' or 'Normal'}
        """
        if 'file' not in request.files:
            return jsonify({"error": "no file part"}), 400
        file = request.files['file']
        # ... rest of the prediction logic