

以下是 **为扩散模型添加标签条件信息的完整方案**，包含关键实现步骤和代码示例：

---

### **1. 条件嵌入架构设计**
```mermaid
graph TD
    T[时间步t] --> TE[时间嵌入]
    Y[标签y] --> LE[标签嵌入]
    TE --> FC[融合层]
    LE --> FC
    FC --> CB[条件向量]
    CB --> UNet[UNet各层]
```

---

### **2. 核心代码修改**

#### **(1) 条件嵌入层实现**
```python
class ConditionalEmbedding(nn.Module):
    def __init__(self, num_classes, time_dim=256, label_dim=128):
        super().__init__()
        self.time_embed = nn.Sequential(
            nn.Linear(time_dim, time_dim*4),
            nn.SiLU(),
            nn.Linear(time_dim*4, time_dim)
        )
        self.label_embed = nn.Embedding(num_classes, label_dim)
        self.fusion = nn.Sequential(
            nn.Linear(time_dim + label_dim, time_dim*2),
            nn.SiLU(),
            nn.Linear(time_dim*2, time_dim)
        )

    def forward(self, t, y):
        # t: [B,] 时间步
        # y: [B,] 标签
        t_emb = sinusoidal_embedding(t, self.time_embed[0].in_features)
        t_emb = self.time_embed(t_emb)  # [B, time_dim]

        l_emb = self.label_embed(y)     # [B, label_dim]

        # 融合时间与标签信息
        combined = torch.cat([t_emb, l_emb], dim=1)
        return self.fusion(combined)    # [B, time_dim]
```

#### **(2) 修改UNet输入**
```python
class ConditionalUNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.cond_embed = ConditionalEmbedding(num_classes)

        # 原UNet结构保持不变，但所有MyBlock需接收条件向量
        self.down = nn.ModuleList([
            MyBlock(..., cond_dim=256)  # 增加条件输入维度
        ])

    def forward(self, x, t, y):
        cond = self.cond_embed(t, y)

        # 将cond传递到每个MyBlock
        for block in self.down:
            x = block(x, cond)
        # ... 后续处理 ...
```

#### **(3) 调整残差块**
```python
class MyBlock(nn.Module):
    def __init__(self, ..., cond_dim):
        super().__init__()
        self.cond_proj = nn.Linear(cond_dim, out_ch*2)

        # 在卷积后注入条件信息
        self.conv = nn.Sequential(
            nn.Conv2d(...),
            AdaGN(out_ch, cond_dim)  # 自适应组归一化
        )

class AdaGN(nn.Module):
    """自适应组归一化"""
    def __init__(self, channels, cond_dim):
        super().__init__()
        self.norm = nn.GroupNorm(8, channels)
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(cond_dim, channels*2)
        )

    def forward(self, x, cond):
        scale, shift = self.mlp(cond).chunk(2, dim=1)
        x = self.norm(x)
        return x * (1 + scale[:,:,None,None]) + shift[:,:,None,None]
```

---

### **3. 训练流程改造**

#### **(1) 数据加载器**
```python
# 假设数据集返回 (data, label)
dataset = YourDataset(data_tensor, labels)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
```

#### **(2) 训练循环调整**
```python
for batch in dataloader:
    x0, y = batch  # x0: [B,24,50], y: [B]
    x0 = x0.to(device)
    y = y.to(device)

    # 加噪过程
    t = torch.randint(0, num_steps, (x0.size(0),), device=device)
    noise = torch.randn_like(x0)
    xt = q_sample(x0, t, noise)

    # 条件预测
    pred_noise = model(xt, t, y)  # 传入标签y

    # 损失计算
    loss = F.mse_loss(pred_noise, noise)
```

---

### **4. 条件采样实现**
```python
@torch.no_grad()
def p_sample_loop(model, shape, labels, device):
    """带标签条件的采样"""
    img = torch.randn(shape, device=device)
    for i in reversed(range(num_steps)):
        t = torch.full((shape[0],), i, device=device)
        img = p_sample(model, img, t, labels)  # 传入标签
    return img

def generate_class_samples(model, num_samples, label, device):
    labels = torch.full((num_samples,), label, device=device)
    samples = p_sample_loop(model, (num_samples,24,50), labels, device)
    return samples.cpu().numpy()
```

