Importing Images From Kaggle

In [None]:
# from pickleshare import Path
from pathlib import Path
import kagglehub

# Download latest version
root_img_path = Path(kagglehub.dataset_download("uraninjo/augmented-alzheimer-mri-dataset"))

# print("Path to dataset files:", PATH)

Imports

In [None]:
import numpy as np
import random
import tensorflow as tf
from tensorflow.keras.applications.inception_resnet_v2 import preprocess_input

Seed Reproducibility

In [None]:
SEED = 67
tf.keras.utils.set_random_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

GLOBAL VARS

In [None]:
IMAGES_PATH = root_img_path / "OriginalDataset"
IMG_SIZE = (299,299) # the image size that inceptionresnetv2 uses
BATCH_SIZE = 32
AUTOTUNE = tf.data.AUTOTUNE

NUM_CLASSES = 4
EPOCHS_HEAD = 70
# EPOCHS_FT = 20

LR_HEAD = 5e-5
# LR_FT = 2e-5
# FINE_TUNE_TOP_LAYERS = 20

BEST_WEIGHTS_PATH = "best_weights.weights.h5"
# where to save the best weights from the training epoch

Split Data

In [None]:
# load in all the files and class names before splitting
all_ds = tf.keras.utils.image_dataset_from_directory(
    IMAGES_PATH,
    image_size=IMG_SIZE,
    color_mode="grayscale",   # reads in images as grayscale --> convert to rgb later
    batch_size=BATCH_SIZE,
    shuffle=False             # just for file listing
)
class_names = all_ds.class_names
assert len(class_names) == NUM_CLASSES

paths = np.array(all_ds.file_paths)
labels = np.array([class_names.index(Path(p).parent.name) for p in paths], dtype=np.int32)

print("Class counts total:", np.bincount(labels, minlength=NUM_CLASSES))

Found 6400 files belonging to 4 classes.
Class counts total: [ 896   64 3200 2240]


In [None]:
# stratified split (70,15,15)
def stratified_split(paths, y, val_frac=0.15, test_frac=0.15, seed=SEED):
    rng = np.random.default_rng(seed)
    idx = np.arange(len(paths))
    tr, va, te = [], [], []
    # go through each class and allocate a random number of images to training, validation and testing
    # based off the fraction split
    for c in range(NUM_CLASSES):
        c_idx = idx[y == c]
        rng.shuffle(c_idx)
        n = len(c_idx)
        n_te = int(round(n * test_frac))
        n_va = int(round(n * val_frac))
        te.append(c_idx[:n_te])
        va.append(c_idx[n_te:n_te+n_va])
        tr.append(c_idx[n_te+n_va:])
    tr = np.concatenate(tr); va = np.concatenate(va); te = np.concatenate(te)
    rng.shuffle(tr); rng.shuffle(va); rng.shuffle(te)
    return paths[tr], y[tr], paths[va], y[va], paths[te], y[te]

train_paths, train_y, val_paths, val_y, test_paths, test_y = stratified_split(paths, labels, seed=SEED)

# display the number of images per class for each type of set (train, val, test)
print("Train counts:", np.bincount(train_y, minlength=NUM_CLASSES))
print("Val counts:  ", np.bincount(val_y, minlength=NUM_CLASSES))
print("Test counts: ", np.bincount(test_y, minlength=NUM_CLASSES))

Train counts: [ 628   44 2240 1568]
Val counts:   [134  10 480 336]
Test counts:  [134  10 480 336]


Data Augmentation

In [None]:
# data augmentation: keep minimal
aug = tf.keras.Sequential([
    tf.keras.layers.RandomRotation(0.05, seed=SEED),
    tf.keras.layers.RandomTranslation(0.05, 0.05, seed=SEED),
    tf.keras.layers.RandomZoom(0.05, seed=SEED),
    # no random horizontal flip because that just doesn't make sense anatomically
], name = "aug")

