# Chest X‑Ray Classification (Pneumonia vs Normal) — Flexible Backbone

This notebook classifies **Pneumonia vs Normal chest X‑rays** using transfer learning.  
It supports **multiple pretrained backbones** so you can balance accuracy vs speed:

- `MobileNetV2` → lightweight, fastest on CPU.  
- `ResNet50` → stronger, slower on CPU.  
- `EfficientNetB3` → good balance of accuracy and efficiency.

Change the `BACKBONE` variable to switch between them.

> ⚠️ Educational use only — not for clinical diagnosis.

## 1. Imports

In [1]:
import os, pathlib
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, roc_curve

print("TensorFlow:", tf.__version__)
print("GPU available:", tf.config.list_physical_devices('GPU'))


TensorFlow: 2.18.0
GPU available: []


## 2. Configuration

In [2]:
# DATASET_ROOT = "/content/drive/MyDrive/datasets/chest_xrays_v4"
DATASET_ROOT = r"C:\\CAS AML\\project M1 and M2\\Chest X Rays v4"  # change if needed

BACKBONE = "EfficientNetB3"  # or "MobileNetV2", "ResNet50"

if BACKBONE == "MobileNetV2":
    IMG_SIZE = (224, 224)
elif BACKBONE == "ResNet50":
    IMG_SIZE = (224, 224)
elif BACKBONE == "EfficientNetB3":
    IMG_SIZE = (300, 300)
else:
    raise ValueError("Unsupported backbone")

BATCH_SIZE = 32
SEED = 123
NUM_CLASSES = 2

MODEL_PATH = os.path.join(DATASET_ROOT, f"pneumonia_xray_{BACKBONE}.keras")

print("Backbone:", BACKBONE, "| Input size:", IMG_SIZE)
print("Model path:", MODEL_PATH)


Backbone: EfficientNetB3 | Input size: (300, 300)
Model path: C:\\CAS AML\\project M1 and M2\\Chest X Rays v4\pneumonia_xray_EfficientNetB3.keras


## 3. Data Loading

In [3]:
train_dir = os.path.join(DATASET_ROOT, "train")
val_dir   = os.path.join(DATASET_ROOT, "valid")
test_dir  = os.path.join(DATASET_ROOT, "test")

train_ds = tf.keras.utils.image_dataset_from_directory(train_dir, image_size=IMG_SIZE, batch_size=BATCH_SIZE, seed=SEED)
val_ds   = tf.keras.utils.image_dataset_from_directory(val_dir,   image_size=IMG_SIZE, batch_size=BATCH_SIZE, seed=SEED)
test_ds  = tf.keras.utils.image_dataset_from_directory(test_dir,  image_size=IMG_SIZE, batch_size=BATCH_SIZE, seed=SEED)

class_names = train_ds.class_names
print("Classes:", class_names)


Found 12229 files belonging to 2 classes.
Found 1165 files belonging to 2 classes.
Found 582 files belonging to 2 classes.
Classes: ['NORMAL', 'PNEUMONIA']


## 4. Preprocessing & Augmentation

In [4]:
normalization_layer = layers.Rescaling(1./255)
data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.15),
    layers.RandomContrast(0.1),
])

def prep(ds, training=False):
    ds = ds.map(lambda x,y:(normalization_layer(x), y), num_parallel_calls=tf.data.AUTOTUNE)
    if training:
        ds = ds.map(lambda x,y:(data_augmentation(x, training=True), y), num_parallel_calls=tf.data.AUTOTUNE)
    return ds.prefetch(tf.data.AUTOTUNE)

train_ds_prep = prep(train_ds, training=True)
val_ds_prep   = prep(val_ds, training=False)
test_ds_prep  = prep(test_ds, training=False)


## 5. Model Architecture

