# 运行环境
`diffusers == 0.121`

# VAE 的作用
VAE 可以从潜在空间（latent space）中生成新的样本。通过学习数据的潜在表示，VAE 可以生成与训练数据相似的新数据。例如，可以用 VAE 生成新
的图像、音频、文本等

In [1]:
import torch


# 定义残差连接层的类
class Resnet(torch.nn.Module):

    def __init__(self, dim_in, dim_out):
        super().__init__()

        # 定义一个序列化的神经网络模块
        self.s = torch.nn.Sequential(
            torch.nn.GroupNorm(num_groups=32,  # 归一化层，分成32组
                               num_channels=dim_in,  # 输入通道数
                               eps=1e-6,  # 防止除零
                               affine=True),  # 允许缩放和平移
            torch.nn.SiLU(),  # SiLU激活函数
            torch.nn.Conv2d(dim_in,  # 输入通道数
                            dim_out,  # 输出通道数
                            kernel_size=3,  # 卷积核大小
                            stride=1,  # 步幅
                            padding=1),  # 填充，使得输出大小与输入相同
            torch.nn.GroupNorm(num_groups=32,  # 归一化层，分成32组
                               num_channels=dim_out,  # 输出通道数
                               eps=1e-6,  # 防止除零
                               affine=True),  # 允许缩放和平移
            torch.nn.SiLU(),  # SiLU激活函数
            torch.nn.Conv2d(dim_out,  # 输入通道数
                            dim_out,  # 输出通道数
                            kernel_size=3,  # 卷积核大小
                            stride=1,  # 步幅
                            padding=1),  # 填充，使得输出大小与输入相同
        )

        # 如果输入和输出的通道数不一致，则使用1x1卷积调整输入的通道数
        self.res = None
        if dim_in != dim_out:
            self.res = torch.nn.Conv2d(dim_in,  # 输入通道数
                                       dim_out,  # 输出通道数
                                       kernel_size=1,  # 1x1卷积核
                                       stride=1,  # 步幅
                                       padding=0)  # 无填充

    def forward(self, x):
        # x的形状为 [batch_size, channels, height, width]，这里为 [1, 128, 10, 10]
        res = x  # 保存输入，作为残差项
        if self.res:
            # 如果需要调整通道数，使用1x1卷积调整残差的形状为 [1, 256, 10, 10]
            res = self.res(x)

        # 计算主分支的输出，并与残差相加
        # 主分支的输出形状为 [1, 256, 10, 10]
        return res + self.s(x)  # 残差相加后的输出

# 测试 Resnet 类，输入形状为 [1, 128, 10, 10]，期望输出形状为 [1, 256, 10, 10]
Resnet(128, 256)(torch.randn(1, 128, 10, 10)).shape


torch.Size([1, 256, 10, 10])

In [2]:
# VAE 自注意力层类(单头自注意力)
class Atten(torch.nn.Module):

    def __init__(self):
        super().__init__()
        # norm 层
        self.norm = torch.nn.GroupNorm(num_channels=512,
                                       num_groups=32,
                                       eps=1e-6,
                                       affine=True)

        """
        在注意力机制中，我们对每个输入元素（例如句子中的一个单词）生成三个向量：
            查询向量 (Query Vector, Q)
            键向量 (Key Vector, K)
            值向量 (Value Vector, V)
        """
        self.q = torch.nn.Linear(512, 512)
        self.k = torch.nn.Linear(512, 512)
        self.v = torch.nn.Linear(512, 512)
        self.out = torch.nn.Linear(512, 512)

    # 单头, 无 mask
    def forward(self, x):
        #x -> [1, 512, 64, 64]
        res = x

        #norm,维度不变
        #[1, 512, 64, 64]
        x = self.norm(x)

        #[1, 512, 64, 64] -> [1, 512, 4096] -> [1, 4096, 512]
        x = x.flatten(start_dim=2).transpose(1, 2)

        #线性运算,维度不变
        #[1, 4096, 512]
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)

        #[1, 4096, 512] -> [1, 512, 4096]
        k = k.transpose(1, 2)

        #[1, 4096, 512] * [1, 512, 4096] -> [1, 4096, 4096]
        #0.044194173824159216 = 1 / 512**0.5
        

        # 照理来说应该是等价的,但是却有很小的误差
        # atten = q.bmm(k) * 0.044194173824159216
        atten = torch.baddbmm(torch.empty(1, 4096, 4096, device=q.device),
                              q,
                              k,
                              beta=0,
                              alpha=0.044194173824159216)

        atten = torch.softmax(atten, dim=2)

        #[1, 4096, 4096] * [1, 4096, 512] -> [1, 4096, 512]
        atten = atten.bmm(v)

        #线性运算,维度不变
        #[1, 4096, 512]
        atten = self.out(atten)

        #[1, 4096, 512] -> [1, 512, 4096] -> [1, 512, 64, 64]
        atten = atten.transpose(1, 2).reshape(-1, 512, 64, 64)

        # 残差连接,维度不变
        # [1, 512, 64, 64]
        atten = atten + res

        return atten


