# Resume Training: Saving and Loading Models


## What is Resume Training?

Resume training refers to the process of continuing the training of a machine learning model from a previously saved state. This includes restoring model weights, optimizer state, and sometimes epoch counters or learning rate schedules.

## Why is Resume Training Important?

- **Time Efficiency**: You don't have to start training from scratch if it's interrupted.
- **Fault Tolerance**: Saves work in progress in case of system crashes or restarts.
- **Experimentation**: Allows tuning or debugging from a specific training point.
- **Resource Optimization**: Ideal for cloud or shared resource environments with limited usage time.

## How Does it Work?

1. **Saving**: Periodically save the model weights and optimizer state during training.
2. **Loading**: Restore the saved weights and optimizer to continue training from the last checkpoint.


## Example: Save and Resume Training in TensorFlow

In [3]:

import tensorflow as tf
import numpy as np
import os

# Load the dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Normalize pixel values to [0, 1]
x_train = x_train / 255.0
x_test = x_test / 255.0

# Flatten the images for dense input
x_train = x_train.reshape((-1, 28 * 28))
x_test = x_test.reshape((-1, 28 * 28))

# Split validation set from training data
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Define a simple model
def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

model = create_model()

# Define a checkpoint callback
checkpoint_path = "training_checkpoints/cp.weights.h5"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

# Train the model with the checkpoint callback
model.fit(x_train, y_train, epochs=5, callbacks=[cp_callback])


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m1539/1563[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 2ms/step - accuracy: 0.8466 - loss: 0.5537
Epoch 1: saving model to training_checkpoints/cp.weights.h5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2ms/step - accuracy: 0.8475 - loss: 0.5503
Epoch 2/5
[1m1560/1563[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 4ms/step - accuracy: 0.9490 - loss: 0.1719
Epoch 2: saving model to training_checkpoints/cp.weights.h5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 4ms/step - accuracy: 0.9490 - loss: 0.1719
Epoch 3/5
[1m1560/1563[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 6ms/step - accuracy: 0.9657 - loss: 0.1172
Epoch 3: saving model to training_checkpoints/cp.weights.h5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.9657 - loss: 0.1172
Epoch 4/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.9729 - loss: 0.0907
Epo

<keras.src.callbacks.history.History at 0x1e6c9054d60>

### Resume Training from Checkpoint

In [4]:

# Create a new model instance
resumed_model = create_model()

# Load weights from the checkpoint
resumed_model.load_weights(checkpoint_path)

# Continue training
resumed_model.fit(x_train, y_train, epochs=5)


  saveable.load_own_variables(weights_store.get(inner_path))


Epoch 1/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 4ms/step - accuracy: 0.9812 - loss: 0.0627
Epoch 2/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 4ms/step - accuracy: 0.9852 - loss: 0.0495
Epoch 3/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - accuracy: 0.9884 - loss: 0.0414
Epoch 4/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 4ms/step - accuracy: 0.9897 - loss: 0.0373
Epoch 5/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 4ms/step - accuracy: 0.9914 - loss: 0.0292


<keras.src.callbacks.history.History at 0x1e697fe76a0>