In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

埋め込み層

In [None]:
class PositionEmbedding(nn.Module):
    def __init__(self, context_size , d_model):
        super(PositionEmbedding, self).__init__()
        self.embedding = nn.Embedding(context_size, d_model)

    def forward(self, x):
        positions = torch.arange(0, x.size(1), device=x.device)
        return self.embedding(positions)

In [None]:
x = torch.LongTensor([[4545, 8410, 458, 3]])
position_embedding = PositionEmbedding(4, 256)
wpe = position_embedding(x)
wpe

tensor([[-1.1679,  1.2869, -0.9309,  ...,  0.4887,  1.3568, -0.7703],
        [ 0.2048,  0.1246, -1.9271,  ...,  0.0097, -1.1410,  1.0631],
        [ 1.0481,  0.5954,  0.4780,  ...,  0.4574,  0.0223, -0.8736],
        [ 1.7000,  0.6803, -0.1919,  ..., -0.1198, -0.5429, -0.6642]],
       grad_fn=<EmbeddingBackward0>)

Transformer

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, context_size, d_model):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(context_size, d_model)

        for pos in range(context_size):
            for i in range(0, d_model, 2):
                pe[pos,i]   = math.sin(pos/(10000**((2*i)/d_model)))
                pe[pos,i+1] = math.cos(pos/(10000**((2*i)/d_model)))

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return self.pe[:, :x.size(1)].detach()

In [None]:
x = torch.LongTensor([[4545, 8410, 458, 3]]) ##学習
positional_encoding = PositionalEncoding(4, 256)
wpe = positional_encoding(x)
wpe

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  7.6172e-01,  ...,  1.0000e+00,
           1.1548e-08,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  9.8705e-01,  ...,  1.0000e+00,
           2.3096e-08,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  5.1731e-01,  ...,  1.0000e+00,
           3.4643e-08,  1.0000e+00]]])

In [None]:
def create_attention_mask(context_size):
    mask = torch.ones((context_size, context_size))
    mask = torch.triu(mask, diagonal=1)
    mask = mask == 0
    mask = mask * 1

    return mask

In [None]:
create_attention_mask(10)

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

In [None]:
class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, d_model, dropout_rate=0.1):
        super().__init__()
        self.sqrt_d_k = d_model ** 0.5
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, q, k, v, mask=None):
        score = torch.matmul(q, k.transpose(2, 3)) /  self.sqrt_d_k

        if mask is not None:
            score = score.masked_fill(mask == 0, float("-inf"))

        attn = F.softmax(score, dim=-1)
        attn = self.dropout(attn)
        output = torch.matmul(attn, v)

        return output, attn

In [None]:
print(math.sqrt(256))
print(256 ** 0.5)

16.0
16.0


In [None]:
context_size = 3
dim = 4
q = torch.randn(context_size, dim)
k = torch.randn(context_size, dim)
v = torch.randn(context_size, dim)
a = q @ k.T
print(q)
print(k)
print(v)
print(a)

tensor([[-0.9215, -0.0370,  1.5726, -0.5920],
        [ 1.0876,  0.3483, -1.5761, -1.3926],
        [ 0.6543, -0.0932, -2.3670, -1.5800]])
tensor([[ 1.4320, -0.1837,  0.5067, -0.5297],
        [ 0.6984, -1.6625,  1.1740,  0.2864],
        [-1.2627, -1.2706, -1.5757,  0.1019]])
tensor([[-0.6995,  2.1653,  0.8612, -0.0959],
        [ 1.0172,  0.8018, -1.1667,  2.4077],
        [ 0.3414,  0.0369,  0.6523,  1.5866]])
tensor([[-0.2024,  1.0945, -1.3278],
        [ 1.4324, -2.0687,  0.5257],
        [ 0.5916, -2.6194,  2.8608]])


In [None]:
a = a / (dim ** 0.5)
a

tensor([[-0.1012,  0.5473, -0.6639],
        [ 0.7162, -1.0344,  0.2629],
        [ 0.2958, -1.3097,  1.4304]])

In [None]:
mask = create_attention_mask(context_size)
mask
a = a * mask
a

tensor([[-0.1012,  0.0000, -0.0000],
        [ 0.7162, -1.0344,  0.0000],
        [ 0.2958, -1.3097,  1.4304]])

In [None]:
attn = a.masked_fill(mask == 0, float("-inf"))
attn

tensor([[-0.1012,    -inf,    -inf],
        [ 0.7162, -1.0344,    -inf],
        [ 0.2958, -1.3097,  1.4304]])

