### 参考にしたサイト
https://tech.gmogshd.com/transformer/

In [1]:
with open('./data.txt','r',encoding='utf-8') as f:
        text = f.read()
print("テキストの文字数 :", len(text))
print("最初の30文字 : ",text[:30])

テキストの文字数 : 1063
最初の30文字 :  Head Mounted Displayをはじめとした立体視


In [2]:
import torch
import torch.nn as nn

# 使用されている文字
chars = sorted(list(set(text)))
print(chars[:30],chars[50:70])
# 使用されている文字数
char_size = len(chars)

# 文字と数字を一対一対応させる辞書
char2int = { ch : i for i, ch in enumerate(chars) }
int2char = { i : ch for i, ch in enumerate(chars) }

# 文字と数字を変換する関数
encode = lambda a: [char2int[b] for b in a ]
decode = lambda a: ''.join([int2char[b] for b in a ])
print("decode_example:",decode([40,2,5,8,23,56]))

# テキストファイルを数字にして，tensor型に変換
train_data = torch.tensor(encode(text), dtype=torch.long)
print(train_data.shape)
print(train_data[:20])

['\n', ' ', '%', '(', ')', '-', '.', '3', 'A', 'C', 'D', 'F', 'G', 'H', 'L', 'M', 'N', 'P', 'S', 'T', 'U', 'Y', '\\', 'a', 'b', 'c', 'd', 'e', 'g', 'h'] ['か', 'が', 'き', 'く', 'こ', 'さ', 'し', 'じ', 'す', 'そ', 'た', 'っ', 'つ', 'て', 'で', 'と', 'ど', 'な', 'に', 'の']
decode_example: u%-Aaし
torch.Size([1063])
tensor([13, 27, 23, 26,  1, 15, 35, 40, 34, 39, 27, 26,  1, 10, 30, 38, 36, 32,
        23, 43])


In [3]:
vector_size = 3

# [単語数] → [単語数，次元数(vector_size)]
embeddings = nn.Embedding(char_size, vector_size)

# e.g. ホログラフィをベクトルにする
encoded_words = torch.tensor(encode("ホログラフィ"))
embeddings_words  = embeddings(encoded_words)
print("[ホログラフィ]のベクトル表現 : \n",embeddings_words)

[ホログラフィ]のベクトル表現 : 
 tensor([[-0.3880,  0.0420,  0.4655],
        [-1.6588,  0.9200, -0.3392],
        [-2.1191, -0.1704,  0.4300],
        [ 1.3032,  0.5103,  1.1194],
        [ 0.4227, -0.5785, -0.2832],
        [-1.4950, -0.5288,  0.5902]], grad_fn=<EmbeddingBackward0>)


In [14]:
class SelfAttention_Head(nn.Module):

    def __init__(self, n_mbed, head_size, block_size):
        super().__init__()
        self.key = nn.Linear(n_mbed, head_size, bias=False)
        self.query = nn.Linear(n_mbed, head_size, bias=False)
        self.value = nn.Linear(n_mbed, head_size, bias=False)
        # 上三角をゼロに，下三角をそのまま
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        print(B,T,C)

        k = self.key(x)
        print("k",k)
        q = self.query(x)
        print("q",q)
        v = self.value(x)
        print("v",v)

        wei = q @ k.transpose(-2,-1)* C ** -0.5
        print(wei)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        print(wei)
        wei = nn.functional.softmax(wei, dim=-1)
        print(wei)

        out = wei @ v
        return out

In [18]:
### 次元を揃える
embeddings_words  = embeddings(encoded_words)
embeddings_words = embeddings_words.unsqueeze(dim = 0)
print(embeddings_words.shape)

attention_head = SelfAttention_Head(3,3,1)
attention_head.forward(embeddings_words)

torch.Size([1, 6, 3])
1 6 3
tensor([[[ 1.0364, -0.6619,  0.1237],
         [ 0.4633, -0.8849,  1.0785],
         [ 0.1580,  0.3470, -0.7655],
         [ 0.3028, -0.6931,  0.9132],
         [ 0.6378, -0.8441,  0.8466],
         [-0.1282,  0.3280, -0.4546]]], grad_fn=<UnsafeViewBackward0>)
tensor([[[ 6.8456e-01, -6.5586e-01, -7.9003e-01],
         [ 2.9608e-01, -1.2033e+00, -7.8217e-01],
         [ 2.5713e-01,  5.5882e-01,  5.8389e-02],
         [-1.6254e-02, -9.1612e-01, -3.8383e-01],
         [ 1.4199e-01, -1.0162e+00, -5.2828e-01],
         [ 1.9334e-01,  3.9858e-01, -1.0645e-03]]],
       grad_fn=<UnsafeViewBackward0>)
tensor([[[-1.0467, -1.3859,  1.2149],
         [ 0.5841, -1.3454, -0.2818],
         [-1.1587,  0.1032,  1.0285],
         [ 0.8717, -0.6808, -0.6540],
         [ 0.5087, -0.9464, -0.2684],
         [-0.6855,  0.0267,  0.6016]]], grad_fn=<UnsafeViewBackward0>)
tensor([[[ 0.6039,  0.0263,  0.2802, -0.0344,  0.1856,  0.0325],
         [ 0.5811,  0.2069,  0.1316,  0.1209,

tensor([[[-0.2847, -0.7517,  0.4010],
         [-0.1907, -0.7951,  0.3244],
         [-0.2356, -0.6485,  0.3358],
         [-0.1143, -0.7819,  0.2525],
         [-0.1430, -0.7903,  0.2802],
         [-0.2253, -0.6578,  0.3284]]], grad_fn=<UnsafeViewBackward0>)