In [1]:
import os
os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'

In [2]:
import tensorflow as tf
from typing import Any, Optional
from contextlib import contextmanager

2025-12-17 20:04:57.016692: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-12-17 20:04:57.036066: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-12-17 20:04:57.041524: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-12-17 20:04:57.056202: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
from mobilenetv2ssd.core.config import load_config

In [4]:
main_cfg_path = "configs/train/default.yaml"
model_cfg_path = "configs/model/mobilenetv2_ssd_voc.yaml"
data_cfg_path = "configs/data/voc_224.yaml"
eval_cfg_path = "configs/eval/default.yaml"

In [5]:
config = load_config(main_cfg_path,model_cfg_path,data_cfg_path,eval_cfg_path)

In [6]:
config['train']['ema']

{'enabled': True,
 'decay': 0.9,
 'warmup_steps': 0,
 'update_every': 1,
 'eval_use_ema': True}

In [7]:
class EMA(tf.Module):
    def __init__(self, model: tf.keras.Model, ema_config: dict[str, Any]):
        super().__init__(name="EMA")

        if model is None:
            raise ValueError("EMA requires a built model instance (model cannot be None).")
        if not model.built:
            raise ValueError("Build/call the model before creating EMA, otherwise trainable_variables is empty.")
        
        self._decay = float(ema_config.get('decay', 0.999))
        self._enabled = bool(ema_config.get('enabled', True))
        
        self._warmup_steps = int(ema_config.get('warmup_steps', 0))
        self._update_every = int(ema_config.get('update_every', 1))
        self._eval_use_ema = bool(ema_config.get('eval_use_ema', True))
        self._is_applied = False
        
        self._num_updates = tf.Variable(0, dtype=tf.int64, trainable=False, name="num_updates") # Tracking the counter
        self._use_num_updates = bool(ema_config.get('use_num_updates', False))

        # Need to initialize the model training variables
        self._model_vars = list(model.trainable_variables)

        self._ema_vars = [tf.Variable(tf.convert_to_tensor(variable), dtype = variable.dtype, trainable = False, name = f"{variable.name.replace(':', '_')}_ema") for variable in self._model_vars]

        self._backup = None

    def reset(self):
        # The function needs to reset to the models current weights

        # First check if the model weights and EMA weights are mapped 1:1
        if len(self._ema_vars) != len(self._model_vars):
            raise ValueError("EMA values are not 1:1 check the length of the variables passed to the EMA.")

        # Need to copy the current weights of the model into the EMA
        for ema_var, model_var in zip(self._ema_vars,self._model_vars):
            ema_var.assign(model_var)

        # Need to reset the updates since the EMA was reset to the model's weights
        self._num_updates.assign(0)

        # Clearing the cache of values
        self._backup = None

    def should_update(self, step: int):

        # Checking if the step is in the warmup phase or not
        if step < self._warmup_steps:
            return False

        # Checking if the step is between the range acceptable
        if self._update_every > 1 and step % self._update_every != 0:
            return False

        # Checking if EMA is enabled
        if not self._enabled:
            return False

        return True # Everything passed the conditions

    def should_apply_during_eval(self):
        return self._enabled and self._eval_use_ema and len(self._ema_vars) > 0

    def update(self, step: int):
        # Function updates the value of the EMA

        if not self.should_update(step):
            return

        decay = tf.constant(self._decay, tf.float32)
        num_updates = tf.cast(self._num_updates, tf.float32)

        # Calculating the ramp based on how may updates have been made to account for early garbage weights
        adjusted_decay = (1 + num_updates)/ (10 + num_updates)

        # Selecting the minimum of the two
        decay_rate = tf.minimum(decay, adjusted_decay)
        decay_rate = tf.cast(decay_rate, tf.float32)
        
        inverse_decay_rate = 1 - decay_rate
        
        # Now updating the value
        for ema_var, model_var in zip(self._ema_vars,self._model_vars, strict = True):
            decay_factor = tf.cast(decay_rate, ema_var.dtype)
            inverse_decay_factor = tf.cast(1.0, ema_var.dtype) - decay_factor
            ema_var.assign(decay_factor * ema_var + inverse_decay_factor * model_var)

        # Increment the counter
        self._num_updates.assign_add(1)

    def apply_to(self, model: tf.keras.Model | None = None):
        # Need to check if the model is None to pick the correct one
        if model is None:
            # Using the fallback model
            model_variables = self._model_vars
        else:
            model_variables = list(model.trainable_variables)
            
        if len(self._ema_vars) != len(model_variables):
            raise ValueError("EMA vars and model vars mismatch")

        # Checking if the dtype is correct
        for ema_var, model_var in zip(self._ema_vars, model_variables):
            if ema_var.dtype != model_var.dtype or ema_var.shape != model_var.shape:
                raise ValueError("Dtypes not same for target model and EMA saved copy")

        # Now checking if backup exists so if used consecutively there can be a sort of ECF with the EMA weights
        if (self._backup is not None) or (self._is_applied):
            raise ValueError("Cannot apply since backup exists, restore() needs to be called")

        # Creating a backup
        self._backup = [tf.convert_to_tensor(var) for var in model_variables]

        # Now swapping the ema weights into the model weights
        for ema_var, model_var in zip(self._ema_vars,model_variables):
            model_var.assign(ema_var)

        self._is_applied = True

    def restore(self, model: tf.keras.Model | None = None):
        # Need to check if the model is None to pick the correct one
        if model is None:
            # Using the fallback model
            model_variables = self._model_vars
        else:
            model_variables = list(model.trainable_variables)

        # Checking if there is a backup to restore from
        if (self._backup is None) or (not self._is_applied):
            raise ValueError("Cannot restore since backup doesnt exist, apply_to() needs to be called")

        if len(self._backup) != len(model_variables):
            raise ValueError("Backup and model vars mismatch")

        # Restoring the backup to the model
        for backup_var, model_var in zip(self._backup,model_variables):
            if backup_var.dtype != model_var.dtype or backup_var.shape != model_var.shape:
                raise ValueError("Dtypes not same for target model and Backup copy")
            model_var.assign(backup_var)

        # Clearing the backup
        self._backup = None
        self._is_applied = False
    
    @contextmanager
    def eval_context(self, model: tf.keras.Model | None = None):
        # This function needs to check whether the model requires EMA or not for the evaluation lifecycle
        use_ema = self.should_apply_during_eval()

        if use_ema:
            # Apply the ema weights to the model
            self.apply_to(model)

        try:
            yield # Allow for the eval step to run without an issue
        finally:
            # If ema was then it needs to be reverted to the model raw values for the training step once again
            if use_ema:
                self.restore(model)

