In [None]:
# | default_exp schedulers/cyclic

# Imports

In [None]:
# | export


import math

from torch.optim.lr_scheduler import LRScheduler

In [None]:
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import Adam

# Scheduler

In [None]:
# | export


class SineScheduler:
    def __init__(self, start_value: float, max_value: float, decay: float = 0.0, wavelength: int | None = None):
        assert 0.0 <= decay < 1.0, "Decay must be between 0 and 1"

        self.start_value = start_value
        self.max_value = max_value
        self.decay_factor = 1 - decay
        self.wavelength = None

        self.pseudo_max_value = max_value / (self.decay_factor**0.5)

        self.x = 1

        if wavelength is not None:
            self.set_wavelength(wavelength)

    def set_wavelength(self, wavelength: int):
        assert wavelength > 0, "Wavelength must be greater than 0"
        self.wavelength = wavelength
        return self  # to allow chaining

    def is_ready(self):
        return self.wavelength is not None

    def get(self):
        if not self.is_ready():
            raise ValueError("Call set_wavelength first")

        # Calculate angle based on current step and wavelength and get sine value
        angle = (-0.5 + 2 * self.x / self.wavelength) * math.pi
        sine = math.sin(angle)

        # Scale it to the range of pseudo_max_lr and max_lr
        scaled = (self.pseudo_max_value - self.start_value) * (1 + sine) / 2

        # Apply decay to it
        decayed = scaled * self.decay_factor ** ((self.x + 1) / self.wavelength)

        # Increase it by the start_lr
        lr = decayed + self.start_value

        return lr

    def step(self):
        if not self.is_ready():
            raise ValueError("Call set_wavelength first")
        self.x = self.x + 1

In [None]:
scheduler = SineScheduler(0, 1, 0.9).set_wavelength(5)

for _ in range(15):
    print(f"Value: {scheduler.get()}")
    scheduler.step()

Value: 0.4349480324496454
Value: 0.7184766378640389
Value: 0.4533281114977514
Value: 0.10925400611220523
Value: 0.0
Value: 0.043494803244964526
Value: 0.07184766378640382
Value: 0.04533281114977513
Value: 0.010925400611220529
Value: 0.0
Value: 0.004349480324496461
Value: 0.007184766378640381
Value: 0.0045332811149775155
Value: 0.001092540061122056
Value: 0.0


In [None]:
scheduler = SineScheduler(0, 1, 0).set_wavelength(14)

for _ in range(15):
    print(f"Value: {scheduler.get()}")
    scheduler.step()

Value: 0.04951556604879043
Value: 0.18825509907063326
Value: 0.3887395330218428
Value: 0.6112604669781572
Value: 0.8117449009293667
Value: 0.9504844339512095
Value: 1.0
Value: 0.9504844339512095
Value: 0.8117449009293667
Value: 0.6112604669781573
Value: 0.38873953302184283
Value: 0.18825509907063348
Value: 0.04951556604879048
Value: 0.0
Value: 0.04951556604879037


In [None]:
# | export


class SineLR(LRScheduler):
    def __init__(self, optimizer, start_lr, max_lr, wavelength, decay, last_epoch=-1, verbose="deprecated"):
        self.scheduler = SineScheduler(start_lr, max_lr, decay).set_wavelength(wavelength)
        self.scheduler.x -= 1  # To match the output of the non-LR scheduler
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        lr = self.scheduler.get()
        return [lr for _ in self.optimizer.param_groups]

    def step(self, epoch=None):
        self.scheduler.step()
        return super().step(epoch)

In [None]:
optimizer = Adam([nn.Parameter()])
scheduler = SineLR(optimizer, 0, 1, 5, 0.9)

for _ in range(15):
    print(f"Value: {scheduler.get_lr()}")
    scheduler.step()

Value: [0.4349480324496454]
Value: [0.7184766378640389]
Value: [0.4533281114977514]
Value: [0.10925400611220523]
Value: [0.0]
Value: [0.043494803244964526]
Value: [0.07184766378640382]
Value: [0.04533281114977513]
Value: [0.010925400611220529]
Value: [0.0]
Value: [0.004349480324496461]
Value: [0.007184766378640381]
Value: [0.0045332811149775155]
Value: [0.001092540061122056]
Value: [0.0]


# nbdev

In [None]:
!nbdev_export