In [None]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import Optimizer
from tqdm import tqdm
import matplotlib.pyplot as plt
from typing import Optional, Callable
from src.gaussian_diffusion import GaussianDiffusion

class DiffusionTrainingSystem:
    def __init__(
        self,
        model: torch.nn.Module,
        diffusion: GaussianDiffusion,
        dataset: Dataset,
        device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        batch_size: int = 64,
        optimizer_class: type = AdamW,
        lr: float = 3e-4,
        grad_clip: float = 1.0,
        save_dir: str = "saved_models",
        save_interval: int = 50,
        data_preprocess_fn: Optional[Callable] = None,
    ):
        """
        扩散模型训练系统

        参数:
            model: 待训练的UNet模型
            diffusion: 扩散过程处理器
            dataset: 原始数据集对象
            device: 训练设备
            batch_size: 批次大小
            optimizer_class: 优化器类型
            lr: 初始学习率
            grad_clip: 梯度裁剪阈值
            save_dir: 模型保存路径
            save_interval: 检查点保存间隔
            data_preprocess_fn: 自定义数据预处理函数
        """
        self.device = device
        self.model = model.to(device)
        self.diffusion = diffusion.to(device)
        self.grad_clip = grad_clip
        self.save_dir = save_dir
        self.save_interval = save_interval

        # 数据预处理流水线
        self.processed_data = self._prepare_data(dataset, data_preprocess_fn)

        # 优化器配置
        self.optimizer = optimizer_class(self.model.parameters(), lr=lr)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=1000)  # 可配置化

        # 训练状态
        self.epoch = 0
        self.best_loss = float('inf')
        self.loss_history = {'train': []}

        # 创建保存目录
        os.makedirs(save_dir, exist_ok=True)

    def _prepare_data(self, dataset: Dataset, preprocess_fn: Optional[Callable]) -> DataLoader:
        """数据预处理与加载器构建"""
        # 自动提取特征和标签
        features, labels = [], []
        for data in dataset:
            features.append(data.x)
            labels.append(data.y)

        # 转换为张量并添加通道维度
        features = torch.stack(features).unsqueeze(1)  # [N, 1, H, W]
        labels = torch.stack(labels)

        # 自定义预处理
        if preprocess_fn is not None:
            features, labels = preprocess_fn(features, labels)

        # 构建数据加载器
        tensor_dataset = TensorDataset(features.to(self.device), labels.to(self.device))
        return DataLoader(tensor_dataset, batch_size=64, shuffle=True, pin_memory=True)

    def _training_step(self, batch: tuple) -> dict:
        """单批次训练逻辑"""
        x_batch, label_batch = batch
        B = x_batch.size(0)

        # 时间步采样
        t = torch.randint(0, self.diffusion.num_steps, (B,), device=self.device).long()

        # 扩散损失计算
        losses = self.diffusion.training_losses(
            model=self.model,
            x_start=x_batch,
            t=t,
            batch_labels=label_batch
        )

        # 反向传播
        self.optimizer.zero_grad()
        losses['loss'].mean().backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
        self.optimizer.step()

        return {k: v.mean().item() for k, v in losses.items()}

    def _save_checkpoint(self, is_best: bool = False):
        """保存训练状态"""
        state = {
            'epoch': self.epoch,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'loss': self.loss_history['train'][-1],
            'best_loss': self.best_loss
        }

        filename = 'best_model.pth' if is_best else f'checkpoint_epoch_{self.epoch}.pth'
        torch.save(state, os.path.join(self.save_dir, filename))

    def _visualize_progress(self):
        """训练过程可视化"""
        plt.figure(figsize=(12, 4))

        plt.subplot(121)
        plt.plot(self.loss_history['train'], label='Training Loss')
        plt.title("Loss Curve")
        plt.xlabel("Epoch")
        plt.legend()

        plt.subplot(122)
        plt.plot(self.loss_history.get('val', []), label='Validation Loss', color='orange')
        plt.title("Validation Metrics")
        plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, 'training_progress.png'))
        plt.close()

    def train(self, num_epochs: int, enable_progress_bar: bool = True):
        """完整训练流程"""
        for _ in range(num_epochs):
            self.model.train()
            epoch_loss = 0.0
            progress_bar = tqdm(self.processed_data, desc=f"Epoch {self.epoch+1}") if enable_progress_bar else self.processed_data

            for batch in progress_bar:
                step_metrics = self._training_step(batch)
                epoch_loss += step_metrics['loss']

                if enable_progress_bar:
                    progress_bar.set_postfix({k: f"{v:.4f}" for k, v in step_metrics.items()})

            # 记录与更新
            avg_loss = epoch_loss / len(self.processed_data)
            self.loss_history['train'].append(avg_loss)
            self.scheduler.step()

            # 模型保存逻辑
            if avg_loss < self.best_loss:
                self.best_loss = avg_loss
                self._save_checkpoint(is_best=True)

            if (self.epoch + 1) % self.save_interval == 0:
                self._save_checkpoint()
                self._visualize_progress()

            self.epoch += 1

    def load_checkpoint(self, checkpoint_path: str):
        """加载训练状态"""
        state = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(state['model'])
        self.optimizer.load_state_dict(state['optimizer'])
        self.epoch = state['epoch']
        self.best_loss = state['best_loss']
        print(f"Loaded checkpoint from epoch {self.epoch} with loss {state['loss']:.4f}")