---

### **5. 条件控制策略对比**

| 方法 | 优点 | 缺点 | 适用场景 |
|------|------|------|---------|
| **嵌入拼接** | 实现简单 | 条件信息可能被稀释 | 低维条件 |
| **自适应归一化** | 细粒度控制 | 计算量稍大 | 高质量生成 |
| **交叉注意力** | 显式对齐 | 需要设计注意力层 | 文本等复杂条件 |

---

### **6. 效果验证方法**

#### **(1) 类别控制测试**
```python
# 生成不同类别的样本
class_labels = [0, 1, 2]
for label in class_labels:
    samples = generate_class_samples(model, 10, label)
    plot_features(samples)  # 可视化特征分布
```

#### **(2) 条件相似度评估**
```python
from sklearn.metrics import accuracy_score

# 用分类器评估生成样本的标签一致性
classifier = load_pretrained_classifier()
gen_labels = classifier.predict(samples)
acc = accuracy_score(true_labels, gen_labels)
print(f"条件一致性准确率: {acc:.2%}")
```

---

### **常见问题解决方案**

**问题1：条件信息不起作用**
- 检查标签嵌入是否参与梯度更新
- 验证条件向量的范数是否与时间嵌入相当
- 尝试增大标签嵌入维度（如128→256）

**问题2：生成样本多样性下降**
- 在条件嵌入后添加Dropout（keep_prob=0.9）
- 使用标签平滑（Label Smoothing）
- 引入随机噪声到条件向量

**问题3：训练不稳定**
- 对标签嵌入进行L2归一化
- 使用梯度裁剪（max_norm=1.0）
- 降低条件融合层的学习率

---

如果需要某个模块的详细实现（如自适应归一化），或想了解如何平衡条件控制与生成多样性，可以告诉我具体方向！



以下是 **将现有U-Net改造为条件U-Net的完整代码修改方案**，包含关键修改点和注释说明：

---

### **1. 条件U-Net类改造**
```python
class ConditionalDiffusionUNet(nn.Module):
    def __init__(self, num_classes, time_dim=128, label_dim=64):
        super().__init__()
        chs = [1, 64, 128, 256]

        # 替换为条件嵌入层
        self.cond_embed = ConditionalEmbedding(
            num_classes=num_classes,
            time_dim=time_dim,
            label_dim=label_dim
        )
        cond_dim = time_dim  # 条件向量的总维度

        # 下采样路径（修改所有MyBlock的cond_dim）
        self.down = nn.ModuleList([
            MySequential(
                MyBlock(chs[i], chs[i+1], cond_dim=cond_dim),
                MyBlock(chs[i+1], chs[i+1], cond_dim=cond_dim),
                Residual(PreNorm(chs[i+1], LinearAttention(chs[i+1])))
            ) for i in range(len(chs)-1)
        ])

        # 中间层
        self.mid = MySequential(
            MyBlock(chs[-1], chs[-1], cond_dim=cond_dim),
            Residual(PreNorm(chs[-1], Attention(chs[-1]))),
            MyBlock(chs[-1], chs[-1], cond_dim=cond_dim)
        )

        # 上采样路径
        self.up = nn.ModuleList([
            MySequential(
                MyBlock(chs[i+1]*2, chs[i], cond_dim=cond_dim),
                MyBlock(chs[i], chs[i], cond_dim=cond_dim),
                Residual(PreNorm(chs[i], LinearAttention(chs[i])))
            ) for i in reversed(range(len(chs)-1))
        ])

        self.final = nn.Conv2d(chs[0], 1, 1)

    def forward(self, x, t, y):
        """新增标签y作为输入"""
        x = x.unsqueeze(1)  # [B,1,24,50]
        cond_emb = self.cond_embed(t, y)  # 获取融合条件向量
        skips = []

        # 编码器（传递cond_emb）
        for block in self.down:
            x = block(x, cond_emb)
            skips.append(x)
            x = F.max_pool2d(x, kernel_size=(2,1))

        # 中间处理
        x = self.mid(x, cond_emb)

        # 解码器
        for i, block in enumerate(self.up):
            x = F.interpolate(x, scale_factor=(2,1), mode='nearest')
            x = torch.cat([x, skips[-(i+1)]], dim=1)
            x = block(x, cond_emb)

        return self.final(x).squeeze(1)
```

