## Imports

In [None]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

import csv
import math
from pathlib import Path
from typing import List, Tuple, Optional

import numpy as np
import pandas as pd
import cv2 as cv
import matplotlib.pyplot as plt
import tensorflow as tf

import keras_hub

from keras_hub.layers import ViTImageConverter
from keras_hub.models import ViTImageClassifierPreprocessor, ViTBackbone

labels = "annotations.csv"
train_path = "deep_learning/working/Diabetic Retinopathy/train"
valid_path = "deep_learning/working/Diabetic Retinopathy/valid"
test_path = "deep_learning/working/Diabetic Retinopathy/test"
train_labels = os.path.join(train_path, labels)
valid_labels = os.path.join(valid_path, labels)
test_labels = os.path.join(test_path, labels)

print(train_labels)
print(valid_labels)
print(test_labels)

BINARY_LABEL_COL = "Risk of macular edema"

# print(tf.keras.__version__)
# print(tf.__version__)

## 0. Dataset Class Balance

In [None]:
def counts_from_dataset_split(split_annotations_path: str):
    split_dir = Path(split_annotations_path)
    labels = pd.read_csv(split_dir, usecols=[BINARY_LABEL_COL])[BINARY_LABEL_COL].astype(int)
    counts = labels.value_counts().reindex([0, 1], fill_value=0)
    return int(counts[0]), int(counts[1])

def show_dataset_split_summary(name: str, negatives: int, positives: int):
    total = negatives + positives
    positive_pct = (positives / total) * 100 if total else 0.0
    print(f"{name:<6}  N={total:4d} | neg={negatives:4d}  pos={positives:4d}  pos%={positive_pct:6.2f}")

def show_class_balance(train_labels, valid_labels, test_labels, save_png="class_balance.png"):
    train_neg, train_pos = counts_from_dataset_split(train_labels)
    valid_neg, valid_pos = counts_from_dataset_split(valid_labels)
    test_neg, test_pos = counts_from_dataset_split(test_labels)

    print("\n=== Class balance per split ===")
    show_dataset_split_summary("train", train_neg, train_pos)
    show_dataset_split_summary("valid", valid_neg, valid_pos)
    show_dataset_split_summary("test ",  test_neg,  test_pos)

    # Visualization
    split_titles = ["Train", "Valid", "Test"]
    negative_counts = [train_neg, valid_neg, test_neg]
    positive_counts = [train_pos, valid_pos, test_pos]
    totals_per_split = [n + p for n, p in zip(negative_counts, positive_counts)]

    fig, axes = plt.subplots(1, 3, figsize=(12, 4), constrained_layout=True)
    for ax, title, n0, n1, N in zip(axes, split_titles, negative_counts, positive_counts, totals_per_split):
        ax.bar(["Negative (0)", "Positive (1)"], [n0, n1])
        pos_pct = (n1 / N * 100) if N else 0.0
        ax.set_title(f"{title}\nN={N}, pos={pos_pct:.1f}%")
        ax.set_ylabel("Count")
        ax.grid(axis="y", alpha=0.2)

    fig.suptitle(f"Class balance by split — label: '{BINARY_LABEL_COL}'")
    fig.savefig(save_png, dpi=200, bbox_inches="tight")
    print(f"\nSaved chart → {save_png}")

    # Inverted weights for Keras (train split only)
    train_total = train_neg + train_pos
    class_weight = {
        0: train_total / (2 * train_neg),
        1: train_total / (2 * train_pos),
    }

    cw0, cw1 = class_weight[0], class_weight[1]
    print(f"Inverse weighting for imbalance: class_weight={{0: {cw0:.3f}, 1: {cw1:.3f}}}")
    return class_weight

## 1. Data Preprocessing

### 1.1. Crop square and resize images to (384x384)

In [None]:
def crop_square_resize_image(
    img_bgr: np.ndarray,
    out_size: int = 384,
    black_delta: int = 5,
    margin_frac: float = 0.03,
) -> np.ndarray:
    """
    Remove near-black borders using a dynamic threshold, keep the largest
    content region, pad to a square with the crop border's median color,
    then resize to `out_size`. Returns BGR uint8.
    """
    # Grayscale for thresholding
    # Treba nam samo intenzitet izmedju skroz crno i ocne slike
    # Cheaper and sufficient
    h, w = img_bgr.shape[:2]
    gray = cv.cvtColor(img_bgr, cv.COLOR_BGR2GRAY)

    # Dynamic threshold from the image border
    border = np.concatenate([gray[0, :], gray[-1, :], gray[:, 0], gray[:, -1]])
    threshold = int(max(10, np.median(border) + black_delta))

    # Binary mask (content = 255)
    # Sve svetlije od okoline se uzima kao content
    mask = (gray > threshold).astype(np.uint8) * 255

    # Morphological closing to seal thin gaps on the rim
    kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, (7, 7))
    mask = cv.morphologyEx(mask, cv.MORPH_CLOSE, kernel, iterations=2)

    # Largest connected component (the retina)
    num_labels, labels, stats, _ = cv.connectedComponentsWithStats(mask, connectivity=8)
    if num_labels <= 1:
        # Fallback: nothing detected → just resize
        print("Error when trying to process image, no values recieved!")
        return cv.resize(img_bgr, (out_size, out_size), interpolation=cv.INTER_AREA)

    largest_id = 1 + np.argmax(stats[1:, cv.CC_STAT_AREA])
    x, y, bw, bh, _ = stats[largest_id]

    # Expand slightly to avoid clipping the rim
    # Dodavanje malog paddinga da ne bi doslo do isecanja ocne slike
    pad_h = int(bh * margin_frac)
    pad_w = int(bw * margin_frac)
    x0 = max(0, x - pad_w)
    y0 = max(0, y - pad_h)
    x1 = min(w, x + bw + pad_w) 
    y1 = min(h, y + bh + pad_h)
    
    crop = img_bgr[y0:y1, x0:x1, :]
    ch, cw = crop.shape[:2]

    # Square-pad (no stretching)
    side = max(ch, cw)
    top = (side - ch) // 2
    bottom = side - ch - top
    left = (side - cw) // 2
    right = side - cw - left

    # Pad color = median of the crop border (neutral)
    border_pixels = np.vstack([crop[0, :, :], crop[-1, :, :], crop[:, 0, :], crop[:, -1, :]])
    pad_color = np.median(border_pixels.reshape(-1, 3), axis=0).astype(np.uint8).tolist()
    padded = cv.copyMakeBorder(
        crop, top, bottom, left, right, borderType=cv.BORDER_CONSTANT, value=pad_color
    )
    # Final resize
    out = cv.resize(padded, (out_size, out_size), interpolation=cv.INTER_AREA)
    return out

