In [1]:
import torch
import torch.nn as nn
import pandas as pd
import numpy

In [None]:
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    ## batch_size here means input length of the matrix
    
    # !!! X.shape[0] -> dimension of embedding, X.shape[1] -> dimension of one weigth matrix

    # 输入X的形状:(batch_size，查询或者“键－值”对的个数，num_hiddens)
    # 输出X的形状:(batch_size，查询或者“键－值”对的个数，num_heads，num_hiddens/num_heads)
    # X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    X = X.reshape(X.shape[0], num_heads, -1)

    # 输出X的形状:(batch_size，num_heads，查询或者“键－值”对的个数, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size*num_heads,查询或者“键－值”对的个数, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

In [None]:
class Attention(nn.Module):
    def __init__(self, dropout) -> None:
        super(Attention, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, query, key, value, ):

class MultiheadAttention(nn.Module):
    def __init__(self, querySize, keySize, valueSize, numHiddens, numHeads, dropout, bias=False, **kwargs) -> None:
        super(MultiheadAttention, self).__init__(**kwargs)
        self.num_heads = numHeads
        # self.attention = self.selfAttention()
        self.wq = nn.Linear(in_features=querySize, out_features=numHiddens, bias=bias)
        self.wk = nn.Linear(in_features=keySize, out_features=numHiddens, bias=bias)
        self.wv = nn.Linear(in_features=valueSize, out_features=numHiddens, bias=bias)
        self.wo = nn.Linear(in_features=numHiddens, out_features=numHiddens, bias=bias)
        

##### Encoder Structure Mimic

In [17]:
class Encoder(nn.Module):
    def __init__(self, inputSize, hiddenDim, numLayers=6) -> None:
        super(Encoder, self).__init__()
        # self.embedding = self.token_embedding + self.posi_embedding + self.seg_embedding
        self.encoder = nn.ModuleList([EncoderBlock(vecLength=inputSize, hiddenDim=hiddenDim) for _ in range(numLayers)])
    
    def __forward__(self, matrix):
        return self.encoder(matrix)

# device='mps'
class EncoderBlock(nn.Module):
    # without positional encoding
    def __init__(self, vecLength=512, dimension=768, hiddenDim=768, bias=True) -> None:
        """
            vecLength   => the input length of sentence or sequence
            dimension   => the dimension of one vector
        """
        super(EncoderBlock, self).__init__()
        self.WQ = nn.Linear(dimension, hiddenDim, bias=bias)
        self.WV = nn.Linear(dimension, hiddenDim, bias=bias)
        self.WK = nn.Linear(dimension, hiddenDim, bias=bias)
        # self.WO = nn.Linear(length, length)
        self.d = hiddenDim
        self.layerNorm = nn.LayerNorm(hiddenDim)
        self.feedForward = nn.ModuleList([
            nn.Linear(hiddenDim, 1024),
            nn.Linear(1024, dimension)
        ])

    def __forward__(self, matrix):
        self.q = self.WQ(matrix)
        self.k = self.WK(matrix)
        self.v = self.WV(matrix)
        out = self.selfAttention(self.q, self.k, self.v)
        out = self.layerNorm(out)
        out = self.feedForward(out)
        out = self.layerNorm(out)
        return out

    def selfAttention(self, Q, K, V, scaling=None):
        if not scaling:
            import math
            scaling = math.sqrt(self.d)
        queryKey = torch.matmul(Q, K.T)
        attMatrix = (torch.softmax(queryKey) / scaling)
        return torch.matmul(attMatrix, V)

In [18]:
encoderTrans = Encoder(inputSize=512, hiddenDim=768)
print(encoderTrans.parameters)

<bound method Module.parameters of Encoder(
  (encoder): ModuleList(
    (0-5): 6 x EncoderBlock(
      (WQ): Linear(in_features=768, out_features=768, bias=True)
      (WV): Linear(in_features=768, out_features=768, bias=True)
      (WK): Linear(in_features=768, out_features=768, bias=True)
      (layerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (feedForward): ModuleList(
        (0): Linear(in_features=768, out_features=1024, bias=True)
        (1): Linear(in_features=1024, out_features=768, bias=True)
      )
    )
  )
)>


##### KB

In [21]:
## how the nn work
# a = nn.Linear(2, 3)
# b = torch.rand(1, 2, device='cpu')
# a(b)

## Matrix Multiplication
# torch.rand(1, 2) * torch.rand(2, 1)
# torch.matmul(torch.rand(1, 2), torch.rand(2, 1))


## reshape the multihead
# input matrix size => batch_size * input length * vector_size
num_heads = 3
batch_size = 2
inputLen = 2
vecLen = 3
weight_matrix = torch.rand(batch_size, inputLen, vecLen)
print(weight_matrix)
matrix = weight_matrix.reshape(weight_matrix.shape[0], weight_matrix.shape[1], num_heads, -1)
print(matrix)

tensor([[[0.8798, 0.3458, 0.6803],
         [0.8010, 0.9066, 0.7542]],

        [[0.7083, 0.1414, 0.2849],
         [0.7246, 0.5032, 0.6088]]])
tensor([[[[0.8798],
          [0.3458],
          [0.6803]],

         [[0.8010],
          [0.9066],
          [0.7542]]],


        [[[0.7083],
          [0.1414],
          [0.2849]],

         [[0.7246],
          [0.5032],
          [0.6088]]]])
