In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt

## Schedule 버전

In [None]:
class LRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, init_lr, warmup_step, decay_fn):
        self.init_lr = init_lr
        self.warmup_step = warmup_step
        self.decay_fn = decay_fn

    def __call__(self, step):
        if step == 0:
            step += 1
            
        step_float = tf.cast(step, tf.float32)
        warmup_step_float = tf.cast(self.warmup_step, tf.float32)

        return tf.cond(
            step_float < warmup_step_float,
            lambda: init_lr * (step_float / warmup_step_float),
            lambda: self.decay_fn(step_float - warmup_step_float),
        )

# data_size: train_set 크기
data_size = 100000; batch_size = 512
global_step = data_size // batch_size
warmup_step = int(global_step * 0.6)
init_lr = 0.1
min_lr = 1e-6
power = 1.
    
lr_scheduler = tf.keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate = init_lr,
    decay_steps = global_step - warmup_step,
    end_learning_rate = min_lr,
    power = power
)

lr_schedule = LRSchedule(init_lr, warmup_step, lr_scheduler)

# 사용 예시
optimizer = tf.keras.optimizers.Adam(learning_rate = lr_schedule)

## Callback 버전

In [None]:
class LRSchedule(tf.keras.callbacks.Callback):
    def __init__(self, init_lr, warmup_epoch, decay_fn):
        self.init_lr = init_lr
        self.decay_fn = decay_fn
        self.warmup_epoch = warmup_epoch
        self.lrs = []

    def on_epoch_begin(self, epoch, logs = None):
        if epoch == 0:
            epoch += 1

        global_epoch = tf.cast(epoch, tf.float32)
        warmup_epoch_float = tf.cast(self.warmup_epoch, tf.float32)

        lr = tf.cond(
                global_epoch < warmup_epoch_float,
                lambda: init_lr * (global_epoch / warmup_epoch_float),
                lambda: self.decay_fn(global_epoch - warmup_epoch_float),
                )

        tf.print('learning rate: ', lr)
        tf.keras.backend.set_value(self.model.optimizer.lr, lr)
        
        self.lrs.append(lr)
        
        
epochs = 1000
warmup_epoch = int(epochs * 0.4)
init_lr = 0.1
min_lr = 1e-6
power = 1.
    
lr_scheduler = tf.keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate = init_lr,
    decay_steps = epochs - warmup_epoch,
    end_learning_rate = min_lr,
    power = power
)

# lr_schedule = LRSchedule(init_lr = init_lr,
#                          warmup_epoch = warmup_epoch,
#                          decay_fn = lr_scheduler)

# for i in range(epochs):
#     lr_schedule.on_epoch_begin(i)

# 사용 예시
model.fit(..., callbacks = [LRSchedule(init_lr = init_lr,
                                      warmup_epoch = warmup_epoch,
                                      decay_fn = lr_scheduler)])

In [None]:
plt.figure(figsize = (10, 10))
plt.plot(lr_schedule.lrs)
plt.xlabel('epochs', fontsize = 16)
plt.ylabel('learning rate', fontsize = 16)
plt.grid()