### 1.2. Process and save resized images

In [None]:
def save_resized_images_from_dataset_split(
    input_dir,
    output_dir,
    out_size: int = 384,
    workers: int | None = None,
):
    """
    Read all images in `input_dir`, trim borders + square-pad + resize,
    and write PNGs to `output_dir`.
    - Accepts: .jpg/.jpeg/.png (case-insensitive)
    - Output: lossless PNG with moderate compression (level 6)
    - Parallelism: ThreadPoolExecutor (set workers=0 or 1 to run sequentially)
    """
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # 1) Collect inputs (deterministic order helps reproducibility)
    exts = {".jpg", ".jpeg", ".png"}
    paths = sorted(p for p in input_dir.iterdir() if p.is_file() and p.suffix.lower() in exts)
    if not paths:
        print(f"[process_dir] No images found in {input_dir}")
        return

    # 2) Decide worker count (simple, CPU-friendly default)
    if workers is None:
        workers = min(8, os.cpu_count() or 2)

    def _process_one(p: Path):
        # Read as 3-channel color for consistency with Keras preprocessors
        img = cv.imread(str(p), cv.IMREAD_COLOR)
        if img is None:
            # keep it simple—skip unreadable files without extra handling
            return
        out = crop_square_resize_image(img, out_size=out_size)

        out_path = output_dir / f"{p.stem}.png"
        # PNG is lossless; compression=6 gives good size without being slow
        cv.imwrite(str(out_path), out, [int(cv.IMWRITE_PNG_COMPRESSION), 6])

    # 3) Process (parallel if workers>1)
    if workers and workers > 1:
        from concurrent.futures import ThreadPoolExecutor
        with ThreadPoolExecutor(max_workers=workers) as ex:
            list(ex.map(_process_one, paths))
    else:
        for p in paths:
            _process_one(p)

    print(f"[process_dir] Wrote {len(paths)} images → {output_dir}  (size={out_size}×{out_size})")

## 2. Data augmentation

### 2.1. Reading file paths and labels

In [None]:
cv.setNumThreads(0)

# Utility for removing extensions to read label for the given image filename
def normalize_fname(s: str) -> str:
    s = str(s).strip()
    # remove one trailing extension
    if "." in s:
        s = s[: s.rfind(".")]
    return s
    
def load_samples_from_csv(
    csv_path: str | Path,
    images_dir: str | Path,
    image_col: str = "Image name",
    label_col: str = "Risk of macular edema",
    seed: Optional[int] = 123,
) -> List[Tuple[str, int]]:
    """
    Return a pre-shuffled list of (png_path, label) pairs, matching CSV names by stem.
    Assumes preprocessed images are stored as <stem>.png in images_dir.
    """
    images_dir = Path(images_dir)
    samples: List[Tuple[str, int]] = []

    with open(csv_path, newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            stem = normalize_fname(row[image_col])
            png = images_dir / f"{stem}.png"
            if png.exists():
                samples.append((str(png), int(row[label_col])))

    if not samples:
        raise FileNotFoundError(f"No matching .png files for {csv_path} in {images_dir}")

    # Pre-shuffle list of zipped file paths and labels 
    if seed is not None:
        rng = np.random.default_rng(seed)
        idx = rng.permutation(len(samples))
        samples = [samples[i] for i in idx]

    return samples

### 2.2. Data Generator

In [None]:
def get_img_data_gen(fill_value=0.03) -> tf.keras.preprocessing.image.ImageDataGenerator:
    return tf.keras.preprocessing.image.ImageDataGenerator(
        rotation_range=10,
        horizontal_flip=True,
        vertical_flip=True,
        zoom_range=[1.0, 1.15],
        fill_mode="constant",
        cval=fill_value,
        dtype=np.float32,
    )

img_datagen = get_img_data_gen()

def load_image(path, label):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=3)
    img = tf.cast(img, tf.float32)
    img.set_shape([image_size, image_size, 3])
    return img, label

def idg_augment_tf(img, label):
    def _aug(np_img):
        x = img_datagen.random_transform(np_img)
        return x.astype(np.float32)
    img = tf.numpy_function(_aug, [img], tf.float32)
    img.set_shape([image_size, image_size, 3])
    return img, label

