# Attention from scratch
ここではpytorch_command.ipynbで学んだことを活かしてAttentionの実装を行います。<br>
### 目次<br>
- Attentionが生まれた経緯
- Self-Attention, Multi-head Attention, Transformerの理論、実装
- Transformer機構を用いた日英翻訳
Attention機構を用いた翻訳を作る際に使うデータセットはHugging Faceのものを扱います。https://huggingface.co/datasets/snow_simplified_japanese_corpus<br>

#### Attentionが生まれた経緯

Attntion機構が初めて世の中に発表されたのは""Neural Machine Translation by Jointly Learning to Align and Translate""という論文であるとされている。<br>
この論文はRNNに注意という概念を加えて機械翻訳の性能を向上させたものである。<br>
この論文内ではAttentionという名前は使われていなく、Alignという表記がされていることに注意してほしい。<br>

Attentionが生まれた以前の機械翻訳は深層学習一強というわけでもなく、Encoder-Decoderを用いたRNNベースの深層学習だったり、<br>
Mosesと呼ばれる統計的機会翻訳が強かったりとどれが一番強い手法なのかを競い合っていた。<br>
「Neural Machine Translation by Jointly Learning to Align and Translate」が公開されたのは2014年であり、<br>
Google翻訳も2016年から処理を深層学習ベースに変えたとwikipediaにもある。<br>

統計的機械翻訳と深層学習ベース機械翻訳の比較は他の人に任せるとして、<br>
ここでは既存のSeq2Seqの機械翻訳に比べてAttentionのどこが優位であったのかを簡単に述べる。<br>
Seq2Seqモデルの問題点<br>
- 入力する文章の長さに関わらず固定長ベクトルに圧縮するため、必要な情報が捨象されてしまう。<br>
- 単語、文章同士の照応関係を利用できない。<br>
- 学習データにある文字列長以上の文字列を入力すると上手く機能しない。<br>
<br>
これらの問題点を理解するのは実際にSeq2Seqモデルを組んで実行してみるといいと思います。

これらの問題点を緩和したものがRNNにAttention機構を追加したものであり、RNNを使わないことでRNN由来の問題点を解消し、<br>
更に精度を向上させたものがMulti-head Attentionを用いたTransformerとなります。<br>

### Self-Attention, Multi-head Attention, Transformerの理論、実装

##### Self-Attentionメカニズム

Self−Attentionメカニズムは3つのステージで構成されている。<br>
1. 入力シーケンス$x^{(1)},x^{(2)},...,x^{(t)}$のうち、現在の要素$x^{(i)}$と、他の全要素$x^{(j)}$との間の類似度に基づいて重要度$ω_{ij}$を計算する<br>

これを式で表現すると以下の通りとなる<br>
$$
ω_{ij} = {x^{(i)}}^{T}x^{(j)}
$$
  計画行列は
$$
X = \begin{pmatrix} 
  {x^{(1)}}^{T} \\
  {x^{(2)}}^{T} \\
  \vdots\\
  \vdots\\
  {x^{(t)}}^{T} 
\end{pmatrix}, X^{T} = \begin{pmatrix} x^{(1)},  x^{(2)},\dots,x^{(t)} \end{pmatrix}
$$
と表されることを考えると、$Ω = {(ω_{ij})_{i,j}}$という$ω_{ij}$をまとめた行列を求めるときには
$$
Ω = XX^{T}
$$
を計算すれば良い。


2. 1.で求めた$ω_{ij}$を正規化して$α_{ij}$を求める。

数式で表現すると<br>
$$
α_{ij} = \dfrac{\exp{(\omega_{ij})}}{\displaystyle \sum_{k=1}^{t}\exp{(\omega_{ik})}}
$$
である。実装の際はpytorchで効率よく計算を行う
<br>

3. 重みをシーケンス内の対応する要素と組み合わせてAttentionスコアを計算する

数式で表現すると、
$$
z^{(i)} = \displaystyle \sum_{j = 1}^{t} \alpha_{ij}x^{(j)}
$$
である。実装の際はpytorchで効率よく計算を行う。この$z^{(i)}$がSelf-Attentionの出力となる。

計算についてもう少し詳しく見ていくと、<br>
$$
z_{ji} = \alpha_{i1}x_{1j} + \alpha_{i2}x_{2j} + \dots + \alpha_{it}x_{tj}
$$
である。これは、
$$
Z = \begin{pmatrix} 
  {z^{(1)}}^{T} \\
  {z^{(2)}}^{T} \\
  \vdots\\
  \vdots\\
  {z^{(t)}}^{T} 
\end{pmatrix}
$$
としたときに、
$$
Z = AX, \text{where} \, A = (\alpha_{ij})_{i,j}
$$

を計算することにあたる

ここまでの処理を詳しく見ていこう。

1. 重要度$\omega_{ij}$の計算

既に辞書などを用いて単語は整数表現にマッピングされているものとする。<br>
このとき、[3,5,1,6,8,3,6,4]と表されるような文章を計算の対象とする。<br>
まずはこの文のベクトル表現をnn.Embeddingで取得する。dictionary_sizeは10, embedding_sizeは16とする。<br>
もしここまでで分散表現や埋め込みという表現に馴染みがなかった場合は自然言語処理の前処理について調べてみると良い<br>

In [1]:
#ライブラリの定義、シード値の固定
import torch
import numpy as np
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
import warnings
warnings.simplefilter('ignore')

In [2]:
from torch import nn
x = torch.Tensor([3,5,1,6,8,3,6,4]).int()
embed = nn.Embedding(10,16)
embed_x = embed(x)
print("生成された埋め込みベクトル:\n", embed_x)

