In [None]:
# tests/test_fabricate_vectors.py
import numpy as np
import tensorflow as tf
from IEModules.Helper_Functions import fabricate_vectors_for_f1
from IEModules.Custom_Metrics import CustomNoBackgroundF1Score

@pytest.mark.parametrize("target, with_bg", [
    (0.30, True),
    (0.55, True),
    (0.87, True),
    (0.20, False),
    (0.77, False),
    (1.00, False),
])
def test_fabricated_f1(target, with_bg):
    y_true, y_pred = fabricate_vectors_for_f1(target,
                                              length=200,
                                              include_background=with_bg,
                                              seed=42)
    metric = CustomNoBackgroundF1Score(
        num_classes=5,   # always 5, the metric strips bg if present
        name="nb_f1"
    )
    metric.update_state(tf.constant(y_true), tf.constant(y_pred))
    realised = metric.result().numpy()
    # allow tiny float error
    assert np.isclose(realised, target, atol=1e-6), \
        f"wanted {target}, got {realised}"


In [None]:
from IEModules.Helper_Functions import fabricate_vectors_for_f1
from IEModules.Custom_Metrics   import CustomNoBackgroundF1Score
import tensorflow as tf

yt, yp = fabricate_vectors_for_f1(0.77, 200, include_background=True, seed=1)
metric = CustomNoBackgroundF1Score(num_classes=5)
metric.update_state(tf.constant(yt), tf.constant(yp))
print("Realised F1 =", round(metric.result().numpy(), 4))


In [None]:
# Helper_Functions.py  (add below the previous helper)
import numpy as np
from typing import Tuple

def fabricate_vectors_for_f1_and_acc(target_f1: float,
                                     target_acc: float,
                                     length: int = 200,
                                     include_background: bool = True,
                                     focus_class: int = 1,
                                     seed: int | None = None,
                                     tol: float = 1e-12
                                     ) -> Tuple[np.ndarray, np.ndarray]:
    """
    Build y_true / y_pred whose **weighted** CustomNoBackgroundF1Score equals
    `target_f1` *and* whose overall accuracy equals `target_acc`.

    Parameters
    ----------
    target_f1, target_acc : floats in (0,1]
    length                : number of rows (>=200 recommended)
    include_background    : 5‑column if True else 4‑column
    focus_class           : which target column (1…4) carries all positives
    seed                  : RNG seed
    tol                   : floating‑point tolerance for the final checks

    Returns
    -------
    y_true, y_pred : float32 arrays shaped [length, num_classes]

    Raises
    ------
    ValueError if the pair (F1,acc) is infeasible with the chosen `length`.
    """
    if not (0 < target_f1 <= 1) or not (0 < target_acc <= 1):
        raise ValueError("F1 and accuracy must lie in (0,1].")
    rng = np.random.default_rng(seed)
    n_classes = 5 if include_background else 4
    pos_col   = focus_class if include_background else focus_class - 1

    # ----------  solve for integer TP  ----------
    denom = 2 * (1 - target_f1)
    TP_est = length * target_f1 * (1 - target_acc) / denom
    TP_int = round(TP_est)

    # slide ±1 around the rounded value until everything fits
    found = False
    for TP in range(max(0, TP_int-2), TP_int+3):
        S  = 2 * TP * (1/target_f1 - 1)             # FP + FN
        TN = int(round(target_acc * length - TP))   # from accuracy

        if (abs(target_acc - (TP+TN)/length) < tol and
            abs(target_f1  - (2*TP)/(2*TP+S)) < tol and
            0 <= TN <= length and
            S.is_integer()):
            FP = int(S // 2)
            FN = int(S - FP)
            if TP + FP + FN + TN == length:
                found = True
                break
    if not found:
        raise ValueError("Chosen F1/accuracy not realisable with length {}"
                         .format(length))

    # ----------  materialise the rows ----------
    y_true = np.zeros((length, n_classes), dtype=np.float32)
    y_pred = np.zeros_like(y_true)
    rows = rng.permutation(length)
    idx = 0
    for _ in range(TP):
        r = rows[idx]; idx += 1
        y_true[r, pos_col] = y_pred[r, pos_col] = 1
    for _ in range(FN):
        r = rows[idx]; idx += 1
        y_true[r, pos_col] = 1
    for _ in range(FP):
        r = rows[idx]; idx += 1
        if include_background:
            y_true[r, 0] = 1
        y_pred[r, pos_col] = 1
    for _ in range(TN):
        r = rows[idx]; idx += 1
        if include_background:
            y_true[r, 0] = 1

    return y_true, y_pred


In [None]:
from IEModules.Custom_Metrics import CustomNoBackgroundF1Score
import tensorflow as tf
from IEModules.Helper_Functions import fabricate_vectors_for_f1_and_acc

F1_target, acc_target = 0.78, 0.60
yt, yp = fabricate_vectors_for_f1_and_acc(F1_target, acc_target,
                                          length=200,
                                          include_background=True,
                                          seed=123)

nb_f1 = CustomNoBackgroundF1Score(num_classes=5)
nb_f1.update_state(tf.constant(yt), tf.constant(yp))
print("F1 :", nb_f1.result().numpy())

accuracy = ( (yp >= 0.5) == (yt == 1) ).mean()
print("acc:", accuracy)


In [None]:
class BatchModelCheckpoint(callbacks.ModelCheckpoint):
    """
    A ModelCheckpoint that also saves immediately after the final training
    batch of each epoch (i.e. before validation begins), using the same
    filepath template and arguments as the standard checkpoint.
    """
    def __init__(self, filepath, steps_per_epoch, logs=None, **kwargs):
        """
        Args:
            filepath: same template you’d pass to ModelCheckpoint
                      (e.g. ".../epoch-{epoch:03d}.keras")
            steps_per_epoch: number of train batches per epoch
            **kwargs: all the same keyword args you’d pass to ModelCheckpoint
        """
        super().__init__(filepath, **kwargs)
        self.steps_per_epoch = steps_per_epoch
        self._current_epoch = None

    def on_epoch_begin(self, epoch, logs=None):
        self._cur_epoch = epoch

    def on_train_batch_end(self, batch, logs=None):
        if batch + 1 == self.steps_per_epoch:
            # Keras 3 signature: epoch, batch, logs
            self._save_model(epoch=self._cur_epoch,
                             batch=batch,
                             logs=logs or {})