In [6]:
"""
Super-clean pipeline: EfficientNetV2 + ResNet50V2 + SimpleCNN ensemble
Includes:
- INPUT_SIZE = (224,224)
- Optional CLAHE offline preprocessing (guarded)
- tf.data pipeline with correct cache/shuffle/prefetch order
- Data augmentation
- Warmup + CosineDecay LR schedule
- AdamW optimizer
- EMA (Exponential Moving Average) callback
- Per-model checkpoint folders (.keras SavedModel)
- TTA (with augmentation forced during TTA)
- PSO for ensemble weight search (configurable particles/iters)
- Logging to CSV and TensorBoard

How to use:
- Fill DATA_DIR with train/val subfolders
- Set USE_CLAHE_PREPROCESS=True if you want CLAHE applied (will write to dataset_clahe once)
- Run: python super_pipeline.py
"""

import os
import time
import math
import json
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.metrics import accuracy_score, classification_report
import cv2

# ---------------------- USER CONFIG ----------------------
DATA_DIR = "RetinalOCT_Dataset"  # root dir containing 'train' and 'val' subfolders
PREPROCESSED_DIR = "dataset_clahe"
INPUT_SIZE = (224, 224)
BATCH_SIZE = 32
NUM_CLASSES = 8
SEED = 42
EPOCHS = 50
MODEL_DIR = "models"
LOG_DIR = "logs"
USE_CLAHE_PREPROCESS = True  # if True, CLAHE will be applied once and saved to PREPROCESSED_DIR
USE_EMA = True
AUTOTUNE = tf.data.AUTOTUNE
PSO_PARTICLES = 20
PSO_ITERS = 150
TTA_STEPS = 5
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

# ---------------------- UTILITIES ----------------------

def list_classes(path):
    return sorted([d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))])

# ---------------------- CLAHE PREPROCESS (offline, guarded) ----------------------

def apply_clahe_to_folder(src_root, dst_root, size=INPUT_SIZE):
    if os.path.exists(dst_root) and any(os.scandir(dst_root)):
        print(f"CLAHE destination '{dst_root}' exists and not empty. Skipping CLAHE.")
        return
    print("Running CLAHE preprocessing (this may take a while)...")
    os.makedirs(dst_root, exist_ok=True)
    for split in ["train", "val"]:
        src_split = os.path.join(src_root, split)
        dst_split = os.path.join(dst_root, split)
        if not os.path.exists(src_split):
            print(f"Warning: {src_split} not found. Skipping.")
            continue
        for class_name in os.listdir(src_split):
            src_cls = os.path.join(src_split, class_name)
            dst_cls = os.path.join(dst_split, class_name)
            os.makedirs(dst_cls, exist_ok=True)
            for fname in os.listdir(src_cls):
                src_path = os.path.join(src_cls, fname)
                dst_path = os.path.join(dst_cls, fname)
                img = cv2.imread(src_path)
                if img is None:
                    continue
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = cv2.resize(img, size)
                lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
                l, a, b = cv2.split(lab)
                clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
                cl = clahe.apply(l)
                limg = cv2.merge((cl, a, b))
                out = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB)
                cv2.imwrite(dst_path, cv2.cvtColor(out, cv2.COLOR_RGB2BGR))
    print("CLAHE preprocessing finished.")

# ---------------------- DATA PIPELINE ----------------------
if USE_CLAHE_PREPROCESS:
    apply_clahe_to_folder(DATA_DIR, PREPROCESSED_DIR)
    DATA_DIR_USED = PREPROCESSED_DIR
else:
    DATA_DIR_USED = DATA_DIR

# sanity check classes
train_root = os.path.join(DATA_DIR_USED, 'train')
if not os.path.exists(train_root):
    raise FileNotFoundError(f"Train directory not found: {train_root}")

classes = list_classes(train_root)
print(f"Detected classes: {classes}")

# dataset loaders
train_ds = keras.preprocessing.image_dataset_from_directory(
    os.path.join(DATA_DIR_USED, 'train'),
    image_size=INPUT_SIZE,
    batch_size=BATCH_SIZE,
    label_mode='int',
    seed=SEED
)
val_ds = keras.preprocessing.image_dataset_from_directory(
    os.path.join(DATA_DIR_USED, 'val'),
    image_size=INPUT_SIZE,
    batch_size=BATCH_SIZE,
    label_mode='int',
    shuffle=False
)