生成された埋め込みベクトル:
 tensor([[-9.1382e-01, -6.5814e-01,  7.8024e-02,  5.2581e-01, -4.8799e-01,
          1.1914e+00, -8.1401e-01, -7.3599e-01, -1.4032e+00,  3.6004e-02,
         -6.3477e-02,  6.7561e-01, -9.7807e-02,  1.8446e+00, -1.1845e+00,
          1.3835e+00],
        [ 1.0868e-02, -3.3874e-01, -1.3407e+00, -5.8537e-01,  5.3619e-01,
          5.2462e-01,  1.1412e+00,  5.1644e-02,  7.4395e-01, -4.8158e-01,
         -1.0495e+00,  6.0390e-01, -1.7223e+00, -8.2777e-01,  1.3347e+00,
          4.8354e-01],
        [ 1.6423e+00, -1.5960e-01, -4.9740e-01,  4.3959e-01, -7.5813e-01,
          1.0783e+00,  8.0080e-01,  1.6806e+00,  1.2791e+00,  1.2964e+00,
          6.1047e-01,  1.3347e+00, -2.3162e-01,  4.1759e-02, -2.5158e-01,
          8.5986e-01],
        [-2.5095e+00,  4.8800e-01,  7.8459e-01,  2.8647e-02,  6.4076e-01,
          5.8325e-01,  1.0669e+00, -4.5015e-01, -1.8527e-01,  7.5276e-01,
          4.0476e-01,  1.7847e-01,  2.6491e-01,  1.2732e+00, -1.3109e-03,
         -3.0360e-01],
    

In [3]:
omega = embed_x @ embed_x.transpose(0,1)
print("omega: \n", omega)

omega: 
 tensor([[13.5729, -3.6601, -0.7355,  4.1794,  0.6965, 13.5729,  4.1794, -1.0428],
        [-3.6601, 12.0409,  2.5783, -1.8934,  6.3044, -3.6601, -1.8934, -8.5589],
        [-0.7355,  2.5783, 14.6957, -3.3807,  9.0543, -0.7355, -3.3807,  3.0030],
        [ 4.1794, -1.8934, -3.3807, 11.8240, -3.9458,  4.1794, 11.8240, -2.9556],
        [ 0.6965,  6.3044,  9.0543, -3.9458, 19.0526,  0.6965, -3.9458, -0.1805],
        [13.5729, -3.6601, -0.7355,  4.1794,  0.6965, 13.5729,  4.1794, -1.0428],
        [ 4.1794, -1.8934, -3.3807, 11.8240, -3.9458,  4.1794, 11.8240, -2.9556],
        [-1.0428, -8.5589,  3.0030, -2.9556, -0.1805, -1.0428, -2.9556, 17.5832]],
       grad_fn=<MmBackward0>)


2. 1.で求めた$ω_{ij}$を正規化して$α_{ij}$を求める。<br>
計算の際は愚直に計算をするか、torch.nn.functionalのsoftmax関数を利用する。愚直に計算をする場合はkeepdim引数を忘れないこと

In [4]:
import torch.nn.functional as F
print("愚直に実装を行った場合: \n", torch.exp(omega) / torch.exp(omega).sum(dim=1,keepdim=True))
print("torch.nn.functional.softmaxを使った場合: \n",F.softmax(omega, dim = 1))

愚直に実装を行った場合: 
 tensor([[4.9996e-01, 1.6396e-08, 3.0542e-07, 4.1630e-05, 1.2788e-06, 4.9996e-01,
         4.1630e-05, 2.2461e-07],
        [1.5126e-07, 9.9670e-01, 7.7451e-05, 8.8509e-07, 3.2154e-03, 1.5126e-07,
         8.8509e-07, 1.1277e-09],
        [1.9805e-07, 5.4442e-06, 9.9645e-01, 1.4059e-08, 3.5352e-03, 1.9805e-07,
         1.4059e-08, 8.3246e-06],
        [2.3919e-04, 5.5126e-07, 1.2457e-07, 4.9976e-01, 7.0797e-08, 2.3919e-04,
         4.9976e-01, 1.9058e-07],
        [1.0667e-08, 2.9075e-06, 4.5476e-05, 1.0278e-10, 9.9995e-01, 1.0667e-08,
         1.0278e-10, 4.4379e-09],
        [4.9996e-01, 1.6396e-08, 3.0542e-07, 4.1630e-05, 1.2788e-06, 4.9996e-01,
         4.1630e-05, 2.2461e-07],
        [2.3919e-04, 5.5126e-07, 1.2457e-07, 4.9976e-01, 7.0798e-08, 2.3919e-04,
         4.9976e-01, 1.9058e-07],
        [8.1439e-09, 4.4321e-12, 4.6546e-07, 1.2026e-09, 1.9289e-08, 8.1439e-09,
         1.2026e-09, 1.0000e+00]], grad_fn=<DivBackward0>)
torch.nn.functional.softmaxを使った場合: 
 ten

3. 重みをシーケンス内の対応する要素と組み合わせてAttentionスコアを計算する

In [5]:
attention_weight = F.softmax(omega, dim = 1)
z = attention_weight @ embed_x
print("Attention Score: \n", z)

Attention Score: 
 tensor([[-9.1395e-01, -6.5804e-01,  7.8081e-02,  5.2577e-01, -4.8790e-01,
          1.1913e+00, -8.1385e-01, -7.3597e-01, -1.4031e+00,  3.6064e-02,
         -6.3440e-02,  6.7557e-01, -9.7777e-02,  1.8445e+00, -1.1844e+00,
          1.3834e+00],
        [ 1.7164e-02, -3.3438e-01, -1.3409e+00, -5.8704e-01,  5.3393e-01,
          5.2824e-01,  1.1396e+00,  5.3455e-02,  7.4527e-01, -4.7984e-01,
         -1.0518e+00,  6.0499e-01, -1.7178e+00, -8.2171e-01,  1.3281e+00,
          4.8406e-01],
        [ 1.6433e+00, -1.5545e-01, -5.0070e-01,  4.3404e-01, -7.5592e-01,
          1.0803e+00,  8.0027e-01,  1.6767e+00,  1.2786e+00,  1.2919e+00,
          6.0191e-01,  1.3333e+00, -2.3213e-01,  4.5254e-02, -2.5312e-01,
          8.5905e-01],
        [-2.5088e+00,  4.8745e-01,  7.8425e-01,  2.8885e-02,  6.4022e-01,
          5.8354e-01,  1.0660e+00, -4.5029e-01, -1.8585e-01,  7.5242e-01,
          4.0453e-01,  1.7870e-01,  2.6473e-01,  1.2734e+00, -1.8767e-03,
         -3.0280e-01],
 

Scaled dot product Attentionの実装

今まで解説してきたself attentionでは出力を計算する際に学習可能なパラメーターを使わなかった。<br>
そのため。Attention Scoreの計算においてある程度の制限がある。<br>
そこで、Self Attentionメカニズムがモデルの最適化により柔軟に対応できるようにするために、モデルの訓練パラメーターとして<br>
3つの重み行列をSelf Attentionメカニズムに追加する。<br>

3つの重み行列をかけたものをそれぞれquery, key, valueとすると、<br>
$$
q^{(i)} = U_{q}x^{(i)} \\
k^{(i)} = U_{k}x^{(i)} \\
v^{(i)} = U_{v}x^{(i)} \\
$$
という形でそれぞれquery, key, valueは表現される。<br>
このときに使われる行列$U_{q}$. $U_{k}$, $U_{v}$は射影行列であり、そのサイズは、$x^{(i)} \in \mathbb{R}^d$とすると、それぞれ
<br>$d_{k} \times d$, $d_{k} \times d$, $d_{v} \times d$となる。$U_q$と$U_k$のサイズが同じことに注意してほしい。<br>
計算された$q, k, v$を用いて、重要度$\omega_{ij}$は以下のように計算される。<br>
$$
\omega_{ij} = {q^{(i)}}^{T}k^{(j)}
$$
次に正規化して$\alpha_{ij}$を求める。$\exp{(\dfrac{\omega_{ij}}{\sqrt{d_k}})}$とスケーリングを行うことで重みベクトルのユークリッド距離がだいたい同じ範囲に収まるようにする<br>
$$
\alpha_{ij} = \text{softmax}(\dfrac{\omega_{ij}}{\sqrt{d_{k}}}) = \dfrac{\exp{(\dfrac{\omega_{ij}}{\sqrt{d_k}})}}{\displaystyle \sum_{l=1}^{t}\exp{(\dfrac{\omega_{il}}{\sqrt{d_k}})}}
$$
最後に出力を計算する。
$$
z^{(i)} = \displaystyle \sum_{j=1}^{t}\alpha_{ij}v^{(j)}
$$

上までの処理を行列で表す。<br>
$$ Q = \begin{pmatrix}   {q^{(1)}}^{T} \\  {q^{(2)}}^{T} \\  \vdots\\  \vdots\\  {q^{(t)}}^{T} \end{pmatrix}, K = \begin{pmatrix}   {k^{(1)}}^{T} \\  {k^{(2)}}^{T} \\  \vdots\\  \vdots\\  {k^{(t)}}^{T} \end{pmatrix}, V = \begin{pmatrix}   {v^{(1)}}^{T} \\  {v^{(2)}}^{T} \\  \vdots\\  \vdots\\  {v^{(t)}}^{T} \end{pmatrix}, Ω = {(ω_{ij})_{i,j}}, A = {(\alpha_{ij})_{i,j}}, 
Z = \begin{pmatrix} 
  {z^{(1)}}^{T} \\
  {z^{(2)}}^{T} \\
  \vdots\\
  \vdots\\
  {z^{(t)}}^{T} 
\end{pmatrix}$$
として定義を行うと、<br>
$$
Q = (U_qX^T)^{T} = X{U_q}^T, K = X{U_k}^T, V = X{U_v}^T, Ω = QK^T, Z = AV
$$

を計算すれば良い。これを図に表したものがよく見る下の図となる。<br>

また、$Q， K， V$を用いて出力Zを求めることを$Z = \text{Attention}(Q,K,V)$と表現する文献もある。ここでは以下この記法を用いることもある。

<img src = "https://production-media.paperswithcode.com/methods/35184258-10f5-4cd0-8de3-bd9bc8f88dc3.png"><br>
### 実装は[Paper with code: Scaled Dot-Product Attention](https://paperswithcode.com/method/scaled)にもある

今回は入力に対して変換行列をかけることで$Q,K,V$を求めたが、Q,K,Vの求め方によってattentionの名前が異なる。<br>
詳しくは次のサイトなどを参考ににしてほしい。[30分で完全理解するTransformerの世界](https://zenn.dev/zenkigen/articles/2023-01-shimizu)<br>
以下では区別のために、Scaled dot product self attentionと呼ぶことにする。

```ここでQuery， Key， Valueの説明をしたほうが良い。```

Scaled dot product self attentionの実装の前にもう少しだけ詳しく実装の仕方を見ていこう。<br>
Scaled dot product self attentionは(batch_size, sequence_length, embedding_dim)という入力を想定する。

In [6]:
#(batch_size, sequence_length, embedding_dim) = (2,5,5)を入力とする
import torch
import torch.nn.functional as F
import numpy as np
x = torch.rand(2,3,5) #入力
U_q = torch.rand(2,5)
U_k = torch.rand(2,5)
U_v = torch.rand(5,5)

上で述べた数式通りに計算を行うと、各batchでの出力は以下の通りとなる。

In [7]:
for i in range(2):
    x_i = x[i]
    q = torch.matmul(x_i,U_q.T)
    k = torch.matmul(x_i,U_k.T)
    v = torch.matmul(x_i,U_v.T)
    omega = q @ k.T
    attn = F.softmax(omega / np.sqrt(k.size(1)), dim = 1)
    z = attn @ v
    print(f"batch = {i}, output: \n",z)

batch = 0, output: 
 tensor([[1.0679, 1.4628, 0.8342, 1.4374, 1.6284],
        [1.0530, 1.4554, 0.8329, 1.4291, 1.6292],
        [1.0555, 1.4539, 0.8332, 1.4297, 1.6272]])
batch = 1, output: 
 tensor([[1.2303, 1.6044, 0.9414, 1.4426, 1.7949],
        [1.2561, 1.6172, 0.9515, 1.4616, 1.8104],
        [1.2479, 1.6132, 0.9480, 1.4553, 1.8042]])


以上のコードのように、for文をbatch_size分だけ回せば理論的には出力が求まりはするが、for文を回すことは計算効率が悪い。そこで、torch.einsumを用いて計算の効率化を行う。

In [8]:
q = torch.einsum("ijk,dk->ijd",[x,U_q])
k = torch.einsum("ijk,dk->ijd",[x,U_k])
v = torch.einsum("ijk,dk->ijd",[x,U_v])
omega = torch.einsum("ijk,ilk->ijl",[q,k])
attn = F.softmax(omega / np.sqrt(k.size(2)),dim=2)
z = torch.einsum("ijk,ikl->ijl",[attn,v])
print("einsumでの計算結果:\n ",z)

einsumでの計算結果:
  tensor([[[1.0679, 1.4628, 0.8342, 1.4374, 1.6284],
         [1.0530, 1.4554, 0.8329, 1.4291, 1.6292],
         [1.0555, 1.4539, 0.8332, 1.4297, 1.6272]],

        [[1.2303, 1.6044, 0.9414, 1.4426, 1.7949],
         [1.2561, 1.6172, 0.9515, 1.4616, 1.8104],
         [1.2479, 1.6132, 0.9480, 1.4553, 1.8042]]])


einsumを使うことで各バッチでの処理が正しく求められていることがわかる。<br>

Scaled Dot Product Self Attentionの実装<br>
今回は入力から$Q,K,V$の値を求めるScaled Dot Product Self Attentionの実装について解説を行ってきた。<br>
そこで、今回の実装ではforwardに(batch_size, sequence_length, embedding_dim)を入力として想定する**ScaledDotProductSelfAttention**<br>
と、forwardに$Q,K,V$の入力を想定する**ScaledDotProductAttention**の2つを実装する。<br>

In [9]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,Q,K,V):
        omega = torch.einsum("ijk,ilk->ijl",[Q,K])
        attn = F.softmax(omega / np.sqrt(K.size(2)),dim=2)
        z = torch.einsum("ijk,ikl->ijl",[attn,V])
        return z
class ScaledDotProductSelfAttention(nn.Module):
    def __init__(self, embed_dim, d_k, d_v):
        super().__init__()
        self.U_q = nn.Parameter(torch.tensor(np.random.uniform(low = -np.sqrt(1/d_k),high = np.sqrt(1/d_k),size = (d_k,embed_dim))).float())
        self.U_k = nn.Parameter(torch.tensor(np.random.uniform(low = -np.sqrt(1/d_k),high = np.sqrt(1/d_k),size = (d_k,embed_dim))).float())
        self.U_v = nn.Parameter(torch.tensor(np.random.uniform(low = -np.sqrt(1/d_v),high = np.sqrt(1/d_v),size = (d_v,embed_dim))).float())
    def forward(self,x):
        Q = torch.einsum("ijk,dk->ijd",[x,self.U_q])
        K = torch.einsum("ijk,dk->ijd",[x,self.U_k])
        V = torch.einsum("ijk,dk->ijd",[x,self.U_v])
        omega = torch.einsum("ijk,ilk->ijl",[Q,K])
        attn = F.softmax(omega / np.sqrt(K.size(2)),dim=2)
        z = torch.einsum("ijk,ikl->ijl",[attn,V])
        return z

In [10]:
batch_size = 2
sequence_length = 3
embedding_dim = 5
d_k = 2
d_v = 5
x = torch.rand(batch_size,sequence_length, embedding_dim) #入力
sdpa = ScaledDotProductAttention()
sdpsa = ScaledDotProductSelfAttention(embedding_dim,d_k,d_v)

2つの実装が一致していることを確かめておきましょう。

In [11]:
print("ScaledDotProductSelfAttentionの実装:\n",sdpsa(x))

ScaledDotProductSelfAttentionの実装:
 tensor([[[-0.1642, -0.1012,  0.0598, -0.0455, -0.3440],
         [-0.1647, -0.1027,  0.0615, -0.0467, -0.3450],
         [-0.1658, -0.1053,  0.0624, -0.0490, -0.3447]],

        [[-0.2318, -0.1854,  0.1680, -0.0798, -0.3056],
         [-0.2353, -0.1908,  0.1712, -0.0801, -0.2971],
         [-0.2303, -0.1830,  0.1666, -0.0796, -0.3093]]],
       grad_fn=<ViewBackward0>)


In [12]:
Q = torch.einsum("ijk,dk->ijd",[x,sdpsa.U_q])
K = torch.einsum("ijk,dk->ijd",[x,sdpsa.U_k])
V = torch.einsum("ijk,dk->ijd",[x,sdpsa.U_v])
print("ScaledDotProductAttentionの実装:\n",sdpa(Q,K,V))

ScaledDotProductAttentionの実装:
 tensor([[[-0.1642, -0.1012,  0.0598, -0.0455, -0.3440],
         [-0.1647, -0.1027,  0.0615, -0.0467, -0.3450],
         [-0.1658, -0.1053,  0.0624, -0.0490, -0.3447]],

        [[-0.2318, -0.1854,  0.1680, -0.0798, -0.3056],
         [-0.2353, -0.1908,  0.1712, -0.0801, -0.2971],
         [-0.2303, -0.1830,  0.1666, -0.0796, -0.3093]]],
       grad_fn=<ViewBackward0>)


MultiheadAttentionの実装

今までで実装してきたScaledDotProductAttentionはTransformerモデルの基本構成要素となっている。<br>
しかし、Transformerで採用されているAttentionはScaledDotProductAttentionを並列で実行するMultiheadAttentionというものになっています。<br>

MultiheadAttentionの理論式は以下の通りとなる。<br>

$\text{MultiHead}(Q,K,V) = \text{Concat}(head_1,...head_h)W^o \  \text{where}\   head_i = \text{Attention}(Q{W_{i}}^Q,K{W_{i}}^K,V{W_{i}}^V)$

これは、入力した$Q, K, V$それぞれにh個の線形層をかけて、それら全てに$\text{ScaledDotAttention}$を計算させてできたTensorを結合したものを、更に線形層にいれて最終的な出力を計算するという演算になります。<br>

実装上の注意としては、まず最初のh個数の線形層をかけるところは、1つの行列の演算としてまとめたほうが効率が良いです。<br>
例えば、$Q{W_{i}}^Q \ i \in \{1,2,3,...,h\}$の計算では$W^Q = [{W_{1}}^Q, {W_{2}}^Q, {W_{3}}^Q, ... , {W_{h}}^Q] \ W^Q \in M_{dim\_q}\times_{outdim\_q \times num\_heads}$とおくことで、<br>
$QW^Q = [Q{W_{1}}^Q, Q{W_{2}}^Q, Q{W_{3}}^Q, ... , Q{W_{h}}^Q]$とまとめて計算することが可能です。これを今回はnn.Linearで実現します。<br>
(上までの実装のようにnn.parameterから制作しても良い)<br>

次に、$head_i = \text{Attention}(Q{W_{i}}^Q,K{W_{i}}^K,V{W_{i}}^V)$の実装方法です。ScaledDotProductAttentionを実装したようにeinsumを用いても良いのですが<br>
今回は**einops**というライブラリを用いて簡単に計算を行いたいと思います。tensor.splitを使ってheadを分割した後に結合するというやり方もありますか、<br>
その実装方針ではメモリ領域の確保などの操作が含まれるため、今回はeinopsでの実装を行います。<br>

<img src = "https://production-media.paperswithcode.com/methods/multi-head-attention_l1A3G7a.png">

In [13]:
from einops import rearrange
class MultiheadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, d_k = None, d_v = None):
        super().__init__()
        self.d_k = embed_dim if d_k == None else d_k
        self.d_v = embed_dim if d_v == None else d_v
        self.to_q = nn.Linear(embed_dim,self.d_k*num_heads,bias = False)
        self.to_k = nn.Linear(embed_dim,self.d_k*num_heads,bias = False)
        self.to_v = nn.Linear(embed_dim,self.d_v*num_heads,bias = False)
        self.output = nn.Linear(self.d_v*num_heads, embed_dim,bias = False)
        nn.init.xavier_uniform_(self.to_q.weight)
        nn.init.xavier_uniform_(self.to_k.weight)
        nn.init.xavier_uniform_(self.to_v.weight)
        nn.init.xavier_uniform_(self.output.weight)
        self.softmax = nn.Softmax(dim = -1)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
    def forward(self, x):
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)
        q = rearrange(q,"b i (h j)->b h i j",h=self.num_heads,j=self.d_k)
        k = rearrange(k,"b i (h j)->b h i j",h=self.num_heads,j=self.d_k)
        v = rearrange(v,"b i (h j)->b h i j",h=self.num_heads,j=self.d_v)
        omega = torch.einsum("bhjk,bhlk->bhjl",[q,k])
        attn = self.softmax(omega/np.sqrt(self.d_k))
        z = torch.einsum("bhjk,bhkl->bhjl",[attn,v])
        concat_z = rearrange(z,"b h i j->b i (h j)")
        out = self.output(concat_z)
        return out, attn
class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, d_k = None, d_v = None):
        super().__init__()
        self.d_k = embed_dim if d_k == None else d_k
        self.d_v = embed_dim if d_v == None else d_v
        self.to_q = nn.Linear(embed_dim,self.d_k*num_heads,bias = False)
        self.to_k = nn.Linear(embed_dim,self.d_k*num_heads,bias = False)
        self.to_v = nn.Linear(embed_dim,self.d_v*num_heads,bias = False)
        self.output = nn.Linear(self.d_v*num_heads, embed_dim,bias = False)
        nn.init.xavier_uniform_(self.to_q.weight)
        nn.init.xavier_uniform_(self.to_k.weight)
        nn.init.xavier_uniform_(self.to_v.weight)
        nn.init.xavier_uniform_(self.output.weight)
        self.softmax = nn.Softmax(dim = -1)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
    def forward(self, q,k,v):
        q = self.to_q(q)
        k = self.to_k(k)
        v = self.to_v(v)
        q = rearrange(q,"b i (h j)->b h i j",h=self.num_heads,j=self.d_k)
        k = rearrange(k,"b i (h j)->b h i j",h=self.num_heads,j=self.d_k)
        v = rearrange(v,"b i (h j)->b h i j",h=self.num_heads,j=self.d_v)
        omega = torch.einsum("bhjk,bhlk->bhjl",[q,k])
        attn = self.softmax(omega/np.sqrt(self.d_k))
        z = torch.einsum("bhjk,bhkl->bhjl",[attn,v])
        concat_z = rearrange(z,"b h i j->b i (h j)")
        out = self.output(concat_z)
        return out, attn

In [14]:
batch_size = 2
sequence_length = 3
embedding_dim = 5
d_k = 2
d_v = 5
x = torch.rand(batch_size,sequence_length, embedding_dim)

In [15]:
mhsa = MultiheadSelfAttention(embed_dim=embedding_dim, num_heads=3,d_k=d_k,d_v=d_v)

In [16]:
mhsa(x)

(tensor([[[-0.1790,  0.3175,  0.1880, -1.7686, -1.2203],
          [-0.1723,  0.3033,  0.1802, -1.7113, -1.1644],
          [-0.1743,  0.3190,  0.1878, -1.7620, -1.2175]],
 
         [[-0.0847,  0.2050,  0.0675, -1.1218, -0.8401],
          [-0.0716,  0.1920,  0.0702, -1.0858, -0.8119],
          [-0.0805,  0.2182,  0.0610, -1.1137, -0.8357]]],
        grad_fn=<UnsafeViewBackward0>),
 tensor([[[[0.3714, 0.2822, 0.3464],
           [0.4062, 0.2386, 0.3552],
           [0.3734, 0.2762, 0.3504]],
 
          [[0.3410, 0.3252, 0.3338],
           [0.3442, 0.3100, 0.3459],
           [0.3370, 0.3236, 0.3394]],
 
          [[0.4132, 0.2586, 0.3282],
           [0.5302, 0.1878, 0.2820],
           [0.4177, 0.2666, 0.3157]]],
 
 
         [[[0.3453, 0.3432, 0.3115],
           [0.3543, 0.3521, 0.2936],
           [0.3505, 0.3386, 0.3109]],
 
          [[0.3345, 0.3027, 0.3628],
           [0.3343, 0.3020, 0.3637],
           [0.3335, 0.3245, 0.3420]],
 
          [[0.3484, 0.3033, 0.3483],
   

分析デモ

制作したScaledDotProductAttentionとMultiheadAttentionを用いて簡単な分析を行ってみましょう。

In [17]:
from datasets import load_dataset
#openpyxlが入っていないとバグるので注意
dataset = load_dataset("snow_simplified_japanese_corpus")

In [18]:
index = np.random.permutation(50000)
train_index = index[:45000]
test_index = index[45000:]
#何だか変な実装になってしまいましたが、訓練データを45000, テストデータを5000として分割を行う。
ja_train = np.array(dataset["train"]["simplified_ja"])[train_index].tolist()
ja_test = np.array(dataset["train"]["simplified_ja"])[test_index].tolist()
en_train = np.array(dataset["train"]["original_en"])[train_index].tolist()
en_test = np.array(dataset["train"]["original_en"])[test_index].tolist()

自然言語処理では分析の前に文章を分かち書きし、一意な単語を整数にマッピングする必要があります。<br>
今回扱うデータは眺めているとわかるのですが、顔文字などの特殊文字のない、とても癖のないきれいなデータであることがわかります。<br>
そのため、データをすぐに分かち書きしても良いでしょう。<br>
(実際のデータに顔文字や特殊文字、htmlの記法などが含まれていた場合、それらの文字には欠損値のように特別な処理を施す必要があります。)<br>

日本語の分かち書きにはjanome, 英語の分かち書きにはtorchtext.data.utils.get_tokenizerを用います。<br>
では、まずは一意な単語を数字に変換するtorchtext.vocab.vocabオブジェクトを製作しましょう。

In [19]:
from janome.tokenizer import Tokenizer
from torchtext.data.utils import get_tokenizer
from collections import Counter, OrderedDict
from tqdm import tqdm
from torchtext.vocab import vocab
def create_ja_vocab(*args):
    ja_tokenizer = Tokenizer()
    ja_token_count = Counter()
    for data in args:
        for d in tqdm(data):
            tokens = list(ja_tokenizer.tokenize(d,wakati=True))
            ja_token_count.update(tokens)
    ordered_dict = OrderedDict(sorted(ja_token_count.items(),key = lambda x: x[1], reverse = True))
    ja_vocab = vocab(ordered_dict)
    ja_vocab.insert_token("<unk>", 0) 
    ja_vocab.insert_token("<pad>", 1)
    ja_vocab.insert_token("<bos>", 2)
    ja_vocab.insert_token("<eos>", 3)
    ja_vocab.set_default_index(0) #<unk>を何番目にするかを指定する
    return ja_vocab
def create_en_vocab(*args):
    en_tokenizer = get_tokenizer("basic_english")
    en_token_count = Counter()
    for data in args:
        for d in tqdm(data):
            tokens = en_tokenizer(d)
            en_token_count.update(tokens)
    ordered_dict = OrderedDict(sorted(en_token_count.items(),key = lambda x: x[1], reverse = True))
    en_vocab = vocab(ordered_dict)
    en_vocab.insert_token("<unk>", 0)
    en_vocab.insert_token("<pad>", 1)
    en_vocab.insert_token("<bos>", 2)
    en_vocab.insert_token("<eos>", 3)
    en_vocab.set_default_index(0)#<unk>を何番目にするかを指定する
    return en_vocab

In [20]:
ja_vocab = create_ja_vocab(ja_train,ja_test)
en_vocab = create_en_vocab(en_train,en_test)

100%|██████████| 45000/45000 [00:14<00:00, 3137.98it/s]
100%|██████████| 5000/5000 [00:01<00:00, 3345.50it/s]
100%|██████████| 45000/45000 [00:00<00:00, 303584.07it/s]
100%|██████████| 5000/5000 [00:00<00:00, 287147.36it/s]


例文を打ってみて、製作したvocabオブジェクトがどのように動作するかを試してみましょう。

In [21]:
print("日本語のvocab: ",ja_vocab(["これ", "は", "例", "です", "。"]))
print("日本語のvocabの語彙サイズ:", len(ja_vocab))
print("英語のvocab: ",en_vocab(["this", "is", "an", "example", "."]))
print("英語のvocabの語彙サイズ:", len(en_vocab))

日本語のvocab:  [75, 5, 1741, 19, 4]
日本語のvocabの語彙サイズ: 3919
英語のvocab:  [22, 9, 73, 1788, 4]
英語のvocabの語彙サイズ: 6626


次にデータセットを定義します。最終的には日本語のstring型を入力して英語のstring型を出力して欲しいため、データにおいてもstring型のデータを入力に用いるように構築します。<br>
カスタムデータセットの構築方法はこのリポジトリのpytorch_command.ipynbでも解説しています。<br>
もし学習が不十分なときは、見返しながらやると良いでしょう。<br>

今回はカスタムデータセットを用いてデータセットを定義します。

In [22]:
from torch.utils.data import Dataset
class sentence_dataset(Dataset):
    def __init__(self, ja, en):
        if len(ja) != len(en):
            raise ValueError("len(ja) != len(en)")
        self.ja = ja
        self.en = en
    def __getitem__(self, index):
        return self.ja[index], self.en[index]
    def __len__(self):
        return len(self.ja)

In [23]:
train_dataset = sentence_dataset(ja_train, en_train)
test_dataset = sentence_dataset(ja_test, en_test)

次に、データの処理関数を記述します。

In [24]:
def str2tensor(batch):
    ja_tokenizer = Tokenizer()
    en_tokenizer = get_tokenizer("basic_english")
    text2int_ja = lambda x: [ja_vocab["<bos>"]]+[ja_vocab[token] for token in list(ja_tokenizer.tokenize(x,wakati=True))]+[ja_vocab["<eos>"]]
    text2int_en = lambda x: [en_vocab["<bos>"]]+[en_vocab[token] for token in en_tokenizer(x)]+[en_vocab["<eos>"]]
    text_ja, text_en, length_ja, length_en = [], [], [], []
    for ja, en in batch:
        processed_ja = torch.tensor(text2int_ja(ja), dtype=torch.int64)
        processed_en = torch.tensor(text2int_en(en), dtype=torch.int64)
        text_ja.append(processed_ja)
        text_en.append(processed_en)
        length_ja.append(processed_ja.size(0))
        length_en.append(processed_en.size(0))
    length_ja = torch.tensor(length_ja)
    length_en = torch.tensor(length_en)
    padded_text_ja = nn.utils.rnn.pad_sequence(text_ja, batch_first=True, padding_value=ja_vocab["<pad>"]) #系列帳を揃える。
    padded_text_en = nn.utils.rnn.pad_sequence(text_en, batch_first=True, padding_value=ja_vocab["<pad>"])
    return padded_text_ja, padded_text_en, length_ja, length_en

DataLoaderを定義します。定義する際、訓練セットはShuffleをTrueに、collate_fn引数に製作した処理関数をいれます。

In [25]:
from torch.utils.data import DataLoader
batch_size = 200
train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=str2tensor)
test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=str2tensor)

上手く動作しているかをtest_dlの最初のデータで調べてみましょう。

In [26]:
i = 0
for ja, en, len_ja, len_en in test_dl:
    print("Data Japanese Sentence:\n", ja_test[0])
    print("Data Japanese: \n",ja[0],ja[0].shape)
    print("Data English Sentence:\n", en_test[0])
    print("Data English: \n",en[0],en[0].shape)
    if i == 1:
        break
    i += 1

Data Japanese Sentence:
 外は寒いよ。コートを着なさい。
Data Japanese: 
 tensor([  2, 172,   5, 663,  48,   4, 746,   8, 575,  76,   4,   3,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1]) torch.Size([22])
Data English Sentence:
 it is cold outdoors . put on your coat .
Data English: 
 tensor([   2,   15,    9,  175, 3170,    4,  143,   32,   37,  662,    4,    3,
           1,    1,    1,    1,    1,    1]) torch.Size([18])
Data Japanese Sentence:
 外は寒いよ。コートを着なさい。
Data Japanese: 
 tensor([  2, 108,   6, 471,  10,  23,   4,   3,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1]) torch.Size([22])
Data English Sentence:
 it is cold outdoors . put on your coat .
Data English: 
 tensor([   2,   10,    9,    5, 2466,   16,   28,   87,    4,    3,    1,    1,
           1,    1,    1,    1,    1,    1,    1]) torch.Size([19])


学習

では簡単に学習を行ってみよう。今回はMultiheadAttentionに2つの層を通して文章を生成することを考える。

In [27]:
from torchtext.data.metrics import bleu_score
device = "cuda" if torch.cuda.is_available() else "cpu"
ja_vocab_size = len(ja_vocab)
en_vocab_size = len(en_vocab)
embedding_dim = 300

class ja2en_multiheadattention(nn.Module):
    def __init__(self, ja_vocab_size, en_vocab_size, embedding_dim, num_heads):
        super().__init__()
        self.embedding_ja = nn.Embedding(ja_vocab_size, embedding_dim)
        self.embedding_en = nn.Embedding(en_vocab_size, embedding_dim)
        self.encoder = MultiheadAttention(embedding_dim,num_heads)
        self.decoder = MultiheadAttention(embedding_dim,num_heads)
        self.to_out = nn.Linear(embedding_dim, en_vocab_size)
        self.softmax = nn.Softmax(dim = -1)
    def forward(self,src, tgt):
        src = self.embedding_ja(src)
        tgt = self.embedding_en(tgt)
        src,_ = self.encoder(src,src,src)
        x,_ = self.decoder(tgt, src, src)
        x = self.to_out(x)
        x = self.softmax(x)
        return x

In [28]:
from torch import optim
import torch.nn.functional as F
from torchtext.data.metrics import bleu_score
model = ja2en_multiheadattention(ja_vocab_size, en_vocab_size, embedding_dim, num_heads=6).to(device)
optimizer = optim.Adam(model.parameters())
eps = 0.1
criterion = nn.CrossEntropyLoss(ignore_index=en_vocab["<pad>"], label_smoothing=eps / (en_vocab_size-1)) #予測が<pad>のものは予測に用いない
epochs = 5

次のコードは一度実行したら実行しないほうがいいです

In [29]:
from tqdm import tqdm
import gc
best_acc = 0
best_params = None
n_warmup = epochs // 2
for epoch in tqdm(range(epochs)):
    model.train()
    all_length = 0
    all_acc = 0
    all_loss = 0
    optimizer.lr = min(1/np.sqrt(embedding_dim)/np.sqrt(epoch+1),\
                        (epoch+1)/np.sqrt(embedding_dim)/n_warmup/np.sqrt(n_warmup)) #後々解説
    for japanese, english, _, _ in tqdm(train_dl):
        optimizer.zero_grad()
        japanese = japanese.to(device)
        english = english.to(device)
        pred = model(japanese, english)
        pred = pred[:,:-1,:]
        tgt = english[:,1:].contiguous().view(-1)
        loss = criterion(pred.contiguous().view(-1,pred.size(-1)), tgt)
        loss.backward()
        optimizer.step()
        all_acc += (pred.argmax(dim = -1).contiguous().view(-1) == tgt).sum().item()
        all_length += tgt.size(0)
        all_loss += loss.item()
    print("Train CrossEntropyLoss: ", all_loss / all_length)
    print("Train ACCURACY: ", all_acc / all_length)
    model.eval()
    all_length = 0
    all_acc = 0
    all_loss = 0
    for japanese, english, _, _ in tqdm(test_dl):
        japanese = japanese.to(device)
        english = english.to(device)
        pred = model(japanese, english)
        pred = pred[:,:-1,:]
        tgt = english[:,1:].contiguous().view(-1)
        loss = criterion(pred.contiguous().view(-1,pred.size(-1)), tgt)
        loss.backward()
        optimizer.step()
        all_acc += (pred.argmax(dim = -1).contiguous().view(-1) == tgt).sum().item()
        all_length += tgt.size(0)
        all_loss += loss.item()
    print("Test CrossEntropyLoss: ", all_loss / all_length)
    print("Test ACCURACY: ", all_acc / all_length)
    if best_acc < all_acc / all_length:
        best_acc = all_acc / all_length
        best_params = model.parameters
        torch.save(model.state_dict(), "naive_attention.pth")
    del english
    del japanese
    del pred
    del tgt
    gc.collect()
    torch.cuda.empty_cache()
    #cuda: 1 epoch 0:39, batch_size=100,  embedding_dim = 300
    #cpu:  1 epoch 3:34, batch_size=100,  embedding_dim = 300

100%|██████████| 225/225 [00:32<00:00,  6.93it/s]


Train CrossEntropyLoss:  0.002699989330946117
Train ACCURACY:  0.054091285162713734


100%|██████████| 25/25 [00:03<00:00,  7.07it/s]
 20%|██        | 1/5 [00:36<02:24, 36.20s/it]

Test CrossEntropyLoss:  0.002679222816316952
Test ACCURACY:  0.053669950738916256


100%|██████████| 225/225 [00:31<00:00,  7.10it/s]


Train CrossEntropyLoss:  0.0027131858857930894
Train ACCURACY:  0.05460643015521064


100%|██████████| 25/25 [00:03<00:00,  7.12it/s]
 40%|████      | 2/5 [01:11<01:47, 35.68s/it]

Test CrossEntropyLoss:  0.002679222816316952
Test ACCURACY:  0.053669950738916256


100%|██████████| 225/225 [00:31<00:00,  7.09it/s]


Train CrossEntropyLoss:  0.0027027007218291916
Train ACCURACY:  0.05439536167863059


100%|██████████| 25/25 [00:03<00:00,  7.10it/s]
 60%|██████    | 3/5 [01:46<01:11, 35.55s/it]

Test CrossEntropyLoss:  0.002679222816316952
Test ACCURACY:  0.053669950738916256


100%|██████████| 225/225 [00:32<00:00,  6.85it/s]


Train CrossEntropyLoss:  0.002708682908362688
Train ACCURACY:  0.054515771997786386


100%|██████████| 25/25 [00:03<00:00,  6.73it/s]
 80%|████████  | 4/5 [02:23<00:36, 36.01s/it]

Test CrossEntropyLoss:  0.002679222816316952
Test ACCURACY:  0.053669950738916256


100%|██████████| 225/225 [00:33<00:00,  6.72it/s]


Train CrossEntropyLoss:  0.002689333070765485
Train ACCURACY:  0.054126373626373625


100%|██████████| 25/25 [00:03<00:00,  6.75it/s]
100%|██████████| 5/5 [03:00<00:00, 36.19s/it]

Test CrossEntropyLoss:  0.002679222816316952
Test ACCURACY:  0.053669950738916256





学習したモデルのパラメーターは保存しておきましょう。

In [30]:
model.parameters = best_params
torch.save(model.state_dict(), "naive_attention.pth")

In [31]:
del model
gc.collect()
torch.cuda.empty_cache()

訓練結果はどうでしたか？思ったよりも結果がでなかったではないでしょうか？<br>
これは、RNNベースのSeq2SeqモデルからAttentionモデルに変更したことによるモデルアーキテクチャの違いや、<br>
ただ愚直にMultiheadAttentionを実装してしまったことが原因となっています。<br>
先に実験した愚直なMultiheadAttentionモデルのアーキテクチャを改善したものがかの有名なTransformerとなっています。<br>
では上のMultiheadAttentionを改善する形で実装をしていきましょう。

問題点1<br>
RNNベースに変えたことで系列データの順序情報が消えている。<br>

これには単語の位置情報を付加するPositional Encodingという方法が考案されています。<br>
詳しくはこちらをご覧ください[Transformer Architecture: The Positional Encoding](https://kazemnejad.com/blog/transformer_architecture_positional_encoding/)<br>
Positionalエンコーディングは系列データt番目の2k, 2k+1番目の変数に以下の値を加えるものである。<br>
$$
PE(t, 2k) = \sin{(\dfrac{t}{T^{\frac{2k}{D}}})}, PE(t, 2k+1) = \cos{(\dfrac{t}{T^{\frac{2k}{D}}})}
$$
Dは埋め込みベクトルの次元数であり、Tは仮想的な最大系列長である。T=10000とすることが多い<br>
実装の際は、PEが学習パラメーターではないことに注意しましょう。<br>
学習パラメーターでない数をモデルに含む際はtorch.nn.Module.register_bufferを用います。<br>
また、ブロードキャスト機能を用いて効率よく実装しましょう。

In [32]:
from torch import nn
import torch
class positional_encoding(nn.Module):
    def __init__(self, embedding_dim,T = 10000):
        super().__init__()
        #PE (1, max_sequence_length, embedding_dim)
        self.pe = torch.zeros(size=(1, T, embedding_dim))
        t = torch.arange(start = 1, end = T+1).reshape(T, 1)
        k_odd = torch.arange(start = 1, end = embedding_dim+1, step = 2)
        k_odd = k_odd.reshape(1, k_odd.size(0))
        k_even = torch.arange(start = 2, end = embedding_dim+1, step = 2)
        k_even = k_even.reshape(1,k_even.size(0))
        phase_odd = t / T**((k_odd//2 * 2) / embedding_dim) #<- ブロードキャスト機能を用いて効率よく計算
        phase_even = t / T**((k_even//2 * 2) / embedding_dim)
        self.pe[0,:,0::2] = torch.sin(phase_odd)
        self.pe[0,:,1::2] = torch.cos(phase_even)
        self.register_buffer("positional_encoding_weight", self.pe)
        self.embedding_dim = embedding_dim
    def forward(self, x):
        """入力のxのサイズx: (batch_size, sequence_length, embedding_dim)"""
        #np.sqrt(embedding_dim)をかけることでベクトルのスケールを合わせている
        return np.sqrt(self.embedding_dim)*x + self.pe[:,:x.size(1),:].to(x.device) #何故かself.register_bufferが働かなかった

問題点2<br>
学習の進みが遅い<br>
上のナイーブなMultiheadAttentionモデルの損失を眺めていた人は気づいたかもしれませんが、収束がかなり遅くなっています。<br>
そこで、Transformerでは残差スキップを用いて収束を早めます。<br>
<img src = "https://upload.wikimedia.org/wikipedia/commons/b/ba/ResBlock.png" height =50% width = 50%>

上2つの問題点の改善に加えてさらにlayernorm層, feedforward層(ffn)を加えたものをTransformerと呼び<br>
その構造は以下の通りとなります。<br>
<img src="https://production-media.paperswithcode.com/social-images/oVEwwksZyfDziYzq.png" height = 50% width = 50%>

左の灰色ブロック部分をTransformerEncoderと呼び、右の灰色ブロック部分をTransformerDecoderと呼びます。<br>
各レイヤーは今までに実装してきたものを組み合わせることで簡単に実装が行えます。<br>

In [33]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embedding_dim, ffn_dim, num_heads, drop_out_rate = 0., layer_eps=1e-05,batch_first=False):
        super().__init__()
        self.multiheadattention = nn.MultiheadAttention(embedding_dim, num_heads,batch_first=batch_first)
        self.dropout_attn = nn.Dropout(p = drop_out_rate)
        self.layernorm_attn = nn.LayerNorm(embedding_dim, eps = layer_eps)
        self.ffn = nn.Sequential(nn.Linear(embedding_dim, ffn_dim), nn.ReLU(), nn.Linear(ffn_dim, embedding_dim))
        self.layernorm_ffn = nn.LayerNorm(embedding_dim, eps = layer_eps)
        self.dropout_ffn = nn.Dropout(p = drop_out_rate)
        #layernormは学習可能パラメーターが別々に定義されているので独立に定義する。
    def forward(self, x, pad_mask=None, mask = None):
        dx, _ = self.multiheadattention(x,x,x,key_padding_mask = pad_mask, attn_mask = mask)
        dx = self.dropout_attn(dx)
        x = self.layernorm_attn(x+dx)
        dx = self.dropout_ffn(self.ffn(x))
        x = self.layernorm_ffn(x + dx)
        return x
class TransformerDecoderLayer(nn.Module):
    def __init__(self, embedding_dim, ffn_dim, num_heads, drop_out_rate = 0., layer_eps=1e-05, batch_first = False):
        super().__init__()
        self.multiheadselfattention = nn.MultiheadAttention(embedding_dim, num_heads,batch_first=batch_first)
        self.dropout_selfattn = nn.Dropout(p = drop_out_rate)
        self.layernorm_selfattn = nn.LayerNorm(embedding_dim, eps = layer_eps)

        self.multiheadattention = nn.MultiheadAttention(embedding_dim, num_heads,batch_first=batch_first) #src-target attention
        self.dropout_attn = nn.Dropout(p = drop_out_rate)
        self.layernorm_attn = nn.LayerNorm(embedding_dim, eps = layer_eps)

        self.ffn = nn.Sequential(nn.Linear(embedding_dim, ffn_dim), nn.ReLU(), nn.Linear(ffn_dim, embedding_dim))
        self.layernorm_ffn = nn.LayerNorm(embedding_dim, eps = layer_eps)
        self.dropout_ffn = nn.Dropout(p = drop_out_rate)
    def forward(self, src, tgt, pad_mask_self = None, mask_self=None, pad_mask = None, mask = None):
        dtgt, _ = self.multiheadselfattention(tgt,tgt,tgt,key_padding_mask = pad_mask_self, attn_mask = mask_self)
        dtgt = self.dropout_selfattn(dtgt)
        tgt = self.layernorm_selfattn(tgt+dtgt)
        dtgt, _ = self.multiheadattention(tgt, src, src, key_padding_mask = pad_mask, attn_mask = mask)
        dtgt = self.dropout_attn(dtgt)
        tgt = self.layernorm_attn(tgt+dtgt)
        dtgt = self.dropout_ffn(self.ffn(tgt))
        tgt = self.layernorm_ffn(dtgt + tgt)
        return tgt

次に、数値化された文章を入力として想定するTransformerEncoderと、TransformerDecoderの２つのモジュールを製作します。

In [34]:
class TransformerEncoder(nn.Module):
    def __init__(self, src_vocab_size, embedding_dim, ffn_dim, num_heads, drop_out_rate = 0.,\
                  layer_eps=1e-05, batch_first = False, T = 10000, N = 1):
        super().__init__()
        self.embedding = nn.Embedding(src_vocab_size, embedding_dim,)
        self.positional_encoding = positional_encoding(embedding_dim, T)
        self.encoder = nn.ModuleList([TransformerEncoderLayer(embedding_dim, ffn_dim, num_heads, drop_out_rate,\
                                                               layer_eps, batch_first) for _ in range(N)])
    def forward(self, src, pad_mask =None, mask = None):
        src = self.embedding(src)
        src = self.positional_encoding(src)
        for layer in self.encoder:
            src = layer(src, pad_mask=pad_mask, mask = mask)
        return src
class TransformerDecoder(nn.Module):
    def __init__(self, tgt_vocab_size, embedding_dim, ffn_dim, num_heads, drop_out_rate = 0.,\
                  layer_eps=1e-05, batch_first = False, T = 10000, N = 1):
        super().__init__()
        self.embedding = nn.Embedding(tgt_vocab_size, embedding_dim,)
        self.positional_encoding = positional_encoding(embedding_dim, T)
        self.decoder = nn.ModuleList([TransformerDecoderLayer(embedding_dim, ffn_dim, num_heads, drop_out_rate,\
                                                               layer_eps, batch_first) for _ in range(N)])
    def forward(self, src, tgt, pad_mask_self = None, mask_self=None, pad_mask = None, mask = None):
        tgt = self.embedding(tgt)
        tgt = self.positional_encoding(tgt)
        for layer in self.decoder:
            tgt = layer(src, tgt, pad_mask_self = pad_mask_self,mask_self = mask_self, pad_mask = pad_mask, mask = mask)
        return tgt

ここまで作ればTransformerの完成です。

In [35]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, embedding_dim, ffn_dim, num_heads, drop_out_rate = 0.,\
                 layer_eps=1e-05, batch_first = False, T = 10000, N = 1):
        super().__init__()
        self.encoder = TransformerEncoder(src_vocab_size, embedding_dim, ffn_dim, num_heads, drop_out_rate,\
                  layer_eps, batch_first, T, N)
        self.decoder = TransformerDecoder(tgt_vocab_size, embedding_dim, ffn_dim, num_heads, drop_out_rate,\
                  layer_eps, batch_first, T, N)
        self.linear = nn.Linear(embedding_dim, tgt_vocab_size)
    def forward(self, src, tgt, pad_encoder_mask = None, encoder_mask=None, \
                pad_decoder_mask_self=None,decoder_mask_self=None, pad_decoder_mask=None,decoder_mask=None):
        src = self.encoder(src, pad_encoder_mask, encoder_mask)
        tgt = self.decoder(src, tgt, pad_decoder_mask_self, decoder_mask_self ,pad_decoder_mask,decoder_mask)
        return self.linear(tgt)

学習に移る前に、最後の問題点について解説します。<br>
自己注意機構を用いた学習では特別な対応をしなければ系列の全情報を用いて単語の予測をしてしまうということです。<br>
これは\<pad\>を学習対象として含んでしまうことや、予測先のトークンの答えを盗み見て学習していることになってしまいます。<br>
そこで、マスキングと呼ばれる作業を行うことでこの状態を回避します。<br>

In [36]:
def create_mask(src, tgt, src_pad, tgt_pad,device):
    """
    (batch_size, sequence_length, embedding_dim)の入力を想定
    """
    """
    Trueが無視される値であることに注意すること
    """
    seq_len_src = src.size(1)
    seq_len_tgt = tgt.size(1)
    #srcのマスク制作
    padding_src_mask = (src == src_pad)
    src_mask = torch.zeros(size=(seq_len_src, seq_len_src)).to(device).type(torch.bool)
    #tgtのマスク制作
    padding_tgt_mask = (tgt ==tgt_pad)
    tgt_mask = torch.triu(torch.ones(size = (seq_len_tgt, seq_len_tgt))==1).transpose(0,1) #下三角行列を作る
    tgt_mask = tgt_mask.float().masked_fill(tgt_mask == 0, float("-inf")).masked_fill(tgt_mask==1.,float(0.0)).to(device)
    return padding_src_mask, src_mask, padding_tgt_mask, tgt_mask

In [54]:
device = "cuda" if torch.cuda.is_available() else "cpu"
ja_vocab_size = len(ja_vocab)
en_vocab_size = len(en_vocab)
embedding_dim = 300
from torch import optim
import torch.nn.functional as F
from torchtext.data.metrics import bleu_score
model = Transformer(ja_vocab_size, en_vocab_size, embedding_dim, ffn_dim = embedding_dim*2,num_heads=6,drop_out_rate=0.1, batch_first=True, N = 5).to(device)
optimizer = optim.Adam(model.parameters())
eps = 0.1
criterion= nn.CrossEntropyLoss(ignore_index=en_vocab["<pad>"], label_smoothing=eps / (en_vocab_size-1)) #予測が<pad>のものは予測に用いない
epochs = 100

In [55]:
from tqdm import tqdm
best_acc = 0
best_params = None
n_warmup = epochs // 3
src_pad = ja_vocab["<pad>"]
tgt_pad = en_vocab["<pad>"]
for epoch in tqdm(range(epochs)):
    model.train()
    all_length = 0
    all_acc = 0
    all_loss = 0
    optimizer.lr = min(1/np.sqrt(embedding_dim)/np.sqrt(epoch+1), (epoch+1)/np.sqrt(embedding_dim)/n_warmup/np.sqrt(n_warmup)) #後々解説
    for japanese, english, _, _ in tqdm(train_dl):
        optimizer.zero_grad()
        japanese = japanese.to(device)
        english = english.to(device)
        padding_src_mask, src_mask, padding_tgt_mask, tgt_mask = create_mask(japanese, english, src_pad, tgt_pad, device)
        pred = model(japanese, english, pad_encoder_mask = padding_src_mask, encoder_mask = src_mask, \
                     pad_decoder_mask_self = padding_tgt_mask, decoder_mask_self = tgt_mask)
        pred = pred[:,:-1,:]
        tgt = english[:,1:].contiguous().view(-1)
        loss = criterion(pred.contiguous().view(-1,pred.size(-1)), tgt)
        loss.backward()
        optimizer.step()
        all_acc += (pred.argmax(dim = -1).contiguous().view(-1) == tgt).sum().item()
        all_length += tgt.size(0)
        all_loss += loss.item()
    print("Train CrossEntropyLoss: ", all_loss / all_length)
    print("Train ACCURACY: ", all_acc / all_length)
    model.eval()
    all_length = 0
    all_acc = 0
    all_loss = 0
    for japanese, english, _, _ in tqdm(test_dl):
        japanese = japanese.to(device)
        english = english.to(device)
        padding_src_mask, src_mask, padding_tgt_mask, tgt_mask = create_mask(japanese, english, src_pad, tgt_pad, device)
        pred = model(japanese, english, pad_encoder_mask = padding_src_mask, encoder_mask = src_mask, \
                     pad_decoder_mask_self = padding_tgt_mask, decoder_mask_self = tgt_mask)
        pred = pred[:,:-1,:]
        tgt = english[:,1:].contiguous().view(-1)
        loss = criterion(pred.contiguous().view(-1,pred.size(-1)), tgt)
        all_acc += (pred.argmax(dim = -1).contiguous().view(-1) == tgt).sum().item()
        all_length += tgt.size(0)
        all_loss += loss.item()
    print("Test CrossEntropyLoss: ", all_loss / all_length)
    print("Test ACCURACY: ", all_acc / all_length)
    if best_acc < all_acc / all_length:
        best_acc = all_acc / all_length
        best_params = model.parameters
        torch.save(model.state_dict(), "transformer.pth")
    #使った変数は一旦削除する
    del english
    del japanese
    del pred
    del tgt
    gc.collect()
    torch.cuda.empty_cache()

100%|██████████| 225/225 [00:43<00:00,  5.21it/s]


Train CrossEntropyLoss:  0.0012225607726810452
Train ACCURACY:  0.1951257253384913


100%|██████████| 25/25 [00:03<00:00,  7.57it/s]


Test CrossEntropyLoss:  0.0009381679713432424
Test ACCURACY:  0.24167487684729064


100%|██████████| 225/225 [00:43<00:00,  5.21it/s]


Train CrossEntropyLoss:  0.0007865674468141764
Train ACCURACY:  0.2829499174463401


100%|██████████| 25/25 [00:03<00:00,  7.52it/s]


Test CrossEntropyLoss:  0.0006809324701431349
Test ACCURACY:  0.31346059113300495


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  0.0005876598565932209
Train ACCURACY:  0.335292333149476


100%|██████████| 25/25 [00:03<00:00,  7.54it/s]


Test CrossEntropyLoss:  0.0005987481750878207
Test ACCURACY:  0.3317118226600985


100%|██████████| 225/225 [00:43<00:00,  5.21it/s]


Train CrossEntropyLoss:  0.00047860160650928024
Train ACCURACY:  0.36448261589403974


100%|██████████| 25/25 [00:03<00:00,  7.50it/s]


Test CrossEntropyLoss:  0.0005446437001228332
Test ACCURACY:  0.3527709359605911


100%|██████████| 225/225 [00:43<00:00,  5.20it/s]


Train CrossEntropyLoss:  0.0004049413425097377
Train ACCURACY:  0.3849530516431925


100%|██████████| 25/25 [00:03<00:00,  7.49it/s]


Test CrossEntropyLoss:  0.0005286945894433947
Test ACCURACY:  0.36060344827586205


100%|██████████| 225/225 [00:43<00:00,  5.21it/s]


Train CrossEntropyLoss:  0.00035034490850353476
Train ACCURACY:  0.40015706806282725


100%|██████████| 25/25 [00:03<00:00,  7.54it/s]


Test CrossEntropyLoss:  0.0005158666759876195
Test ACCURACY:  0.3658004926108374


100%|██████████| 225/225 [00:43<00:00,  5.21it/s]


Train CrossEntropyLoss:  0.000308430115305017
Train ACCURACY:  0.41535921525283226


100%|██████████| 25/25 [00:03<00:00,  7.58it/s]


Test CrossEntropyLoss:  0.0005200540505606553
Test ACCURACY:  0.36799261083743845


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  0.0002738693853903329
Train ACCURACY:  0.4258152024235748


100%|██████████| 25/25 [00:03<00:00,  7.56it/s]


Test CrossEntropyLoss:  0.0005204794559572717
Test ACCURACY:  0.36979064039408865


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  0.0002471009265652019
Train ACCURACY:  0.437664364640884


100%|██████████| 25/25 [00:03<00:00,  7.55it/s]
  9%|▉         | 9/100 [07:01<1:10:59, 46.81s/it]

Test CrossEntropyLoss:  0.0005296335475785392
Test ACCURACY:  0.3697167487684729


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  0.0002230685405159049
Train ACCURACY:  0.442888522789676


100%|██████████| 25/25 [00:03<00:00,  7.49it/s]


Test CrossEntropyLoss:  0.0005373915545458864
Test ACCURACY:  0.3725985221674877


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]]


Train CrossEntropyLoss:  0.00020383707795496194
Train ACCURACY:  0.4523809523809524


100%|██████████| 25/25 [00:03<00:00,  7.58it/s]


Test CrossEntropyLoss:  0.000538244576289736
Test ACCURACY:  0.374679802955665


100%|██████████| 225/225 [00:43<00:00,  5.20it/s]]


Train CrossEntropyLoss:  0.00018724919219685208
Train ACCURACY:  0.46214483522569927


100%|██████████| 25/25 [00:03<00:00,  7.56it/s]
 12%|█▏        | 12/100 [09:21<1:08:41, 46.83s/it]

Test CrossEntropyLoss:  0.0005511784465442151
Test ACCURACY:  0.37445812807881773


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  0.00017471416747372453
Train ACCURACY:  0.4658756906077348


100%|██████████| 25/25 [00:03<00:00,  7.58it/s]
 13%|█▎        | 13/100 [10:08<1:07:53, 46.83s/it]

Test CrossEntropyLoss:  0.000563844987030687
Test ACCURACY:  0.3730665024630542


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  0.00016015644882195042
Train ACCURACY:  0.4719398454746137


100%|██████████| 25/25 [00:03<00:00,  7.60it/s]
 14%|█▍        | 14/100 [10:55<1:07:05, 46.80s/it]

Test CrossEntropyLoss:  0.0005715418595985826
Test ACCURACY:  0.37354679802955665


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  0.00014911843267279666
Train ACCURACY:  0.4747252747252747


100%|██████████| 25/25 [00:03<00:00,  7.48it/s]
 15%|█▌        | 15/100 [11:42<1:06:18, 46.81s/it]

Test CrossEntropyLoss:  0.000582163140104322
Test ACCURACY:  0.3741995073891626


100%|██████████| 225/225 [00:43<00:00,  5.20it/s]


Train CrossEntropyLoss:  0.00014130548604073063
Train ACCURACY:  0.4802426247587538


100%|██████████| 25/25 [00:03<00:00,  7.54it/s]


Test CrossEntropyLoss:  0.0005908617024938461
Test ACCURACY:  0.37562807881773397


100%|██████████| 225/225 [00:43<00:00,  5.21it/s]]


Train CrossEntropyLoss:  0.0001330329928114928
Train ACCURACY:  0.4849613259668508


100%|██████████| 25/25 [00:03<00:00,  7.52it/s]
 17%|█▋        | 17/100 [13:15<1:04:42, 46.78s/it]

Test CrossEntropyLoss:  0.0005982607911373007
Test ACCURACY:  0.3753325123152709


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  0.00012564661561868869
Train ACCURACY:  0.48646354883081155


100%|██████████| 25/25 [00:03<00:00,  7.51it/s]
 18%|█▊        | 18/100 [14:02<1:03:56, 46.79s/it]

Test CrossEntropyLoss:  0.0006108807548513553
Test ACCURACY:  0.37557881773399016


100%|██████████| 225/225 [00:43<00:00,  5.21it/s]


Train CrossEntropyLoss:  0.00011637937088203693
Train ACCURACY:  0.4934364640883978


100%|██████████| 25/25 [00:03<00:00,  7.36it/s]


Test CrossEntropyLoss:  0.0006193505660653702
Test ACCURACY:  0.37720443349753696


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]]


Train CrossEntropyLoss:  0.0001124345479567258
Train ACCURACY:  0.49031772117228156


100%|██████████| 25/25 [00:03<00:00,  7.56it/s]
 20%|██        | 20/100 [15:36<1:02:25, 46.81s/it]

Test CrossEntropyLoss:  0.0006333244023064674
Test ACCURACY:  0.37565270935960593


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  0.00010782988404453789
Train ACCURACY:  0.49754558011049727


100%|██████████| 25/25 [00:03<00:00,  7.54it/s]
 21%|██        | 21/100 [16:23<1:01:37, 46.80s/it]

Test CrossEntropyLoss:  0.0006356458681557566
Test ACCURACY:  0.3750246305418719


100%|██████████| 225/225 [00:43<00:00,  5.20it/s]


Train CrossEntropyLoss:  0.000102656631473333
Train ACCURACY:  0.4980074257425743


100%|██████████| 25/25 [00:03<00:00,  7.44it/s]
 22%|██▏       | 22/100 [17:09<1:00:49, 46.79s/it]

Test CrossEntropyLoss:  0.0006438984967804894
Test ACCURACY:  0.3740024630541872


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  9.740042101645103e-05
Train ACCURACY:  0.500254400440044


100%|██████████| 25/25 [00:03<00:00,  7.54it/s]
 23%|██▎       | 23/100 [17:56<1:00:04, 46.82s/it]

Test CrossEntropyLoss:  0.0006519638479049569
Test ACCURACY:  0.3745073891625616


100%|██████████| 225/225 [00:43<00:00,  5.21it/s]


Train CrossEntropyLoss:  9.579738044225476e-05
Train ACCURACY:  0.5038365995031742


100%|██████████| 25/25 [00:03<00:00,  7.58it/s]
 24%|██▍       | 24/100 [18:43<59:13, 46.75s/it]  

Test CrossEntropyLoss:  0.0006598937805063032
Test ACCURACY:  0.3767118226600985


100%|██████████| 225/225 [00:43<00:00,  5.20it/s]


Train CrossEntropyLoss:  9.169603066728713e-05
Train ACCURACY:  0.5043954833379234


100%|██████████| 25/25 [00:03<00:00,  7.64it/s]
 25%|██▌       | 25/100 [19:29<58:24, 46.72s/it]

Test CrossEntropyLoss:  0.0006703335854220273
Test ACCURACY:  0.37695812807881773


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  8.766352914530656e-05
Train ACCURACY:  0.5076165517241379


100%|██████████| 25/25 [00:03<00:00,  7.51it/s]
 26%|██▌       | 26/100 [20:16<57:38, 46.74s/it]

Test CrossEntropyLoss:  0.0006698015713926607
Test ACCURACY:  0.3748768472906404


100%|██████████| 225/225 [00:43<00:00,  5.20it/s]


Train CrossEntropyLoss:  8.622827331389413e-05
Train ACCURACY:  0.5079831816928592


100%|██████████| 25/25 [00:03<00:00,  7.52it/s]
 27%|██▋       | 27/100 [21:03<56:51, 46.74s/it]

Test CrossEntropyLoss:  0.0006788354759733078
Test ACCURACY:  0.37695812807881773


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  8.381024484703269e-05
Train ACCURACY:  0.5079367262723521


100%|██████████| 25/25 [00:03<00:00,  7.53it/s]


Test CrossEntropyLoss:  0.0006784021237800862
Test ACCURACY:  0.378743842364532


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  8.17330124397299e-05
Train ACCURACY:  0.5113615023474178


100%|██████████| 25/25 [00:03<00:00,  7.53it/s]
 29%|██▉       | 29/100 [22:37<55:23, 46.80s/it]

Test CrossEntropyLoss:  0.0006924563440783271
Test ACCURACY:  0.3776847290640394


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  7.842628614292817e-05
Train ACCURACY:  0.5117911270322403


100%|██████████| 25/25 [00:03<00:00,  7.53it/s]
 30%|███       | 30/100 [23:23<54:35, 46.80s/it]

Test CrossEntropyLoss:  0.0006952104310096779
Test ACCURACY:  0.3782881773399015


100%|██████████| 225/225 [00:43<00:00,  5.20it/s]


Train CrossEntropyLoss:  7.724798908786832e-05
Train ACCURACY:  0.5147981194690265


100%|██████████| 25/25 [00:03<00:00,  7.56it/s]
 31%|███       | 31/100 [24:10<53:47, 46.77s/it]

Test CrossEntropyLoss:  0.0007002803198809694
Test ACCURACY:  0.3786206896551724


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  7.410487577310945e-05
Train ACCURACY:  0.5137658227848101


100%|██████████| 25/25 [00:03<00:00,  7.54it/s]
 32%|███▏      | 32/100 [24:57<53:02, 46.80s/it]

Test CrossEntropyLoss:  0.0007057564569811516
Test ACCURACY:  0.37729064039408866


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  7.336438255827072e-05
Train ACCURACY:  0.5162058011049724


100%|██████████| 25/25 [00:03<00:00,  7.51it/s]
 33%|███▎      | 33/100 [25:44<52:17, 46.83s/it]

Test CrossEntropyLoss:  0.0007079539216797927
Test ACCURACY:  0.3763054187192118


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  7.09632850706102e-05
Train ACCURACY:  0.5148776464118779


100%|██████████| 25/25 [00:03<00:00,  7.51it/s]
 34%|███▍      | 34/100 [26:31<51:31, 46.84s/it]

Test CrossEntropyLoss:  0.0007140074486802952
Test ACCURACY:  0.37742610837438423


100%|██████████| 225/225 [00:43<00:00,  5.14it/s]


Train CrossEntropyLoss:  6.886197102427221e-05
Train ACCURACY:  0.5163696369636964


100%|██████████| 25/25 [00:03<00:00,  7.55it/s]
 35%|███▌      | 35/100 [27:18<50:51, 46.94s/it]

Test CrossEntropyLoss:  0.000716460737688788
Test ACCURACY:  0.3761822660098522


100%|██████████| 225/225 [00:43<00:00,  5.20it/s]


Train CrossEntropyLoss:  7.008568890878464e-05
Train ACCURACY:  0.5203952843273232


100%|██████████| 25/25 [00:03<00:00,  7.51it/s]
 36%|███▌      | 36/100 [28:05<50:00, 46.89s/it]

Test CrossEntropyLoss:  0.0007222506389242088
Test ACCURACY:  0.3766748768472906


100%|██████████| 225/225 [00:43<00:00,  5.16it/s]


Train CrossEntropyLoss:  6.847433213971238e-05
Train ACCURACY:  0.5199086884338683


100%|██████████| 25/25 [00:03<00:00,  7.54it/s]
 37%|███▋      | 37/100 [28:52<49:17, 46.94s/it]

Test CrossEntropyLoss:  0.000721228745183334
Test ACCURACY:  0.3766009852216749


100%|██████████| 225/225 [00:43<00:00,  5.20it/s]


Train CrossEntropyLoss:  6.650222913295673e-05
Train ACCURACY:  0.5206536760641238


100%|██████████| 25/25 [00:03<00:00,  7.60it/s]


Test CrossEntropyLoss:  0.0007246648736775215
Test ACCURACY:  0.3791995073891626


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  6.515523832705286e-05
Train ACCURACY:  0.5184543454345435


100%|██████████| 25/25 [00:03<00:00,  7.51it/s]
 39%|███▉      | 39/100 [30:25<47:39, 46.88s/it]

Test CrossEntropyLoss:  0.0007290814809611278
Test ACCURACY:  0.37823891625615763


100%|██████████| 225/225 [00:43<00:00,  5.20it/s]


Train CrossEntropyLoss:  6.548958189048836e-05
Train ACCURACY:  0.5190016524373451


100%|██████████| 25/25 [00:03<00:00,  7.54it/s]
 40%|████      | 40/100 [31:12<46:49, 46.83s/it]

Test CrossEntropyLoss:  0.0007305593502345344
Test ACCURACY:  0.3786945812807882


100%|██████████| 225/225 [00:43<00:00,  5.20it/s]


Train CrossEntropyLoss:  6.184014025561058e-05
Train ACCURACY:  0.5228881215469613


100%|██████████| 25/25 [00:03<00:00,  7.53it/s]
 41%|████      | 41/100 [31:59<46:00, 46.79s/it]

Test CrossEntropyLoss:  0.0007378573458770226
Test ACCURACY:  0.37822660098522165


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  6.14743705233957e-05
Train ACCURACY:  0.522125723738627


100%|██████████| 25/25 [00:03<00:00,  7.55it/s]
 42%|████▏     | 42/100 [32:46<45:14, 46.80s/it]

Test CrossEntropyLoss:  0.0007391559549153145
Test ACCURACY:  0.37857142857142856


100%|██████████| 225/225 [00:43<00:00,  5.20it/s]


Train CrossEntropyLoss:  6.054912338361296e-05
Train ACCURACY:  0.5222526866905484


100%|██████████| 25/25 [00:08<00:00,  3.11it/s]
 43%|████▎     | 43/100 [33:37<45:47, 48.20s/it]

Test CrossEntropyLoss:  0.0007418202707920169
Test ACCURACY:  0.3775615763546798


100%|██████████| 225/225 [00:52<00:00,  4.32it/s]


Train CrossEntropyLoss:  5.926584451360162e-05
Train ACCURACY:  0.5248672199170125


100%|██████████| 25/25 [00:03<00:00,  6.85it/s]
 44%|████▍     | 44/100 [34:33<47:07, 50.50s/it]

Test CrossEntropyLoss:  0.0007407934354443855
Test ACCURACY:  0.3787931034482759


100%|██████████| 225/225 [00:51<00:00,  4.37it/s]


Train CrossEntropyLoss:  5.904940847580529e-05
Train ACCURACY:  0.5253540802213001


100%|██████████| 25/25 [00:03<00:00,  6.78it/s]
 45%|████▌     | 45/100 [35:28<47:37, 51.95s/it]

Test CrossEntropyLoss:  0.0007450577337753596
Test ACCURACY:  0.37767241379310346


100%|██████████| 225/225 [00:54<00:00,  4.12it/s]


Train CrossEntropyLoss:  5.729494142626175e-05
Train ACCURACY:  0.5251505109085888


100%|██████████| 25/25 [00:04<00:00,  6.23it/s]


Test CrossEntropyLoss:  0.000750383173890889
Test ACCURACY:  0.3792241379310345


100%|██████████| 225/225 [00:54<00:00,  4.13it/s]


Train CrossEntropyLoss:  5.5676074480669184e-05
Train ACCURACY:  0.5246022570878062


100%|██████████| 25/25 [00:04<00:00,  6.20it/s]


Test CrossEntropyLoss:  0.0007544259073699049
Test ACCURACY:  0.3816871921182266


100%|██████████| 225/225 [00:54<00:00,  4.14it/s]


Train CrossEntropyLoss:  5.482608174717113e-05
Train ACCURACY:  0.5273789764868603


100%|██████████| 25/25 [00:04<00:00,  6.23it/s]
 48%|████▊     | 48/100 [38:25<48:50, 56.35s/it]

Test CrossEntropyLoss:  0.0007575275040612432
Test ACCURACY:  0.37896551724137933


100%|██████████| 225/225 [00:53<00:00,  4.17it/s]


Train CrossEntropyLoss:  5.5770137104754446e-05
Train ACCURACY:  0.5292254303164908


100%|██████████| 25/25 [00:04<00:00,  6.21it/s]
 49%|████▉     | 49/100 [39:23<48:21, 56.89s/it]

Test CrossEntropyLoss:  0.000753487853581095
Test ACCURACY:  0.3787931034482759


100%|██████████| 225/225 [00:54<00:00,  4.14it/s]


Train CrossEntropyLoss:  5.5248347073054104e-05
Train ACCURACY:  0.5253900220507166


100%|██████████| 25/25 [00:04<00:00,  6.17it/s]
 50%|█████     | 50/100 [40:21<47:49, 57.39s/it]

Test CrossEntropyLoss:  0.0007585824035071387
Test ACCURACY:  0.37897783251231526


100%|██████████| 225/225 [00:54<00:00,  4.13it/s]


Train CrossEntropyLoss:  5.3955674664048126e-05
Train ACCURACY:  0.5258732782369147


100%|██████████| 25/25 [00:04<00:00,  6.22it/s]
 51%|█████     | 51/100 [41:20<47:09, 57.74s/it]

Test CrossEntropyLoss:  0.0007570985708330653
Test ACCURACY:  0.3782881773399015


100%|██████████| 225/225 [00:54<00:00,  4.15it/s]


Train CrossEntropyLoss:  5.2166859246585045e-05
Train ACCURACY:  0.5271912899669239


100%|██████████| 25/25 [00:04<00:00,  6.24it/s]
 52%|█████▏    | 52/100 [42:18<46:20, 57.93s/it]

Test CrossEntropyLoss:  0.0007652768213760677
Test ACCURACY:  0.37769704433497536


100%|██████████| 225/225 [00:54<00:00,  4.11it/s]


Train CrossEntropyLoss:  5.2496219415120276e-05
Train ACCURACY:  0.5258308115543329


100%|██████████| 25/25 [00:03<00:00,  6.26it/s]
 53%|█████▎    | 53/100 [43:17<45:36, 58.23s/it]

Test CrossEntropyLoss:  0.0007619167372510938
Test ACCURACY:  0.38019704433497536


100%|██████████| 225/225 [00:54<00:00,  4.13it/s]


Train CrossEntropyLoss:  5.181837231332995e-05
Train ACCURACY:  0.529195020746888


100%|██████████| 25/25 [00:03<00:00,  6.26it/s]
 54%|█████▍    | 54/100 [44:16<44:44, 58.36s/it]

Test CrossEntropyLoss:  0.0007671165202051548
Test ACCURACY:  0.3812931034482759


100%|██████████| 225/225 [00:54<00:00,  4.15it/s]


Train CrossEntropyLoss:  5.160183964097973e-05
Train ACCURACY:  0.5274820837927232


100%|██████████| 25/25 [00:04<00:00,  6.18it/s]
 55%|█████▌    | 55/100 [45:14<43:47, 58.38s/it]

Test CrossEntropyLoss:  0.0007746784998278312
Test ACCURACY:  0.38149014778325124


100%|██████████| 225/225 [00:54<00:00,  4.14it/s]


Train CrossEntropyLoss:  5.141054539145618e-05
Train ACCURACY:  0.5281677704194261


100%|██████████| 25/25 [00:04<00:00,  6.17it/s]
 56%|█████▌    | 56/100 [46:13<42:51, 58.45s/it]

Test CrossEntropyLoss:  0.0007684869308189806
Test ACCURACY:  0.3797783251231527


100%|██████████| 225/225 [00:54<00:00,  4.13it/s]


Train CrossEntropyLoss:  5.0049762647713856e-05
Train ACCURACY:  0.5277656938325991


100%|██████████| 25/25 [00:04<00:00,  6.21it/s]
 57%|█████▋    | 57/100 [47:11<41:55, 58.50s/it]

Test CrossEntropyLoss:  0.0007753865795182477
Test ACCURACY:  0.3800738916256158


100%|██████████| 225/225 [00:54<00:00,  4.15it/s]


Train CrossEntropyLoss:  4.8234217817394424e-05
Train ACCURACY:  0.5294061724993111


100%|██████████| 25/25 [00:03<00:00,  6.27it/s]
 58%|█████▊    | 58/100 [48:10<40:54, 58.44s/it]

Test CrossEntropyLoss:  0.0007779355060878059
Test ACCURACY:  0.3786822660098522


100%|██████████| 225/225 [00:54<00:00,  4.14it/s]


Train CrossEntropyLoss:  4.8119250821505544e-05
Train ACCURACY:  0.5293538164783687


100%|██████████| 25/25 [00:04<00:00,  6.07it/s]
 59%|█████▉    | 59/100 [49:08<39:58, 58.50s/it]

Test CrossEntropyLoss:  0.0007826274486598123
Test ACCURACY:  0.37986453201970444


100%|██████████| 225/225 [00:54<00:00,  4.10it/s]


Train CrossEntropyLoss:  4.82710571270031e-05
Train ACCURACY:  0.5286774015964767


100%|██████████| 25/25 [00:04<00:00,  6.24it/s]
 60%|██████    | 60/100 [50:07<39:06, 58.67s/it]

Test CrossEntropyLoss:  0.0007830939210694412
Test ACCURACY:  0.3783128078817734


100%|██████████| 225/225 [00:54<00:00,  4.13it/s]


Train CrossEntropyLoss:  5.356465964255356e-05
Train ACCURACY:  0.5255871837183719


100%|██████████| 25/25 [00:04<00:00,  6.21it/s]
 61%|██████    | 61/100 [51:06<38:07, 58.66s/it]

Test CrossEntropyLoss:  0.0007762169603056508
Test ACCURACY:  0.377807881773399


100%|██████████| 225/225 [00:54<00:00,  4.14it/s]


Train CrossEntropyLoss:  5.324790619801138e-05
Train ACCURACY:  0.5249725123694338


100%|██████████| 25/25 [00:04<00:00,  6.23it/s]
 62%|██████▏   | 62/100 [52:05<37:07, 58.62s/it]

Test CrossEntropyLoss:  0.0007806387147292715
Test ACCURACY:  0.380307881773399


100%|██████████| 225/225 [00:54<00:00,  4.14it/s]


Train CrossEntropyLoss:  5.102626092573342e-05
Train ACCURACY:  0.5293972905722975


100%|██████████| 25/25 [00:04<00:00,  6.22it/s]
 63%|██████▎   | 63/100 [53:03<36:07, 58.58s/it]

Test CrossEntropyLoss:  0.0007761589323945821
Test ACCURACY:  0.3803448275862069


100%|██████████| 225/225 [00:48<00:00,  4.69it/s]


Train CrossEntropyLoss:  4.300844141969082e-05
Train ACCURACY:  0.5324545203969129


100%|██████████| 25/25 [00:03<00:00,  7.53it/s]
 64%|██████▍   | 64/100 [53:54<33:51, 56.44s/it]

Test CrossEntropyLoss:  0.0007886136459012337
Test ACCURACY:  0.3808128078817734


100%|██████████| 225/225 [00:43<00:00,  5.20it/s]


Train CrossEntropyLoss:  4.046893305974015e-05
Train ACCURACY:  0.532805448541552


100%|██████████| 25/25 [00:03<00:00,  7.53it/s]
 65%|██████▌   | 65/100 [54:41<31:13, 53.53s/it]

Test CrossEntropyLoss:  0.0007882693336515004
Test ACCURACY:  0.3809113300492611


100%|██████████| 225/225 [00:43<00:00,  5.20it/s]


Train CrossEntropyLoss:  4.188108071787547e-05
Train ACCURACY:  0.5333705541770057


100%|██████████| 25/25 [00:03<00:00,  7.47it/s]
 66%|██████▌   | 66/100 [55:28<29:11, 51.50s/it]

Test CrossEntropyLoss:  0.0007905925847039434
Test ACCURACY:  0.3813177339901478


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  4.394919339343946e-05
Train ACCURACY:  0.5310071546505228


100%|██████████| 25/25 [00:03<00:00,  7.48it/s]
 67%|██████▋   | 67/100 [56:15<27:33, 50.12s/it]

Test CrossEntropyLoss:  0.0007894131056780886
Test ACCURACY:  0.38022167487684727


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  4.379539836441246e-05
Train ACCURACY:  0.5314004406499587


100%|██████████| 25/25 [00:03<00:00,  7.51it/s]
 68%|██████▊   | 68/100 [57:02<26:12, 49.14s/it]

Test CrossEntropyLoss:  0.0007926280774506442
Test ACCURACY:  0.38051724137931037


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  4.408477960775296e-05
Train ACCURACY:  0.5323772075055188


100%|██████████| 25/25 [00:03<00:00,  7.51it/s]
 69%|██████▉   | 69/100 [57:49<25:02, 48.47s/it]

Test CrossEntropyLoss:  0.0007964145931704291
Test ACCURACY:  0.37935960591133006


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  4.196996433716832e-05
Train ACCURACY:  0.5326845730027548


100%|██████████| 25/25 [00:03<00:00,  7.48it/s]
 70%|███████   | 70/100 [58:35<23:59, 47.97s/it]

Test CrossEntropyLoss:  0.000802058668559408
Test ACCURACY:  0.3802832512315271


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  4.219413614074717e-05
Train ACCURACY:  0.531819306930693


100%|██████████| 25/25 [00:03<00:00,  7.59it/s]
 71%|███████   | 71/100 [59:22<23:01, 47.64s/it]

Test CrossEntropyLoss:  0.0007978804035139789
Test ACCURACY:  0.3791871921182266


100%|██████████| 225/225 [00:43<00:00,  5.21it/s]


Train CrossEntropyLoss:  4.098781067399615e-05
Train ACCURACY:  0.5347224523612262


100%|██████████| 25/25 [00:03<00:00,  7.51it/s]
 72%|███████▏  | 72/100 [1:00:09<22:05, 47.35s/it]

Test CrossEntropyLoss:  0.0008030404803788133
Test ACCURACY:  0.3797783251231527


100%|██████████| 225/225 [00:43<00:00,  5.20it/s]


Train CrossEntropyLoss:  4.2020468487532704e-05
Train ACCURACY:  0.5345495993368333


100%|██████████| 25/25 [00:03<00:00,  7.50it/s]
 73%|███████▎  | 73/100 [1:00:56<21:13, 47.16s/it]

Test CrossEntropyLoss:  0.0008034137052855468
Test ACCURACY:  0.3797536945812808


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  4.2623985903016454e-05
Train ACCURACY:  0.533328275862069


100%|██████████| 25/25 [00:03<00:00,  7.41it/s]
 74%|███████▍  | 74/100 [1:01:43<20:23, 47.08s/it]

Test CrossEntropyLoss:  0.0008050068816527944
Test ACCURACY:  0.3810344827586207


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  4.121779935367186e-05
Train ACCURACY:  0.5338075544527158


100%|██████████| 25/25 [00:03<00:00,  7.46it/s]
 75%|███████▌  | 75/100 [1:02:29<19:35, 47.00s/it]

Test CrossEntropyLoss:  0.000805883448699425
Test ACCURACY:  0.381564039408867


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  4.175219391976955e-05
Train ACCURACY:  0.5315929081913139


100%|██████████| 25/25 [00:03<00:00,  7.50it/s]
 76%|███████▌  | 76/100 [1:03:16<18:47, 46.97s/it]

Test CrossEntropyLoss:  0.000802880117458663
Test ACCURACY:  0.3812192118226601


100%|██████████| 225/225 [00:43<00:00,  5.17it/s]


Train CrossEntropyLoss:  4.060423434248125e-05
Train ACCURACY:  0.5329113924050632


100%|██████████| 25/25 [00:03<00:00,  7.48it/s]


Test CrossEntropyLoss:  0.0007997788731100524
Test ACCURACY:  0.3817487684729064


100%|██████████| 225/225 [00:43<00:00,  5.17it/s]]