In [None]:
def decode_and_preprocess(path, label, training=False):
    img_bytes = tf.io.read_file(path) # read in the images as raw bytes
    img = tf.io.decode_image(img_bytes, channels=1, expand_animations=False) # decode into a image tensor
    img.set_shape([None, None, 1]) # gives tensorflow a set shape
    img = tf.image.resize(img, IMG_SIZE) # resizes the image to the img_size (check global vars)
    img = tf.image.grayscale_to_rgb(img) # convert to rgb (what inception expects)
    img = tf.cast(img, tf.float32) # convert for data aug and preprocessing
    # apply augmentation only for training set
    if training:
        img = aug(img, training=True)
    img = preprocess_input(img)  # InceptionResNetV2 expects this
    return img, label

def make_ds(paths, y, training=False):
    ds = tf.data.Dataset.from_tensor_slices((paths, y))
    # shuffle only for training to improve generalization
    # buffer size is capped at 4000 to limit memory use on large datasets
    if training:
        ds = ds.shuffle(min(len(paths), 4000), seed=SEED, reshuffle_each_iteration=True)
    # map paths --> decoded & preprocessed image tensors (parallelized for speed)
    ds = ds.map(lambda p, l: decode_and_preprocess(p, l, training),
                num_parallel_calls=AUTOTUNE)
    # overlap CPU preprocessing with GPU training to reduce input bottlenecks
    ds = ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
    return ds

# create all the datasets (preprocessing) + data aug only for training
train_ds = make_ds(train_paths, train_y, training=True)
val_ds   = make_ds(val_paths,   val_y,   training=False)
test_ds  = make_ds(test_paths,  test_y,  training=False)

In [None]:
# since the classes as really imbalanced, we want to reweight the classes
counts = np.bincount(train_y, minlength=NUM_CLASSES).astype(np.float32)
total = counts.sum()
base_w = total / (NUM_CLASSES * counts)
# since the base weights won't give us a good enough accuracy, precision or recall,
# we choose to slightly alter the weights a little bit
m = np.ones(NUM_CLASSES, dtype=np.float32)
m[0] *= 1.10
# m[1] *= 0.95
m[2] *= 0.94
m[3] *= 1.27
w = base_w*m
# w /= np.mean(w)

class_weight = {i: float(w[i]) for i in range(NUM_CLASSES)}
print("class_weight:", class_weight)

class_weight: {0: 1.961783528327942, 1: 25.454545974731445, 2: 0.4699999988079071, 3: 0.9071428775787354}


Build Model

In [None]:
base=tf.keras.applications.InceptionResNetV2(
    include_top = False,
    weights = "imagenet",
    input_shape = (IMG_SIZE[0], IMG_SIZE[1], 3), # size of 2D input image plus RGB (3)
    pooling = None,
)

inputs = tf.keras.Input(shape = (IMG_SIZE[0], IMG_SIZE[1], 3))
x = base(inputs, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.3)(x) # helps prevent overfitting
outputs = tf.keras.layers.Dense(NUM_CLASSES, activation="softmax")(x)
base.trainable = False
model = tf.keras.Model(inputs, outputs)

Train Head

In [None]:
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath=BEST_WEIGHTS_PATH,
    monitor="val_loss",        # or "val_acc"
    mode="min",                # "min" for val_loss, "max" for val_acc
    save_best_only=True,
    save_weights_only=True,
    verbose=1
)

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LR_HEAD, clipnorm=1.0),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(), # because labels are ints and not one-hot encoded vectors
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="acc")]
)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS_HEAD,
    class_weight=class_weight,
    callbacks=[checkpoint_cb] # ensures we only save the best performing set of weights
)

Epoch 1/70
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 225ms/step - acc: 0.3892 - loss: 1.4424
Epoch 1: val_loss improved from None to 0.98547, saving model to best_weights.weights.h5

