# NIH Chest X-ray — Full Dataset Training (DenseNet vs ResNet)

This notebook is adapted to the **official NIH Chest X-ray dataset structure**:
- `Data_Entry_2017.csv` — labels
- `train_val_list.txt` and `test_list.txt` — file splits
- `images_001/ ... images_012/` — subfolders with PNG images

### Workflow
1. Load metadata and split files
2. Build full filepaths from subfolders
3. Encode multilabel classes
4. Create efficient `tf.data` pipelines with augmentation
5. Train **DenseNet121** and **ResNet50** with transfer learning (warmup + fine-tune)
6. Use callbacks: EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
7. Evaluate models with ROC AUC (macro, micro, per-class)
8. Plot learning curves

In [None]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import tensorflow as tf

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

import matplotlib.pyplot as plt

print("TensorFlow:", tf.__version__)
print("GPUs:", tf.config.list_physical_devices('GPU'))

# Enable mixed precision if available
try:
    from tensorflow.keras import mixed_precision
    mixed_precision.set_global_policy('mixed_float16')
    print("Mixed precision policy:", mixed_precision.global_policy())
except Exception as e:
    print("Mixed precision not set:", e)

In [None]:
# Paths
BASE_DIR = Path("/path/to/NIH-ChestXray")   # <-- CHANGE THIS
IMAGES_DIR = BASE_DIR / "images"
CSV_FILE   = BASE_DIR / "Data_Entry_2017.csv"
TRAINVAL_TXT = BASE_DIR / "train_val_list.txt"
TEST_TXT     = BASE_DIR / "test_list.txt"

# Hyperparameters
IMG_SIZE = 224
BATCH_SIZE = 32
SEED = 42

WARMUP_EPOCHS = 3
FINETUNE_EPOCHS = 10
INIT_LR = 1e-3
FT_LR   = 1e-5

In [None]:
# Load file lists
trainval_list = set(TRAINVAL_TXT.read_text().splitlines())
test_list     = set(TEST_TXT.read_text().splitlines())
print("Train/Val list:", len(trainval_list))
print("Test list:", len(test_list))

In [None]:
# Load metadata
df = pd.read_csv(CSV_FILE)
df['Finding Labels'] = df['Finding Labels'].str.split('|')

# Helper: find full path across subfolders
def find_path(img_name):
    for sub in IMAGES_DIR.iterdir():
        candidate = sub / img_name
        if candidate.exists():
            return str(candidate)
    return None

df['filepath'] = df['Image Index'].apply(find_path)

# Split according to txt lists
trainval_df = df[df['Image Index'].isin(trainval_list)].reset_index(drop=True)
test_df     = df[df['Image Index'].isin(test_list)].reset_index(drop=True)

print("Train/Val:", len(trainval_df), " Test:", len(test_df))

In [None]:
# Encode multilabels
mlb = MultiLabelBinarizer()
mlb.fit(df['Finding Labels'])
CLASSES = list(mlb.classes_)
print("Classes:", CLASSES)

# Apply encoding
Y_trainval = mlb.transform(trainval_df['Finding Labels'])
Y_test     = mlb.transform(test_df['Finding Labels'])

# Split train/val internally (90/10 stratified)
train_df, val_df, Y_train, Y_val = train_test_split(
    trainval_df, Y_trainval,
    test_size=0.1, random_state=SEED,
    stratify=trainval_df['Finding Labels'].apply(lambda x: str(x))
)

print("Train:", len(train_df), "Val:", len(val_df), "Test:", len(test_df))

In [None]:
# TF Dataset pipeline
AUTOTUNE = tf.data.AUTOTUNE

def decode_resize(path, label):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=1)
    img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
    img = img / 255.0
    img = tf.image.grayscale_to_rgb(img)
    return img, label

def augment(img, label):
    img = tf.image.random_brightness(img, 0.05)
    img = tf.image.random_contrast(img, 0.95, 1.05)
    img = tf.image.random_flip_left_right(img)
    return img, label

def make_ds(paths, labels, training=True):
    ds = tf.data.Dataset.from_tensor_slices((paths, labels))
    if training:
        ds = ds.shuffle(buffer_size=len(paths), seed=SEED, reshuffle_each_iteration=True)
    ds = ds.map(decode_resize, num_parallel_calls=AUTOTUNE)
    if training:
        ds = ds.map(augment, num_parallel_calls=AUTOTUNE)
    ds = ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
    return ds

train_ds = make_ds(train_df['filepath'].values, Y_train, training=True)
val_ds   = make_ds(val_df['filepath'].values,   Y_val, training=False)
test_ds  = make_ds(test_df['filepath'].values,  Y_test, training=False)

In [None]:
# Model builders
from tensorflow.keras.applications import DenseNet121, ResNet50
from tensorflow.keras import layers, Model

def build_densenet(input_shape=(IMG_SIZE, IMG_SIZE, 3), num_classes=None):
    if num_classes is None:
        num_classes = len(CLASSES)
    base = DenseNet121(include_top=False, weights='imagenet', input_shape=input_shape)
    x = layers.GlobalAveragePooling2D()(base.output)
    x = layers.Dropout(0.3)(x)
    out = layers.Dense(num_classes, activation='sigmoid', dtype='float32')(x)
    return Model(base.input, out)

