In [1]:
import os
os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'

In [2]:
from dataclasses import dataclass
import tensorflow as tf
from typing import Any, Optional
from pathlib import Path
import tempfile
import time
import json
import hashlib

2025-12-26 00:41:48.208212: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-12-26 00:41:48.225644: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-12-26 00:41:48.231227: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-12-26 00:41:48.244102: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
from mobilenetv2ssd.core.config import load_config

In [4]:
main_cfg_path = "configs/train/default.yaml"
model_cfg_path = "configs/model/mobilenetv2_ssd_voc.yaml"
data_cfg_path = "configs/data/voc_224.yaml"
eval_cfg_path = "configs/eval/default.yaml"

In [5]:
config = load_config(main_cfg_path,model_cfg_path,data_cfg_path,eval_cfg_path)

In [6]:
config['checkpoint']

{'dir': '/mnt/d/dev/MobileNetV2-SSD/checkpoints',
 'keep_last_k': 5,
 'save_every_steps': 500,
 'save_every_epochs': 1,
 'save_last': True,
 'save_best': True,
 'monitor': 'val_map',
 'mode': 'max'}

In [7]:
class CheckpointManager:
    def __init__(self, checkpoint_config: dict[str,Any], model: tf.keras.Model, optimizer: tf.keras.optimizers.Optimizer, ema: Optional[Any], is_main_node: bool = True):
        self._model = model
        self._optimizer = optimizer
        
        self._checkpoint_directory = Path(checkpoint_config['dir'])
        self._checkpoint_directory.mkdir(parents = True, exist_ok = True)
        
        self._keep_last_k = checkpoint_config.get('keep_last_k', 5)
        
        self._save_every_steps = checkpoint_config.get('save_every_steps', None)
        if self._save_every_steps is not None and self._save_every_steps <= 0:
            raise ValueError("The value has to be greater than 0")
        
        self._save_every_epochs = checkpoint_config.get('save_every_epochs', 1)
        self._save_last = checkpoint_config.get('save_last', True)
        self._save_best = checkpoint_config.get('save_best', True)
        self._monitor = checkpoint_config.get('monitor', "val_map")
        self._mode = checkpoint_config.get('mode', "max")

        if self._mode not in {"max","min"}:
            raise ValueError("The value for the mode is wrong and should be either 'max' or 'min'")
        
        self._ema = ema

        # Saving the main status (rank = 0) to stop potential I/O problems when using DDP
        self._is_main = is_main_node

        # Creating the variables to use
        self._epoch_var = tf.Variable(0, dtype = tf.int64, trainable = False)
        self._global_step_var = tf.Variable(0, dtype = tf.int64, trainable = False)
        self._best_epoch_var = tf.Variable(-1, dtype = tf.int64, trainable = False)
        self._best_metric_var = tf.Variable(float("-inf"), dtype = tf.float32, trainable = False) if self._mode == "max" else tf.Variable(float("inf"), dtype = tf.float32, trainable = False)

        # Building the checkpoint bundle for the manager to store
        checkpoint_dict = {
            'model': self._model,
            'epoch': self._epoch_var,
            'global_step': self._global_step_var,
            'best_epoch': self._best_epoch_var,
            'best_metric': self._best_metric_var
        }

        if self._ema is not None:
            checkpoint_dict['ema'] = self._ema
        
        self._checkpoint = tf.train.Checkpoint(**checkpoint_dict)

        # Ensuring the /last subdirectory is creating in the checkpoint 
        self._last_directory = self._checkpoint_directory / "last"
        self._last_directory.mkdir(parents = True, exist_ok = True)

        # Now creating the two checkpoint managers that will be used (last & best)
        self._last_manager = tf.train.CheckpointManager(checkpoint = self._checkpoint, directory = str( self._last_directory), max_to_keep = self._keep_last_k)
        self._best_manager = None

        # This manager is used when there is a metric increase.
        if self._save_best:
            # There is a save best manager that is used
            self._best_directory = self._checkpoint_directory / "best"
            self._best_directory.mkdir(parents = True, exist_ok = True)
            self._best_manager = tf.train.CheckpointManager(checkpoint = self._checkpoint, directory = str(self._best_directory), max_to_keep = 1)
        
    def restore_latest(self):
        # Accessing the last manager and its parts
        latest_path = self._last_manager.latest_checkpoint

        if latest_path is None:
            latest_dir = Path(self._last_manager.directory)
            index_files = list(latest_dir.glob("ckpt-*.index"))
            if index_files:
                newest = max(index_files, key = self._select_checkpoint)
                latest_path = str(newest.with_suffix(""))

        # Checking if the path is None
        if latest_path is None:
            return {'restored': False, 'epoch': 0, 'global_step': 0, 'best_metric': float("-inf") if self._mode == "max" else float("inf") , 'best_epoch': -1}

        # There is a checkpoint and now needs to be loaded
        self._checkpoint.restore(latest_path).expect_partial()

        # Now getting the values to return to the training loop to resume correctly
        epoch = int(self._epoch_var.numpy())
        global_step = int(self._global_step_var.numpy())
        best_metric = float(self._best_metric_var.numpy())
        best_epoch = int(self._best_epoch_var.numpy())

        return {'restored': True, 'epoch': epoch, 'global_step': global_step, 'best_metric': best_metric , 'best_epoch': best_epoch}

    def restore_best(self):
        # Check if the best manager is even initialized since it can be disabled
        if self._best_manager is None:
            return {'restored': False, 'epoch': 0, 'global_step': 0, 'best_metric': float("-inf") if self._mode == "max" else float("inf") , 'best_epoch': -1}

        # Now check if the checkpoint exists
        best_path = self._best_manager.latest_checkpoint

        if best_path is None:
            best_dir = Path(self._best_manager.directory)
            index_files = list(best_dir.glob("ckpt-*.index"))
            if index_files:
                newest = max(index_files, key = self._select_checkpoint)
                best_path = str(newest.with_suffix(""))
                
        # Check if the path exists
        if best_path is None:
            return {'restored': False, 'epoch': 0, 'global_step': 0, 'best_metric': float("-inf") if self._mode == "max" else float("inf") , 'best_epoch': -1}

        # The path exists and now needs to be restored
        self._checkpoint.restore(best_path).expect_partial()

        # Now getting the values to return to the training loop to resume correctly
        epoch = int(self._epoch_var.numpy())
        global_step = int(self._global_step_var.numpy())
        best_metric = float(self._best_metric_var.numpy())
        best_epoch = int(self._best_epoch_var.numpy())

        return {'restored': True, 'epoch': epoch, 'global_step': global_step, 'best_metric': best_metric , 'best_epoch': best_epoch}

    def save_last(self, epoch: int, global_step: int):

        if not self._is_main:
            return None

        # Now saving the variables
        self._epoch_var.assign(epoch)
        self._global_step_var.assign(global_step)

        # Path create checkpoint file name
        save_path = self._last_manager.save(checkpoint_number = global_step)

        return save_path
        
    def save_best(self, epoch: int, global_step: int, metric: float):
        if not self._is_main:
            return {'is_best': False, 'path': None}
        
        if not self._save_best or self._best_manager is None:
            return {'is_best': False, 'path': None}

        # Now checking if the metric is less than or more than the value
        if self._compare_metrics(metric):
            
            # Assign the metric and the epoch to track the best
            self._best_metric_var.assign(metric)
            self._best_epoch_var.assign(epoch)
            self._epoch_var.assign(epoch)
            self._global_step_var.assign(global_step)

            # Save the checkpoint
            best_path = self._best_manager.save(checkpoint_number = global_step)

            return {'is_best': True, 'path': best_path}

        return {'is_best': False, 'path': None}

    def _compare_metrics(self, metric: float):

        # Now comparing the metric
        if self._mode == "max":
            return metric > float(self._best_metric_var.numpy())
        else:
            return metric < float(self._best_metric_var.numpy())

    def _select_checkpoint(self,p: Path):
        # "ckpt-123.index" -> 123
        stem = p.name.split(".")[0]      # "ckpt-123"
        return int(stem.split("-")[1])   # 123

    def should_save_step(self, global_step: int):
        # Now checking if the conditions for the steps are not violated
        if not self._save_last:
            return False

        # Checking if the checkpoint needs to saved every k steps
        if self._save_every_steps is None:
            return False

        # Now checking if the step is exactly the interval to save the checkpoint on
        if global_step % self._save_every_steps == 0:
            return True

        return False

    def should_save_epoch(self,epoch: int):
        # Now checking if the conditions for the epochs are not violated
        if not self._save_last:
            return False

        # Default is to save every epoch so there is no need for the second condition
        if (epoch + 1) % self._save_every_epochs == 0:
            return True

        return False

