In [None]:
# ==========================================================
# Full end-to-end clean mini-training script (copy-paste)
# ==========================================================

# Imports
import os
import math
import random
import hashlib
from pathlib import Path
from collections import defaultdict
import numpy as np
import cv2
from PIL import Image
import imagehash
import albumentations as A
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import EfficientNetB4

# -------------------------
# Config
# -------------------------
DATA_DIR = "/kaggle/input/eye-diseases-classification/dataset"
OUTPUT_DIR = "/kaggle/working"
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "artifacts"), exist_ok=True)

# Seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# Sizes / training control
IMG_SIZE = (380, 380)
BATCH_SIZE = 16
WARMUP_EPOCHS = 5  # as requested (doubled)
TOTAL_EPOCHS = 80  # as requested (doubled)
INITIAL_LR = 3e-4
WEIGHT_DECAY = 1e-5
PATIENCE_ES = 12
PATIENCE_RLR = 6

# Features toggles (kept same as original)
USE_IMAGENET = True
USE_CLAHE = True
USE_MIXUP = True
MIXUP_ALPHA = 0.2
USE_CUTMIX = True
CUTMIX_ALPHA = 1.0
LABEL_SMOOTHING = 0.05
USE_FOCAL_LOSS = False
TTA_ROUNDS = 3

# -------------------------
# Optional: Focal loss (kept same)
# -------------------------
def categorical_focal_loss(gamma=2.0, alpha=0.25):
    def loss_fn(y_true, y_pred):
        y_pred = tf.clip_by_value(y_pred, keras.backend.epsilon(), 1.0 - keras.backend.epsilon())
        ce = -y_true * tf.math.log(y_pred)
        weight = alpha * tf.pow(1 - y_pred, gamma)
        loss = weight * ce
        return tf.reduce_sum(loss, axis=-1)
    return loss_fn

# -------------------------
# SEBlock (custom layer used in saved model / training)
# -------------------------
class SEBlock(layers.Layer):
    def __init__(self, se_ratio=0.25, **kwargs):
        super().__init__(**kwargs)
        self.se_ratio = se_ratio

    def build(self, input_shape):
        channels = int(input_shape[-1])
        reduced = max(1, int(channels * self.se_ratio))
        self.gap = layers.GlobalAveragePooling2D()
        self.fc1 = layers.Dense(reduced, activation="relu", kernel_initializer="he_normal")
        self.fc2 = layers.Dense(channels, activation="sigmoid", kernel_initializer="he_normal")
        self.reshape = layers.Reshape((1, 1, channels))

    def call(self, x):
        se = self.gap(x)
        se = self.fc1(se)
        se = self.fc2(se)
        se = self.reshape(se)
        return x * se

# -------------------------
# Preprocessor (CLAHE + albumentations)
# -------------------------
class Preprocessor:
    def __init__(self, img_size=(380, 380), use_clahe=True):
        self.img_size = img_size
        self.use_clahe = use_clahe
        if use_clahe:
            try:
                self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            except Exception:
                self.clahe = None

        self.train_aug = A.Compose([
            A.Resize(img_size[0], img_size[1]),
            A.RandomRotate90(p=0.15),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.OneOf([
                A.RandomBrightnessContrast(p=1.0),
                A.HueSaturationValue(p=1.0)
            ], p=0.6),
            A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.15, rotate_limit=15, p=0.6),
            A.OneOf([A.GaussNoise(), A.ISONoise()], p=0.2),
            A.OneOf([A.Blur(3), A.GaussianBlur(3)], p=0.2),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])

        self.val_aug = A.Compose([
            A.Resize(img_size[0], img_size[1]),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])

    def apply_clahe(self, img):
        if self.clahe is None:
            return img
        lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
        lab[:, :, 0] = self.clahe.apply(lab[:, :, 0])
        return cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)

    def preprocess(self, path_or_img, training=True):
        if isinstance(path_or_img, str):
            img = cv2.imread(path_or_img)
            if img is None:
                img = np.zeros((self.img_size[0], self.img_size[1], 3), dtype=np.uint8)
            else:
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        else:
            img = path_or_img

        if self.use_clahe and img is not None:
            try:
                img = self.apply_clahe(img)
            except Exception:
                pass

        aug = self.train_aug if training else self.val_aug
        out = aug(image=img)['image']
        return out.astype(np.float32)

