<a href="https://colab.research.google.com/github/ShinAsakawa/ShinAsakawa.github.io/blob/master/2023notebooks/2023_1113chihaya_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 百人一首の上の句とエンコーダによって符号化し，下の句をデコーダで生成する自作 Transformer モデル

* date: 2023_0225
* author: 浅川伸一
* bibliography: [Attention Is All You Need](https://arxiv.org/abs/1706.03762)


# 準備 必要なライブラリの輸入と諸元の表示

In [None]:
%config InlineBackend.figure_format = 'retina'
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

from IPython import get_ipython
isColab =  'google.colab' in str(get_ipython())

if isColab:

    # GPU 情報を表示
    !nvidia-smi -L

    # `import bit` する前に termcolor を downgrade しないと colab ではテキストに色がつかない
    !pip install --upgrade termcolor==1.1
    import termcolor

    !pip install jaconv
    !git clone https://github.com/ShinAsakawa/RAM.git

import platform
HOSTNAME = platform.node().split('.')[0]

import os
HOME = os.environ['HOME']

try:
    import ipynbname
except ImportError:
    !pip install ipynbname
    import ipynbname
FILEPATH = str(ipynbname.path()).replace(HOME+'/','')

import pwd
USER=pwd.getpwuid(os.geteuid())[0]

from datetime import date
TODAY=date.today()

import torch
TORCH_VERSION = torch.__version__

from termcolor import colored

try:
    import japanize_matplotlib
except ImportError:
    !pip install japanize_matplotlib
    import japanize_matplotlib

color = 'green'
print('日付:',colored(f'{TODAY}', color=color, attrs=['bold']))
print('HOSTNAME:',colored(f'{HOSTNAME}', color=color, attrs=['bold']))
print('ユーザ名:',colored(f'{USER}', color=color, attrs=['bold']))
print('HOME:',colored(f'{HOME}', color=color,attrs=['bold']))
print('ファイル名:',colored(f'{FILEPATH}', color=color, attrs=['bold']))
print('torch.__version__:',colored(f'{TORCH_VERSION}', color=color, attrs=['bold']))

# 百人一首データのダウンロード

In [None]:
from collections import OrderedDict
import os
import sys
import numpy as np
import json
chihaya_fname = 'chihaya.json'

if os.path.exists(chihaya_fname):
    # カレントディレクトリに 'chihaya.json' があれば，その情報を読み込む
    with open(chihaya_fname, 'r') as fp:
        chihaya = OrderedDict(json.load(fp))
else:
    # カレントディレクトリに 'chihaya.json' がなければ，ダウンロード
    import requests
    url = 'http://www.diana.dti.ne.jp/~fujikura/List/List.html'
    page = requests.get(url)  # url から内容を取得

    from bs4 import BeautifulSoup
    soup = BeautifulSoup(page.content, 'html.parser')
    # print(soup.prettify()) 確認のため表示
    body = list(soup.children)[0]

    chihaya = OrderedDict()
    i = 1
    m = []
    # 最初と最後は百人一首の歌と無関係なため [1:-1] で除外
    for p in body.getText().split()[1:-1]:
        mod = i % 7

        if mod == 0:
            chihaya[N] = m
            print(chihaya[N])
            m = []
        elif mod == 1:
            N = int(p)
        elif mod > 2:
            m.append(p)
        i += 1

    # 後日のために，'chihaya.json' を書き出す
    if not os.path.exists(chihaya_fname):
        with open(chihaya_fname, 'w') as fp:
            json.dump(chihaya, fp, ensure_ascii=False, indent=4)

chihaya_chrs = OrderedDict()
for k, v in chihaya.items():

    # v[0]:漢字上の句，v[1]:漢字下の句，v[2]:ひらがな上の句，v[3]:ひらがな下の句
    for ku in [v[2], v[3]]:
        for ch in ku:
            if not ch in chihaya_chrs:
                chihaya_chrs[ch] = 1
            else:
                chihaya_chrs[ch] += 1

chihaya_tokens = sorted(chihaya_chrs.keys())
for tkn in reversed(['<PAD>','<SOS>','<EOS>','<UNK>']):
    chihaya_tokens.insert(0, tkn)

_chihaya = OrderedDict()
for k, v in chihaya.items():
    _chihaya[int(k)] = v
chihaya = _chihaya

idx2tkn = dict(enumerate(chihaya_tokens))  # トークン ID 番号から文字を返す辞書

# 文字からトークン ID を返す辞書
def tkn2idx(tkn:list, tokens=chihaya_tokens):
    ret = []
    for _tkn in tkn:
        #print(f'_tkn:{_tkn}')
        if not _tkn in tokens:
            ret.append(tokens.index('<UNK>'))
        else:
            ret.append(tokens.index(_tkn))
    return ret

print(f'idx2tkn:{idx2tkn}')
for tkn in chihaya_tokens:
    print(f'({tkn},{tkn2idx([tkn])})', end=" ")

# 自作 Transformer の輸入

In [3]:
from RAM import Transformer
model = Transformer(src_vocab_size=len(idx2tkn),
                    tgt_vocab_size=len(idx2tkn),
                    model_dim=32,
                    num_heads=4,
                    num_layers=1,
                    max_seq_length=22,
                    dropout=0.,
                    ff_dim=32).to(device)
model.eval();

# 乱数の種の設定

In [4]:
# 乱数のシードを設定
import random

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# 訓練

In [None]:
# 交差エントロピーによる損失関数
criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
# [Adam](https://arxiv.org/abs/1412.6980) による最適化関数の定義
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)

