# FomulaBEAT

変更点
- デコーダのみで学習させる
- TransformerDecoderLayerをスクラッチで書く


In [1]:
version = '03-1'
model_dir = './model/' + version
data_path = 'data/eq03-1.txt'

In [2]:
from pathlib import Path
import math
import time
from collections import Counter
from tqdm import tqdm
import torch
from torch.utils.data import random_split
import torch.nn as nn
from torch import Tensor
from torch.nn import (
    TransformerEncoder, TransformerDecoder,
    TransformerEncoderLayer, TransformerDecoderLayer
)
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import vocab
from torchtext.utils import download_from_url, extract_archive
import torch.nn.functional as F




パラメータの事前設定

In [3]:
%load_ext autoreload
%autoreload 2
torch.set_printoptions(linewidth=100)

In [4]:

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model_dir_path = Path(model_dir)
if not model_dir_path.exists():
    model_dir_path.mkdir(parents=True)

データの取得

In [5]:
def read_equation_file(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()
    src_data, tgt_data = [], []
    for line in lines:
        src, tgt = line.strip().split(' => ')
        src_data.append(src)
        tgt_data.append(tgt)
    return src_data, tgt_data


In [6]:
# ファイルを読み込み、数式データを取得
src_data, tgt_data = read_equation_file(data_path)
print(src_data[:3], tgt_data[:3])


['+ 1 + 6 4', '+ + 7 5 + 3 2', '+ 2 + 2 8'] ['1 6 4 + +', '7 5 + 3 2 + +', '2 2 8 + +']


辞書データの作成

In [7]:

SPECIALS = ['<unk>', '<pad>', '<start>', '<end>']

def build_vocab(texts):
    vocab = {}
    idx = 0
    # 数字の語彙定義
    for i in range(10):
        vocab[str(i)] = idx
        idx += 1
    # 特別語の語彙定義
    for sp in SPECIALS:
        vocab[sp] = idx
        idx += 1
    # その他の文字の語彙定義
    for text in texts:
        for char in text:
            if char not in vocab:
                vocab[char] = idx
                idx += 1
    return vocab


def convert_text_to_indexes(text, vocab):
    # <start> と <end> トークンを追加して数値化
    return [vocab['<start>']] + [vocab[char] if char in vocab else vocab['<unk>'] for char in text] + [vocab['<end>']]

# データを処理して Train と Valid に分ける関数
# データを処理して Train と Valid に分ける関数
def data_process_split(src_texts, tgt_texts, vocab_src, vocab_tgt, valid_size=0.2):
    # データを数値化
    data = []
    for (src, tgt) in zip(src_texts, tgt_texts):
        src_tensor = torch.tensor(convert_text_to_indexes(src, vocab_src), dtype=torch.long)
        tgt_tensor = torch.tensor(convert_text_to_indexes(tgt, vocab_tgt), dtype=torch.long)
        data.append((src_tensor, tgt_tensor))
    
    # データのサイズを計算して、訓練データと検証データに分割
    data_size = len(data)
    valid_size = int(valid_size * data_size)
    train_size = data_size - valid_size

    # PyTorchのrandom_splitを使って分割
    train_data, valid_data = random_split(data, [train_size, valid_size])
    
    return train_data, valid_data



In [8]:
# 辞書と逆辞書を構築
vocab_src = build_vocab(src_data)
vocab_tgt = build_vocab(tgt_data)

print(vocab_tgt)

{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '<unk>': 10, '<pad>': 11, '<start>': 12, '<end>': 13, ' ': 14, '+': 15}


In [9]:

# データを数値化
train_data, valid_data = data_process_split(src_data, tgt_data, vocab_src, vocab_tgt)

# 結果の確認
print('インデックス化された文章')
print(f"Input: {train_data[0][0]}\nOutput: {train_data[0][1]}")

# インデックスから元の文字列に戻す
def convert_indexes_to_text(indexes:list, vocab):
    reverse_vocab = {idx: token for token, idx in vocab.items()}
    return ''.join([reverse_vocab[idx] for idx in indexes if idx in reverse_vocab and reverse_vocab[idx] not in ['<start>', '<end>', '<pad>']])

print('元に戻した文章')
print(f"Input: {convert_indexes_to_text(train_data[0][0].tolist(), vocab_src)}")
print(f"Output: {convert_indexes_to_text(train_data[0][1].tolist(), vocab_tgt)}")


インデックス化された文章
Input: tensor([12,  9, 13])
Output: tensor([12,  9, 13])
元に戻した文章
Input: 9
Output: 9


In [10]:
batch_size = 1024
PAD_IDX = vocab_src['<pad>']
START_IDX = vocab_src['<start>']
END_IDX = vocab_src['<end>']

def generate_batch(data_batch):
    
    batch_src, batch_tgt = [], []
    for src, tgt in data_batch:
        batch_src.append(src)
        batch_tgt.append(tgt)
        
    batch_src = pad_sequence(batch_src, padding_value=PAD_IDX)
    batch_tgt = pad_sequence(batch_tgt, padding_value=PAD_IDX)
    
    return batch_src, batch_tgt

train_iter = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(valid_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)

In [11]:
len(train_data)

80000

Transoformerの設定

In [12]:
class TokenEmbedding(nn.Module):
    
    def __init__(self, vocab_size, embedding_size):
        
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=PAD_IDX)
        self.embedding_size = embedding_size
        
    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.embedding_size)
    
    
class PositionalEncoding(nn.Module):
    
    def __init__(self, embedding_size: int, dropout: float, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        
        den = torch.exp(-torch.arange(0, embedding_size, 2) * math.log(10000) / embedding_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        embedding_pos = torch.zeros((maxlen, embedding_size))
        embedding_pos[:, 0::2] = torch.sin(pos * den)
        embedding_pos[:, 1::2] = torch.cos(pos * den)
        embedding_pos = embedding_pos.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('embedding_pos', embedding_pos)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.embedding_pos[: token_embedding.size(0), :])


In [13]:

class TransformerDecoderLayerScratch(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):
        super(TransformerDecoderLayerScratch, self).__init__()
        # Self-attention for the decoder
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Multihead attention for attending to encoder outputs (memory)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Feedforward layers
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        # Layer normalization layers
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # Self-attention
        tgt2, _ = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        
        # Attention with the encoder outputs (memory)
        tgt2, _ = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        
        # Feedforward network
        tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)

        return tgt


