# FomulaBEAT

変更点
- 9999までの数の偶数奇数を判定するタスク


In [1]:
version = '03-2'
model_dir = './model/' + version
data_path = 'data/eq04-2.txt'

In [2]:
from pathlib import Path
import math
import time
from collections import Counter
from tqdm import tqdm
import torch
from torch.utils.data import random_split
import torch.nn as nn
from torch import Tensor
from torch.nn import (
    TransformerEncoder, TransformerDecoder,
    TransformerEncoderLayer, TransformerDecoderLayer
)
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import vocab
from torchtext.utils import download_from_url, extract_archive
import torch.nn.functional as F




パラメータの事前設定

In [3]:
%load_ext autoreload
%autoreload 2
torch.set_printoptions(linewidth=100)

In [4]:

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model_dir_path = Path(model_dir)
if not model_dir_path.exists():
    model_dir_path.mkdir(parents=True)

データの取得

In [5]:
def read_equation_file(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()
    src_data, tgt_data = [], []
    for line in lines:
        src, tgt = line.strip().split(' => ')
        src_data.append(src)
        tgt_data.append(tgt)
    return src_data, tgt_data


In [6]:
# ファイルを読み込み、数式データを取得
src_data, tgt_data = read_equation_file(data_path)
print(src_data[:3], tgt_data[:3])


['8912', '8624', '1989'] ['0', '0', '1']


辞書データの作成

In [7]:

SPECIALS = ['<unk>', '<pad>', '<start>', '<end>']

def build_vocab(texts):
    vocab = {}
    idx = 0
    # 数字の語彙定義
    for i in range(10):
        vocab[str(i)] = idx
        idx += 1
    # 特別語の語彙定義
    for sp in SPECIALS:
        vocab[sp] = idx
        idx += 1
    # その他の文字の語彙定義
    for text in texts:
        for char in text:
            if char not in vocab:
                vocab[char] = idx
                idx += 1
    return vocab


def convert_text_to_indexes(text, vocab):
    # <start> と <end> トークンを追加して数値化
    return [vocab['<start>']] + [vocab[char] if char in vocab else vocab['<unk>'] for char in text] + [vocab['<end>']]

# データを処理して Train と Valid に分ける関数
# データを処理して Train と Valid に分ける関数
def data_process_split(src_texts, tgt_texts, vocab_src, vocab_tgt, valid_size=0.2):
    # データを数値化
    data = []
    for (src, tgt) in zip(src_texts, tgt_texts):
        src_tensor = torch.tensor(convert_text_to_indexes(src, vocab_src), dtype=torch.long)
        tgt_tensor = torch.tensor(convert_text_to_indexes(tgt, vocab_tgt), dtype=torch.long)
        data.append((src_tensor, tgt_tensor))
    
    # データのサイズを計算して、訓練データと検証データに分割
    data_size = len(data)
    valid_size = int(valid_size * data_size)
    train_size = data_size - valid_size

    # PyTorchのrandom_splitを使って分割
    train_data, valid_data = random_split(data, [train_size, valid_size])
    
    return train_data, valid_data



In [8]:
# 辞書と逆辞書を構築
vocab_src = build_vocab(src_data)
vocab_tgt = build_vocab(tgt_data)

print(vocab_tgt)

{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '<unk>': 10, '<pad>': 11, '<start>': 12, '<end>': 13}


In [9]:

# データを数値化
train_data, valid_data = data_process_split(src_data, tgt_data, vocab_src, vocab_tgt)

# 結果の確認
print('インデックス化された文章')
print(f"Input: {train_data[0][0]}\nOutput: {train_data[0][1]}")

# インデックスから元の文字列に戻す
def convert_indexes_to_text(indexes:list, vocab):
    reverse_vocab = {idx: token for token, idx in vocab.items()}
    return ''.join([reverse_vocab[idx] for idx in indexes if idx in reverse_vocab and reverse_vocab[idx] not in ['<start>', '<end>', '<pad>']])

print('元に戻した文章')
print(f"Input: {convert_indexes_to_text(train_data[0][0].tolist(), vocab_src)}")
print(f"Output: {convert_indexes_to_text(train_data[0][1].tolist(), vocab_tgt)}")


インデックス化された文章
Input: tensor([12,  9,  1,  3,  3, 13])
Output: tensor([12,  1, 13])
元に戻した文章
Input: 9133
Output: 1


In [10]:
batch_size = 32
PAD_IDX = vocab_src['<pad>']
START_IDX = vocab_src['<start>']
END_IDX = vocab_src['<end>']

def generate_batch(data_batch):
    
    batch_src, batch_tgt = [], []
    for src, tgt in data_batch:
        batch_src.append(src)
        batch_tgt.append(tgt)
        
    batch_src = pad_sequence(batch_src, padding_value=PAD_IDX)
    batch_tgt = pad_sequence(batch_tgt, padding_value=PAD_IDX)
    
    return batch_src, batch_tgt

train_iter = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(valid_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)

In [11]:
len(train_data)

8000

Transoformerの設定

In [12]:
class TokenEmbedding(nn.Module):
    
    def __init__(self, vocab_size, embedding_size):
        
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=PAD_IDX)
        self.embedding_size = embedding_size
        
    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.embedding_size)
    
    
class PositionalEncoding(nn.Module):
    
    def __init__(self, embedding_size: int, dropout: float, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        
        den = torch.exp(-torch.arange(0, embedding_size, 2) * math.log(10000) / embedding_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        embedding_pos = torch.zeros((maxlen, embedding_size))
        embedding_pos[:, 0::2] = torch.sin(pos * den)
        embedding_pos[:, 1::2] = torch.cos(pos * den)
        embedding_pos = embedding_pos.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('embedding_pos', embedding_pos)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.embedding_pos[: token_embedding.size(0), :])


In [13]:

class TransformerDecoderLayerScratch(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):
        super(TransformerDecoderLayerScratch, self).__init__()
        # Self-attention for the decoder
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Multihead attention for attending to encoder outputs (memory)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Feedforward layers
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        # Layer normalization layers
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # Self-attention
        tgt2, _ = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        
        # Attention with the encoder outputs (memory)
        tgt2, _ = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        
        # Feedforward network
        tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)

        return tgt


In [14]:

