In [None]:
# AIML Chest X‑ray Classifier: Normal vs Pneumonia vs Tuberculosis
# Jupyter Notebook (drop-in) — Keras/TensorFlow 2.x
# -------------------------------------------------------------------
# ✅ What this gives you
# - 70/15/15 split (train/val/test)
# - Preprocessing: resize, normalization, on-the-fly augmentation
# - Strong model via transfer learning (EfficientNetV2B0 by default)
# - Regularization: Dropout, BatchNorm, Weight Decay (AdamW)
# - Metrics: accuracy, precision, recall, F1, AUC (micro/macro), confusion matrix
# - Plots: training curves, ROC curves, confusion matrix heatmap
# - Explainability: Grad-CAM heatmaps
# - Class imbalance handling: class weights (auto-computed)
# - Reproducible + clean structure; works with any folder‑structured dataset
# -------------------------------------------------------------------

# ========================= 1) SETUP =========================
import os, random, math, json, itertools, shutil
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision

# Optional: TensorFlow Addons for F1/AdamW. If not present, we will fallback.
try:
    import tensorflow_addons as tfa
    HAS_TFA = True
except Exception:
    HAS_TFA = False

print(tf.__version__, '— TF version')
print('TFA available:', HAS_TFA)

# Reproducibility
SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# Mixed precision (speeds up on modern GPUs, safe on CPU too)
try:
    mixed_precision.set_global_policy('mixed_float16')
    MP = True
except Exception:
    MP = False
print('Mixed precision:', MP)

# ========================= 2) CONFIG =========================
# Point DATA_DIR to a folder that contains *images* inside subfolders named by class, e.g.:
# DATA_DIR/
#   Normal/
#     img1.jpg, img2.jpg, ...
#   Pneumonia/
#   Tuberculosis/
# If you downloaded separate datasets (e.g., one for TB, one for Pneumonia),
# place/merge images into these three folders accordingly.

DATA_DIR = Path('data/chestxray_3class')  # <-- change to your dataset root folder
IMG_SIZE = (256, 256)
BATCH_SIZE = 32
VAL_SPLIT = 0.15
TEST_SPLIT = 0.15
TRAIN_SPLIT = 1.0 - VAL_SPLIT - TEST_SPLIT  # 0.70
AUTOTUNE = tf.data.AUTOTUNE
EPOCHS = 20
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 1e-4  # AdamW weight decay
MODEL_NAME = 'EfficientNetV2B0'  # or 'ResNet50', 'EfficientNetB0', etc.
SAVE_DIR = Path('artifacts')
SAVE_DIR.mkdir(parents=True, exist_ok=True)

# ========================= 3) OPTIONAL: KAGGLE DOWNLOAD =========================
# If you want to download via Kaggle API, place kaggle.json in ~/.kaggle/ and uncomment
# Example: a placeholder (replace DATASET with an actual Kaggle dataset slug)
# !pip install -q kaggle
# !kaggle datasets download -d <OWNER/DATASET> -p downloads/ -o
# !unzip -q downloads/DATASET.zip -d data/

# ========================= 4) COLLECT FILES & SPLIT 70/15/15 =========================
assert DATA_DIR.exists(), f"DATA_DIR not found: {DATA_DIR}"

classes = sorted([d.name for d in DATA_DIR.iterdir() if d.is_dir()])
assert set(classes) == set(['Normal','Pneumonia','Tuberculosis']), \
    f"Expected classes Normal, Pneumonia, Tuberculosis. Found: {classes}"
print('Classes:', classes)

# Gather filepaths
all_files = []
for cls in classes:
    for p in (DATA_DIR/cls).glob('**/*'):
        if p.suffix.lower() in {'.png','.jpg','.jpeg','.bmp','.tif','.tiff'}:
            all_files.append((str(p), cls))

df = pd.DataFrame(all_files, columns=['path','label'])
print('Total images:', len(df))

# Stratified split
from sklearn.model_selection import train_test_split