def create_dataset_from_split(samples, augment, shuffle):
    paths = tf.constant([p for p, _ in samples])
    labels = tf.constant([l for _, l in samples], dtype=tf.int32)
    ds = tf.data.Dataset.from_tensor_slices((paths, labels))
    if shuffle:
        ds = ds.shuffle(len(samples), seed=42, reshuffle_each_iteration=True)
    ds = ds.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
    if augment:
        ds = ds.map(idg_augment_tf, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds

## 3. Model setup

### 3.1. Hyperparameters and Train Settings

#### 3.1.1. Inverse weight and class balance

In [None]:
class_weight = show_class_balance(
    train_labels=train_labels,
    valid_labels=valid_labels,
    test_labels=test_labels,
)

#### 3.1.2. Experiment settings

In [None]:
architecture = "ViTBase32"
experiment_no = 6

image_size = 384
batch_size = 16

cnn_architectures = {'ResNet50', 'EfficientNetV2S', 'ConvNeXtTiny', 'EfficientNetB4'}

curr_folder = f"deep_learning/working/experiments/{architecture}_00{experiment_no} - AUC"
if not os.path.exists(curr_folder):
    os.mkdir(curr_folder)

experiments_csv_path = os.path.join("deep_learning/working/experiments", "all_experiments.csv")
model_path = os.path.join(curr_folder, 'model.keras')
metadata_path = os.path.join(curr_folder, 'model.json')
log_path = os.path.join(curr_folder, 'log.csv')
log_warmup_path = os.path.join(curr_folder, 'log_warmup.csv')

#### 3.1.3. Hyperparameter settings

In [None]:
def select_optimizer(name: str, lr, wd):
    if name == "Adam":
        return tf.keras.optimizers.Adam(learning_rate=lr)
    if name == "AdamW":
        return tf.keras.optimizers.AdamW(learning_rate=lr, weight_decay=wd)

init_lr = 1e-3
weight_decay = 0 
optimizer_name = "AdamW"
optimizer = select_optimizer(name=optimizer_name, lr=init_lr, wd=weight_decay)
# val_monitor = ('val_binary_accuracy', 'max')
val_monitor = ('val_auc', 'max')
epochs = 30
epochs_warmup = 5
reduce_lr_patience = 5
early_stopping_patience = 10

#### 3.1.4. Model fit callbacks

In [None]:
csv_logger = tf.keras.callbacks.CSVLogger(log_path, separator=',', append=False)
csv_logger_warmup = tf.keras.callbacks.CSVLogger(log_warmup_path, separator=',', append=False)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    factor=0.1, 
    patience=reduce_lr_patience,
    monitor=val_monitor[0], 
    mode=val_monitor[1]
)

early_stopping = tf.keras.callbacks.EarlyStopping(
    patience=early_stopping_patience, 
    verbose=1,                               
    restore_best_weights=True,
    monitor=val_monitor[0], 
    mode=val_monitor[1]
)

checkpoint_best = tf.keras.callbacks.ModelCheckpoint(
    filepath=model_path, 
    monitor=val_monitor[0], 
    mode=val_monitor[1], 
    verbose=1, 
    save_best_only=True
)

### 3.2. Architecture selection

In [None]:
if architecture == 'ResNet50':
    preprocess_input = tf.keras.applications.resnet.preprocess_input
    encoder = tf.keras.applications.ResNet50
elif architecture == 'EfficientNetV2S':
    preprocess_input = tf.keras.applications.efficientnet_v2.preprocess_input
    encoder = tf.keras.applications.EfficientNetV2S
elif architecture == 'ConvNeXtTiny':
    preprocess_input = tf.keras.applications.convnext.preprocess_input
    encoder = tf.keras.applications.ConvNeXtTiny
elif architecture == 'EfficientNetB4':
    preprocess_input = tf.keras.applications.efficientnet.preprocess_input
    encoder = tf.keras.applications.EfficientNetB4
elif architecture == 'ViTBase16':
    backbone = keras_hub.models.Backbone.from_preset(
        "vit_base_patch16_384_imagenet"
    )
    preprocessor = keras_hub.models.ViTImageClassifierPreprocessor.from_preset(
        "vit_base_patch16_384_imagenet"
    )
elif architecture == 'ViTBase32':
    backbone = keras_hub.models.Backbone.from_preset(
        "vit_base_patch32_384_imagenet"
    )
    preprocessor = keras_hub.models.ViTImageClassifierPreprocessor.from_preset(
        "vit_base_patch32_384_imagenet"
    )
# elif architecture == 'DeiTBase16':
#     backbone = keras_hub.models.DeiTBackbone.from_preset(
#         "deit_base_distilled_patch16_384_imagenet"
#     )
#     preprocessor = keras_hub.models.DeiTImageClassifierPreprocessor.from_preset(
#         "deit_base_distilled_patch16_384_imagenet"
#     )

### 3.3. Create Model

In [None]:
def create_classifier_model(trainable_encoder):
    in_shp = (image_size, image_size, 3)
    x = tf.keras.layers.Input(shape=in_shp, name='input')
    z = x
    
    if architecture not in cnn_architectures:
        # Freeze the entire backbone
        backbone.trainable = trainable_encoder
        z = preprocessor(z)
        y = backbone(z)
        y = tf.keras.layers.Lambda(lambda t: t[:, 0, :], name="cls_token")(y)
        
    else:
        y = tf.keras.layers.Lambda(preprocess_input, name="preproc")(z)
        cnn_backbone = encoder(include_top=False, weights='imagenet', input_shape=in_shp, input_tensor=y, pooling='avg', classes=2)
        if not trainable_encoder:
            for layer in cnn_backbone.layers:
                layer.trainable = False
        
        y = cnn_backbone.output
        
    y = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(y)
    model = tf.keras.models.Model(inputs=x, outputs=y, name=f"{architecture}_binclass")
    return model

## 4. Training

### 4.0. Preprocessing and loading samples