class Seq2SeqTransformer(nn.Module):
    
    def __init__(
        self, num_encoder_layers: int, num_decoder_layers: int,
        embedding_size: int, vocab_size_src: int, vocab_size_tgt: int,
        dim_feedforward:int = 512, dropout:float = 0.1, nhead:int = 8
    ):
        
        super(Seq2SeqTransformer, self).__init__()

        self.token_embedding_src = TokenEmbedding(vocab_size_src, embedding_size)
        self.positional_encoding = PositionalEncoding(embedding_size, dropout=dropout)
        
        self.token_embedding_tgt = TokenEmbedding(vocab_size_tgt, embedding_size)
        self.decoder_layer = TransformerDecoderLayerScratch(
            d_model=embedding_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout
        )
        
        self.output = nn.Linear(embedding_size, vocab_size_tgt)

    def forward(
        self, src: Tensor, tgt: Tensor,
        mask_src: Tensor, mask_tgt: Tensor,
        padding_mask_src: Tensor, padding_mask_tgt: Tensor,
        memory_key_padding_mask: Tensor
    ):
        embedding_src = self.positional_encoding(self.token_embedding_src(src))
        # memory = self.transformer_encoder(embedding_src, mask_src, padding_mask_src)
        embedding_tgt = self.positional_encoding(self.token_embedding_tgt(tgt))
        outs = self.decoder_layer(
            embedding_tgt, embedding_src, mask_tgt, None,
            padding_mask_tgt, memory_key_padding_mask
        )
        return self.output(outs)

    def decode(self, tgt: Tensor, memory: Tensor, mask_tgt: Tensor):
        return self.decoder_layer(self.positional_encoding(self.token_embedding_tgt(tgt)), memory, mask_tgt)

In [15]:
def create_mask(src, tgt, PAD_IDX):
    
    seq_len_src = src.shape[0]
    seq_len_tgt = tgt.shape[0]

    mask_src = torch.zeros((seq_len_src, seq_len_src), device=device).type(torch.bool)
    mask_tgt = generate_square_subsequent_mask(seq_len_tgt)

    padding_mask_src = (src == PAD_IDX).transpose(0, 1)
    padding_mask_tgt = (tgt == PAD_IDX).transpose(0, 1)
    
    return mask_src, mask_tgt, padding_mask_src, padding_mask_tgt