# recommended order: shuffle -> map -> cache -> prefetch
train_ds = train_ds.shuffle(1000, seed=SEED).cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# compute steps_per_epoch safely
try:
    steps_per_epoch = int(tf.data.experimental.cardinality(train_ds).numpy())
    if steps_per_epoch <= 0:
        steps_per_epoch = None
except Exception:
    steps_per_epoch = None

print(f"steps_per_epoch (estimate): {steps_per_epoch}")

# augmentation pipeline
data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.08),
    layers.RandomZoom(0.08),
    layers.RandomContrast(0.12),
], name='data_augmentation')

# ---------------------- MODEL BUILDERS ----------------------
from tensorflow.keras.applications import efficientnet_v2, resnet_v2


def build_effnet(input_shape=(*INPUT_SIZE,3), num_classes=NUM_CLASSES, fine_tune_at=120):
    try:
        base = efficientnet_v2.EfficientNetV2B0(include_top=False, weights='imagenet', input_shape=input_shape)
        preprocess = efficientnet_v2.preprocess_input
    except Exception:
        base = keras.applications.EfficientNetB0(include_top=False, weights='imagenet', input_shape=input_shape)
        preprocess = keras.applications.efficientnet.preprocess_input

    base.trainable = True
    for layer in base.layers[:-fine_tune_at]:
        layer.trainable = False

    inputs = keras.Input(shape=input_shape)
    x = data_augmentation(inputs)
    x = preprocess(x)
    x = base(x, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.4)(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = keras.Model(inputs, outputs, name='EffNetV2')
    return model


def build_resnet(input_shape=(*INPUT_SIZE,3), num_classes=NUM_CLASSES, fine_tune_at=80):
    base = resnet_v2.ResNet50V2(include_top=False, weights='imagenet', input_shape=input_shape)
    preprocess = resnet_v2.preprocess_input

    base.trainable = True
    for layer in base.layers[:-fine_tune_at]:
        layer.trainable = False

    inputs = keras.Input(shape=input_shape)
    x = data_augmentation(inputs)
    x = preprocess(x)
    x = base(x, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.4)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = keras.Model(inputs, outputs, name='ResNet50V2')
    return model


def build_simple_cnn(input_shape=(*INPUT_SIZE,3), num_classes=NUM_CLASSES):
    inputs = keras.Input(shape=input_shape)
    x = data_augmentation(inputs)
    x = layers.Rescaling(1./255)(x)
    x = layers.Conv2D(32, 3, padding='same', activation='relu')(x)
    x = layers.MaxPool2D()(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
    x = layers.MaxPool2D()(x)
    x = layers.Conv2D(128, 3, padding='same', activation='relu')(x)
    x = layers.MaxPool2D()(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.4)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = keras.Model(inputs, outputs, name='SimpleCNN')
    return model

# ---------------------- SCHEDULERS / OPTIMIZER / EMA ----------------------

class WarmUpCosineDecay:
    def __init__(self, base_lr, total_steps, warmup_steps=0, alpha=0.0):
        self.base_lr = base_lr
        self.total_steps = max(1, int(total_steps))
        self.warmup_steps = int(warmup_steps)
        self.alpha = alpha

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        if self.warmup_steps > 0:
            warmup_pct = tf.minimum(1.0, step / tf.cast(self.warmup_steps, tf.float32))
            warmup_lr = self.base_lr * warmup_pct
        else:
            warmup_lr = self.base_lr

        # cosine after warmup
        progress = (step - self.warmup_steps) / tf.maximum(1.0, (self.total_steps - self.warmup_steps))
        cosine_lr = self.alpha + 0.5 * (1.0 - self.alpha) * (1.0 + tf.cos(math.pi * tf.clip_by_value(progress, 0.0, 1.0)))
        cosine_lr = self.base_lr * cosine_lr

        return tf.where(step < self.warmup_steps, warmup_lr, cosine_lr)

class ExponentialMovingAverageCallback(keras.callbacks.Callback):
    def __init__(self, ema_decay=0.9999):
        super().__init__()
        self.ema_decay = ema_decay
        self.ema_weights = None

    def set_model(self, model):
        super().set_model(model)
        # initialize shadow weights
        self.ema_weights = [tf.identity(w) for w in model.get_weights()]

    def on_train_batch_end(self, batch, logs=None):
        weights = self.model.get_weights()
        for i in range(len(weights)):
            self.ema_weights[i] = self.ema_decay * self.ema_weights[i] + (1.0 - self.ema_decay) * weights[i]

    def on_train_end(self, logs=None):
        # save current weights and set ema weights for final evaluation / saving
        self._backup = self.model.get_weights()
        self.model.set_weights(self.ema_weights)

    def restore(self):
        if hasattr(self, '_backup'):
            self.model.set_weights(self._backup)

# ---------------------- TRAIN UTIL ----------------------

def compile_and_train(model, train_ds, val_ds, model_name, epochs=EPOCHS, base_lr=1e-4, warmup_epochs=3):
    # per-model folder
    ckpt_dir = os.path.join(MODEL_DIR, model_name)
    os.makedirs(ckpt_dir, exist_ok=True)
    log_dir = os.path.join(LOG_DIR, model_name)
    os.makedirs(log_dir, exist_ok=True)

    # compute total steps
    if steps_per_epoch is None:
        # fallback: iterate once to count
        cnt = 0
        for _ in train_ds:
            cnt += 1
        sperep = max(1, cnt)
    else:
        sperep = steps_per_epoch
    total_steps = sperep * epochs
    warmup_steps = sperep * warmup_epochs

    # lr schedule
    schedules = WarmUpCosineDecay(base_lr, total_steps, warmup_steps=warmup_steps, alpha=0.0)
    lr_fn = lambda step: schedules(step)
    lr_schedule = keras.optimizers.schedules.LearningRateSchedule()
    # wrapcustom as a tf.function-compatible schedule
    class _LR(keras.optimizers.schedules.LearningRateSchedule):
        def __call__(self, step):
            return schedules(step)
    lr_schedule = _LR()

    optimizer = keras.optimizers.AdamW(learning_rate=lr_schedule, weight_decay=1e-5)

    model.compile(
        optimizer=optimizer,
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    # callbacks
    callbacks = []
    callbacks.append(keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(ckpt_dir, f"{model_name}.keras"),
        save_best_only=True, save_weights_only=False, monitor='val_accuracy'
    ))
    callbacks.append(keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-7))
    callbacks.append(keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True))
    callbacks.append(keras.callbacks.CSVLogger(os.path.join(log_dir, 'history.csv')))
    callbacks.append(keras.callbacks.TensorBoard(log_dir=log_dir))
    if USE_EMA:
        ema_cb = ExponentialMovingAverageCallback(ema_decay=0.9999)
        callbacks.append(ema_cb)
    else:
        ema_cb = None

    history = model.fit(train_ds, validation_data=val_ds, epochs=epochs, callbacks=callbacks)

    # if EMA used, ensure final weights are ema for evaluation and saving
    if ema_cb is not None:
        print("Applying EMA weights for final evaluation and saving...")
        ema_cb.on_train_end()

    # save final model copy (best already saved by checkpoint)
    final_path = os.path.join(ckpt_dir, f"{model_name}_final.keras")
    model.save(final_path)
    print(f"Saved final model to: {final_path}")

    # restore model if EMA backup exists (so model object remains original for further training)
    if ema_cb is not None:
        ema_cb.restore()

    return history