In [None]:
# Called only once on first startup, comment after that to avoid unesscesary processing!
# save_resized_images_from_dataset_split(os.path.join(train_path, "images"), "working/preprocessed/train/images_resized", out_size=image_size, workers=8)
# save_resized_images_from_dataset_split(os.path.join(valid_path, "images"), "working/preprocessed/valid/images_resized", out_size=image_size, workers=8)
# save_resized_images_from_dataset_split(os.path.join(test_path, "images"),  "working/preprocessed/test/images_resized",  out_size=image_size, workers=8)

# Used for creating datasets for train and evaluation
train_samples = load_samples_from_csv(train_labels, "deep_learning/working/preprocessed/train/images_resized")
valid_samples = load_samples_from_csv(valid_labels, "deep_learning/working/preprocessed/valid/images_resized")
test_samples = load_samples_from_csv(test_labels, "deep_learning/working/preprocessed/test/images_resized")

### 4.1. Model summary

In [None]:
model = create_classifier_model(False)

model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.binary_crossentropy,
    metrics=[tf.keras.metrics.binary_accuracy, tf.keras.metrics.AUC(name="auc", curve="ROC", num_thresholds=200)],
    jit_compile=False,
    run_eagerly=True
)

model.summary()

### 4.2. Warmup

In [None]:
# (path, label)
train_dataset = create_dataset_from_split(train_samples, augment=True, shuffle=True)
valid_dataset = create_dataset_from_split(valid_samples, augment=False, shuffle=False)

history_warmup = model.fit(
    train_dataset,
    validation_data=valid_dataset,
    epochs=epochs_warmup,
    callbacks=[csv_logger_warmup, reduce_lr, checkpoint_best],
    class_weight=class_weight,
    verbose=1,
)

### 4.3. Full training

In [None]:
for layer in model.layers:
    if not isinstance(layer, tf.keras.layers.BatchNormalization):
        layer.trainable = True

fine_tune_lr = 1e-5
weight_decay = 1e-4 if optimizer_name == 'AdamW' else 0
model.compile(optimizer=select_optimizer(name=optimizer_name, lr=fine_tune_lr, wd=weight_decay),
              loss=tf.keras.losses.binary_crossentropy,
              metrics=[tf.keras.metrics.binary_accuracy, tf.keras.metrics.AUC(name="auc", curve="ROC", num_thresholds=200)])

history_full = model.fit(
    train_dataset,
    validation_data=valid_dataset,
    epochs=epochs,
    callbacks=[csv_logger, reduce_lr, early_stopping, checkpoint_best],
    class_weight=class_weight,
    verbose=1,
)

### 4.4. Save training results

In [None]:
def plot_hist(hist, metric, title, outfile):
    plt.clf()
    if metric in hist.history:
        plt.plot(hist.history[metric], label=f"train {metric}")
    if f"val_{metric}" in hist.history:
        plt.plot(hist.history[f'val_{metric}'], label=f"val {metric}")
    plt.title(title); plt.xlabel("Epoch"); plt.ylabel(metric); plt.grid(True, alpha=.3); plt.legend()
    plt.savefig(os.path.join(curr_folder, outfile), dpi=300, bbox_inches="tight")

# Warmup
plot_hist(history_warmup, "binary_accuracy", "Binary Accuracy (warmup)", "warmup_accuracy.png")
plot_hist(history_warmup, "loss", "Loss (warmup)", "warmup_loss.png")
plot_hist(history_warmup, "auc", "AUC (warmup)", "warmup_auc.png")

# Full training
plot_hist(history_full, "binary_accuracy", "Binary Accuracy (full)", "full_accuracy.png")
plot_hist(history_full, "loss", "Loss (full)", "full_loss.png")
plot_hist(history_full, "auc", "AUC (full)", "full_auc.png")

## 5. Model Evaluation

### 5.0. Evaluation Imports

In [None]:
from sklearn.metrics import (
    roc_curve, roc_auc_score, 
    precision_recall_curve, average_precision_score, 
    confusion_matrix, ConfusionMatrixDisplay, 
    accuracy_score, precision_score, recall_score, f1_score
)

### 5.1. Predict and Evaluate split

In [None]:
def predict_labels_and_scores(eval_dataset):
    """
    Run model.predict on a dataset (no augmentation) and return:
    - y_true: 1D int array of shape [N]
    - y_pred: 1D float array of shape [N], sigmoid probabilities in [0,1]
    """
    y_true = np.concatenate([y.numpy().ravel() for _, y in eval_dataset]).astype(np.int32)
    y_pred = np.squeeze(model.predict(eval_dataset, verbose=0), axis=-1).astype(np.float32)
    return y_true, y_pred

def evaluate_split(y_true, y_pred, tau=0.5):
    y_true = np.asarray(y_true).astype(int).ravel()
    y_pred = np.asarray(y_pred, dtype=np.float32).ravel()

    pred_labels = (y_pred >= float(tau)).astype(int)

    cm = confusion_matrix(y_true, pred_labels, labels=[0,1])
    metrics = {
        "accuracy":  accuracy_score(y_true, pred_labels),
        "precision": precision_score(y_true, pred_labels, zero_division=0),
        "recall":    recall_score(y_true, pred_labels),
        "f1":        f1_score(y_true, pred_labels),
        "auroc":     roc_auc_score(y_true, y_pred),
        "auprc":     average_precision_score(y_true, y_pred),
        "cm":        cm,
    }
    return metrics

### 5.2. Save Experiment and Evaluation results

In [None]:
def as_pct(x, nd=2):
    """Format as percentage string with nd decimals (e.g., '91.23%')."""
    v = float(np.asarray(x))
    if abs(v) < 1e-12:
        v = 0.0
    return f"{v * 100:.{nd}f}%"