Train CrossEntropyLoss:  3.954212477868036e-05
Train ACCURACY:  0.5336911643270025


100%|██████████| 25/25 [00:03<00:00,  7.46it/s]
 78%|███████▊  | 78/100 [1:04:50<17:14, 47.01s/it]

Test CrossEntropyLoss:  0.0008171837934719517
Test ACCURACY:  0.3799384236453202


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  4.023944468928873e-05
Train ACCURACY:  0.5357725262576009


100%|██████████| 25/25 [00:03<00:00,  7.54it/s]
 79%|███████▉  | 79/100 [1:05:37<16:26, 46.96s/it]

Test CrossEntropyLoss:  0.0008072624241777242
Test ACCURACY:  0.3817118226600985


100%|██████████| 225/225 [00:43<00:00,  5.19it/s]


Train CrossEntropyLoss:  3.926394738118537e-05
Train ACCURACY:  0.5380094261158858


100%|██████████| 25/25 [00:03<00:00,  7.44it/s]


Test CrossEntropyLoss:  0.0008088120304304978
Test ACCURACY:  0.3826231527093596


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]]


Train CrossEntropyLoss:  3.9116456628982506e-05
Train ACCURACY:  0.5362617468214483


100%|██████████| 25/25 [00:03<00:00,  7.49it/s]
 81%|████████  | 81/100 [1:07:11<14:52, 46.95s/it]

