# Pneumonia Detection with ResNet and CheXpert

In this notebook, we'll use TensorFlow and Keras to build a pneumonia classifier using a ResNet-like model on a subset of the CheXpert dataset (simulated using the Kaggle Chest X-ray dataset).

In [None]:
!pip install -q gdown
!gdown 100vev6txE-xbjWxP8k7nFqouTU8gAyyp
!unzip -qq data.zip

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

## Load and Preprocess the Dataset

### Add some augmentation to the images for better result

In [None]:
# Define paths
train_dir = 'data/train'
val_dir = 'data/val'

# Image generators (Fine tuned to X-Ray images)
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.01,
    zoom_range=0.1,
    horizontal_flip=True,
    fill_mode='nearest'
)

val_datagen = ImageDataGenerator(rescale=1./255)

# Load data
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=16,
    class_mode='categorical'
)

val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(224, 224),
    batch_size=16,
    class_mode='categorical',
    shuffle=False
)

class_names = list(train_generator.class_indices.keys())
print(f"Classes: {class_names}")

### Visualization of the augmented images

In [None]:
images, labels = next(train_generator)

# Plot the first 8 images in the batch
plt.figure(figsize=(12, 6))
for i in range(8):
    ax = plt.subplot(2, 4, i + 1)
    plt.imshow(images[i], cmap='gray')
    label_index = np.argmax(labels[i])
    plt.title(class_names[label_index])
    plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
from sklearn.utils import class_weight

class_weights = class_weight.compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_generator.classes),
    y=train_generator.classes
)

class_weight_dict = dict(enumerate(class_weights))
print("Class Weights:", class_weight_dict)

## Build and Compile the Model

In [None]:
# Load base model
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
base_model.trainable = True

# Fine-tune last few layers
for layer in base_model.layers[:-40]:  # Keep earlier layers frozen
    layer.trainable = False

# Add custom head
x = base_model.output
x = GlobalAveragePooling2D()(x)
predictions = Dense(2, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)

# Compile
model.compile(optimizer=Adam(learning_rate=1e-5), loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

## Train the Model

In [None]:
early_stop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=3,
    class_weight=class_weight_dict,
    callbacks=[early_stop]
)

## Evaluate the Model

In [None]:
# Predict on validation set
val_generator.reset()
preds = model.predict(val_generator, verbose=0)
y_pred = np.argmax(preds, axis=1)
y_true = val_generator.classes

# Normalize by row (i.e., precision per true class)
cm = confusion_matrix(y_true, y_pred, normalize='true')

disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
disp.plot(cmap='Blues', values_format=".2f")  # Shows numbers with 2 decimal places
plt.title("Normalized Confusion Matrix (Row-wise)")
plt.show()

## Show Sample Predictions

In [None]:
import random

def show_sample_predictions(generator, model, class_names, num_images=10, rows=2, cols=5):
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
    axes = axes.flatten()

    for i in range(num_images):
        index = random.randint(0, len(generator.filenames) - 1)
        img_path = os.path.join(generator.directory, generator.filenames[index])
        img = tf.keras.preprocessing.image.load_img(img_path, target_size=(224, 224))
        img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0
        prediction = model.predict(np.expand_dims(img_array, axis=0), verbose=0)
        pred_class = class_names[np.argmax(prediction)]
        true_class = class_names[generator.classes[index]]

        axes[i].imshow(img)
        axes[i].set_title(f"Pred: {pred_class}\nTrue: {true_class}", fontsize=10)
        axes[i].axis("off")

    # Hide any unused axes if num_images < rows*cols
    for i in range(num_images, len(axes)):
        axes[i].axis("off")

    plt.tight_layout()
    plt.show()

show_sample_predictions(val_generator, model, class_names, num_images=10, rows=2, cols=5)
