# FomulaBEAT

変更点
- デコーダのみで学習させる
- TransformerDecoderLayerをスクラッチで書く


In [1]:
version = '03-1'
model_dir = './model/' + version
data_path = 'data/eq03-1.txt'

In [2]:
from pathlib import Path
import math
import time
from collections import Counter
from tqdm import tqdm
import torch
from torch.utils.data import random_split
import torch.nn as nn
from torch import Tensor
from torch.nn import (
    TransformerEncoder, TransformerDecoder,
    TransformerEncoderLayer, TransformerDecoderLayer
)
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import vocab
from torchtext.utils import download_from_url, extract_archive
import torch.nn.functional as F




パラメータの事前設定

In [3]:
%load_ext autoreload
%autoreload 2
torch.set_printoptions(linewidth=100)

In [4]:

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model_dir_path = Path(model_dir)
if not model_dir_path.exists():
    model_dir_path.mkdir(parents=True)

データの取得

In [5]:
def read_equation_file(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()
    src_data, tgt_data = [], []
    for line in lines:
        src, tgt = line.strip().split(' => ')
        src_data.append(src)
        tgt_data.append(tgt)
    return src_data, tgt_data


In [6]:
# ファイルを読み込み、数式データを取得
src_data, tgt_data = read_equation_file(data_path)
print(src_data[:3], tgt_data[:3])


['+ 1 + 6 4', '+ + 7 5 + 3 2', '+ 2 + 2 8'] ['1 6 4 + +', '7 5 + 3 2 + +', '2 2 8 + +']


辞書データの作成

In [7]:

SPECIALS = ['<unk>', '<pad>', '<start>', '<end>']

def build_vocab(texts):
    vocab = {}
    idx = 0
    # 数字の語彙定義
    for i in range(10):
        vocab[str(i)] = idx
        idx += 1
    # 特別語の語彙定義
    for sp in SPECIALS:
        vocab[sp] = idx
        idx += 1
    # その他の文字の語彙定義
    for text in texts:
        for char in text:
            if char not in vocab:
                vocab[char] = idx
                idx += 1
    return vocab


def convert_text_to_indexes(text, vocab):
    # <start> と <end> トークンを追加して数値化
    return [vocab['<start>']] + [vocab[char] if char in vocab else vocab['<unk>'] for char in text] + [vocab['<end>']]

# データを処理して Train と Valid に分ける関数
# データを処理して Train と Valid に分ける関数
def data_process_split(src_texts, tgt_texts, vocab_src, vocab_tgt, valid_size=0.2):
    # データを数値化
    data = []
    for (src, tgt) in zip(src_texts, tgt_texts):
        src_tensor = torch.tensor(convert_text_to_indexes(src, vocab_src), dtype=torch.long)
        tgt_tensor = torch.tensor(convert_text_to_indexes(tgt, vocab_tgt), dtype=torch.long)
        data.append((src_tensor, tgt_tensor))
    
    # データのサイズを計算して、訓練データと検証データに分割
    data_size = len(data)
    valid_size = int(valid_size * data_size)
    train_size = data_size - valid_size

    # PyTorchのrandom_splitを使って分割
    train_data, valid_data = random_split(data, [train_size, valid_size])
    
    return train_data, valid_data



In [8]:
# 辞書と逆辞書を構築
vocab_src = build_vocab(src_data)
vocab_tgt = build_vocab(tgt_data)

print(vocab_tgt)

{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '<unk>': 10, '<pad>': 11, '<start>': 12, '<end>': 13, ' ': 14, '+': 15}


In [9]:

# データを数値化
train_data, valid_data = data_process_split(src_data, tgt_data, vocab_src, vocab_tgt)

# 結果の確認
print('インデックス化された文章')
print(f"Input: {train_data[0][0]}\nOutput: {train_data[0][1]}")

# インデックスから元の文字列に戻す
def convert_indexes_to_text(indexes:list, vocab):
    reverse_vocab = {idx: token for token, idx in vocab.items()}
    return ''.join([reverse_vocab[idx] for idx in indexes if idx in reverse_vocab and reverse_vocab[idx] not in ['<start>', '<end>', '<pad>']])

print('元に戻した文章')
print(f"Input: {convert_indexes_to_text(train_data[0][0].tolist(), vocab_src)}")
print(f"Output: {convert_indexes_to_text(train_data[0][1].tolist(), vocab_tgt)}")


インデックス化された文章
Input: tensor([12,  8, 13])
Output: tensor([12,  8, 13])
元に戻した文章
Input: 8
Output: 8


In [10]:
batch_size = 32
PAD_IDX = vocab_src['<pad>']
START_IDX = vocab_src['<start>']
END_IDX = vocab_src['<end>']

def generate_batch(data_batch):
    
    batch_src, batch_tgt = [], []
    for src, tgt in data_batch:
        batch_src.append(src)
        batch_tgt.append(tgt)
        
    batch_src = pad_sequence(batch_src, padding_value=PAD_IDX)
    batch_tgt = pad_sequence(batch_tgt, padding_value=PAD_IDX)
    
    return batch_src, batch_tgt

train_iter = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(valid_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)

In [11]:
len(train_data)

80000

Transoformerの設定

In [12]:
class TokenEmbedding(nn.Module):
    
    def __init__(self, vocab_size, embedding_size):
        
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=PAD_IDX)
        self.embedding_size = embedding_size
        
    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.embedding_size)
    
    
class PositionalEncoding(nn.Module):
    
    def __init__(self, embedding_size: int, dropout: float, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        
        den = torch.exp(-torch.arange(0, embedding_size, 2) * math.log(10000) / embedding_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        embedding_pos = torch.zeros((maxlen, embedding_size))
        embedding_pos[:, 0::2] = torch.sin(pos * den)
        embedding_pos[:, 1::2] = torch.cos(pos * den)
        embedding_pos = embedding_pos.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('embedding_pos', embedding_pos)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.embedding_pos[: token_embedding.size(0), :])


In [13]:

class TransformerDecoderLayerScratch(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):
        super(TransformerDecoderLayerScratch, self).__init__()
        # Self-attention for the decoder
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Multihead attention for attending to encoder outputs (memory)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Feedforward layers
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        # Layer normalization layers
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # Self-attention
        tgt2, _ = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        
        # Attention with the encoder outputs (memory)
        tgt2, _ = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        
        # Feedforward network
        tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)

        return tgt


In [14]:

class Seq2SeqTransformer(nn.Module):
    
    def __init__(
        self, num_encoder_layers: int, num_decoder_layers: int,
        embedding_size: int, vocab_size_src: int, vocab_size_tgt: int,
        dim_feedforward:int = 512, dropout:float = 0.1, nhead:int = 8
    ):
        
        super(Seq2SeqTransformer, self).__init__()

        self.token_embedding_src = TokenEmbedding(vocab_size_src, embedding_size)
        self.positional_encoding = PositionalEncoding(embedding_size, dropout=dropout)
        
        self.token_embedding_tgt = TokenEmbedding(vocab_size_tgt, embedding_size)
        self.decoder_layer = TransformerDecoderLayerScratch(
            d_model=embedding_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout
        )
        
        self.output = nn.Linear(embedding_size, vocab_size_tgt)

    def forward(
        self, src: Tensor, tgt: Tensor,
        mask_src: Tensor, mask_tgt: Tensor,
        padding_mask_src: Tensor, padding_mask_tgt: Tensor,
        memory_key_padding_mask: Tensor
    ):
        embedding_src = self.positional_encoding(self.token_embedding_src(src))
        # memory = self.transformer_encoder(embedding_src, mask_src, padding_mask_src)
        embedding_tgt = self.positional_encoding(self.token_embedding_tgt(tgt))
        outs = self.decoder_layer(
            embedding_tgt, embedding_src, mask_tgt, None,
            padding_mask_tgt, memory_key_padding_mask
        )
        return self.output(outs)

    def decode(self, tgt: Tensor, memory: Tensor, mask_tgt: Tensor):
        return self.decoder_layer(self.positional_encoding(self.token_embedding_tgt(tgt)), memory, mask_tgt)

In [15]:
def create_mask(src, tgt, PAD_IDX):
    
    seq_len_src = src.shape[0]
    seq_len_tgt = tgt.shape[0]

    mask_src = torch.zeros((seq_len_src, seq_len_src), device=device).type(torch.bool)
    mask_tgt = generate_square_subsequent_mask(seq_len_tgt)

    padding_mask_src = (src == PAD_IDX).transpose(0, 1)
    padding_mask_tgt = (tgt == PAD_IDX).transpose(0, 1)
    
    return mask_src, mask_tgt, padding_mask_src, padding_mask_tgt


def generate_square_subsequent_mask(seq_len):
    mask = (torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

学習の定義

In [16]:
def train(model, data, optimizer, criterion, PAD_IDX):
    
    model.train()
    losses = 0
    for src, tgt in tqdm(data):
        
        src = src.to(device)
        tgt = tgt.to(device)

        input_tgt = tgt[:-1, :]

        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_IDX)

        logits = model(
            src=src, tgt=input_tgt,
            mask_src=mask_src, mask_tgt=mask_tgt,
            padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
            memory_key_padding_mask=padding_mask_src
        )

        optimizer.zero_grad()
        output_tgt = tgt[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()
        
    return losses / len(data)

In [17]:

def evaluate(model, data, criterion, PAD_IDX):
    
    model.eval()
    losses = 0
    for src, tgt in data:
        
        src = src.to(device)
        tgt = tgt.to(device)

        input_tgt = tgt[:-1, :]

        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_IDX)

        logits = model(
            src=src, tgt=input_tgt,
            mask_src=mask_src, mask_tgt=mask_tgt,
            padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
            memory_key_padding_mask=padding_mask_src
        )
        
        output_tgt = tgt[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        losses += loss.item()
        
    return losses / len(data)

設定

In [18]:
vocab_size_src = len(vocab_src)
vocab_size_tgt = len(vocab_tgt)
embedding_size = 4
nhead = 1
dim_feedforward = 4
num_encoder_layers = 1
num_decoder_layers = 1
dropout = 0
# vocab_size_src = len(vocab_src)
# vocab_size_tgt = len(vocab_tgt)
# embedding_size = 240
# nhead = 8
# dim_feedforward = 100
# num_encoder_layers = 2
# num_decoder_layers = 2
# dropout = 0.1

model = Seq2SeqTransformer(
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    embedding_size=embedding_size,
    vocab_size_src=vocab_size_src, vocab_size_tgt=vocab_size_tgt,
    dim_feedforward=dim_feedforward,
    dropout=dropout, nhead=nhead
)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

model = model.to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(model.parameters())

モデルの調査

In [19]:
print(model)

Seq2SeqTransformer(
  (token_embedding_src): TokenEmbedding(
    (embedding): Embedding(16, 4, padding_idx=11)
  )
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0, inplace=False)
  )
  (token_embedding_tgt): TokenEmbedding(
    (embedding): Embedding(16, 4, padding_idx=11)
  )
  (decoder_layer): TransformerDecoderLayerScratch(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (linear1): Linear(in_features=4, out_features=4, bias=True)
    (dropout): Dropout(p=0, inplace=False)
    (linear2): Linear(in_features=4, out_features=4, bias=True)
    (norm1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (d

In [20]:
# モデル内の層の名前とパラメータ情報を表示
LP = list(model.named_parameters())
lp = len(LP)
print(f"{lp} 層")
for p in range(0, lp):
    print(f"\n層名: {LP[p][0]}")
    print(f"形状: {LP[p][1].shape}")
    print(f"値: {LP[p][1]}")


22 層

層名: token_embedding_src.embedding.weight
形状: torch.Size([16, 4])
値: Parameter containing:
tensor([[ 0.2462, -0.5026, -0.1733,  0.3210],
        [-0.2435,  0.0656,  0.4046, -0.4386],
        [ 0.0946,  0.1828,  0.0934, -0.4812],
        [-0.4953,  0.0775, -0.0757, -0.4097],
        [ 0.5156,  0.1668,  0.1934, -0.4394],
        [-0.0298,  0.5106,  0.2552, -0.1143],
        [-0.3192, -0.5437, -0.1608,  0.0599],
        [-0.1952, -0.0550,  0.4002, -0.3174],
        [ 0.5283,  0.2347,  0.3989,  0.4341],
        [-0.4482,  0.2137, -0.1937, -0.3266],
        [ 0.4465,  0.4387, -0.3048,  0.2486],
        [-0.5282, -0.3780,  0.1323, -0.2818],
        [ 0.3326, -0.3909,  0.2611, -0.4637],
        [ 0.4959,  0.1987, -0.1198, -0.1356],
        [-0.3363,  0.0837,  0.5109, -0.2546],
        [ 0.2482,  0.2211, -0.1849, -0.1736]], device='cuda:0', requires_grad=True)

層名: token_embedding_tgt.embedding.weight
形状: torch.Size([16, 4])
値: Parameter containing:
tensor([[ 5.2284e-01,  7.4130e-02,  2.2

## 学習実行

In [21]:
epoch = 100
best_loss = float('Inf')
best_model = None
patience = 10
counter = 0

for loop in range(1, epoch + 1):
    
    start_time = time.time()
    
    loss_train = train(
        model=model, data=train_iter, optimizer=optimizer,
        criterion=criterion, PAD_IDX=PAD_IDX
    )
    
    elapsed_time = time.time() - start_time
    
    loss_valid = evaluate(
        model=model, data=valid_iter, criterion=criterion, PAD_IDX=PAD_IDX
    )
    
    print('[{}/{}] train loss: {:.2f}, valid loss: {:.2f}  [{}{:.0f}s] counter: {} {}'.format(
        loop, epoch,
        loss_train, loss_valid,
        str(int(math.floor(elapsed_time / 60))) + 'm' if math.floor(elapsed_time / 60) > 0 else '',
        elapsed_time % 60,
        counter,
        '**' if best_loss > loss_valid else ''
    ))
    
    if best_loss > loss_valid:
        best_loss = loss_valid
        best_model = model
        counter = 0
        
    if counter > patience:
        break
    
    counter += 1

100%|██████████| 2500/2500 [00:27<00:00, 89.55it/s]


[1/100] train loss: 0.94, valid loss: 0.63  [28s] counter: 0 **


100%|██████████| 2500/2500 [00:09<00:00, 258.68it/s]


[2/100] train loss: 0.59, valid loss: 0.56  [10s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 89.55it/s]


[3/100] train loss: 0.56, valid loss: 0.55  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 90.01it/s]


[4/100] train loss: 0.55, valid loss: 0.54  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:20<00:00, 124.90it/s]


[5/100] train loss: 0.54, valid loss: 0.53  [20s] counter: 1 **


100%|██████████| 2500/2500 [00:09<00:00, 263.84it/s]


[6/100] train loss: 0.53, valid loss: 0.52  [9s] counter: 1 **


100%|██████████| 2500/2500 [00:09<00:00, 262.94it/s]


[7/100] train loss: 0.52, valid loss: 0.50  [10s] counter: 1 **


100%|██████████| 2500/2500 [00:09<00:00, 257.14it/s]


[8/100] train loss: 0.50, valid loss: 0.50  [10s] counter: 1 **


100%|██████████| 2500/2500 [00:10<00:00, 246.38it/s]


[9/100] train loss: 0.48, valid loss: 0.49  [10s] counter: 1 **


100%|██████████| 2500/2500 [00:15<00:00, 157.92it/s]


[10/100] train loss: 0.47, valid loss: 0.46  [16s] counter: 1 **


100%|██████████| 2500/2500 [00:28<00:00, 89.24it/s]


[11/100] train loss: 0.46, valid loss: 0.45  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:28<00:00, 89.03it/s]


[12/100] train loss: 0.45, valid loss: 0.45  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:09<00:00, 253.10it/s]


[13/100] train loss: 0.45, valid loss: 0.43  [10s] counter: 1 **


100%|██████████| 2500/2500 [00:09<00:00, 262.78it/s]


[14/100] train loss: 0.44, valid loss: 0.43  [10s] counter: 1 **


100%|██████████| 2500/2500 [00:11<00:00, 214.09it/s]


[15/100] train loss: 0.43, valid loss: 0.45  [12s] counter: 1 


100%|██████████| 2500/2500 [00:27<00:00, 89.83it/s]


[16/100] train loss: 0.43, valid loss: 0.42  [28s] counter: 2 **


100%|██████████| 2500/2500 [00:27<00:00, 91.75it/s]


[17/100] train loss: 0.42, valid loss: 0.42  [27s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 91.68it/s]


[18/100] train loss: 0.42, valid loss: 0.43  [27s] counter: 1 


100%|██████████| 2500/2500 [00:27<00:00, 89.70it/s]


[19/100] train loss: 0.41, valid loss: 0.40  [28s] counter: 2 **


100%|██████████| 2500/2500 [00:27<00:00, 89.81it/s]


[20/100] train loss: 0.41, valid loss: 0.39  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 89.89it/s]


[21/100] train loss: 0.41, valid loss: 0.39  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 89.78it/s]