train_df, temp_df = train_test_split(
    df, test_size=(1.0-TRAIN_SPLIT), stratify=df['label'], random_state=SEED
)
val_df, test_df = train_test_split(
    temp_df, test_size=TEST_SPLIT/(TEST_SPLIT+VAL_SPLIT), stratify=temp_df['label'], random_state=SEED
)

print(f"Train: {len(train_df)}  Val: {len(val_df)}  Test: {len(test_df)}")

# ========================= 5) DATASET PIPELINES =========================
# Augmentation
data_augment = keras.Sequential([
    layers.Resizing(IMG_SIZE[0], IMG_SIZE[1]),
    layers.RandomFlip('horizontal'),
    layers.RandomRotation(0.05),
    layers.RandomZoom(0.1),
    layers.RandomContrast(0.1),
])

# Normalization
norm_layer = layers.Rescaling(1./255)

# Helper to build tf.data from filepaths
def make_ds(paths, labels, training=False):
    ds = tf.data.Dataset.from_tensor_slices((paths, labels))
    def _load(path, y):
        img = tf.io.read_file(path)
        img = tf.io.decode_image(img, channels=3, expand_animations=False)
        img.set_shape([None, None, 3])
        if training:
            img = data_augment(img)
        else:
            img = layers.Resizing(IMG_SIZE[0], IMG_SIZE[1])(img)
        img = norm_layer(img)
        return img, y
    # map, cache, shuffle (if training), batch, prefetch
    ds = ds.map(_load, num_parallel_calls=AUTOTUNE)
    if training:
        ds = ds.shuffle(2048, seed=SEED, reshuffle_each_iteration=True)
    ds = ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
    return ds

label2idx = {c:i for i,c in enumerate(classes)}
idx2label = {i:c for c,i in label2idx.items()}

train_paths = train_df['path'].tolist()
train_labels = train_df['label'].map(label2idx).astype(np.int32).values
val_paths   = val_df['path'].tolist()
val_labels  = val_df['label'].map(label2idx).astype(np.int32).values
test_paths  = test_df['path'].tolist()
test_labels = test_df['label'].map(label2idx).astype(np.int32).values

train_ds = make_ds(train_paths, train_labels, training=True)
val_ds   = make_ds(val_paths, val_labels, training=False)
test_ds  = make_ds(test_paths, test_labels, training=False)

# ========================= 6) CLASS WEIGHTS (IMBALANCE) =========================
from collections import Counter
ctr = Counter(train_labels)
class_count = np.array([ctr[i] for i in range(len(classes))])
class_weights = {i: (sum(class_count) / (len(classes) * class_count[i])) for i in range(len(classes))}
print('Class counts:', dict(ctr))
print('Class weights:', class_weights)

# ========================= 7) BUILD MODEL (TRANSFER LEARNING) =========================
# Switchable backbones
BACKBONES = {
    'EfficientNetV2B0': keras.applications.efficientnet_v2.EfficientNetV2B0,
    'EfficientNetB0': keras.applications.efficientnet.EfficientNetB0,
    'ResNet50': keras.applications.resnet50.ResNet50,
}
assert MODEL_NAME in BACKBONES, f"Unsupported MODEL_NAME: {MODEL_NAME}"

base = BACKBONES[MODEL_NAME](
    include_top=False,
    weights='imagenet',
    input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3)
)
base.trainable = False  # freeze backbone for warmup

