In [1]:
import torch

class Embed(torch.nn.Module):
    """This part is the embedding part of the model.

    Args:
        torch (_type_): 输入的是一个tensor，shape为[b, 77]，其中b为batch_size，77为每个句子的长度，即77个token。
    """
    def __init__(self):
        super(Embed, self).__init__()
        self.embed = torch.nn.Embedding(49408, 768)
        self.pos_embed = torch.nn.Embedding(77, 768)

        self.register_buffer("pos_ids", torch.arange(77).unsqueeze(dim=0))
        
    def forward(self, input_ids):
        # input_ids = [b, 77]
        
        # [b ,77] -> [b, 77, 768]
        embed = self.embed(input_ids)
        
        # [1, 77] -> [1, 77, 768]
        pos_embed = self.pos_embed(self.pos_ids)
        
        return embed + pos_embed  # 这里面有一个broadcast机制，pos_embed会自动broadcast成[b, 77, 768]的形状
    
Embed()(torch.ones(2,77).long())

tensor([[[-1.8509,  0.3531, -0.2358,  ...,  1.7505,  2.0493, -0.4218],
         [-2.1914, -0.6323, -1.1931,  ...,  2.4355,  1.4947, -2.5901],
         [-2.5607,  0.6912,  1.5422,  ...,  2.8686,  0.5209,  0.4583],
         ...,
         [-1.4648,  1.0369,  0.4630,  ...,  1.1369,  0.5880, -0.0726],
         [-2.3465,  0.5403,  2.3221,  ...,  1.6129, -0.6126,  0.1816],
         [-2.2795, -1.4391,  1.2479,  ...,  2.5200,  0.2735, -0.7442]],

        [[-1.8509,  0.3531, -0.2358,  ...,  1.7505,  2.0493, -0.4218],
         [-2.1914, -0.6323, -1.1931,  ...,  2.4355,  1.4947, -2.5901],
         [-2.5607,  0.6912,  1.5422,  ...,  2.8686,  0.5209,  0.4583],
         ...,
         [-1.4648,  1.0369,  0.4630,  ...,  1.1369,  0.5880, -0.0726],
         [-2.3465,  0.5403,  2.3221,  ...,  1.6129, -0.6126,  0.1816],
         [-2.2795, -1.4391,  1.2479,  ...,  2.5200,  0.2735, -0.7442]]],
       grad_fn=<AddBackward0>)

In [2]:
from turtle import forward


class Attention(torch.nn.Module):
    """注意力机制

    Args:
        torch (_type_): _description_
    """
    def __init__(self) -> None:
        super().__init__()
        self.q = torch.nn.Linear(768, 768)
        self.k = torch.nn.Linear(768, 768)
        self.v = torch.nn.Linear(768, 768)
        self.out = torch.nn.Linear(768, 768)
        
    def forward(self, x):
        # x -> [b, 77, 768]
        b = x.shape[0]
        
        # 纬度不变
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        
        # 拆分注意力头,拆分成12个
        q = q.reshape(b, 77, 12, 64).transpose(1,2).reshape(b*12, 77, 64) * 0.125
        k = k.reshape(b, 77, 12, 64).transpose(1,2).reshape(b*12, 77, 64)
        v = v.reshape(b, 77, 12, 64).transpose(1,2).reshape(b*12, 77, 64)
        
        # 计算注意力得分
        # [b*12, 77, 64] * [b*12, 64, 77] -> [b*12, 77, 77]
        attn = torch.bmm(q, k.transpose(1, 2))
        
        # [b*12, 77, 77] -> [b,12, 77, 77]
        attn = attn.reshape(b, 12, 77, 77)
        
        # 上三角掩码，b个词，
        def  get_mask(b):
            mask = torch.empty(b, 77, 77)
            
            # 上三角负无穷
            mask.fill_(float("-inf"))
            
            mask.triu_(1)
            
            return mask.unsqueeze(1)  # [b, 1, 77, 77]
        
        # [b, 12, 77, 77] + [b, 1, 77, 77] = [b, 12, 77, 77]
        attn = attn + get_mask(attn.shape[0]).to(attn.device)
        
        # [b, 12, 77, 77] -> [b*12, 77, 77]
        attn = attn.reshape(b*12, 77, 77)
        
        # sofrmax, 被mask的部分是0
        attn = attn.softmax(dim = -1)
        
        # 和v的乘积
        # [b*12, 77, 77] * [b*12, 77, 64] -> [b*12, 77, 64]
        attn = torch.bmm(attn, v)
        
        # 恢复回去
        attn = attn.reshape(b, 12, 77, 64).transpose(1,2).reshape(b, 77, 768)
        
        return self.out(attn)
    
Attention()(torch.ones(2, 77, 768)).shape
            
            

torch.Size([2, 77, 768])