Epoch 1: finished saving model to best_weights.weights.h5
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m150s[0m 321ms/step - acc: 0.4817 - loss: 1.1654 - val_acc: 0.5427 - val_loss: 0.9855
Epoch 2/70
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 230ms/step - acc: 0.6372 - loss: 0.7436
Epoch 2: val_loss improved from 0.98547 to 0.80310, saving model to best_weights.weights.h5

Epoch 2: finished saving model to best_weights.weights.h5
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 273ms/step - acc: 0.6594 - loss: 0.6662 - val_acc: 0.6500 - val_loss: 0.8031
Epoch 3/70
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 232ms/step - acc: 0.7721 - loss: 0.4015
Epoch 3: val_loss improved from 0.80310 to 0.57881, saving mo

In [None]:
model.load_weights(BEST_WEIGHTS_PATH) # uses the best weights for evaluation
model.evaluate(test_ds)

[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 65ms/step - acc: 0.9875 - loss: 0.0351


[0.035124555230140686, 0.987500011920929]

Evaluation and Confusion Matrix and Per-Class Metrics

In [None]:
def predict_labels(ds):
    probs = model.predict(ds, verbose=0)
    return np.argmax(probs, axis=1), probs

val_pred, _  = predict_labels(val_ds)
test_pred, _ = predict_labels(test_ds)

def confusion_matrix_np(y_true, y_pred, num_classes):
    cm = np.zeros((num_classes, num_classes), dtype=np.int64)
    for t, p in zip(y_true, y_pred):
        cm[int(t), int(p)] += 1
    return cm

cm_val  = confusion_matrix_np(val_y,  val_pred,  NUM_CLASSES)
cm_test = confusion_matrix_np(test_y, test_pred, NUM_CLASSES)

print("\nVAL confusion matrix (rows=true, cols=pred):\n", cm_val)
print("\nTEST confusion matrix (rows=true, cols=pred):\n", cm_test)

def per_class_report(cm, class_names):
    # precision, recall, f1 per class
    eps = 1e-9
    report = []
    for i, name in enumerate(class_names):
        tp = cm[i, i]
        fp = cm[:, i].sum() - tp
        fn = cm[i, :].sum() - tp
        precision = tp / (tp + fp + eps)
        recall    = tp / (tp + fn + eps)
        f1        = 2 * precision * recall / (precision + recall + eps)
        support   = cm[i, :].sum()
        report.append((name, precision, recall, f1, support))
    bal_acc = np.mean([r[2] for r in report])  # macro recall
    return report, bal_acc

rep_val,  bal_val  = per_class_report(cm_val, class_names)
rep_test, bal_test = per_class_report(cm_test, class_names)

print("\nVAL per-class:")
for name, p, r, f1, sup in rep_val:
    print(f"  {name:>12s}  P={p:.3f}  R={r:.3f}  F1={f1:.3f}  n={sup}")
print("  VAL Balanced Acc (macro recall):", round(bal_val, 3))

print("\nTEST per-class:")
for name, p, r, f1, sup in rep_test:
    print(f"  {name:>12s}  P={p:.3f}  R={r:.3f}  F1={f1:.3f}  n={sup}")
print("  TEST Balanced Acc (macro recall):", round(bal_test, 3))


VAL confusion matrix (rows=true, cols=pred):
 [[134   0   0   0]
 [  0  10   0   0]
 [  1   0 469  10]
 [  0   0   6 330]]

TEST confusion matrix (rows=true, cols=pred):
 [[134   0   0   0]
 [  0  10   0   0]
 [  0   0 475   5]
 [  1   0   6 329]]

VAL per-class:
  MildDemented  P=0.993  R=1.000  F1=0.996  n=134
  ModerateDemented  P=1.000  R=1.000  F1=1.000  n=10
   NonDemented  P=0.987  R=0.977  F1=0.982  n=480
  VeryMildDemented  P=0.971  R=0.982  F1=0.976  n=336
  VAL Balanced Acc (macro recall): 0.99

TEST per-class:
  MildDemented  P=0.993  R=1.000  F1=0.996  n=134
  ModerateDemented  P=1.000  R=1.000  F1=1.000  n=10
   NonDemented  P=0.988  R=0.990  F1=0.989  n=480
  VeryMildDemented  P=0.985  R=0.979  F1=0.982  n=336
  TEST Balanced Acc (macro recall): 0.992
