In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$

In [2]:
class ScaledDotProductAttention(nn.Module):
    '''
    Scaled Dot Product Attention
    attention(Q,K,V) = softmax((Q*K^T)/sqrt(d_k))*V
    '''
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, query, key, value, mask=None):
        '''

        :param query: 查询向量，[batch_size, num_heads, seq_len, d_model]
        :param key:  键值向量， [batch_size, num_heads, seq_len, d_model]
        :param value: 值向量， [batch_size, num_heads, seq_len, d_model]
        :param mask [[1,0,0],[1,1,0], [1,1,1]]
        :return:
        '''
        d_k= key.size()[-1]
        # scores [batch_size, seq_len, seq_len]
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

        #mask:give a A negative epsilon
        if mask is not None:
            scores = scores.masked_fill(mask==0, -1e9)
        scores = torch.softmax(scores, dim=-1)
        # output [batch_size, seq_len, d_model]  broadcast
        output = torch.matmul(scores, value)

        return output, scores

In [3]:
sdpa = ScaledDotProductAttention()

batch_size, seq_len, d_model = 16, 10, 768

query = torch.randn(batch_size, seq_len, d_model)
key = torch.randn(batch_size, seq_len, d_model)
value = torch.randn(batch_size, seq_len, d_model)

output, scores = sdpa(query, key, value)

print(f"output size is {output.size()}")
print(f"scores size is {scores.size()}")
print(f"score size is {scores}")