In [14]:

class Seq2SeqTransformer(nn.Module):
    
    def __init__(
        self, num_encoder_layers: int, num_decoder_layers: int,
        embedding_size: int, vocab_size_src: int, vocab_size_tgt: int,
        dim_feedforward:int = 512, dropout:float = 0.1, nhead:int = 8
    ):
        
        super(Seq2SeqTransformer, self).__init__()

        self.token_embedding_src = TokenEmbedding(vocab_size_src, embedding_size)
        self.positional_encoding = PositionalEncoding(embedding_size, dropout=dropout)
        
        self.token_embedding_tgt = TokenEmbedding(vocab_size_tgt, embedding_size)
        self.decoder_layer = TransformerDecoderLayerScratch(
            d_model=embedding_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout
        )
        
        self.output = nn.Linear(embedding_size, vocab_size_tgt)

    def forward(
        self, src: Tensor, tgt: Tensor,
        mask_src: Tensor, mask_tgt: Tensor,
        padding_mask_src: Tensor, padding_mask_tgt: Tensor,
        memory_key_padding_mask: Tensor
    ):
        embedding_src = self.positional_encoding(self.token_embedding_src(src))
        # memory = self.transformer_encoder(embedding_src, mask_src, padding_mask_src)
        embedding_tgt = self.positional_encoding(self.token_embedding_tgt(tgt))
        outs = self.decoder_layer(
            embedding_tgt, embedding_src, mask_tgt, None,
            padding_mask_tgt, memory_key_padding_mask
        )
        return self.output(outs)

    def decode(self, tgt: Tensor, memory: Tensor, mask_tgt: Tensor):
        return self.decoder_layer(self.positional_encoding(self.token_embedding_tgt(tgt)), memory, mask_tgt)

In [15]:
def create_mask(src, tgt, PAD_IDX):
    
    seq_len_src = src.shape[0]
    seq_len_tgt = tgt.shape[0]

    mask_src = torch.zeros((seq_len_src, seq_len_src), device=device).type(torch.bool)
    mask_tgt = generate_square_subsequent_mask(seq_len_tgt)

    padding_mask_src = (src == PAD_IDX).transpose(0, 1)
    padding_mask_tgt = (tgt == PAD_IDX).transpose(0, 1)
    
    return mask_src, mask_tgt, padding_mask_src, padding_mask_tgt


def generate_square_subsequent_mask(seq_len):
    mask = (torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

学習の定義

In [16]:
def train(model, data, optimizer, criterion, PAD_IDX):
    
    model.train()
    losses = 0
    for src, tgt in tqdm(data):
        
        src = src.to(device)
        tgt = tgt.to(device)

        input_tgt = tgt[:-1, :]

        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_IDX)

        logits = model(
            src=src, tgt=input_tgt,
            mask_src=mask_src, mask_tgt=mask_tgt,
            padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
            memory_key_padding_mask=padding_mask_src
        )

        optimizer.zero_grad()
        output_tgt = tgt[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()
        
    return losses / len(data)

In [17]:

def evaluate(model, data, criterion, PAD_IDX):
    
    model.eval()
    losses = 0
    for src, tgt in data:
        
        src = src.to(device)
        tgt = tgt.to(device)

        input_tgt = tgt[:-1, :]

        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_IDX)

        logits = model(
            src=src, tgt=input_tgt,
            mask_src=mask_src, mask_tgt=mask_tgt,
            padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
            memory_key_padding_mask=padding_mask_src
        )
        
        output_tgt = tgt[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        losses += loss.item()
        
    return losses / len(data)

設定

In [18]:
vocab_size_src = len(vocab_src)
vocab_size_tgt = len(vocab_tgt)
embedding_size = 4
nhead = 1
dim_feedforward = 4
num_encoder_layers = 1
num_decoder_layers = 1
dropout = 0
# vocab_size_src = len(vocab_src)
# vocab_size_tgt = len(vocab_tgt)
# embedding_size = 240
# nhead = 8
# dim_feedforward = 100
# num_encoder_layers = 2
# num_decoder_layers = 2
# dropout = 0.1

model = Seq2SeqTransformer(
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    embedding_size=embedding_size,
    vocab_size_src=vocab_size_src, vocab_size_tgt=vocab_size_tgt,
    dim_feedforward=dim_feedforward,
    dropout=dropout, nhead=nhead
)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

model = model.to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(model.parameters())

モデルの調査

In [19]:
print(model)

Seq2SeqTransformer(
  (token_embedding_src): TokenEmbedding(
    (embedding): Embedding(16, 4, padding_idx=11)
  )
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0, inplace=False)
  )
  (token_embedding_tgt): TokenEmbedding(
    (embedding): Embedding(16, 4, padding_idx=11)
  )
  (decoder_layer): TransformerDecoderLayerScratch(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (linear1): Linear(in_features=4, out_features=4, bias=True)
    (dropout): Dropout(p=0, inplace=False)
    (linear2): Linear(in_features=4, out_features=4, bias=True)
    (norm1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (d

In [20]:
# モデル内の層の名前とパラメータ情報を表示
LP = list(model.named_parameters())
lp = len(LP)
print(f"{lp} 層")
for p in range(0, lp):
    print(f"\n層名: {LP[p][0]}")
    print(f"形状: {LP[p][1].shape}")
    print(f"値: {LP[p][1]}")


22 層

層名: token_embedding_src.embedding.weight
形状: torch.Size([16, 4])
値: Parameter containing:
tensor([[ 0.4493, -0.1190, -0.5131, -0.1694],
        [ 0.4179, -0.4150,  0.0538, -0.2728],
        [ 0.5186, -0.2191, -0.0956, -0.3543],
        [ 0.2524, -0.0561, -0.4551, -0.1345],
        [ 0.2105, -0.0209,  0.1182, -0.0670],
        [ 0.3754, -0.4820,  0.5325, -0.0182],
        [ 0.2620, -0.0589,  0.2701,  0.2093],
        [ 0.5421,  0.4825, -0.4296,  0.4211],
        [ 0.3743,  0.4395, -0.0666,  0.2283],
        [-0.0342, -0.4078,  0.1308, -0.2950],
        [-0.2832, -0.2584,  0.1903,  0.0321],
        [-0.1291, -0.2487,  0.4573, -0.1969],
        [ 0.0788,  0.1503, -0.5106,  0.3713],
        [-0.2556,  0.0289, -0.3875, -0.4704],
        [ 0.0215, -0.4832,  0.4469, -0.4050],
        [ 0.1719, -0.3343,  0.2011,  0.3218]], device='cuda:0', requires_grad=True)

層名: token_embedding_tgt.embedding.weight
形状: torch.Size([16, 4])
値: Parameter containing:
tensor([[-0.5115,  0.5080,  0.3903, -0.

## 学習実行

In [21]:
epoch = 100
best_loss = float('Inf')
best_model = None
patience = 10
counter = 0

for loop in range(1, epoch + 1):
    
    start_time = time.time()
    
    loss_train = train(
        model=model, data=train_iter, optimizer=optimizer,
        criterion=criterion, PAD_IDX=PAD_IDX
    )
    
    elapsed_time = time.time() - start_time
    
    loss_valid = evaluate(
        model=model, data=valid_iter, criterion=criterion, PAD_IDX=PAD_IDX
    )
    
    print('[{}/{}] train loss: {:.2f}, valid loss: {:.2f}  [{}{:.0f}s] counter: {} {}'.format(
        loop, epoch,
        loss_train, loss_valid,
        str(int(math.floor(elapsed_time / 60))) + 'm' if math.floor(elapsed_time / 60) > 0 else '',
        elapsed_time % 60,
        counter,
        '**' if best_loss > loss_valid else ''
    ))
    
    if best_loss > loss_valid:
        best_loss = loss_valid
        best_model = model
        counter = 0
        
    if counter > patience:
        break
    
    counter += 1

100%|██████████| 79/79 [00:01<00:00, 54.43it/s]


[1/100] train loss: 2.65, valid loss: 2.44  [1s] counter: 0 **


100%|██████████| 79/79 [00:01<00:00, 62.92it/s]


[2/100] train loss: 2.22, valid loss: 1.99  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.28it/s]


[3/100] train loss: 1.78, valid loss: 1.59  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.05it/s]


[4/100] train loss: 1.46, valid loss: 1.33  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.76it/s]


[5/100] train loss: 1.22, valid loss: 1.12  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.95it/s]


[6/100] train loss: 1.04, valid loss: 0.97  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.67it/s]


[7/100] train loss: 0.93, valid loss: 0.89  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.39it/s]


