## 1. 安装和导入依赖

如果没有安装 diffusers，先运行：
```
pip install diffusers transformers accelerate
```

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# 检查 GPU
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. 加载预训练的 Stable Diffusion 模型

In [None]:
from diffusers import StableDiffusionPipeline

# 加载 SD 1.5 模型（约4GB，首次下载需要几分钟）
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,  # 半精度，节省显存
    safety_checker=None,  # 关闭安全检查，加速
    requires_safety_checker=False
)
pipe = pipe.to("cuda")

# 优化设置
pipe.enable_attention_slicing()  # 节省显存

print("模型加载完成！")

## 3. 快速生成测试（单张图像）

In [None]:
import time

# 生成单张图像
prompt = "a cute cat sitting on grass, realistic photo"

start_time = time.time()
image = pipe(
    prompt,
    num_inference_steps=20,  # 20步足够7-8成效果
    guidance_scale=7.5,
    height=512,
    width=512
).images[0]
elapsed = time.time() - start_time

print(f"生成时间: {elapsed:.2f} 秒")

# 显示图像
plt.figure(figsize=(6, 6))
plt.imshow(image)
plt.title(f"Prompt: {prompt}")
plt.axis('off')
plt.show()

## 4. 生成 CIFAR-10 类别的图像

CIFAR-10 包含 10 个类别：airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck

In [None]:
# CIFAR-10 类别
cifar10_classes = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

# 为每个类别生成一张图像
generated_images = []

for cls in cifar10_classes:
    prompt = f"a photo of a {cls}, high quality, realistic"
    print(f"生成: {cls}...", end=" ")
    
    start_time = time.time()
    image = pipe(
        prompt,
        num_inference_steps=20,
        guidance_scale=7.5,
        height=512,
        width=512
    ).images[0]
    elapsed = time.time() - start_time
    
    generated_images.append(image)
    print(f"完成 ({elapsed:.2f}s)")

print("\n所有类别生成完成！")

In [None]:
# 可视化所有生成的图像
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

for idx, (img, cls) in enumerate(zip(generated_images, cifar10_classes)):
    axes[idx].imshow(img)
    axes[idx].set_title(cls)
    axes[idx].axis('off')

plt.suptitle("Generated CIFAR-10 Categories (Stable Diffusion 1.5)", fontsize=14)
plt.tight_layout()
plt.show()

## 5. 批量生成（每个类别多张）

In [None]:
# 批量生成：每个类别生成 n 张
n_per_class = 5  # 每个类别生成5张
batch_images = {cls: [] for cls in cifar10_classes}

total = n_per_class * len(cifar10_classes)
count = 0

start_total = time.time()

for cls in cifar10_classes:
    for i in range(n_per_class):
        prompt = f"a photo of a {cls}, high quality, realistic"
        
        image = pipe(
            prompt,
            num_inference_steps=20,
            guidance_scale=7.5,
            height=512,
            width=512
        ).images[0]
        
        batch_images[cls].append(image)
        count += 1
        print(f"\r进度: {count}/{total}", end="")

elapsed_total = time.time() - start_total
print(f"\n\n总计生成 {total} 张图像")
print(f"总耗时: {elapsed_total:.1f} 秒")
print(f"平均每张: {elapsed_total/total:.2f} 秒")

## 6. 提取 Latent 表示（用于后续分析）

In [None]:
from diffusers import AutoencoderKL
from torchvision import transforms

# 加载 VAE（可以单独使用，用于提取 latent）
vae = AutoencoderKL.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="vae",
    torch_dtype=torch.float16
).to("cuda")

print("VAE 加载完成！")
print(f"Latent 空间维度: 4 x 64 x 64 (对于 512x512 图像)")

In [None]:
# 图像预处理
preprocess = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # 归一化到 [-1, 1]
])

