In [None]:
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from src.gaussian_diffusion import GaussianDiffusion

class DiffusionTrainer:
    def __init__(
        self,
        model: torch.nn.Module,
        diffusion: GaussianDiffusion,
        train_loader: DataLoader,
        optimizer: Optimizer,
        scheduler: _LRScheduler = None,
        device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        save_dir: str = "./checkpoints",
        grad_clip: float = 1.0,
        save_interval: int = 50
    ):
        """
        扩散模型训练器

        参数:
            model: 待训练的UNet模型
            diffusion: 配置好的扩散模型
            train_loader: 训练数据加载器
            optimizer: 优化器
            scheduler: 学习率调度器 (可选)
            device: 训练设备
            save_dir: 模型保存路径
            grad_clip: 梯度裁剪阈值
            save_interval: 检查点保存间隔
        """
        self.model = model.to(device)
        self.diffusion = diffusion.to(device)
        self.train_loader = train_loader
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.save_dir = save_dir
        self.grad_clip = grad_clip
        self.save_interval = save_interval

        # 训练状态跟踪
        self.epoch = 0
        self.best_loss = float('inf')
        self.train_losses = []

        # 创建保存目录
        os.makedirs(save_dir, exist_ok=True)

    def train_epoch(self) -> float:
        """执行一个epoch的训练"""
        self.model.train()
        epoch_loss = 0.0

        with tqdm(self.train_loader, desc=f"Epoch {self.epoch+1}") as pbar:
            for x_batch, label_batch in pbar:
                # 数据转移到设备
                x_batch = x_batch.to(self.device)
                label_batch = label_batch.to(self.device)

                # 采样时间步
                B = x_batch.size(0)
                t = torch.randint(0, self.diffusion.num_steps, (B,), device=self.device).long()

                # 前向计算损失
                losses = self.diffusion.training_losses(
                    model=self.model,
                    x_start=x_batch,
                    t=t,
                    batch_labels=label_batch
                )
                loss = losses["loss"].mean()

                # 反向传播
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
                self.optimizer.step()

                # 记录损失
                epoch_loss += loss.item()
                pbar.set_postfix({
                    'Loss': loss.item(),
                    'MSE': losses['mse'].mean().item(),
                    'VB': losses['vb'].mean().item()
                })

        return epoch_loss / len(self.train_loader)

    def save_checkpoint(self, is_best: bool = False):
        """保存训练状态"""
        state = {
            'epoch': self.epoch,
            'model_state': self.model.state_dict(),
            'optimizer_state': self.optimizer.state_dict(),
            'loss': self.train_losses[-1],
            'best_loss': self.best_loss
        }

        if is_best:
            torch.save(state, os.path.join(self.save_dir, 'best_model.pth'))
        else:
            torch.save(state, os.path.join(self.save_dir, f'checkpoint_epoch_{self.epoch+1}.pth'))

    def plot_loss_curve(self):
        """绘制损失曲线"""
        plt.figure(figsize=(10, 5))
        plt.plot(self.train_losses, label='Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss Curve')
        plt.legend()
        plt.savefig(os.path.join(self.save_dir, 'loss_curve.png'))
        plt.close()

    def train(self, num_epochs: int):
        """完整训练流程"""
        for _ in range(num_epochs):
            avg_loss = self.train_epoch()
            self.train_losses.append(avg_loss)

            # 更新学习率
            if self.scheduler is not None:
                self.scheduler.step()

            # 保存最佳模型
            if avg_loss < self.best_loss:
                self.best_loss = avg_loss
                self.save_checkpoint(is_best=True)
                print(f"New best loss: {self.best_loss:.4f}")

            # 定期保存
            if (self.epoch + 1) % self.save_interval == 0:
                self.save_checkpoint()
                self.plot_loss_curve()

            self.epoch += 1