Atten()(torch.randn(1, 512, 64, 64)).shape

torch.Size([1, 512, 64, 64])

In [3]:
# 工具层
class Pad(torch.nn.Module):

    def forward(self, x):
        # 给数据的最右边跟最下边增加一行跟一列的 0
        return torch.nn.functional.pad(x, (0, 1, 0, 1),
                                       mode='constant',
                                       value=0)


Pad()(torch.ones(1, 2, 5, 5))

tensor([[[[1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [0., 0., 0., 0., 0., 0.]],

         [[1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [0., 0., 0., 0., 0., 0.]]]])

In [4]:
# VAE 模型类
# 分为 encode 跟 decode 两个参数

import torch

# 定义变分自编码器（VAE）模型类，包含编码器和解码器
class VAE(torch.nn.Module):

    def __init__(self):
        super().__init__()

        # 定义编码器部分
        self.encoder = torch.nn.Sequential(
            # 输入层：卷积操作，将输入图像的通道数从3变为128
            torch.nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1),

            # 下采样和特征提取部分（多层 ResNet 模块和卷积下采样）
            torch.nn.Sequential(
                Resnet(128, 128),  # ResNet 残差模块，输入和输出通道均为128
                Resnet(128, 128),  # 再次通过 ResNet 模块
                torch.nn.Sequential(
                    Pad(),  # 填充操作
                    torch.nn.Conv2d(128, 128, 3, stride=2, padding=0),  # 卷积下采样，输出特征图尺寸减半
                ),
            ),
            torch.nn.Sequential(
                Resnet(128, 256),  # ResNet 模块，通道数从128增加到256
                Resnet(256, 256),  # 再次通过 ResNet 模块
                torch.nn.Sequential(
                    Pad(),
                    torch.nn.Conv2d(256, 256, 3, stride=2, padding=0),  # 再次下采样
                ),
            ),
            torch.nn.Sequential(
                Resnet(256, 512),  # ResNet 模块，通道数从256增加到512
                Resnet(512, 512),  # 再次通过 ResNet 模块
                torch.nn.Sequential(
                    Pad(),
                    torch.nn.Conv2d(512, 512, 3, stride=2, padding=0),  # 再次下采样
                ),
            ),
            torch.nn.Sequential(
                Resnet(512, 512),  # 保持通道数为512的 ResNet 模块
                Resnet(512, 512),  # 再次通过 ResNet 模块
            ),

            # 中间层：增加注意力机制
            torch.nn.Sequential(
                Resnet(512, 512),  # ResNet 模块
                Atten(),  # 注意力机制模块
                Resnet(512, 512),  # 再次通过 ResNet 模块
            ),

            # 输出层：标准化和卷积操作
            torch.nn.Sequential(
                torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6),  # 组归一化
                torch.nn.SiLU(),  # SiLU 激活函数
                torch.nn.Conv2d(512, 8, 3, padding=1),  # 将特征图通道数降为8
            ),

            # 最后一层，用于生成正态分布的均值和方差
            torch.nn.Conv2d(8, 8, 1),  # 1x1 卷积，保持通道数为8
        )

        # 定义解码器部分
        self.decoder = torch.nn.Sequential(
            # 输入层：从正态分布中采样得到的4个通道
            torch.nn.Conv2d(4, 4, 1),  # 1x1 卷积，保持通道数为4

            # 初始卷积操作，将通道数从4变为512
            torch.nn.Conv2d(4, 512, kernel_size=3, stride=1, padding=1),

            # 中间部分：ResNet 模块和注意力机制
            torch.nn.Sequential(Resnet(512, 512), Atten(), Resnet(512, 512)),

            # 上采样部分（多层 ResNet 模块和上采样操作）
            torch.nn.Sequential(
                Resnet(512, 512),  # ResNet 模块，保持通道数为512
                Resnet(512, 512),  # 再次通过 ResNet 模块
                Resnet(512, 512),  # 再次通过 ResNet 模块
                torch.nn.Upsample(scale_factor=2.0, mode='nearest'),  # 上采样，特征图尺寸加倍
                torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),  # 卷积，保持通道数为512
            ),
            torch.nn.Sequential(
                Resnet(512, 512),  # ResNet 模块，保持通道数为512
                Resnet(512, 512),  # 再次通过 ResNet 模块
                Resnet(512, 512),  # 再次通过 ResNet 模块
                torch.nn.Upsample(scale_factor=2.0, mode='nearest'),  # 上采样，特征图尺寸加倍
                torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),  # 卷积，保持通道数为512
            ),
            torch.nn.Sequential(
                Resnet(512, 256),  # ResNet 模块，将通道数从512减为256
                Resnet(256, 256),  # 再次通过 ResNet 模块
                Resnet(256, 256),  # 再次通过 ResNet 模块
                torch.nn.Upsample(scale_factor=2.0, mode='nearest'),  # 上采样，特征图尺寸加倍
                torch.nn.Conv2d(256, 256, kernel_size=3, padding=1),  # 卷积，保持通道数为256
            ),
            torch.nn.Sequential(
                Resnet(256, 128),  # ResNet 模块，将通道数从256减为128
                Resnet(128, 128),  # 再次通过 ResNet 模块
                Resnet(128, 128),  # 再次通过 ResNet 模块
            ),

            # 输出层：标准化和卷积操作
            torch.nn.Sequential(
                torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6),  # 组归一化
                torch.nn.SiLU(),  # SiLU 激活函数
                torch.nn.Conv2d(128, 3, 3, padding=1),  # 将通道数降为3，即恢复为RGB图像
            ),
        )

    # 采样函数，从正态分布中生成新的数据
    def sample(self, h):
        # h -> [1, 8, 64, 64]

        # 将特征图的前4个通道作为均值，后4个通道作为方差
        mean = h[:, :4]
        logvar = h[:, 4:]
        std = logvar.exp() ** 0.5  # 计算标准差

        # 从正态分布中采样，并使用重参数技巧生成新的特征图
        h = torch.randn(mean.shape, device=mean.device)
        h = mean + std * h

        # 返回符合正态分布的 h
        return h

    # 前向传播函数
    def forward(self, x):
        # 👇 1 表示一张图片
        # x -> [1, 3, 512, 512]

        # 编码：将输入图像编码为隐变量
        h = self.encoder(x)  # h -> [1, 8, 64, 64]

        # 从隐变量中采样
        h = self.sample(h)  # h -> [1, 4, 64, 64]

        # 解码：将隐变量解码为图像
        h = self.decoder(h)  # h -> [1, 3, 512, 512]

        return h  # 返回重建的图像