In [3]:
# bert 的编码器层
class ClipEncoder(torch.nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        
        self.s1 = torch.nn.Sequential(
            torch.nn.LayerNorm(768),
            Attention(),
        )
        
        self.s2 = torch.nn.Sequential(
            torch.nn.LayerNorm(768),
            torch.nn.Linear(768, 3072),
        )
        
        self.s3 = torch.nn.Linear(3072, 768)
        
    def forward(self, x):
        # x -> [b, 77, 768]
        
        # 维度不变
        x = x +self.s1(x)
        
        # [2, 77, 768]
        res = x

        # [b, 77, 768] -> [b, 77, 3072]
        x = self.s2(x)
        
        # 激活函数
        x = x * (x * 1.702).sigmoid()
        
        return res + self.s3(x)
    
ClipEncoder()(torch.randn(2, 77, 768)).shape

torch.Size([2, 77, 768])

In [4]:
#经过优化之后的代码量少得吓人...
encoder = torch.nn.Sequential(
    Embed(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    torch.nn.LayerNorm(768),
)

# encoder(torch.ones(2, 77).long()).shape

In [8]:
from transformers import CLIPTextModel

local_model_path = 'E:\Myproject\\02Model\HuggingFace-Download-Accelerator\hf_hub\diffusion'
# 加载预训练
params = CLIPTextModel.from_pretrained(local_model_path)
print(params)

CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0): CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05, ele

In [9]:
# 词编码
encoder[0].embed.load_state_dict(
    params.text_model.embeddings.token_embedding.state_dict()
)

# 位置编码
encoder[0].pos_embed.load_state_dict(
    params.text_model.embeddings.position_embedding.state_dict()
)

# 12层编码层
for i in range(12):
    encoder[i+1].s1[0].load_state_dict(
        params.text_model.encoder.layers[i].layer_norm1.state_dict()
    )

    #注意力q矩阵
    encoder[i + 1].s1[1].q.load_state_dict(
        params.text_model.encoder.layers[i].self_attn.q_proj.state_dict())

    #注意力k矩阵
    encoder[i + 1].s1[1].k.load_state_dict(
        params.text_model.encoder.layers[i].self_attn.k_proj.state_dict())

    #注意力v矩阵
    encoder[i + 1].s1[1].v.load_state_dict(
        params.text_model.encoder.layers[i].self_attn.v_proj.state_dict())

    #注意力out
    encoder[i + 1].s1[1].out.load_state_dict(
        params.text_model.encoder.layers[i].self_attn.out_proj.state_dict())

    #第二层norm
    encoder[i + 1].s2[0].load_state_dict(
        params.text_model.encoder.layers[i].layer_norm2.state_dict())

    #mlp第一层fc
    encoder[i + 1].s2[1].load_state_dict(
        params.text_model.encoder.layers[i].mlp.fc1.state_dict())

    #mlp第二层fc
    encoder[i + 1].s3.load_state_dict(
        params.text_model.encoder.layers[i].mlp.fc2.state_dict())

# 输出norm
encoder[13].load_state_dict(params.text_model.final_layer_norm.state_dict())

<All keys matched successfully>

In [10]:
a = encoder(torch.arange(77).unsqueeze(dim=0))
b = params(torch.arange(77).unsqueeze(dim=0)).last_hidden_state
print(a)
print(b)
(a == b).all()

tensor([[[-0.3488,  0.0139, -0.0409,  ..., -0.4707, -0.2910,  0.0627],
         [ 0.6009, -0.4915,  1.0705,  ...,  0.0032,  0.5970, -0.4605],
         [ 0.5848, -1.8402,  0.6390,  ...,  0.3736,  0.1611,  1.0529],
         ...,
         [ 0.7383, -0.1099,  1.2613,  ...,  0.2626, -0.2641,  0.3401],
         [ 1.1845, -0.1865,  1.5217,  ...,  0.2758,  0.1133,  0.1809],
         [ 0.9668, -0.5271,  1.4090,  ..., -0.0710,  0.1474, -0.2603]]],
       grad_fn=<NativeLayerNormBackward0>)
tensor([[[-0.3488,  0.0139, -0.0409,  ..., -0.4707, -0.2910,  0.0627],
         [ 0.6009, -0.4915,  1.0705,  ...,  0.0032,  0.5970, -0.4605],
         [ 0.5848, -1.8402,  0.6390,  ...,  0.3736,  0.1611,  1.0529],
         ...,
         [ 0.7383, -0.1099,  1.2613,  ...,  0.2626, -0.2641,  0.3401],
         [ 1.1845, -0.1865,  1.5217,  ...,  0.2758,  0.1133,  0.1809],
         [ 0.9668, -0.5271,  1.4090,  ..., -0.0710,  0.1474, -0.2603]]],
       grad_fn=<NativeLayerNormBackward0>)


tensor(True)