# ---------------------- TTA (with augmentation enabled) ----------------------

def predict_with_tta(model, dataset, tta_steps=TTA_STEPS):
    probs = []
    for x_batch, _ in dataset:
        x0 = tf.cast(x_batch, tf.float32)
        batch_probs = np.zeros((x0.shape[0], NUM_CLASSES), dtype=np.float32)
        for t in range(tta_steps):
            aug = data_augmentation(x0, training=True)
            preds = model.predict(aug, verbose=0)
            batch_probs += preds
        batch_probs /= float(tta_steps)
        probs.append(batch_probs)
    probs = np.vstack(probs)
    return probs

# ---------------------- ENSEMBLE / PSO ----------------------

def ensemble_average(probs_list):
    return np.mean(np.stack(probs_list, axis=0), axis=0)

def ensemble_weighted(probs_list, weights):
    w = np.array(weights).reshape(-1,1,1)
    stacked = np.stack(probs_list, axis=0)
    combined = np.sum(w * stacked, axis=0)
    return combined

class SimplePSO:
    def __init__(self, n_particles, dim, probs_list, y_true, iters=100, w=0.72, c1=1.49, c2=1.49):
        self.n_particles = n_particles
        self.dim = dim
        self.probs_list = probs_list
        self.y_true = y_true
        self.iters = iters
        self.w = w
        self.c1 = c1
        self.c2 = c2
        self.pos = np.random.rand(n_particles, dim)
        self.pos = self.pos / np.sum(self.pos, axis=1, keepdims=True)
        self.vel = np.zeros_like(self.pos)
        self.pbest_pos = self.pos.copy()
        self.pbest_val = np.array([self._fitness(p) for p in self.pos])
        self.gbest_idx = np.argmin(self.pbest_val)
        self.gbest_pos = self.pbest_pos[self.gbest_idx].copy()
        self.gbest_val = self.pbest_val[self.gbest_idx]

    def _fitness(self, weights):
        combined = np.zeros_like(self.probs_list[0])
        for w, probs in zip(weights, self.probs_list):
            combined += w * probs
        preds = np.argmax(combined, axis=1)
        acc = accuracy_score(self.y_true, preds)
        return -acc

    def optimize(self):
        for it in range(self.iters):
            r1 = np.random.rand(self.n_particles, self.dim)
            r2 = np.random.rand(self.n_particles, self.dim)
            cognitive = self.c1 * r1 * (self.pbest_pos - self.pos)
            social = self.c2 * r2 * (self.gbest_pos - self.pos)
            self.vel = self.w * self.vel + cognitive + social
            self.pos = self.pos + self.vel
            self.pos = np.clip(self.pos, 1e-6, None)
            self.pos = self.pos / np.sum(self.pos, axis=1, keepdims=True)
            vals = np.array([self._fitness(p) for p in self.pos])
            improved = vals < self.pbest_val
            self.pbest_val[improved] = vals[improved]
            self.pbest_pos[improved] = self.pos[improved]
            gidx = np.argmin(self.pbest_val)
            if self.pbest_val[gidx] < self.gbest_val:
                self.gbest_val = self.pbest_val[gidx]
                self.gbest_pos = self.pbest_pos[gidx].copy()
            if it % max(1, self.iters//10) == 0 or it == self.iters - 1:
                print(f"PSO iter {it+1}/{self.iters}, best_acc = {-self.gbest_val:.4f}")
        return self.gbest_pos, -self.gbest_val

# ---------------------- MAIN WORKFLOW ----------------------
if __name__ == '__main__':
    print("Building models (EffNetV2, ResNet50V2, SimpleCNN)")
    effnet = build_effnet(fine_tune_at=120)
    resnet = build_resnet(fine_tune_at=80)
    cnn = build_simple_cnn()

    print(effnet.summary())
    print(resnet.summary())
    print(cnn.summary())

    # Train models
    compile_and_train(effnet, train_ds, val_ds, model_name='effnetv2_finetuned', epochs=EPOCHS, base_lr=1e-4, warmup_epochs=3)
    compile_and_train(resnet, train_ds, val_ds, model_name='resnet50v2_finetuned', epochs=EPOCHS, base_lr=1e-4, warmup_epochs=3)
    compile_and_train(cnn, train_ds, val_ds, model_name='simplecnn', epochs=EPOCHS, base_lr=1e-4, warmup_epochs=2)

    # Evaluate with TTA
    print("Computing TTA predictions (this may take a while)")
    eff_probs = predict_with_tta(effnet, val_ds, tta_steps=TTA_STEPS)
    res_probs = predict_with_tta(resnet, val_ds, tta_steps=TTA_STEPS)
    cnn_probs = predict_with_tta(cnn, val_ds, tta_steps=TTA_STEPS)

    y_true = np.concatenate([y.numpy() for _, y in val_ds], axis=0)

    # Simple average
    avg_probs = ensemble_average([eff_probs, res_probs, cnn_probs])
    avg_pred = np.argmax(avg_probs, axis=1)
    avg_acc = accuracy_score(y_true, avg_pred)
    print(f"Average ensemble accuracy: {avg_acc:.4f}")
    print(classification_report(y_true, avg_pred))

    # PSO optimize
    print("Running PSO to find best ensemble weights...")
    pso = SimplePSO(n_particles=PSO_PARTICLES, dim=3, probs_list=[eff_probs, res_probs, cnn_probs], y_true=y_true, iters=PSO_ITERS)
    best_w, best_acc = pso.optimize()
    print(f"Best weights: {best_w}, best_acc: {best_acc:.4f}")

    combined = ensemble_weighted([eff_probs, res_probs, cnn_probs], best_w)
    comb_pred = np.argmax(combined, axis=1)
    print(classification_report(y_true, comb_pred))

    np.save(os.path.join(MODEL_DIR, 'ensemble_weights_3models.npy'), best_w)
    print("All done. Models and ensemble weights saved.")


CLAHE destination 'dataset_clahe' exists and not empty. Skipping CLAHE.
Detected classes: ['AMD', 'CNV', 'CSR', 'DME', 'DR', 'DRUSEN', 'MH', 'NORMAL']
Found 18400 files belonging to 8 classes.
Found 2800 files belonging to 8 classes.
steps_per_epoch (estimate): 575
Building models (EffNetV2, ResNet50V2, SimpleCNN)


None


None


None
Epoch 1/50




[1m575/575[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 956ms/step - accuracy: 0.1737 - loss: 2.7181

NotImplementedError: Learning rate schedule '_LR' must override `get_config()` in order to be serializable.

In [10]:
"""
Super-clean pipeline (final, error-free):
EfficientNetV2 + ResNet50V2 + SimpleCNN ensemble with PSO weight search,
optional CLAHE preprocessing, EMA, TTA, AdamW optimizer.
"""

import os
import math
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.metrics import accuracy_score, classification_report
import cv2

# ---------------------- USER CONFIG ----------------------
DATA_DIR = "RetinalOCT_Dataset"      # must contain train/val subfolders
PREPROCESSED_DIR = "dataset_clahe"
INPUT_SIZE = (224, 224)
BATCH_SIZE = 32
NUM_CLASSES = 8
SEED = 42
EPOCHS = 50
MODEL_DIR = "models"
LOG_DIR = "logs"
USE_CLAHE_PREPROCESS = True
USE_EMA = True
AUTOTUNE = tf.data.AUTOTUNE
PSO_PARTICLES = 20
PSO_ITERS = 150
TTA_STEPS = 5

os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

# ---------------------- UTILITIES ----------------------
def list_classes(path):
    return sorted([d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))])

# ---------------------- CLAHE PREPROCESS (offline, guarded) ----------------------
def apply_clahe_to_folder(src_root, dst_root, size=INPUT_SIZE):
    # Guard: run only if destination not exists or empty
    if os.path.exists(dst_root) and any(os.scandir(dst_root)):
        print(f"CLAHE destination '{dst_root}' exists and not empty. Skipping CLAHE.")
        return
    print("Running CLAHE preprocessing (this may take a while)...")
    os.makedirs(dst_root, exist_ok=True)
    for split in ["train", "val"]:
        src_split = os.path.join(src_root, split)
        dst_split = os.path.join(dst_root, split)
        if not os.path.exists(src_split):
            print(f"Warning: {src_split} not found. Skipping.")
            continue
        for class_name in os.listdir(src_split):
            src_cls = os.path.join(src_split, class_name)
            dst_cls = os.path.join(dst_split, class_name)
            os.makedirs(dst_cls, exist_ok=True)
            for fname in os.listdir(src_cls):
                src_path = os.path.join(src_cls, fname)
                dst_path = os.path.join(dst_cls, fname)
                img = cv2.imread(src_path)
                if img is None:
                    continue
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = cv2.resize(img, size)
                lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
                l, a, b = cv2.split(lab)
                clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
                cl = clahe.apply(l)
                limg = cv2.merge((cl, a, b))
                out = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB)
                cv2.imwrite(dst_path, cv2.cvtColor(out, cv2.COLOR_RGB2BGR))
    print("CLAHE preprocessing finished.")