# 测试 VAE 模型
VAE()(torch.randn(1, 3, 512, 512)).shape


torch.Size([1, 3, 512, 512])

In [5]:
# 加载【预训练模型】来初始化参数
from diffusers import AutoencoderKL

#加载预训练模型的参数
params = AutoencoderKL.from_pretrained(
    'model/diffsion_from_scratch.params', subfolder='vae')

vae = VAE()

# 加载参数用的函数
def load_res(model, param):
    model.s[0].load_state_dict(param.norm1.state_dict())
    model.s[2].load_state_dict(param.conv1.state_dict())
    model.s[3].load_state_dict(param.norm2.state_dict())
    model.s[5].load_state_dict(param.conv2.state_dict())

    if isinstance(model.res, torch.nn.Module):
        model.res.load_state_dict(param.conv_shortcut.state_dict())

# 加载参数用的函数
def load_atten(model, param):
    print(dir(param)) 
    model.norm.load_state_dict(param.group_norm.state_dict())
    model.q.load_state_dict(param.query.state_dict())
    model.k.load_state_dict(param.key.state_dict())
    model.v.load_state_dict(param.value.state_dict())
    model.out.load_state_dict(param.proj_attn.state_dict())


#encoder.in
vae.encoder[0].load_state_dict(params.encoder.conv_in.state_dict())

