In [2]:
import math


def adam_step_table(
    gradients, beta1=0.9, beta2=0.99, alpha=0.01, epsilon=1e-6, bias_correction=False
):
    m = 0.0
    v = 0.0

    results = []

    for t, g in enumerate(gradients, start=1):
        # Update moments
        m = beta1 * m + (1 - beta1) * g
        v = beta2 * v + (1 - beta2) * (g**2)

        # Bias correction (optional)
        if bias_correction:
            m_hat = m / (1 - beta1**t)
            v_hat = v / (1 - beta2**t)
        else:
            m_hat = m
            v_hat = v

        sqrt_v = math.sqrt(v_hat)
        m_over_sqrt_v = m_hat / sqrt_v if sqrt_v != 0 else float("inf")
        m_over_sqrt_v_eps = m_hat / (sqrt_v + epsilon)
        alpha_m_over = alpha * m_over_sqrt_v_eps
        sgd_grad = alpha * g

        # Store RAW numbers only (no rounding)
        results.append(
            [
                t,
                m_hat,
                v_hat,
                sqrt_v,
                m_over_sqrt_v,
                m_over_sqrt_v_eps,
                alpha_m_over,
                sgd_grad,
            ]
        )

    # ---------- DISPLAY STEP ----------
    print("\nAdam Optimizer Step-by-Step")
    print("Bias correction:", bias_correction)
    print(f"Gradients = {gradients}")
    print("-" * 160)

    print(
        f"{'t':>3}   | "
        f"{'mt':>8}      | "
        f"{'vt':>8}    |   "
        f"{'sqrt(vt)':>10}     | "
        f"{'mt/sqrt(vt)':>16}   | "
        f"{'mt/(sqrt(vt)+eps)':>22}   | "
        f"{'adm update':>16}   | "
        f"{'sgd update':>16}"
    )

    print("-" * 160)

    for row in results:
        print(
            f"{row[0]:>4}   | "
            f"{row[1]:>8.4f}    | "
            f"{row[2]:>8.4f}   |   "
            f"{row[3]:>10.4f}     | "
            f"{row[4]:>16.4f}      | "
            f"{row[5]:>22.4f}          | "
            f"{row[6]:>16.4f}        | "
            f"{row[7]:>16.4f}"
        )

    print("-" * 160)

In [None]:
grads = [10, -10, 9, -9, 8.5, -8.5]
print(adam_step_table(grads, bias_correction=True))
grads = [-10, 9, -9, 8.5, -8.5, 8]
print(adam_step_table(grads, bias_correction=True))