def generate_square_subsequent_mask(seq_len):
    mask = (torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

学習の定義

In [16]:
def train(model, data, optimizer, criterion, PAD_IDX):
    
    model.train()
    losses = 0
    for src, tgt in tqdm(data):
        
        src = src.to(device)
        tgt = tgt.to(device)

        input_tgt = tgt[:-1, :]

        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_IDX)

        logits = model(
            src=src, tgt=input_tgt,
            mask_src=mask_src, mask_tgt=mask_tgt,
            padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
            memory_key_padding_mask=padding_mask_src
        )

        optimizer.zero_grad()
        output_tgt = tgt[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()
        
    return losses / len(data)

In [17]:

def evaluate(model, data, criterion, PAD_IDX):
    
    model.eval()
    losses = 0
    for src, tgt in data:
        
        src = src.to(device)
        tgt = tgt.to(device)

        input_tgt = tgt[:-1, :]

        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_IDX)

        logits = model(
            src=src, tgt=input_tgt,
            mask_src=mask_src, mask_tgt=mask_tgt,
            padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
            memory_key_padding_mask=padding_mask_src
        )
        
        output_tgt = tgt[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        losses += loss.item()
        
    return losses / len(data)

設定

In [18]:
vocab_size_src = len(vocab_src)
vocab_size_tgt = len(vocab_tgt)
embedding_size = 4
nhead = 1
dim_feedforward = 4
num_encoder_layers = 1
num_decoder_layers = 1
dropout = 0
# vocab_size_src = len(vocab_src)
# vocab_size_tgt = len(vocab_tgt)
# embedding_size = 240
# nhead = 8
# dim_feedforward = 100
# num_encoder_layers = 2
# num_decoder_layers = 2
# dropout = 0.1

model = Seq2SeqTransformer(
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    embedding_size=embedding_size,
    vocab_size_src=vocab_size_src, vocab_size_tgt=vocab_size_tgt,
    dim_feedforward=dim_feedforward,
    dropout=dropout, nhead=nhead
)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

model = model.to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(model.parameters())

モデルの調査

In [19]:
print(model)

Seq2SeqTransformer(
  (token_embedding_src): TokenEmbedding(
    (embedding): Embedding(14, 4, padding_idx=11)
  )
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0, inplace=False)
  )
  (token_embedding_tgt): TokenEmbedding(
    (embedding): Embedding(14, 4, padding_idx=11)
  )
  (decoder_layer): TransformerDecoderLayerScratch(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (linear1): Linear(in_features=4, out_features=4, bias=True)
    (dropout): Dropout(p=0, inplace=False)
    (linear2): Linear(in_features=4, out_features=4, bias=True)
    (norm1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (d

In [20]:
# モデル内の層の名前とパラメータ情報を表示
LP = list(model.named_parameters())
lp = len(LP)
print(f"{lp} 層")
for p in range(0, lp):
    print(f"\n層名: {LP[p][0]}")
    print(f"形状: {LP[p][1].shape}")
    print(f"値: {LP[p][1]}")


22 層

層名: token_embedding_src.embedding.weight
形状: torch.Size([14, 4])
値: Parameter containing:
tensor([[-0.2442, -0.5608,  0.1661, -0.0656],
        [-0.2991, -0.0175,  0.3813,  0.2564],
        [ 0.0549,  0.3204, -0.2359, -0.1937],
        [ 0.1486, -0.3486, -0.0201, -0.3452],
        [ 0.5576,  0.4646,  0.4771, -0.3959],
        [ 0.4643, -0.0222,  0.3757, -0.0673],
        [-0.5511,  0.5018,  0.0681,  0.5078],
        [ 0.5068, -0.3201, -0.3315,  0.5730],
        [-0.2852,  0.5019,  0.4931,  0.1591],
        [ 0.5670, -0.5485, -0.3329, -0.1426],
        [ 0.1483, -0.1588, -0.3534, -0.2823],
        [ 0.4688, -0.4443, -0.2801, -0.1368],
        [-0.0400,  0.5703, -0.5489,  0.5169],
        [ 0.4827, -0.5096,  0.3812,  0.0877]], device='cuda:0', requires_grad=True)

層名: token_embedding_tgt.embedding.weight
形状: torch.Size([14, 4])
値: Parameter containing:
tensor([[ 0.1954, -0.5407, -0.2756,  0.0025],
        [ 0.4609,  0.4682,  0.1350,  0.5624],
        [-0.3449, -0.5731,  0.0433, -0.

## 学習実行

In [21]:
epoch = 100
best_loss = float('Inf')
best_model = None
patience = 10
counter = 0

for loop in range(1, epoch + 1):
    
    start_time = time.time()
    
    loss_train = train(
        model=model, data=train_iter, optimizer=optimizer,
        criterion=criterion, PAD_IDX=PAD_IDX
    )
    
    elapsed_time = time.time() - start_time
    
    loss_valid = evaluate(
        model=model, data=valid_iter, criterion=criterion, PAD_IDX=PAD_IDX
    )
    
    print('[{}/{}] train loss: {:.2f}, valid loss: {:.2f}  [{}{:.0f}s] counter: {} {}'.format(
        loop, epoch,
        loss_train, loss_valid,
        str(int(math.floor(elapsed_time / 60))) + 'm' if math.floor(elapsed_time / 60) > 0 else '',
        elapsed_time % 60,
        counter,
        '**' if best_loss > loss_valid else ''
    ))
    
    if best_loss > loss_valid:
        best_loss = loss_valid
        best_model = model
        counter = 0
        
    if counter > patience:
        break
    
    counter += 1

100%|██████████| 250/250 [00:02<00:00, 85.86it/s]


[1/100] train loss: 1.18, valid loss: 0.64  [3s] counter: 0 **


100%|██████████| 250/250 [00:02<00:00, 90.91it/s]


[2/100] train loss: 0.47, valid loss: 0.30  [3s] counter: 1 **


100%|██████████| 250/250 [00:02<00:00, 90.63it/s]


[3/100] train loss: 0.18, valid loss: 0.11  [3s] counter: 1 **


100%|██████████| 250/250 [00:02<00:00, 91.00it/s]


[4/100] train loss: 0.09, valid loss: 0.06  [3s] counter: 1 **


100%|██████████| 250/250 [00:02<00:00, 91.22it/s]


[5/100] train loss: 0.05, valid loss: 0.04  [3s] counter: 1 **


100%|██████████| 250/250 [00:02<00:00, 91.09it/s]


[6/100] train loss: 0.03, valid loss: 0.03  [3s] counter: 1 **


100%|██████████| 250/250 [00:02<00:00, 91.15it/s]


[7/100] train loss: 0.02, valid loss: 0.01  [3s] counter: 1 **


100%|██████████| 250/250 [00:02<00:00, 94.62it/s] 


[8/100] train loss: 0.01, valid loss: 0.01  [3s] counter: 1 **


100%|██████████| 250/250 [00:02<00:00, 94.34it/s]


[9/100] train loss: 0.01, valid loss: 0.01  [3s] counter: 1 **


100%|██████████| 250/250 [00:02<00:00, 91.01it/s]


[10/100] train loss: 0.01, valid loss: 0.01  [3s] counter: 1 **


100%|██████████| 250/250 [00:02<00:00, 91.25it/s]


[11/100] train loss: 0.01, valid loss: 0.01  [3s] counter: 1 **


100%|██████████| 250/250 [00:02<00:00, 91.18it/s]


[12/100] train loss: 0.01, valid loss: 0.01  [3s] counter: 1 **


100%|██████████| 250/250 [00:02<00:00, 91.17it/s]


[13/100] train loss: 0.01, valid loss: 0.00  [3s] counter: 1 **


100%|██████████| 250/250 [00:02<00:00, 91.17it/s]


[14/100] train loss: 0.00, valid loss: 0.00  [3s] counter: 1 **


100%|██████████| 250/250 [00:02<00:00, 91.11it/s]


[15/100] train loss: 0.00, valid loss: 0.01  [3s] counter: 1 


100%|██████████| 250/250 [00:02<00:00, 91.08it/s]


[16/100] train loss: 0.00, valid loss: 0.00  [3s] counter: 2 


100%|██████████| 250/250 [00:02<00:00, 91.37it/s]


[17/100] train loss: 0.00, valid loss: 0.00  [3s] counter: 3 **


100%|██████████| 250/250 [00:02<00:00, 91.01it/s]


[18/100] train loss: 0.00, valid loss: 0.00  [3s] counter: 1 


100%|██████████| 250/250 [00:01<00:00, 136.39it/s]


[19/100] train loss: 0.00, valid loss: 0.00  [2s] counter: 2 **


100%|██████████| 250/250 [00:00<00:00, 264.63it/s]


[20/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 


100%|██████████| 250/250 [00:00<00:00, 264.48it/s]


[21/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 2 **


100%|██████████| 250/250 [00:00<00:00, 263.01it/s]


[22/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 


100%|██████████| 250/250 [00:00<00:00, 263.25it/s]


[23/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 2 **


100%|██████████| 250/250 [00:00<00:00, 265.22it/s]


[24/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 


100%|██████████| 250/250 [00:00<00:00, 264.30it/s]


[25/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 2 **


100%|██████████| 250/250 [00:00<00:00, 264.81it/s]


[26/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 **


100%|██████████| 250/250 [00:00<00:00, 262.51it/s]


[27/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 


100%|██████████| 250/250 [00:00<00:00, 262.99it/s]


[28/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 2 


100%|██████████| 250/250 [00:00<00:00, 264.03it/s]


[29/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 3 **


100%|██████████| 250/250 [00:00<00:00, 265.37it/s]


[30/100] train loss: 0.00, valid loss: 0.01  [1s] counter: 1 


100%|██████████| 250/250 [00:00<00:00, 264.25it/s]


[31/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 2 


100%|██████████| 250/250 [00:00<00:00, 264.05it/s]


[32/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 3 **


100%|██████████| 250/250 [00:00<00:00, 263.56it/s]


[33/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 


100%|██████████| 250/250 [00:00<00:00, 264.19it/s]


[34/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 2 **


100%|██████████| 250/250 [00:00<00:00, 264.39it/s]


[35/100] train loss: 0.00, valid loss: 0.01  [1s] counter: 1 


100%|██████████| 250/250 [00:00<00:00, 262.99it/s]


[36/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 2 **


100%|██████████| 250/250 [00:00<00:00, 262.87it/s]


[37/100] train loss: 0.01, valid loss: 0.00  [1s] counter: 1 


100%|██████████| 250/250 [00:00<00:00, 264.85it/s]


[38/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 2 


100%|██████████| 250/250 [00:00<00:00, 263.09it/s]


[39/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 3 


100%|██████████| 250/250 [00:00<00:00, 264.17it/s]


[40/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 4 


100%|██████████| 250/250 [00:00<00:00, 262.68it/s]


[41/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 5 


100%|██████████| 250/250 [00:00<00:00, 263.21it/s]


[42/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 6 **


100%|██████████| 250/250 [00:00<00:00, 264.00it/s]


[43/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 


100%|██████████| 250/250 [00:00<00:00, 265.12it/s]


[44/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 2 **


100%|██████████| 250/250 [00:00<00:00, 263.99it/s]


[45/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 **


100%|██████████| 250/250 [00:00<00:00, 264.57it/s]


[46/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 


100%|██████████| 250/250 [00:00<00:00, 263.97it/s]


[47/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 2 


100%|██████████| 250/250 [00:00<00:00, 264.65it/s]


[48/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 3 


100%|██████████| 250/250 [00:00<00:00, 264.04it/s]


[49/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 4 **


100%|██████████| 250/250 [00:00<00:00, 263.18it/s]


[50/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 


100%|██████████| 250/250 [00:00<00:00, 262.90it/s]


[51/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 2 


100%|██████████| 250/250 [00:00<00:00, 264.52it/s]


[52/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 3 


100%|██████████| 250/250 [00:00<00:00, 264.41it/s]


[53/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 4 **


100%|██████████| 250/250 [00:00<00:00, 264.62it/s]


[54/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 **


100%|██████████| 250/250 [00:00<00:00, 262.61it/s]


[55/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 **


100%|██████████| 250/250 [00:00<00:00, 264.70it/s]


[56/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 **


100%|██████████| 250/250 [00:00<00:00, 264.19it/s]


[57/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 **


100%|██████████| 250/250 [00:00<00:00, 264.69it/s]


[58/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 **


100%|██████████| 250/250 [00:00<00:00, 264.37it/s]


[59/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 **


100%|██████████| 250/250 [00:00<00:00, 263.27it/s]


[60/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 **


100%|██████████| 250/250 [00:00<00:00, 264.11it/s]


[61/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 **


100%|██████████| 250/250 [00:00<00:00, 264.47it/s]


[62/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 **


100%|██████████| 250/250 [00:00<00:00, 264.11it/s]


[63/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 1 


100%|██████████| 250/250 [00:00<00:00, 262.88it/s]


[64/100] train loss: 0.01, valid loss: 0.00  [1s] counter: 2 


100%|██████████| 250/250 [00:00<00:00, 263.37it/s]


[65/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 3 


100%|██████████| 250/250 [00:00<00:00, 264.12it/s]


[66/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 4 


100%|██████████| 250/250 [00:00<00:00, 263.86it/s]


[67/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 5 


100%|██████████| 250/250 [00:00<00:00, 263.59it/s]


[68/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 6 


100%|██████████| 250/250 [00:00<00:00, 263.90it/s]


[69/100] train loss: 0.00, valid loss: 0.00  [1s] counter: 7 


100%|██████████| 250/250 [00:02<00:00, 104.77it/s]


[70/100] train loss: 0.00, valid loss: 0.00  [2s] counter: 8 


100%|██████████| 250/250 [00:02<00:00, 91.13it/s]


[71/100] train loss: 0.00, valid loss: 0.00  [3s] counter: 9 


100%|██████████| 250/250 [00:02<00:00, 91.38it/s]


[72/100] train loss: 0.00, valid loss: 0.00  [3s] counter: 10 


100%|██████████| 250/250 [00:02<00:00, 91.37it/s]


[73/100] train loss: 0.00, valid loss: 0.00  [3s] counter: 11 


学習したモデルの保存

In [22]:
torch.save(best_model.state_dict(), model_dir_path.joinpath(version + 'translation_transfomer.pth'))

学習したモデルを使って翻訳をする

In [23]:
def translate(
    model, text, vocab_src, vocab_tgt, seq_len_tgt,
    START_IDX, END_IDX
):
    model.eval()
    tokens_src = convert_text_to_indexes(text, vocab=vocab_src)
    num_tokens_src = len(tokens_src)

    # Tensorに変換
    src = torch.LongTensor(tokens_src).reshape(num_tokens_src, 1).to(device)
    mask_src = torch.zeros((num_tokens_src, num_tokens_src), device=device).type(torch.bool)

    # デコード
    predicts = greedy_decode(
        model=model, src=src,
        mask_src=mask_src, seq_len_tgt=seq_len_tgt,
        START_IDX=START_IDX, END_IDX=END_IDX
    ).flatten()

    return convert_indexes_to_text(predicts.tolist(), vocab=vocab_tgt)

def greedy_decode(model, src, mask_src, seq_len_tgt, START_IDX, END_IDX):
    src = src.to(device)
    mask_src = mask_src.to(device)

    # ソースの埋め込みをメモリとして利用
    memory = model.positional_encoding(model.token_embedding_src(src))
    
    ys = torch.ones(1, 1).fill_(START_IDX).type(torch.long).to(device)
    
    for i in range(seq_len_tgt - 1):
        memory = memory.to(device)
        mask_tgt = generate_square_subsequent_mask(ys.size(0)).to(device).type(torch.bool)
        
        output = model.decode(ys, memory, mask_tgt)
        output = output.transpose(0, 1)
        output = model.output(output[:, -1])
        
        # 最も高いスコアのトークンを取得
        _, next_word = torch.max(output, dim=1)
        next_word = next_word.item()

        # 生成されたトークンを追加
        ys = torch.cat([ys, torch.ones(1, 1).fill_(next_word).type_as(src.data)], dim=0)
        if next_word == END_IDX:
            break
    
    return ys


In [50]:
seq_len_tgt = max([len(x[1]) for x in train_data])
text = '20'

# 翻訳を実行
translation = translate(
    model=best_model, text=text, vocab_src=vocab_src, vocab_tgt=vocab_tgt,
    seq_len_tgt=seq_len_tgt,
    START_IDX=START_IDX, END_IDX=END_IDX
)

print(f"Input: {text}")
print(f"Output: {translation}")

Input: 20
Output: 0


In [52]:
# 様々な入力を試してみる

text_list = ["1", "2", "3", "11", "5009", "200", "9999", "200033", "200004448", "20000006699"]

for text in text_list:
    translation = translate(
        model=best_model, text=text, vocab_src=vocab_src, vocab_tgt=vocab_tgt,
        seq_len_tgt=seq_len_tgt,
        START_IDX=START_IDX, END_IDX=END_IDX
    )
    print(f"Input: {text}")
    print(f"Output: {translation}")




Input: 9000
Output: 0
Input: 9011
Output: 1
Input: 9022
Output: 0
Input: 9033
Output: 1
Input: 9044
Output: 0
Input: 9055
Output: 1
Input: 9066
Output: 0
Input: 9077
Output: 1
Input: 9088
Output: 0
Input: 9099
Output: 1
Input: 9110
Output: 0
Input: 9121
Output: 1
Input: 9132
Output: 0
Input: 9143
Output: 1
Input: 9154
Output: 0
Input: 9165
Output: 1
Input: 9176
Output: 0
Input: 9187
Output: 1
Input: 9198
Output: 0
Input: 9209
Output: 1
Input: 9220
Output: 0
Input: 9231
Output: 1
Input: 9242
Output: 0
Input: 9253
Output: 1
Input: 9264
Output: 0
Input: 9275
Output: 1
Input: 9286
Output: 0
Input: 9297
Output: 1
Input: 9308
Output: 0
Input: 9319
Output: 1
Input: 9330
Output: 0
Input: 9341
Output: 1
Input: 9352
Output: 0
Input: 9363
Output: 1
Input: 9374
Output: 0
Input: 9385
Output: 1
Input: 9396
Output: 0
Input: 9407
Output: 1
Input: 9418
Output: 0
Input: 9429
Output: 1
Input: 9440
Output: 0
Input: 9451
Output: 1
Input: 9462
Output: 0
Input: 9473
Output: 1
Input: 9484
Output: 0
Input: 949

In [60]:
# 網羅性チェック
# 0-99999までの数値の偶数奇数を判断して、ミスがないか確認

count = 0
for text in range(0, 20000):
    translation = translate(
        model=best_model, text=str(text), vocab_src=vocab_src, vocab_tgt=vocab_tgt,
        seq_len_tgt=seq_len_tgt,
        START_IDX=START_IDX, END_IDX=END_IDX
    )
    # 偶数かつ奇数として判断された場合のみ表示
    if text % 2 == 0 and translation == '1':
        print(f"Input: {text}/{translation}")
        count += 1
    elif text % 2 == 1 and translation == '0':
        print(f"Input: {text}/{translation}")
        count += 1

# 10000を超えると、翻訳が正しく行われない
print(count)

Input: 10001
Output: 0
Input: 10003
Output: 0
Input: 10005
Output: 0
Input: 10007
Output: 0
Input: 10009
Output: 0
Input: 10021
Output: 0
Input: 10023
Output: 0
Input: 10025
Output: 0
Input: 10027
Output: 0
Input: 10029
Output: 0
Input: 10041
Output: 0
Input: 10043
Output: 0
Input: 10045
Output: 0
Input: 10047
Output: 0
Input: 10049
Output: 0
Input: 10061
Output: 0
Input: 10063
Output: 0
Input: 10065
Output: 0
Input: 10067
Output: 0
Input: 10069
Output: 0
Input: 10081
Output: 0
Input: 10083
Output: 0
Input: 10085
Output: 0
Input: 10087
Output: 0
Input: 10089
Output: 0
Input: 10101
Output: 0
Input: 10103
Output: 0
Input: 10105
Output: 0
Input: 10107
Output: 0
Input: 10109
Output: 0
Input: 10121
Output: 0
Input: 10123
Output: 0
Input: 10125
Output: 0
Input: 10127
Output: 0
Input: 10129
Output: 0
Input: 10141
Output: 0
Input: 10143
Output: 0
Input: 10145
Output: 0
Input: 10147
Output: 0
Input: 10149
Output: 0
Input: 10161
Output: 0
Input: 10163
Output: 0
Input: 10165
Output: 0
Input: 1016

KeyboardInterrupt: 

In [61]:
# 網羅性チェック
# 0-99999までの数値の偶数奇数を判断して、ミスがないか確認

count = 0
for text in range(10000, 20000):
    translation = translate(
        model=best_model, text=str(text), vocab_src=vocab_src, vocab_tgt=vocab_tgt,
        seq_len_tgt=seq_len_tgt,
        START_IDX=START_IDX, END_IDX=END_IDX
    )
    # 偶数かつ奇数として判断された場合のみ表示
    if text % 2 == 0 and translation == '1':
        print(f"Input: {text}/{translation}")
        count += 1
    elif text % 2 == 1 and translation == '0':
        print(f"Input: {text}/{translation}")
        count += 1

# 10000を超えると、翻訳が正しく行われない
print(count)

Input: 10001/0
Input: 10003/0
Input: 10005/0
Input: 10007/0
Input: 10009/0
Input: 10021/0
Input: 10023/0
Input: 10025/0
Input: 10027/0
Input: 10029/0
Input: 10041/0
Input: 10043/0
Input: 10045/0
Input: 10047/0
Input: 10049/0
Input: 10061/0
Input: 10063/0
Input: 10065/0
Input: 10067/0
Input: 10069/0
Input: 10081/0
Input: 10083/0
Input: 10085/0
Input: 10087/0
Input: 10089/0
Input: 10101/0
Input: 10103/0
Input: 10105/0
Input: 10107/0
Input: 10109/0
Input: 10121/0
Input: 10123/0
Input: 10125/0
Input: 10127/0
Input: 10129/0
Input: 10141/0
Input: 10143/0
Input: 10145/0
Input: 10147/0
Input: 10149/0
Input: 10161/0
Input: 10163/0
Input: 10165/0
Input: 10167/0
Input: 10169/0
Input: 10181/0
Input: 10183/0
Input: 10185/0
Input: 10187/0
Input: 10189/0
Input: 10201/0
Input: 10203/0
Input: 10205/0
Input: 10207/0
Input: 10209/0
Input: 10221/0
Input: 10223/0
Input: 10225/0
Input: 10227/0
Input: 10229/0
Input: 10241/0
Input: 10243/0
Input: 10245/0
Input: 10247/0
Input: 10249/0
Input: 10261/0
Input: 102

## モデルの動作を分析

In [26]:
import torch

# モデルのロード
model_path = model_dir_path.joinpath(version + 'translation_transfomer.pth')
loaded_model = Seq2SeqTransformer(
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    embedding_size=embedding_size,
    vocab_size_src=vocab_size_src, vocab_size_tgt=vocab_size_tgt,
    dim_feedforward=dim_feedforward,
    dropout=dropout, nhead=nhead
).to(device)
loaded_model.load_state_dict(torch.load(model_path))
loaded_model.eval()


Seq2SeqTransformer(
  (token_embedding_src): TokenEmbedding(
    (embedding): Embedding(14, 4, padding_idx=11)
  )
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0, inplace=False)
  )
  (token_embedding_tgt): TokenEmbedding(
    (embedding): Embedding(14, 4, padding_idx=11)
  )
  (decoder_layer): TransformerDecoderLayerScratch(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (linear1): Linear(in_features=4, out_features=4, bias=True)
    (dropout): Dropout(p=0, inplace=False)
    (linear2): Linear(in_features=4, out_features=4, bias=True)
    (norm1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (d

In [27]:
# パラメータを取り出す


# モデルのパラメータを取得
params = dict(loaded_model.named_parameters())
print(params.keys())

# 埋め込み行列を取得
embedding_src_weight = params['token_embedding_src.embedding.weight'].data
embedding_tgt_weight = params['token_embedding_tgt.embedding.weight'].data

# 線形層の重みとバイアス
output_weight = params['output.weight'].data
output_bias = params['output.bias'].data

# デコーダの自己注意の重みとバイアス
self_attn_in_proj_weight = params['decoder_layer.self_attn.in_proj_weight'].data
self_attn_in_proj_bias = params['decoder_layer.self_attn.in_proj_bias'].data
self_attn_out_proj_weight = params['decoder_layer.self_attn.out_proj.weight'].data
self_attn_out_proj_bias = params['decoder_layer.self_attn.out_proj.bias'].data

# メモリー注意の重みとバイアス
multihead_attn_in_proj_weight = params['decoder_layer.multihead_attn.in_proj_weight'].data
multihead_attn_in_proj_bias = params['decoder_layer.multihead_attn.in_proj_bias'].data
multihead_attn_out_proj_weight = params['decoder_layer.multihead_attn.out_proj.weight'].data
multihead_attn_out_proj_bias = params['decoder_layer.multihead_attn.out_proj.bias'].data

# フィードフォワードネットワークの重みとバイアス
linear1_weight = params['decoder_layer.linear1.weight'].data
linear1_bias = params['decoder_layer.linear1.bias'].data
linear2_weight = params['decoder_layer.linear2.weight'].data
linear2_bias = params['decoder_layer.linear2.bias'].data

# LayerNormのパラメータ
norm1_weight = params['decoder_layer.norm1.weight'].data
norm1_bias = params['decoder_layer.norm1.bias'].data
norm2_weight = params['decoder_layer.norm2.weight'].data
norm2_bias = params['decoder_layer.norm2.bias'].data
norm3_weight = params['decoder_layer.norm3.weight'].data
norm3_bias = params['decoder_layer.norm3.bias'].data


dict_keys(['token_embedding_src.embedding.weight', 'token_embedding_tgt.embedding.weight', 'decoder_layer.self_attn.in_proj_weight', 'decoder_layer.self_attn.in_proj_bias', 'decoder_layer.self_attn.out_proj.weight', 'decoder_layer.self_attn.out_proj.bias', 'decoder_layer.multihead_attn.in_proj_weight', 'decoder_layer.multihead_attn.in_proj_bias', 'decoder_layer.multihead_attn.out_proj.weight', 'decoder_layer.multihead_attn.out_proj.bias', 'decoder_layer.linear1.weight', 'decoder_layer.linear1.bias', 'decoder_layer.linear2.weight', 'decoder_layer.linear2.bias', 'decoder_layer.norm1.weight', 'decoder_layer.norm1.bias', 'decoder_layer.norm2.weight', 'decoder_layer.norm2.bias', 'decoder_layer.norm3.weight', 'decoder_layer.norm3.bias', 'output.weight', 'output.bias'])


In [28]:

# Positional Encoding
def positional_encoding(tensor: Tensor, maxlen=5000):
    embedding_size = tensor.size(-1)
    den = torch.exp(-torch.arange(0, embedding_size, 2) * math.log(10000) / embedding_size)
    pos = torch.arange(0, maxlen).reshape(maxlen, 1)
    embedding_pos = torch.zeros((maxlen, embedding_size))
    embedding_pos[:, 0::2] = torch.sin(pos * den)
    embedding_pos[:, 1::2] = torch.cos(pos * den)
    embedding_pos = embedding_pos.unsqueeze(-2)
    return tensor + embedding_pos[: tensor.size(0), :].to(tensor.device)

In [59]:

# 翻訳処理を実行
seq_len_tgt = max([len(x[1]) for x in train_data])
text = '999999'

tokens_src = convert_text_to_indexes(text, vocab=vocab_src)
src = torch.LongTensor(tokens_src).reshape(len(tokens_src), 1).to(device)
memory = positional_encoding(embedding_src_weight[src] * math.sqrt(embedding_size))
ys = torch.ones(1, 1).fill_(START_IDX).type(torch.long).to(device)

for i in range(10):
    tgt_embed = positional_encoding(embedding_tgt_weight[ys] * math.sqrt(embedding_size))
    tgt_mask = generate_square_subsequent_mask(ys.size(0)).to(device).type(torch.bool)
    
    # Self-attention
    self_attn_wq, self_attn_wk, self_attn_wv = self_attn_in_proj_weight.chunk(3, dim=0)
    self_attn_bq, self_attn_bk, self_attn_bv = self_attn_in_proj_bias.chunk(3, dim=0)
    QW = torch.matmul(tgt_embed.permute(1, 0, 2), self_attn_wq.T) + self_attn_bq
    KW = torch.matmul(tgt_embed.permute(1, 0, 2), self_attn_wk.T) + self_attn_bk
    VW = torch.matmul(tgt_embed.permute(1, 0, 2), self_attn_wv.T) + self_attn_bv
    self_attn_weights = F.softmax(torch.bmm(QW, KW.transpose(-2, -1)) / math.sqrt(embedding_size), dim=-1)

    AV = torch.matmul(self_attn_weights, VW)
    self_attn_output = torch.matmul(AV, self_attn_out_proj_weight.T) + self_attn_out_proj_bias
    self_attn_output = self_attn_output.permute(1, 0, 2)
    tgt = tgt_embed + self_attn_output
    tgt = loaded_model.decoder_layer.norm1(tgt)



    # Attention with the encoder outputs (memory)
    multi_attn_wq, multi_attn_wk, multi_attn_wv = multihead_attn_in_proj_weight.chunk(3, dim=0)
    multi_attn_bq, multi_attn_bk, multi_attn_bv = multihead_attn_in_proj_bias.chunk(3, dim=0)
    QW = torch.matmul(tgt.permute(1, 0, 2), multi_attn_wq.T) + multi_attn_bq
    KW = torch.matmul(memory.permute(1, 0, 2), multi_attn_wk.T) + multi_attn_bk
    VW = torch.matmul(memory.permute(1, 0, 2), multi_attn_wv.T) + multi_attn_bv
    multi_attn_weights = F.softmax(torch.bmm(QW, KW.transpose(-2, -1)) / math.sqrt(embedding_size), dim=-1)

    AV = torch.matmul(multi_attn_weights, VW)
    multi_attn_output = torch.matmul(AV, multihead_attn_out_proj_weight.T) + multihead_attn_out_proj_bias
    multi_attn_output = multi_attn_output.permute(1, 0, 2)
    tgt = tgt + multi_attn_output
    tgt = loaded_model.decoder_layer.norm2(tgt)
    
    # Feedforward network

    # decoder linear1, 2
    tgt2 = tgt.matmul(linear1_weight.T) + linear1_bias
    tgt2 = F.relu(tgt2)
    tgt2 = tgt2.matmul(linear2_weight.T) + linear2_bias
    tgt = tgt + tgt2

    # LayerNorm
    tgt = loaded_model.decoder_layer.norm3(tgt)

    output = tgt.transpose(0, 1)
    output = loaded_model.output(output[:, -1])

    _, next_word = torch.max(output, dim=1)
    next_word = next_word.item()

    ys = torch.cat([ys, torch.ones(1, 1).fill_(next_word).type_as(src.data)], dim=0)
    print(next_word)
    
    if next_word == END_IDX:
        break


flat_indexes = [idx for sublist in ys.tolist() for idx in sublist] if isinstance(ys.tolist()[0], list) else ys.tolist()

print(f"Input: {text}")
print(f"Decoded sequence: {convert_indexes_to_text(flat_indexes, vocab_tgt)}")

1
13
Input: 999999
Decoded sequence: 1


# Transformerの検算

スクラッチで書くための検算

## Multihead Attention

Multihead Attentionの動作をスクラッチで書きたいので、ここで検算する

参考サイト
https://blog.amedama.jp/entry/pytorch-multi-head-attention-verify

In [30]:
import torch
from torch import nn
import torch.nn.functional as F


In [31]:
edim = 4 # 埋め込み次元
num_heads = 1 # ヘッド数
model = nn.MultiheadAttention(edim, num_heads, bias=True, batch_first=True)

In [32]:
batch_size = 2
L=5
X = torch.randn(batch_size, L, edim) # 入力

Q = K = V = X # クエリ、キー、バリューは全て入力とする
print(Q.shape)
print(Q)

torch.Size([2, 5, 4])
tensor([[[-0.6987, -0.3834,  1.2682, -0.9115],
         [ 0.1152,  0.0263,  1.2358,  0.0605],
         [ 0.3132, -0.8764,  0.0898, -1.3914],
         [-0.4325, -0.9414,  0.4386, -1.1127],
         [-1.1277,  0.8637, -0.3704, -0.0910]],

        [[ 2.1964, -1.9274, -0.2268,  0.4893],
         [-0.2748,  1.1336,  0.2190, -0.6835],
         [-0.1141, -0.4967, -1.3600,  0.2523],
         [ 0.3653,  2.9546,  0.3290, -0.3951],
         [ 0.8593,  0.6910, -0.5978,  1.4734]]])


In [33]:

attn_output, attn_output_weights = model(Q, K, V)

print(attn_output.shape)
print(attn_output)



torch.Size([2, 5, 4])
tensor([[[ 0.2059, -0.4056,  0.2160, -0.0673],
         [ 0.1822, -0.3807,  0.2070, -0.0473],
         [ 0.1560, -0.4002,  0.2226, -0.0127],
         [ 0.1659, -0.3980,  0.2178, -0.0270],
         [ 0.1697, -0.3473,  0.1753, -0.0671]],

        [[-0.9989,  0.7810, -0.2646,  0.8636],
         [ 0.3898,  0.4143, -0.4019, -0.7050],
         [-0.5618,  0.6062, -0.2746,  0.3864],
         [ 0.5713,  0.4507, -0.4685, -0.9366],
         [-0.6284,  0.6691, -0.2853,  0.4612]]], grad_fn=<TransposeBackward0>)


In [34]:
from pprint import pprint
pprint(list(model.named_parameters()))

[('in_proj_weight',
  Parameter containing:
tensor([[ 1.6164e-01, -5.7244e-01,  9.6066e-03, -5.3087e-01],
        [ 3.5323e-01, -2.8011e-01, -2.6955e-01,  5.0069e-01],
        [-4.8172e-01,  5.8581e-01, -3.8958e-02,  2.4328e-01],
        [ 4.8875e-01, -5.1309e-01,  4.8611e-01, -5.8679e-01],
        [-1.9868e-01, -3.2428e-01,  6.5746e-03, -5.0330e-01],
        [ 5.2913e-01, -5.8652e-01,  2.9461e-02,  1.4236e-01],
        [ 1.6880e-01,  3.3092e-01,  2.6239e-01, -6.1039e-01],
        [ 3.1390e-04,  2.2849e-01,  4.5257e-01, -1.0142e-01],
        [-5.8239e-01,  4.9897e-01,  5.1502e-01,  3.9330e-02],
        [ 3.2820e-01,  4.2034e-01, -5.4201e-01, -4.6981e-01],
        [-4.5807e-01, -4.7335e-01, -2.1864e-01, -3.4159e-01],
        [-2.4120e-01,  4.2626e-01, -2.9527e-01, -5.8191e-01]], requires_grad=True)),
 ('in_proj_bias',
  Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)),
 ('out_proj.weight',
  Parameter containing:
tensor([[ 0.4081, -0.33

In [35]:
model_weight = {name: param.data for name, param in model.named_parameters()}
Wi = model_weight['in_proj_weight']
Wo = model_weight['out_proj.weight']
Wbi = model_weight['in_proj_bias']
Wbo = model_weight['out_proj.bias']

In [36]:
Wi_q, Wi_k, Wi_v = Wi.chunk(3, dim=0)
Wbi_q, Wbi_k, Wbi_v = Wbi.chunk(3, dim=0)
QW = torch.matmul(Q, Wi_q.T) + Wbi_q
KW = torch.matmul(K, Wi_k.T) + Wbi_k
VW = torch.matmul(V, Wi_v.T) + Wbi_v

KW_t = KW.transpose(-2, -1)
QK_t = torch.bmm(QW, KW_t)
QK_scaled = QK_t / (edim ** 0.5)
attn_weights_ = F.softmax(QK_scaled, dim=-1)

In [37]:
print(attn_weights_)
print(attn_output_weights)

tensor([[[0.2550, 0.1756, 0.1428, 0.1795, 0.2471],
         [0.2207, 0.2139, 0.1668, 0.1815, 0.2171],
         [0.2623, 0.1733, 0.1690, 0.2222, 0.1732],
         [0.2570, 0.1694, 0.1658, 0.2107, 0.1972],
         [0.1773, 0.1757, 0.1878, 0.1696, 0.2896]],

        [[0.7347, 0.0297, 0.1307, 0.0082, 0.0968],
         [0.0493, 0.2572, 0.1150, 0.4460, 0.1325],
         [0.4085, 0.1053, 0.2311, 0.0691, 0.1860],
         [0.0236, 0.1909, 0.0642, 0.6030, 0.1183],
         [0.4714, 0.0651, 0.1643, 0.0633, 0.2359]]])
tensor([[[0.2550, 0.1756, 0.1428, 0.1795, 0.2471],
         [0.2207, 0.2139, 0.1668, 0.1815, 0.2171],
         [0.2623, 0.1733, 0.1690, 0.2222, 0.1732],
         [0.2570, 0.1694, 0.1658, 0.2107, 0.1972],
         [0.1773, 0.1757, 0.1878, 0.1696, 0.2896]],

        [[0.7347, 0.0297, 0.1307, 0.0082, 0.0968],
         [0.0493, 0.2572, 0.1150, 0.4460, 0.1325],
         [0.4085, 0.1053, 0.2311, 0.0691, 0.1860],
         [0.0236, 0.1909, 0.0642, 0.6030, 0.1183],
         [0.4714, 0.0651,

In [38]:
AV = torch.matmul(attn_weights_, VW)
attn_output_ = torch.matmul(AV, Wo.T) + Wbo

In [39]:
print(attn_output_)
print(attn_output)

tensor([[[ 0.2059, -0.4056,  0.2160, -0.0673],
         [ 0.1822, -0.3807,  0.2070, -0.0473],
         [ 0.1560, -0.4002,  0.2226, -0.0127],
         [ 0.1659, -0.3980,  0.2178, -0.0270],
         [ 0.1697, -0.3473,  0.1753, -0.0671]],

        [[-0.9989,  0.7810, -0.2646,  0.8636],
         [ 0.3898,  0.4143, -0.4019, -0.7050],
         [-0.5618,  0.6062, -0.2746,  0.3864],
         [ 0.5713,  0.4507, -0.4685, -0.9366],
         [-0.6284,  0.6691, -0.2853,  0.4612]]])
tensor([[[ 0.2059, -0.4056,  0.2160, -0.0673],
         [ 0.1822, -0.3807,  0.2070, -0.0473],
         [ 0.1560, -0.4002,  0.2226, -0.0127],
         [ 0.1659, -0.3980,  0.2178, -0.0270],
         [ 0.1697, -0.3473,  0.1753, -0.0671]],

        [[-0.9989,  0.7810, -0.2646,  0.8636],
         [ 0.3898,  0.4143, -0.4019, -0.7050],
         [-0.5618,  0.6062, -0.2746,  0.3864],
         [ 0.5713,  0.4507, -0.4685, -0.9366],
         [-0.6284,  0.6691, -0.2853,  0.4612]]], grad_fn=<TransposeBackward0>)


## nn.Linear

In [40]:
model = nn.Linear(4, 4)
model

Linear(in_features=4, out_features=4, bias=True)

In [41]:
pprint(list(model.named_parameters()))

[('weight',
  Parameter containing:
tensor([[-0.0820, -0.4251, -0.0604,  0.4206],
        [ 0.3261,  0.3803, -0.3762, -0.3561],
        [-0.4774,  0.1000, -0.1047,  0.4022],
        [ 0.3403, -0.0822,  0.4370,  0.2757]], requires_grad=True)),
 ('bias',
  Parameter containing:
tensor([ 0.3558, -0.4544, -0.2201, -0.4630], requires_grad=True))]


In [42]:
model_weight = {name: param.data for name, param in model.named_parameters()}
W = model_weight['weight']
B = model_weight['bias']

X = torch.randn(4) 
print(X.shape)
print(X)
output = model(X)
print(output.shape)
print(output)


torch.Size([4])
tensor([-1.2497, -2.1330, -0.1963,  1.3451])
torch.Size([4])
tensor([ 1.9427, -2.0784,  0.7248, -0.4280], grad_fn=<ViewBackward0>)


In [43]:
output_ = X.matmul(W.T) + B
print(output_)
print(output)

tensor([ 1.9427, -2.0784,  0.7248, -0.4280])
tensor([ 1.9427, -2.0784,  0.7248, -0.4280], grad_fn=<ViewBackward0>)


## nn.LayerNorm

参考サイト
https://qiita.com/dl_from_scratch/items/133fe741b67ed14f1856

In [44]:
model = nn.LayerNorm(4)
model

LayerNorm((4,), eps=1e-05, elementwise_affine=True)

In [45]:
pprint(list(model.named_parameters()))

[('weight', Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True))]


In [46]:
model_weight = {name: param.data for name, param in model.named_parameters()}
W = model_weight['weight']
B = model_weight['bias']

X = torch.randn(4) 
print(X.shape)
print(X)
output = model(X)
print(output.shape)
print(output)


torch.Size([4])
tensor([-2.3616, -0.3533, -0.9263,  0.7948])
torch.Size([4])
tensor([-1.4519,  0.3152, -0.1889,  1.3255], grad_fn=<NativeLayerNormBackward0>)


In [47]:
output_ = X.matmul(W.T) + B
print(output_)
print(output)

tensor([-2.8464, -2.8464, -2.8464, -2.8464])
tensor([-1.4519,  0.3152, -0.1889,  1.3255], grad_fn=<NativeLayerNormBackward0>)


  output_ = X.matmul(W.T) + B
