# Pneumonia Detection â€“ ResNet50 Transfer Learning

This notebook builds and trains a pneumonia detector using ResNet50 as a frozen feature extractor with a custom classification head. It augments training data, trains for 10 epochs, evaluates on the test set, saves figures to `results/`, and exports the trained model to `models/resnet50_final.h5`.


In [None]:
import os
from pathlib import Path
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

sns.set_theme(style="whitegrid")
plt.rcParams["figure.dpi"] = 120

SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)
print(tf.__version__)


In [None]:
# Paths and hyperparameters
PROJECT_ROOT = Path.cwd()
DATA_DIR = PROJECT_ROOT / "data" / "chest_xray"
RESULTS_DIR = PROJECT_ROOT / "results"
MODELS_DIR = PROJECT_ROOT / "models"
for d in [RESULTS_DIR, MODELS_DIR]:
    d.mkdir(parents=True, exist_ok=True)

IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 1e-4

print(f"Project root: {PROJECT_ROOT}")
print(f"Data dir: {DATA_DIR}")
print(f"Results dir: {RESULTS_DIR}")
print(f"Models dir: {MODELS_DIR}")

if not DATA_DIR.exists():
    raise FileNotFoundError("Dataset not found at data/chest_xray. Place it before running.")


In [None]:
# Data generators with augmentation for training
train_datagen = ImageDataGenerator(
    rescale=1.0 / 255,
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True,
)
val_test_datagen = ImageDataGenerator(rescale=1.0 / 255)

train_dir = DATA_DIR / "train"
val_dir = DATA_DIR / "val"
test_dir = DATA_DIR / "test"

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode="binary",
    shuffle=True,
    seed=SEED,
)

val_generator = val_test_datagen.flow_from_directory(
    val_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode="binary",
    shuffle=False,
    seed=SEED,
)

test_generator = val_test_datagen.flow_from_directory(
    test_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode="binary",
    shuffle=False,
    seed=SEED,
)

print("Class indices:", train_generator.class_indices)


In [None]:
def build_model(input_shape=(224, 224, 3)):
    base_model = ResNet50(
        include_top=False,
        weights="imagenet",
        input_tensor=Input(shape=input_shape),
    )
    base_model.trainable = False

    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(256, activation="relu")(x)
    x = Dropout(0.5)(x)
    outputs = Dense(1, activation="sigmoid")(x)

    model = Model(inputs=base_model.input, outputs=outputs)
    model.compile(
        optimizer=Adam(learning_rate=LEARNING_RATE),
        loss="binary_crossentropy",
        metrics=["accuracy", tf.keras.metrics.Precision(name="precision"), tf.keras.metrics.Recall(name="recall")],
    )
    return model

model = build_model(input_shape=IMG_SIZE + (3,))
model.summary()


In [None]:
checkpoint_path = MODELS_DIR / "resnet50_best.h5"
callbacks = [
    ModelCheckpoint(
        filepath=checkpoint_path,
        monitor="val_loss",
        mode="min",
        save_best_only=True,
        verbose=1,
    ),
    EarlyStopping(
        monitor="val_loss",
        patience=3,
        restore_best_weights=True,
        verbose=1,
    ),
]

steps_per_epoch = train_generator.samples // BATCH_SIZE
validation_steps = val_generator.samples // BATCH_SIZE

history = model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    epochs=EPOCHS,
    validation_data=val_generator,
    validation_steps=validation_steps,
    callbacks=callbacks,
    verbose=1,
)


In [None]:
# Plot training curves
hist_df = pd.DataFrame(history.history)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

axes[0].plot(hist_df["loss"], label="train")
axes[0].plot(hist_df["val_loss"], label="val")
axes[0].set_title("Loss")
axes[0].legend()

axes[1].plot(hist_df["accuracy"], label="train")
axes[1].plot(hist_df["val_accuracy"], label="val")
axes[1].set_title("Accuracy")
axes[1].legend()

plt.tight_layout()
train_plot_path = RESULTS_DIR / "resnet50_training_curves.png"
plt.savefig(train_plot_path, dpi=300)
print(f"Saved training curves to {train_plot_path}")
plt.show()

hist_df.to_csv(RESULTS_DIR / "resnet50_history.csv", index=False)


In [None]:
# Evaluate on test set
print("Evaluating on test set...")
test_steps = test_generator.samples // BATCH_SIZE
pred_probs = model.predict(test_generator, steps=test_steps + 1, verbose=1).ravel()
pred_labels = (pred_probs >= 0.5).astype(int)
true_labels = test_generator.classes[: len(pred_labels)]

acc = accuracy_score(true_labels, pred_labels)
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, pred_labels, average="binary", zero_division=0)
cm = confusion_matrix(true_labels, pred_labels)

metrics = {
    "accuracy": acc,
    "precision": precision,
    "recall": recall,
    "f1_score": f1,
}
print("Metrics:", metrics)

metrics_df = pd.DataFrame([metrics])
metrics_csv = RESULTS_DIR / "resnet50_metrics.csv"
metrics_df.to_csv(metrics_csv, index=False)
print(f"Saved metrics to {metrics_csv}")

# Confusion matrix plot
fig, ax = plt.subplots(figsize=(4, 4))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False, ax=ax)
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
ax.set_title("ResNet50 Confusion Matrix")
cm_path = RESULTS_DIR / "resnet50_confusion_matrix.png"
plt.tight_layout()
plt.savefig(cm_path, dpi=300)
print(f"Saved confusion matrix to {cm_path}")
plt.show()

# Save final model
final_model_path = MODELS_DIR / "resnet50_final.h5"
model.save(final_model_path)
print(f"Saved model to {final_model_path}")

# Persist metrics JSON for comparison notebook
with open(RESULTS_DIR / "resnet50_metrics.json", "w") as f:
    json.dump(metrics, f, indent=2)
print("Stored metrics JSON for comparison.")