# 将生成的图像编码到 latent 空间
def encode_to_latent(image, vae):
    """将 PIL 图像编码到 latent 空间"""
    img_tensor = preprocess(image).unsqueeze(0).to("cuda", dtype=torch.float16)
    with torch.no_grad():
        latent = vae.encode(img_tensor).latent_dist.sample()
        latent = latent * 0.18215  # 缩放因子
    return latent

# 测试：编码第一张生成的图像
test_latent = encode_to_latent(generated_images[0], vae)
print(f"Latent shape: {test_latent.shape}")
print(f"Latent dtype: {test_latent.dtype}")

## 7. 加载本地 CIFAR-10 并编码到 Latent 空间

In [None]:
import pickle

# 加载本地 CIFAR-10 数据
def load_cifar10_batch(filepath):
    with open(filepath, 'rb') as f:
        data_dict = pickle.load(f, encoding='bytes')
    images = data_dict[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
    labels = data_dict[b'labels']
    return images, labels

# 加载第一个 batch
cifar_path = "data/cifar-10-batches-py/data_batch_1"
images, labels = load_cifar10_batch(cifar_path)

print(f"CIFAR-10 图像形状: {images.shape}")
print(f"标签数量: {len(labels)}")

# 显示几张 CIFAR-10 原图
fig, axes = plt.subplots(1, 5, figsize=(12, 3))
for i in range(5):
    axes[i].imshow(images[i])
    axes[i].set_title(cifar10_classes[labels[i]])
    axes[i].axis('off')
plt.suptitle("Original CIFAR-10 Images (32x32)")
plt.show()

## 8. KSWGD 生成方法（替代 Stable Diffusion 的扩散过程）

使用 KSWGD（Kernel Stein Wasserstein Gradient Descent）在 latent space 中进行粒子传输生成。

**核心思想：**
- 保留 VAE 的 encoder/decoder（预训练的 latent space）
- 用 KSWGD 替代 UNet 去噪过程
- 从随机噪声出发，通过核梯度流传输到目标分布

In [None]:
# 导入 KSWGD 所需的库和自定义核函数
from sklearn.metrics import pairwise_distances
from tqdm.auto import trange

# 导入你的核函数
from grad_ker1 import grad_ker1
from K_tar_eval import K_tar_eval

# 尝试导入 GPU 版本
try:
    import cupy as cp
    from grad_ker1_gpu import grad_ker1 as grad_ker1_gpu
    from K_tar_eval_gpu import K_tar_eval as K_tar_eval_gpu
    GPU_KSWGD = True
    print("✓ GPU KSWGD backend available (CuPy)")
except Exception as e:
    cp = None
    grad_ker1_gpu = None
    K_tar_eval_gpu = None
    GPU_KSWGD = False
    print(f"✗ GPU KSWGD backend not available: {e}")
    print("  Using CPU backend instead")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

In [None]:
# 重新加载 VAE 用于 KSWGD 生成
from diffusers import AutoencoderKL

vae = AutoencoderKL.from_pretrained(
    "stabilityai/sd-vae-ft-mse",  # 使用优化过的 VAE
    torch_dtype=torch.float32  # KSWGD 需要 float32 精度
).to(device)

vae_scaling = float(getattr(vae.config, "scaling_factor", 0.18215))
print(f"VAE scaling factor: {vae_scaling}")

# VAE 辅助函数
def _to_vae_range(x):
    """[0,1] → [-1,1]"""
    return (x * 2.0) - 1.0

def _from_vae_range(x):
    """[-1,1] → [0,1]"""
    return torch.clamp((x + 1.0) * 0.5, 0.0, 1.0)

In [None]:
# 将 CIFAR-10 图像编码到 VAE latent space 作为 KSWGD 目标分布
from torchvision import datasets, transforms as T
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

# 加载 CIFAR-10 数据集
transform_cifar = T.Compose([
    T.Resize((256, 256)),  # VAE 需要较大尺寸输入
    T.ToTensor(),
])

train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform_cifar, download=False)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)

