In [None]:
# 扩散模型MNIST图片生成实验

# 1. 导入库
import sys
sys.path.append('..')

import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display, HTML

# 2. 加载配置和模型
from config import Config
from data_loader import MNISTDataLoader
from diffusion_process import DiffusionProcess
from unet_model import ConditionalUNet
from generator import ImageGenerator

# 3. 设置
config = Config()
set_seed(42)

# 4. 数据可视化
dataloader = MNISTDataLoader(config)
train_loader, _ = dataloader.get_dataloaders()
dataloader.visualize_dataset(num_samples=25)

# 5. 扩散过程可视化
diffusion = DiffusionProcess(config)
x, _ = next(iter(train_loader))
x = x[:8].to(config.device)
diffusion.visualize_diffusion_process(x)

# 6. 模型训练（简化版）
from trainer import DiffusionTrainer

model = ConditionalUNet(config)
trainer = DiffusionTrainer(model, diffusion, config)

# 训练几个epoch
trainer.train(train_loader)

# 7. 生成图片
generator = ImageGenerator(model, diffusion, config)

# 生成数字5
generator.generate_digit_grid(digit=5, num_rows=4, num_cols=4)

# 生成所有数字
generator.generate_digit_grid(digit=None)

# 8. 可视化采样过程
generator.visualize_sampling_process(digit=5)

# 9. 创建插值视频
generator.create_interpolation_video(digit1=0, digit2=9, num_frames=30)