Test CrossEntropyLoss:  0.0008179438730766033
Test ACCURACY:  0.3825369458128079


100%|██████████| 225/225 [00:43<00:00,  5.17it/s]


Train CrossEntropyLoss:  3.9693815902822966e-05
Train ACCURACY:  0.5336801541425819


100%|██████████| 25/25 [00:03<00:00,  7.48it/s]
 82%|████████▏ | 82/100 [1:07:58<14:05, 46.96s/it]

Test CrossEntropyLoss:  0.0008096857669905489
Test ACCURACY:  0.38120689655172413


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  3.818266475627224e-05
Train ACCURACY:  0.5366095606521138


100%|██████████| 25/25 [00:03<00:00,  7.47it/s]
 83%|████████▎ | 83/100 [1:08:45<13:18, 46.94s/it]

Test CrossEntropyLoss:  0.0008107080189465302
Test ACCURACY:  0.38116995073891624


100%|██████████| 225/225 [00:43<00:00,  5.17it/s]


Train CrossEntropyLoss:  3.6596062865150294e-05
Train ACCURACY:  0.5353275529865126


100%|██████████| 25/25 [00:03<00:00,  7.44it/s]
 84%|████████▍ | 84/100 [1:09:32<12:31, 46.97s/it]

Test CrossEntropyLoss:  0.000821944090533139
Test ACCURACY:  0.38133004926108377


