# Unnet 模型的作用
结合文本给图像降噪

In [1]:
import torch


# 定义残差连接层
class Resnet(torch.nn.Module):

    def __init__(self, dim_in, dim_out):
        super().__init__()

        self.time = torch.nn.Sequential(
            torch.nn.SiLU(),
            torch.torch.nn.Linear(1280, dim_out),
            torch.nn.Unflatten(dim=1, unflattened_size=(dim_out, 1, 1)),
        )

        self.s0 = torch.nn.Sequential(
            torch.torch.nn.GroupNorm(num_groups=32,
                                     num_channels=dim_in,
                                     eps=1e-05,
                                     affine=True),
            torch.nn.SiLU(),
            torch.torch.nn.Conv2d(dim_in,
                                  dim_out,
                                  kernel_size=3,
                                  stride=1,
                                  padding=1),
        )

        self.s1 = torch.nn.Sequential(
            torch.torch.nn.GroupNorm(num_groups=32,
                                     num_channels=dim_out,
                                     eps=1e-05,
                                     affine=True),
            torch.nn.SiLU(),
            torch.torch.nn.Conv2d(dim_out,
                                  dim_out,
                                  kernel_size=3,
                                  stride=1,
                                  padding=1),
        )

        self.res = None
        if dim_in != dim_out:
            self.res = torch.torch.nn.Conv2d(dim_in,
                                             dim_out,
                                             kernel_size=1,
                                             stride=1,
                                             padding=0)

    # time 表示了下面 forward 中 x 里边【噪声】的比例
    def forward(self, x, time):
        # x -> [1, 320, 64, 64]
        #time -> [1, 1280]

        res = x

        #[1, 1280] -> [1, 640, 1, 1]
        time = self.time(time)

        #[1, 320, 64, 64] -> [1, 640, 32, 32]
        x = self.s0(x) + time

        #维度不变
        #[1, 640, 32, 32]
        x = self.s1(x)

        #[1, 320, 64, 64] -> [1, 640, 32, 32]
        if self.res:
            res = self.res(res)

        #维度不变
        #[1, 640, 32, 32]
        x = res + x

        return x


Resnet(320, 640)(torch.randn(1, 320, 32, 32), torch.randn(1, 1280)).shape

torch.Size([1, 640, 32, 32])