inputs = layers.Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
x = inputs
x = norm_layer(x)  # safety, though already applied in pipeline
x = base(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.BatchNormalization()(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(len(classes), activation='softmax', dtype='float32')(x)
model = keras.Model(inputs, outputs)

# Optimizer with weight decay
if HAS_TFA:
    optimizer = tfa.optimizers.AdamW(learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
else:
    optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)

model.compile(
    optimizer=optimizer,
    loss='sparse_categorical_crossentropy',
    metrics=[
        'accuracy',
        keras.metrics.Precision(name='precision'),
        keras.metrics.Recall(name='recall'),
        keras.metrics.AUC(name='auc_ovr', multi_label=False),
        keras.metrics.AUC(name='auc_ovo', curve='ROC', multi_label=False)
    ]
)

model.summary()

# Callbacks
ckpt_path = str(SAVE_DIR/'best_model.keras')
callbacks = [
    keras.callbacks.ModelCheckpoint(ckpt_path, monitor='val_accuracy', save_best_only=True, mode='max'),
    keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2)
]

# Warmup training (frozen backbone)
history1 = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=max(3, min(6, EPOCHS//4)),
    class_weight=class_weights,
    callbacks=callbacks,
    verbose=1
)

# Fine-tune: unfreeze last N blocks
base.trainable = True
# Optionally, we can set a fine-tune starting layer for large backbones
if MODEL_NAME.startswith('ResNet'):
    for layer in base.layers[:-50]:
        layer.trainable = False
elif MODEL_NAME.startswith('EfficientNet'):
    for layer in base.layers[:-50]:
        layer.trainable = False

if HAS_TFA:
    optimizer_ft = tfa.optimizers.AdamW(learning_rate=LEARNING_RATE/5, weight_decay=WEIGHT_DECAY)
else:
    optimizer_ft = keras.optimizers.Adam(learning_rate=LEARNING_RATE/5)

model.compile(
    optimizer=optimizer_ft,
    loss='sparse_categorical_crossentropy',
    metrics=[
        'accuracy',
        keras.metrics.Precision(name='precision'),
        keras.metrics.Recall(name='recall'),
        keras.metrics.AUC(name='auc_macro', multi_label=True, num_labels=len(classes), average='macro'),
        keras.metrics.AUC(name='auc_micro', multi_label=True, num_labels=len(classes), average='micro')
    ]
)

history2 = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    class_weight=class_weights,
    callbacks=callbacks,
    verbose=1
)

# ========================= 8) TRAINING CURVES =========================
def plot_history(hlist, key, title=None):
    plt.figure()
    vals = []
    for h in hlist:
        vals += h.history.get(key, [])
    plt.plot(vals)
    plt.title(title or key)
    plt.xlabel('epoch')
    plt.ylabel(key)
    plt.grid(True)
    plt.show()

plot_history([history1, history2], 'loss', 'Training Loss')
plot_history([history1, history2], 'val_loss', 'Validation Loss')
plot_history([history1, history2], 'accuracy', 'Training Accuracy')
plot_history([history1, history2], 'val_accuracy', 'Validation Accuracy')

# ========================= 9) EVALUATION: METRICS & CONFUSION MATRIX =========================
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc

# Predict
y_true = test_labels
probs = model.predict(test_ds, verbose=1)
y_pred = np.argmax(probs, axis=1)

# Classification report (precision, recall, f1 per class)
report = classification_report(y_true, y_pred, target_names=classes, output_dict=True)
print(pd.DataFrame(report).T)

# Save report
pd.DataFrame(report).T.to_csv(SAVE_DIR/'classification_report.csv')

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
fig = plt.figure()
plt.imshow(cm, interpolation='nearest')
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j, i, format(cm[i, j], 'd'),
             horizontalalignment="center",
             color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
plt.show()

# ROC & AUC (one-vs-rest)
from sklearn.preprocessing import label_binarize

y_true_bin = label_binarize(y_true, classes=list(range(len(classes))))
fig = plt.figure()
for i, cls in enumerate(classes):
    fpr, tpr, _ = roc_curve(y_true_bin[:, i], probs[:, i])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f"{cls} (AUC = {roc_auc:.3f})")
plt.plot([0,1],[0,1],'--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves (OvR)')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()

# ========================= 10) SAVE MODEL & LABEL MAP =========================
model.save(SAVE_DIR/'final_model.keras')
with open(SAVE_DIR/'labels.json','w') as f:
    json.dump(idx2label, f)
print('Saved to', SAVE_DIR)

# ========================= 11) GRAD-CAM FOR EXPLAINABILITY =========================
# Utility functions adapted for Keras models

