In [None]:
import os
from pathlib import Path

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HOME"] = str(Path.cwd().joinpath("cache"))

# Test Model and Scheduler by EDM Model

In [None]:
from diffusers import DiffusionPipeline
from pathlib import Path
import sys

from diffusers.utils import make_image_grid
from IPython import get_ipython
from IPython.display import display
from PIL import Image
import torch

sys.path.append(Path(get_ipython().run_line_magic("pwd", "")).resolve().parent.as_posix())

from coach_pl.configuration import CfgNode
from coach_pl.model import build_model
from coach_pl.utils.checkpoint import load_pretrained

from diffusion.module.scheduler import (
    EDMNoiseScheduler,
    RectifiedFlowNoiseScheduler,
)
from image import (
    EDMPipeline,
    MeanFlowPipeline,
    RectifiedFlowPipeline,
)

ROOT = Path(get_ipython().run_line_magic("pwd", "")).resolve().parent
CONFIGURATION_PATH = ROOT.joinpath("diffusion/configuration")
CHECKPOINT_PATH = ROOT.joinpath("output")

In [None]:
def draw_inference_result(
    pipeline: DiffusionPipeline,
    grid_size: int,
    num_inference_steps: int,
    seed: int
) -> None:
    torch.manual_seed(seed)

    samples = pipeline(output_shape=(grid_size ** 2, 3, 32, 32), num_inference_steps=num_inference_steps)
    samples = ((samples + 1.0) * 127.5).clip(0, 255).byte().detach().cpu().numpy()
    images = samples.transpose(0, 2, 3, 1)
    images = list(Image.fromarray(image) for image in images)
    display(make_image_grid(images, grid_size, grid_size))

def draw_inference_result_condition(
    pipeline: DiffusionPipeline,
    grid_size: int,
    num_inference_steps: int,
    condition: int,
    seed: int
) -> None:
    torch.manual_seed(seed)

    samples = pipeline(output_shape=(grid_size ** 2, 3, 32, 32), num_inference_steps=num_inference_steps, condition=condition)
    samples = ((samples + 1.0) * 127.5).clip(0, 255).byte().detach().cpu().numpy()
    images = samples.transpose(0, 2, 3, 1)
    images = list(Image.fromarray(image) for image in images)
    display(make_image_grid(images, grid_size, grid_size))

In [None]:
seed = 0
grid_size = 8
num_inference_steps = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
cfg_unconditional = CfgNode.load_yaml_with_base(CONFIGURATION_PATH.joinpath("edm_unet_cifar.yaml"))
cfg_conditional = CfgNode.load_yaml_with_base(CONFIGURATION_PATH.joinpath("edm_unet_cifar.yaml"))
CfgNode.merge_with_dotlist(cfg_unconditional, ["MODEL.NUM_CLASSES", 0])
CfgNode.set_readonly(cfg_unconditional, True)
CfgNode.set_readonly(cfg_conditional, True)

## Test Rectified Flow's Formulation

In [None]:
model_unconditional = build_model(cfg_unconditional).eval()
load_pretrained(model_unconditional, CHECKPOINT_PATH.joinpath("rectified_flow/velocity_unet_cifar_unconditional/regular_ckpts/last.ckpt"))
model_conditional = build_model(cfg_conditional).eval()
load_pretrained(model_conditional, CHECKPOINT_PATH.joinpath("rectified_flow/velocity_unet_cifar_conditional/regular_ckpts/last.ckpt"))

scheduler = RectifiedFlowNoiseScheduler(
    t_min=0.0001,
    t_max=0.9999,
    sigma_data=cfg_unconditional.MODULE.NOISE_SCHEDULER.SIGMA_DATA,
    prediction_type="velocity",
    algorithm_type="ode",
    timestep_schedule="uniform"
)

pipeline_unconditional = RectifiedFlowPipeline(model_unconditional, scheduler).to(device)
pipeline_conditional = RectifiedFlowPipeline(model_conditional, scheduler).to(device)