In [8]:
model = tf.keras.Sequential([tf.keras.layers.Input((4,)), tf.keras.layers.Dense(1)])

I0000 00:00:1766019900.484724    9216 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1766019900.578930    9216 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1766019900.579033    9216 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1766019900.580669    9216 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1766019900.580753    9216 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:0

In [9]:
ema = EMA(model = model, ema_config = config['train']['ema'])

In [10]:
ema.should_update(500)

True

In [11]:
tf.convert_to_tensor(model.trainable_variables[0])

<tf.Tensor: shape=(4, 1), dtype=float32, numpy=
array([[ 0.41609168],
       [ 1.0812318 ],
       [-0.3979873 ],
       [-1.0829793 ]], dtype=float32)>

## Factory Pattern For EMA

In [12]:
def get_ema_config(config: dict[str,Any]):
    train_config = config.get('train',{})
    ema_options= train_config.get('ema',{})

    ema_config = {
        'enabled': ema_options.get('enabled',True),
        'decay': ema_options.get('decay', 0.9),
        'warmup_steps': ema_options.get('warmup_steps', 0),
        'update_every': ema_options.get('update_every', 1),
        'eval_use_ema': ema_options.get('eval_use_ema', True)
    }

    return ema_config

In [13]:
def build_ema(config: dict[str,Any], model: tf.keras.Model):
    # First build the config
    ema_config = get_ema_config(config)

    # Build the EMA
    ema = EMA(model = model, ema_config = ema_config)

    return ema

In [14]:
build_ema(config,model)._ema_vars

ListWrapper([<tf.Variable 'kernel_ema:0' shape=(4, 1) dtype=float32, numpy=
array([[ 0.41609168],
       [ 1.0812318 ],
       [-0.3979873 ],
       [-1.0829793 ]], dtype=float32)>, <tf.Variable 'bias_ema:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>])

## Testing EMA

In [15]:
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(4,)),
    tf.keras.layers.Dense(3, use_bias=True),
    tf.keras.layers.Dense(2, use_bias=False),
])
_ = model(tf.zeros((1, 4)))

In [16]:
def set_all_trainable_vars_to(model, value: float):
    for v in model.trainable_variables:
        v.assign(tf.ones_like(v) * value)

def snapshot_vars(model):
    return [tf.identity(v) for v in model.trainable_variables]

def max_abs_diff(vars_a, vars_b):
    return max(float(tf.reduce_max(tf.abs(a - b))) for a, b in zip(vars_a, vars_b))

### Test 1

In [17]:
ema = EMA(model = model, ema_config = config['train']['ema'])

In [18]:
set_all_trainable_vars_to(model, 5.0)

In [19]:
ema.reset()

In [20]:
set_all_trainable_vars_to(model, 7.0)