[22/100] train loss: 0.40, valid loss: 0.39  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 90.01it/s]


[23/100] train loss: 0.40, valid loss: 0.40  [28s] counter: 1 


100%|██████████| 2500/2500 [00:27<00:00, 89.87it/s]


[24/100] train loss: 0.39, valid loss: 0.39  [28s] counter: 2 


100%|██████████| 2500/2500 [00:27<00:00, 90.08it/s]


[25/100] train loss: 0.39, valid loss: 0.38  [28s] counter: 3 **


100%|██████████| 2500/2500 [00:27<00:00, 89.79it/s]


[26/100] train loss: 0.39, valid loss: 0.39  [28s] counter: 1 


100%|██████████| 2500/2500 [00:27<00:00, 89.73it/s]


[27/100] train loss: 0.39, valid loss: 0.41  [28s] counter: 2 


100%|██████████| 2500/2500 [00:27<00:00, 89.81it/s]


[28/100] train loss: 0.38, valid loss: 0.39  [28s] counter: 3 


100%|██████████| 2500/2500 [00:27<00:00, 89.76it/s]


[29/100] train loss: 0.38, valid loss: 0.37  [28s] counter: 4 **


100%|██████████| 2500/2500 [00:15<00:00, 163.31it/s]


[30/100] train loss: 0.38, valid loss: 0.38  [15s] counter: 1 


100%|██████████| 2500/2500 [00:27<00:00, 89.56it/s]


[31/100] train loss: 0.37, valid loss: 0.39  [28s] counter: 2 


100%|██████████| 2500/2500 [00:28<00:00, 88.92it/s]


[32/100] train loss: 0.37, valid loss: 0.36  [28s] counter: 3 **


100%|██████████| 2500/2500 [00:27<00:00, 89.86it/s]


[33/100] train loss: 0.37, valid loss: 0.36  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 90.49it/s]


[34/100] train loss: 0.37, valid loss: 0.39  [28s] counter: 1 


100%|██████████| 2500/2500 [00:28<00:00, 88.94it/s]


[35/100] train loss: 0.36, valid loss: 0.36  [28s] counter: 2 **


100%|██████████| 2500/2500 [00:28<00:00, 89.00it/s]


[36/100] train loss: 0.37, valid loss: 0.38  [28s] counter: 1 


100%|██████████| 2500/2500 [00:28<00:00, 88.95it/s]


[37/100] train loss: 0.36, valid loss: 0.38  [28s] counter: 2 