100%|██████████| 225/225 [00:43<00:00,  5.16it/s]


Train CrossEntropyLoss:  3.641161233784241e-05
Train ACCURACY:  0.5344603131007964


100%|██████████| 25/25 [00:03<00:00,  7.56it/s]
 85%|████████▌ | 85/100 [1:10:19<11:44, 47.00s/it]

Test CrossEntropyLoss:  0.0008169511063345548
Test ACCURACY:  0.3819950738916256


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  3.8219954060332785e-05
Train ACCURACY:  0.5333768195550673


100%|██████████| 25/25 [00:03<00:00,  7.48it/s]
 86%|████████▌ | 86/100 [1:11:06<10:57, 46.98s/it]

Test CrossEntropyLoss:  0.0008113778488976614
Test ACCURACY:  0.3811945812807882


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  0.000219611500276248
Train ACCURACY:  0.4591352244560727


100%|██████████| 25/25 [00:03<00:00,  7.48it/s]
 87%|████████▋ | 87/100 [1:11:53<10:10, 46.97s/it]

Test CrossEntropyLoss:  0.0006658836391759036
Test ACCURACY:  0.35887931034482756


100%|██████████| 225/225 [00:43<00:00,  5.17it/s]


Train CrossEntropyLoss:  9.517991125600123e-05
Train ACCURACY:  0.50684770591487