def rnd(x, nd=4):
    """Round to nd decimals and return a Python float (good for CSV)."""
    v = np.round(np.asarray(x, dtype=np.float64), nd)
    return float(v) if np.ndim(v) == 0 else v.tolist()

def save_roc_curve(y_true, y_pred, title, out_path_png):
    fpr, tpr, _ = roc_curve(y_true, y_pred)
    auroc = roc_auc_score(y_true, y_pred)
    plt.clf()
    plt.plot(fpr, tpr, lw=2, label=f"AUROC={auroc:.4f}")
    plt.plot([0, 1], [0, 1], "--", lw=1, color="gray")
    plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
    plt.title(title); plt.grid(True, alpha=.3); plt.legend(loc="lower right")
    plt.savefig(out_path_png, dpi=300, bbox_inches="tight")
    plt.close()

def save_confusion_matrix_results(
    cm, 
    class_names=("negative","positive"),
    out_counts_png=None, 
    out_rowperc_png=None
):
    """
    1) counts with 'Misclassified X of N'
    2) row-normalized % with 'Classification accuracy Y%'
    """
    tn, fp, fn, tp = cm.ravel()
    total = cm.sum()
    acc = (tn + tp) / total
    miscls = total - (tn + tp)

    if out_counts_png:
        fig, ax = plt.subplots(figsize=(10, 8.5))
        im = ax.imshow(cm, cmap="Blues")
        ax.set_xticks([0,1]); ax.set_yticks([0,1])
        ax.set_xticklabels(["negative","positive"])
        ax.set_yticklabels(["negative","positive"])
        ax.set_xlabel("Predicted label"); ax.set_ylabel("True label")
        title = f"Misclassified {miscls} out of {total} case studies"
        ax.set_title(title)
        for (i,j), v in np.ndenumerate(cm):
            tag = ("TN","FP","FN","TP")[(i*2)+j]
            ax.text(j, i, f"{v}\n{tag}", va="center", ha="center", color="black")
        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.ax.set_ylabel("count", rotation=90, va="bottom")
        plt.tight_layout()
        plt.savefig(out_counts_png, dpi=300, bbox_inches="tight")
        plt.close(fig)

    if out_rowperc_png:
        row_sum = cm.sum(axis=1, keepdims=True).clip(min=1)
        rowperc = (cm / row_sum) * 100.0
        fig, ax = plt.subplots(figsize=(10, 8.5))
        im = ax.imshow(rowperc, cmap="Blues", vmin=0, vmax=100)
        ax.set_xticks([0,1]); ax.set_yticks([0,1])
        ax.set_xticklabels(["negative","positive"])
        ax.set_yticklabels(["negative","positive"])
        ax.set_xlabel("Predicted label"); ax.set_ylabel("True label")
        title = f"Classification accuracy {as_pct(acc)}"
        ax.set_title(title)
        for (i,j), v in np.ndenumerate(rowperc):
            tag = ("TN","FP","FN","TP")[(i*2)+j]
            ax.text(j, i, f"{v:.1f}%\n{tag}", va="center", ha="center", color="black")
        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.ax.set_ylabel("% within true class", rotation=90, va="bottom")
        plt.tight_layout()
        plt.savefig(out_rowperc_png, dpi=300, bbox_inches="tight")
        plt.close(fig)

def append_row_to_experiments_csv(csv_path, row_dict, column_order):
    """
    Append one experiment row to CSV. If file doesn't exist, write header first.
    `row_dict` keys must match `column_order`.
    """
    file_exists = os.path.exists(csv_path)
    with open(csv_path, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=column_order)
        if not file_exists:
            writer.writeheader()
        writer.writerow(row_dict)

### 5.3. Create datasets for evaluation and load model

In [None]:
# model = tf.keras.models.load_model(filepath=model_path, compile=False, safe_mode=False)

eval_ds_train = create_dataset_from_split(train_samples, augment=False, shuffle=False)
eval_ds_valid = create_dataset_from_split(valid_samples, augment=False, shuffle=False)
eval_ds_test  = create_dataset_from_split(test_samples,  augment=False, shuffle=False)

y_true_valid, y_pred_valid = predict_labels_and_scores(eval_ds_valid)
y_true_test, y_pred_test = predict_labels_and_scores(eval_ds_test)
y_true_train, y_pred_train = predict_labels_and_scores(eval_ds_train)

### 5.4. Get evaluation metrics

In [None]:
valid_metrics = evaluate_split(y_true=y_true_valid, y_pred=y_pred_valid)
test_metrics = evaluate_split(y_true=y_true_test, y_pred=y_pred_test)
train_metrics = evaluate_split(y_true=y_true_train, y_pred=y_pred_train)

def brief(m):
    return (f"acc={m['accuracy']:.4f}  auroc={m['auroc']:.4f}  auprc={m['auprc']:.4f}  "
            f"prec={m['precision']:.4f}  rec={m['recall']:.4f}  f1={m['f1']:.4f}")

print("\nVALID @ τ=0.5:", brief(valid_metrics))
print("TEST @ τ=0.5:", brief(test_metrics))
print("TRAIN @ τ=0.5:", brief(train_metrics))

### 5.5. Save evaluation results

In [None]:
save_roc_curve(y_true_valid, y_pred_valid, "Validation ROC", os.path.join(curr_folder, "roc_val.png"))
save_roc_curve(y_true_test, y_pred_test, "Test ROC", os.path.join(curr_folder, "roc_test.png"))

save_confusion_matrix_results(
    valid_metrics["cm"],
    out_counts_png=os.path.join(curr_folder, "cm_val_counts.png"),
    out_rowperc_png=os.path.join(curr_folder, "cm_val_rowperc.png"),
)