100%|██████████| 2500/2500 [00:28<00:00, 88.94it/s]


[38/100] train loss: 0.36, valid loss: 0.35  [28s] counter: 3 **


100%|██████████| 2500/2500 [00:28<00:00, 88.86it/s]


[39/100] train loss: 0.36, valid loss: 0.35  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:28<00:00, 88.74it/s]


[40/100] train loss: 0.36, valid loss: 0.35  [28s] counter: 1 


100%|██████████| 2500/2500 [00:28<00:00, 88.72it/s]


[41/100] train loss: 0.36, valid loss: 0.38  [28s] counter: 2 


100%|██████████| 2500/2500 [00:27<00:00, 90.35it/s]


[42/100] train loss: 0.36, valid loss: 0.37  [28s] counter: 3 


100%|██████████| 2500/2500 [00:27<00:00, 90.14it/s]


[43/100] train loss: 0.36, valid loss: 0.34  [28s] counter: 4 **


100%|██████████| 2500/2500 [00:27<00:00, 90.14it/s]


[44/100] train loss: 0.36, valid loss: 0.36  [28s] counter: 1 


100%|██████████| 2500/2500 [00:27<00:00, 90.15it/s]


[45/100] train loss: 0.36, valid loss: 0.36  [28s] counter: 2 


100%|██████████| 2500/2500 [00:27<00:00, 90.13it/s]


[46/100] train loss: 0.35, valid loss: 0.37  [28s] counter: 3 


100%|██████████| 2500/2500 [00:27<00:00, 90.17it/s]


[47/100] train loss: 0.36, valid loss: 0.34  [28s] counter: 4 


100%|██████████| 2500/2500 [00:27<00:00, 90.27it/s]


[48/100] train loss: 0.35, valid loss: 0.37  [28s] counter: 5 


100%|██████████| 2500/2500 [00:27<00:00, 90.29it/s]


[49/100] train loss: 0.35, valid loss: 0.36  [28s] counter: 6 


100%|██████████| 2500/2500 [00:27<00:00, 90.16it/s]


[50/100] train loss: 0.35, valid loss: 0.36  [28s] counter: 7 


100%|██████████| 2500/2500 [00:27<00:00, 90.16it/s]


[51/100] train loss: 0.35, valid loss: 0.35  [28s] counter: 8 


100%|██████████| 2500/2500 [00:27<00:00, 90.10it/s]


[52/100] train loss: 0.35, valid loss: 0.36  [28s] counter: 9 


100%|██████████| 2500/2500 [00:27<00:00, 89.95it/s]


[53/100] train loss: 0.35, valid loss: 0.34  [28s] counter: 10 **


100%|██████████| 2500/2500 [00:27<00:00, 92.15it/s]


[54/100] train loss: 0.35, valid loss: 0.34  [27s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 91.67it/s] 


[55/100] train loss: 0.35, valid loss: 0.34  [27s] counter: 1 


100%|██████████| 2500/2500 [00:27<00:00, 90.08it/s]


[56/100] train loss: 0.35, valid loss: 0.34  [28s] counter: 2 **


100%|██████████| 2500/2500 [00:27<00:00, 90.10it/s]


[57/100] train loss: 0.34, valid loss: 0.33  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 90.31it/s]


[58/100] train loss: 0.34, valid loss: 0.33  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 90.27it/s]


[59/100] train loss: 0.34, valid loss: 0.37  [28s] counter: 1 


100%|██████████| 2500/2500 [00:27<00:00, 90.68it/s]


[60/100] train loss: 0.34, valid loss: 0.33  [28s] counter: 2 


100%|██████████| 2500/2500 [00:27<00:00, 90.44it/s]


[61/100] train loss: 0.34, valid loss: 0.33  [28s] counter: 3 


100%|██████████| 2500/2500 [00:27<00:00, 91.78it/s] 


[62/100] train loss: 0.34, valid loss: 0.33  [27s] counter: 4 


100%|██████████| 2500/2500 [00:27<00:00, 90.68it/s]


[63/100] train loss: 0.34, valid loss: 0.33  [28s] counter: 5 


100%|██████████| 2500/2500 [00:27<00:00, 90.88it/s] 


[64/100] train loss: 0.34, valid loss: 0.36  [28s] counter: 6 


100%|██████████| 2500/2500 [00:27<00:00, 91.63it/s] 


[65/100] train loss: 0.32, valid loss: 0.32  [27s] counter: 7 **


100%|██████████| 2500/2500 [00:27<00:00, 91.12it/s] 


[66/100] train loss: 0.31, valid loss: 0.31  [27s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 90.43it/s]


[67/100] train loss: 0.31, valid loss: 0.30  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 90.90it/s] 


[68/100] train loss: 0.30, valid loss: 0.29  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 91.30it/s]


[69/100] train loss: 0.29, valid loss: 0.29  [27s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 89.87it/s]


[70/100] train loss: 0.29, valid loss: 0.29  [28s] counter: 1 


100%|██████████| 2500/2500 [00:27<00:00, 89.82it/s]


[71/100] train loss: 0.29, valid loss: 0.28  [28s] counter: 2 **


100%|██████████| 2500/2500 [00:27<00:00, 89.63it/s]


[72/100] train loss: 0.28, valid loss: 0.27  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 90.12it/s]


[73/100] train loss: 0.28, valid loss: 0.27  [28s] counter: 1 


100%|██████████| 2500/2500 [00:27<00:00, 90.38it/s]


[74/100] train loss: 0.28, valid loss: 0.30  [28s] counter: 2 


100%|██████████| 2500/2500 [00:27<00:00, 90.88it/s]


[75/100] train loss: 0.28, valid loss: 0.28  [28s] counter: 3 


100%|██████████| 2500/2500 [00:27<00:00, 90.31it/s]


[76/100] train loss: 0.27, valid loss: 0.26  [28s] counter: 4 **


100%|██████████| 2500/2500 [00:27<00:00, 90.10it/s]


[77/100] train loss: 0.27, valid loss: 0.26  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 90.73it/s] 


[78/100] train loss: 0.26, valid loss: 0.25  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 90.13it/s]


[79/100] train loss: 0.26, valid loss: 0.26  [28s] counter: 1 


100%|██████████| 2500/2500 [00:27<00:00, 90.05it/s]


[80/100] train loss: 0.25, valid loss: 0.24  [28s] counter: 2 **


100%|██████████| 2500/2500 [00:27<00:00, 90.18it/s]


[81/100] train loss: 0.25, valid loss: 0.24  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 90.11it/s]


[82/100] train loss: 0.24, valid loss: 0.37  [28s] counter: 1 