print(f"CIFAR-10 训练集大小: {len(train_dataset)}")

# 编码所有训练图像到 latent space
max_samples = 2000  # 限制样本数量以加速（可调整）
all_latents = []
all_labels = []

vae.eval()
print(f"正在编码 {max_samples} 张图像到 latent space...")

# 计算需要的 batch 数量
n_batches_needed = (max_samples + 63) // 64

with torch.no_grad():
    count = 0
    pbar = tqdm(train_loader, total=n_batches_needed, desc="编码进度", unit="batch")
    
    for batch_imgs, batch_labels in pbar:
        if count >= max_samples:
            break
        
        batch_imgs = batch_imgs.to(device)
        # 编码到 latent
        latents = vae.encode(_to_vae_range(batch_imgs)).latent_dist.mode()
        latents = latents * vae_scaling
        
        # 展平 latent
        latents_flat = latents.view(latents.size(0), -1).cpu().numpy()
        all_latents.append(latents_flat)
        all_labels.append(batch_labels.numpy())
        
        count += latents_flat.shape[0]
        pbar.set_postfix({"已编码": f"{min(count, max_samples)}/{max_samples}"})
    
    pbar.close()

Z_all = np.concatenate(all_latents, axis=0)[:max_samples]
y_labels = np.concatenate(all_labels, axis=0)[:max_samples]

print(f"\nLatent codes shape: {Z_all.shape}")
print(f"Latent dim: {Z_all.shape[1]}")

# 记录 latent shape 用于后续解码
with torch.no_grad():
    dummy = torch.zeros(1, 3, 256, 256, device=device)
    dummy_latent = vae.encode(_to_vae_range(dummy)).latent_dist.mode()
    latent_shape = dummy_latent.shape[1:]  # (C, H, W)
    
print(f"Latent shape (C, H, W): {latent_shape}")

In [None]:
# 标准化 latent codes 并构建 KSWGD 核算子
# 标准化
Z_mean = np.mean(Z_all, axis=0, keepdims=True)
Z_std = np.std(Z_all, axis=0, keepdims=True) + 1e-8
Z_std = Z_std.astype(np.float64)
Z_mean = Z_mean.astype(np.float64)
X_tar = ((Z_all - Z_mean) / Z_std).astype(np.float64)

print(f"标准化后: mean={X_tar.mean():.4f}, std={X_tar.std():.4f}")

# 计算目标样本的平方和（用于核函数）
sq_tar = np.sum(X_tar ** 2, axis=1)

# 计算成对距离和带宽 epsilon
dists = pairwise_distances(X_tar, metric="euclidean")
eps_kswgd = np.median(dists**2) / (2.0 * np.log(X_tar.shape[0] + 1))
eps_kswgd = float(max(eps_kswgd, 1e-6))

print(f"KSWGD epsilon: {eps_kswgd:.6f}")
print(f"距离统计: min={dists[dists>0].min():.4f}, median={np.median(dists):.4f}, max={dists.max():.4f}")

# 构建数据核矩阵
data_kernel = np.exp(-dists**2 / (2.0 * eps_kswgd))

# 归一化
p_x = np.sqrt(np.sum(data_kernel, axis=1))
data_kernel_norm = data_kernel / (p_x[:, None] * p_x[None, :] + 1e-12)
D_y = np.sum(data_kernel_norm, axis=0)
rw_kernel = 0.5 * (data_kernel_norm / (D_y + 1e-12) + data_kernel_norm / (D_y[:, None] + 1e-12))
rw_kernel = np.nan_to_num(rw_kernel)

print(f"核矩阵构建完成，shape: {rw_kernel.shape}")

In [None]:
# 计算特征分解和 KSWGD 权重
import time

print("正在计算特征分解（可能需要 1-2 分钟）...")
start_time = time.time()