save_confusion_matrix_results(
    test_metrics["cm"],
    out_counts_png=os.path.join(curr_folder, "cm_test_counts.png"),
    out_rowperc_png=os.path.join(curr_folder, "cm_test_rowperc.png"),
)

### 5.6. Save experiment run

In [None]:
# warmup_freeze_policy   = "only output trainable"
# finetune_freeze_policy = "all layers trainable, BN frozen"
has_bn = architecture in {"ResNet50", "EfficientNetV2S", "EfficientNetB4"}
# bn_frozen_warmup   = "yes" if has_bn else "n/a"
# bn_frozen_finetune = "yes"

csv_row = {
    # run context
    "experiment_no": experiment_no,
    "architecture": architecture,
    "image_size": image_size,
    "batch_size": batch_size,
    "init_lr": init_lr,
    "finetune_lr": fine_tune_lr,
    "weight_decay": weight_decay,
    "optimizer": optimizer_name,
    # "warmup_freeze": warmup_freeze_policy,
    # "finetune_freeze": finetune_freeze_policy,
    "has_bn": "yes" if has_bn else "no",
    # "bn_frozen_warmup": bn_frozen_warmup,
    # "bn_frozen_finetune": bn_frozen_finetune,

    # train/val (only AUC + accuracy for the sheet)
    "train_acc": as_pct(train_metrics["accuracy"]),
    "train_auroc": rnd(train_metrics["auroc"], 4),
    "val_acc": as_pct(valid_metrics["accuracy"]),
    "val_auroc": rnd(valid_metrics["auroc"], 4),

    # test (full set; rounded/percent where it helps)
    "test_acc":  as_pct(test_metrics["accuracy"]),
    "test_auroc": rnd(test_metrics["auroc"], 4),
    "test_auprc": rnd(test_metrics["auprc"], 4),
    "test_prec": as_pct(test_metrics["precision"]),
    "test_rec":  as_pct(test_metrics["recall"]),
    "test_f1":   as_pct(test_metrics["f1"]),
    "tau_used": 0.5,
}

csv_columns = [
    "experiment_no","architecture","image_size","batch_size","init_lr","finetune_lr","weight_decay","optimizer","has_bn",
    "train_acc","train_auroc","val_acc","val_auroc",
    "test_acc","test_auroc","test_auprc","test_prec","test_rec","test_f1","tau_used",
]

append_row_to_experiments_csv(experiments_csv_path, csv_row, csv_columns)
print(f"\nAppended metrics → {experiments_csv_path}")

In [None]:
from sklearn.metrics import (
    roc_curve, roc_auc_score, precision_recall_curve, average_precision_score,
    confusion_matrix, ConfusionMatrixDisplay,
    accuracy_score, precision_score, recall_score, f1_score
)

#### 5.2. Predict, threshold and evaluation on datasets

In [None]:
def predict_labels_and_scores(eval_dataset):
    """
    Run model.predict on a dataset (no augmentation) and return:
    - y_true: 1D int array of shape [N]
    - y_pred: 1D float array of shape [N], sigmoid probabilities in [0,1]
    """
    y_true = np.concatenate([y.numpy().ravel() for _, y in eval_dataset]).astype(np.int32)
    y_pred = np.squeeze(model.predict(eval_dataset, verbose=0), axis=-1).astype(np.float32)
    return y_true, y_pred

def get_threshold_for_sensitivity(y_true, y_pred, target_sens=0.90):
    fp_rate, tp_rate, threshold = roc_curve(y_true, y_pred)
    idx = np.where(tp_rate >= target_sens)[0]
    tau_sens = float(threshold[idx[0]]) if idx.size else 1.0
    return tau_sens, (fp_rate, tp_rate, threshold)

def evaluate_split_at_threshold(y_true, y_pred, tau):
    """
    Compute standard metrics at a fixed threshold τ:
    - accuracy, precision, recall, F1
    - AUROC (threshold-free), AUPRC (threshold-free)
    - confusion_matrix (2x2 for [neg,pos])
    Returns a metrics dict.
    """
 
    y_true = np.asarray(y_true).astype(int).ravel()
    y_pred = np.asarray(y_pred, dtype=np.float32).ravel()
    
    pred_labels = (y_pred >= float(tau)).astype(int)
    cm = confusion_matrix(y_true, pred_labels, labels=[0, 1])
    metrics = {
        "accuracy":  accuracy_score(y_true, pred_labels),
        "precision": precision_score(y_true, pred_labels, zero_division=0),
        "recall":    recall_score(y_true, pred_labels),
        "f1":        f1_score(y_true, pred_labels),
        "auroc":     roc_auc_score(y_true, y_pred),
        "auprc":     average_precision_score(y_true, y_pred),
        "cm":        cm,
    }
    return metrics

#### 5.3. Helper methods for plotting and saving results

In [None]:
def save_confusion_matrix(cm, title, out_path_png):
    """Render and save a confusion matrix figure to PNG."""
    disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])
    fig, ax = plt.subplots(figsize=(5, 5))
    disp.plot(cmap="Blues", values_format="d", ax=ax, colorbar=False)
    ax.set_title(title)
    fig.savefig(out_path_png, dpi=300, bbox_inches="tight")
    plt.close(fig)

def save_roc_curve(true_labels, pred_scores, title, out_path_png):
    """Render and save an ROC curve figure to PNG."""
    fpr, tpr, _ = roc_curve(true_labels, pred_scores)
    auroc = roc_auc_score(true_labels, pred_scores)
    plt.clf()
    plt.plot(fpr, tpr, lw=2, label=f"AUROC={auroc:.4f}")
    plt.plot([0, 1], [0, 1], "--", lw=1, color="gray")
    plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
    plt.title(title); plt.grid(True, alpha=.3); plt.legend(loc="lower right")
    plt.savefig(out_path_png, dpi=300, bbox_inches="tight")
    plt.close()