100%|██████████| 25/25 [00:03<00:00,  7.44it/s]


Test CrossEntropyLoss:  0.0006975740194320679
Test ACCURACY:  0.38445812807881774


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]]


Train CrossEntropyLoss:  3.6256987992571205e-05
Train ACCURACY:  0.5373218232044199


100%|██████████| 25/25 [00:03<00:00,  7.50it/s]


Test CrossEntropyLoss:  0.0007384355549741848
Test ACCURACY:  0.38726600985221676


100%|██████████| 225/225 [00:43<00:00,  5.17it/s]]


Train CrossEntropyLoss:  2.6146070250249794e-05
Train ACCURACY:  0.5410082530949106


100%|██████████| 25/25 [00:03<00:00,  7.57it/s]
 90%|█████████ | 90/100 [1:14:14<07:49, 47.00s/it]

Test CrossEntropyLoss:  0.0007593320625756175
Test ACCURACY:  0.3860344827586207


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  2.3751516908973077e-05
Train ACCURACY:  0.5425866813428729


100%|██████████| 25/25 [00:03<00:00,  7.47it/s]
 91%|█████████ | 91/100 [1:15:01<07:02, 46.97s/it]

Test CrossEntropyLoss:  0.0007924226968746467
Test ACCURACY:  0.38572660098522166


