In [None]:
# 网格排列
import torch
from torchvision.utils import save_image
from denoising_diffusion_pytorch import Unet, GaussianDiffusion
import os
import numpy as np
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"

# -------- 1. 和上面完全一样的定义 + 加载 --------
unet = Unet(
    dim=64,
    dim_mults=(1, 2, 4, 8),
    channels=3
)

diffusion = GaussianDiffusion(
    unet,
    image_size=96,
    timesteps=800
).to(device)

ckpt = torch.load("./results_stl10/model-10.pt", map_location=device)
diffusion.load_state_dict(ckpt["model"])

diffusion.eval()

# -------- 2. 采样并返回所有时间步 --------
with torch.no_grad():
    # imgs: [1, T+1, C, H, W]，已经是 [0,1]
    imgs = diffusion.sample(batch_size=1, return_all_timesteps=True)

imgs = imgs[0]   # [T+1, C, H, W]

# 想看的时间步（越大越“噪声”，越小越清晰）
timesteps_to_show = np.linspace(0, 799, 25, dtype=int)

chosen = []
for t in timesteps_to_show:
    # 防越界，稳一点
    t = min(t, imgs.shape[0] - 1)
    chosen.append(imgs[t])

chosen = torch.stack(chosen, dim=0)   # [len(steps), C, H, W]


# -------- 显示 5×5 网格 --------
fig, axes = plt.subplots(5, 5, figsize=(12, 12))

for ax, t in zip(axes.flatten(), timesteps_to_show):
    t = min(t, imgs.shape[0] - 1)

    img = imgs[t].cpu().numpy().transpose(1, 2, 0)  # C,H,W → H,W,C
    img = np.clip(img, 0, 1)  # 防止 imshow 警告

    ax.imshow(img)
    ax.set_title(f"t={t}", fontsize=10)
    ax.axis("off")

plt.tight_layout()
plt.show()  
