<a href="https://colab.research.google.com/github/aradeyal/machine_learning/blob/main/callback_%D7%AA%D7%A8%D7%92%D7%99%D7%9C_%D7%AA%D7%99%D7%90%D7%95%D7%A8%D7%98%D7%99%E2%80%8E.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -*- coding: utf-8 -*-
import numpy as np
from tensorflow import keras

# ========================= Callbacks =========================
class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
    """Stop when train loss plateaus OR when overfitting is suspected."""
    def __init__(self, patience=0):
        super().__init__()
        self.patience = int(patience)

        # Overfitting logic (can be tuned after init)
        self.overfit_patience = 2     # consecutive overfit-epochs allowed
        self.min_abs_delta = 0.05     # trigger if (val_loss - loss) >= this
        self.min_rel_delta = 0.20     # OR if (val_loss - loss)/loss >= this
        self.warmup = 5               # skip overfit checks for first N epochs
        self.restore_best_weights = True
        self.verbose = 1

        # Trackers (train loss)
        self.best = np.inf
        self.best_weights = None
        self.wait = 0

        # Trackers (overfitting)
        self.best_val = np.inf
        self.best_val_weights = None
        self.wait_overfit = 0
        self.stopped_epoch = 0

    def on_train_begin(self, logs=None):
        self.wait = 0
        self.best = np.inf
        self.best_weights = None
        self.best_val = np.inf
        self.best_val_weights = None
        self.wait_overfit = 0
        self.stopped_epoch = 0

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        loss = logs.get("loss")
        val_loss = logs.get("val_loss")

        # ---- original: stop if train loss stops improving ----
        if loss is not None:
            if loss < self.best:
                self.best = loss
                self.wait = 0
                self.best_weights = self.model.get_weights()
            else:
                self.wait += 1
                if self.verbose and self.patience:
                    print(f"[Train no-improve] wait {self.wait}/{self.patience}")
                if self.patience and self.wait >= self.patience:
                    self.stopped_epoch = epoch + 1
                    if self.verbose:
                        print("Stopping (train loss plateau). Restoring best train-loss weights.")
                    self.model.stop_training = True
                    if self.best_weights is not None:
                        self.model.set_weights(self.best_weights)
                    return  # already stopping

        # ---- add-on: suspected overfitting stop ----
        if loss is not None and val_loss is not None and (epoch + 1) > self.warmup:
            # keep checkpoint by best val_loss
            if val_loss < self.best_val:
                self.best_val = val_loss
                if self.restore_best_weights:
                    self.best_val_weights = self.model.get_weights()

            gap_abs = val_loss - loss
            gap_rel = gap_abs / max(loss, 1e-8)
            overfit_now = (gap_abs >= self.min_abs_delta) or (gap_rel >= self.min_rel_delta)

            if overfit_now:
                self.wait_overfit += 1
                if self.verbose:
                    print(f"[Overfit?] epoch {epoch+1}: loss={loss:.4f}, val_loss={val_loss:.4f}, "
                          f"gap_abs={gap_abs:.4f}, gap_rel={gap_rel:.2%} "
                          f"(wait {self.wait_overfit}/{self.overfit_patience})")
                if self.wait_overfit >= self.overfit_patience:
                    self.stopped_epoch = epoch + 1
                    if self.verbose:
                        print("Overfitting detected — stopping and restoring best val_loss weights.")
                    self.model.stop_training = True
                    if self.restore_best_weights and self.best_val_weights is not None:
                        self.model.set_weights(self.best_val_weights)
            else:
                self.wait_overfit = 0

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0 and self.verbose:
            print(f"Epoch {self.stopped_epoch}: early stopping")

class LossAndErrorPrintingCallback(keras.callbacks.Callback):
    """Print loss & val_loss every epoch."""
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        loss = logs.get("loss")
        val_loss = logs.get("val_loss")
        loss_str = f"{loss:.6f}" if loss is not None else "NA"
        val_str  = f"{val_loss:.6f}" if val_loss is not None else "NA"
        print(f"Epoch {epoch+1} — loss: {loss_str}, val_loss: {val_str}")

# ========================= Model =========================
def get_model(n_features: int):
    model = keras.Sequential([
        keras.layers.Input(shape=(n_features,)),
        keras.layers.Dense(64, activation="relu"),
        keras.layers.Dense(1)  # regression head (for classification use softmax)
    ])
    model.compile(optimizer="adam", loss="mse")
    return model

# ========================= TRAIN =========================
# Expect x_train, y_train to be defined by you:
# x_train: shape (num_samples, num_features)
# y_train: shape (num_samples,)
# If you don't have them yet, load/prepare your data above.

# n_features from your data
n_features = x_train.shape[1]

model = get_model(n_features)

# configure the callback (so training can go further before stopping for overfitting)
cb = EarlyStoppingAtMinLoss(patience=5)
cb.min_rel_delta = 0.35     # less sensitive relative gap (was 0.20)
cb.min_abs_delta = 0.07     # less sensitive absolute gap (was 0.05)
cb.overfit_patience = 4     # require more consecutive overfit epochs (was 2)
cb.warmup = 10              # start checking later (was 5)

# (optional) lower LR when val_loss plateaus — helps loss go down further
reduce_lr = keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss", factor=0.5, patience=3, min_lr=1e-6, verbose=1
)

history = model.fit(
    x_train, y_train,
    validation_split=0.2,      # or validation_data=(x_val, y_val)
    epochs=60,
    batch_size=64,
    verbose=0,                 # logs printed by callbacks
    callbacks=[LossAndErrorPrintingCallback(), cb, reduce_lr],
)


Epoch 1 — loss: 7.702114, val_loss: 7.758357
Epoch 2 — loss: 6.601505, val_loss: 6.592540
Epoch 3 — loss: 5.558546, val_loss: 5.516394
Epoch 4 — loss: 4.583126, val_loss: 4.438117
Epoch 5 — loss: 3.607121, val_loss: 3.433881
Epoch 6 — loss: 2.708570, val_loss: 2.540847
Epoch 7 — loss: 1.949201, val_loss: 1.796958
Epoch 8 — loss: 1.318943, val_loss: 1.239800
Epoch 9 — loss: 0.901448, val_loss: 0.853298
Epoch 10 — loss: 0.639529, val_loss: 0.617327
Epoch 11 — loss: 0.483727, val_loss: 0.490222
Epoch 12 — loss: 0.403558, val_loss: 0.422356
Epoch 13 — loss: 0.365822, val_loss: 0.382997
Epoch 14 — loss: 0.343832, val_loss: 0.363323
Epoch 15 — loss: 0.330520, val_loss: 0.351166
Epoch 16 — loss: 0.321747, val_loss: 0.342800
Epoch 17 — loss: 0.316055, val_loss: 0.336048
Epoch 18 — loss: 0.309998, val_loss: 0.332338
Epoch 19 — loss: 0.304692, val_loss: 0.327520
Epoch 20 — loss: 0.300473, val_loss: 0.324472
Epoch 21 — loss: 0.296417, val_loss: 0.320107
Epoch 22 — loss: 0.292141, val_loss: 0.3174