100%|██████████| 225/225 [00:43<00:00,  5.18it/s]


Train CrossEntropyLoss:  2.5914323375797415e-05
Train ACCURACY:  0.5431681391496411


100%|██████████| 25/25 [00:03<00:00,  7.50it/s]
 92%|█████████▏| 92/100 [1:15:48<06:15, 46.95s/it]

Test CrossEntropyLoss:  0.0007973086804591963
Test ACCURACY:  0.3849384236453202


100%|██████████| 225/225 [00:43<00:00,  5.16it/s]


Train CrossEntropyLoss:  2.9259625322796767e-05
Train ACCURACY:  0.5398389317180616


100%|██████████| 25/25 [00:03<00:00,  7.42it/s]
 93%|█████████▎| 93/100 [1:16:35<05:29, 47.00s/it]

Test CrossEntropyLoss:  0.0008048249230596232
Test ACCURACY:  0.384371921182266


100%|██████████| 225/225 [00:43<00:00,  5.14it/s]


Train CrossEntropyLoss:  3.3441559295009876e-05
Train ACCURACY:  0.5369188445667126


100%|██████████| 25/25 [00:03<00:00,  7.34it/s]
 94%|█████████▍| 94/100 [1:17:22<04:42, 47.10s/it]

Test CrossEntropyLoss:  0.0008141783569833916
Test ACCURACY:  0.3816256157635468