output size is torch.Size([16, 10, 768])
scores size is torch.Size([16, 10, 10])
score size is tensor([[[0.0650, 0.0439, 0.1345,  ..., 0.0685, 0.0072, 0.0500],
         [0.1130, 0.0637, 0.0620,  ..., 0.0549, 0.1023, 0.1027],
         [0.0322, 0.0253, 0.3055,  ..., 0.2461, 0.0118, 0.1107],
         ...,
         [0.1474, 0.0704, 0.0691,  ..., 0.0969, 0.0719, 0.0782],
         [0.6037, 0.0083, 0.0951,  ..., 0.0349, 0.0395, 0.0488],
         [0.0390, 0.0444, 0.5200,  ..., 0.0556, 0.1627, 0.0166]],

        [[0.1081, 0.0291, 0.2011,  ..., 0.0379, 0.1862, 0.0156],
         [0.1251, 0.0318, 0.0303,  ..., 0.0689, 0.0555, 0.1269],
         [0.1809, 0.0568, 0.0204,  ..., 0.0760, 0.0344, 0.0459],
         ...,
         [0.0044, 0.0211, 0.0781,  ..., 0.1824, 0.0252, 0.0061],
         [0.0128, 0.0626, 0.1920,  ..., 0.0288, 0.1591, 0.0170],
         [0.0903, 0.0451, 0.2762,  ..., 0.0544, 0.0705, 0.0963]],

        [[0.0161, 0.0658, 0.0727,  ..., 0.0330, 0.3254, 0.2031],
         [0.0089, 0.0446, 0.

In [4]:
sum(scores[0][0])

tensor(1.)

Implement of MASK：

In [5]:
mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)
mask.size()

torch.Size([1, 10, 10])

In [6]:
output, scores = sdpa(query, key, value, mask)

print(f"output size is {output.size()}")
print(f"scores size is {scores.size()}")
print(f"score size is {scores}")

output size is torch.Size([16, 10, 768])
scores size is torch.Size([16, 10, 10])
score size is tensor([[[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.6394, 0.3606, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0887, 0.0697, 0.8416,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.1735, 0.0828, 0.0813,  ..., 0.1140, 0.0000, 0.0000],
         [0.6346, 0.0087, 0.1000,  ..., 0.0367, 0.0416, 0.0000],
         [0.0390, 0.0444, 0.5200,  ..., 0.0556, 0.1627, 0.0166]],

        [[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.7974, 0.2026, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.7007, 0.2202, 0.0791,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0045, 0.0218, 0.0806,  ..., 0.1883, 0.0000, 0.0000],
         [0.0131, 0.0636, 0.1953,  ..., 0.0293, 0.1619, 0.0000],
         [0.0903, 0.0451, 0.2762,  ..., 0.0544, 0.0705, 0.0963]],

        [[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.1668, 0.8332, 0.

Multiheadattention：split the tenser and merge and compute attention

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        self.d_model = d_model
        self.num_heads = num_heads

        super(MultiHeadAttention, self).__init__()
        assert self.d_model%self.num_heads==0, "d_model must be divisible by num_heads"

        self.head_dim = self.d_model//self.num_heads

        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)

        self.out_proj = nn.Linear(d_model, d_model)
        self.attention = ScaledDotProductAttention()

    def split_heads(self, x):
        '''

        :param x: [batch_size, seq_len, d_model]
        :return:
            [batch_size, num_heads, seq_len, head_dim]
        '''
        batch_size, seq_len, d_model = x.size()
        assert  d_model==self.head_dim*self.num_heads, f"input must in dim {self.num_heads*self.head_dim} but input dim is {d_model}"

        return x.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
    def combine_heads(self,x):
        '''

        :param x:  [batch_size, num_heads, seq_len, head_dim]
        :return:  [batch_size, seq_len, num_heads*head_dim]
        '''
        batch_size, num_heads, seq_len, head_dim = x.size()

        return x.transpose(1,2).contiguous().view(batch_size, seq_len,num_heads*head_dim)

    def forward(self, x, mask=None):

        query = self.query_proj(x)
        key = self.key_proj(x)
        value = self.value_proj(x)

        splited_query = self.split_heads(query)
        splited_key = self.split_heads(key)
        splited_value = self.split_heads(value)

        output, scores =self.attention(splited_query, splited_key, splited_value, mask)
        output = self.combine_heads(output)

        return self.out_proj(output), scores

Apply mask on multi head attention

In [8]:
class MaskedMultiheadAttention(MultiHeadAttention):
    def __init__(self, d_model, num_heads):
        super(MaskedMultiheadAttention, self).__init__(d_model, num_heads)

    def forward(self, x):
        seq_len = x.size()[1]
        mask = torch.tril(torch.ones(1, seq_len,seq_len))
        return super().forward(x, mask)

In [9]:
batch_size, seq_len, d_model = 16, 10, 768
x = torch.randn(batch_size, seq_len, d_model)

mmha = MaskedMultiheadAttention(d_model, num_heads=12)
output, socres = mmha(x)

print(f"output size is {output.size()}")
print(f"scores size is {scores.size()}")
print(f"score size is {scores}")

output size is torch.Size([16, 10, 768])
scores size is torch.Size([16, 10, 10])
score size is tensor([[[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.6394, 0.3606, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0887, 0.0697, 0.8416,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.1735, 0.0828, 0.0813,  ..., 0.1140, 0.0000, 0.0000],
         [0.6346, 0.0087, 0.1000,  ..., 0.0367, 0.0416, 0.0000],
         [0.0390, 0.0444, 0.5200,  ..., 0.0556, 0.1627, 0.0166]],

        [[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.7974, 0.2026, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.7007, 0.2202, 0.0791,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0045, 0.0218, 0.0806,  ..., 0.1883, 0.0000, 0.0000],
         [0.0131, 0.0636, 0.1953,  ..., 0.0293, 0.1619, 0.0000],
         [0.0903, 0.0451, 0.2762,  ..., 0.0544, 0.0705, 0.0963]],

        [[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.1668, 0.8332, 0.

Feedforward：

In [None]:
import torch
import torch.nn as nn

class FeedForwardNeuralNetwork(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForwardNeuralNetwork, self).__init__()
        # laynorm
        self.layer_norm = nn.LayerNorm(d_model)
        # proj
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        #激活函数
        self.activation = nn.GELU()

    def forward(self, x):
        '''

        :param x: [batch_szie, seq_len, hidden_size]
        :return:
        '''
        resiual = x
        output = self.layer_norm(x) #[batch_szie, seq_len, hidden_size]
        output = self.linear1(output) #[batch_szie, seq_len, hidden_size*4]
        output = self.activation(output)#[batch_szie, seq_len, hidden_size*4]
        output = self.linear2(output)#[batch_szie, seq_len, hidden_size]

        return resiual + output

In [13]:
batch_size, seq_len, hidden_size = 16, 10, 768

x = torch.randn(batch_size, seq_len, hidden_size)

ffn = FeedForwardNeuralNetwork(768, 768*4)

output = ffn(x)

print(f"x size is {x.size()}")
print(f"output size is {output.size()}")
print(f"output is {output[0]}")

x size is torch.Size([16, 10, 768])
output size is torch.Size([16, 10, 768])
output is tensor([[-0.1202,  1.5371, -0.2498,  ...,  0.7882,  0.6324,  0.8899],
        [-0.3706, -0.5207,  0.5701,  ...,  1.2478, -1.3609,  0.8454],
        [-1.6795, -0.8868,  0.8596,  ...,  0.4456,  0.9892, -0.0972],
        ...,
        [ 0.0592, -1.9784, -0.6724,  ...,  1.0572,  1.5154,  1.2274],
        [ 0.3856,  0.3645,  0.9655,  ..., -0.1415, -1.2560,  0.9342],
        [ 0.0223,  0.8291, -0.2557,  ..., -1.3377,  0.1309, -0.4588]],
       grad_fn=<SelectBackward0>)


Decoderblock: Feedforward+Multiheadattention

In [14]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super(TransformerDecoderBlock, self).__init__()
        #核心模块
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.feedForward = FeedForwardNeuralNetwork(d_model, d_ff)
        # 后层归一化
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x, attn_mask = None):
        attn_output, attn_weights = self.attention(x, attn_mask)

        ff_output = self.feedForward(x+attn_output)

        output = self.layer_norm(ff_output)

        return output, attn_weights

In [15]:
batch_size, seq_len, hidden_size = 16, 10, 768

x = torch.randn(batch_size, seq_len, hidden_size)

tdb = TransformerDecoderBlock(hidden_size, 12, hidden_size*4)

output, attn_weights = tdb(x)
print(f"x size is {x.size()}")
print(f"output size is {output.size()}")
print(f"output is {output[0]}")

x size is torch.Size([16, 10, 768])
output size is torch.Size([16, 10, 768])
output is tensor([[ 1.0840, -0.0034,  1.2365,  ...,  0.7578,  1.6233, -0.8166],
        [-0.2595, -0.8290,  1.1105,  ..., -0.9751,  1.2091,  0.8741],
        [-1.7811,  0.0750, -0.6566,  ...,  1.0701,  0.0884, -1.8810],
        ...,
        [-0.2670, -0.5450, -0.4385,  ...,  1.3385,  0.6984,  0.4419],
        [-0.8112, -2.5580,  1.4085,  ..., -0.2154,  1.0498,  0.1947],
        [-1.3156,  0.1104, -0.0278,  ..., -0.6975,  0.6842, -0.2235]],
       grad_fn=<SelectBackward0>)


Multilayer Decoder:

In [16]:
import math

class PositionalEncoding(nn.Module):
    """位置编码模块（支持动态序列长度）"""
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # [1, max_len, d_model]

    def forward(self, x):
        # 动态获取位置编码
        position_emb = self.pe[:, :x.size(1)]
        return x + position_emb  # [batch, seq_len, d_model]

Transfomer Decoder:
1. token embedding
2. position embedding
3. Multihead attention
4. Feed forward Nural network
5. attention +FFN ：decoder block
6. multi-layer-decoder
7. laynorm/activation/residual connection

In [17]:
# Transformer实现

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, num_layers, num_heads, d_ff):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len)

        #堆叠Decoder块
        self.layers = nn.ModuleList([
            TransformerDecoderBlock(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])
        self.final_norm = nn.LayerNorm(d_model)
        self.output_layer = nn.Linear(d_model, vocab_size, bias=False)

        # tied embeddings
        self.output_layer.weight = self.token_embedding.weight

        self.init_weights()
        pass
    def init_weights(self):
        nn.init.normal_(self.token_embedding.weight, std=0.02)

        #各层做一下初始化
        for layer in self.layers:
            nn.init.xavier_normal(layer.attention.query_proj.weight)
            nn.init.xavier_normal(layer.attention.key_proj.weight)
            nn.init.xavier_normal(layer.attention.value_proj.weight)

            nn.init.kaiming_normal(layer.feedForward.linear1.weight)
            nn.init.kaiming_uniform(layer.feedForward.linear2.weight)

    def create_causal_mask(self, seq_len):
        mask = torch.tril(torch.ones(seq_len,seq_len))
        return mask

    def forward(self, input_ids):
        '''

        :param x:[batch_size, seq_len]
        :return:
        '''
        batch_size, seq_len = input_ids.size()

        #嵌入
        embeddings = self.token_embedding(input_ids) #[batch_size, seq_len, d_model]
        pos_embedding = self.pos_encoder(embeddings)

        embeddings = embeddings + pos_embedding

        mask = self.create_causal_mask(seq_len)
        # 通过所有的Transformer Decoder block
        hidden_states = embeddings
        all_attn_weights = []
        for layer in self.layers:
            hidden_states, attn_weights = layer(hidden_states, mask)
            all_attn_weights.append(attn_weights)

        hidden_states = self.final_norm(hidden_states)

        logits = self.output_layer(hidden_states)

        return logits, all_attn_weights

In [18]:
model = TransformerDecoder(
    vocab_size=500, d_model=256, max_len=128, num_layers=12, num_heads=8, d_ff=256*4)

  nn.init.xavier_normal(layer.attention.query_proj.weight)
  nn.init.xavier_normal(layer.attention.key_proj.weight)
  nn.init.xavier_normal(layer.attention.value_proj.weight)
  nn.init.kaiming_normal(layer.feedForward.linear1.weight)
  nn.init.kaiming_uniform(layer.feedForward.linear2.weight)


In [19]:
batch_size, seq_len = 16,10

input_ids = torch.randint(0,500,(batch_size, seq_len))

print(f"input size is {input_ids.size()}")
print(f"input is {input_ids}")

output, _ = model(input_ids)

print(f"output size is {output.size()}")
print(f"output is {output[-1]}")

input size is torch.Size([16, 10])
input is tensor([[ 58, 347, 344,  97, 455, 301, 159, 179, 333, 294],
        [470, 389, 478, 340, 150, 375, 448, 131, 492, 111],
        [217,  46, 177,  63, 434, 256, 431, 273, 296, 401],
        [274, 156, 474, 452,  64, 287, 136, 395,  16,  49],
        [171, 484, 118,   6, 145, 463, 405, 178, 424,  35],
        [355, 302,  75, 483, 487,  80, 394, 378, 437, 261],
        [213, 349, 448, 296, 363, 157, 260, 237, 169, 272],
        [232,  25, 351, 485, 236, 251, 412, 428, 174, 421],
        [454, 284, 465, 228, 358,  14, 293, 190, 422, 487],
        [157, 343, 167, 336, 387,  40,   1, 110, 193, 442],
        [ 91, 421, 249, 314,  51, 224, 475, 371, 424,  54],
        [152, 242, 360, 339, 400, 210, 489, 441, 242, 285],
        [131, 472,  96, 462, 141, 258, 450,  78, 174, 484],
        [264, 134,  81,   6,  26, 368, 475,  28, 269, 315],
        [411, 321, 207,  42, 237, 361,  31, 397, 155, 363],
        [131, 253, 399,  70, 110,  86, 256, 286, 474, 48