In [None]:
import os
from pathlib import Path

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HOME"] = str(Path.cwd().joinpath("cache"))

# Test Model and Scheduler by EDM Model

In [None]:
from pathlib import Path
import sys

from IPython import get_ipython
import matplotlib.pyplot as plt
import torch

sys.path.append(Path(get_ipython().run_line_magic("pwd", "")).resolve().parent.as_posix())

from coach_pl.configuration import CfgNode
from coach_pl.model import load_pretrained
from diffusion.model import GeneralContinuousTimeDiffusion, build_backbone
from image import UnconditionalGenerationPipeline
from sampler import SAMPLER_FORMULATION_TABLE, GeneralContinuousTimeDiffusionScheduler

In [None]:
FORMULATION = SAMPLER_FORMULATION_TABLE["EDM"]

cfg = CfgNode.load_yaml_with_base("edm.yaml")

In [None]:
pipeline = UnconditionalGenerationPipeline(
    GeneralContinuousTimeDiffusion(
        build_backbone(cfg),
        cfg.MODEL.PREDICTION_TYPE,
        cfg.MODEL.SIGMA_DATA,
        cfg.MODEL.REQUIRE_PRE_AND_POST_PROCESSING,
    ).eval(),
    GeneralContinuousTimeDiffusionScheduler(
        t_min=0.002,
        t_max=80.0,
        sigma_data=0.5,
        scale_fn=FORMULATION["scale_fn"],
        sigma_fn=FORMULATION["sigma_fn"],
        nsr_inv_fn=FORMULATION["nsr_inv_fn"],
        prediction_type="sample",
        algorithm_type="ode",
        timestep_schedule="power_lognsr"
    )
)
# TODO: Load the model using from_pretrained
pipeline.model = load_pretrained(pipeline.model, "../output/edm_sample_cifar/regular_ckpts/last.ckpt")
if torch.cuda.is_available():
    pipeline = pipeline.to(device=torch.device("cuda"))

In [None]:
seed = 0
grid_size = 8
batch_size = grid_size * grid_size
num_inference_steps = 32

In [None]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

samples = pipeline(batch_size=batch_size, num_inference_steps=num_inference_steps)
samples = (samples * 127.5 + 128).detach().clip(0, 255).byte().cpu().numpy()

In [None]:
plt.figure(figsize=(grid_size, grid_size))
image = samples.transpose(0, 2, 3, 1).reshape((grid_size, grid_size, 32, 32, 3)).transpose(0, 2, 1, 3, 4).reshape(32 * grid_size, 32 * grid_size, 3)
plt.imshow(image)
plt.axis("off")
plt.show()