# 使用 GPU 加速特征分解（如果可用）
if torch.cuda.is_available():
    print("  使用 GPU 加速...")
    rw_kernel_torch = torch.from_numpy(rw_kernel).to(device)
    lambda_ns_torch, phi_torch = torch.linalg.eigh(rw_kernel_torch)
    lambda_ns = lambda_ns_torch.cpu().numpy()[::-1].copy()
    phi = phi_torch.cpu().numpy()[:, ::-1].copy()
    del rw_kernel_torch, lambda_ns_torch, phi_torch
    torch.cuda.empty_cache()
else:
    print("  使用 CPU...")
    lambda_ns, phi = np.linalg.eigh(rw_kernel)
    phi = phi[:, ::-1]
    lambda_ns = lambda_ns[::-1]

elapsed = time.time() - start_time
print(f"特征分解完成，耗时: {elapsed:.1f}s")

# 设置正则化参数
tol = 1e-6
reg = 1e-3
latent_dim = X_tar.shape[1]

# 计算逆特征值
lambda_ = lambda_ns - 1.0
inv_lambda = np.zeros_like(lambda_)
inv_lambda[1:] = 1.0 / np.clip(lambda_[1:], 1e-12, None)
inv_lambda *= eps_kswgd

# 截断小特征值
lambda_ns_inv = np.zeros_like(lambda_ns)
mask = lambda_ns >= tol
lambda_ns_inv[mask] = eps_kswgd / (lambda_ns[mask] + reg)
above_tol = int(np.sum(mask))
phi_trunc = phi[:, :above_tol]
lambda_ns_s_ns = (lambda_ns_inv * inv_lambda * lambda_ns_inv)[:above_tol]

# 目标分布权重
p_tar = np.sum(data_kernel, axis=0)
sqrt_p = np.sqrt(p_tar + 1e-12)
D_vec = np.sum(data_kernel / sqrt_p[:, None] / sqrt_p[None, :], axis=1)

print(f"保留的特征向量数量: {above_tol}")
print(f"前 10 个特征值: {lambda_ns[:10]}")

In [None]:
# 定义 KSWGD 采样器
def run_kswgd_sampler(num_particles=16, num_iters=200, step_size=0.05, rng_seed=42):
    """
    KSWGD 粒子传输采样器
    
    从标准正态分布初始化粒子，通过核梯度流传输到目标分布
    
    Args:
        num_particles: 生成的样本数量
        num_iters: 迭代次数
        step_size: 步长
        rng_seed: 随机种子
    
    Returns:
        生成的标准化 latent vectors
    """
    rng = np.random.default_rng(rng_seed)
    
    # 选择后端
    use_gpu = GPU_KSWGD and torch.cuda.is_available()
    xp = cp if use_gpu else np
    grad_fn = grad_ker1_gpu if use_gpu else grad_ker1
    K_eval_fn = K_tar_eval_gpu if use_gpu else K_tar_eval
    
    print(f"KSWGD 后端: {'GPU (CuPy)' if use_gpu else 'CPU (NumPy)'}")
    
    # 初始化粒子轨迹
    x_hist = xp.zeros((num_particles, latent_dim, num_iters), dtype=xp.float64)
    init_particles = rng.normal(0.0, 1.0, size=(num_particles, latent_dim))
    x_hist[:, :, 0] = xp.asarray(init_particles)
    
    # 准备目标数据
    if use_gpu:
        X_tar_dev = cp.asarray(X_tar)
        p_tar_dev = cp.asarray(p_tar)
        sq_tar_dev = cp.asarray(sq_tar)
        D_vec_dev = cp.asarray(D_vec)
        phi_trunc_dev = cp.asarray(phi_trunc)
        lambda_weights = cp.asarray(lambda_ns_s_ns)
    else:
        X_tar_dev = X_tar
        p_tar_dev = p_tar
        sq_tar_dev = sq_tar
        D_vec_dev = D_vec
        phi_trunc_dev = phi_trunc
        lambda_weights = lambda_ns_s_ns
    
    # KSWGD 迭代
    iterator = trange(num_iters - 1, desc="KSWGD 传输", unit="step")
    for t in iterator:
        current = x_hist[:, :, t]
        
        # 计算核梯度
        grad_matrix = grad_fn(current, X_tar_dev, p_tar_dev, sq_tar_dev, D_vec_dev, eps_kswgd)
        cross_matrix = K_eval_fn(X_tar_dev, current, p_tar_dev, sq_tar_dev, D_vec_dev, eps_kswgd)
        
        # 谱分解加速
        tmp = phi_trunc_dev.T @ cross_matrix
        tmp = lambda_weights[:, None] * tmp
        kswgd_push = phi_trunc_dev @ tmp
        
        # 更新粒子位置
        for dim in range(latent_dim):
            sum_term = grad_matrix[:, :, dim] @ kswgd_push
            x_hist[:, dim, t + 1] = x_hist[:, dim, t] - (step_size / num_particles) * xp.sum(sum_term, axis=1)
        
        # 显示进度
        if (t + 1) % 50 == 0:
            step_norm = x_hist[:, :, t + 1] - x_hist[:, :, t]
            mean_disp = float(xp.mean(xp.linalg.norm(step_norm, axis=1)))
            iterator.set_postfix({"mean_step": f"{mean_disp:.3e}"})
    
    # 返回最终样本
    samples_std = x_hist[:, :, -1]
    if use_gpu:
        samples_std = cp.asnumpy(samples_std)
    
    return np.asarray(samples_std, dtype=np.float64)


