In [None]:
import os
from pathlib import Path

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HOME"] = str(Path.cwd().joinpath("cache"))

# Test Model and Scheduler by EDM Model

In [None]:
from pathlib import Path
import sys

from diffusers.utils import make_image_grid
from IPython import get_ipython
from IPython.display import display
from PIL import Image
import torch

sys.path.append(Path(get_ipython().run_line_magic("pwd", "")).resolve().parent.as_posix())

from coach_pl.configuration import CfgNode
from coach_pl.model import build_model, load_pretrained

from diffusion.model import EDMNoiseScheduler, RectifiedFlowNoiseScheduler
from image import UnconditionalGenerationPipeline

In [None]:
def draw_inference_result(
    pipeline: UnconditionalGenerationPipeline,
    grid_size: int,
    num_inference_steps: int,
    seed: int
):   
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    batch_size = grid_size ** 2
    samples = pipeline(batch_size=batch_size, num_inference_steps=num_inference_steps)
    samples = (samples * 127.5 + 128).detach().clip(0, 255).byte().cpu().numpy()
    images = samples.transpose(0, 2, 3, 1)
    images = list(Image.fromarray(image) for image in images)
    display(make_image_grid(images, grid_size, grid_size))

In [None]:
seed = 0
grid_size = 8
num_inference_steps = 32

In [None]:
cfg = CfgNode.load_yaml_with_base("image.yaml")
CfgNode.set_readonly(cfg, True)

## Test Rectified Flow's Formulation

In [None]:
pipeline = UnconditionalGenerationPipeline(
    build_model(cfg).eval(),
    RectifiedFlowNoiseScheduler(
        t_min=0.0001,
        t_max=0.9999,
        sigma_data=cfg.MODEL.SIGMA_DATA,
        prediction_type="velocity",
        algorithm_type="ode",
        timestep_schedule="linear_lognsr"
    )
)
pipeline.model = load_pretrained(pipeline.model, "../output/rf_velocity_unet_cifar/regular_ckpts/last.ckpt")
if torch.cuda.is_available():
    pipeline = pipeline.to(device=torch.device("cuda"))

In [None]:
draw_inference_result(pipeline, grid_size, num_inference_steps, seed)

## Test EDM's Formulation

In [None]:
pipeline = UnconditionalGenerationPipeline(
    build_model(cfg).eval(),
    EDMNoiseScheduler(
        t_min=0.002,
        t_max=80.0,
        sigma_data=cfg.MODEL.SIGMA_DATA,
        prediction_type="sample",
        algorithm_type="ode",
        timestep_schedule="linear_lognsr"
    )
)
pipeline.model = load_pretrained(pipeline.model, "../output/edm_sample_unet_cifar/regular_ckpts/last.ckpt")
if torch.cuda.is_available():
    pipeline = pipeline.to(device=torch.device("cuda"))

In [None]:
draw_inference_result(pipeline, grid_size, num_inference_steps, seed)