In [None]:
import torch
import torch.nn as nn
from transformer_core import CrossAttention # 假设你把上节课的代码存为了这个名字，或者直接复制过来

class ResNetBlock(nn.Module):
    """
    基础的卷积模块，用于处理图像特征。
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.act = nn.SiLU() # Diffusion 常用激活函数
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)

    def forward(self, x):
        h = self.act(self.conv1(x))
        return self.conv2(h) + x # 残差连接 (防止梯度消失)

class SpatialTransformer(nn.Module):
    """
    【关键模块】：将文本信息注入图像的地方。
    包含：Self-Attention (图像自己整理逻辑) + Cross-Attention (听文本指挥)
    """
    def __init__(self, channels, context_dim):
        super().__init__()
        self.norm = nn.GroupNorm(32, channels)
        # 1x1 卷积把 (B, C, H, W) 变成 (B, C, H*W) 以便 Transformer 处理
        self.proj_in = nn.Conv2d(channels, channels, kernel_size=1)
        
        # 你的 CrossAttention 代码在这里被调用！
        self.transformer = CrossAttention(d_model=channels, d_context=context_dim, n_head=4)
        
        self.proj_out = nn.Conv2d(channels, channels, kernel_size=1)

    def forward(self, x, context):
        b, c, h, w = x.shape
        x_in = x
        
        # 1. 变形: Image -> Sequence
        x = self.norm(x)
        x = self.proj_in(x)
        x = x.flatten(2).transpose(1, 2) # (B, H*W, C)
        
        # 2. Attention: 图像特征 x 去查询 文本特征 context
        x = self.transformer(x, context)
        
        # 3. 还原: Sequence -> Image
        x = x.transpose(1, 2).reshape(b, c, h, w)
        return self.proj_out(x) + x_in

class SimpleUNet(nn.Module):
    """
    极简版 U-Net 架构，展示核心逻辑。
    """
    def __init__(self, in_channels=320, context_dim=768):
        super().__init__()
        # 1. 编码器 (下采样)
        self.down_blocks = nn.ModuleList([
            ResNetBlock(in_channels, in_channels),
            ResNetBlock(in_channels, in_channels)
        ])
        
        # 2. 中间层 (最关键的地方，通常在这里加 Attention)
        self.mid_block1 = ResNetBlock(in_channels, in_channels)
        self.mid_attn = SpatialTransformer(in_channels, context_dim) # <-- 文本在这里注入
        self.mid_block2 = ResNetBlock(in_channels, in_channels)
        
        # 3. 解码器 (上采样)
        self.up_blocks = nn.ModuleList([
            ResNetBlock(in_channels, in_channels),
            ResNetBlock(in_channels, in_channels)
        ])

    def forward(self, x, context):
        # x: 噪声图像, context: 文本提示词向量
        
        # Down
        for block in self.down_blocks:
            x = block(x)
            
        # Middle (文本控制发生在这里)
        x = self.mid_block1(x)
        x = self.mid_attn(x, context) # Cross-Attention!
        x = self.mid_block2(x)
        
        # Up
        for block in self.up_blocks:
            x = block(x)
            
        return x

# --- 测试代码 ---
if __name__ == "__main__":
    # 模拟输入
    dummy_img = torch.randn(1, 320, 32, 32) # Batch, Channel, H, W
    dummy_txt = torch.randn(1, 77, 768)     # Batch, Token, Dim (CLIP output)
    
    model = SimpleUNet()
    output = model(dummy_img, dummy_txt)
    
    print(f"Input Image: {dummy_img.shape}")
    print(f"Output Image: {output.shape}")
    print("U-Net structure verified. Text injection successful.")