In [17]:
"""
Exercise 02: Checkpoint Strategy - Starter Code

Implement optimal checkpointing for long training runs.

Prerequisites:
- Reading: 02-model-checkpoints.md
- Demo: demo_02_checkpoint_callback.py (KEY REFERENCE)
"""

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import os
import shutil

# Setup
CKPT_DIR = 'checkpoints'
if os.path.exists(CKPT_DIR):
    shutil.rmtree(CKPT_DIR)
os.makedirs(CKPT_DIR)

In [18]:
# ============================================================================
# SETUP (PROVIDED)
# ============================================================================

def load_data():
    """Load MNIST subset"""
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    x_train = x_train[:10000].reshape(-1, 784).astype('float32') / 255.0
    x_test = x_test.reshape(-1, 784).astype('float32') / 255.0
    return (x_train, y_train[:10000]), (x_test, y_test)

def create_model():
    """Create simple model"""
    model = keras.Sequential([
        layers.Dense(128, activation='relu', input_shape=(784,)),
        layers.Dense(64, activation='relu'),
        layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

In [19]:
# ============================================================================
# TASK 2.1: Basic Checkpointing
# ============================================================================

def basic_checkpointing():
    """
    Save checkpoint after every epoch.
    
    ModelCheckpoint PARAMETERS:
    - filepath: where to save (use {epoch:02d} for epoch number)
    - save_weights_only: True for smaller files
    - save_freq: 'epoch' or integer (batches)
    - verbose: 1 to print save messages
    
    FILEPATH PATTERN:
    'checkpoints/ckpt_epoch-{epoch:02d}.keras'
    -> Creates: ckpt_epoch-01.keras, ckpt_epoch-02.keras, etc.
    """
    print("Task 2.1: Basic Checkpointing")
    
    (x_train, y_train), (x_test, y_test) = load_data()
    model = create_model()
    
    checkpoint_cb = keras.callbacks.ModelCheckpoint(
        filepath='checkpoints/ckpt_epoch-{epoch:02d}.weights.h5',
        save_weights_only=True,
        save_freq='epoch',
        verbose=1
    )
    
    model.fit(
        x_train, y_train,
        epochs=5,
        validation_data=(x_test, y_test),
        callbacks=[checkpoint_cb]
    )
    
    print(f"\nCheckpoints saved: {os.listdir(CKPT_DIR)}")

In [20]:
# ============================================================================
# TASK 2.2: Save Best Model Only
# ============================================================================

def save_best_only():
    """
    Only save when validation loss improves.
    
    KEY PARAMETERS:
    - monitor='val_loss': Watch validation loss
    - mode='min': Lower is better
    - save_best_only=True: Only save improvements
    
    BENEFIT: Saves disk space, keeps only best model
    """
    print("Task 2.2: Save Best Only")
    
    (x_train, y_train), (x_test, y_test) = load_data()
    model = create_model()
    
    checkpoint_cb = keras.callbacks.ModelCheckpoint(
        filepath='checkpoints/best_model.keras',
        monitor='val_loss',
        mode='min',
        save_best_only=True,
        verbose=1
    )
    
    model.fit(
        x_train, y_train,
        epochs=10,
        validation_data=(x_test, y_test),
        callbacks=[checkpoint_cb]
    )
    
    print(f"\nCheckpoints saved: {os.listdir(CKPT_DIR)}")

In [21]:
# ============================================================================
# TASK 2.3: Resume Training from Checkpoint
# ============================================================================

def resume_training():
    """
    Simulate crash recovery: train 5 epochs, "crash", resume.
    
    SCENARIO:
    1. Train for 5 epochs, save checkpoint
    2. Clear model (simulate crash)
    3. Load checkpoint
    4. Continue training for 5 more epochs
    
    KEY: After loading, model should continue improving!
    
    HINT: After keras.models.load_model(), you may need to recompile
    if optimizer state issues occur (Keras 3 quirk).
    """
    print("Task 2.3: Resume Training")
    
    (x_train, y_train), (x_test, y_test) = load_data()
    
    # Phase 1: Train for 5 epochs
    print("\n--- Phase 1: Initial Training (5 epochs) ---")
    model = create_model()
    checkpoint_path = 'checkpoints/resume_model.keras'
    
    checkpoint_cb = keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        save_best_only=False,
        verbose=1
    )
    
    history1 = model.fit(
        x_train, y_train,
        epochs=5,
        validation_data=(x_test, y_test),
        callbacks=[checkpoint_cb]
    )
    
    final_loss_phase1 = history1.history['val_loss'][-1]
    print(f"\nPhase 1 final val_loss: {final_loss_phase1:.4f}")
    
    # Phase 2: Simulate crash
    print("\n--- Simulating Crash (deleting model) ---")
    del model
    
    # Phase 3: Load checkpoint and resume
    print("\n--- Phase 2: Resuming Training (5 more epochs) ---")
    model = keras.models.load_model(checkpoint_path)
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    
    history2 = model.fit(
        x_train, y_train,
        epochs=5,
        validation_data=(x_test, y_test),
        callbacks=[checkpoint_cb]
    )
    
    final_loss_phase2 = history2.history['val_loss'][-1]
    print(f"\nPhase 2 final val_loss: {final_loss_phase2:.4f}")
    print(f"Improvement: {final_loss_phase1 - final_loss_phase2:.4f}")

In [22]:
# ============================================================================
# TASK 2.4: Cleanup Old Checkpoints
# ============================================================================

class KeepNCheckpoints(keras.callbacks.Callback):
    """Custom callback to keep only the N most recent checkpoints."""
    
    def __init__(self, checkpoint_dir, keep_n=3, pattern='smart_ckpt_epoch-'):
        super().__init__()
        self.checkpoint_dir = checkpoint_dir
        self.keep_n = keep_n
        self.pattern = pattern
    
    def on_epoch_end(self, epoch, logs=None):
        checkpoints = sorted([
            f for f in os.listdir(self.checkpoint_dir) 
            if f.startswith(self.pattern)
        ])
        
        while len(checkpoints) > self.keep_n:
            old_ckpt = checkpoints.pop(0)
            os.remove(os.path.join(self.checkpoint_dir, old_ckpt))
            print(f"  Deleted old checkpoint: {old_ckpt}")

def smart_checkpointing():
    """
    Keep only the N most recent checkpoints to save disk space.
    
    APPROACH:
    1. Use custom callback or post-training cleanup
    2. After each epoch, delete checkpoints older than last N
    
    ALTERNATIVE: Use keras.callbacks.BackupAndRestore for automatic resume
    """
    print("Task 2.4: Smart Checkpointing")
    
    (x_train, y_train), (x_test, y_test) = load_data()
    model = create_model()
    
    checkpoint_cb = keras.callbacks.ModelCheckpoint(
        filepath='checkpoints/smart_ckpt_epoch-{epoch:02d}.keras',
        verbose=1
    )
    
    cleanup_cb = KeepNCheckpoints(CKPT_DIR, keep_n=3)
    
    model.fit(
        x_train, y_train,
        epochs=8,
        validation_data=(x_test, y_test),
        callbacks=[checkpoint_cb, cleanup_cb]
    )
    
    print(f"\nFinal checkpoints (only last 3): {sorted(os.listdir(CKPT_DIR))}")

In [24]:
# ============================================================================
# MAIN
# ============================================================================

if __name__ == "__main__":
    print("=" * 60)
    print("Exercise 02: Checkpoint Strategy")
    print("=" * 60)
    
    # Uncomment as you complete:
    basic_checkpointing()
    save_best_only()
    resume_training()
    smart_checkpointing()

Exercise 02: Checkpoint Strategy
Task 2.1: Basic Checkpointing
Epoch 1/5
[1m300/313[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 675us/step - accuracy: 0.7369 - loss: 0.8928
Epoch 1: saving model to checkpoints/ckpt_epoch-01.weights.h5

Epoch 1: finished saving model to checkpoints/ckpt_epoch-01.weights.h5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step - accuracy: 0.8557 - loss: 0.5058 - val_accuracy: 0.9241 - val_loss: 0.2691
Epoch 2/5
[1m290/313[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 700us/step - accuracy: 0.9376 - loss: 0.2158
Epoch 2: saving model to checkpoints/ckpt_epoch-02.weights.h5

Epoch 2: finished saving model to checkpoints/ckpt_epoch-02.weights.h5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - accuracy: 0.9399 - loss: 0.2095 - val_accuracy: 0.9349 - val_loss: 0.2188
Epoch 3/5
[1m254/313[0m [32m━━━━━━━━━━━━━━━━[0m[37m━━━━[0m [1m0s[0m 596us/step - accuracy: 0.9573 - loss: 0.1520
Epoc