In [None]:
from pathlib import Path
import sys

sys.path.append(Path.cwd().resolve().parent.as_posix())

import matplotlib.pyplot as plt
import numpy as np

from sampler import SAMPLER_FORMULATION_TABLE

In [None]:
def draw_inference_timesteps(
    t_min: float,
    t_max: float,
    formulation: str
) -> None:
    t = np.linspace(0, 1, 1000)
    t_inverted = 1 - t

    scale = SAMPLER_FORMULATION_TABLE[formulation]["scale_fn"]
    sigma = SAMPLER_FORMULATION_TABLE[formulation]["sigma_fn"]
    nsr_inv = SAMPLER_FORMULATION_TABLE[formulation]["nsr_inv_fn"]

    nsr_min = sigma(t_min) / scale(t_min)
    nsr_max = sigma(t_max) / scale(t_max)

    formulations = {
        "Linear LogNSR": np.exp(np.log(nsr_max) + (np.log(nsr_min) - np.log(nsr_max)) * t),
        "Cosine LogNSR": np.tan(np.arctan(nsr_max) + (np.arctan(nsr_min) - np.arctan(nsr_max)) * t),
        "Cosine Square LogNSR": np.square(np.tan(np.arctan(np.sqrt(nsr_max)) + (np.arctan(np.sqrt(nsr_min)) - np.arctan(np.sqrt(nsr_max))) * t)),
    }
    for rho in range(1, 8):
        formulations[rf"EDM $\rho = {rho}$"] = (nsr_max ** (1 / rho) + (nsr_min ** (1 / rho) - nsr_max ** (1 / rho)) * t) ** rho

    colors = plt.cm.jet(np.linspace(0, 1, len(formulations)))

    plt.figure(figsize=(15, 6))

    plt.subplot(1, 2, 1)
    for i, (label, formulation_data) in enumerate(formulations.items()):
        plt.plot(t_inverted, formulation_data, label=label, color=colors[i])
    plt.xlabel("Normalized Timesteps")
    plt.ylabel("Noise to Signal Ratio $\sigma(t) / s(t)$")
    plt.yscale("log")
    plt.xlim([1.05, -0.05])
    plt.title(f"(Log) Inference Noise Level for {formulation}")
    plt.grid()
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(t_inverted, t_max + (t_min - t_max) * t, label="Time Uniform")
    for i, (label, formulation_data) in enumerate(formulations.items()):
        plt.plot(t_inverted, nsr_inv(formulation_data), label=label, color=colors[i])
    plt.xlabel("Normalized Timesteps")
    plt.ylabel("Absolute Timesteps")
    plt.xlim([1.05, -0.05])
    plt.title(f"Inference Timesteps for {formulation}")
    plt.grid()
    plt.legend()

    plt.show()


In [None]:
draw_inference_timesteps(0.5 * np.log(1 + 1e-4 ** 2), 0.5 * np.log(1 + 1e4 ** 2), "Ornstein-Uhlenbeck")

In [None]:
draw_inference_timesteps(1e-4, 1e4, "EDM")

In [None]:
draw_inference_timesteps(1e-4 / (1 + 1e-4), 1e4 / (1 + 1e4), "Rectified Flow")