def append_row_to_experiments_csv(csv_path, row_dict, column_order):
    """
    Append one experiment row to CSV. If file doesn't exist, write header first.
    `row_dict` keys must match `column_order`.
    """
    file_exists = os.path.exists(csv_path)
    with open(csv_path, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=column_order)
        if not file_exists:
            writer.writeheader()
        writer.writerow(row_dict)

#### 5.4. Full Evaluation 

In [None]:
test_samples = load_samples_from_csv(test_labels, "working/preprocessed/test/images_resized")

eval_ds_train = create_dataset_from_split(train_samples, augment=False, shuffle=False)
eval_ds_valid = create_dataset_from_split(valid_samples, augment=False, shuffle=False)
eval_ds_test  = create_dataset_from_split(test_samples,  augment=False, shuffle=False)

labels_valid, scores_valid = predict_labels_and_scores(eval_ds_valid)
labels_test, scores_test  = predict_labels_and_scores(eval_ds_test)
labels_train, scores_train = predict_labels_and_scores(eval_ds_train)

target_sensitivity = 0.90
tau_target_sensitivity, (fpr_v, tpr_v, thr_v) = get_threshold_for_sensitivity(
    labels_valid, scores_valid, target_sensitivity
)
print(f"[Operating τ] First TPR ≥ {target_sensitivity:.2f} → τ = {tau_target_sensitivity:.6f}")

In [None]:
metrics_valid = evaluate_split_at_threshold(labels_valid, scores_valid, tau_target_sensitivity)
metrics_test  = evaluate_split_at_threshold(labels_test,  scores_test,  tau_target_sensitivity)
metrics_train = evaluate_split_at_threshold(labels_train, scores_train, tau_target_sensitivity)

save_roc_curve(labels_valid, scores_valid, "Validation ROC", os.path.join(curr_folder, "roc_val.png"))
save_roc_curve(labels_test, scores_test, "Test ROC", os.path.join(curr_folder, "roc_test.png"))

save_confusion_matrix(metrics_valid["cm"], f"Valid - Risk of Macular Edema @ τ={tau_target_sensitivity:.4f} (sens≥{int(target_sensitivity*100)}%)", os.path.join(curr_folder, "cm_val_sens.png"))
save_confusion_matrix(metrics_test["cm"], f"Test - Risk of Macular Edema @ τ={tau_target_sensitivity:.4f} (sens≥{int(target_sensitivity*100)}%)", os.path.join(curr_folder, "cm_test_sens.png"))

def fmt(m): return (f"acc={m['accuracy']:.4f}  auroc={m['auroc']:.4f}  auprc={m['auprc']:.4f}  "
                    f"prec={m['precision']:.4f}  rec={m['recall']:.4f}  f1={m['f1']:.4f}")
print("\nVALID @ τ_sens:", fmt(metrics_valid))
print("TEST  @ τ_sens:", fmt(metrics_test))
print("TRAIN @ τ_sens:", fmt(metrics_train))

csv_row = {
    # Run context
    "architecture": architecture,
    "image_size": image_size,
    "batch_size": batch_size,
    "init_lr": init_lr,             
    "optimizer": "Adam",
    "target_sensitivity": target_sensitivity,
    "operating_tau": tau_target_sensitivity, # τ chosen on VALIDATION to hit sensitivity target

    # TRAIN (evaluated without augmentation)
    "train_acc": metrics_train["accuracy"],
    "train_auroc": metrics_train["auroc"],
    "train_auprc": metrics_train["auprc"],
    "train_prec": metrics_train["precision"],
    "train_rec": metrics_train["recall"],
    "train_f1": metrics_train["f1"],

    # VALIDATION (threshold selection split)
    "val_acc": metrics_valid["accuracy"],
    "val_auroc": metrics_valid["auroc"],
    "val_auprc": metrics_valid["auprc"],
    "val_prec": metrics_valid["precision"],
    "val_rec": metrics_valid["recall"],
    "val_f1": metrics_valid["f1"],

    # TEST (held-out, honest report)
    "test_acc": metrics_test["accuracy"],
    "test_auroc": metrics_test["auroc"],
    "test_auprc": metrics_test["auprc"],
    "test_prec": metrics_test["precision"],
    "test_rec": metrics_test["recall"],
    "test_f1": metrics_test["f1"],
}

csv_columns = [
    "architecture","image_size","batch_size","init_lr","optimizer",
    "target_sensitivity","operating_tau",
    "train_acc","train_auroc","train_auprc","train_prec","train_rec","train_f1",
    "val_acc","val_auroc","val_auprc","val_prec","val_rec","val_f1",
    "test_acc","test_auroc","test_auprc","test_prec","test_rec","test_f1",
]

append_row_to_experiments_csv(experiments_csv_path, csv_row, csv_columns)
print(f"\nAppended metrics → {experiments_csv_path}")

### OLD TESTING