In [21]:
ema._ema_vars

ListWrapper([<tf.Variable 'kernel_ema:0' shape=(4, 3) dtype=float32, numpy=
array([[5., 5., 5.],
       [5., 5., 5.],
       [5., 5., 5.],
       [5., 5., 5.]], dtype=float32)>, <tf.Variable 'bias_ema:0' shape=(3,) dtype=float32, numpy=array([5., 5., 5.], dtype=float32)>, <tf.Variable 'kernel_ema:0' shape=(3, 2) dtype=float32, numpy=
array([[5., 5.],
       [5., 5.],
       [5., 5.]], dtype=float32)>])

In [22]:
[tf.identity(var) for var in model.trainable_variables]

[<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
 array([[7., 7., 7.],
        [7., 7., 7.],
        [7., 7., 7.],
        [7., 7., 7.]], dtype=float32)>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([7., 7., 7.], dtype=float32)>,
 <tf.Tensor: shape=(3, 2), dtype=float32, numpy=
 array([[7., 7.],
        [7., 7.],
        [7., 7.]], dtype=float32)>]

### Test 2

In [23]:
ema = EMA(model = model, ema_config = config['train']['ema'])

In [24]:
set_all_trainable_vars_to(model, 0.0)

In [25]:
ema.reset()

In [26]:
ema._ema_vars

ListWrapper([<tf.Variable 'kernel_ema:0' shape=(4, 3) dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>, <tf.Variable 'bias_ema:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>, <tf.Variable 'kernel_ema:0' shape=(3, 2) dtype=float32, numpy=
array([[0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)>])

In [27]:
set_all_trainable_vars_to(model, 10.0)

In [28]:
ema.update(step = 0)

In [29]:
ema._ema_vars

ListWrapper([<tf.Variable 'kernel_ema:0' shape=(4, 3) dtype=float32, numpy=
array([[9., 9., 9.],
       [9., 9., 9.],
       [9., 9., 9.],
       [9., 9., 9.]], dtype=float32)>, <tf.Variable 'bias_ema:0' shape=(3,) dtype=float32, numpy=array([9., 9., 9.], dtype=float32)>, <tf.Variable 'kernel_ema:0' shape=(3, 2) dtype=float32, numpy=
array([[9., 9.],
       [9., 9.],
       [9., 9.]], dtype=float32)>])

### Test 3

In [30]:
ema = EMA(model = model, ema_config = config['train']['ema'])

In [31]:
ema._ema_vars

ListWrapper([<tf.Variable 'kernel_ema:0' shape=(4, 3) dtype=float32, numpy=
array([[10., 10., 10.],
       [10., 10., 10.],
       [10., 10., 10.],
       [10., 10., 10.]], dtype=float32)>, <tf.Variable 'bias_ema:0' shape=(3,) dtype=float32, numpy=array([10., 10., 10.], dtype=float32)>, <tf.Variable 'kernel_ema:0' shape=(3, 2) dtype=float32, numpy=
array([[10., 10.],
       [10., 10.],
       [10., 10.]], dtype=float32)>])

In [32]:
set_all_trainable_vars_to(model, 2.0)

In [33]:
[tf.identity(var) for var in model.trainable_variables]

[<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
 array([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]], dtype=float32)>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([2., 2., 2.], dtype=float32)>,
 <tf.Tensor: shape=(3, 2), dtype=float32, numpy=
 array([[2., 2.],
        [2., 2.],
        [2., 2.]], dtype=float32)>]

In [34]:
ema.reset()

In [35]:
ema._ema_vars

ListWrapper([<tf.Variable 'kernel_ema:0' shape=(4, 3) dtype=float32, numpy=
array([[2., 2., 2.],
       [2., 2., 2.],
       [2., 2., 2.],
       [2., 2., 2.]], dtype=float32)>, <tf.Variable 'bias_ema:0' shape=(3,) dtype=float32, numpy=array([2., 2., 2.], dtype=float32)>, <tf.Variable 'kernel_ema:0' shape=(3, 2) dtype=float32, numpy=
array([[2., 2.],
       [2., 2.],
       [2., 2.]], dtype=float32)>])

In [36]:
set_all_trainable_vars_to(model, 9.0)

In [37]:
[tf.identity(var) for var in model.trainable_variables]

[<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
 array([[9., 9., 9.],
        [9., 9., 9.],
        [9., 9., 9.],
        [9., 9., 9.]], dtype=float32)>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([9., 9., 9.], dtype=float32)>,
 <tf.Tensor: shape=(3, 2), dtype=float32, numpy=
 array([[9., 9.],
        [9., 9.],
        [9., 9.]], dtype=float32)>]

In [38]:
ema.apply_to()

In [39]:
[tf.identity(var) for var in model.trainable_variables]

[<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
 array([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]], dtype=float32)>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([2., 2., 2.], dtype=float32)>,
 <tf.Tensor: shape=(3, 2), dtype=float32, numpy=
 array([[2., 2.],
        [2., 2.],
        [2., 2.]], dtype=float32)>]