epochs = 6
for epoch in range(epochs):

    model.train()  # 訓練モードに設定
    losses = []
    epoch_loss, n_corrects = 0, 0
    Ns = np.random.permutation(len(chihaya))
    for i in Ns:
        kami, simo = chihaya[i+1][2], chihaya[i+1][3]
        optimizer.zero_grad()
        kami_ = torch.LongTensor([chihaya_tokens.index('<SOS>')]+tkn2idx(kami)+[chihaya_tokens.index('<EOS>')]).unsqueeze(0).to(device)
        simo_ = torch.LongTensor([chihaya_tokens.index('<SOS>')]+tkn2idx(simo)+[chihaya_tokens.index('<EOS>')]).unsqueeze(0).to(device)
        tch_ = torch.LongTensor([chihaya_tokens.index('<SOS>')]+tkn2idx(simo)+[chihaya_tokens.index('<EOS>')]).to(device)
        out = model(kami_, simo_).to(device) # 出力を得る

        loss = criterion(out[0], tch_)       # 損失値の計算
        loss.backward()                      # 誤差逆伝播
        optimizer.step()                     # 誤差に基づき学習ステップ実行
        epoch_loss += loss.item()            # 損失値総和

    model.eval()  # 評価モードに設定
    for i in range(len(chihaya)):
        kami, simo = chihaya[i+1][2], chihaya[i+1][3]
        kami_ = torch.LongTensor([chihaya_tokens.index('<SOS>')]+tkn2idx(kami)+[chihaya_tokens.index('<EOS>')]).unsqueeze(0).to(device)
        simo_ = torch.LongTensor([chihaya_tokens.index('<SOS>')]+tkn2idx(simo)+[chihaya_tokens.index('<EOS>')]).unsqueeze(0).to(device)
        out = model(kami_, simo_).detach().numpy()[0]
        _out = np.argmax(out, axis=-1)

        out_str = "".join([idx2tkn[idx] for idx in _out[1:-1]])  # 出力文字列の作成
        yesno = out_str == simo                                  # 正誤判断
        if yesno:
            n_corrects += 1                                      # 正答数の計測
        if yesno == False:                                       # 不正解の場合結果の表示
            print(f'{i+1:4d} ', end="")
            for i, c0 in enumerate(out_str):
                if i < len(simo):
                    color = 'blue' if c0 == simo[i] else 'red'
                    print(colored(c0, color=color, attrs=['bold']), end="")
            print(f', 正解(下句):{simo}',
                  f', 入力(上句):{kami}')

    # エポック毎の結果表示
    print(f'エポック:{epoch+1}',
          f'損失:{epoch_loss/len(chihaya):.5f}',
          f'正解率: {((n_corrects / len(chihaya)))*100:7.3f}%'),

In [None]:
import operator
import matplotlib.pyplot as plt
import japanize_matplotlib

count = {}
for k, v in chihaya.items():
    kami, shimo = v[2], v[3]
    for ch in kami+shimo:
        #print(ch, end=" ")
        if ch in count:
            count[ch] += 1
        else:
            count[ch] = 1
count_sorted = sorted(count.items(), key=operator.itemgetter(1), reverse=True)
plt.figure(figsize=(14,4))
N = np.array([x[1] for x in count.items()]).sum()
plt.bar(range(len(count_sorted)), [x[1]/N for x in count_sorted])
plt.xticks(ticks=range(len(count_sorted)), labels=[c[0] for c in count_sorted])
plt.title('百人一首の文字頻度')
plt.show()