[8/100] train loss: 0.87, valid loss: 0.84  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.43it/s]


[9/100] train loss: 0.82, valid loss: 0.80  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.96it/s]


[10/100] train loss: 0.79, valid loss: 0.77  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.37it/s]


[11/100] train loss: 0.76, valid loss: 0.74  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 61.86it/s]


[12/100] train loss: 0.73, valid loss: 0.72  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.93it/s]


[13/100] train loss: 0.71, valid loss: 0.71  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.03it/s]


[14/100] train loss: 0.70, valid loss: 0.69  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.38it/s]


[15/100] train loss: 0.68, valid loss: 0.67  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.42it/s]


[16/100] train loss: 0.67, valid loss: 0.66  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.47it/s]


[17/100] train loss: 0.66, valid loss: 0.65  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.43it/s]


[18/100] train loss: 0.65, valid loss: 0.65  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.42it/s]


[19/100] train loss: 0.64, valid loss: 0.64  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.03it/s]


[20/100] train loss: 0.64, valid loss: 0.63  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.47it/s]


[21/100] train loss: 0.63, valid loss: 0.63  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.40it/s]


[22/100] train loss: 0.62, valid loss: 0.62  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.31it/s]


[23/100] train loss: 0.62, valid loss: 0.62  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.72it/s]


[24/100] train loss: 0.61, valid loss: 0.61  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.83it/s]


[25/100] train loss: 0.61, valid loss: 0.61  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.64it/s]


[26/100] train loss: 0.61, valid loss: 0.61  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 63.06it/s]


[27/100] train loss: 0.60, valid loss: 0.60  [1s] counter: 2 **


100%|██████████| 79/79 [00:01<00:00, 63.39it/s]


[28/100] train loss: 0.60, valid loss: 0.60  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.25it/s]


[29/100] train loss: 0.60, valid loss: 0.60  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 61.53it/s]


[30/100] train loss: 0.59, valid loss: 0.60  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 63.23it/s]


[31/100] train loss: 0.59, valid loss: 0.59  [1s] counter: 2 **


100%|██████████| 79/79 [00:01<00:00, 63.34it/s]


[32/100] train loss: 0.59, valid loss: 0.59  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.83it/s]


[33/100] train loss: 0.58, valid loss: 0.59  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 62.59it/s]


[34/100] train loss: 0.58, valid loss: 0.58  [1s] counter: 2 **


100%|██████████| 79/79 [00:01<00:00, 62.78it/s]


[35/100] train loss: 0.58, valid loss: 0.58  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 63.35it/s]


[36/100] train loss: 0.58, valid loss: 0.58  [1s] counter: 2 **


100%|██████████| 79/79 [00:01<00:00, 61.81it/s]


[37/100] train loss: 0.58, valid loss: 0.58  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.35it/s]


[38/100] train loss: 0.58, valid loss: 0.58  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.48it/s]


[39/100] train loss: 0.57, valid loss: 0.57  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.78it/s]


[40/100] train loss: 0.57, valid loss: 0.57  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.42it/s]


[41/100] train loss: 0.57, valid loss: 0.58  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 62.70it/s]


[42/100] train loss: 0.57, valid loss: 0.57  [1s] counter: 2 **


100%|██████████| 79/79 [00:01<00:00, 63.51it/s]


[43/100] train loss: 0.57, valid loss: 0.57  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.84it/s]