In [5]:
def build_model(backbone):
    if backbone == "MobileNetV2":
        base = keras.applications.MobileNetV2(input_shape=IMG_SIZE+(3,), include_top=False, weights="imagenet")
        preprocess_input = keras.applications.mobilenet_v2.preprocess_input
    elif backbone == "ResNet50":
        base = keras.applications.ResNet50(input_shape=IMG_SIZE+(3,), include_top=False, weights="imagenet")
        preprocess_input = keras.applications.resnet50.preprocess_input
    elif backbone == "EfficientNetB3":
        base = keras.applications.EfficientNetB3(input_shape=IMG_SIZE+(3,), include_top=False, weights="imagenet")
        preprocess_input = keras.applications.efficientnet.preprocess_input
    else:
        raise ValueError("Unsupported backbone")
    base.trainable = False
    inputs = keras.Input(shape=IMG_SIZE+(3,))
    x = preprocess_input(inputs)
    x = base(x, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)
    return keras.Model(inputs, outputs)

model = build_model(BACKBONE)
model.summary()


Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb3_notop.h5
[1m43941136/43941136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m229s[0m 5us/step


## 6. Compile & Train/Load

In [6]:
model.compile(optimizer=keras.optimizers.Adam(1e-3), loss="sparse_categorical_crossentropy", metrics=["accuracy"])

if pathlib.Path(MODEL_PATH).exists():
    print("Loading pretrained:", MODEL_PATH)
    model = tf.keras.models.load_model(MODEL_PATH)
else:
    callbacks = [
        keras.callbacks.ModelCheckpoint(MODEL_PATH, save_best_only=True, monitor="val_accuracy"),
        keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)
    ]
    history = model.fit(train_ds_prep, validation_data=val_ds_prep, epochs=15, callbacks=callbacks)
    model.save(MODEL_PATH)


Epoch 1/15
[1m383/383[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4505s[0m 11s/step - accuracy: 0.7207 - loss: 0.5961 - val_accuracy: 0.7391 - val_loss: 0.5523
Epoch 2/15
[1m383/383[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3716s[0m 10s/step - accuracy: 0.7242 - loss: 0.5778 - val_accuracy: 0.7391 - val_loss: 0.5501
Epoch 3/15
[1m220/383[0m [32m━━━━━━━━━━━[0m[37m━━━━━━━━━[0m [1m25:53[0m 10s/step - accuracy: 0.7201 - loss: 0.5789

KeyboardInterrupt: 

## 7. Fine-tuning

In [None]:
for layer in model.layers[-50:]:
    if not isinstance(layer, layers.BatchNormalization):
        layer.trainable = True

model.compile(optimizer=keras.optimizers.Adam(1e-5), loss="sparse_categorical_crossentropy", metrics=["accuracy"])
history_ft = model.fit(train_ds_prep, validation_data=val_ds_prep, epochs=5)
model.save(MODEL_PATH)


## 8. Evaluation

In [None]:
test_results = model.evaluate(test_ds_prep, verbose=0)
print("Test metrics:", dict(zip(model.metrics_names, test_results)))

y_true, y_prob = [], []
for images, labels in test_ds_prep:
    probs = model.predict(images, verbose=0)
    y_prob.extend(probs); y_true.extend(labels.numpy())
y_true = np.array(y_true); y_prob = np.array(y_prob)
y_pred = y_prob.argmax(axis=1)

print("Confusion Matrix:\n", confusion_matrix(y_true,y_pred))
print(classification_report(y_true,y_pred, target_names=class_names, digits=4))

pos_label = 1 if "pneumonia" in class_names[1].lower() else 0
auc = roc_auc_score((y_true==pos_label).astype(int), y_prob[:,pos_label])
print("ROC-AUC:", auc)


## 9. Sample Predictions

In [None]:
for images, labels in test_ds.take(1):
    probs = model.predict(images, verbose=0)
    preds = probs.argmax(axis=1)
    plt.figure(figsize=(12,8))
    for i in range(min(8, images.shape[0])):
        ax = plt.subplot(2,4,i+1)
        plt.imshow(images[i].numpy().astype("uint8"))
        p, t = class_names[preds[i]], class_names[labels[i].numpy()]
        plt.title(f"Pred: {p}\nTrue: {t}"); plt.axis("off")
    plt.show()
