# FomulaBEAT

変更点
- デコーダのみで学習させる
- TransformerDecoderLayerをスクラッチで書く


In [1]:
version = '03'
model_dir = './model/' + version
data_path = 'data/eq03.txt'

In [2]:
from pathlib import Path
import math
import time
from collections import Counter
from tqdm import tqdm
import torch
from torch.utils.data import random_split
import torch.nn as nn
from torch import Tensor
from torch.nn import (
    TransformerEncoder, TransformerDecoder,
    TransformerEncoderLayer, TransformerDecoderLayer
)
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import vocab
from torchtext.utils import download_from_url, extract_archive
import torch.nn.functional as F




パラメータの事前設定

In [3]:
%load_ext autoreload
%autoreload 2
torch.set_printoptions(linewidth=100)

In [4]:

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model_dir_path = Path(model_dir)
if not model_dir_path.exists():
    model_dir_path.mkdir(parents=True)

データの取得

In [5]:
def read_equation_file(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()
    src_data, tgt_data = [], []
    for line in lines:
        src, tgt = line.strip().split(' => ')
        src_data.append(src)
        tgt_data.append(tgt)
    return src_data, tgt_data


In [6]:
# ファイルを読み込み、数式データを取得
src_data, tgt_data = read_equation_file(data_path)
print(src_data[:3], tgt_data[:3])


['- - 2 + 0 1 7', '- 5 1', '+ 6 8'] ['2 0 1 + - 7 -', '5 1 -', '6 8 +']


辞書データの作成

In [7]:

SPECIALS = ['<unk>', '<pad>', '<start>', '<end>']

def build_vocab(texts):
    vocab = {}
    idx = 0
    # 数字の語彙定義
    for i in range(10):
        vocab[str(i)] = idx
        idx += 1
    # 特別語の語彙定義
    for sp in SPECIALS:
        vocab[sp] = idx
        idx += 1
    # その他の文字の語彙定義
    for text in texts:
        for char in text:
            if char not in vocab:
                vocab[char] = idx
                idx += 1
    return vocab


def convert_text_to_indexes(text, vocab):
    # <start> と <end> トークンを追加して数値化
    return [vocab['<start>']] + [vocab[char] if char in vocab else vocab['<unk>'] for char in text] + [vocab['<end>']]

# データを処理して Train と Valid に分ける関数
# データを処理して Train と Valid に分ける関数
def data_process_split(src_texts, tgt_texts, vocab_src, vocab_tgt, valid_size=0.2):
    # データを数値化
    data = []
    for (src, tgt) in zip(src_texts, tgt_texts):
        src_tensor = torch.tensor(convert_text_to_indexes(src, vocab_src), dtype=torch.long)
        tgt_tensor = torch.tensor(convert_text_to_indexes(tgt, vocab_tgt), dtype=torch.long)
        data.append((src_tensor, tgt_tensor))
    
    # データのサイズを計算して、訓練データと検証データに分割
    data_size = len(data)
    valid_size = int(valid_size * data_size)
    train_size = data_size - valid_size

    # PyTorchのrandom_splitを使って分割
    train_data, valid_data = random_split(data, [train_size, valid_size])
    
    return train_data, valid_data



In [8]:
# 辞書と逆辞書を構築
vocab_src = build_vocab(src_data)
vocab_tgt = build_vocab(tgt_data)

print(vocab_tgt)

{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '<unk>': 10, '<pad>': 11, '<start>': 12, '<end>': 13, ' ': 14, '+': 15, '-': 16}


In [9]:

# データを数値化
train_data, valid_data = data_process_split(src_data, tgt_data, vocab_src, vocab_tgt)

# 結果の確認
print('インデックス化された文章')
print(f"Input: {train_data[0][0]}\nOutput: {train_data[0][1]}")

# インデックスから元の文字列に戻す
def convert_indexes_to_text(indexes:list, vocab):
    reverse_vocab = {idx: token for token, idx in vocab.items()}
    return ''.join([reverse_vocab[idx] for idx in indexes if idx in reverse_vocab and reverse_vocab[idx] not in ['<start>', '<end>', '<pad>']])

print('元に戻した文章')
print(f"Input: {convert_indexes_to_text(train_data[0][0].tolist(), vocab_src)}")
print(f"Output: {convert_indexes_to_text(train_data[0][1].tolist(), vocab_tgt)}")


インデックス化された文章
Input: tensor([12, 14, 15,  9, 15,  7, 13])
Output: tensor([12,  9, 14,  7, 14, 16, 13])
元に戻した文章
Input: - 9 7
Output: 9 7 -


In [10]:
batch_size = 128
PAD_IDX = vocab_src['<pad>']
START_IDX = vocab_src['<start>']
END_IDX = vocab_src['<end>']

def generate_batch(data_batch):
    
    batch_src, batch_tgt = [], []
    for src, tgt in data_batch:
        batch_src.append(src)
        batch_tgt.append(tgt)
        
    batch_src = pad_sequence(batch_src, padding_value=PAD_IDX)
    batch_tgt = pad_sequence(batch_tgt, padding_value=PAD_IDX)
    
    return batch_src, batch_tgt

train_iter = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(valid_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)

In [11]:
len(train_data)

800000

Transoformerの設定

In [12]:
class TokenEmbedding(nn.Module):
    
    def __init__(self, vocab_size, embedding_size):
        
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=PAD_IDX)
        self.embedding_size = embedding_size
        
    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.embedding_size)
    
    
class PositionalEncoding(nn.Module):
    
    def __init__(self, embedding_size: int, dropout: float, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        
        den = torch.exp(-torch.arange(0, embedding_size, 2) * math.log(10000) / embedding_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        embedding_pos = torch.zeros((maxlen, embedding_size))
        embedding_pos[:, 0::2] = torch.sin(pos * den)
        embedding_pos[:, 1::2] = torch.cos(pos * den)
        embedding_pos = embedding_pos.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('embedding_pos', embedding_pos)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.embedding_pos[: token_embedding.size(0), :])


In [13]:

class TransformerDecoderLayerScratch(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):
        super(TransformerDecoderLayerScratch, self).__init__()
        # Self-attention for the decoder
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Multihead attention for attending to encoder outputs (memory)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Feedforward layers
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        # Layer normalization layers
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # Self-attention
        tgt2, _ = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        
        # Attention with the encoder outputs (memory)
        tgt2, _ = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        
        # Feedforward network
        tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)

        return tgt