# -------------------------
# Data generator (MixUp + CutMix)
# -------------------------
class DataGenerator(keras.utils.Sequence):
    def __init__(self, filepaths, labels, batch_size, preprocessor, num_classes,
                 shuffle=True, mixup_prob=0.5, cutmix_prob=0.5, mixup_alpha=0.2, cutmix_alpha=1.0):
        super().__init__()
        self.filepaths = np.array(filepaths)
        self.labels = np.array(labels)
        self.batch_size = batch_size
        self.prep = preprocessor
        self.num_classes = num_classes
        self.indexes = np.arange(len(self.filepaths))
        self.shuffle = shuffle
        self.on_epoch_end()
        self.mixup_prob = mixup_prob
        self.cutmix_prob = cutmix_prob
        self.mixup_alpha = mixup_alpha
        self.cutmix_alpha = cutmix_alpha

    def __len__(self):
        return int(np.ceil(len(self.filepaths) / self.batch_size))

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def _one_hot(self, idx):
        lab = np.zeros(self.num_classes, dtype=np.float32)
        lab[idx] = 1.0
        return lab

    def _mixup(self, x1, y1, x2, y2, alpha):
        lam = np.random.beta(alpha, alpha) if alpha > 0 else 1.0
        x = lam * x1 + (1 - lam) * x2
        y = lam * y1 + (1 - lam) * y2
        return x, y

    def _cutmix(self, x1, y1, x2, y2, alpha):
        H, W = x1.shape[:2]
        lam = np.random.beta(alpha, alpha) if alpha > 0 else 1.0
        cut_rat = math.sqrt(max(0.0, 1.0 - lam))
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)
        cx = np.random.randint(0, W)
        cy = np.random.randint(0, H)
        x1_copy = x1.copy()
        x2_copy = x2.copy()
        x1_copy[max(0, cy - cut_h // 2):max(0, cy - cut_h // 2) + cut_h,
                max(0, cx - cut_w // 2):max(0, cx - cut_w // 2) + cut_w, :] = \
            x2_copy[max(0, cy - cut_h // 2):max(0, cy - cut_h // 2) + cut_h,
                    max(0, cx - cut_w // 2):max(0, cx - cut_w // 2) + cut_w, :]
        new_lam = 1.0 - (cut_w * cut_h) / (W * H) if (W * H) > 0 else 1.0
        y = new_lam * y1 + (1.0 - new_lam) * y2
        return x1_copy, y

    def __getitem__(self, idx):
        start = idx * self.batch_size
        end = min((idx + 1) * self.batch_size, len(self.filepaths))
        batch_inds = self.indexes[start:end]
        bsize = len(batch_inds)
        X = np.zeros((bsize, IMG_SIZE[0], IMG_SIZE[1], 3), dtype=np.float32)
        Y = np.zeros((bsize, self.num_classes), dtype=np.float32)

        for i, ind in enumerate(batch_inds):
            img = self.prep.preprocess(self.filepaths[ind], training=True)
            lbl = self._one_hot(self.labels[ind])
            X[i] = img
            Y[i] = lbl

        if (USE_MIXUP or USE_CUTMIX) and bsize > 0:
            for i in range(bsize):
                p = np.random.rand()
                j = np.random.randint(0, len(self.filepaths)) if len(self.filepaths) > 1 else batch_inds[i]
                x2 = self.prep.preprocess(self.filepaths[j], training=True)
                y2 = self._one_hot(self.labels[j])

                if USE_CUTMIX and p < self.cutmix_prob:
                    try:
                        X[i], Y[i] = self._cutmix(X[i], Y[i], x2, y2, CUTMIX_ALPHA)
                    except Exception:
                        X[i], Y[i] = self._mixup(X[i], Y[i], x2, y2, MIXUP_ALPHA)
                elif USE_MIXUP and p < self.mixup_prob:
                    X[i], Y[i] = self._mixup(X[i], Y[i], x2, y2, MIXUP_ALPHA)
        return X, Y
# -------------------------
# Load dataset (class folders)
# -------------------------
def load_filepaths_labels(root):
    root = Path(root)
    if not root.exists():
        raise FileNotFoundError(f"DATA_DIR '{root}' does not exist.")
    class_dirs = sorted([d for d in root.iterdir() if d.is_dir()])
    classes = [p.name for p in class_dirs]
    filepaths, labels = [], []
    for idx, dirp in enumerate(class_dirs):
        for p in dirp.glob("*"):
            if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"):
                filepaths.append(str(p))
                labels.append(idx)
    return filepaths, labels, classes

filepaths, labels, classes = load_filepaths_labels(DATA_DIR)
NUM_CLASSES = len(classes)
print("Classes:", classes)
print("Total images:", len(filepaths))

# -------------------------
# Duplicate detection (exact md5 + perceptual phash)
# -------------------------
def md5_file(path):
    h = hashlib.md5()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            h.update(chunk)
    return h.hexdigest()

# Exact duplicates
md5_map = defaultdict(list)
for fp in filepaths:
    try:
        md = md5_file(fp)
        md5_map[md].append(fp)
    except Exception:
        pass
exact_duplicates = [v for v in md5_map.values() if len(v) > 1]
print(f"Exact duplicate groups found: {len(exact_duplicates)}")
if len(exact_duplicates) > 0:
    for g in exact_duplicates[:5]:
        print("EXACT DUP GROUP (example):", g)

# Perceptual near-duplicates
phash_map = defaultdict(list)
for fp in filepaths:
    try:
        h = str(imagehash.phash(Image.open(fp)))
        phash_map[h].append(fp)
    except Exception:
        pass
near_duplicates = [v for v in phash_map.values() if len(v) > 1]
print(f"Perceptual (phash) duplicate groups found: {len(near_duplicates)}")
if len(near_duplicates) > 0:
    for g in near_duplicates[:5]:
        print("NEAR DUP GROUP (example):", g)

# -------------------------
# Remove all duplicates
# -------------------------
dup_paths = set()
for s in exact_duplicates + near_duplicates:
    for p in s:
        dup_paths.add(p)

clean_filepaths = [fp for fp in filepaths if fp not in dup_paths]
clean_labels = [labels[i] for i, fp in enumerate(filepaths) if fp not in dup_paths]
print(f"After removing duplicates: {len(clean_filepaths)} images (removed {len(filepaths)-len(clean_filepaths)})")

# -------------------------
# Split: train / val / test
# -------------------------
train_paths, temp_paths, train_labels, temp_labels = train_test_split(
    clean_filepaths, clean_labels, test_size=0.30, random_state=SEED, stratify=clean_labels
)
val_paths, test_paths, val_labels, test_labels = train_test_split(
    temp_paths, temp_labels, test_size=0.50, random_state=SEED, stratify=temp_labels
)
print("Split sizes -> train:", len(train_paths), "val:", len(val_paths), "test:", len(test_paths))

# -------------------------
# Compute class weights
# -------------------------
cw = compute_class_weight("balanced", classes=np.unique(train_labels), y=np.array(train_labels))
class_weights = {i: float(w) for i, w in enumerate(cw)}
print("class_weights:", class_weights)

# -------------------------
# Create generators
# -------------------------
preproc = Preprocessor(img_size=IMG_SIZE, use_clahe=USE_CLAHE)
train_gen = DataGenerator(train_paths, train_labels, BATCH_SIZE, preproc, NUM_CLASSES,
                          shuffle=True, mixup_prob=0.5 if USE_MIXUP else 0.0,
                          cutmix_prob=0.5 if USE_CUTMIX else 0.0,
                          mixup_alpha=MIXUP_ALPHA, cutmix_alpha=CUTMIX_ALPHA)
val_gen = DataGenerator(val_paths, val_labels, BATCH_SIZE, preproc, NUM_CLASSES,
                        shuffle=False, mixup_prob=0.0, cutmix_prob=0.0)

# -------------------------
# Model builder (EffNetB4 + SE + head)
# -------------------------
def build_model(input_shape=(380, 380, 3), num_classes=4, dropout_rate=0.3):
    inputs = keras.Input(shape=input_shape)
    weights = "imagenet" if USE_IMAGENET else None
    try:
        base = EfficientNetB4(include_top=False, weights=weights, input_tensor=inputs)
    except Exception as e:
        print("EfficientNetB4 weight load failed; using random init. Error:", e)
        base = EfficientNetB4(include_top=False, weights=None, input_tensor=inputs)
    x = base.output
    x = SEBlock(se_ratio=0.25)(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(dropout_rate)(x)
    x = layers.Dense(512, activation="relu", kernel_initializer="he_normal")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout_rate * 0.5)(x)
    outputs = layers.Dense(num_classes, activation="softmax", dtype="float32")(x)
    model = keras.Model(inputs, outputs)
    return model, base

model, base = build_model(input_shape=(*IMG_SIZE, 3), num_classes=NUM_CLASSES)
print("Model params:", model.count_params())

# -------------------------
# Loss and optimizer
# -------------------------
if USE_FOCAL_LOSS:
    loss_fn = categorical_focal_loss(gamma=2.0, alpha=0.25)
else:
    loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING)

def make_optimizer(lr):
    return keras.optimizers.AdamW(learning_rate=float(lr), weight_decay=WEIGHT_DECAY)

steps_per_epoch = len(train_gen)
total_steps = max(1, steps_per_epoch * (TOTAL_EPOCHS - WARMUP_EPOCHS))

# -------------------------
# Callbacks
# -------------------------
callbacks = [
    keras.callbacks.ModelCheckpoint(
        os.path.join(OUTPUT_DIR, "artifacts", "best_model.h5"),
        monitor="val_loss",
        save_best_only=True,
        verbose=1
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=PATIENCE_RLR,
        verbose=1,
        min_lr=1e-7
    ),
    keras.callbacks.CSVLogger(os.path.join(OUTPUT_DIR, "artifacts", "train_log.csv"))
]

class WarmupCosine(tf.keras.callbacks.Callback):
    def __init__(self, warmup_epochs, initial_lr, total_steps, steps_per_epoch):
        super().__init__()
        self.warmup_epochs = warmup_epochs
        self.initial_lr = float(initial_lr)
        self.total_steps = max(1, total_steps)
        self.steps_per_epoch = steps_per_epoch

    def on_train_begin(self, logs=None):
        self.step = 0

    def on_batch_begin(self, batch, logs=None):
        if self.step < self.warmup_epochs * self.steps_per_epoch:
            warmup_total = float(self.warmup_epochs * self.steps_per_epoch)
            lr = self.initial_lr * max(0.0, (self.step / warmup_total))
        else:
            t = (self.step - self.warmup_epochs * self.steps_per_epoch) / float(self.total_steps)
            t = min(1.0, max(0.0, t))
            lr = 0.5 * self.initial_lr * (1 + math.cos(math.pi * t))
        try:
            tf.keras.backend.set_value(self.model.optimizer.learning_rate, lr if lr > 0 else 1e-8)
        except Exception:
            try:
                self.model.optimizer.learning_rate.assign(lr if lr > 0 else 1e-8)
            except Exception:
                pass
        self.step += 1

    def on_epoch_end(self, epoch, logs=None):
        try:
            current_lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
        except Exception:
            current_lr = float(self.initial_lr)
        print(f"Epoch {epoch+1} lr={current_lr:.6e}")

callbacks.append(WarmupCosine(WARMUP_EPOCHS, INITIAL_LR, total_steps, steps_per_epoch))

# -------------------------
# Phase 1: freeze base, train head
# -------------------------
for layer in base.layers:
    layer.trainable = False

optimizer = make_optimizer(INITIAL_LR)
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])

print("Phase 1 (warmup head only) training...")
history1 = model.fit(train_gen, validation_data=val_gen, epochs=WARMUP_EPOCHS,
                     class_weight=class_weights, callbacks=callbacks, verbose=1)

# -------------------------
# Phase 2: unfreeze and fine-tune
# -------------------------
for layer in base.layers:
    layer.trainable = True

optimizer2 = make_optimizer(INITIAL_LR * 0.5)
model.compile(optimizer=optimizer2, loss=loss_fn, metrics=["accuracy"])

print("Phase 2 (fine-tune full model) training...")
initial_epoch = history1.epoch[-1] + 1 if hasattr(history1, "epoch") and len(history1.epoch) else 0
history2 = model.fit(train_gen, validation_data=val_gen,
                     epochs=(TOTAL_EPOCHS - WARMUP_EPOCHS), initial_epoch=initial_epoch,
                     class_weight=class_weights, callbacks=callbacks, verbose=1)

# Merge histories
history = history1
for k, v in history2.history.items():
    history.history.setdefault(k, []).extend(v)

# -------------------------
# Save final model
# -------------------------
final_path = os.path.join(OUTPUT_DIR, "artifacts", "efnb4_se_final_clean.h5")
try:
    model.save(final_path)
    print("Saved final model to:", final_path)
except Exception as e:
    print("Model.save failed, saving weights only. Error:", e)
    model.save_weights(final_path + ".weights.h5")

# -------------------------
# TTA evaluate on test set
# -------------------------
def tta_predict(model, file_list, preprocessor, tta_rounds=3):
    preds = []
    for t in range(tta_rounds):
        batch_preds = []
        for i in range(0, len(file_list), BATCH_SIZE):
            batch_files = file_list[i:i+BATCH_SIZE]
            X = np.zeros((len(batch_files), IMG_SIZE[0], IMG_SIZE[1], 3), dtype=np.float32)
            for j, fp in enumerate(batch_files):
                X[j] = preprocessor.preprocess(fp, training=(t > 0))
            p = model.predict(X, verbose=0)
            batch_preds.append(p)
        if len(batch_preds):
            batch_preds = np.vstack(batch_preds)
        else:
            batch_preds = np.zeros((0, NUM_CLASSES), dtype=np.float32)
        preds.append(batch_preds)
    if len(preds) == 0:
        return np.zeros((len(file_list), NUM_CLASSES), dtype=np.float32)
    return np.mean(preds, axis=0)

y_true = np.array(test_labels)
y_prob_plain = tta_predict(model, test_paths, preproc, tta_rounds=1)
y_pred_plain = np.argmax(y_prob_plain, axis=1)
acc_plain = accuracy_score(y_true, y_pred_plain)

y_prob_tta = tta_predict(model, test_paths, preproc, tta_rounds=TTA_ROUNDS)
y_pred_tta = np.argmax(y_prob_tta, axis=1)
acc_tta = accuracy_score(y_true, y_pred_tta)

# -------------------------
# Final summary output
# -------------------------
print("\n===== FINAL SUMMARY =====")
print("Dataset classes:", classes)
print(f"Original images (before cleaning): {len(filepaths)}")
print(f"Images removed (duplicates & near duplicates): {len(filepaths) - len(clean_filepaths)}")
print(f"Images used (after cleaning): {len(clean_filepaths)}")
print(f"Train samples: {len(train_paths)}")
print(f"Validation samples: {len(val_paths)}")
print(f"Test samples: {len(test_paths)}\n")

print(f"Test Accuracy (No TTA): {acc_plain*100:.2f}%")
print(classification_report(y_true, y_pred_plain, target_names=classes, digits=4))
print("Confusion Matrix (No TTA):\n", confusion_matrix(y_true, y_pred_plain))

print(f"\nTest Accuracy (TTA={TTA_ROUNDS}): {acc_tta*100:.2f}%")
print(classification_report(y_true, y_pred_tta, target_names=classes, digits=4))
print("Confusion Matrix (TTA):\n", confusion_matrix(y_true, y_pred_tta))