## Testing the Checkpoint Manager

In [8]:
run_dir = Path(tempfile.mkdtemp()) / "checkpoints"
cfg = {
    "dir": str(run_dir),
    "keep_last_k": 2,
    "save_last": True,
    "save_best": True,
    "save_every_steps": None,
    "save_every_epochs": 1,
    "monitor": "val_map",
    "mode": "max",
}

# 2) Objects for ckpt (writer)
model = tf.keras.Sequential([tf.keras.layers.Input((4,)), tf.keras.layers.Dense(1)])
opt = tf.keras.optimizers.Adam(1e-3)

_ = model(tf.zeros([1,4]))  # build vars

ckpt = CheckpointManager(cfg, model, opt, ema=None, is_main_node=True)

# 3) Objects for ckpt2 (reader) -- created BEFORE best exists (this is the stale case)
model2 = tf.keras.Sequential([tf.keras.layers.Input((4,)), tf.keras.layers.Dense(1)])
opt2 = tf.keras.optimizers.Adam(1e-3)
_ = model2(tf.zeros([1,4]))  # build vars

ckpt2 = CheckpointManager(cfg, model2, opt2, ema=None, is_main_node=True)

# 4) Save a couple best checkpoints (writer only)
r1 = ckpt.save_best(epoch=0, global_step=0, metric=0.2)
r2 = ckpt.save_best(epoch=1, global_step=1, metric=0.1)  # worse -> should not save
r3 = ckpt.save_best(epoch=2, global_step=2, metric=0.3)  # better -> should save