In [None]:
# Evaluate model
train_acc = 0
val_acc = 0
test_acc = 0
with open(eval_path, 'w') as f:
    eval_gen_train = DataGenerator(images=train_images_orig, image_classes=train_classes_orig, use_augmentation=False)
    res = model.evaluate(eval_gen_train)
    print('EVAL TRAIN:', res)
    f.write('Train loss: {}\r\n'.format(res[0]))
    f.write('Train acc: {}\r\n'.format(res[1]))
    train_acc = res[1]

    eval_gen_valid = DataGenerator(images=val_images, image_classes=val_classes, use_augmentation=False)
    res = model.evaluate(eval_gen_valid)
    print('EVAL VAL:', res)
    f.write('Val loss: {}\r\n'.format(res[0]))
    f.write('Val acc: {}\r\n'.format(res[1]))
    val_acc = res[1]

    if not join_test_with_train:
        eval_gen_test = DataGenerator(images=test_images, image_classes=test_classes, use_augmentation=False)
        res = model.evaluate(eval_gen_test)
        print('EVAL TEST:', res)
        f.write('Test loss: {}\r\n'.format(res[0]))
        f.write('Test acc: {}\r\n'.format(res[1]))
        test_acc = res[1]

# Write entry in experiments.csv
with open(experiments_file, 'a') as f:
    f.write('{},{}{},{},{},{},{},{},{},{},{},{},{},{},{}\r\n'.format(experiment_no,
                                                                     image_size,
                                                                     architecture,
                                                                     data_augmentation,
                                                                     init_lr,
                                                                     optimizer_name,
                                                                     monitor_loss,
                                                                     batch_size,
                                                                     train_acc,
                                                                     val_acc,
                                                                     test_acc))

# Rename current folder
curr_folder_new = curr_folder + ' ({})'.format(test_acc)
os.rename(curr_folder, curr_folder_new)
curr_folder = curr_folder_new

In [None]:
model_ft = create_classifier_model(True, augmenter)

# Load warmup weights
model_ft.load_weights(warmup_weights_path)

# New optimizer with smaller LR for fine-tuning
finetune_lr = 3e-4
optimizer_ft = tf.keras.optimizers.AdamW(learning_rate=finetune_lr, weight_decay=init_decay)

# Re-compile (same loss/metrics for fairness)
model_ft.compile(
    optimizer=optimizer_ft,
    loss=tf.keras.losses.binary_crossentropy,
    metrics=[tf.keras.metrics.binary_accuracy]
    # Tip: add AUCs when you’re ready:
    # metrics=[tf.keras.metrics.BinaryAccuracy(name="binary_accuracy"),
    #          tf.keras.metrics.AUC(name="auc"),
    #          tf.keras.metrics.AUC(curve="PR", name="auprc")]
)

# Continue training to total epochs (warmup + finetune = epochs)
remaining_epochs = max(0, epochs - epochs_warmup)
history_ft = model_ft.fit(
    train_ds,
    validation_data=valid_ds,
    initial_epoch=epochs_warmup,    # purely cosmetic in logs
    epochs=epochs_warmup + remaining_epochs,
    callbacks=[csv_logger, reduce_lr, early_stopping],
    class_weight=class_weight,
    verbose=1
)

# Keep the fine-tuned model as your “current” model
model = model_ft

In [None]:
plt.clf()
plt.plot(history_ft.history['binary_accuracy'])
plt.plot(history_ft.history['val_binary_accuracy'])
plt.savefig(os.path.join(curr_folder, 'accuracy.png'), dpi=300)

In [None]:
from sklearn.metrics import (
    roc_auc_score, average_precision_score, precision_recall_curve,
    roc_curve, f1_score, confusion_matrix, classification_report
)

In [None]:
def collect_probs_and_labels(model, ds):
    """Return y_true (0/1) and y_prob (sigmoid probs) from a batched dataset."""
    y_true, y_prob = [], []
    for x_batch, y_batch in ds:
        p = model.predict(x_batch, verbose=0).reshape(-1)   # sigmoid output
        y_true.append(y_batch.numpy().reshape(-1))
        y_prob.append(p)
    y_true = np.concatenate(y_true).astype(int)
    y_prob = np.concatenate(y_prob)
    return y_true, y_prob

y_val, p_val = collect_probs_and_labels(model, valid_dataset)

auroc  = roc_auc_score(y_val, p_val)
auprc  = average_precision_score(y_val, p_val)   # PR-AUC
print(f"Val AUROC: {auroc:.4f} | Val AUPRC: {auprc:.4f}")

# F1-max threshold
prec, rec, th = precision_recall_curve(y_val, p_val)
f1s = 2*prec*rec / (prec+rec + 1e-12)
ix_f1 = np.nanargmax(f1s)
tau_f1 = th[max(ix_f1-1, 0)]     # note: precision_recall_curve returns one fewer thresholds than points
print(f"Tau_F1 (val): {tau_f1:.4f}, F1={f1s[ix_f1]:.4f}, P={prec[ix_f1]:.4f}, R={rec[ix_f1]:.4f}")

# Sensitivity (recall) target, e.g., ≥0.90
target_recall = 0.90
# choose the smallest threshold achieving >= target recall
candidates = th[rec[:-1] >= target_recall]   # align lengths
tau_rec = candidates.min() if candidates.size else th[-1]
print(f"Tau_sens≥{target_recall:.0%} (val): {tau_rec:.4f}")

def evaluate_at_threshold(y, p, tau, title=""):
    y_hat = (p >= tau).astype(int)
    cm = confusion_matrix(y, y_hat)
    tn, fp, fn, tp = cm.ravel()
    print(title or f"@tau={tau:.4f}")
    print("Confusion matrix [[TN, FP],[FN, TP]]:\n", cm)
    print(f"Accuracy={((tp+tn)/cm.sum()):.4f}")
    print(f"Sensitivity/Recall={tp/(tp+fn+1e-12):.4f}  |  Specificity={tn/(tn+fp+1e-12):.4f}")
    print(f"Precision={tp/(tp+fp+1e-12):.4f}  |  F1={f1_score(y, y_hat):.4f}")
    print(classification_report(y, y_hat, digits=4))

evaluate_at_threshold(y_val, p_val, tau_f1, title="Validation @ F1-max τ")