def get_last_conv_layer(model):
    # Find the last Conv2D layer
    for layer in reversed(model.layers):
        if isinstance(layer, layers.Conv2D) or 'conv' in layer.name.lower():
            return layer.name
    # If the backbone is nested, try inside
    for layer in reversed(model.layers):
        if hasattr(layer, 'layers'):
            for l2 in reversed(layer.layers):
                if isinstance(l2, layers.Conv2D) or 'conv' in l2.name.lower():
                    return l2.name
    raise ValueError('No conv layer found for Grad-CAM')

last_conv_name = get_last_conv_layer(model)
print('Grad-CAM last conv layer:', last_conv_name)

@tf.function
def grad_cam(img_tensor, class_index=None):
    grad_model = keras.models.Model([model.inputs], [model.get_layer(last_conv_name).output, model.output])
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img_tensor)
        if class_index is None:
            class_index = tf.argmax(predictions[0])
        loss = predictions[:, class_index]
    grads = tape.gradient(loss, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=(0,1,2))
    conv_outputs = conv_outputs[0]
    heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1)
    heatmap = tf.maximum(heatmap, 0) / (tf.reduce_max(heatmap) + 1e-8)
    heatmap = tf.image.resize(heatmap[..., None], IMG_SIZE)
    heatmap = tf.squeeze(heatmap)
    return heatmap.numpy()

# Visualize a few test images
sample_idx = np.random.choice(len(test_paths), size=min(6,len(test_paths)), replace=False)
for idx in sample_idx:
    path = test_paths[idx]
    label = idx2label[int(test_labels[idx])]
    raw = tf.io.read_file(path)
    img = tf.io.decode_image(raw, channels=3, expand_animations=False)
    img = tf.image.resize(img, IMG_SIZE)
    inp = norm_layer(tf.expand_dims(tf.cast(img, tf.float32), 0))
    prob = model.predict(inp, verbose=0)[0]
    pred_idx = int(np.argmax(prob))
    pred_label = idx2label[pred_idx]
    heatmap = grad_cam(inp, pred_idx)

    plt.figure()
    plt.subplot(1,2,1)
    plt.title(f"True: {label}\nPred: {pred_label} ({prob[pred_idx]:.2f})")
    plt.imshow(tf.cast(img, tf.uint8))
    plt.axis('off')

    plt.subplot(1,2,2)
    plt.title('Grad-CAM')
    plt.imshow(tf.cast(img, tf.uint8))
    plt.imshow(heatmap, alpha=0.35)
    plt.axis('off')
    plt.show()

# ========================= 12) ETHICS, BIAS & DEPLOYMENT NOTES =========================
print('\nNOTES — Ethics & Fairness:')
print('- Validate on an external test set from a different hospital to check generalization.')
print('- Ensure class balance when reporting metrics; show per-class sensitivity (recall).')
print('- Check performance across demographics (age/sex/region if available) to detect bias.')
print('- This model is a decision support tool, NOT a diagnostic device. Keep a clinician-in-the-loop.')

# ========================= 13) INFERENCE SNIPPET (LOADING SAVED MODEL) =========================
# Example function to load and predict on a single image

def load_artifacts(model_path=SAVE_DIR/'final_model.keras', labels_path=SAVE_DIR/'labels.json'):
    mdl = keras.models.load_model(model_path)
    with open(labels_path,'r') as f:
        id2lbl = json.load(f)
    return mdl, {int(k):v for k,v in id2lbl.items()}


def predict_image(image_path):
    mdl, id2lbl = load_artifacts()
    raw = tf.io.read_file(image_path)
    img = tf.io.decode_image(raw, channels=3, expand_animations=False)
    img = tf.image.resize(img, IMG_SIZE)
    img = tf.expand_dims(img, 0)
    img = norm_layer(img)
    prob = mdl.predict(img, verbose=0)[0]
    idx = int(np.argmax(prob))
    return id2lbl[idx], float(prob[idx])

# Example:
# pred, conf = predict_image('some_image.jpg')
# print('Prediction:', pred, 'Confidence:', conf)