print("save_best returns:")
print(" r1 =", r1)
print(" r2 =", r2)
print(" r3 =", r3)

# 5) Debug: show stale latest_checkpoint state (expected ckpt sees it, ckpt2 might not)
print("\nlatest_checkpoint (writer ckpt):", ckpt._best_manager.latest_checkpoint)
print("latest_checkpoint (stale ckpt2):", ckpt2._best_manager.latest_checkpoint)

# 6) Restore best using stale ckpt2 (this is what failed before)
state_best = ckpt2.restore_best()
print("\nrestore_best() state:", state_best)

# 7) Assertions
assert r2["is_best"] is False
assert r3["is_best"] is True

assert state_best["restored"] is True
assert abs(state_best["best_metric"] - 0.3) < 1e-6
assert state_best["best_epoch"] == 2

I0000 00:00:1766727710.396967    3054 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1766727710.474114    3054 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1766727710.474178    3054 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1766727710.475750    3054 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1766727710.475822    3054 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:0

save_best returns:
 r1 = {'is_best': True, 'path': '/tmp/tmpfrpuiuo7/checkpoints/best/ckpt-0'}
 r2 = {'is_best': False, 'path': None}
 r3 = {'is_best': True, 'path': '/tmp/tmpfrpuiuo7/checkpoints/best/ckpt-2'}

latest_checkpoint (writer ckpt): /tmp/tmpfrpuiuo7/checkpoints/best/ckpt-2
latest_checkpoint (stale ckpt2): None

restore_best() state: {'restored': True, 'epoch': 2, 'global_step': 2, 'best_metric': 0.30000001192092896, 'best_epoch': 2}


In [9]:
run_dir = Path(tempfile.mkdtemp()) / "checkpoints"
cfg = {"dir": str(run_dir), "keep_last_k": 2, "save_last": True, "save_best": True,
       "save_every_steps": None, "save_every_epochs": 1, "monitor": "val_map", "mode": "max"}

model = tf.keras.Sequential([tf.keras.layers.Input((4,)), tf.keras.layers.Dense(1)])
opt = tf.keras.optimizers.Adam(1e-3)
_ = model(tf.zeros([1,4]))

ckpt_worker = CheckpointManager(cfg, model, opt, ema=None, is_main_node=False)

p_last = ckpt_worker.save_last(epoch=0, global_step=10)
p_best = ckpt_worker.save_best(epoch=0, global_step=10, metric=0.9)

print("worker save_last returned:", p_last)
print("worker save_best returned:", p_best)

# Expectations:
assert p_last is None
assert p_best["is_best"] is False and p_best["path"] is None

# Also expect no checkpoint files exist
last_dir = Path(cfg["dir"]) / "last"
best_dir = Path(cfg["dir"]) / "best"
print("last_dir files:", list(last_dir.glob("*")))
print("best_dir files:", list(best_dir.glob("*")))

worker save_last returned: None
worker save_best returned: {'is_best': False, 'path': None}
last_dir files: []
best_dir files: []


In [10]:
modelM = tf.keras.Sequential([tf.keras.layers.Input((4,)), tf.keras.layers.Dense(1)])
optM = tf.keras.optimizers.Adam(1e-3)
_ = modelM(tf.zeros([1,4]))
ckpt_main = CheckpointManager(cfg, modelM, optM, ema=None, is_main_node=True)

p = ckpt_main.save_last(epoch=5, global_step=777)
b = ckpt_main.save_best(epoch=5, global_step=777, metric=0.42)

print("main saved last:", p)
print("main saved best:", b)

# Reader (worker rank) - should NOT write, but CAN restore
modelW = tf.keras.Sequential([tf.keras.layers.Input((4,)), tf.keras.layers.Dense(1)])
optW = tf.keras.optimizers.Adam(1e-3)
_ = modelW(tf.zeros([1,4]))
ckpt_worker = CheckpointManager(cfg, modelW, optW, ema=None, is_main_node=False)