In [40]:
ema._ema_vars

ListWrapper([<tf.Variable 'kernel_ema:0' shape=(4, 3) dtype=float32, numpy=
array([[2., 2., 2.],
       [2., 2., 2.],
       [2., 2., 2.],
       [2., 2., 2.]], dtype=float32)>, <tf.Variable 'bias_ema:0' shape=(3,) dtype=float32, numpy=array([2., 2., 2.], dtype=float32)>, <tf.Variable 'kernel_ema:0' shape=(3, 2) dtype=float32, numpy=
array([[2., 2.],
       [2., 2.],
       [2., 2.]], dtype=float32)>])

In [41]:
ema.restore()

In [42]:
[tf.identity(var) for var in model.trainable_variables]

[<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
 array([[9., 9., 9.],
        [9., 9., 9.],
        [9., 9., 9.],
        [9., 9., 9.]], dtype=float32)>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([9., 9., 9.], dtype=float32)>,
 <tf.Tensor: shape=(3, 2), dtype=float32, numpy=
 array([[9., 9.],
        [9., 9.],
        [9., 9.]], dtype=float32)>]

In [43]:
ema._ema_vars

ListWrapper([<tf.Variable 'kernel_ema:0' shape=(4, 3) dtype=float32, numpy=
array([[2., 2., 2.],
       [2., 2., 2.],
       [2., 2., 2.],
       [2., 2., 2.]], dtype=float32)>, <tf.Variable 'bias_ema:0' shape=(3,) dtype=float32, numpy=array([2., 2., 2.], dtype=float32)>, <tf.Variable 'kernel_ema:0' shape=(3, 2) dtype=float32, numpy=
array([[2., 2.],
       [2., 2.],
       [2., 2.]], dtype=float32)>])

## Test 4

In [44]:
ema = EMA(model = model, ema_config = config['train']['ema'])

In [45]:
set_all_trainable_vars_to(model, 1.0)

In [46]:
ema.reset()

In [47]:
try:
    ema.restore()
except Exception as e:
    print(f"Error Type: {type(e).__name__}, Message : {e}")

Error Type: ValueError, Message : Cannot restore since backup doesnt exist, apply_to() needs to be called


In [48]:
ema = EMA(model = model, ema_config = config['train']['ema'])
set_all_trainable_vars_to(model, 1.0)
ema.reset()

In [49]:
ema.apply_to()

In [50]:
try:
    ema.apply_to()
except Exception as e:
    print(f"Error Type: {type(e).__name__}, Message : {e}")

Error Type: ValueError, Message : Cannot apply since backup exists, restore() needs to be called


## Test 5

In [51]:
ema = EMA(model = model, ema_config = {'enabled': True,'decay': 0.9,'warmup_steps': 5,'update_every': 2,'eval_use_ema': True})

In [52]:
set_all_trainable_vars_to(model, 0.0)

In [53]:
ema.reset()

In [54]:
ema._ema_vars

ListWrapper([<tf.Variable 'kernel_ema:0' shape=(4, 3) dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>, <tf.Variable 'bias_ema:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>, <tf.Variable 'kernel_ema:0' shape=(3, 2) dtype=float32, numpy=
array([[0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)>])

In [55]:
set_all_trainable_vars_to(model, 10.0)

In [56]:
for step in range(10):
    ema.update(step)

In [57]:
[tf.identity(var) for var in model.trainable_variables]

[<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
 array([[10., 10., 10.],
        [10., 10., 10.],
        [10., 10., 10.],
        [10., 10., 10.]], dtype=float32)>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([10., 10., 10.], dtype=float32)>,
 <tf.Tensor: shape=(3, 2), dtype=float32, numpy=
 array([[10., 10.],
        [10., 10.],
        [10., 10.]], dtype=float32)>]

In [58]:
ema.apply_to()

In [59]:
[tf.identity(var) for var in model.trainable_variables]

[<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
 array([[9.818182, 9.818182, 9.818182],
        [9.818182, 9.818182, 9.818182],
        [9.818182, 9.818182, 9.818182],
        [9.818182, 9.818182, 9.818182]], dtype=float32)>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([9.818182, 9.818182, 9.818182], dtype=float32)>,
 <tf.Tensor: shape=(3, 2), dtype=float32, numpy=
 array([[9.818182, 9.818182],
        [9.818182, 9.818182],
        [9.818182, 9.818182]], dtype=float32)>]