def build_resnet(input_shape=(IMG_SIZE, IMG_SIZE, 3), num_classes=None):
    if num_classes is None:
        num_classes = len(CLASSES)
    base = ResNet50(include_top=False, weights='imagenet', input_shape=input_shape)
    x = layers.GlobalAveragePooling2D()(base.output)
    x = layers.Dropout(0.3)(x)
    out = layers.Dense(num_classes, activation='sigmoid', dtype='float32')(x)
    return Model(base.input, out)

In [None]:
# Callbacks
def get_callbacks(name):
    return [
        tf.keras.callbacks.EarlyStopping(monitor='val_auc', patience=3, mode='max', restore_best_weights=True),
        tf.keras.callbacks.ReduceLROnPlateau(monitor='val_auc', factor=0.2, patience=2, mode='max', min_lr=1e-7),
        tf.keras.callbacks.ModelCheckpoint(f"best_{name}.h5", monitor='val_auc', save_best_only=True, mode='max')
    ]

In [None]:
# Training function (warmup + fine-tune)
def compile_model(model, lr):
    auc_roc = tf.keras.metrics.AUC(curve='ROC', multi_label=True, num_labels=len(CLASSES), name='auc')
    model.compile(optimizer=tf.keras.optimizers.Adam(lr),
                  loss=tf.keras.losses.BinaryCrossentropy(),
                  metrics=[auc_roc])
    return model

def warmup_and_finetune(model_builder, name_prefix):
    model = model_builder()
    for layer in model.layers:
        if isinstance(layer, tf.keras.Model):
            layer.trainable = False
    compile_model(model, lr=INIT_LR)
    print(f"\n[Warmup] Training {name_prefix}...")
    model.fit(train_ds, validation_data=val_ds, epochs=WARMUP_EPOCHS, callbacks=get_callbacks(name_prefix))
    for layer in model.layers[-50:]:
        layer.trainable = True
    compile_model(model, lr=FT_LR)
    print(f"\n[Fine-tune] Training {name_prefix}...")
    history = model.fit(train_ds, validation_data=val_ds, epochs=FINETUNE_EPOCHS, callbacks=get_callbacks(name_prefix))
    return model, history

In [None]:
# Train models
densenet_model, dn_hist = warmup_and_finetune(build_densenet, "DenseNet")
resnet_model, rn_hist   = warmup_and_finetune(build_resnet, "ResNet")

In [None]:
# Plot training curves
def plot_history(histories, metric='auc'):
    plt.figure(figsize=(8,6))
    for name, hist in histories.items():
        plt.plot(hist.history[metric], label=f"{name} train")
        plt.plot(hist.history[f"val_{metric}"], label=f"{name} val")
    plt.title(metric.upper())
    plt.xlabel('Epoch')
    plt.ylabel(metric.upper())
    plt.legend()
    plt.show()

plot_history({'DenseNet': dn_hist, 'ResNet': rn_hist}, metric='auc')
plot_history({'DenseNet': dn_hist, 'ResNet': rn_hist}, metric='loss')

In [None]:
# Final evaluation on test set
def evaluate_model(model, dataset, y_true):
    y_pred = model.predict(dataset)
    macro = roc_auc_score(y_true, y_pred, average='macro')
    micro = roc_auc_score(y_true, y_pred, average='micro')
    per_class = dict(zip(CLASSES, roc_auc_score(y_true, y_pred, average=None)))
    return macro, micro, per_class

print("DenseNet Test Results:", evaluate_model(densenet_model, test_ds, Y_test))
print("ResNet Test Results:", evaluate_model(resnet_model, test_ds, Y_test))

In [None]:
from sklearn.metrics import roc_curve, auc
import itertools

def plot_per_class_roc(model1, model2, y_true, ds, name1="DenseNet", name2="ResNet"):
    y_pred1 = model1.predict(ds)
    y_pred2 = model2.predict(ds)

    n_classes = len(CLASSES)
    fig, axes = plt.subplots(nrows=5, ncols=3, figsize=(18, 24))
    axes = axes.flatten()

    for i, cls in enumerate(CLASSES):
        fpr1, tpr1, _ = roc_curve(y_true[:, i], y_pred1[:, i])
        fpr2, tpr2, _ = roc_curve(y_true[:, i], y_pred2[:, i])
        auc1 = auc(fpr1, tpr1)
        auc2 = auc(fpr2, tpr2)

        ax = axes[i]
        ax.plot(fpr1, tpr1, label=f"{name1} (AUC={auc1:.2f})")
        ax.plot(fpr2, tpr2, label=f"{name2} (AUC={auc2:.2f})")
        ax.plot([0,1],[0,1],'k--')
        ax.set_title(cls)
        ax.set_xlabel('FPR')
        ax.set_ylabel('TPR')
        ax.legend()

    plt.tight_layout()
    plt.show()

# Plot ROC per class on test set
plot_per_class_roc(densenet_model, resnet_model, Y_test, test_ds)