In [2]:
# 定义 Unet 的注意力层 (计算 图 文 的注意力)
class CrossAttention(torch.nn.Module):

    def __init__(self, dim_q, dim_kv):
        #dim_q -> 320
        #dim_kv -> 768

        super().__init__()

        self.dim_q = dim_q

        # 👇 三个线性运算成为三个矩阵
        self.q = torch.nn.Linear(dim_q, dim_q, bias=False)
        self.k = torch.nn.Linear(dim_kv, dim_q, bias=False)
        self.v = torch.nn.Linear(dim_kv, dim_q, bias=False)

        self.out = torch.nn.Linear(dim_q, dim_q)

    # 🔥 q 是图像数据, kv 是文本数据, 因为 q != kv, 因此这里计算就【⚠️不是自注意力】而是【相互注意力】!
    def forward(self, q, kv):
        #x -> [1, 4096, 320]
        #kv -> [1, 77, 768]

        # 🔨 拆分 q k v 形成多头注意力
        #[1, 4096, 320] -> [1, 4096, 320]
        q = self.q(q)
        #[1, 77, 768] -> [1, 77, 320]
        k = self.k(kv)
        #[1, 77, 768] -> [1, 77, 320]
        v = self.v(kv)

        def reshape(x):
            #x -> [1, 4096, 320]
            b, lens, dim = x.shape

            #[1, 4096, 320] -> [1, 4096, 8, 40]
            x = x.reshape(b, lens, 8, dim // 8)

            #[1, 4096, 8, 40] -> [1, 8, 4096, 40]
            x = x.transpose(1, 2)

            #[1, 8, 4096, 40] -> [8, 4096, 40]
            x = x.reshape(b * 8, lens, dim // 8)

            return x

        #[1, 4096, 320] -> [8, 4096, 40]
        q = reshape(q)
        #[1, 77, 320] -> [8, 77, 40]
        k = reshape(k)
        #[1, 77, 320] -> [8, 77, 40]
        v = reshape(v)

        #[8, 4096, 40] * [8, 40, 77] -> [8, 4096, 77]
        #atten = q.bmm(k.transpose(1, 2)) * (self.dim_q // 8)**-0.5

        #从数学上是等价的,但是在实际计算时会产生很小的误差
        atten = torch.baddbmm(
            torch.empty(q.shape[0], q.shape[1], k.shape[1], device=q.device),
            q,
            k.transpose(1, 2),
            beta=0,
            alpha=(self.dim_q // 8)**-0.5,
        )

        atten = atten.softmax(dim=-1)

        # [8, 4096, 77] * [8, 77, 40] -> [8, 4096, 40]
        atten = atten.bmm(v)

    
        def reshape(x):
            #x -> [8, 4096, 40]
            b, lens, dim = x.shape

            #[8, 4096, 40] -> [1, 8, 4096, 40]
            x = x.reshape(b // 8, 8, lens, dim)

            #[1, 8, 4096, 40] -> [1, 4096, 8, 40]
            x = x.transpose(1, 2)

            #[1, 4096, 320]
            x = x.reshape(b // 8, lens, dim * 8)

            return x

        # [8, 4096, 40] -> [1, 4096, 320]
        # 👀 把多头注意力合成一个
        atten = reshape(atten)

        #[1, 4096, 320] -> [1, 4096, 320]
        atten = self.out(atten)

        return atten


CrossAttention(320, 768)(torch.randn(1, 4096, 320), torch.randn(1, 77,
                                                                768)).shape

torch.Size([1, 4096, 320])

In [3]:
# Transofrmer 层类
class Transformer(torch.nn.Module):

    def __init__(self, dim):
        super().__init__()

        self.dim = dim

        #in
        self.norm_in = torch.nn.GroupNorm(num_groups=32,
                                          num_channels=dim,
                                          eps=1e-6,
                                          affine=True)
        self.cnn_in = torch.nn.Conv2d(dim,
                                      dim,
                                      kernel_size=1,
                                      stride=1,
                                      padding=0)

        #atten
        self.norm_atten0 = torch.nn.LayerNorm(dim, elementwise_affine=True)
        self.atten1 = CrossAttention(dim, dim)
        self.norm_atten1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
        self.atten2 = CrossAttention(dim, 768)

        #act
        self.norm_act = torch.nn.LayerNorm(dim, elementwise_affine=True)
        self.fc0 = torch.nn.Linear(dim, dim * 8)
        self.act = torch.nn.GELU()
        self.fc1 = torch.nn.Linear(dim * 4, dim)

        #out
        self.cnn_out = torch.nn.Conv2d(dim,
                                       dim,
                                       kernel_size=1,
                                       stride=1,
                                       padding=0)

    # 🔥 q 是一个图像数据, kv 是一个文本数据
    def forward(self, q, kv):
        #q -> [1, 320, 64, 64]
        #kv -> [1, 77, 768]
        b, _, h, w = q.shape
        res1 = q

        #----in 输入部分----
        # 先进行规范化，然后通过卷积层处理，维度保持不变
        #[1, 320, 64, 64]
        q = self.cnn_in(self.norm_in(q))

        #[1, 320, 64, 64] -> [1, 64, 64, 320] -> [1, 4096, 320]
        q = q.permute(0, 2, 3, 1).reshape(b, h * w, self.dim)

        #---- atten ----
         # 第一个注意力层处理输入数据，维度保持不变
        #[1, 4096, 320]
        q = self.atten1(q=self.norm_atten0(q), kv=self.norm_atten0(q)) + q
        q = self.atten2(q=self.norm_atten1(q), kv=kv) + q


        #---- act 激活函数部分 ----
        # 先进行规范化，然后通过全连接层扩展维度
        # [1, 4096, 320]
        res2 = q
        #[1, 4096, 320] -> [1, 4096, 2560]
        q = self.fc0(self.norm_act(q))


        # 切分和激活操作
        # 1280
        d = q.shape[2] // 2

        # 通过另一个全连接层减少维度
        #[1, 4096, 1280] * [1, 4096, 1280] -> [1, 4096, 1280]
        q = q[:, :, :d] * self.act(q[:, :, d:])

        #[1, 4096, 1280] -> [1, 4096, 320]
        q = self.fc1(q) + res2



        #----out 输出部分 (一个 CNN 的运算加上残差的连接）----
        # 将数据恢复成图像的维度，维度保持不变
        #[1, 4096, 320] -> [1, 64, 64, 320] -> [1, 320, 64, 64]
        q = q.reshape(b, h, w, self.dim).permute(0, 3, 1, 2).contiguous()

        # 通过卷积层处理输出，加上残差连接, 维度不变
        # [1, 320, 64, 64]
        q = self.cnn_out(q) + res1

        return q


Transformer(320)(torch.randn(1, 320, 64, 64), torch.randn(1, 77, 768)).shape

torch.Size([1, 320, 64, 64])

In [4]:
# Down 层
# DownBlock 层类
class DownBlock(torch.nn.Module):

    def __init__(self, dim_in, dim_out):
        super().__init__()

        # 初始化两个 Transformer 层和两个 ResNet 层
        self.tf0 = Transformer(dim_out)  # 用于处理输入数据的第一个 Transformer 层
        self.res0 = Resnet(dim_in, dim_out)  # 用于输入数据的第一个 ResNet 层

        self.tf1 = Transformer(dim_out)  # 用于处理数据的第二个 Transformer 层
        self.res1 = Resnet(dim_out, dim_out)  # 用于数据的第二个 ResNet 层

        # 输出卷积层，用于下采样，将空间维度减少一半
        self.out = torch.nn.Conv2d(dim_out,
                                   dim_out,
                                   kernel_size=3,
                                   stride=2,
                                   padding=1)

    def forward(self, out_vae, out_encoder, time):
        outs = []

        # 第一阶段处理
        out_vae = self.res0(out_vae, time)  # 通过第一个 ResNet 层处理输入数据
        out_vae = self.tf0(out_vae, out_encoder)  # 通过第一个 Transformer 层处理数据
        outs.append(out_vae)  # 将结果添加到输出列表中

        # 第二阶段处理
        out_vae = self.res1(out_vae, time)  # 通过第二个 ResNet 层处理数据
        out_vae = self.tf1(out_vae, out_encoder)  # 通过第二个 Transformer 层处理数据
        outs.append(out_vae)  # 将结果添加到输出列表中

        # 输出阶段
        out_vae = self.out(out_vae)  # 使用卷积层对数据进行下采样
        outs.append(out_vae)  # 将最终结果添加到输出列表中

        # 返回处理后的数据和所有阶段的结果
        return out_vae, outs


DownBlock(320, 640)(torch.randn(1, 320, 32, 32), torch.randn(1, 77, 768),
                    torch.randn(1, 1280))[0].shape

torch.Size([1, 640, 16, 16])

In [5]:
# Up 层
class UpBlock(torch.nn.Module):

    def __init__(self, dim_in, dim_out, dim_prev, add_up):
        super().__init__()

        # 初始化三个 ResNet 层和三个 Transformer 层
        self.res0 = Resnet(dim_out + dim_prev, dim_out)  # 处理合并后的输入数据
        self.res1 = Resnet(dim_out + dim_out, dim_out)  # 处理合并后的数据
        self.res2 = Resnet(dim_in + dim_out, dim_out)  # 处理合并后的数据

        self.tf0 = Transformer(dim_out)  # 用于处理数据的第一个 Transformer 层
        self.tf1 = Transformer(dim_out)  # 用于处理数据的第二个 Transformer 层
        self.tf2 = Transformer(dim_out)  # 用于处理数据的第三个 Transformer 层

        # 输出层，用于上采样并进一步处理数据
        self.out = None
        if add_up:
            self.out = torch.nn.Sequential(
                torch.nn.Upsample(scale_factor=2, mode='nearest'),  # 上采样，将数据的空间尺寸扩大一倍
                torch.nn.Conv2d(dim_out, dim_out, kernel_size=3, padding=1),  # 进一步处理上采样后的数据
            )

    def forward(self, out_vae, out_encoder, time, out_down):
        # 第一阶段处理
        # 合并当前层的输出数据 `out_vae` 和来自下采样层的数据 `out_down`，然后通过第一个 ResNet 层处理
        out_vae = self.res0(torch.cat([out_vae, out_down.pop()], dim=1), time)
        out_vae = self.tf0(out_vae, out_encoder)  # 通过第一个 Transformer 层处理数据

        # 第二阶段处理
        # 合并当前层的输出数据 `out_vae` 和来自下采样层的数据 `out_down`，然后通过第二个 ResNet 层处理
        out_vae = self.res1(torch.cat([out_vae, out_down.pop()], dim=1), time)
        out_vae = self.tf1(out_vae, out_encoder)  # 通过第二个 Transformer 层处理数据

        # 第三阶段处理
        # 合并当前层的输出数据 `out_vae` 和来自下采样层的数据 `out_down`，然后通过第三个 ResNet 层处理
        out_vae = self.res2(torch.cat([out_vae, out_down.pop()], dim=1), time)
        out_vae = self.tf2(out_vae, out_encoder)  # 通过第三个 Transformer 层处理数据

        # 如果需要上采样，则应用上采样层
        if self.out:
            out_vae = self.out(out_vae)

        return out_vae


# 测试
UpBlock(320, 640, 1280, True)(torch.randn(1, 1280, 32, 32),
                              torch.randn(1, 77, 768), torch.randn(1, 1280), [
                                  torch.randn(1, 320, 32, 32),
                                  torch.randn(1, 640, 32, 32),
                                  torch.randn(1, 640, 32, 32)
                              ]).shape

torch.Size([1, 640, 64, 64])

In [6]:
# 🌟 Unet 模型的类 【核心】
class UNet(torch.nn.Module):

    def __init__(self):
        super().__init__()

        #in
        self.in_vae = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)

        self.in_time = torch.nn.Sequential(
            torch.nn.Linear(320, 1280),
            torch.nn.SiLU(),
            torch.nn.Linear(1280, 1280),
        )

        #down
        self.down_block0 = DownBlock(320, 320)
        self.down_block1 = DownBlock(320, 640)
        self.down_block2 = DownBlock(640, 1280)

        self.down_res0 = Resnet(1280, 1280)
        self.down_res1 = Resnet(1280, 1280)

        #mid
        self.mid_res0 = Resnet(1280, 1280)
        self.mid_tf = Transformer(1280)
        self.mid_res1 = Resnet(1280, 1280)

        #up
        self.up_res0 = Resnet(2560, 1280)
        self.up_res1 = Resnet(2560, 1280)
        self.up_res2 = Resnet(2560, 1280)

        self.up_in = torch.nn.Sequential(
            torch.nn.Upsample(scale_factor=2, mode='nearest'),
            torch.nn.Conv2d(1280, 1280, kernel_size=3, padding=1),
        )

        self.up_block0 = UpBlock(640, 1280, 1280, True)
        self.up_block1 = UpBlock(320, 640, 1280, True)
        self.up_block2 = UpBlock(320, 320, 640, False)

        #out
        self.out = torch.nn.Sequential(
            torch.nn.GroupNorm(num_channels=320, num_groups=32, eps=1e-5),
            torch.nn.SiLU(),
            torch.nn.Conv2d(320, 4, kernel_size=3, padding=1),
        )

    # 🔥 out_vae 是一个特征图的形状
    def forward(self, out_vae, out_encoder, time):
        #out_vae -> [1, 4, 64, 64]     特征图的形状
        #out_encoder -> [1, 77, 768]   文本数据的形状
        #time -> [1]

        #-------- in------------------------------------------------------------------------
        #[1, 4, 64, 64] -> [1, 320, 64, 64]
        out_vae = self.in_vae(out_vae) # CNN 运算

        def get_time_embed(t):
            #-9.210340371976184 = -math.log(10000)
            e = torch.arange(160) * -9.210340371976184 / 160
            e = e.exp().to(t.device) * t

            #[160+160] -> [320] -> [1, 320]
            e = torch.cat([e.cos(), e.sin()]).unsqueeze(dim=0)

            return e

        # 对 time 进行运算, 变成一个 1280 的向量
        #[1] -> [1, 320]
        time = get_time_embed(time)
        #[1, 320] -> [1, 1280]
        time = self.in_time(time)

        #---------- down 层的计算 : 一层层的计算图文的注意力 ------------------------------------------------------------------------
        #[1, 320, 64, 64]
        #[1, 320, 64, 64]
        #[1, 320, 64, 64]
        #[1, 320, 32, 32]
        #[1, 640, 32, 32]
        #[1, 640, 32, 32]
        #[1, 640, 16, 16]
        #[1, 1280, 16, 16]
        #[1, 1280, 16, 16]
        #[1, 1280, 8, 8]
        #[1, 1280, 8, 8]
        #[1, 1280, 8, 8]
        out_down = [out_vae]

        #[1, 320, 64, 64],[1, 77, 768],[1, 1280] -> [1, 320, 32, 32]
        #out -> [1, 320, 64, 64],[1, 320, 64, 64][1, 320, 32, 32]
        out_vae, out = self.down_block0(out_vae=out_vae,
                                        out_encoder=out_encoder,
                                        time=time)
        out_down.extend(out)

        #[1, 320, 32, 32],[1, 77, 768],[1, 1280] -> [1, 640, 16, 16]
        #out -> [1, 640, 32, 32],[1, 640, 32, 32],[1, 640, 16, 16]
        out_vae, out = self.down_block1(out_vae=out_vae,
                                        out_encoder=out_encoder,
                                        time=time)
        out_down.extend(out)

        #[1, 640, 16, 16],[1, 77, 768],[1, 1280] -> [1, 1280, 8, 8]
        #out -> [1, 1280, 16, 16],[1, 1280, 16, 16],[1, 1280, 8, 8]
        out_vae, out = self.down_block2(out_vae=out_vae,
                                        out_encoder=out_encoder,
                                        time=time)
        out_down.extend(out)

        #[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
        out_vae = self.down_res0(out_vae, time)
        out_down.append(out_vae)

        #[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
        out_vae = self.down_res1(out_vae, time)
        out_down.append(out_vae)

        #---- mid 层 : 也是计算图文的注意力----
        #[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
        out_vae = self.mid_res0(out_vae, time)

        #[1, 1280, 8, 8],[1, 77, 768] -> [1, 1280, 8, 8]
        out_vae = self.mid_tf(out_vae, out_encoder)

        #[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
        out_vae = self.mid_res1(out_vae, time)


        # ------------- up 层计算 : 跨层连接 dwon ------------------------------------------------------------------------
        #[1, 1280+1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
        out_vae = self.up_res0(torch.cat([out_vae, out_down.pop()], dim=1),
                               time)

        #[1, 1280+1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
        out_vae = self.up_res1(torch.cat([out_vae, out_down.pop()], dim=1),
                               time)

        #[1, 1280+1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
        out_vae = self.up_res2(torch.cat([out_vae, out_down.pop()], dim=1),
                               time)

        #[1, 1280, 8, 8] -> [1, 1280, 16, 16]
        out_vae = self.up_in(out_vae)

        #[1, 1280, 16, 16],[1, 77, 768],[1, 1280] -> [1, 1280, 32, 32]
        #out_down -> [1, 640, 16, 16],[1, 1280, 16, 16],[1, 1280, 16, 16]
        out_vae = self.up_block0(out_vae=out_vae,
                                 out_encoder=out_encoder,
                                 time=time,
                                 out_down=out_down)

        #[1, 1280, 32, 32],[1, 77, 768],[1, 1280] -> [1, 640, 64, 64]
        #out_down -> [1, 320, 32, 32],[1, 640, 32, 32],[1, 640, 32, 32]
        out_vae = self.up_block1(out_vae=out_vae,
                                 out_encoder=out_encoder,
                                 time=time,
                                 out_down=out_down)

        #[1, 640, 64, 64],[1, 77, 768],[1, 1280] -> [1, 320, 64, 64]
        #out_down -> [1, 320, 64, 64],[1, 320, 64, 64],[1, 320, 64, 64]
        out_vae = self.up_block2(out_vae=out_vae,
                                 out_encoder=out_encoder,
                                 time=time,
                                 out_down=out_down)

        #------ out 层计算 : 卷积的输出 ----------------------------------------------------------------------------
        #[1, 320, 64, 64] -> [1, 4, 64, 64]
        out_vae = self.out(out_vae)

        return out_vae


UNet()(torch.randn(2, 4, 64, 64), torch.randn(2, 77, 768),
       torch.LongTensor([26])).shape

torch.Size([2, 4, 64, 64])

In [7]:
# 使用【预训练的模型】初始化数据
from diffusers import UNet2DConditionModel

#加载预训练模型的参数
params = UNet2DConditionModel.from_pretrained(
    'model/diffsion_from_scratch.params', subfolder='unet')

unet = UNet()

# 加载 in 的参数
unet.in_vae.load_state_dict(params.conv_in.state_dict())
unet.in_time[0].load_state_dict(params.time_embedding.linear_1.state_dict())
unet.in_time[2].load_state_dict(params.time_embedding.linear_2.state_dict())


# 加载 down 的参数
def load_tf(model, param):
    model.norm_in.load_state_dict(param.norm.state_dict())
    model.cnn_in.load_state_dict(param.proj_in.state_dict())

    model.atten1.q.load_state_dict(
        param.transformer_blocks[0].attn1.to_q.state_dict())
    model.atten1.k.load_state_dict(
        param.transformer_blocks[0].attn1.to_k.state_dict())
    model.atten1.v.load_state_dict(
        param.transformer_blocks[0].attn1.to_v.state_dict())
    model.atten1.out.load_state_dict(
        param.transformer_blocks[0].attn1.to_out[0].state_dict())

    model.atten2.q.load_state_dict(
        param.transformer_blocks[0].attn2.to_q.state_dict())
    model.atten2.k.load_state_dict(
        param.transformer_blocks[0].attn2.to_k.state_dict())
    model.atten2.v.load_state_dict(
        param.transformer_blocks[0].attn2.to_v.state_dict())
    model.atten2.out.load_state_dict(
        param.transformer_blocks[0].attn2.to_out[0].state_dict())

    model.fc0.load_state_dict(
        param.transformer_blocks[0].ff.net[0].proj.state_dict())

    model.fc1.load_state_dict(
        param.transformer_blocks[0].ff.net[2].state_dict())

    model.norm_atten0.load_state_dict(
        param.transformer_blocks[0].norm1.state_dict())
    model.norm_atten1.load_state_dict(
        param.transformer_blocks[0].norm2.state_dict())
    model.norm_act.load_state_dict(
        param.transformer_blocks[0].norm3.state_dict())

    model.cnn_out.load_state_dict(param.proj_out.state_dict())


def load_res(model, param):
    model.time[1].load_state_dict(param.time_emb_proj.state_dict())

    model.s0[0].load_state_dict(param.norm1.state_dict())
    model.s0[2].load_state_dict(param.conv1.state_dict())

    model.s1[0].load_state_dict(param.norm2.state_dict())
    model.s1[2].load_state_dict(param.conv2.state_dict())

    if isinstance(model.res, torch.nn.Module):
        model.res.load_state_dict(param.conv_shortcut.state_dict())


def load_down_block(model, param):
    load_tf(model.tf0, param.attentions[0])
    load_tf(model.tf1, param.attentions[1])

    load_res(model.res0, param.resnets[0])
    load_res(model.res1, param.resnets[1])

    model.out.load_state_dict(param.downsamplers[0].conv.state_dict())


load_down_block(unet.down_block0, params.down_blocks[0])
load_down_block(unet.down_block1, params.down_blocks[1])
load_down_block(unet.down_block2, params.down_blocks[2])

load_res(unet.down_res0, params.down_blocks[3].resnets[0])
load_res(unet.down_res1, params.down_blocks[3].resnets[1])

# 加载 mid 的数据
load_tf(unet.mid_tf, params.mid_block.attentions[0])
load_res(unet.mid_res0, params.mid_block.resnets[0])
load_res(unet.mid_res1, params.mid_block.resnets[1])

# 加载 up 的数据
load_res(unet.up_res0, params.up_blocks[0].resnets[0])
load_res(unet.up_res1, params.up_blocks[0].resnets[1])
load_res(unet.up_res2, params.up_blocks[0].resnets[2])
unet.up_in[1].load_state_dict(
    params.up_blocks[0].upsamplers[0].conv.state_dict())


def load_up_block(model, param):
    load_tf(model.tf0, param.attentions[0])
    load_tf(model.tf1, param.attentions[1])
    load_tf(model.tf2, param.attentions[2])

    load_res(model.res0, param.resnets[0])
    load_res(model.res1, param.resnets[1])
    load_res(model.res2, param.resnets[2])

    if isinstance(model.out, torch.nn.Module):
        model.out[1].load_state_dict(param.upsamplers[0].conv.state_dict())


load_up_block(unet.up_block0, params.up_blocks[1])
load_up_block(unet.up_block1, params.up_blocks[2])
load_up_block(unet.up_block2, params.up_blocks[3])

# 加载 out 的数据
unet.out[0].load_state_dict(params.conv_norm_out.state_dict())
unet.out[2].load_state_dict(params.conv_out.state_dict())

  from .autonotebook import tqdm as notebook_tqdm
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
  return torch.load(checkpoint_file, map_location="cpu")


<All keys matched successfully>

In [8]:
# 虚拟一批数据并检查两个模型的输出结果
out_vae = torch.randn(1, 4, 64, 64)
out_encoder = torch.randn(1, 77, 768)
time = torch.LongTensor([26])

a = unet(out_vae=out_vae, out_encoder=out_encoder, time=time)
b = params(out_vae, time, out_encoder).sample

(a == b).all()

tensor(True)