100%|██████████| 2500/2500 [00:27<00:00, 90.17it/s]


[83/100] train loss: 0.24, valid loss: 0.24  [28s] counter: 2 


100%|██████████| 2500/2500 [00:27<00:00, 90.21it/s]


[84/100] train loss: 0.24, valid loss: 0.22  [28s] counter: 3 **


100%|██████████| 2500/2500 [00:27<00:00, 90.94it/s] 


[85/100] train loss: 0.24, valid loss: 0.22  [27s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 90.45it/s]


[86/100] train loss: 0.23, valid loss: 0.22  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 90.46it/s]


[87/100] train loss: 0.23, valid loss: 0.23  [28s] counter: 1 


100%|██████████| 2500/2500 [00:27<00:00, 91.66it/s]


[88/100] train loss: 0.23, valid loss: 0.21  [27s] counter: 2 **


100%|██████████| 2500/2500 [00:27<00:00, 90.21it/s]


[89/100] train loss: 0.23, valid loss: 0.22  [28s] counter: 1 


100%|██████████| 2500/2500 [00:27<00:00, 90.31it/s]


[90/100] train loss: 0.23, valid loss: 0.21  [28s] counter: 2 **


100%|██████████| 2500/2500 [00:27<00:00, 90.57it/s] 


[91/100] train loss: 0.22, valid loss: 0.21  [28s] counter: 1 **


100%|██████████| 2500/2500 [00:27<00:00, 91.38it/s] 


[92/100] train loss: 0.23, valid loss: 0.23  [27s] counter: 1 


100%|██████████| 2500/2500 [00:27<00:00, 90.59it/s]


[93/100] train loss: 0.22, valid loss: 0.21  [28s] counter: 2 


100%|██████████| 2500/2500 [00:27<00:00, 90.12it/s]


[94/100] train loss: 0.22, valid loss: 0.24  [28s] counter: 3 


100%|██████████| 2500/2500 [00:27<00:00, 89.99it/s]


[95/100] train loss: 0.22, valid loss: 0.21  [28s] counter: 4 


100%|██████████| 2500/2500 [00:27<00:00, 91.56it/s] 


[96/100] train loss: 0.22, valid loss: 0.24  [27s] counter: 5 


100%|██████████| 2500/2500 [00:27<00:00, 91.33it/s] 


[97/100] train loss: 0.22, valid loss: 0.24  [27s] counter: 6 


100%|██████████| 2500/2500 [00:27<00:00, 91.14it/s] 


[98/100] train loss: 0.22, valid loss: 0.22  [27s] counter: 7 


100%|██████████| 2500/2500 [00:27<00:00, 91.29it/s] 


[99/100] train loss: 0.22, valid loss: 0.23  [27s] counter: 8 


100%|██████████| 2500/2500 [00:27<00:00, 90.58it/s]


[100/100] train loss: 0.21, valid loss: 0.20  [28s] counter: 9 **


学習したモデルの保存

In [22]:
torch.save(best_model.state_dict(), model_dir_path.joinpath(version + 'translation_transfomer.pth'))

学習したモデルを使って翻訳をする

In [23]:
def translate(
    model, text, vocab_src, vocab_tgt, seq_len_tgt,
    START_IDX, END_IDX
):
    model.eval()
    tokens_src = convert_text_to_indexes(text, vocab=vocab_src)
    num_tokens_src = len(tokens_src)

    # Tensorに変換
    src = torch.LongTensor(tokens_src).reshape(num_tokens_src, 1).to(device)
    mask_src = torch.zeros((num_tokens_src, num_tokens_src), device=device).type(torch.bool)

    # デコード
    predicts = greedy_decode(
        model=model, src=src,
        mask_src=mask_src, seq_len_tgt=seq_len_tgt,
        START_IDX=START_IDX, END_IDX=END_IDX
    ).flatten()

    return convert_indexes_to_text(predicts.tolist(), vocab=vocab_tgt)

def greedy_decode(model, src, mask_src, seq_len_tgt, START_IDX, END_IDX):
    src = src.to(device)
    mask_src = mask_src.to(device)

    # ソースの埋め込みをメモリとして利用
    memory = model.positional_encoding(model.token_embedding_src(src))
    
    ys = torch.ones(1, 1).fill_(START_IDX).type(torch.long).to(device)
    
    for i in range(seq_len_tgt - 1):
        memory = memory.to(device)
        mask_tgt = generate_square_subsequent_mask(ys.size(0)).to(device).type(torch.bool)
        
        output = model.decode(ys, memory, mask_tgt)
        output = output.transpose(0, 1)
        output = model.output(output[:, -1])
        
        # 最も高いスコアのトークンを取得
        _, next_word = torch.max(output, dim=1)
        next_word = next_word.item()

        # 生成されたトークンを追加
        ys = torch.cat([ys, torch.ones(1, 1).fill_(next_word).type_as(src.data)], dim=0)
        if next_word == END_IDX:
            break
    
    return ys


In [24]:
seq_len_tgt = max([len(x[1]) for x in train_data])
text = '- 5 2'

# 翻訳を実行
translation = translate(
    model=best_model, text=text, vocab_src=vocab_src, vocab_tgt=vocab_tgt,
    seq_len_tgt=seq_len_tgt,
    START_IDX=START_IDX, END_IDX=END_IDX
)

print(f"Input: {text}")
print(f"Output: {translation}")

Input: - 5 2
Output: 0 9 +


In [25]:
# 様々な入力を試してみる

text_list = { '- 2 7':'2 7 -', '- 9 7' : '9 7 -', '+ - 2 7 4' : '2 7 - 4 +', '- - 6 9 - 7 3' : '6 9 - 7 3 - -', '- - + + + + + 4 6 3 7 + 6 7 1 9 + + 3 6 2' : '4 6 + 3 + 7 + 6 7 + + 1 + 9 - 3 6 + 2 + - '}

for text, tgt in text_list.items():
    translation = translate(
        model=best_model, text=text, vocab_src=vocab_src, vocab_tgt=vocab_tgt,
        seq_len_tgt=seq_len_tgt,
        START_IDX=START_IDX, END_IDX=END_IDX
    )
    print(f"Input: {text}")
    print(f"Output: {translation}")
    print(f"Target: {tgt}")
    print('---')




Input: - 2 7
Output: 1 9 +
Target: 2 7 -
---
Input: - 9 7
Output: 1 9 +
Target: 9 7 -
---
Input: + - 2 7 4
Output: 9 7 +
Target: 2 7 - 4 +
---
Input: - - 6 9 - 7 3
Output: 7 9+6 8
Target: 6 9 - 7 3 - -
---
Input: - - + + + + + 4 6 3 7 + 6 7 1 9 + + 3 6 2
Output: 9 9 + 1 
Target: 4 6 + 3 + 7 + 6 7 + + 1 + 9 - 3 6 + 2 + - 
---