def decode_latents_to_images(flat_latents_std):
    """
    将标准化的 latent vectors 解码为图像
    
    Args:
        flat_latents_std: 标准化的 latent vectors (N, latent_dim)
    
    Returns:
        解码后的图像 tensor (N, 3, H, W)
    """
    # 反标准化
    flat_latents = flat_latents_std * Z_std + Z_mean
    
    # 重塑为 (N, C, H, W)
    latents = flat_latents.reshape(-1, *latent_shape)
    latents_tensor = torch.from_numpy(latents).float().to(device)
    
    # VAE 解码
    vae.eval()
    with torch.no_grad():
        decoded = vae.decode(latents_tensor / vae_scaling).sample
        decoded_rgb = _from_vae_range(decoded)
    
    return decoded_rgb.cpu()

print("KSWGD 采样器定义完成！")

### 8.1 运行 KSWGD 生成

In [None]:
# 运行 KSWGD 采样生成新的 latent vectors
kswgd_config = {
    "num_particles": 16,   # 生成 16 张图像
    "num_iters": 300,      # 迭代次数
    "step_size": 0.03,     # 步长（较小以保持稳定）
    "rng_seed": 42
}

print("=" * 50)
print("KSWGD 生成配置:")
for k, v in kswgd_config.items():
    print(f"  {k}: {v}")
print("=" * 50)

# 运行 KSWGD
start_time = time.time()
Z_kswgd_std = run_kswgd_sampler(**kswgd_config)
kswgd_time = time.time() - start_time

print(f"\nKSWGD 完成！")
print(f"生成样本 shape: {Z_kswgd_std.shape}")
print(f"总耗时: {kswgd_time:.1f} 秒")

In [None]:
# 解码 KSWGD 生成的 latent vectors 为图像
print("正在解码 latent vectors 为图像...")

kswgd_images = decode_latents_to_images(Z_kswgd_std)
kswgd_images_np = kswgd_images.numpy()

print(f"生成图像 shape: {kswgd_images_np.shape}")
print(f"像素值范围: [{kswgd_images_np.min():.3f}, {kswgd_images_np.max():.3f}]")