In [14]:

class Seq2SeqTransformer(nn.Module):
    
    def __init__(
        self, num_encoder_layers: int, num_decoder_layers: int,
        embedding_size: int, vocab_size_src: int, vocab_size_tgt: int,
        dim_feedforward:int = 512, dropout:float = 0.1, nhead:int = 8
    ):
        
        super(Seq2SeqTransformer, self).__init__()

        self.token_embedding_src = TokenEmbedding(vocab_size_src, embedding_size)
        self.positional_encoding = PositionalEncoding(embedding_size, dropout=dropout)
        
        self.token_embedding_tgt = TokenEmbedding(vocab_size_tgt, embedding_size)
        self.decoder_layer = TransformerDecoderLayerScratch(
            d_model=embedding_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout
        )
        
        self.output = nn.Linear(embedding_size, vocab_size_tgt)

    def forward(
        self, src: Tensor, tgt: Tensor,
        mask_src: Tensor, mask_tgt: Tensor,
        padding_mask_src: Tensor, padding_mask_tgt: Tensor,
        memory_key_padding_mask: Tensor
    ):
        embedding_src = self.positional_encoding(self.token_embedding_src(src))
        # memory = self.transformer_encoder(embedding_src, mask_src, padding_mask_src)
        embedding_tgt = self.positional_encoding(self.token_embedding_tgt(tgt))
        outs = self.decoder_layer(
            embedding_tgt, embedding_src, mask_tgt, None,
            padding_mask_tgt, memory_key_padding_mask
        )
        return self.output(outs)

    def decode(self, tgt: Tensor, memory: Tensor, mask_tgt: Tensor):
        return self.decoder_layer(self.positional_encoding(self.token_embedding_tgt(tgt)), memory, mask_tgt)

In [15]:
def create_mask(src, tgt, PAD_IDX):
    
    seq_len_src = src.shape[0]
    seq_len_tgt = tgt.shape[0]

    mask_src = torch.zeros((seq_len_src, seq_len_src), device=device).type(torch.bool)
    mask_tgt = generate_square_subsequent_mask(seq_len_tgt)

    padding_mask_src = (src == PAD_IDX).transpose(0, 1)
    padding_mask_tgt = (tgt == PAD_IDX).transpose(0, 1)
    
    return mask_src, mask_tgt, padding_mask_src, padding_mask_tgt


def generate_square_subsequent_mask(seq_len):
    mask = (torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

学習の定義

In [16]:
def train(model, data, optimizer, criterion, PAD_IDX):
    
    model.train()
    losses = 0
    for src, tgt in tqdm(data):
        
        src = src.to(device)
        tgt = tgt.to(device)

        input_tgt = tgt[:-1, :]

        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_IDX)

        logits = model(
            src=src, tgt=input_tgt,
            mask_src=mask_src, mask_tgt=mask_tgt,
            padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
            memory_key_padding_mask=padding_mask_src
        )

        optimizer.zero_grad()
        output_tgt = tgt[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()
        
    return losses / len(data)

In [17]:

def evaluate(model, data, criterion, PAD_IDX):
    
    model.eval()
    losses = 0
    for src, tgt in data:
        
        src = src.to(device)
        tgt = tgt.to(device)

        input_tgt = tgt[:-1, :]

        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_IDX)

        logits = model(
            src=src, tgt=input_tgt,
            mask_src=mask_src, mask_tgt=mask_tgt,
            padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
            memory_key_padding_mask=padding_mask_src
        )
        
        output_tgt = tgt[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        losses += loss.item()
        
    return losses / len(data)

設定

In [18]:
vocab_size_src = len(vocab_src)
vocab_size_tgt = len(vocab_tgt)
embedding_size = 4
nhead = 1
dim_feedforward = 4
num_encoder_layers = 1
num_decoder_layers = 1
dropout = 0
# vocab_size_src = len(vocab_src)
# vocab_size_tgt = len(vocab_tgt)
# embedding_size = 240
# nhead = 8
# dim_feedforward = 100
# num_encoder_layers = 2
# num_decoder_layers = 2
# dropout = 0.1

model = Seq2SeqTransformer(
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    embedding_size=embedding_size,
    vocab_size_src=vocab_size_src, vocab_size_tgt=vocab_size_tgt,
    dim_feedforward=dim_feedforward,
    dropout=dropout, nhead=nhead
)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

model = model.to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(model.parameters())

モデルの調査

In [19]:
print(model)

Seq2SeqTransformer(
  (token_embedding_src): TokenEmbedding(
    (embedding): Embedding(17, 4, padding_idx=11)
  )
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0, inplace=False)
  )
  (token_embedding_tgt): TokenEmbedding(
    (embedding): Embedding(17, 4, padding_idx=11)
  )
  (decoder_layer): TransformerDecoderLayerScratch(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (linear1): Linear(in_features=4, out_features=4, bias=True)
    (dropout): Dropout(p=0, inplace=False)
    (linear2): Linear(in_features=4, out_features=4, bias=True)
    (norm1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (d

In [20]:
# モデル内の層の名前とパラメータ情報を表示
LP = list(model.named_parameters())
lp = len(LP)
print(f"{lp} 層")
for p in range(0, lp):
    print(f"\n層名: {LP[p][0]}")
    print(f"形状: {LP[p][1].shape}")
    print(f"値: {LP[p][1]}")


22 層

層名: token_embedding_src.embedding.weight
形状: torch.Size([17, 4])
値: Parameter containing:
tensor([[-2.0939e-01,  3.2882e-01,  5.2519e-01,  1.1288e-01],
        [ 1.8429e-01,  3.8788e-01, -1.6286e-01, -1.4642e-01],
        [-1.0410e-01,  9.8763e-02, -2.7996e-01, -1.8372e-02],
        [ 1.6086e-01,  7.2780e-02,  3.4627e-01,  2.3480e-01],
        [ 2.7048e-01,  3.9514e-01,  9.1258e-02, -4.6651e-01],
        [-5.0054e-01, -2.5417e-01, -1.6787e-02, -2.9451e-01],
        [-2.0243e-01, -1.6558e-01, -2.7159e-01,  2.0609e-01],
        [ 1.5302e-01,  3.6215e-01, -2.2980e-01, -8.4958e-02],
        [-2.3522e-01, -1.8330e-01,  2.4772e-01, -2.5181e-01],
        [-3.5181e-01, -4.2697e-02, -1.5366e-01,  1.4675e-01],
        [-1.4069e-01,  3.0973e-01,  1.7339e-01, -1.1477e-01],
        [-2.3494e-01, -4.5671e-01, -1.4950e-01,  4.4653e-01],
        [ 2.3980e-01,  4.4021e-01, -3.0284e-01,  6.8008e-02],
        [ 9.9600e-02,  1.4888e-01, -5.0286e-01, -4.0006e-01],
        [-4.2614e-01, -2.0224e-01, -

## 学習実行

In [21]:
epoch = 100
best_loss = float('Inf')
best_model = None
patience = 10
counter = 0

for loop in range(1, epoch + 1):
    
    start_time = time.time()
    
    loss_train = train(
        model=model, data=train_iter, optimizer=optimizer,
        criterion=criterion, PAD_IDX=PAD_IDX
    )
    
    elapsed_time = time.time() - start_time
    
    loss_valid = evaluate(
        model=model, data=valid_iter, criterion=criterion, PAD_IDX=PAD_IDX
    )
    
    print('[{}/{}] train loss: {:.2f}, valid loss: {:.2f}  [{}{:.0f}s] counter: {} {}'.format(
        loop, epoch,
        loss_train, loss_valid,
        str(int(math.floor(elapsed_time / 60))) + 'm' if math.floor(elapsed_time / 60) > 0 else '',
        elapsed_time % 60,
        counter,
        '**' if best_loss > loss_valid else ''
    ))
    
    if best_loss > loss_valid:
        best_loss = loss_valid
        best_model = model
        counter = 0
        
    if counter > patience:
        break
    
    counter += 1

100%|██████████| 80/80 [00:12<00:00,  6.41it/s]


[1/100] train loss: 2.64, valid loss: 2.29  [12s] counter: 0 **


100%|██████████| 80/80 [00:12<00:00,  6.54it/s]


[2/100] train loss: 2.07, valid loss: 1.87  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.60it/s]


[3/100] train loss: 1.73, valid loss: 1.62  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.58it/s]


[4/100] train loss: 1.55, valid loss: 1.49  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.60it/s]


[5/100] train loss: 1.43, valid loss: 1.38  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.56it/s]


[6/100] train loss: 1.34, valid loss: 1.29  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.59it/s]


[7/100] train loss: 1.25, valid loss: 1.21  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.55it/s]


