In [56]:
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from typing import Optional, Any, Union, Callable
from torch import Tensor

In [122]:
input = torch.randn(32, 150, 1024)
conv1_layer = nn.Conv1d(in_channels=1024, out_channels=1024, kernel_size=1)
linear_layer = nn.Linear(1024, 1024)

In [123]:
input = conv1_layer(input[:,:64,:].permute(0,2,1)).permute(0,2,1)

In [124]:
input.shape

torch.Size([32, 64, 1024])

In [111]:
input.shape
output.shape
assert input.shape == output.shape == input.shape

In [37]:
for p in conv1_layer.parameters():
    print(p.shape)

torch.Size([1024, 1024, 1])
torch.Size([1024])


In [38]:
for p in linear_layer.parameters():
    print(p.shape)

torch.Size([1024, 1024])
torch.Size([1024])


In [42]:
key_embedders = nn.ModuleList([nn.Conv1d(in_channels=1024, out_channels=1024, kernel_size=1) for i in range(12)])
value_embedders = nn.ModuleList([nn.Conv1d(in_channels=1024, out_channels=1024, kernel_size=1) for i in range(12)])

In [48]:
for p in key_embedders.parameters():
    print(p.shape)

torch.Size([1024, 1024, 1])
torch.Size([1024])
torch.Size([1024, 1024, 1])
torch.Size([1024])
torch.Size([1024, 1024, 1])
torch.Size([1024])
torch.Size([1024, 1024, 1])
torch.Size([1024])
torch.Size([1024, 1024, 1])
torch.Size([1024])
torch.Size([1024, 1024, 1])
torch.Size([1024])
torch.Size([1024, 1024, 1])
torch.Size([1024])
torch.Size([1024, 1024, 1])
torch.Size([1024])
torch.Size([1024, 1024, 1])
torch.Size([1024])
torch.Size([1024, 1024, 1])
torch.Size([1024])
torch.Size([1024, 1024, 1])
torch.Size([1024])
torch.Size([1024, 1024, 1])
torch.Size([1024])


In [57]:
class prefix_TransformerEncoderLayer(nn.TransformerEncoderLayer):
    def __init__(
        self, 
        d_model: int, 
        nhead: int, 
        dim_feedforward: int = 2048, 
        dropout: float = 0.1, 
        activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 
        layer_norm_eps: float = 0.00001, 
        batch_first: bool = False, 
        norm_first: bool = False, 
        device=None, 
        dtype=None,
        ) -> None:
        super().__init__(
            d_model, 
            nhead, 
            dim_feedforward, 
            dropout, 
            activation, 
            layer_norm_eps, 
            batch_first, 
            norm_first, 
            device, 
            dtype
        )
        self.d_model = d_model

        # 可能需要独立在 其之外： 因为需要frozen 整个LM的参数 还需要加载
        # self.key_prompt_embed = nn.Linear(d_model, d_model)
        # self.value_promt_embed = nn.Linear(d_model, d_model)

    def forward(
            self, 
            src: Tensor,
            key_prompt: Tensor,
            value_prompt: Tensor,
            src_mask: Optional[Tensor] = None,
            src_key_padding_mask: Optional[Tensor] = None
        ) -> Tensor:
        """
            Customized Encoder Layer.
            key_prompt: Prefix-prompt key. It should have the same size as "src"
            value_prompt: Prefix-prompt value. It should have the same size as "src"
            For e.g.:
                The original seq length is T; the prefix length is T'.
                Concaten seq length should be T'+T, while the [B, :T', dim] is promptable, which
                means, controled by params.
                And in each layer of Transformer, the QKV is shaped like [B, T'+T, dim], so the 
                prefix part should also be promptable.

            If key_prompt == value_prompt == x, this equals to a standard TransformerEncoderLayer.
        """
        x = src
        assert x.shape == key_prompt.shape == value_prompt.shape, f"Q {x.shape}, K {key_prompt.shape}, V {value_prompt.shape} should have the same size"
        if self.norm_first:
            # x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
            x = x + self._sa_block_prompt(
                x = self.norm1(x), 
                k_prompt = self.norm1(key_prompt), 
                v_prompt = self.norm1(value_prompt), 
                attn_mask = src_mask, 
                key_padding_mask = src_key_padding_mask
            ) 
            x = x + self._ff_block(self.norm2(x))
        else:
            x = self.norm1(x + self._sa_block_prompt(x, key_prompt, value_prompt, src_mask, src_key_padding_mask))
            x = self.norm2(x + self._ff_block(x))
        return x

    def _sa_block_prompt(
            self,
            x: Tensor,
            k_prompt: Tensor,
            v_prompt: Tensor,
            attn_mask: Optional[Tensor], 
            key_padding_mask: Optional[Tensor]
        ) -> Tensor:
        """ 
        Customized Self-attention Block for prefix prompt tuning.
        Ref: Prefix-tuning https://arxiv.org/abs/2101.00190
        """
        x = self.self_attn(query=x, key=k_prompt, value=v_prompt,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False)[0]
        return self.dropout1(x)

In [58]:
a = prefix_TransformerEncoderLayer(d_model=1024, nhead=16)

In [59]:
a.d_model

1024

In [79]:
input = torch.randn(3,3, requires_grad=True)

In [80]:
input

tensor([[-1.7913,  0.7482,  0.4389],
        [ 2.0273,  0.2695, -0.3742],
        [ 0.0391,  0.1879, -2.4662]], requires_grad=True)

In [90]:
input_2 = input.clone()
# input_2[0,0] = 0

In [102]:
input

tensor([[-1.7913,  0.7482,  0.4389],
        [ 2.0273,  0.2695, -0.3742],
        [ 0.0391,  0.1879, -2.4662]], requires_grad=True)

In [103]:
input_2

tensor([[ 0.2559,  0.1615, -1.3387],
        [ 2.0273,  0.2695, -0.3742],
        [ 0.0391,  0.1879, -2.4662]], grad_fn=<CopySlices>)

In [104]:
linear_simple = nn.Linear(3,3)

In [105]:
prefix = linear_simple(input_2[:1,:])

In [106]:
prefix

tensor([[-0.4677, -0.4245, -0.2293]], grad_fn=<AddmmBackward0>)

In [107]:
input_2[:1,:] = prefix

In [108]:
input_2

tensor([[-0.4677, -0.4245, -0.2293],
        [ 2.0273,  0.2695, -0.3742],
        [ 0.0391,  0.1879, -2.4662]], grad_fn=<CopySlices>)