100%|██████████| 225/225 [00:43<00:00,  5.15it/s]


Train CrossEntropyLoss:  3.4263719299743914e-05
Train ACCURACY:  0.537896551724138


100%|██████████| 25/25 [00:03<00:00,  7.44it/s]
 95%|█████████▌| 95/100 [1:18:10<03:55, 47.11s/it]

Test CrossEntropyLoss:  0.0008151894749091763
Test ACCURACY:  0.3816256157635468


100%|██████████| 225/225 [00:43<00:00,  5.16it/s]


Train CrossEntropyLoss:  3.53609583038368e-05
Train ACCURACY:  0.5376379690949228


100%|██████████| 25/25 [00:03<00:00,  7.43it/s]
 96%|█████████▌| 96/100 [1:18:57<03:08, 47.10s/it]

Test CrossEntropyLoss:  0.0008179003293878339
Test ACCURACY:  0.38316502463054186


100%|██████████| 225/225 [00:43<00:00,  5.17it/s]


Train CrossEntropyLoss:  3.5236391625126337e-05
Train ACCURACY:  0.5357260726072607


100%|██████████| 25/25 [00:03<00:00,  7.45it/s]
 97%|█████████▋| 97/100 [1:19:44<02:21, 47.08s/it]

Test CrossEntropyLoss:  0.0008147805371307974
Test ACCURACY:  0.38141625615763547


100%|██████████| 225/225 [00:43<00:00,  5.16it/s]


Train CrossEntropyLoss:  3.3102471551924396e-05
Train ACCURACY:  0.5384849738075544


100%|██████████| 25/25 [00:03<00:00,  7.40it/s]
 98%|█████████▊| 98/100 [1:20:31<01:34, 47.09s/it]

Test CrossEntropyLoss:  0.0008319573948536013
Test ACCURACY:  0.38073891625615763


100%|██████████| 225/225 [00:44<00:00,  5.11it/s]


Train CrossEntropyLoss:  3.321652536638931e-05
Train ACCURACY:  0.5395385465598231


100%|██████████| 25/25 [00:03<00:00,  7.21it/s]
 99%|█████████▉| 99/100 [1:21:18<00:47, 47.25s/it]

Test CrossEntropyLoss:  0.0008211052300307551
Test ACCURACY:  0.3822167487684729


100%|██████████| 225/225 [00:44<00:00,  5.09it/s]


Train CrossEntropyLoss:  3.46608002038499e-05
Train ACCURACY:  0.5391774951617363


100%|██████████| 25/25 [00:03<00:00,  7.47it/s]
100%|██████████| 100/100 [1:22:06<00:00, 49.27s/it]

Test CrossEntropyLoss:  0.0008165678308515126
Test ACCURACY:  0.3797167487684729





今回は単純に正解率を評価指標にしたため(padやbosは考慮していないので低く出る)、Test Accuracyは38%程度でしたが、実際にどのようなものが出力されているのかを見てみましょう。<br>

In [56]:
#parameterの引き継ぎ
model.parameters = best_params

In [89]:
for japanese, english, ja_len, en_len in tqdm(test_dl):
        japanese = japanese.to(device)
        english = english.to(device)
        padding_src_mask, src_mask, padding_tgt_mask, tgt_mask = create_mask(japanese, english, src_pad, tgt_pad, device)
        pred = model(japanese, english, pad_encoder_mask = padding_src_mask, encoder_mask = src_mask, \
                     pad_decoder_mask_self = padding_tgt_mask, decoder_mask_self = tgt_mask)
        pred = pred.argmax(dim=-1)
        pred = pred.cpu().numpy()
        english = english.cpu().numpy()
        for i in range(10):
                print(f"予測{i}", " ".join(en_vocab.lookup_tokens(pred[i][:en_len[i]-2])).capitalize())
                print(f"正解{i}", " ".join(en_vocab.lookup_tokens(english[i][1:en_len[i]-1])).capitalize())
        break

  0%|          | 0/25 [00:00<?, ?it/s]

予測0 Put ' cold outside to <eos> on your coat .
正解0 It is cold outdoors . put on your coat .
予測1 You get to get in touch with your parents . once .
正解1 You had better get in touch with your parents at once .
予測2 He will pay for dollars . most .
正解2 He will pay 20 dollars at most .
予測3 I is half ten minutes .
正解3 It leaves every thirty minutes .
予測4 He is every statesman in every of .
正解4 He is a politician in all senses .
予測5 Who will assure ? success ?
正解5 Who can guarantee his success ?
予測6 Our airport broke down on our way to the airport .
正解6 The car broke down on the way to the airport .
予測7 The are wish to it is s rising to leave early hours .
正解7 We all know that it ' s better to keep early hours .
予測8 I got a pair of shoes .
正解8 I bought a pair of shoes .
予測9 I hair stood on end .
正解9 His hair stood on end .





予測0は2文であったために、ピリオドのあとに<eos>が予想されています。これはテキストの前処理の甘さに起因すると考えられます。<br>
他の文はあってそうであっていない文章ばかりであるため、もう少し学習に工夫を加える必要がありそうです。<br>
予測8が単語が異なっているのに似たような文章になっていることは興味深いと考えられます。<br>これは分散表現が似たような意味の単語は似たような表現になるためだと考えられます。<br>

In [53]:
#次の作業を行うために、メモリを開放しておく
import gc
del model
gc.collect()
torch.cuda.empty_cache()

単純なAttentionモデルより精度が上がったことがわかったと思います。

今までに制作したMultiheadAttentionとTransformerは全てtorch.nnのモジュールとして提供されています。実際にTransformerより上位のモデルを構築する際はそちらを用いたほうが開発効率が良いでしょう。<br>

Pytorchのライブラリ目次<br>
[torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html)<br>
[torch.nn.Transformer](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html)<br>