In [None]:
attn = F.softmax(attn, dim=-1)
attn

tensor([[1.0000, 0.0000, 0.0000],
        [0.8520, 0.1480, 0.0000],
        [0.2320, 0.0466, 0.7214]])

In [None]:
attn.shape

torch.Size([3, 3])

In [None]:
v.shape
v

tensor([[-0.6995,  2.1653,  0.8612, -0.0959],
        [ 1.0172,  0.8018, -1.1667,  2.4077],
        [ 0.3414,  0.0369,  0.6523,  1.5866]])

In [None]:
torch.matmul(attn, v)

tensor([[-0.6995,  2.1653,  0.8612, -0.0959],
        [-0.4454,  1.9635,  0.5611,  0.2746],
        [ 0.1314,  0.5663,  0.6160,  1.2345]])

Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model, dropout_rate=0.1):
        super().__init__()
        self.n_head = n_head
        self.d_model = d_model
        self.fc_q = nn.Linear(d_model, d_model)
        self.fc_k = nn.Linear(d_model, d_model)
        self.fc_v = nn.Linear(d_model, d_model)
        self.attn = ScaledDotProductAttention(d_model, dropout_rate)
        self.fc = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout_rate)

        ## 全結合層を初期化
        nn.init.xavier_uniform_(self.fc_q.weight)
        nn.init.xavier_uniform_(self.fc_k.weight)
        nn.init.xavier_uniform_(self.fc_v.weight)
        nn.init.xavier_uniform_(self.fc.weight)

    def forward(self, q, k, v, mask=None):
        N = q.size(0)
        S = q.size(1)
        H = self.n_head
        D = self.d_model // self.n_head


        # 線形変換
        q = self.fc_q(q)
        k = self.fc_k(k)
        v = self.fc_v(v)

        # 展開
        q = q.view(N, S, H, D)
        k = k.view(N, S, H, D)
        v = v.view(N, S, H, D)

        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
        x, attn = self.attn(q, k, v, mask=mask)
        x = x.transpose(1, 2).contiguous().view(N, S, -1)
        x = self.fc(x)
        x = self.dropout(x)

        return x, attn

In [None]:
a = torch.randint(10,(128,15,8,256))
print(a.shape)
a.transpose(1,2).shape

torch.Size([128, 15, 8, 256])


torch.Size([128, 8, 15, 256])

In [None]:
n_head = 8
d_model = 16
attention = MultiHeadAttention(n_head, d_model)
attention

