In [None]:
import os
from pathlib import Path

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HOME"] = str(Path.cwd().joinpath("cache"))

# Test Scheduler by Gaussian Model

In [None]:
import itertools
from pathlib import Path
import sys

from IPython import get_ipython
import matplotlib.pyplot as plt
import torch

sys.path.append(Path(get_ipython().run_line_magic("pwd", "")).resolve().parent.as_posix())

from gaussian import (
    GaussianModelPipeline,
    GaussianModel,
    GaussianModelScheduler
)
from sampler import SAMPLER_FORMULATION_TABLE

In [None]:
def draw_inference_result(
    pipeline: GaussianModelPipeline,
    batch_size: int,
    num_inference_steps: int,
    seed: int
):
    plt.figure(figsize=(16, 22))

    algorithm_types = ["ode", "sde"]
    prediction_types = ["epsilon", "sample", "velocity"]
    timestep_schedules = ["linear_lognsr", "cosine_lognsr", "power_lognsr", "uniform"]

    for row_id, (algorithm_type, prediction_type) in enumerate(itertools.product(algorithm_types, prediction_types)):
        pipeline.scheduler.config.algorithm_type = algorithm_type
        pipeline.scheduler.config.prediction_type = prediction_type
        for col_id, timestep_schedule in enumerate(timestep_schedules):
            if algorithm_type == "sde" and prediction_type == "velocity":
                continue
            pipeline.scheduler.config.timestep_schedule = timestep_schedule
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)

            samples = pipeline(batch_size=batch_size, num_inference_steps=num_inference_steps)
            samples = samples.detach().cpu().numpy()
            mean, std = samples.mean(axis=0), samples.std(axis=0)
            print(f"{algorithm_type} {prediction_type} {timestep_schedule}: mean={mean}, std={std}")

            plt.subplot(len(algorithm_types) * len(prediction_types), len(timestep_schedules), row_id * len(timestep_schedules) + col_id + 1)
            plt.hist2d(samples[:, 0], samples[:, 1], bins=[128, 128], range=[[-4, 4], [-4, 4]], cmap="plasma", cmin=1)
            plt.title(f"{algorithm_type} {prediction_type} {timestep_schedule}")
            plt.axis("equal")

    plt.show()

In [None]:
seed = 0
batch_size = 16384
num_inference_steps = 32

In [None]:
mu = torch.Tensor([0.0, 0.0])
cov = torch.Tensor([
    [1.0, 0.0],
    [0.0, 1.0]
])

## Test Rectified Flow's Formulation

In [None]:
FORMULATION = SAMPLER_FORMULATION_TABLE["Rectified Flow"]

pipeline = GaussianModelPipeline(
    GaussianModel(mu=mu, cov=cov),
    GaussianModelScheduler(
        t_min=0 + 1e-4,
        t_max=1 - 1e-4,
        sigma_data=1.0,
        scale_fn=FORMULATION["scale_fn"],
        scale_deriv_fn=FORMULATION["scale_deriv_fn"],
        sigma_fn=FORMULATION["sigma_fn"],
        sigma_deriv_fn=FORMULATION["sigma_deriv_fn"],
        nsr_inv_fn=FORMULATION["nsr_inv_fn"],
        prediction_type="epsilon",
        algorithm_type="ode",
        timestep_schedule="cosine_lognsr"
    )
)
if torch.cuda.is_available():
    pipeline = pipeline.to(device=torch.device("cuda"))

In [None]:
draw_inference_result(pipeline, batch_size, num_inference_steps, seed)

## Test EDM's Formulation

In [None]:
FORMULATION = SAMPLER_FORMULATION_TABLE["EDM"]

pipeline.scheduler = GaussianModelScheduler(
    t_min=1e-3,
    t_max=1e+3,
    sigma_data=1.0,
    scale_fn=FORMULATION["scale_fn"],
    scale_deriv_fn=FORMULATION["scale_deriv_fn"],
    sigma_fn=FORMULATION["sigma_fn"],
    sigma_deriv_fn=FORMULATION["sigma_deriv_fn"],
    nsr_inv_fn=FORMULATION["nsr_inv_fn"],
    prediction_type="epsilon",
    algorithm_type="ode",
    timestep_schedule="cosine_lognsr"
)

In [None]:
draw_inference_result(pipeline, batch_size, num_inference_steps, seed)

## Test NCSN's Formulation

In [None]:
FORMULATION = SAMPLER_FORMULATION_TABLE["NCSN"]

pipeline.scheduler = GaussianModelScheduler(
    t_min=1e-4,
    t_max=1e+4,
    sigma_data=1.0,
    scale_fn=FORMULATION["scale_fn"],
    scale_deriv_fn=FORMULATION["scale_deriv_fn"],
    sigma_fn=FORMULATION["sigma_fn"],
    sigma_deriv_fn=FORMULATION["sigma_deriv_fn"],
    nsr_inv_fn=FORMULATION["nsr_inv_fn"],
    prediction_type="epsilon",
    algorithm_type="ode",
    timestep_schedule="cosine_lognsr"
)

In [None]:
draw_inference_result(pipeline, batch_size, num_inference_steps, seed)

In [None]:
plt.figure(figsize=(5, 5))

samples = torch.distributions.MultivariateNormal(
    loc=mu,
    covariance_matrix=cov
).sample((batch_size,)).cpu().numpy()

plt.hist2d(samples[:, 0], samples[:, 1], bins=[128, 128], range=[[-4, 4], [-4, 4]], cmap="plasma", cmin=1)
plt.title("Ground Truth")

plt.show()