In [None]:
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import tensorflow.keras as tfkeras
from tensorflow.keras.optimizers.schedules import LearningRateSchedule

%matplotlib inline


def visualize_lr_schedule(opt: LearningRateSchedule, n_steps: int, step_s: int) -> Tuple[List[int], List[float]]:
    lr = []
    steps = list(range(0, n_steps, step_s))

    for step in steps:
        lr_at_s = opt(step).numpy()
        lr.append(lr_at_s)

    plt.figure(figsize=(8, 5))
    plt.suptitle(f'Learning Rate Schedule: {type(opt).__name__}')
    plt.plot(steps, lr, label="Learning Rate")
    plt.xlabel("Training Steps")
    plt.ylabel("Learning Rate")
    plt.legend()
    plt.grid(True)

    return steps, lr


In [None]:
from keras.src.optimizers.schedules import CosineDecayRestarts

cos_dec = CosineDecayRestarts(
    initial_learning_rate=0.01,
    first_decay_steps=200,
    t_mul=2.0,
    m_mul=0.5,
    alpha=1e-5
)
steps, lr_values = visualize_lr_schedule(cos_dec, n_steps=1000, step_s=3)