#encoder.down
for i in range(4):
    load_res(vae.encoder[i + 1][0], params.encoder.down_blocks[i].resnets[0])
    load_res(vae.encoder[i + 1][1], params.encoder.down_blocks[i].resnets[1])

    if i != 3:
        vae.encoder[i + 1][2][1].load_state_dict(
            params.encoder.down_blocks[i].downsamplers[0].conv.state_dict())

#encoder.mid
load_res(vae.encoder[5][0], params.encoder.mid_block.resnets[0])
load_res(vae.encoder[5][2], params.encoder.mid_block.resnets[1])
load_atten(vae.encoder[5][1], params.encoder.mid_block.attentions[0])

#encoder.out
vae.encoder[6][0].load_state_dict(params.encoder.conv_norm_out.state_dict())
vae.encoder[6][2].load_state_dict(params.encoder.conv_out.state_dict())

#encoder.正态分布层
vae.encoder[7].load_state_dict(params.quant_conv.state_dict())

#decoder.正态分布层
vae.decoder[0].load_state_dict(params.post_quant_conv.state_dict())

#decoder.in
vae.decoder[1].load_state_dict(params.decoder.conv_in.state_dict())

#decoder.mid
load_res(vae.decoder[2][0], params.decoder.mid_block.resnets[0])
load_res(vae.decoder[2][2], params.decoder.mid_block.resnets[1])
load_atten(vae.decoder[2][1], params.decoder.mid_block.attentions[0])

#decoder.up
for i in range(4):
    load_res(vae.decoder[i + 3][0], params.decoder.up_blocks[i].resnets[0])
    load_res(vae.decoder[i + 3][1], params.decoder.up_blocks[i].resnets[1])
    load_res(vae.decoder[i + 3][2], params.decoder.up_blocks[i].resnets[2])

    if i != 3:
        vae.decoder[i + 3][4].load_state_dict(
            params.decoder.up_blocks[i].upsamplers[0].conv.state_dict())

#decoder.out
vae.decoder[7][0].load_state_dict(params.decoder.conv_norm_out.state_dict())
vae.decoder[7][2].load_state_dict(params.decoder.conv_out.state_dict())

  from .autonotebook import tqdm as notebook_tqdm
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
The config attributes {'scaling_factor': 0.18215} were passed to AutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.
  return torch.load(checkpoint_file, map_location="cpu")


['T_destination', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_attention_op', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '_named_members', '_non_persistent_buffers_set', '_parameters', '_register_load_s

<All keys matched successfully>

In [6]:
# 测试 encode
data = torch.randn(1, 3, 512, 512)

a = vae.encoder(data)
b = params.encode(data).latent_dist.parameters

(a == b).all() # tensor(True)

tensor(True)

In [7]:
# 测试 decode
data = torch.randn(1, 4, 64, 64)

a = vae.decoder(data)
b = params.decode(data).sample

(a == b).all() # tensor(True)

tensor(True)