[44/100] train loss: 0.57, valid loss: 0.57  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 63.38it/s]


[45/100] train loss: 0.56, valid loss: 0.58  [1s] counter: 2 


100%|██████████| 79/79 [00:01<00:00, 63.32it/s]


[46/100] train loss: 0.56, valid loss: 0.56  [1s] counter: 3 **


100%|██████████| 79/79 [00:01<00:00, 63.28it/s]


[47/100] train loss: 0.56, valid loss: 0.56  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.39it/s]


[48/100] train loss: 0.56, valid loss: 0.57  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 63.34it/s]


[49/100] train loss: 0.56, valid loss: 0.56  [1s] counter: 2 


100%|██████████| 79/79 [00:01<00:00, 62.50it/s]


[50/100] train loss: 0.56, valid loss: 0.56  [1s] counter: 3 **


100%|██████████| 79/79 [00:01<00:00, 62.55it/s]


[51/100] train loss: 0.56, valid loss: 0.56  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.42it/s]


[52/100] train loss: 0.56, valid loss: 0.56  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 63.40it/s]


[53/100] train loss: 0.56, valid loss: 0.56  [1s] counter: 2 **


100%|██████████| 79/79 [00:01<00:00, 63.36it/s]


[54/100] train loss: 0.55, valid loss: 0.55  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.71it/s]


[55/100] train loss: 0.55, valid loss: 0.56  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 62.80it/s]


[56/100] train loss: 0.55, valid loss: 0.55  [1s] counter: 2 **


100%|██████████| 79/79 [00:01<00:00, 63.34it/s]


[57/100] train loss: 0.55, valid loss: 0.55  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.30it/s]


[58/100] train loss: 0.55, valid loss: 0.55  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 63.34it/s]


[59/100] train loss: 0.55, valid loss: 0.56  [1s] counter: 2 


100%|██████████| 79/79 [00:01<00:00, 62.31it/s]


[60/100] train loss: 0.55, valid loss: 0.55  [1s] counter: 3 **


100%|██████████| 79/79 [00:01<00:00, 62.68it/s]


[61/100] train loss: 0.55, valid loss: 0.55  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 61.42it/s]


[62/100] train loss: 0.55, valid loss: 0.55  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.03it/s]


[63/100] train loss: 0.54, valid loss: 0.55  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 62.82it/s]


[64/100] train loss: 0.54, valid loss: 0.55  [1s] counter: 2 


100%|██████████| 79/79 [00:01<00:00, 63.12it/s]


[65/100] train loss: 0.54, valid loss: 0.55  [1s] counter: 3 **


100%|██████████| 79/79 [00:01<00:00, 63.30it/s]


[66/100] train loss: 0.54, valid loss: 0.54  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.35it/s]


[67/100] train loss: 0.54, valid loss: 0.54  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 62.96it/s]


[68/100] train loss: 0.54, valid loss: 0.54  [1s] counter: 2 **


100%|██████████| 79/79 [00:01<00:00, 63.28it/s]


[69/100] train loss: 0.54, valid loss: 0.54  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.03it/s]


[70/100] train loss: 0.54, valid loss: 0.54  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 63.30it/s]


[71/100] train loss: 0.54, valid loss: 0.54  [1s] counter: 2 


100%|██████████| 79/79 [00:01<00:00, 63.31it/s]


[72/100] train loss: 0.54, valid loss: 0.54  [1s] counter: 3 **


100%|██████████| 79/79 [00:01<00:00, 63.14it/s]


[73/100] train loss: 0.54, valid loss: 0.54  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 62.04it/s]


[74/100] train loss: 0.53, valid loss: 0.53  [1s] counter: 2 **


100%|██████████| 79/79 [00:01<00:00, 63.31it/s]


[75/100] train loss: 0.53, valid loss: 0.54  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 62.35it/s]


[76/100] train loss: 0.54, valid loss: 0.54  [1s] counter: 2 


100%|██████████| 79/79 [00:01<00:00, 62.80it/s]


[77/100] train loss: 0.53, valid loss: 0.53  [1s] counter: 3 **


100%|██████████| 79/79 [00:01<00:00, 62.17it/s]


[78/100] train loss: 0.53, valid loss: 0.53  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.43it/s]


[79/100] train loss: 0.53, valid loss: 0.53  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.26it/s]


[80/100] train loss: 0.53, valid loss: 0.53  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 62.39it/s]


[81/100] train loss: 0.53, valid loss: 0.53  [1s] counter: 2 **


100%|██████████| 79/79 [00:01<00:00, 63.41it/s]


[82/100] train loss: 0.52, valid loss: 0.53  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.32it/s]


[83/100] train loss: 0.53, valid loss: 0.53  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 63.27it/s]


[84/100] train loss: 0.52, valid loss: 0.53  [1s] counter: 2 


100%|██████████| 79/79 [00:01<00:00, 62.76it/s]


[85/100] train loss: 0.52, valid loss: 0.53  [1s] counter: 3 


100%|██████████| 79/79 [00:01<00:00, 63.27it/s]


[86/100] train loss: 0.52, valid loss: 0.52  [1s] counter: 4 **


100%|██████████| 79/79 [00:01<00:00, 62.67it/s]


[87/100] train loss: 0.52, valid loss: 0.52  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 62.03it/s]


[88/100] train loss: 0.52, valid loss: 0.52  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 63.31it/s]


[89/100] train loss: 0.52, valid loss: 0.52  [1s] counter: 2 **


100%|██████████| 79/79 [00:01<00:00, 61.28it/s]


[90/100] train loss: 0.52, valid loss: 0.52  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.28it/s]


[91/100] train loss: 0.51, valid loss: 0.52  [1s] counter: 1 **


100%|██████████| 79/79 [00:01<00:00, 63.27it/s]


[92/100] train loss: 0.51, valid loss: 0.53  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 62.01it/s]


[93/100] train loss: 0.53, valid loss: 0.51  [1s] counter: 2 **


100%|██████████| 79/79 [00:01<00:00, 63.40it/s]


[94/100] train loss: 0.51, valid loss: 0.52  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 63.39it/s]