# ---------------------- DATA PIPELINE ----------------------
if USE_CLAHE_PREPROCESS:
    apply_clahe_to_folder(DATA_DIR, PREPROCESSED_DIR)
    DATA_DIR_USED = PREPROCESSED_DIR
else:
    DATA_DIR_USED = DATA_DIR

train_root = os.path.join(DATA_DIR_USED, 'train')
if not os.path.exists(train_root):
    raise FileNotFoundError(f"Train directory not found: {train_root}")

classes = list_classes(train_root)
print(f"Detected classes: {classes}")

train_ds = keras.preprocessing.image_dataset_from_directory(
    os.path.join(DATA_DIR_USED, 'train'),
    image_size=INPUT_SIZE,
    batch_size=BATCH_SIZE,
    label_mode='int',
    seed=SEED
)
val_ds = keras.preprocessing.image_dataset_from_directory(
    os.path.join(DATA_DIR_USED, 'val'),
    image_size=INPUT_SIZE,
    batch_size=BATCH_SIZE,
    label_mode='int',
    shuffle=False
)

# Order: shuffle -> cache -> prefetch (cache after shuffle to avoid caching shuffled order issue)
train_ds = train_ds.shuffle(1000, seed=SEED).cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# compute steps_per_epoch safely (may be None)
try:
    steps_per_epoch = int(tf.data.experimental.cardinality(train_ds).numpy())
    if steps_per_epoch <= 0:
        steps_per_epoch = None