[8/100] train loss: 1.18, valid loss: 1.15  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.53it/s]


[9/100] train loss: 1.12, valid loss: 1.09  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.61it/s]


[10/100] train loss: 1.06, valid loss: 1.03  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.59it/s]


[11/100] train loss: 1.00, valid loss: 0.98  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.53it/s]


[12/100] train loss: 0.96, valid loss: 0.95  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.60it/s]


[13/100] train loss: 0.94, valid loss: 0.92  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.55it/s]


[14/100] train loss: 0.91, valid loss: 0.91  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.64it/s]


[15/100] train loss: 0.90, valid loss: 0.89  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.62it/s]


[16/100] train loss: 0.88, valid loss: 0.88  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.57it/s]


[17/100] train loss: 0.87, valid loss: 0.87  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.58it/s]


[18/100] train loss: 0.86, valid loss: 0.85  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.56it/s]


[19/100] train loss: 0.85, valid loss: 0.84  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.56it/s]


[20/100] train loss: 0.84, valid loss: 0.83  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.57it/s]


[21/100] train loss: 0.83, valid loss: 0.82  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.58it/s]


[22/100] train loss: 0.82, valid loss: 0.81  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.50it/s]


[23/100] train loss: 0.81, valid loss: 0.80  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.46it/s]


[24/100] train loss: 0.80, valid loss: 0.80  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.51it/s]


[25/100] train loss: 0.79, valid loss: 0.79  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.52it/s]


[26/100] train loss: 0.78, valid loss: 0.78  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.53it/s]


[27/100] train loss: 0.77, valid loss: 0.77  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.57it/s]


[28/100] train loss: 0.76, valid loss: 0.76  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.43it/s]


[29/100] train loss: 0.75, valid loss: 0.75  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.56it/s]


[30/100] train loss: 0.74, valid loss: 0.74  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.47it/s]


[31/100] train loss: 0.74, valid loss: 0.74  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.56it/s]


[32/100] train loss: 0.73, valid loss: 0.73  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.55it/s]


[33/100] train loss: 0.72, valid loss: 0.72  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.54it/s]


[34/100] train loss: 0.71, valid loss: 0.71  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.56it/s]


[35/100] train loss: 0.71, valid loss: 0.70  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.55it/s]


[36/100] train loss: 0.70, valid loss: 0.70  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.53it/s]


[37/100] train loss: 0.70, valid loss: 0.69  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.54it/s]


[38/100] train loss: 0.69, valid loss: 0.69  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.57it/s]


[39/100] train loss: 0.68, valid loss: 0.68  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.54it/s]


[40/100] train loss: 0.68, valid loss: 0.68  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.56it/s]


[41/100] train loss: 0.68, valid loss: 0.68  [12s] counter: 1 


100%|██████████| 80/80 [00:12<00:00,  6.56it/s]


[42/100] train loss: 0.67, valid loss: 0.67  [12s] counter: 2 **


100%|██████████| 80/80 [00:12<00:00,  6.62it/s]


[43/100] train loss: 0.67, valid loss: 0.67  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.56it/s]