[95/100] train loss: 0.51, valid loss: 0.52  [1s] counter: 2 


100%|██████████| 79/79 [00:01<00:00, 63.38it/s]


[96/100] train loss: 0.52, valid loss: 0.52  [1s] counter: 3 


100%|██████████| 79/79 [00:01<00:00, 63.35it/s]


[97/100] train loss: 0.51, valid loss: 0.51  [1s] counter: 4 **


100%|██████████| 79/79 [00:01<00:00, 63.36it/s]


[98/100] train loss: 0.51, valid loss: 0.52  [1s] counter: 1 


100%|██████████| 79/79 [00:01<00:00, 63.33it/s]


[99/100] train loss: 0.51, valid loss: 0.51  [1s] counter: 2 


100%|██████████| 79/79 [00:01<00:00, 63.25it/s]


[100/100] train loss: 0.51, valid loss: 0.51  [1s] counter: 3 **


学習したモデルの保存

In [22]:
torch.save(best_model.state_dict(), model_dir_path.joinpath(version + 'translation_transfomer.pth'))

学習したモデルを使って翻訳をする

In [23]:
def translate(
    model, text, vocab_src, vocab_tgt, seq_len_tgt,
    START_IDX, END_IDX
):
    model.eval()
    tokens_src = convert_text_to_indexes(text, vocab=vocab_src)
    num_tokens_src = len(tokens_src)

    # Tensorに変換
    src = torch.LongTensor(tokens_src).reshape(num_tokens_src, 1).to(device)
    mask_src = torch.zeros((num_tokens_src, num_tokens_src), device=device).type(torch.bool)

    # デコード
    predicts = greedy_decode(
        model=model, src=src,
        mask_src=mask_src, seq_len_tgt=seq_len_tgt,
        START_IDX=START_IDX, END_IDX=END_IDX
    ).flatten()

    return convert_indexes_to_text(predicts.tolist(), vocab=vocab_tgt)

def greedy_decode(model, src, mask_src, seq_len_tgt, START_IDX, END_IDX):
    src = src.to(device)
    mask_src = mask_src.to(device)

    # ソースの埋め込みをメモリとして利用
    memory = model.positional_encoding(model.token_embedding_src(src))
    
    ys = torch.ones(1, 1).fill_(START_IDX).type(torch.long).to(device)
    
    for i in range(seq_len_tgt - 1):
        memory = memory.to(device)
        mask_tgt = generate_square_subsequent_mask(ys.size(0)).to(device).type(torch.bool)
        
        output = model.decode(ys, memory, mask_tgt)
        output = output.transpose(0, 1)
        output = model.output(output[:, -1])
        
        # 最も高いスコアのトークンを取得
        _, next_word = torch.max(output, dim=1)
        next_word = next_word.item()

        # 生成されたトークンを追加
        ys = torch.cat([ys, torch.ones(1, 1).fill_(next_word).type_as(src.data)], dim=0)
        if next_word == END_IDX:
            break
    
    return ys


In [24]:
seq_len_tgt = max([len(x[1]) for x in train_data])
text = '- 5 2'

# 翻訳を実行
translation = translate(
    model=best_model, text=text, vocab_src=vocab_src, vocab_tgt=vocab_tgt,
    seq_len_tgt=seq_len_tgt,
    START_IDX=START_IDX, END_IDX=END_IDX
)

print(f"Input: {text}")
print(f"Output: {translation}")

Input: - 5 2
Output: 9 5 0 +


In [25]:
# 様々な入力を試してみる

text_list = { '- 2 7':'2 7 -', '- 9 7' : '9 7 -', '+ - 2 7 4' : '2 7 - 4 +', '- - 6 9 - 7 3' : '6 9 - 7 3 - -', '- - + + + + + 4 6 3 7 + 6 7 1 9 + + 3 6 2' : '4 6 + 3 + 7 + 6 7 + + 1 + 9 - 3 6 + 2 + - '}

for text, tgt in text_list.items():
    translation = translate(
        model=best_model, text=text, vocab_src=vocab_src, vocab_tgt=vocab_tgt,
        seq_len_tgt=seq_len_tgt,
        START_IDX=START_IDX, END_IDX=END_IDX
    )
    print(f"Input: {text}")
    print(f"Output: {translation}")
    print(f"Target: {tgt}")
    print('---')




Input: - 2 7
Output: 
Target: 2 7 -
---
Input: - 9 7
Output: 
Target: 9 7 -
---
Input: + - 2 7 4
Output: 
Target: 2 7 - 4 +
---
Input: - - 6 9 - 7 3
Output: +
Target: 6 9 - 7 3 - -
---
Input: - - + + + + + 4 6 3 7 + 6 7 1 9 + + 3 6 2
Output: 7 + + 5 + + + 
Target: 4 6 + 3 + 7 + 6 7 + + 1 + 9 - 3 6 + 2 + - 
---


## モデルの動作を分析

In [26]:
import torch

# モデルのロード
model_path = model_dir_path.joinpath(version + 'translation_transfomer.pth')
loaded_model = Seq2SeqTransformer(
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    embedding_size=embedding_size,
    vocab_size_src=vocab_size_src, vocab_size_tgt=vocab_size_tgt,
    dim_feedforward=dim_feedforward,
    dropout=dropout, nhead=nhead
).to(device)
loaded_model.load_state_dict(torch.load(model_path))
loaded_model.eval()