In [None]:
# 可视化 KSWGD 生成的图像
n_show = min(16, kswgd_images_np.shape[0])
n_cols = 4
n_rows = (n_show + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows))
axes = np.asarray(axes).reshape(-1)

for idx in range(n_show):
    img = np.transpose(kswgd_images_np[idx], (1, 2, 0))  # (C,H,W) → (H,W,C)
    axes[idx].imshow(np.clip(img, 0.0, 1.0))
    axes[idx].set_title(f"KSWGD #{idx+1}")
    axes[idx].axis('off')

# 隐藏多余的子图
for idx in range(n_show, len(axes)):
    axes[idx].axis('off')

plt.suptitle("KSWGD 生成的 CIFAR-10 图像\n(VAE Latent Space + 核梯度流)", fontsize=14)
plt.tight_layout()
plt.show()

print(f"\n对比 Stable Diffusion:")
print(f"  - SD: 随机噪声 → UNet 去噪 (20步) → VAE Decode")
print(f"  - KSWGD: 随机噪声 → 核梯度流传输 ({kswgd_config['num_iters']}步) → VAE Decode")

### 8.2 对比：原始 CIFAR-10 vs KSWGD 生成

In [None]:
# 并排对比原始图像和 KSWGD 生成图像
from torchvision.utils import make_grid

# 从训练集随机选几张原始图像
n_compare = 8
random_indices = np.random.choice(len(train_dataset), n_compare, replace=False)

original_imgs = []
for idx in random_indices:
    img, _ = train_dataset[idx]
    original_imgs.append(img)
original_batch = torch.stack(original_imgs)

# VAE 重建原始图像（作为参考）
vae.eval()
with torch.no_grad():
    original_latents = vae.encode(_to_vae_range(original_batch.to(device))).latent_dist.mode()
    reconstructed = vae.decode(original_latents).sample
    reconstructed_rgb = _from_vae_range(reconstructed).cpu()

# KSWGD 生成图像
kswgd_batch = torch.from_numpy(kswgd_images_np[:n_compare])

# 创建对比图
fig, axes = plt.subplots(3, n_compare, figsize=(2 * n_compare, 6))

for i in range(n_compare):
    # 原始 CIFAR-10
    orig = original_batch[i].permute(1, 2, 0).numpy()
    axes[0, i].imshow(np.clip(orig, 0, 1))
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_ylabel("Original", fontsize=12)
    
    # VAE 重建
    recon = reconstructed_rgb[i].permute(1, 2, 0).numpy()
    axes[1, i].imshow(np.clip(recon, 0, 1))
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_ylabel("VAE Recon", fontsize=12)
    
    # KSWGD 生成
    kswgd_img = kswgd_batch[i].permute(1, 2, 0).numpy()
    axes[2, i].imshow(np.clip(kswgd_img, 0, 1))
    axes[2, i].axis('off')
    if i == 0:
        axes[2, i].set_ylabel("KSWGD Gen", fontsize=12)

# 设置行标签
axes[0, 0].text(-0.3, 0.5, 'Original\n(CIFAR-10)', transform=axes[0, 0].transAxes, 
                fontsize=10, va='center', ha='right')
axes[1, 0].text(-0.3, 0.5, 'VAE\nReconstruct', transform=axes[1, 0].transAxes,
                fontsize=10, va='center', ha='right')
axes[2, 0].text(-0.3, 0.5, 'KSWGD\nGenerated', transform=axes[2, 0].transAxes,
                fontsize=10, va='center', ha='right')

plt.suptitle("CIFAR-10: Original vs VAE Reconstruction vs KSWGD Generation", fontsize=14)
plt.tight_layout()
plt.show()

## 8. 清理 GPU 显存

In [None]:
# 清理显存
import gc

del pipe
del vae
gc.collect()
torch.cuda.empty_cache()

print(f"当前显存使用: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"显存缓存: {torch.cuda.memory_reserved() / 1e9:.2f} GB")