[44/100] train loss: 0.66, valid loss: 0.66  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.56it/s]


[45/100] train loss: 0.66, valid loss: 0.66  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.48it/s]


[46/100] train loss: 0.66, valid loss: 0.66  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.55it/s]


[47/100] train loss: 0.65, valid loss: 0.65  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.60it/s]


[48/100] train loss: 0.65, valid loss: 0.65  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.59it/s]


[49/100] train loss: 0.65, valid loss: 0.65  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.60it/s]


[50/100] train loss: 0.64, valid loss: 0.64  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.48it/s]


[51/100] train loss: 0.64, valid loss: 0.64  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.55it/s]


[52/100] train loss: 0.64, valid loss: 0.64  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.57it/s]


[53/100] train loss: 0.64, valid loss: 0.64  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.44it/s]


[54/100] train loss: 0.63, valid loss: 0.63  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.48it/s]


[55/100] train loss: 0.63, valid loss: 0.63  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.49it/s]


[56/100] train loss: 0.63, valid loss: 0.64  [12s] counter: 1 


100%|██████████| 80/80 [00:12<00:00,  6.53it/s]


[57/100] train loss: 0.63, valid loss: 0.63  [12s] counter: 2 **


100%|██████████| 80/80 [00:12<00:00,  6.45it/s]


[58/100] train loss: 0.63, valid loss: 0.62  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.50it/s]


[59/100] train loss: 0.66, valid loss: 0.74  [12s] counter: 1 


100%|██████████| 80/80 [00:12<00:00,  6.51it/s]


[60/100] train loss: 0.67, valid loss: 0.62  [12s] counter: 2 **


100%|██████████| 80/80 [00:12<00:00,  6.53it/s]


[61/100] train loss: 0.62, valid loss: 0.62  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.53it/s]


[62/100] train loss: 0.62, valid loss: 0.62  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.55it/s]


[63/100] train loss: 0.62, valid loss: 0.62  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.58it/s]


[64/100] train loss: 0.63, valid loss: 0.81  [12s] counter: 1 


100%|██████████| 80/80 [00:12<00:00,  6.49it/s]


[65/100] train loss: 0.77, valid loss: 0.74  [12s] counter: 2 


100%|██████████| 80/80 [00:12<00:00,  6.59it/s]


[66/100] train loss: 0.69, valid loss: 0.63  [12s] counter: 3 


100%|██████████| 80/80 [00:12<00:00,  6.53it/s]


[67/100] train loss: 0.62, valid loss: 0.61  [12s] counter: 4 **


100%|██████████| 80/80 [00:12<00:00,  6.55it/s]


[68/100] train loss: 0.61, valid loss: 0.61  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.59it/s]


[69/100] train loss: 0.61, valid loss: 0.61  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.49it/s]


[70/100] train loss: 0.61, valid loss: 0.61  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.58it/s]


[71/100] train loss: 0.60, valid loss: 0.60  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.55it/s]


[72/100] train loss: 0.60, valid loss: 0.60  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.60it/s]


[73/100] train loss: 0.70, valid loss: 0.74  [12s] counter: 1 


100%|██████████| 80/80 [00:12<00:00,  6.60it/s]


[74/100] train loss: 0.71, valid loss: 0.64  [12s] counter: 2 


100%|██████████| 80/80 [00:12<00:00,  6.58it/s]


[75/100] train loss: 0.62, valid loss: 0.61  [12s] counter: 3 


100%|██████████| 80/80 [00:12<00:00,  6.52it/s]


[76/100] train loss: 0.61, valid loss: 0.61  [12s] counter: 4 


100%|██████████| 80/80 [00:12<00:00,  6.58it/s]


[77/100] train loss: 0.61, valid loss: 0.60  [12s] counter: 5 


100%|██████████| 80/80 [00:12<00:00,  6.50it/s]


[78/100] train loss: 0.70, valid loss: 0.71  [12s] counter: 6 


100%|██████████| 80/80 [00:12<00:00,  6.51it/s]


[79/100] train loss: 0.76, valid loss: 0.71  [12s] counter: 7 


100%|██████████| 80/80 [00:12<00:00,  6.50it/s]


[80/100] train loss: 0.63, valid loss: 0.61  [12s] counter: 8 


100%|██████████| 80/80 [00:12<00:00,  6.58it/s]


[81/100] train loss: 0.60, valid loss: 0.60  [12s] counter: 9 


100%|██████████| 80/80 [00:12<00:00,  6.60it/s]


[82/100] train loss: 0.60, valid loss: 0.60  [12s] counter: 10 **


100%|██████████| 80/80 [00:12<00:00,  6.56it/s]


[83/100] train loss: 0.60, valid loss: 0.60  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.56it/s]


[84/100] train loss: 0.60, valid loss: 0.60  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.55it/s]


[85/100] train loss: 0.59, valid loss: 0.59  [12s] counter: 1 **


100%|██████████| 80/80 [00:12<00:00,  6.57it/s]


[86/100] train loss: 0.75, valid loss: 0.74  [12s] counter: 1 


100%|██████████| 80/80 [00:12<00:00,  6.55it/s]


[87/100] train loss: 0.73, valid loss: 0.72  [12s] counter: 2 


100%|██████████| 80/80 [00:12<00:00,  6.53it/s]


[88/100] train loss: 0.71, valid loss: 0.70  [12s] counter: 3 


100%|██████████| 80/80 [00:12<00:00,  6.59it/s]


[89/100] train loss: 0.65, valid loss: 0.61  [12s] counter: 4 


100%|██████████| 80/80 [00:12<00:00,  6.50it/s]


[90/100] train loss: 0.61, valid loss: 0.60  [12s] counter: 5 


100%|██████████| 80/80 [00:12<00:00,  6.59it/s]


[91/100] train loss: 0.60, valid loss: 0.60  [12s] counter: 6 


100%|██████████| 80/80 [00:12<00:00,  6.56it/s]


[92/100] train loss: 0.60, valid loss: 0.60  [12s] counter: 7 


100%|██████████| 80/80 [00:12<00:00,  6.55it/s]