MultiHeadAttention(
  (fc_q): Linear(in_features=16, out_features=16, bias=True)
  (fc_k): Linear(in_features=16, out_features=16, bias=True)
  (fc_v): Linear(in_features=16, out_features=16, bias=True)
  (attn): ScaledDotProductAttention(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (fc): Linear(in_features=16, out_features=16, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [None]:
batch_size = 1
context_size = 10
x = torch.randn(batch_size, context_size, d_model)
q, w = attention(x, x, x)

In [None]:
q.shape

torch.Size([1, 10, 16])

In [None]:
w.shape

torch.Size([1, 8, 10, 10])

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, dropout_rate=0.1):
        super(FeedForward, self).__init__()

        self.fc1 = nn.Linear(d_model, d_model * 4 )
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(d_model * 4, d_model)

        ## 全結合層を初期化
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        h = self.fc1(x)
        h = F.gelu(h)
        h = self.fc2(h)
        h = self.dropout(h)
        return h

In [None]:
d_model = 2
ff = FeedForward(d_model)
print(ff)
x = torch.randn(d_model)
print(x)
ff(x)

FeedForward(
  (fc1): Linear(in_features=2, out_features=8, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (fc2): Linear(in_features=8, out_features=2, bias=True)
)
tensor([-1.2393, -1.7066])


tensor([-0.3460,  0.3750], grad_fn=<MulBackward0>)

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_head, dropout_rate=0.1):
        super(TransformerBlock, self).__init__()
        self.norm_1 = nn.LayerNorm(d_model)
        self.norm_2 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(n_head, d_model, dropout_rate)
        self.ff = FeedForward(d_model)

        nn.init.normal_(self.norm_1.weight, mean=0, std=0.02)
        nn.init.normal_(self.norm_2.weight, mean=0, std=0.02)

    def forward(self, x, mask=None):
        rx = x 
        x, w = self.attn(x, x, x, mask)
        x = self.norm_1(x + rx)

        rx = x
        x = self.ff(x)
        x = self.norm_2(x + rx)

        return x, w

In [None]:
d_model = 2
n_head = 1
block = TransformerBlock(d_model, n_head)
batch_size = 1
context_size = 5

x = torch.randn(batch_size, context_size, d_model)
y, w = block(x)
print(x)
print(y.shape)
print(y)
print(w.shape)
print(w)

tensor([[[-0.7771,  1.0819],
         [ 0.4459, -0.7646],
         [ 0.3905,  0.8654],
         [-1.0705,  0.7716],
         [-0.2005, -0.2011]]])
torch.Size([1, 5, 2])
tensor([[[ 0.0163, -0.0240],
         [ 0.0163, -0.0239],
         [ 0.0163, -0.0239],
         [ 0.0163, -0.0239],
         [ 0.0163, -0.0239]]], grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 1, 5, 5])
tensor([[[[0.2312, 0.2105, 0.2130, 0.2356, 0.2208],
          [0.2304, 0.2123, 0.1645, 0.2665, 0.2375],
          [0.2318, 0.2110, 0.1893, 0.2505, 0.0000],
          [0.2310, 0.2109, 0.0000, 0.2353, 0.2209],
          [0.2310, 0.2122, 0.1826, 0.2541, 0.2312]]]], grad_fn=<MulBackward0>)


In [None]:
class GPT(nn.Module): ## メイン
    def __init__(self, vocab_size, context_size, d_model, n_block, n_head, dropout_rate=0.1):
        super(GPT, self).__init__()
        self.vocab_size = vocab_size
        self.context_size = context_size
        self.d_model = d_model
        self.n_block = n_block
        self.n_head = n_head
        self.token_embedding = nn.Embedding(vocab_size, d_model)

        self.positional_encoding = PositionalEncoding(context_size, d_model)
        self.dropout = nn.Dropout(dropout_rate)
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(d_model, n_head, dropout_rate) for _ in range(self.n_block)])
        self.norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model * context_size, vocab_size)

        nn.init.xavier_uniform_(self.fc.weight)
        nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)


    def forward(self, x, mask=None):
        x = self.token_embedding(x) + self.positional_encoding(x)
        x = self.dropout(x)

        for block in self.transformer_blocks:
            x, w = block(x, mask)

        x = self.norm(x)
        x = x.view(-1, self.context_size * self.d_model)
        x = self.fc(x)


        return x, w

In [None]:
context_size = 5
vocab_size = 10  
d_model = 8
n_block = 6
n_head = 4

In [None]:
model = GPT(vocab_size, context_size, d_model, n_block, n_head)
mask = create_attention_mask(context_size).to(device)
x = torch.LongTensor([[2,2,9,4,9]]) # 0～9までの数値を使って context_size の長さの配列を作成します。
y, w = model(x)

In [None]:
y

tensor([[ 0.1442, -0.9293, -2.4225, -1.8907,  0.2925, -0.4585, -1.0107, -0.3088,
         -1.3548, -1.2617]], grad_fn=<AddmmBackward0>)

In [None]:
w

tensor([[[[0.2222, 0.2223, 0.2223, 0.0000, 0.0000],
          [0.2222, 0.2223, 0.0000, 0.2222, 0.2222],
          [0.2222, 0.2223, 0.2223, 0.0000, 0.2222],
          [0.2222, 0.2223, 0.2223, 0.2222, 0.2222],
          [0.2222, 0.2223, 0.2223, 0.2222, 0.2222]],

         [[0.2222, 0.2222, 0.2222, 0.2222, 0.2222],
          [0.2222, 0.2222, 0.2222, 0.2222, 0.2222],
          [0.2222, 0.2222, 0.2222, 0.2222, 0.2222],
          [0.2222, 0.2222, 0.2222, 0.2222, 0.2222],
          [0.2222, 0.2222, 0.0000, 0.2222, 0.2222]],

         [[0.2223, 0.2222, 0.2222, 0.2222, 0.2222],
          [0.2223, 0.2222, 0.0000, 0.2222, 0.0000],
          [0.2223, 0.2222, 0.2222, 0.2222, 0.2222],
          [0.2223, 0.2222, 0.2222, 0.2222, 0.0000],
          [0.2223, 0.2222, 0.2222, 0.2222, 0.2222]],

         [[0.2222, 0.2222, 0.2222, 0.2222, 0.2222],
          [0.2222, 0.2222, 0.2222, 0.2222, 0.2222],
          [0.2222, 0.2222, 0.2222, 0.0000, 0.2222],
          [0.2222, 0.2222, 0.2222, 0.2222, 0.2222],
      