In [None]:
draw_inference_result(pipeline_unconditional, grid_size, num_inference_steps, seed)

In [None]:
for c in torch.arange(10, device=device):
    draw_inference_result_condition(pipeline_conditional, grid_size, num_inference_steps, c, seed)

## Test EDM's Formulation

In [None]:
model_unconditional = build_model(cfg_unconditional).eval()
load_pretrained(model_unconditional, CHECKPOINT_PATH.joinpath("edm/sample_unet_cifar_unconditional/regular_ckpts/last.ckpt"))
model_conditional = build_model(cfg_conditional).eval()
load_pretrained(model_conditional, CHECKPOINT_PATH.joinpath("edm/sample_unet_cifar_conditional/regular_ckpts/last.ckpt"))

scheduler = EDMNoiseScheduler(
    t_min=0.002,
    t_max=80.0,
    sigma_data=cfg_unconditional.MODULE.NOISE_SCHEDULER.SIGMA_DATA,
    prediction_type="sample",
    algorithm_type="ode",
    timestep_schedule="linear_lognsr"
)

pipeline_unconditional = EDMPipeline(model_unconditional, scheduler).to(device)
pipeline_conditional = EDMPipeline(model_conditional, scheduler).to(device)

In [None]:
draw_inference_result(pipeline_unconditional, grid_size, num_inference_steps, seed)

In [None]:
for c in torch.arange(10, device=device):
    draw_inference_result_condition(pipeline_conditional, grid_size, num_inference_steps, c, seed)

In [None]:
model_unconditional = build_model(cfg_unconditional).eval()
load_pretrained(model_unconditional, CHECKPOINT_PATH.joinpath("edm/epsilon_unet_cifar_unconditional/regular_ckpts/last.ckpt"))
model_conditional = build_model(cfg_conditional).eval()
load_pretrained(model_conditional, CHECKPOINT_PATH.joinpath("edm/epsilon_unet_cifar_conditional/regular_ckpts/last.ckpt"))

scheduler = EDMNoiseScheduler(
    t_min=0.002,
    t_max=80.0,
    sigma_data=cfg_unconditional.MODULE.NOISE_SCHEDULER.SIGMA_DATA,
    prediction_type="epsilon",
    algorithm_type="ode",
    timestep_schedule="linear_lognsr"
)

pipeline_unconditional = EDMPipeline(model_unconditional, scheduler).to(device)
pipeline_conditional = EDMPipeline(model_conditional, scheduler).to(device)

In [None]:
draw_inference_result(pipeline_unconditional, grid_size, num_inference_steps, seed)

In [None]:
for c in torch.arange(10, device=device):
    draw_inference_result_condition(pipeline_conditional, grid_size, num_inference_steps, c, seed)

## Test Mean Flow's Formulation

In [None]:
model_unconditional = build_model(cfg_unconditional).eval()
load_pretrained(model_unconditional, CHECKPOINT_PATH.joinpath("mean_flow/velocity_unet_cifar_unconditional/regular_ckpts/last.ckpt"))
model_conditional = build_model(cfg_conditional).eval()
load_pretrained(model_conditional, CHECKPOINT_PATH.joinpath("mean_flow/velocity_unet_cifar_conditional/regular_ckpts/last.ckpt"))

scheduler = RectifiedFlowNoiseScheduler(
    t_min=0.0,
    t_max=1.0,
    sigma_data=cfg_unconditional.MODULE.NOISE_SCHEDULER.SIGMA_DATA,
    prediction_type="velocity",
    algorithm_type="ode",
    timestep_schedule="uniform"
)

pipeline_unconditional = MeanFlowPipeline(model_unconditional, scheduler).to(device)
pipeline_conditional = MeanFlowPipeline(model_conditional, scheduler).to(device)

In [None]:
draw_inference_result(pipeline_unconditional, grid_size, 4 + 1, seed)

In [None]:
for c in torch.arange(10, device=device):
    draw_inference_result_condition(pipeline_conditional, grid_size, 4 + 1, c, seed)