[93/100] train loss: 0.60, valid loss: 0.60  [12s] counter: 8 


100%|██████████| 80/80 [00:12<00:00,  6.45it/s]


[94/100] train loss: 0.65, valid loss: 0.78  [12s] counter: 9 


100%|██████████| 80/80 [00:12<00:00,  6.43it/s]


[95/100] train loss: 0.74, valid loss: 0.72  [12s] counter: 10 


100%|██████████| 80/80 [00:12<00:00,  6.57it/s]


[96/100] train loss: 0.71, valid loss: 0.70  [12s] counter: 11 


学習したモデルの保存

In [22]:
torch.save(best_model.state_dict(), model_dir_path.joinpath(version + 'translation_transfomer.pth'))

学習したモデルを使って翻訳をする

In [23]:
def translate(
    model, text, vocab_src, vocab_tgt, seq_len_tgt,
    START_IDX, END_IDX
):
    model.eval()
    tokens_src = convert_text_to_indexes(text, vocab=vocab_src)
    num_tokens_src = len(tokens_src)

    # Tensorに変換
    src = torch.LongTensor(tokens_src).reshape(num_tokens_src, 1).to(device)
    mask_src = torch.zeros((num_tokens_src, num_tokens_src), device=device).type(torch.bool)

    # デコード
    predicts = greedy_decode(
        model=model, src=src,
        mask_src=mask_src, seq_len_tgt=seq_len_tgt,
        START_IDX=START_IDX, END_IDX=END_IDX
    ).flatten()

    return convert_indexes_to_text(predicts.tolist(), vocab=vocab_tgt)

def greedy_decode(model, src, mask_src, seq_len_tgt, START_IDX, END_IDX):
    src = src.to(device)
    mask_src = mask_src.to(device)

    # ソースの埋め込みをメモリとして利用
    memory = model.positional_encoding(model.token_embedding_src(src))
    
    ys = torch.ones(1, 1).fill_(START_IDX).type(torch.long).to(device)
    
    for i in range(seq_len_tgt - 1):
        memory = memory.to(device)
        mask_tgt = generate_square_subsequent_mask(ys.size(0)).to(device).type(torch.bool)
        
        output = model.decode(ys, memory, mask_tgt)
        output = output.transpose(0, 1)
        output = model.output(output[:, -1])
        
        # 最も高いスコアのトークンを取得
        _, next_word = torch.max(output, dim=1)
        next_word = next_word.item()

        # 生成されたトークンを追加
        ys = torch.cat([ys, torch.ones(1, 1).fill_(next_word).type_as(src.data)], dim=0)
        if next_word == END_IDX:
            break
    
    return ys


In [24]:
seq_len_tgt = max([len(x[1]) for x in train_data])
text = '- 5 2'

# 翻訳を実行
translation = translate(
    model=best_model, text=text, vocab_src=vocab_src, vocab_tgt=vocab_tgt,
    seq_len_tgt=seq_len_tgt,
    START_IDX=START_IDX, END_IDX=END_IDX
)

print(f"Input: {text}")
print(f"Output: {translation}")

Input: - 5 2
Output: 5 5 -


In [49]:
# 様々な入力を試してみる

text_list = { '- 2 7':'2 7 -', '- 9 7' : '9 7 -', '+ - 2 7 4' : '2 7 - 4 +', '- - 6 9 - 7 3' : '6 9 - 7 3 - -', '- - + + + + + 4 6 3 7 + 6 7 1 9 + + 3 6 2' : '4 6 + 3 + 7 + 6 7 + + 1 + 9 - 3 6 + 2 + - '}

for text, tgt in text_list.items():
    translation = translate(
        model=best_model, text=text, vocab_src=vocab_src, vocab_tgt=vocab_tgt,
        seq_len_tgt=seq_len_tgt,
        START_IDX=START_IDX, END_IDX=END_IDX
    )
    print(f"Input: {text}")
    print(f"Output: {translation}")
    print(f"Target: {tgt}")
    print('---')




Input: - 2 7
Output: 2 7 -
Target: 2 7 -
---
Input: - 9 7
Output: 9 7 -
Target: 9 7 -
---
Input: + - 2 7 4
Output: 2 7 - 4 -
Target: 2 7 - 4 +
---
Input: - - 6 9 - 7 3
Output: 2 9 - 7 - + -
Target: 6 9 - 7 3 - -
---
Input: - - + + + + + 4 6 3 7 + 6 7 1 9 + + 3 6 2
Output: 1 9 - 9 - + 1 + + + + + - 0 + - 0 + + 9 - 1 0 - 1 + - 1 - - 0 - + 0 + +
Target: 4 6 + 3 + 7 + 6 7 + + 1 + 9 - 3 6 + 2 + - 
---


## モデルの動作を分析

In [26]:
import torch

# モデルのロード
model_path = model_dir_path.joinpath(version + 'translation_transfomer.pth')
loaded_model = Seq2SeqTransformer(
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    embedding_size=embedding_size,
    vocab_size_src=vocab_size_src, vocab_size_tgt=vocab_size_tgt,
    dim_feedforward=dim_feedforward,
    dropout=dropout, nhead=nhead
).to(device)
loaded_model.load_state_dict(torch.load(model_path))
loaded_model.eval()