Seq2SeqTransformer(
  (token_embedding_src): TokenEmbedding(
    (embedding): Embedding(16, 4, padding_idx=11)
  )
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0, inplace=False)
  )
  (token_embedding_tgt): TokenEmbedding(
    (embedding): Embedding(16, 4, padding_idx=11)
  )
  (decoder_layer): TransformerDecoderLayerScratch(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (linear1): Linear(in_features=4, out_features=4, bias=True)
    (dropout): Dropout(p=0, inplace=False)
    (linear2): Linear(in_features=4, out_features=4, bias=True)
    (norm1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (d

In [27]:
# パラメータを取り出す


# モデルのパラメータを取得
params = dict(loaded_model.named_parameters())
print(params.keys())

# 埋め込み行列を取得
embedding_src_weight = params['token_embedding_src.embedding.weight'].data
embedding_tgt_weight = params['token_embedding_tgt.embedding.weight'].data

# 線形層の重みとバイアス
output_weight = params['output.weight'].data
output_bias = params['output.bias'].data

# デコーダの自己注意の重みとバイアス
self_attn_in_proj_weight = params['decoder_layer.self_attn.in_proj_weight'].data
self_attn_in_proj_bias = params['decoder_layer.self_attn.in_proj_bias'].data
self_attn_out_proj_weight = params['decoder_layer.self_attn.out_proj.weight'].data
self_attn_out_proj_bias = params['decoder_layer.self_attn.out_proj.bias'].data

# メモリー注意の重みとバイアス
multihead_attn_in_proj_weight = params['decoder_layer.multihead_attn.in_proj_weight'].data
multihead_attn_in_proj_bias = params['decoder_layer.multihead_attn.in_proj_bias'].data
multihead_attn_out_proj_weight = params['decoder_layer.multihead_attn.out_proj.weight'].data
multihead_attn_out_proj_bias = params['decoder_layer.multihead_attn.out_proj.bias'].data

# フィードフォワードネットワークの重みとバイアス
linear1_weight = params['decoder_layer.linear1.weight'].data
linear1_bias = params['decoder_layer.linear1.bias'].data
linear2_weight = params['decoder_layer.linear2.weight'].data
linear2_bias = params['decoder_layer.linear2.bias'].data

# LayerNormのパラメータ
norm1_weight = params['decoder_layer.norm1.weight'].data
norm1_bias = params['decoder_layer.norm1.bias'].data
norm2_weight = params['decoder_layer.norm2.weight'].data
norm2_bias = params['decoder_layer.norm2.bias'].data
norm3_weight = params['decoder_layer.norm3.weight'].data
norm3_bias = params['decoder_layer.norm3.bias'].data


dict_keys(['token_embedding_src.embedding.weight', 'token_embedding_tgt.embedding.weight', 'decoder_layer.self_attn.in_proj_weight', 'decoder_layer.self_attn.in_proj_bias', 'decoder_layer.self_attn.out_proj.weight', 'decoder_layer.self_attn.out_proj.bias', 'decoder_layer.multihead_attn.in_proj_weight', 'decoder_layer.multihead_attn.in_proj_bias', 'decoder_layer.multihead_attn.out_proj.weight', 'decoder_layer.multihead_attn.out_proj.bias', 'decoder_layer.linear1.weight', 'decoder_layer.linear1.bias', 'decoder_layer.linear2.weight', 'decoder_layer.linear2.bias', 'decoder_layer.norm1.weight', 'decoder_layer.norm1.bias', 'decoder_layer.norm2.weight', 'decoder_layer.norm2.bias', 'decoder_layer.norm3.weight', 'decoder_layer.norm3.bias', 'output.weight', 'output.bias'])


In [28]:

# Positional Encoding
def positional_encoding(tensor: Tensor, maxlen=5000):
    embedding_size = tensor.size(-1)
    den = torch.exp(-torch.arange(0, embedding_size, 2) * math.log(10000) / embedding_size)
    pos = torch.arange(0, maxlen).reshape(maxlen, 1)
    embedding_pos = torch.zeros((maxlen, embedding_size))
    embedding_pos[:, 0::2] = torch.sin(pos * den)
    embedding_pos[:, 1::2] = torch.cos(pos * den)
    embedding_pos = embedding_pos.unsqueeze(-2)
    return tensor + embedding_pos[: tensor.size(0), :].to(tensor.device)

In [29]:

# 翻訳処理を実行
seq_len_tgt = max([len(x[1]) for x in train_data])
text = '+ 3 1'

tokens_src = convert_text_to_indexes(text, vocab=vocab_src)
src = torch.LongTensor(tokens_src).reshape(len(tokens_src), 1).to(device)
memory = positional_encoding(embedding_src_weight[src] * math.sqrt(embedding_size))
ys = torch.ones(1, 1).fill_(START_IDX).type(torch.long).to(device)

for i in range(10):
    tgt_embed = positional_encoding(embedding_tgt_weight[ys] * math.sqrt(embedding_size))
    tgt_mask = generate_square_subsequent_mask(ys.size(0)).to(device).type(torch.bool)
    
    # Self-attention
    self_attn_wq, self_attn_wk, self_attn_wv = self_attn_in_proj_weight.chunk(3, dim=0)
    self_attn_bq, self_attn_bk, self_attn_bv = self_attn_in_proj_bias.chunk(3, dim=0)
    QW = torch.matmul(tgt_embed.permute(1, 0, 2), self_attn_wq.T) + self_attn_bq
    KW = torch.matmul(tgt_embed.permute(1, 0, 2), self_attn_wk.T) + self_attn_bk
    VW = torch.matmul(tgt_embed.permute(1, 0, 2), self_attn_wv.T) + self_attn_bv
    self_attn_weights = F.softmax(torch.bmm(QW, KW.transpose(-2, -1)) / math.sqrt(embedding_size), dim=-1)

    AV = torch.matmul(self_attn_weights, VW)
    self_attn_output = torch.matmul(AV, self_attn_out_proj_weight.T) + self_attn_out_proj_bias
    self_attn_output = self_attn_output.permute(1, 0, 2)
    tgt = tgt_embed + self_attn_output
    tgt = loaded_model.decoder_layer.norm1(tgt)



    # Attention with the encoder outputs (memory)
    multi_attn_wq, multi_attn_wk, multi_attn_wv = multihead_attn_in_proj_weight.chunk(3, dim=0)
    multi_attn_bq, multi_attn_bk, multi_attn_bv = multihead_attn_in_proj_bias.chunk(3, dim=0)
    QW = torch.matmul(tgt.permute(1, 0, 2), multi_attn_wq.T) + multi_attn_bq
    KW = torch.matmul(memory.permute(1, 0, 2), multi_attn_wk.T) + multi_attn_bk
    VW = torch.matmul(memory.permute(1, 0, 2), multi_attn_wv.T) + multi_attn_bv
    multi_attn_weights = F.softmax(torch.bmm(QW, KW.transpose(-2, -1)) / math.sqrt(embedding_size), dim=-1)

    AV = torch.matmul(multi_attn_weights, VW)
    multi_attn_output = torch.matmul(AV, multihead_attn_out_proj_weight.T) + multihead_attn_out_proj_bias
    multi_attn_output = multi_attn_output.permute(1, 0, 2)
    tgt = tgt + multi_attn_output
    tgt = loaded_model.decoder_layer.norm2(tgt)
    
    # Feedforward network

    # decoder linear1, 2
    tgt2 = tgt.matmul(linear1_weight.T) + linear1_bias
    tgt2 = F.relu(tgt2)
    tgt2 = tgt2.matmul(linear2_weight.T) + linear2_bias
    tgt = tgt + tgt2

    # LayerNorm
    tgt = loaded_model.decoder_layer.norm3(tgt)

    output = tgt.transpose(0, 1)
    output = loaded_model.output(output[:, -1])

    _, next_word = torch.max(output, dim=1)
    next_word = next_word.item()

    ys = torch.cat([ys, torch.ones(1, 1).fill_(next_word).type_as(src.data)], dim=0)
    print(next_word)
    
    if next_word == END_IDX:
        break


flat_indexes = [idx for sublist in ys.tolist() for idx in sublist] if isinstance(ys.tolist()[0], list) else ys.tolist()

print(f"Input: {text}")
print(f"Decoded sequence: {convert_indexes_to_text(flat_indexes, vocab_tgt)}")

6
14
5
14
15
13
Input: + 3 1
Decoded sequence: 6 5 +


# Transformerの検算

スクラッチで書くための検算

## Multihead Attention

Multihead Attentionの動作をスクラッチで書きたいので、ここで検算する

参考サイト
https://blog.amedama.jp/entry/pytorch-multi-head-attention-verify

In [30]:
import torch
from torch import nn
import torch.nn.functional as F


In [31]:
edim = 4 # 埋め込み次元
num_heads = 1 # ヘッド数
model = nn.MultiheadAttention(edim, num_heads, bias=True, batch_first=True)

In [32]:
batch_size = 2
L=5
X = torch.randn(batch_size, L, edim) # 入力

Q = K = V = X # クエリ、キー、バリューは全て入力とする
print(Q.shape)
print(Q)

torch.Size([2, 5, 4])
tensor([[[ 1.7931,  0.8052,  1.8185, -0.5266],
         [-0.2902, -1.0391,  1.2279,  0.3198],
         [-0.4066, -0.4179,  0.9050,  0.0436],
         [-0.1240, -0.1971, -0.2292, -0.1083],
         [-0.7328, -0.2042,  0.5302, -0.1011]],

        [[-0.8783,  0.6967,  0.2382,  1.8613],
         [ 1.1250,  0.0673,  0.8159,  0.9802],
         [-1.0651, -1.2724,  0.0781,  1.0047],
         [-0.2543,  1.3695,  0.5802,  0.8394],
         [ 0.2736, -0.1948, -1.4705,  0.6487]]])


In [33]:

attn_output, attn_output_weights = model(Q, K, V)

print(attn_output.shape)
print(attn_output)



torch.Size([2, 5, 4])
tensor([[[-0.1039, -0.3863, -0.3572,  0.3019],
         [-0.1499, -0.2646, -0.2513,  0.2433],
         [-0.1414, -0.3110, -0.2908,  0.2715],
         [-0.1447, -0.3263, -0.3034,  0.2840],
         [-0.1425, -0.3196, -0.2979,  0.2784]],

        [[ 0.2660,  0.0807,  0.2903,  0.0049],
         [ 0.2797,  0.1043,  0.3195, -0.0153],
         [ 0.2673,  0.1238,  0.3190, -0.0267],
         [ 0.2790,  0.0579,  0.2777,  0.0162],
         [ 0.2307,  0.0473,  0.2784,  0.0496]]], grad_fn=<TransposeBackward0>)


In [34]:
from pprint import pprint
pprint(list(model.named_parameters()))

[('in_proj_weight',
  Parameter containing:
tensor([[ 0.3843,  0.4464,  0.4579, -0.5462],
        [-0.2418, -0.0685,  0.4328,  0.2146],
        [-0.2437, -0.5598,  0.3846, -0.1159],
        [ 0.0128,  0.5299,  0.0290, -0.2573],
        [-0.0222,  0.2149, -0.6024,  0.5504],
        [ 0.2653, -0.0621, -0.2771,  0.2197],
        [ 0.0402, -0.2455,  0.1905, -0.3444],
        [-0.3170, -0.3291,  0.2684, -0.3979],
        [ 0.3349,  0.4198, -0.4257, -0.1142],
        [ 0.3138,  0.2361, -0.3218, -0.5586],
        [ 0.5523,  0.4972, -0.0885,  0.5812],
        [ 0.3345, -0.5090, -0.3253,  0.5553]], requires_grad=True)),
 ('in_proj_bias',
  Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)),
 ('out_proj.weight',
  Parameter containing:
tensor([[ 0.1698, -0.1021,  0.0235,  0.4659],
        [ 0.1619,  0.4810,  0.4507,  0.2649],
        [ 0.4789, -0.0274,  0.4383,  0.1554],
        [-0.0916, -0.4968, -0.2360, -0.3673]], requires_grad=True)),
 ('out_p

In [35]:
model_weight = {name: param.data for name, param in model.named_parameters()}
Wi = model_weight['in_proj_weight']
Wo = model_weight['out_proj.weight']
Wbi = model_weight['in_proj_bias']
Wbo = model_weight['out_proj.bias']

In [36]:
Wi_q, Wi_k, Wi_v = Wi.chunk(3, dim=0)
Wbi_q, Wbi_k, Wbi_v = Wbi.chunk(3, dim=0)
QW = torch.matmul(Q, Wi_q.T) + Wbi_q
KW = torch.matmul(K, Wi_k.T) + Wbi_k
VW = torch.matmul(V, Wi_v.T) + Wbi_v

KW_t = KW.transpose(-2, -1)
QK_t = torch.bmm(QW, KW_t)
QK_scaled = QK_t / (edim ** 0.5)
attn_weights_ = F.softmax(QK_scaled, dim=-1)

In [37]:
print(attn_weights_)
print(attn_output_weights)

tensor([[[0.0758, 0.1607, 0.1870, 0.3443, 0.2321],
         [0.2653, 0.1911, 0.1807, 0.1950, 0.1679],
         [0.2195, 0.1975, 0.1911, 0.2067, 0.1852],
         [0.2169, 0.2016, 0.1991, 0.1873, 0.1952],
         [0.2183, 0.2028, 0.1947, 0.1961, 0.1881]],

        [[0.1658, 0.2583, 0.1982, 0.1998, 0.1780],
         [0.2186, 0.1855, 0.1685, 0.1888, 0.2386],
         [0.1195, 0.3482, 0.1826, 0.1975, 0.1522],
         [0.2222, 0.1668, 0.2025, 0.1818, 0.2268],
         [0.1864, 0.2382, 0.1946, 0.2467, 0.1341]]])
tensor([[[0.0758, 0.1607, 0.1870, 0.3443, 0.2321],
         [0.2653, 0.1911, 0.1807, 0.1950, 0.1679],
         [0.2195, 0.1975, 0.1911, 0.2067, 0.1852],
         [0.2169, 0.2016, 0.1991, 0.1873, 0.1952],
         [0.2183, 0.2028, 0.1947, 0.1961, 0.1881]],

        [[0.1658, 0.2583, 0.1982, 0.1998, 0.1780],
         [0.2186, 0.1855, 0.1685, 0.1888, 0.2386],
         [0.1195, 0.3482, 0.1826, 0.1975, 0.1522],
         [0.2222, 0.1668, 0.2025, 0.1818, 0.2268],
         [0.1864, 0.2382,

In [38]:
AV = torch.matmul(attn_weights_, VW)
attn_output_ = torch.matmul(AV, Wo.T) + Wbo

In [39]:
print(attn_output_)
print(attn_output)

tensor([[[-0.1039, -0.3863, -0.3572,  0.3019],
         [-0.1499, -0.2646, -0.2513,  0.2433],
         [-0.1414, -0.3110, -0.2908,  0.2715],
         [-0.1447, -0.3263, -0.3034,  0.2840],
         [-0.1425, -0.3196, -0.2979,  0.2784]],

        [[ 0.2660,  0.0807,  0.2903,  0.0049],
         [ 0.2797,  0.1043,  0.3195, -0.0153],
         [ 0.2673,  0.1238,  0.3190, -0.0267],
         [ 0.2790,  0.0579,  0.2777,  0.0162],
         [ 0.2307,  0.0473,  0.2784,  0.0496]]])
tensor([[[-0.1039, -0.3863, -0.3572,  0.3019],
         [-0.1499, -0.2646, -0.2513,  0.2433],
         [-0.1414, -0.3110, -0.2908,  0.2715],
         [-0.1447, -0.3263, -0.3034,  0.2840],
         [-0.1425, -0.3196, -0.2979,  0.2784]],

        [[ 0.2660,  0.0807,  0.2903,  0.0049],
         [ 0.2797,  0.1043,  0.3195, -0.0153],
         [ 0.2673,  0.1238,  0.3190, -0.0267],
         [ 0.2790,  0.0579,  0.2777,  0.0162],
         [ 0.2307,  0.0473,  0.2784,  0.0496]]], grad_fn=<TransposeBackward0>)


## nn.Linear

In [40]:
model = nn.Linear(4, 4)
model

Linear(in_features=4, out_features=4, bias=True)

In [41]:
pprint(list(model.named_parameters()))

[('weight',
  Parameter containing:
tensor([[ 0.0918,  0.2403, -0.3218, -0.1812],
        [-0.0175,  0.2324,  0.3647, -0.4856],
        [ 0.3877, -0.3938,  0.0295,  0.4861],
        [ 0.0994, -0.1966,  0.0058, -0.4141]], requires_grad=True)),
 ('bias',
  Parameter containing:
tensor([-0.0700,  0.3236, -0.4422,  0.4141], requires_grad=True))]


In [42]:
model_weight = {name: param.data for name, param in model.named_parameters()}
W = model_weight['weight']
B = model_weight['bias']

X = torch.randn(4) 
print(X.shape)
print(X)
output = model(X)
print(output.shape)
print(output)


torch.Size([4])
tensor([ 1.8139, -0.5194, -0.3596, -0.6756])
torch.Size([4])
tensor([0.2098, 0.3680, 0.1266, 0.9741], grad_fn=<ViewBackward0>)


In [43]:
output_ = X.matmul(W.T) + B
print(output_)
print(output)

tensor([0.2098, 0.3680, 0.1266, 0.9741])
tensor([0.2098, 0.3680, 0.1266, 0.9741], grad_fn=<ViewBackward0>)


## nn.LayerNorm

参考サイト
https://qiita.com/dl_from_scratch/items/133fe741b67ed14f1856

In [44]:
model = nn.LayerNorm(4)
model

LayerNorm((4,), eps=1e-05, elementwise_affine=True)

In [45]:
pprint(list(model.named_parameters()))

[('weight', Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True))]


In [46]:
model_weight = {name: param.data for name, param in model.named_parameters()}
W = model_weight['weight']
B = model_weight['bias']

X = torch.randn(4) 
print(X.shape)
print(X)
output = model(X)
print(output.shape)
print(output)


torch.Size([4])
tensor([-0.9687, -0.4609, -0.4245,  0.7778])
torch.Size([4])
tensor([-1.0905, -0.2989, -0.2423,  1.6317], grad_fn=<NativeLayerNormBackward0>)


In [47]:
output_ = X.matmul(W.T) + B
print(output_)
print(output)

tensor([-1.0763, -1.0763, -1.0763, -1.0763])
tensor([-1.0905, -0.2989, -0.2423,  1.6317], grad_fn=<NativeLayerNormBackward0>)


  output_ = X.matmul(W.T) + B
