In [None]:
# | default_exp schedulers/sigmoid

# Imports

In [None]:
# | export


import math

# Scheduler

In [None]:
# | export


class SigmoidScheduler:
    def __init__(self, min_y=0.0, max_y=1.0, min_x=-7, max_x=7):
        assert min_x < max_x, "min_x must be less than max_x"
        assert min_y < max_y, "min_y must be less than max_y"

        self.min_y = min_y
        self.max_y = max_y
        self.min_x = min_x
        self.max_x = max_x
        self.num_steps = ...
        self.x_step_size = ...

        self.x = min_x

    @staticmethod
    def _sigmoid(x):
        return 1 / (1 + math.exp(-x))

    def set_num_steps(self, num_steps):
        if self.num_steps == ...:
            self.num_steps = num_steps
            self.x_step_size = (self.max_x - self.min_x) / self.num_steps

    def is_ready(self):
        return self.num_steps != ...

    def is_completed(self):
        return self.x >= self.max_x

    def get(self):
        if not self.is_ready():
            raise ValueError("Call set_num_steps first")
        y = self._sigmoid(self.x)
        scaled_y = self._scale(y)
        return scaled_y

    def step(self):
        if not self.is_ready():
            raise ValueError("Call set_num_steps first")
        if self.is_completed():
            return
        self.x = self.x + self.x_step_size

    def _scale(self, y):
        scaled_y = self.min_y + y * (self.max_y - self.min_y)
        return scaled_y

In [None]:
scheduler = SigmoidScheduler()
print(f"Is ready: {scheduler.is_ready()}")
scheduler.set_num_steps(10)
print(f"Is ready: {scheduler.is_ready()}")

for _ in range(12):
    print(f"Value: {scheduler.get()}\tIs completed: {scheduler.is_completed()}")
    scheduler.step()

Is ready: False
Is ready: True
Value: 0.0009110511944006454	Is completed: False
Value: 0.003684239899435989	Is completed: False
Value: 0.014774031693273067	Is completed: False
Value: 0.057324175898868776	Is completed: False
Value: 0.19781611144141834	Is completed: False
Value: 0.5000000000000001	Is completed: False
Value: 0.8021838885585818	Is completed: False
Value: 0.9426758241011313	Is completed: False
Value: 0.9852259683067269	Is completed: False
Value: 0.9963157601005641	Is completed: False
Value: 0.9990889488055994	Is completed: True
Value: 0.9990889488055994	Is completed: True


# nbdev

In [None]:
!nbdev_export