Seq2SeqTransformer(
  (token_embedding_src): TokenEmbedding(
    (embedding): Embedding(17, 4, padding_idx=11)
  )
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0, inplace=False)
  )
  (token_embedding_tgt): TokenEmbedding(
    (embedding): Embedding(17, 4, padding_idx=11)
  )
  (decoder_layer): TransformerDecoderLayerScratch(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (linear1): Linear(in_features=4, out_features=4, bias=True)
    (dropout): Dropout(p=0, inplace=False)
    (linear2): Linear(in_features=4, out_features=4, bias=True)
    (norm1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (d

In [27]:
# パラメータを取り出す


# モデルのパラメータを取得
params = dict(loaded_model.named_parameters())
print(params.keys())

# 埋め込み行列を取得
embedding_src_weight = params['token_embedding_src.embedding.weight'].data
embedding_tgt_weight = params['token_embedding_tgt.embedding.weight'].data

# 線形層の重みとバイアス
output_weight = params['output.weight'].data
output_bias = params['output.bias'].data

# デコーダの自己注意の重みとバイアス
self_attn_in_proj_weight = params['decoder_layer.self_attn.in_proj_weight'].data
self_attn_in_proj_bias = params['decoder_layer.self_attn.in_proj_bias'].data
self_attn_out_proj_weight = params['decoder_layer.self_attn.out_proj.weight'].data
self_attn_out_proj_bias = params['decoder_layer.self_attn.out_proj.bias'].data

# メモリー注意の重みとバイアス
multihead_attn_in_proj_weight = params['decoder_layer.multihead_attn.in_proj_weight'].data
multihead_attn_in_proj_bias = params['decoder_layer.multihead_attn.in_proj_bias'].data
multihead_attn_out_proj_weight = params['decoder_layer.multihead_attn.out_proj.weight'].data
multihead_attn_out_proj_bias = params['decoder_layer.multihead_attn.out_proj.bias'].data

# フィードフォワードネットワークの重みとバイアス
linear1_weight = params['decoder_layer.linear1.weight'].data
linear1_bias = params['decoder_layer.linear1.bias'].data
linear2_weight = params['decoder_layer.linear2.weight'].data
linear2_bias = params['decoder_layer.linear2.bias'].data

# LayerNormのパラメータ
norm1_weight = params['decoder_layer.norm1.weight'].data
norm1_bias = params['decoder_layer.norm1.bias'].data
norm2_weight = params['decoder_layer.norm2.weight'].data
norm2_bias = params['decoder_layer.norm2.bias'].data
norm3_weight = params['decoder_layer.norm3.weight'].data
norm3_bias = params['decoder_layer.norm3.bias'].data


dict_keys(['token_embedding_src.embedding.weight', 'token_embedding_tgt.embedding.weight', 'decoder_layer.self_attn.in_proj_weight', 'decoder_layer.self_attn.in_proj_bias', 'decoder_layer.self_attn.out_proj.weight', 'decoder_layer.self_attn.out_proj.bias', 'decoder_layer.multihead_attn.in_proj_weight', 'decoder_layer.multihead_attn.in_proj_bias', 'decoder_layer.multihead_attn.out_proj.weight', 'decoder_layer.multihead_attn.out_proj.bias', 'decoder_layer.linear1.weight', 'decoder_layer.linear1.bias', 'decoder_layer.linear2.weight', 'decoder_layer.linear2.bias', 'decoder_layer.norm1.weight', 'decoder_layer.norm1.bias', 'decoder_layer.norm2.weight', 'decoder_layer.norm2.bias', 'decoder_layer.norm3.weight', 'decoder_layer.norm3.bias', 'output.weight', 'output.bias'])


In [28]:

# Positional Encoding
def positional_encoding(tensor: Tensor, maxlen=5000):
    embedding_size = tensor.size(-1)
    den = torch.exp(-torch.arange(0, embedding_size, 2) * math.log(10000) / embedding_size)
    pos = torch.arange(0, maxlen).reshape(maxlen, 1)
    embedding_pos = torch.zeros((maxlen, embedding_size))
    embedding_pos[:, 0::2] = torch.sin(pos * den)
    embedding_pos[:, 1::2] = torch.cos(pos * den)
    embedding_pos = embedding_pos.unsqueeze(-2)
    return tensor + embedding_pos[: tensor.size(0), :].to(tensor.device)

In [50]:

# 翻訳処理を実行
seq_len_tgt = max([len(x[1]) for x in train_data])
text = '+ 3 1'

tokens_src = convert_text_to_indexes(text, vocab=vocab_src)
src = torch.LongTensor(tokens_src).reshape(len(tokens_src), 1).to(device)
memory = positional_encoding(embedding_src_weight[src] * math.sqrt(embedding_size))
ys = torch.ones(1, 1).fill_(START_IDX).type(torch.long).to(device)

for i in range(10):
    tgt_embed = positional_encoding(embedding_tgt_weight[ys] * math.sqrt(embedding_size))
    tgt_mask = generate_square_subsequent_mask(ys.size(0)).to(device).type(torch.bool)

    tgt2, self_attn_weight = loaded_model.decoder_layer.self_attn(tgt_embed, tgt_embed, tgt_embed)
    tgt = tgt_embed + tgt2
    tgt = loaded_model.decoder_layer.norm1(tgt)

    # Attention with the encoder outputs (memory)
    tgt2, multi_attn_weight = loaded_model.decoder_layer.multihead_attn(tgt, memory, memory)
    tgt = tgt + tgt2
    tgt = loaded_model.decoder_layer.norm2(tgt)
    
    # Feedforward network

    # decoder linear1, 2
    tgt2 = tgt.matmul(linear1_weight.T) + linear1_bias
    tgt2 = F.relu(tgt2)
    tgt2 = tgt2.matmul(linear2_weight.T) + linear2_bias
    tgt = tgt + tgt2

    # LayerNorm
    tgt = loaded_model.decoder_layer.norm3(tgt)

    output = tgt.transpose(0, 1)
    output = loaded_model.output(output[:, -1])

    _, next_word = torch.max(output, dim=1)
    next_word = next_word.item()

    ys = torch.cat([ys, torch.ones(1, 1).fill_(next_word).type_as(src.data)], dim=0)
    print(next_word)
    
    if next_word == END_IDX:
        break


flat_indexes = [idx for sublist in ys.tolist() for idx in sublist] if isinstance(ys.tolist()[0], list) else ys.tolist()

print(f"Input: {text}")
print(f"Decoded sequence: {convert_indexes_to_text(flat_indexes, vocab_tgt)}")

3
14
9
14
16
13
Input: + 3 1
Decoded sequence: 3 9 -


In [30]:
print(type(self_attn_weight))
print(self_attn_weight.shape)
print(self_attn_weight)
print(type(multi_attn_weight))
print(multi_attn_weight.shape)
print(multi_attn_weight)

<class 'torch.Tensor'>
torch.Size([1, 6, 6])
tensor([[[9.7928e-02, 3.7918e-02, 3.4161e-01, 4.1949e-02, 4.6730e-01, 1.3293e-02],
         [7.6814e-02, 2.5177e-02, 2.7461e-01, 6.8419e-02, 5.4614e-01, 8.8377e-03],
         [4.7875e-02, 4.8488e-03, 7.2264e-02, 6.1091e-02, 7.8746e-01, 2.6458e-02],
         [7.6803e-03, 7.7532e-04, 1.1772e-01, 2.1982e-02, 8.5183e-01, 1.2766e-05],
         [5.7988e-02, 8.3736e-03, 1.5991e-01, 2.7607e-02, 7.3520e-01, 1.0917e-02],
         [3.4614e-02, 1.5527e-01, 7.6203e-01, 8.8106e-03, 3.9233e-02, 4.1122e-05]]],
       device='cuda:0', grad_fn=<MeanBackward1>)
<class 'torch.Tensor'>
torch.Size([1, 6, 7])
tensor([[[1.3517e-01, 3.1551e-03, 2.0992e-01, 4.3665e-02, 1.1005e-01, 2.4871e-01, 2.4932e-01],
         [1.3068e-02, 9.2935e-05, 2.3955e-01, 2.8887e-02, 2.1169e-01, 2.9657e-01, 2.1014e-01],
         [1.2632e-03, 8.7829e-05, 1.5520e-01, 1.1557e-01, 2.2905e-01, 4.7093e-01, 2.7911e-02],
         [2.4496e-01, 1.8300e-04, 9.7513e-02, 2.0154e-03, 2.8760e-02, 5.2720

# Transformerの検算

スクラッチで書くための検算

## Multihead Attention

Multihead Attentionの動作をスクラッチで書きたいので、ここで検算する

参考サイト
https://blog.amedama.jp/entry/pytorch-multi-head-attention-verify

In [31]:
import torch
from torch import nn
import torch.nn.functional as F


In [32]:
edim = 4 # 埋め込み次元
num_heads = 1 # ヘッド数
model = nn.MultiheadAttention(edim, num_heads, bias=True, batch_first=True)

In [33]:
batch_size = 2
L=5
X = torch.randn(batch_size, L, edim) # 入力

Q = K = V = X # クエリ、キー、バリューは全て入力とする
print(Q.shape)
print(Q)

torch.Size([2, 5, 4])
tensor([[[-2.0358, -0.4210, -0.1428,  0.7600],
         [ 0.8289, -0.5450, -1.3160,  0.2971],
         [ 1.1711, -0.9972,  0.1004, -0.5166],
         [ 1.4415,  1.5530,  1.0889,  0.2702],
         [-0.0863,  1.2597, -0.0109,  0.7002]],

        [[ 0.2162, -1.0729, -1.1537, -1.1673],
         [-2.2337, -0.3451,  1.0464, -0.8750],
         [ 1.4075,  0.4826, -0.8646, -0.8766],
         [ 0.3306, -1.0459, -1.4164, -0.5356],
         [ 1.1663, -0.2434, -1.2874,  0.5635]]])


In [34]:

attn_output, attn_output_weights = model(Q, K, V)

print(attn_output.shape)
print(attn_output)



torch.Size([2, 5, 4])
tensor([[[-0.0288, -0.1482, -0.0386, -0.3702],
         [ 0.0743,  0.2831,  0.2549,  0.1993],
         [ 0.0551,  0.2288,  0.1949,  0.2375],
         [-0.0944, -0.2004, -0.1716, -0.0278],
         [-0.0442, -0.1338, -0.0655, -0.1966]],

        [[ 0.2483,  0.6067,  0.4471,  0.5850],
         [ 0.0917,  0.1767,  0.0822,  0.2280],
         [ 0.2295,  0.5605,  0.3955,  0.5753],
         [ 0.2496,  0.6061,  0.4472,  0.5789],
         [ 0.2371,  0.5706,  0.4091,  0.5626]]], grad_fn=<TransposeBackward0>)


In [35]:
from pprint import pprint
pprint(list(model.named_parameters()))

[('in_proj_weight',
  Parameter containing:
tensor([[-0.5120,  0.5027,  0.5138,  0.3504],
        [-0.4264,  0.2892,  0.4455, -0.1747],
        [-0.5509,  0.3105, -0.4759,  0.5259],
        [ 0.3882,  0.2312,  0.1669,  0.2994],
        [-0.3315,  0.5497,  0.4611, -0.1661],
        [-0.5297,  0.2857, -0.5514,  0.4634],
        [ 0.3630, -0.0409,  0.3067,  0.2207],
        [-0.3475, -0.3323, -0.0067, -0.0314],
        [-0.2557,  0.1545,  0.4712,  0.2270],
        [ 0.0468, -0.5495, -0.4486, -0.1624],
        [-0.2268, -0.1971, -0.0109,  0.2832],
        [ 0.1716,  0.3284,  0.3033,  0.2963]], requires_grad=True)),
 ('in_proj_bias',
  Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)),
 ('out_proj.weight',
  Parameter containing:
tensor([[-0.4673, -0.3341, -0.0209, -0.2351],
        [-0.4463,  0.2863, -0.2579,  0.1748],
        [-0.3392,  0.3523, -0.2350,  0.3936],
        [-0.3965,  0.2987,  0.2892, -0.0233]], requires_grad=True)),
 ('out_p

In [36]:
model_weight = {name: param.data for name, param in model.named_parameters()}
Wi = model_weight['in_proj_weight']
Wo = model_weight['out_proj.weight']
Wbi = model_weight['in_proj_bias']
Wbo = model_weight['out_proj.bias']

In [37]:
Wi_q, Wi_k, Wi_v = Wi.chunk(3, dim=0)
Wbi_q, Wbi_k, Wbi_v = Wbi.chunk(3, dim=0)
QW = torch.matmul(Q, Wi_q.T) + Wbi_q
KW = torch.matmul(K, Wi_k.T) + Wbi_k
VW = torch.matmul(V, Wi_v.T) + Wbi_v

KW_t = KW.transpose(-2, -1)
QK_t = torch.bmm(QW, KW_t)
QK_scaled = QK_t / (edim ** 0.5)
attn_weights_ = F.softmax(QK_scaled, dim=-1)

In [38]:
print(attn_weights_)
print(attn_output_weights)

tensor([[[0.1110, 0.0804, 0.0878, 0.4459, 0.2749],
         [0.0533, 0.2640, 0.4837, 0.1362, 0.0629],
         [0.1467, 0.3487, 0.3370, 0.0766, 0.0909],
         [0.4930, 0.1192, 0.1016, 0.0909, 0.1953],
         [0.2407, 0.1067, 0.1261, 0.2726, 0.2539]],

        [[0.2471, 0.0299, 0.2070, 0.2641, 0.2519],
         [0.1005, 0.3707, 0.2028, 0.1199, 0.2061],
         [0.2865, 0.0741, 0.1643, 0.2803, 0.1948],
         [0.2483, 0.0303, 0.2191, 0.2541, 0.2482],
         [0.2769, 0.0637, 0.2037, 0.2522, 0.2035]]])
tensor([[[0.1110, 0.0804, 0.0878, 0.4459, 0.2749],
         [0.0533, 0.2640, 0.4837, 0.1362, 0.0629],
         [0.1467, 0.3487, 0.3370, 0.0766, 0.0909],
         [0.4930, 0.1192, 0.1016, 0.0909, 0.1953],
         [0.2407, 0.1067, 0.1261, 0.2726, 0.2539]],

        [[0.2471, 0.0299, 0.2070, 0.2641, 0.2519],
         [0.1005, 0.3707, 0.2028, 0.1199, 0.2061],
         [0.2865, 0.0741, 0.1643, 0.2803, 0.1948],
         [0.2483, 0.0303, 0.2191, 0.2541, 0.2482],
         [0.2769, 0.0637,

In [39]:
AV = torch.matmul(attn_weights_, VW)
attn_output_ = torch.matmul(AV, Wo.T) + Wbo

In [40]:
print(attn_output_)
print(attn_output)

tensor([[[-0.0288, -0.1482, -0.0386, -0.3702],
         [ 0.0743,  0.2831,  0.2549,  0.1993],
         [ 0.0551,  0.2288,  0.1949,  0.2375],
         [-0.0944, -0.2004, -0.1716, -0.0278],
         [-0.0442, -0.1338, -0.0655, -0.1966]],

        [[ 0.2483,  0.6067,  0.4471,  0.5850],
         [ 0.0917,  0.1767,  0.0822,  0.2280],
         [ 0.2295,  0.5605,  0.3955,  0.5753],
         [ 0.2496,  0.6061,  0.4472,  0.5789],
         [ 0.2371,  0.5706,  0.4091,  0.5626]]])
tensor([[[-0.0288, -0.1482, -0.0386, -0.3702],
         [ 0.0743,  0.2831,  0.2549,  0.1993],
         [ 0.0551,  0.2288,  0.1949,  0.2375],
         [-0.0944, -0.2004, -0.1716, -0.0278],
         [-0.0442, -0.1338, -0.0655, -0.1966]],

        [[ 0.2483,  0.6067,  0.4471,  0.5850],
         [ 0.0917,  0.1767,  0.0822,  0.2280],
         [ 0.2295,  0.5605,  0.3955,  0.5753],
         [ 0.2496,  0.6061,  0.4472,  0.5789],
         [ 0.2371,  0.5706,  0.4091,  0.5626]]], grad_fn=<TransposeBackward0>)


## nn.Linear

In [41]:
model = nn.Linear(4, 4)
model

Linear(in_features=4, out_features=4, bias=True)

In [42]:
pprint(list(model.named_parameters()))

[('weight',
  Parameter containing:
tensor([[ 0.1411, -0.3989,  0.2314,  0.4356],
        [ 0.3195,  0.3917, -0.3176, -0.0304],
        [ 0.1668, -0.2893, -0.1019,  0.2288],
        [ 0.2496,  0.2296,  0.3388, -0.1318]], requires_grad=True)),
 ('bias',
  Parameter containing:
tensor([-0.3957, -0.2978, -0.3691, -0.3539], requires_grad=True))]


In [43]:
model_weight = {name: param.data for name, param in model.named_parameters()}
W = model_weight['weight']
B = model_weight['bias']

X = torch.randn(4) 
print(X.shape)
print(X)
output = model(X)
print(output.shape)
print(output)


torch.Size([4])
tensor([ 1.6521, -1.4468,  1.6599,  1.3019])
torch.Size([4])
tensor([ 1.3658, -0.9035,  0.4536,  0.1170], grad_fn=<ViewBackward0>)


In [44]:
output_ = X.matmul(W.T) + B
print(output_)
print(output)

tensor([ 1.3658, -0.9035,  0.4536,  0.1170])
tensor([ 1.3658, -0.9035,  0.4536,  0.1170], grad_fn=<ViewBackward0>)


## nn.LayerNorm

参考サイト
https://qiita.com/dl_from_scratch/items/133fe741b67ed14f1856

In [45]:
model = nn.LayerNorm(4)
model

LayerNorm((4,), eps=1e-05, elementwise_affine=True)

In [46]:
pprint(list(model.named_parameters()))

[('weight', Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True))]


In [47]:
model_weight = {name: param.data for name, param in model.named_parameters()}
W = model_weight['weight']
B = model_weight['bias']

X = torch.randn(4) 
print(X.shape)
print(X)
output = model(X)
print(output.shape)
print(output)


torch.Size([4])
tensor([-2.8229, -0.4509, -0.0722, -0.0127])
torch.Size([4])
tensor([-1.7137,  0.3359,  0.6632,  0.7146], grad_fn=<NativeLayerNormBackward0>)


In [48]:
output_ = X.matmul(W.T) + B
print(output_)
print(output)

tensor([-3.3588, -3.3588, -3.3588, -3.3588])
tensor([-1.7137,  0.3359,  0.6632,  0.7146], grad_fn=<NativeLayerNormBackward0>)


  output_ = X.matmul(W.T) + B