state_last = ckpt_worker.restore_latest()
state_best = ckpt_worker.restore_best()

print("worker restore_latest:", state_last)
print("worker restore_best:", state_best)

# Expectations:
assert state_last["restored"] is True
assert state_last["epoch"] == 5
assert state_last["global_step"] == 777

assert state_best["restored"] is True
assert abs(state_best["best_metric"] - 0.42) < 1e-6
assert state_best["best_epoch"] == 5

main saved last: /tmp/tmpfi6xs4ti/checkpoints/last/ckpt-777
main saved best: {'is_best': True, 'path': '/tmp/tmpfi6xs4ti/checkpoints/best/ckpt-777'}
worker restore_latest: {'restored': True, 'epoch': 5, 'global_step': 777, 'best_metric': -inf, 'best_epoch': -1}
worker restore_best: {'restored': True, 'epoch': 5, 'global_step': 777, 'best_metric': 0.41999998688697815, 'best_epoch': 5}


In [11]:
config['data']

{'dataset_name': 'voc',
 'root': '/mnt/d/dev/MobileNetV2-SSD/datasets/VOCdevkit',
 'train_split': 'trainval',
 'val_split': 'test',
 'input_size': [224, 224],
 'num_workers': 4,
 'shuffle_buffer': 1000,
 'prefetch_batches': 2,
 'augment': {'random_flip': True,
  'random_flip_prob': 0.5,
  'random_crop': True,
  'min_crop_iou_choices': [0.1, 0.3, 0.5, 0.7, 0.9],
  'min_crop_scale': 0.3,
  'max_crop_scale': 1.0,
  'photometric_distort': True,
  'photometric_distort_prob': 0.5},
 'normalization': {'mean': [0.485, 0.456, 0.406],
  'std': [0.229, 0.224, 0.225]},
 'classes_file': '/mnt/d/dev/MobileNetV2-SSD/configs/data/voc_classes.txt'}

In [12]:
def _create_checkpoint_directory_fingerprint(config: dict[str,Any]):
    model_config = config['model']
    dataset_config = config['data']
    train_config = config['train']

    fingerprint_config = {
        'model_backbone': model_config.get('backbone',''),
        'num_classes': model_config.get('num_classes',0),
        'priors': model_config.get('priors',{}),
        'dataset_name': dataset_config.get('dataset_name', ''),
        'dataset_augmentation': dataset_config.get('augment', {}),
        'dataset_normalization': dataset_config.get('normalization', {}),
        'training_batch_size': train_config.get('batch_size', 0),
        'training_epochs': train_config.get('epochs', 0),
        'training_optimizer_name': train_config['optimizer'].get('name', ''),
        'training_optimizer_lr': train_config['optimizer'].get('lr', 0.0),
        'training_optimizer_weight_decay': train_config['optimizer'].get('weight_decay', 0.0),
        'training_scheduler_params':train_config.get('scheduler',{})
    }

    config_json = json.dumps(fingerprint_config,sort_keys = True, separators=(",",":"))
    hash_object = hashlib.sha256(config_json.encode('utf-8'))
    hex_digest = hash_object.hexdigest()[:10]

    file_slug = f"{model_config['name']}_{dataset_config['dataset_name']}_img{dataset_config['input_size'][0]}_bs{fingerprint_config['training_batch_size']}_lr{fingerprint_config['training_optimizer_lr']:.2e}_{train_config['scheduler']['name']}"
    file_name = f"{file_slug}_{hex_digest}"

    root_dir =  config['checkpoint']['dir']
    run_dir = Path(root_dir) / file_name
    
    return {
        'dir': str(run_dir),
        'keep_last_k': config['checkpoint'].get('keep_last_k', 1),
        'save_every_steps': config['checkpoint'].get('save_every_steps', 200),
        'save_every_epochs': config['checkpoint'].get('save_every_epochs', 1),
        'save_last': config['checkpoint'].get('save_last', True),
        'save_best': config['checkpoint'].get('save_best', True),
        'monitor': config['checkpoint'].get('monitor', 'val_map'),
        'mode': config['checkpoint'].get('mode', 'max')
    }

def build_checkpoint_manager(config: dict[str,Any], model: tf.keras.Model, optimizer: tf.keras.optimizers.Optimizer, ema: Optional[Any], is_main_node: bool =True):
    checkpoint_config = _create_checkpoint_directory_fingerprint(config)

    checkpoint_manager = CheckpointManager(checkpoint_config, model, optimizer, ema= ema, is_main_node= is_main_node)

    return checkpoint_manager
    

In [13]:
build_checkpoint_manager(config,modelM, optM, None, True)

<__main__.CheckpointManager at 0x7cbc87fe4340>