### データの準備

In [None]:
!unzip -q kyoto_translation.zip

In [None]:
import pandas as pd

data_path = 'kyoto_translation/'

df_train = pd.read_csv(data_path + 'train.csv')
df_test = pd.read_csv(data_path + 'test.csv')

In [None]:
df_train.tail(3)

Unnamed: 0,Japanese,English
36376,入木抄,Juboku sho
36377,マキノ (駅),Makino Station
36378,ウィングス京都,Wings Kyoto


In [None]:
df_test.tail(3)

Unnamed: 0,Japanese,English
15589,月の宴,Tsuki no Utage (party of the moon)
15590,義満,Yoshimitsu
15591,北条友時,Tomotoki HOJO


In [None]:
df_train.shape, df_test.shape

((36379, 2), (15592, 2))

### 必要なモジュールの読み込み

In [None]:
# Only for Colab
!pip install -q pytorch_lightning
!pip install -q torchtext==0.11.0

[K     |████████████████████████████████| 585 kB 7.4 MB/s 
[K     |████████████████████████████████| 140 kB 54.2 MB/s 
[K     |████████████████████████████████| 418 kB 44.8 MB/s 
[K     |████████████████████████████████| 596 kB 44.6 MB/s 
[K     |████████████████████████████████| 1.1 MB 45.0 MB/s 
[K     |████████████████████████████████| 94 kB 1.5 MB/s 
[K     |████████████████████████████████| 144 kB 20.8 MB/s 
[K     |████████████████████████████████| 271 kB 51.6 MB/s 
[K     |████████████████████████████████| 8.0 MB 16.9 MB/s 
[K     |██████████████████████████████▎ | 834.1 MB 1.2 MB/s eta 0:00:39tcmalloc: large alloc 1147494400 bytes == 0x39070000 @  0x7fe675f7b615 0x592b76 0x4df71e 0x59afff 0x515655 0x549576 0x593fce 0x548ae9 0x51566f 0x549576 0x593fce 0x548ae9 0x5127f1 0x598e3b 0x511f68 0x598e3b 0x511f68 0x598e3b 0x511f68 0x4bc98a 0x532e76 0x594b72 0x515600 0x549576 0x593fce 0x548ae9 0x5127f1 0x549576 0x593fce 0x5118f8 0x593dd7
[K     |████████████████████████████████

In [None]:
# torch, pytorch_lightning, torchtext
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torchtext

  warn(f"Failed to load image Python extension: {e}")


In [None]:
print(torch.__version__)
print(pl.__version__)
print(torchtext.__version__)

1.10.0+cu102
1.6.4
0.11.0


In [None]:
torch.cuda.get_device_name(0)

'Tesla T4'

### 分かち書き（spaCy）

In [None]:
# MeCab
%%capture
!apt install aptitude swig
!aptitude install mecab libmecab-dev mecab-ipadic-utf8 git make curl xz-utils file -y
!pip install mecab-python3==0.996.5
!pip install unidic-lite

# fugashi
!pip install -q fugashi

In [None]:
import spacy

JA = spacy.blank('ja')
EN = spacy.blank('en')

In [None]:
def tokenize_ja(sentence):
    return [tok.text for tok in JA.tokenizer(sentence)]

def tokenize_en(sentence):
    return [tok.text for tok in EN.tokenizer(sentence)]

In [None]:
tokenize_ja('月の宴')

['月', 'の', '宴']

In [None]:
tokenize_en('Tsuki no Utage (party of the moon)')

['Tsuki', 'no', 'Utage', '(', 'party', 'of', 'the', 'moon', ')']

### 辞書の作成

In [None]:
def yield_tokens(df, tokenize):
        for line in df:
            yield tokenize(line)

In [None]:
# 辞書作成には build_vocab_from_iterator を使用します
from torchtext.vocab import build_vocab_from_iterator

vocab_ja = build_vocab_from_iterator(
    yield_tokens(df_train['Japanese'], tokenize_ja),
    specials=('<unk>', '<pad>', '<bos>', '<eos>'),
    special_first=True)

vocab_en = build_vocab_from_iterator(
    yield_tokens(df_train['English'], tokenize_en),
    specials=('<unk>', '<pad>', '<bos>', '<eos>'),
    special_first=True)

In [None]:
print(vocab_ja.get_stoi())
print(vocab_en.get_stoi())

{'ＷＧＳ': 22800, '龗神': 22792, '龍雲': 22789, '龍衆': 22787, '龍王': 22782, '龍之助': 22776, '龍三郎': 22775, '齋藤': 22773, '鼻紙': 22770, '鼻塚': 22768, '鼻': 22767, '鼓面': 22766, '鼎立': 22763, '黒釉': 22755, '黒木': 22742, '黒幕': 22740, '黒尾': 22739, '黒住': 22738, '黒い': 22736, '黄鐘': 22734, '黄蜀葵': 22733, '黄海': 22731, '黄泉比良坂': 22729, '黄梅': 22728, '黄桜': 22727, '黄昏': 22726, '麻緒': 22725, '麻紙': 22724, '麻布': 22720, '麦僊': 22717, '麞': 22715, '麒': 22713, '鹿王': 22709, '鹿ケ谷': 22703, '鷺沼': 22700, '鷹飼': 22698, '鷹山': 22697, '鷹匠': 22695, '鷲山': 22694, '鶴太郎': 22688, '鶯菜': 22686, '鶯張り': 22685, '鶏足': 22683, '鶉': 22682, '鵬': 22681, '鵜匠': 22679, '鴻': 22677, '鴫野': 22676, '鴨場': 22673, '鴟尾': 22672, '鴛鴦': 22671, '鴎': 22669, '鳴鶴': 22668, '鳴り': 22663, '鳳闕': 22662, '鳳輦': 22661, '鳳潭': 22660, '鳳来': 22658, '鳩尾': 22656, '鳥子': 22652, '鱧': 22650, '鰻谷': 22648, '鰹節': 22647, '鰒': 22645, '鯛めし': 22643, '鯉口': 22640, '鮭延': 22637, '鮭': 22636, '鮎川': 22634, '魚谷': 22627, '魚服': 22624, '魚介': 22621, '魔道': 22620, '魔王': 22618, '魏志': 22614, '魁': 22613, '鬼面': 22612

In [None]:
print(len(vocab_ja))
print(len(vocab_en))

22801
30664


In [None]:
vocab_ja.set_default_index(vocab_ja["<unk>"])
vocab_en.set_default_index(vocab_en["<unk>"])

### 文字列のインデックスへの置き換え

In [None]:
transform_ja = lambda x: vocab_ja(tokenize_ja(x))
transform_en = lambda x: [vocab_en['<BOS>']] + vocab_en(tokenize_en(x)) + [vocab_en['<EOS>']]

In [None]:
from torch.nn.utils.rnn import pad_sequence

def translate_index(df, transform):
    text_list = []
    for text in df:
        text_list.append(torch.tensor(transform(text), dtype=torch.int64))
    text_tensor = pad_sequence(text_list, batch_first=True, padding_value=1)
    return text_tensor

In [None]:
ja_train_tensor = translate_index(df_train['Japanese'], transform_ja)
ja_val_tensor = translate_index(df_test['Japanese'], transform_ja)
en_train_tensor = translate_index(df_train['English'], transform_en)
en_val_tensor = translate_index(df_test['English'], transform_en)

print(ja_train_tensor.shape)
print(ja_val_tensor.shape)
print(en_train_tensor.shape)
print(en_val_tensor.shape)

torch.Size([36379, 31])
torch.Size([15592, 28])
torch.Size([36379, 68])
torch.Size([15592, 73])


### DataLoader の作成

In [None]:
from torch.utils.data import TensorDataset, DataLoader

train_dataset = TensorDataset(ja_train_tensor, en_train_tensor)
val_dataset = TensorDataset(ja_val_tensor, en_val_tensor)

In [None]:
n_val = int(len(val_dataset) * 0.6)
n_test = len(val_dataset) - n_val

In [None]:
# ランダムに分割を行うため、シードを固定して再現性を確保
pl.seed_everything(0)

# データセットの分割
val, test = torch.utils.data.random_split(val_dataset, [n_val, n_test])

Global seed set to 0


In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_loader = DataLoader(val, batch_size=32)
test_loader = DataLoader(test, batch_size=32)

### 各層の挙動を確認

### Encoder
- Embedding 層：入力文章を 512 次元のベクトルの分散表現に変換
- Positional Encoder 層：単語の位置情報を付加
- Multi-Head Attention 層：Self-attention を計算
- Normalize 層
- Position-wise Feed-Forward Networks　（位置単位順伝播ネットワーク）：各単語毎に順伝播を適応
- Normalize 層

#### Decoder
- Embedding 層：入力文章を 512 次元のベクトルの分散表現に変換
- Positional Encoder 層：単語の位置情報を付加
- Multi-Head Attention 層：Self-attention を計算
- Normalize 層
- Multi-Head Attention 層：Source-Target Attention を計算
- Normalize 層
- PFFN
- Normalize 層

In [None]:
# 5バッチ分取り出して表示
src, trg = next(iter(train_loader))
src, trg

(tensor([[ 1885,    19,  4722,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1],
         [ 2089,  7487,     5,  4551,     6,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1],
         [ 4631,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1],
         [ 4732, 15716,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1, 

### Encoder の Embedding 層

In [None]:
src_vocab_length = len(vocab_ja)
d_model = 512

src_embedder = nn.Embedding(src_vocab_length, d_model)
embeded = src_embedder(src)
embeded.shape

torch.Size([32, 31, 512])

### Positional Encoder 層

In [None]:
import math

In [None]:
class PositionalEncoder(pl.LightningModule):

    def __init__(self, d_model=512, max_seq_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=0.1)
        self.d_model = d_model
        
        # 0 の行列を作成（Sequence_length, Embedding_dim）
        pe = torch.zeros(max_seq_len, d_model)

        # pe に位置情報が入った配列を追加
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                # 配列中の0 と偶数インデックスには sin 波を適用
                pe[pos, i] = math.sin(pos / 10000.0 ** ((2 * i) / d_model))
                # 配列中の奇数インデックスには cos 波を適用
                pe[pos, i + 1] = math.cos(pos / 10000.0 ** ((2 * (i + 1)) / d_model))

        pe = pe.unsqueeze(1)
        # print(f'PE のサイズ: {pe.shape}')

        # PE を pe という名前でモデルに保存
        self.register_buffer('pe', pe)

    def forward(self, x):
        # 埋め込み表現の値に sqrt を掛け値を大きくする
        x = x * math.sqrt(self.d_model)

        # 元の埋め込み表現に pe を足し合わせ位置情報を付加
        x = x + self.pe[:x.size(0), :]
        x = self.dropout(x)
        return x

In [None]:
pos_encoder = PositionalEncoder(d_model)

pos_embeded = pos_encoder(embeded)
pos_embeded.shape

torch.Size([32, 31, 512])

In [None]:
pos_encoder.pe[:3]

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00]],

        [[ 8.4147e-01,  5.6969e-01,  8.0196e-01,  ...,  1.0000e+00,
           1.0746e-08,  1.0000e+00]],

        [[ 9.0930e-01, -3.5090e-01,  9.5814e-01,  ...,  1.0000e+00,
           2.1492e-08,  1.0000e+00]]])

### Transformer の Encoder 層

In [None]:
src_pad_idx = vocab_ja['<pad>']
src_pad_idx

1

In [None]:
def create_src_pad_mask(src):
        src_pad_mask = src == src_pad_idx
        return src_pad_mask

In [None]:
src_mask = create_src_pad_mask(src)
src_mask.shape

torch.Size([32, 31])

In [None]:
# EncoderLayer
encoder_layer = nn.TransformerEncoderLayer(
    d_model, nhead=8,
    dim_feedforward=2048,
    dropout=0.1,
    activation='relu',
    batch_first=True
)
encoder_layer

TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
  )
  (linear1): Linear(in_features=512, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2048, out_features=512, bias=True)
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
)

In [None]:
# LayerNorm
encoder_norm =  nn.LayerNorm(d_model)

In [None]:
encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=6, norm=encoder_norm)
encoder

TransformerEncoder(
  (layers): ModuleList(
    (0): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
      )
      (linear1): Linear(in_features=512, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=512, bias=True)
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
      )
      (linear1): Linear(in_features=512, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_featu

In [None]:
enc_out = encoder(pos_embeded, src_key_padding_mask=src_mask)
enc_out.shape

torch.Size([32, 31, 512])

In [None]:
enc_out[0]

tensor([[ 2.0616, -0.3600,  2.7540,  ..., -0.0557,  1.1046,  0.9178],
        [ 0.5464, -0.1797,  1.9662,  ..., -0.5942,  1.4223, -0.2907],
        [ 1.6748,  0.4706,  1.7604,  ..., -0.7122,  1.0253, -0.5580],
        ...,
        [ 0.6022,  0.0037,  1.7303,  ...,  0.7644,  0.2974,  0.3956],
        [ 1.0232, -0.0172,  1.1006,  ...,  0.3857,  0.5727,  0.2795],
        [ 0.8282, -0.3790,  1.2736,  ...,  0.3487,  0.9906,  0.6590]],
       grad_fn=<SelectBackward0>)

### Deocder の各層を定義

### Decoder の Embedding 層

In [None]:
trg_input = trg[:, :-1]
print(f'before:{trg.shape}, after:{trg_input.shape}')

before:torch.Size([32, 68]), after:torch.Size([32, 67])


In [None]:
trg_vocab_length = len(vocab_en)

trg_embedder = nn.Embedding(trg_vocab_length, d_model)

embeded = trg_embedder(trg_input)
embeded.shape

torch.Size([32, 67, 512])

### Positional Encoder 層

In [None]:
pos_encoder = PositionalEncoder(d_model)

pos_embeded = pos_encoder(embeded)
pos_embeded.shape

torch.Size([32, 67, 512])

### Transformer の Decoder 層

In [None]:
# tgt_key_padding_mask
trg_pad_idx = vocab_en['<pad>']

def create_trg_pad_mask(trg):
        trg_pad_mask = trg == trg_pad_idx
        return trg_pad_mask

In [None]:
trg_pad_mask = create_trg_pad_mask(trg_input)
trg_pad_mask.shape

torch.Size([32, 67])

In [None]:
# tgt_mask
def generate_square_subsequent_mask(size):
    # tril下三角行列の生成(ones 任意のサイズの行列) 
    mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
    # 0 → -inf にし計算しない, 1 → 0.0 の行列に変換
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [None]:
trg_mask = generate_square_subsequent_mask(trg_input.size(1))
trg_mask

tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
        [0., 0., -inf,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [None]:
decoder_layer = nn.TransformerDecoderLayer(
    d_model,
    nhead=8,
    dim_feedforward=2048,
    dropout=0.1,
    activation='relu',
    batch_first=True
    )

decoder_layer

TransformerDecoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
  )
  (multihead_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
  )
  (linear1): Linear(in_features=512, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2048, out_features=512, bias=True)
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
  (dropout3): Dropout(p=0.1, inplace=False)
)

In [None]:
decoder_norm = nn.LayerNorm(d_model)
decoder = nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=6, norm=decoder_norm)

In [None]:
dec_out = decoder(pos_embeded, enc_out, tgt_mask=trg_mask, tgt_key_padding_mask=trg_pad_mask)
dec_out.shape

torch.Size([32, 67, 512])

In [None]:
dec_out[0]

tensor([[-1.8381, -1.4934, -0.6170,  ..., -1.0257,  0.3365, -0.4710],
        [-1.8192, -1.3145, -0.2520,  ..., -0.6463,  0.6646, -0.1213],
        [-1.8712, -1.3270, -0.7544,  ..., -0.3518,  0.1748, -0.2686],
        ...,
        [-0.8560, -1.1334, -0.6101,  ..., -0.1753,  0.0825, -0.2061],
        [-1.5711, -1.3280, -0.6516,  ..., -0.2121,  0.0088, -0.7075],
        [-1.1023, -1.1807, -0.4691,  ..., -0.0804,  0.2215, -0.1511]],
       grad_fn=<SelectBackward0>)

### Decoder の出力層

In [None]:
out = nn.Linear(d_model, trg_vocab_length)

logit = out(dec_out)
logit.shape

torch.Size([32, 67, 30664])

In [None]:
y_softmax = F.softmax(logit, dim=-1)

pred = y_softmax.max(axis=-1)[1][0]
pred

tensor([24164, 23082, 12876, 11824,  5054, 11824,  5054, 11824, 14663,  5054,
        14663, 16479, 12911, 27659, 11969, 12911, 14663,  5054,  5054,  5054,
        14663, 14663, 14663, 14663, 16479, 27659, 14663, 15050, 27659, 14663,
        14663,  5054, 14663, 14663, 14663, 14663, 14663, 12911, 14663, 20447,
        27659, 11824,  5054, 14663,  5054, 11967, 14663, 12911,  5054,  5054,
        14663, 11969, 24164, 12911, 27659, 14663, 14663, 12876, 16479, 24164,
         2653, 14663, 14663, 15050, 15050,  5054, 11969])

In [None]:
print(vocab_en.lookup_token(pred[0]))

Yojiuemon


### 損失の計算

In [None]:
# 先頭の <sos> は目標値に含まない
targets = trg[:, 1:].reshape(-1)
y = logit.view(-1, logit.size(-1))

loss = F.cross_entropy(y, targets, ignore_index=trg_pad_idx)

In [None]:
loss

tensor(10.5138, grad_fn=<NllLossBackward0>)

### Transformer のネットワークを定義

### Positional Encoder

In [None]:
class PositionalEncoder(pl.LightningModule):

    def __init__(self, d_model=512, max_seq_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=0.1)
        self.d_model = d_model
        
        pe = torch.zeros(max_seq_len, d_model)
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / 10000.0 ** ((2 * i) / d_model))
                pe[pos, i + 1] = math.cos(pos / 10000.0 ** ((2 * (i + 1)) / d_model))
        pe = pe.unsqueeze(1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x * math.sqrt(self.d_model)
        x = x + self.pe[:x.size(0), :]
        x = self.dropout(x)
        return x

### Transformer Encoder

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
class Encoder(pl.LightningModule):

    def __init__(self, src_vocab_length, d_model, nhead, dim_feedforward, num_encoder_layers, dropout, activation):
        super().__init__()

        self.src_embedding = nn.Embedding(src_vocab_length, d_model)
        self.pos_encoder = PositionalEncoder(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, batch_first=True)
        encoder_norm = nn.LayerNorm(d_model)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
    

    def forward(self, src, src_pad_mask):
        
        src_embeded = self.src_embedding(src)
        pos_src = self.pos_encoder(src_embeded)
        memory = self.encoder(pos_src, src_key_padding_mask=src_pad_mask)

        return memory

### Transformer Decoder

In [None]:
class Decoder(pl.LightningModule):

    def __init__(self, trg_vocab_length, d_model, nhead, dim_feedforward, num_decoder_layers, dropout, activation):
        super().__init__()

        self.trg_embedding = nn.Embedding(trg_vocab_length, d_model)
        self.pos_encoder = PositionalEncoder(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, batch_first=True)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)

    def forward(self, memory, trg_input, trg_mask, trg_pad_mask):
        
        trg_embeded = self.trg_embedding(trg_input)
        pos_trg = self.pos_encoder(trg_embeded)
        output = self.decoder(pos_trg, memory, tgt_mask=trg_mask, tgt_key_padding_mask=trg_pad_mask)

        return output

### Transformer

In [None]:
def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
        return optimizer

In [None]:
def reset_parameters(self):
        for param in self.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

In [None]:
class Transformer(pl.LightningModule):

    def __init__(self, src_vocab_length=10000, trg_vocab_length=10000, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6,
                  dim_feedforward=2048, dropout=0.1, activation="relu", src_pad_idx=1, trg_pad_idx=1):
        super().__init__()

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx

        self.encoder = Encoder(src_vocab_length, d_model, nhead, dim_feedforward, num_encoder_layers, dropout, activation)
        self.decoder = Decoder(trg_vocab_length, d_model, nhead, dim_feedforward, num_decoder_layers, dropout, activation)

        self.out = nn.Linear(d_model, trg_vocab_length)
        
        # Xavier の初期値を使う場
        #   self.reset_parameters()
        # def reset_parameters(self):
        #     for param in self.parameters():
        #         if param.dim() > 1:
        #             nn.init.xavier_uniform_(param)


    def create_pad_mask(self, input_word, pad_idx):
        pad_mask = input_word == pad_idx
        return pad_mask


    def generate_square_subsequent_mask(self, size):
        mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask==0, float('-inf')).masked_fill(mask==1, float(0.0)).to(device)
        return mask

        
    def forward(self, src, trg):
        trg_input = trg[:, :-1]

        # 各種 Mask
        src_pad_mask = self.create_pad_mask(src, self.src_pad_idx)
        trg_pad_mask = self.create_pad_mask(trg_input, self.trg_pad_idx)
        trg_mask = self.generate_square_subsequent_mask(trg_input.size(1))

        memory = self.encoder(src, src_pad_mask)
        output = self.decoder(memory, trg_input, trg_mask, trg_pad_mask)
        
        logit = self.out(output)
        return logit


    def training_step(self, batch, batch_idx):
        src, trg = batch

        logit = self(src, trg)
        
        targets = trg[:, 1:].reshape(-1)
        y = logit.view(-1, logit.size(-1))

        # ignore_index : 損失計算で <pad> のクラスを省く
        loss = F.cross_entropy(y, targets, ignore_index=self.trg_pad_idx)
        self.log('train_loss', loss, on_step=False, on_epoch=True)

        return loss


    def validation_step(self, batch, batch_idx):
        src, trg = batch

        logit = self(src, trg)
        
        targets = trg[:, 1:].reshape(-1)
        y = logit.view(-1, logit.size(-1))

        loss = F.cross_entropy(y, targets, ignore_index=self.trg_pad_idx)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        
        return loss


    def test_step(self, batch, batch_idx):
        src, trg = batch

        logit = self(src, trg)
        
        targets = trg[:, 1:].reshape(-1)
        y = logit.view(-1, logit.size(-1))

        loss = F.cross_entropy(y, targets, ignore_index=self.trg_pad_idx)
        self.log('test_loss', loss, on_step=False, on_epoch=True)
        return loss


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
        return optimizer

### 学習の実行

In [None]:
# 乱数シードの固定
pl.seed_everything(0)

src_vocab_length = len(vocab_ja)
trg_vocab_length = len(vocab_en)
src_pad_idx = vocab_ja['<pad>']
trg_pad_idx = vocab_en['<pad>']

# インスタンス化
net = Transformer(
    src_vocab_length=src_vocab_length,
    trg_vocab_length=trg_vocab_length,
    src_pad_idx=src_pad_idx,
    trg_pad_idx=trg_pad_idx
    )

trainer = pl.Trainer(max_epochs=5, gpus=1)
trainer.fit(net, train_loader, val_loader)

Global seed set to 0
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /content/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 30.6 M
1 | decoder | Decoder | 40.9 M
2 | out     | Linear  | 15.7 M
------------------------------------
87.2 M    Trainable params
0         Non-trainable params
87.2 M    Total params
348.981   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

                not been set for this class (_ResultMetric). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
result = trainer.test(dataloaders=test_loader)