## モデルの動作を分析

In [26]:
import torch

# モデルのロード
model_path = model_dir_path.joinpath(version + 'translation_transfomer.pth')
loaded_model = Seq2SeqTransformer(
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    embedding_size=embedding_size,
    vocab_size_src=vocab_size_src, vocab_size_tgt=vocab_size_tgt,
    dim_feedforward=dim_feedforward,
    dropout=dropout, nhead=nhead
).to(device)
loaded_model.load_state_dict(torch.load(model_path))
loaded_model.eval()


Seq2SeqTransformer(
  (token_embedding_src): TokenEmbedding(
    (embedding): Embedding(16, 4, padding_idx=11)
  )
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0, inplace=False)
  )
  (token_embedding_tgt): TokenEmbedding(
    (embedding): Embedding(16, 4, padding_idx=11)
  )
  (decoder_layer): TransformerDecoderLayerScratch(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (linear1): Linear(in_features=4, out_features=4, bias=True)
    (dropout): Dropout(p=0, inplace=False)
    (linear2): Linear(in_features=4, out_features=4, bias=True)
    (norm1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (d

In [27]:
# パラメータを取り出す


# モデルのパラメータを取得
params = dict(loaded_model.named_parameters())
print(params.keys())

# 埋め込み行列を取得
embedding_src_weight = params['token_embedding_src.embedding.weight'].data
embedding_tgt_weight = params['token_embedding_tgt.embedding.weight'].data

# 線形層の重みとバイアス
output_weight = params['output.weight'].data
output_bias = params['output.bias'].data

# デコーダの自己注意の重みとバイアス
self_attn_in_proj_weight = params['decoder_layer.self_attn.in_proj_weight'].data
self_attn_in_proj_bias = params['decoder_layer.self_attn.in_proj_bias'].data
self_attn_out_proj_weight = params['decoder_layer.self_attn.out_proj.weight'].data
self_attn_out_proj_bias = params['decoder_layer.self_attn.out_proj.bias'].data

# メモリー注意の重みとバイアス
multihead_attn_in_proj_weight = params['decoder_layer.multihead_attn.in_proj_weight'].data
multihead_attn_in_proj_bias = params['decoder_layer.multihead_attn.in_proj_bias'].data
multihead_attn_out_proj_weight = params['decoder_layer.multihead_attn.out_proj.weight'].data
multihead_attn_out_proj_bias = params['decoder_layer.multihead_attn.out_proj.bias'].data

# フィードフォワードネットワークの重みとバイアス
linear1_weight = params['decoder_layer.linear1.weight'].data
linear1_bias = params['decoder_layer.linear1.bias'].data
linear2_weight = params['decoder_layer.linear2.weight'].data
linear2_bias = params['decoder_layer.linear2.bias'].data

# LayerNormのパラメータ
norm1_weight = params['decoder_layer.norm1.weight'].data
norm1_bias = params['decoder_layer.norm1.bias'].data
norm2_weight = params['decoder_layer.norm2.weight'].data
norm2_bias = params['decoder_layer.norm2.bias'].data
norm3_weight = params['decoder_layer.norm3.weight'].data
norm3_bias = params['decoder_layer.norm3.bias'].data


dict_keys(['token_embedding_src.embedding.weight', 'token_embedding_tgt.embedding.weight', 'decoder_layer.self_attn.in_proj_weight', 'decoder_layer.self_attn.in_proj_bias', 'decoder_layer.self_attn.out_proj.weight', 'decoder_layer.self_attn.out_proj.bias', 'decoder_layer.multihead_attn.in_proj_weight', 'decoder_layer.multihead_attn.in_proj_bias', 'decoder_layer.multihead_attn.out_proj.weight', 'decoder_layer.multihead_attn.out_proj.bias', 'decoder_layer.linear1.weight', 'decoder_layer.linear1.bias', 'decoder_layer.linear2.weight', 'decoder_layer.linear2.bias', 'decoder_layer.norm1.weight', 'decoder_layer.norm1.bias', 'decoder_layer.norm2.weight', 'decoder_layer.norm2.bias', 'decoder_layer.norm3.weight', 'decoder_layer.norm3.bias', 'output.weight', 'output.bias'])


In [28]:

# Positional Encoding
def positional_encoding(tensor: Tensor, maxlen=5000):
    embedding_size = tensor.size(-1)
    den = torch.exp(-torch.arange(0, embedding_size, 2) * math.log(10000) / embedding_size)
    pos = torch.arange(0, maxlen).reshape(maxlen, 1)
    embedding_pos = torch.zeros((maxlen, embedding_size))
    embedding_pos[:, 0::2] = torch.sin(pos * den)
    embedding_pos[:, 1::2] = torch.cos(pos * den)
    embedding_pos = embedding_pos.unsqueeze(-2)
    return tensor + embedding_pos[: tensor.size(0), :].to(tensor.device)

In [50]:

# 翻訳処理を実行
seq_len_tgt = max([len(x[1]) for x in train_data])
text = '+ 3 1'

tokens_src = convert_text_to_indexes(text, vocab=vocab_src)
src = torch.LongTensor(tokens_src).reshape(len(tokens_src), 1).to(device)
memory = positional_encoding(embedding_src_weight[src] * math.sqrt(embedding_size))
ys = torch.ones(1, 1).fill_(START_IDX).type(torch.long).to(device)

for i in range(10):
    tgt_embed = positional_encoding(embedding_tgt_weight[ys] * math.sqrt(embedding_size))
    tgt_mask = generate_square_subsequent_mask(ys.size(0)).to(device).type(torch.bool)
    
    # Self-attention
    self_attn_wq, self_attn_wk, self_attn_wv = self_attn_in_proj_weight.chunk(3, dim=0)
    self_attn_bq, self_attn_bk, self_attn_bv = self_attn_in_proj_bias.chunk(3, dim=0)
    QW = torch.matmul(tgt_embed.permute(1, 0, 2), self_attn_wq.T) + self_attn_bq
    KW = torch.matmul(tgt_embed.permute(1, 0, 2), self_attn_wk.T) + self_attn_bk
    VW = torch.matmul(tgt_embed.permute(1, 0, 2), self_attn_wv.T) + self_attn_bv
    self_attn_weights = F.softmax(torch.bmm(QW, KW.transpose(-2, -1)) / math.sqrt(embedding_size), dim=-1)

    AV = torch.matmul(self_attn_weights, VW)
    self_attn_output = torch.matmul(AV, self_attn_out_proj_weight.T) + self_attn_out_proj_bias
    self_attn_output = self_attn_output.permute(1, 0, 2)
    tgt = tgt_embed + self_attn_output
    tgt = loaded_model.decoder_layer.norm1(tgt)



    # Attention with the encoder outputs (memory)
    multi_attn_wq, multi_attn_wk, multi_attn_wv = multihead_attn_in_proj_weight.chunk(3, dim=0)
    multi_attn_bq, multi_attn_bk, multi_attn_bv = multihead_attn_in_proj_bias.chunk(3, dim=0)
    QW = torch.matmul(tgt.permute(1, 0, 2), multi_attn_wq.T) + multi_attn_bq
    KW = torch.matmul(memory.permute(1, 0, 2), multi_attn_wk.T) + multi_attn_bk
    VW = torch.matmul(memory.permute(1, 0, 2), multi_attn_wv.T) + multi_attn_bv
    multi_attn_weights = F.softmax(torch.bmm(QW, KW.transpose(-2, -1)) / math.sqrt(embedding_size), dim=-1)

    AV = torch.matmul(multi_attn_weights, VW)
    multi_attn_output = torch.matmul(AV, multihead_attn_out_proj_weight.T) + multihead_attn_out_proj_bias
    multi_attn_output = multi_attn_output.permute(1, 0, 2)
    tgt = tgt + multi_attn_output
    tgt = loaded_model.decoder_layer.norm2(tgt)
    
    # Feedforward network

    # decoder linear1, 2
    tgt2 = tgt.matmul(linear1_weight.T) + linear1_bias
    tgt2 = F.relu(tgt2)
    tgt2 = tgt2.matmul(linear2_weight.T) + linear2_bias
    tgt = tgt + tgt2

    # LayerNorm
    tgt = loaded_model.decoder_layer.norm3(tgt)

    output = tgt.transpose(0, 1)
    output = loaded_model.output(output[:, -1])

    _, next_word = torch.max(output, dim=1)
    next_word = next_word.item()

    ys = torch.cat([ys, torch.ones(1, 1).fill_(next_word).type_as(src.data)], dim=0)
    print(next_word)
    
    if next_word == END_IDX:
        break


flat_indexes = [idx for sublist in ys.tolist() for idx in sublist] if isinstance(ys.tolist()[0], list) else ys.tolist()

print(f"Input: {text}")
print(f"Decoded sequence: {convert_indexes_to_text(flat_indexes, vocab_tgt)}")

15
9
14
13
Input: 28
Decoded sequence: +9 


# Transformerの検算

スクラッチで書くための検算

## Multihead Attention

Multihead Attentionの動作をスクラッチで書きたいので、ここで検算する

参考サイト
https://blog.amedama.jp/entry/pytorch-multi-head-attention-verify

In [31]:
import torch
from torch import nn
import torch.nn.functional as F


In [32]:
edim = 4 # 埋め込み次元
num_heads = 1 # ヘッド数
model = nn.MultiheadAttention(edim, num_heads, bias=True, batch_first=True)

In [33]:
batch_size = 2
L=5
X = torch.randn(batch_size, L, edim) # 入力

Q = K = V = X # クエリ、キー、バリューは全て入力とする
print(Q.shape)
print(Q)

torch.Size([2, 5, 4])
tensor([[[-0.8737,  0.3509, -0.7706, -0.4398],
         [-0.5066,  0.3654, -0.5356,  1.1612],
         [ 0.1350, -0.9153, -0.0212, -0.6847],
         [ 1.8786,  1.4753, -0.0033, -1.3206],
         [ 0.7521, -0.6133,  0.7877,  0.6922]],

        [[-0.1772, -0.7303, -0.5030, -0.4813],
         [-1.5324, -0.5756, -0.1032,  1.3733],
         [-2.1627,  0.4071,  0.5956,  0.5075],
         [-0.6918, -1.3552,  1.5769, -0.4403],
         [ 0.0566,  0.6875,  0.1158,  0.9204]]])


In [34]:

attn_output, attn_output_weights = model(Q, K, V)

print(attn_output.shape)
print(attn_output)



torch.Size([2, 5, 4])
tensor([[[ 0.0030, -0.0006, -0.0164,  0.0271],
         [ 0.0130, -0.0116, -0.0532, -0.0038],
         [ 0.0278, -0.0131,  0.0103,  0.0132],
         [ 0.2010, -0.1152, -0.0115, -0.0817],
         [ 0.1102, -0.0677, -0.0354, -0.0673]],

        [[-0.1606,  0.1109,  0.2430,  0.0691],
         [-0.1860,  0.1146,  0.1056,  0.0759],
         [-0.1871,  0.1079,  0.0172,  0.0665],
         [-0.1323,  0.0840,  0.1321,  0.0306],
         [-0.1746,  0.1039,  0.0569,  0.0632]]], grad_fn=<TransposeBackward0>)


In [35]:
from pprint import pprint
pprint(list(model.named_parameters()))

[('in_proj_weight',
  Parameter containing:
tensor([[-0.2124, -0.0409,  0.4260, -0.2784],
        [ 0.4721, -0.4545, -0.2706,  0.1997],
        [ 0.0011, -0.2175,  0.1762,  0.4944],
        [ 0.5919,  0.4483,  0.4251,  0.2059],
        [ 0.4992,  0.5886,  0.5152, -0.1627],
        [-0.5067,  0.3761, -0.1674,  0.1918],
        [ 0.4766, -0.4809,  0.1981,  0.0830],
        [ 0.5641,  0.2628,  0.4033, -0.3477],
        [-0.5672,  0.0243,  0.0874, -0.1228],
        [-0.1424,  0.0863,  0.3133,  0.2118],
        [ 0.2337,  0.5890, -0.0199,  0.2652],
        [ 0.2667,  0.2908,  0.1397,  0.3063]], requires_grad=True)),
 ('in_proj_bias',
  Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)),
 ('out_proj.weight',
  Parameter containing:
tensor([[-0.1144,  0.0356,  0.1532,  0.3898],
        [ 0.0433,  0.0848,  0.0379, -0.3974],
        [ 0.4518,  0.4360,  0.4493,  0.3711],
        [ 0.2783, -0.4146,  0.1156, -0.4570]], requires_grad=True)),
 ('out_p

In [36]:
model_weight = {name: param.data for name, param in model.named_parameters()}
Wi = model_weight['in_proj_weight']
Wo = model_weight['out_proj.weight']
Wbi = model_weight['in_proj_bias']
Wbo = model_weight['out_proj.bias']

In [37]:
Wi_q, Wi_k, Wi_v = Wi.chunk(3, dim=0)
Wbi_q, Wbi_k, Wbi_v = Wbi.chunk(3, dim=0)
QW = torch.matmul(Q, Wi_q.T) + Wbi_q
KW = torch.matmul(K, Wi_k.T) + Wbi_k
VW = torch.matmul(V, Wi_v.T) + Wbi_v

KW_t = KW.transpose(-2, -1)
QK_t = torch.bmm(QW, KW_t)
QK_scaled = QK_t / (edim ** 0.5)
attn_weights_ = F.softmax(QK_scaled, dim=-1)

In [38]:
print(attn_weights_)
print(attn_output_weights)

tensor([[[0.2608, 0.2614, 0.2039, 0.1061, 0.1679],
         [0.2024, 0.2175, 0.2421, 0.1161, 0.2219],
         [0.2581, 0.2729, 0.1693, 0.1376, 0.1620],
         [0.1369, 0.0953, 0.1231, 0.5251, 0.1195],
         [0.1437, 0.1511, 0.1851, 0.2937, 0.2266]],

        [[0.1542, 0.2936, 0.2743, 0.1230, 0.1550],
         [0.2050, 0.2801, 0.1594, 0.2095, 0.1459],
         [0.2397, 0.1512, 0.1100, 0.3068, 0.1923],
         [0.1890, 0.1431, 0.1588, 0.2444, 0.2647],
         [0.2331, 0.1540, 0.1400, 0.2710, 0.2019]]])
tensor([[[0.2608, 0.2614, 0.2039, 0.1061, 0.1679],
         [0.2024, 0.2175, 0.2421, 0.1161, 0.2219],
         [0.2581, 0.2729, 0.1693, 0.1376, 0.1620],
         [0.1369, 0.0953, 0.1231, 0.5251, 0.1195],
         [0.1437, 0.1511, 0.1851, 0.2937, 0.2266]],

        [[0.1542, 0.2936, 0.2743, 0.1230, 0.1550],
         [0.2050, 0.2801, 0.1594, 0.2095, 0.1459],
         [0.2397, 0.1512, 0.1100, 0.3068, 0.1923],
         [0.1890, 0.1431, 0.1588, 0.2444, 0.2647],
         [0.2331, 0.1540,

In [39]:
AV = torch.matmul(attn_weights_, VW)
attn_output_ = torch.matmul(AV, Wo.T) + Wbo

In [40]:
print(attn_output_)
print(attn_output)

tensor([[[ 0.0030, -0.0006, -0.0164,  0.0271],
         [ 0.0130, -0.0116, -0.0532, -0.0038],
         [ 0.0278, -0.0131,  0.0103,  0.0132],
         [ 0.2010, -0.1152, -0.0115, -0.0817],
         [ 0.1102, -0.0677, -0.0354, -0.0673]],

        [[-0.1606,  0.1109,  0.2430,  0.0691],
         [-0.1860,  0.1146,  0.1056,  0.0759],
         [-0.1871,  0.1079,  0.0172,  0.0665],
         [-0.1323,  0.0840,  0.1321,  0.0306],
         [-0.1746,  0.1039,  0.0569,  0.0632]]])
tensor([[[ 0.0030, -0.0006, -0.0164,  0.0271],
         [ 0.0130, -0.0116, -0.0532, -0.0038],
         [ 0.0278, -0.0131,  0.0103,  0.0132],
         [ 0.2010, -0.1152, -0.0115, -0.0817],
         [ 0.1102, -0.0677, -0.0354, -0.0673]],

        [[-0.1606,  0.1109,  0.2430,  0.0691],
         [-0.1860,  0.1146,  0.1056,  0.0759],
         [-0.1871,  0.1079,  0.0172,  0.0665],
         [-0.1323,  0.0840,  0.1321,  0.0306],
         [-0.1746,  0.1039,  0.0569,  0.0632]]], grad_fn=<TransposeBackward0>)


## nn.Linear

In [41]:
model = nn.Linear(4, 4)
model

Linear(in_features=4, out_features=4, bias=True)

In [42]:
pprint(list(model.named_parameters()))

[('weight',
  Parameter containing:
tensor([[ 0.3715, -0.4210,  0.4472, -0.4051],
        [-0.0006, -0.4907, -0.3896, -0.3565],
        [ 0.1778,  0.0747,  0.0024,  0.3716],
        [ 0.1368,  0.4278,  0.1304, -0.4809]], requires_grad=True)),
 ('bias',
  Parameter containing:
tensor([-0.4060,  0.2613, -0.0672,  0.1253], requires_grad=True))]


In [43]:
model_weight = {name: param.data for name, param in model.named_parameters()}
W = model_weight['weight']
B = model_weight['bias']

X = torch.randn(4) 
print(X.shape)
print(X)
output = model(X)
print(output.shape)
print(output)


torch.Size([4])
tensor([-0.5570, -0.5913,  0.6583,  0.5334])
torch.Size([4])
tensor([-0.2857,  0.1052, -0.0106, -0.3745], grad_fn=<ViewBackward0>)


In [44]:
output_ = X.matmul(W.T) + B
print(output_)
print(output)

tensor([-0.2857,  0.1052, -0.0106, -0.3745])
tensor([-0.2857,  0.1052, -0.0106, -0.3745], grad_fn=<ViewBackward0>)


## nn.LayerNorm

参考サイト
https://qiita.com/dl_from_scratch/items/133fe741b67ed14f1856

In [45]:
model = nn.LayerNorm(4)
model

LayerNorm((4,), eps=1e-05, elementwise_affine=True)

In [46]:
pprint(list(model.named_parameters()))

[('weight', Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True))]


In [47]:
model_weight = {name: param.data for name, param in model.named_parameters()}
W = model_weight['weight']
B = model_weight['bias']

X = torch.randn(4) 
print(X.shape)
print(X)
output = model(X)
print(output.shape)
print(output)


torch.Size([4])
tensor([ 0.2209,  0.9319, -0.5267,  0.0094])
torch.Size([4])
tensor([ 0.1186,  1.4783, -1.3111, -0.2858], grad_fn=<NativeLayerNormBackward0>)


In [48]:
output_ = X.matmul(W.T) + B
print(output_)
print(output)

tensor([0.6355, 0.6355, 0.6355, 0.6355])
tensor([ 0.1186,  1.4783, -1.3111, -0.2858], grad_fn=<NativeLayerNormBackward0>)


  output_ = X.matmul(W.T) + B