---

### **2. 修改MyBlock支持条件输入**
```python
class MyBlock(nn.Module):
    def __init__(self, in_ch, out_ch, cond_dim):
        super().__init__()
        # 条件投影层
        self.cond_proj = nn.Sequential(
            nn.Linear(cond_dim, out_ch * 2),
            nn.GELU()
        )

        self.ds_conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)

        self.conv = nn.Sequential(
            nn.GroupNorm(8, out_ch),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GELU(),
            nn.GroupNorm(8, out_ch),
            nn.Conv2d(out_ch, out_ch, 3, padding=1)
        )

        self.res_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, cond):
        """新增cond参数"""
        h = self.ds_conv(x)

        # 条件注入
        scale, shift = self.cond_proj(cond).chunk(2, dim=1)
        h = h * (1 + scale[:, :, None, None]) + shift[:, :, None, None]

        h = self.conv(h)
        return h + self.res_conv(x)
```

---

### **3. 配套修改训练流程**
```python
# 训练循环修改
for batch in dataloader:
    x0, labels = batch  # 假设数据返回 (数据, 标签)
    x0 = x0.to(device)
    labels = labels.to(device)

    # 加噪过程
    t = torch.randint(0, num_steps, (x0.size(0),), device=device)
    noise = torch.randn_like(x0)
    xt = q_sample(x0, t, noise)

    # 条件预测
    pred_noise = model(xt, t, labels)  # 传入标签

    # 计算损失
    loss = F.mse_loss(pred_noise, noise)
```

---

### **4. 条件采样函数修改**
```python
@torch.no_grad()
def p_sample_loop(model, shape, labels, device):
    """带标签条件的采样"""
    img = torch.randn(shape, device=device)
    for i in reversed(range(num_steps)):
        t = torch.full((shape[0],), i, device=device, dtype=torch.long)
        img = p_sample(model, img, t, labels)
    return img

def generate_class_samples(model, num_samples, label, device):
    labels = torch.full((num_samples,), label, device=device)
    samples = p_sample_loop(model, (num_samples,24,50), labels, device)
    return samples.cpu().numpy()
```

---

### **关键修改点说明**
| 模块 | 修改内容 | 作用 |
|------|----------|------|
| **U-Net初始化** | 添加`num_classes`参数 | 支持标签输入 |
| **条件嵌入** | 替换`time_embed`为`cond_embed` | 融合时间+标签信息 |
| **MyBlock** | 新增`cond_proj`和条件缩放 | 实现条件特征调制 |
| **前向传播** | 所有模块传递`cond_emb` | 条件信息贯穿网络 |

---

### **验证测试**
```python
def test_conditional_unet():
    num_classes = 10
    model = ConditionalDiffusionUNet(num_classes=num_classes).cuda()

    # 测试输入
    x = torch.randn(4, 24, 50).cuda()
    t = torch.randint(0, 1000, (4,)).cuda()
    labels = torch.randint(0, num_classes, (4,)).cuda()

    # 前向传播
    output = model(x, t, labels)
    assert output.shape == (4, 24, 50), "输出形状错误"

    # 梯度测试
    loss = output.mean()
    loss.backward()
    print("梯度测试通过")

test_conditional_unet()
```

---

### **可能遇到的问题及解决**
1. **维度不匹配错误**
   - 检查`cond_embed`的输出维度是否与`MyBlock`的`cond_dim`一致
   - 确保所有`cond_proj`层的输入维度匹配

2. **标签泄漏问题**
   - 验证在采样时是否使用正确的标签输入
   - 检查训练时标签是否与数据正确对应

3. **条件控制不足**
   - 增大标签嵌入维度（如128→256）
   - 在条件投影后添加非线性激活

需要我解释条件注入的具体实现细节吗？或是展示如何可视化条件控制效果？