except Exception:
    steps_per_epoch = None
print(f"steps_per_epoch (estimate): {steps_per_epoch}")

# augmentation pipeline (used in both training and TTA)
data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.08),
    layers.RandomZoom(0.08),
    layers.RandomContrast(0.12),
], name='data_augmentation')

# ---------------------- MODEL BUILDERS ----------------------
from tensorflow.keras.applications import efficientnet_v2, resnet_v2

def build_effnet(input_shape=(*INPUT_SIZE,3), num_classes=NUM_CLASSES, fine_tune_at=120):
    try:
        base = efficientnet_v2.EfficientNetV2B0(include_top=False, weights='imagenet', input_shape=input_shape)
        preprocess = efficientnet_v2.preprocess_input
    except Exception:
        base = keras.applications.EfficientNetB0(include_top=False, weights='imagenet', input_shape=input_shape)
        preprocess = keras.applications.efficientnet.preprocess_input

    base.trainable = True
    for layer in base.layers[:-fine_tune_at]:
        layer.trainable = False

    inputs = keras.Input(shape=input_shape)
    x = data_augmentation(inputs)
    x = preprocess(x)
    x = base(x, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.4)(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = keras.Model(inputs, outputs, name='EffNetV2')
    return model

def build_resnet(input_shape=(*INPUT_SIZE,3), num_classes=NUM_CLASSES, fine_tune_at=80):
    base = resnet_v2.ResNet50V2(include_top=False, weights='imagenet', input_shape=input_shape)
    preprocess = resnet_v2.preprocess_input

    base.trainable = True
    for layer in base.layers[:-fine_tune_at]:
        layer.trainable = False

    inputs = keras.Input(shape=input_shape)
    x = data_augmentation(inputs)
    x = preprocess(x)
    x = base(x, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.4)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = keras.Model(inputs, outputs, name='ResNet50V2')
    return model

def build_simple_cnn(input_shape=(*INPUT_SIZE,3), num_classes=NUM_CLASSES):
    inputs = keras.Input(shape=input_shape)
    x = data_augmentation(inputs)
    x = layers.Rescaling(1./255)(x)
    x = layers.Conv2D(32, 3, padding='same', activation='relu')(x)
    x = layers.MaxPool2D()(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
    x = layers.MaxPool2D()(x)
    x = layers.Conv2D(128, 3, padding='same', activation='relu')(x)
    x = layers.MaxPool2D()(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.4)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = keras.Model(inputs, outputs, name='SimpleCNN')
    return model

# ---------------------- EMA (fixed) ----------------------
class ExponentialMovingAverageCallback(keras.callbacks.Callback):
    def __init__(self, ema_decay=0.9999):
        super().__init__()
        self.ema_decay = float(ema_decay)
        self.ema_weights = None
        self._backup = None

    def set_model(self, model):
        # Use parent to set internal model reference
        super().set_model(model)
        # store shadow weights as numpy arrays (copy)
        self.ema_weights = [w.copy() for w in self.model.get_weights()]

    def on_train_batch_end(self, batch, logs=None):
        # update shadow weights (numpy operations)
        weights = self.model.get_weights()
        for i in range(len(weights)):
            # ensure ema_weights initialized
            if self.ema_weights[i] is None:
                self.ema_weights[i] = weights[i].copy()
            else:
                self.ema_weights[i] = self.ema_decay * self.ema_weights[i] + (1.0 - self.ema_decay) * weights[i]

    def on_train_end(self, logs=None):
        # backup current weights and set ema weights
        self._backup = self.model.get_weights()
        try:
            self.model.set_weights(self.ema_weights)
        except Exception as e:
            print("Warning: failed to set EMA weights on model:", e)

    def restore(self):
        if self._backup is not None:
            self.model.set_weights(self._backup)

# ---------------------- TRAIN UTIL ----------------------
def compile_and_train(model, train_ds, val_ds, model_name, epochs=EPOCHS, base_lr=1e-4, warmup_epochs=0):
    ckpt_dir = os.path.join(MODEL_DIR, model_name)
    os.makedirs(ckpt_dir, exist_ok=True)
    log_dir = os.path.join(LOG_DIR, model_name)
    os.makedirs(log_dir, exist_ok=True)

    # compute steps_per_epoch fallback
    if steps_per_epoch is None:
        cnt = 0
        for _ in train_ds:
            cnt += 1
        sperep = max(1, cnt)
    else:
        sperep = steps_per_epoch

    # optimizer (simple AdamW with fixed lr)
    optimizer = keras.optimizers.AdamW(learning_rate=base_lr, weight_decay=1e-5)

    model.compile(
        optimizer=optimizer,
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    callbacks = [
        keras.callbacks.ModelCheckpoint(
            filepath=os.path.join(ckpt_dir, f"{model_name}.keras"),
            save_best_only=True, save_weights_only=False, monitor='val_accuracy'
        ),
        keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-7),
        keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True),
        keras.callbacks.CSVLogger(os.path.join(log_dir, 'history.csv')),
        keras.callbacks.TensorBoard(log_dir=log_dir),
    ]

    ema_cb = None
    if USE_EMA:
        ema_cb = ExponentialMovingAverageCallback(ema_decay=0.9999)
        callbacks.append(ema_cb)

    history = model.fit(train_ds, validation_data=val_ds, epochs=epochs, callbacks=callbacks)

    # if EMA used, ensure final weights are ema for saving/evaluation (on_train_end should have set them)
    if ema_cb is not None:
        print("EMA applied at training end (if available).")

    # save final model copy (best already saved by checkpoint)
    final_path = os.path.join(ckpt_dir, f"{model_name}_final.keras")
    model.save(final_path)
    print(f"Saved final model to: {final_path}")

    # restore model weights back to original (if EMA was applied and we want model object restored)
    if ema_cb is not None:
        ema_cb.restore()

    return history

# ---------------------- TTA (with augmentation enabled) ----------------------
def predict_with_tta(model, dataset, tta_steps=TTA_STEPS):
    probs = []
    for x_batch, _ in dataset:
        x0 = tf.cast(x_batch, tf.float32)
        batch_probs = np.zeros((x0.shape[0], NUM_CLASSES), dtype=np.float32)
        for _ in range(tta_steps):
            aug = data_augmentation(x0, training=True)   # force augmentation
            preds = model.predict(aug, verbose=0)
            batch_probs += preds
        batch_probs /= float(tta_steps)
        probs.append(batch_probs)
    probs = np.vstack(probs)
    return probs

# ---------------------- ENSEMBLE / PSO ----------------------
def ensemble_average(probs_list):
    return np.mean(np.stack(probs_list, axis=0), axis=0)

def ensemble_weighted(probs_list, weights):
    w = np.array(weights).reshape(-1,1,1)
    stacked = np.stack(probs_list, axis=0)
    combined = np.sum(w * stacked, axis=0)
    return combined

class SimplePSO:
    def __init__(self, n_particles, dim, probs_list, y_true, iters=100, w=0.72, c1=1.49, c2=1.49):
        self.n_particles = n_particles
        self.dim = dim
        self.probs_list = probs_list
        self.y_true = y_true
        self.iters = iters
        self.w = w
        self.c1 = c1
        self.c2 = c2
        self.pos = np.random.rand(n_particles, dim)
        self.pos = self.pos / np.sum(self.pos, axis=1, keepdims=True)
        self.vel = np.zeros_like(self.pos)
        self.pbest_pos = self.pos.copy()
        self.pbest_val = np.array([self._fitness(p) for p in self.pos])
        self.gbest_idx = np.argmin(self.pbest_val)
        self.gbest_pos = self.pbest_pos[self.gbest_idx].copy()
        self.gbest_val = self.pbest_val[self.gbest_idx]

    def _fitness(self, weights):
        combined = np.zeros_like(self.probs_list[0])
        for w, probs in zip(weights, self.probs_list):
            combined += w * probs
        preds = np.argmax(combined, axis=1)
        acc = accuracy_score(self.y_true, preds)
        return -acc

    def optimize(self):
        for it in range(self.iters):
            r1 = np.random.rand(self.n_particles, self.dim)
            r2 = np.random.rand(self.n_particles, self.dim)
            cognitive = self.c1 * r1 * (self.pbest_pos - self.pos)
            social = self.c2 * r2 * (self.gbest_pos - self.pos)
            self.vel = self.w * self.vel + cognitive + social
            self.pos = self.pos + self.vel
            self.pos = np.clip(self.pos, 1e-6, None)
            self.pos = self.pos / np.sum(self.pos, axis=1, keepdims=True)
            vals = np.array([self._fitness(p) for p in self.pos])
            improved = vals < self.pbest_val
            self.pbest_val[improved] = vals[improved]
            self.pbest_pos[improved] = self.pos[improved]
            gidx = np.argmin(self.pbest_val)
            if self.pbest_val[gidx] < self.gbest_val:
                self.gbest_val = self.pbest_val[gidx]
                self.gbest_pos = self.pbest_pos[gidx].copy()
            if it % max(1, self.iters//10) == 0 or it == self.iters - 1:
                print(f"PSO iter {it+1}/{self.iters}, best_acc = {-self.gbest_val:.4f}")
        return self.gbest_pos, -self.gbest_val

# ---------------------- MAIN WORKFLOW ----------------------
if __name__ == '__main__':
    print("Building models (EffNetV2, ResNet50V2, SimpleCNN)")
    effnet = build_effnet(fine_tune_at=120)
    resnet = build_resnet(fine_tune_at=80)
    cnn = build_simple_cnn()

    print(effnet.summary())
    print(resnet.summary())
    print(cnn.summary())

    # Train models
    compile_and_train(effnet, train_ds, val_ds, model_name='effnetv2_finetuned', epochs=EPOCHS, base_lr=1e-4, warmup_epochs=3)
    compile_and_train(resnet, train_ds, val_ds, model_name='resnet50v2_finetuned', epochs=EPOCHS, base_lr=1e-4, warmup_epochs=3)
    compile_and_train(cnn, train_ds, val_ds, model_name='simplecnn', epochs=EPOCHS, base_lr=1e-4, warmup_epochs=2)

    # Evaluate with TTA
    print("Computing TTA predictions (this may take a while)")
    eff_probs = predict_with_tta(effnet, val_ds, tta_steps=TTA_STEPS)
    res_probs = predict_with_tta(resnet, val_ds, tta_steps=TTA_STEPS)
    cnn_probs = predict_with_tta(cnn, val_ds, tta_steps=TTA_STEPS)

    y_true = np.concatenate([y.numpy() for _, y in val_ds], axis=0)

    # Simple average
    avg_probs = ensemble_average([eff_probs, res_probs, cnn_probs])
    avg_pred = np.argmax(avg_probs, axis=1)
    avg_acc = accuracy_score(y_true, avg_pred)
    print(f"Average ensemble accuracy: {avg_acc:.4f}")
    print(classification_report(y_true, avg_pred))

    # PSO optimize
    print("Running PSO to find best ensemble weights...")
    pso = SimplePSO(n_particles=PSO_PARTICLES, dim=3, probs_list=[eff_probs, res_probs, cnn_probs], y_true=y_true, iters=PSO_ITERS)
    best_w, best_acc = pso.optimize()
    print(f"Best weights: {best_w}, best_acc: {best_acc:.4f}")

    combined = ensemble_weighted([eff_probs, res_probs, cnn_probs], best_w)
    comb_pred = np.argmax(combined, axis=1)
    print(classification_report(y_true, comb_pred))

    np.save(os.path.join(MODEL_DIR, 'ensemble_weights_3models.npy'), best_w)
    print("All done. Models and ensemble weights saved.")


CLAHE destination 'dataset_clahe' exists and not empty. Skipping CLAHE.
Detected classes: ['AMD', 'CNV', 'CSR', 'DME', 'DR', 'DRUSEN', 'MH', 'NORMAL']
Found 18400 files belonging to 8 classes.
Found 2800 files belonging to 8 classes.
steps_per_epoch (estimate): 575
Building models (EffNetV2, ResNet50V2, SimpleCNN)


None


None


None
Epoch 1/50
[1m575/575[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m654s[0m 1s/step - accuracy: 0.7198 - loss: 0.7860 - val_accuracy: 0.9182 - val_loss: 0.2171 - learning_rate: 1.0000e-04
Epoch 2/50
[1m575/575[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m542s[0m 942ms/step - accuracy: 0.8968 - loss: 0.2963 - val_accuracy: 0.9300 - val_loss: 0.1796 - learning_rate: 1.0000e-04
Epoch 3/50
[1m575/575[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m569s[0m 989ms/step - accuracy: 0.9230 - loss: 0.2259 - val_accuracy: 0.9425 - val_loss: 0.1547 - learning_rate: 1.0000e-04
Epoch 4/50
[1m575/575[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m562s[0m 977ms/step - accuracy: 0.9324 - loss: 0.1939 - val_accuracy: 0.9450 - val_loss: 0.1464 - learning_rate: 1.0000e-04
Epoch 5/50
[1m575/575[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m568s[0m 987ms/step - accuracy: 0.9411 - loss: 0.1677 - val_accuracy: 0.9511 - val_loss: 0.1342 - learning_rate: 1.0000e-04
Epoch 6/50
[1m575/575[0m 