<a href="https://colab.research.google.com/github/ShinAsakawa/ShinAsakawa.github.io/blob/master/2023notebooks/2023_0113lam_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ここはお遊びなので，スキップしても良い
import IPython
#IPython.display.Image(url="https://livedoor.blogimg.jp/ftb001/imgs/b/4/b4629a79.jpg")
IPython.display.Image(url="https://uy-allstars.com/_assets/images/pages/char/detail/webp/lum@pc.webp")

# 1 準備作業



## 1.1 ライブラリのインポート

1.   直下セルは，mecab をコンパイルするので時間がかかるので注意



In [None]:
%config InlineBackend.figure_format = 'retina'
try:
    import bit
except ImportError:
    !pip install ipynbname --upgrade > /dev/null 2>&1 
    !git clone https://github.com/ShinAsakawa/bit.git
import bit

isColab = bit.isColab
HOME = bit.HOME

if isColab:
    !apt install aptitude
    !aptitude install mecab libmecab-dev mecab-ipadic-utf8 git make curl xz-utils file -y
    !pip install mecab-python3==0.7
    !pip install jaconv
    
    import MeCab
    wakati = MeCab.Tagger('-Owakati').parse
    yomi = MeCab.Tagger('-Oyomi').parse
else:
    from ccap.mecab_settings import yomi
    from ccap.mecab_settings import wakati

# 自作ライブラリ LAM の読み込み
if isColab:
    !git clone https://github.com/ShinAsakawa/ccap.git
    !git clone https://github.com/ShinAsakawa/lam.git

## 1.2 Colab 上で実行する場合，必要なファイルのアップロード

In [None]:
# # upload `pslex71utf8.txt`, NTT 日本語の語彙特性 頻度データ
# # upload `lam/2022_0508SALA_TLPA.json` SALA と TLPA のデータが必要になるかもしれない。
# if isColab:
#     from google.colab import files
#     uploaded = files.upload()

## 1.3 パラメータ設定

In [None]:
%reload_ext autoreload
%autoreload 2

import torch
import lam
device = lam.device  # CPU or GPU の選択

from termcolor import colored

# シミュレーションに必要なパラメータの設定
params = {
    'traindata_size':  20000,   # 訓練データ数，NTT 日本語語彙特性の高頻度語を上位から抽出
    #'traindata_size': 301612,  # 訓練データ数，NTT 日本語語彙特性の高頻度語を上位から抽出
    'epochs': 20,               # 学習のためのエポック数
    'hidden_size': 128,         # 中間層のニューロン数
    'random_seed': 42,          # 乱数の種。ダグラス・アダムス著「銀河ヒッチハイカーズガイド」

    # 以下 `source` と `target` を定義することで，別の課題を実行可能
    'source': 'orth',          # ['orth', 'phon', 'mora', 'mora_p', 'mora_p_r']
    'target': 'phon',         # ['orth', 'phon', 'mora', 'mora_p', 'mora_p_r']
    #'target': 'mora_p_r',      # ['orth', 'phon', 'mora', 'mora_p', 'mora_p_r']
    # 'orthography': 書記素, 
    # 'phonology': 音韻, 
    # 'mora': モーラ
    # 'mora_p': モーラを silius による音分解
    # 'mora_p_r': モーラの silius 音分解の逆
    'pretrained': False,          # True であれば訓練済ファイルを読み込む
    #'pretrained': True,          # True であれば訓練済ファイルを読み込む
    #'isTrain'   : True,          # True であれば学習する
    
    # 学習済のモデルパラメータを保存するファイル名
    #'path_saved': '2022_0607lam_o2p_hid32_vocab10k.pt', 
    #'path_saved': '2022_0829lam_p2p_hid24_vocab10k.pt',
    'path_saved': False,                      # 保存しない場合
    
    # 結果の散布図を保存するファイル名    
    #'path_graph': '2022_0829lam_p2p_hid24_vocab10k.pdf',
    'path_graph': False,                     # 保存しない場合

    'lr': 0.0001,                              # 学習率
    'dropout_p': 0.0,                         # ドロップアウト率
    'teacher_forcing_ratio': 0.5,             # 教師強制を行う確率
    'optim_func': torch.optim.Adam,           # 最適化アルゴリズム ['torch.optim.Adam', 'torch.optim.SGD', 'torch.optim.AdamW']
    'loss_func' :torch.nn.CrossEntropyLoss(), # 交差エントロピー損失 ['torch.nn.NLLLoss()', or 'torch.nn.CrossEntropyLoss()']
}

## 1.3 データセットの設定

In [None]:
_vocab = lam.VOCAB(traindata_size=params['traindata_size'], 
                   w2v=None, 
                   yomi=yomi) 

source = params['source']
target = params['target']

# _max_len はアテンション機構のデコーダで必要になるため，全条件で最長の長さを指定する必要がある
_max_len = _vocab.max_ortho_length
_max_len = _max_len if _max_len > _vocab.max_phone_length else _vocab.max_phone_length
_max_len = _max_len if _max_len > _vocab.max_mora_length else _vocab.max_mora_length
_max_len = _max_len if _max_len > _vocab.max_mora_p_length else _vocab.max_mora_p_length
_vocab.max_length = _max_len + 1
print(colored(f'_vocab.max_length: {_vocab.max_length}', 'blue', attrs=['bold']))

# ソース，すなわち encoder 側の，項目番号，項目 ID，decoder 側の項目，項目 ID を設定
source_vocab, source_ids, target_vocab, target_ids = lam.get_soure_and_target_from_params(
    params=None,
    _vocab=_vocab,
    source=source,
    target=target,
    is_print=True)

print(colored(f'source:{source}','blue', attrs=['bold']), f'{source_vocab}')
print(colored(f'target:{target}','cyan', attrs=['bold']), f'{target_vocab}')
print(colored(f'source_ids:{source_ids}','blue', attrs=['bold']), f'{source_ids}')
print(colored(f'target_ids:{target_ids}','cyan', attrs=['bold']), f'{target_ids}')

# 検証データとして，TLPA と SALA のデータを用いる
tlpa1, tlpa2, tlpa3, tlpa4, sala_r29, sala_r30, sala_r31 = lam.read_json_tlpa1234_sala_r29_30_31(
    json_fname='lam/2022_0508SALA_TLPA.json')

_dataset = {}
_data_names = ['tlpa2', 'tlpa3', 'tlpa4', 'sala_r29', 'sala_r30', 'sala_r31']
for data in _data_names:
    _dataset[data] = {'rawdata':eval(data),
                      'pdata': lam.make_vocab_dataset(eval(data),vocab=_vocab)}

# 以下は後から付け足したので，コードが汚くなっている。
# 時間ができたらコードの整理をすること
X_vals = lam.make_X_vals(_dataset=_dataset,
                         source_vocab=source_vocab,
                         target_vocab=target_vocab,
                         source_ids=source_ids,
                         target_ids=target_ids
                        )

In [None]:
train_wordlist = [v['orig'] for k, v in _vocab.train_data.items()]
print(len(train_wordlist))

### 任意の単語 orthography を変換するための関数

In [None]:
def _get_ids_from_orth(orth_wrd:str='てれび',
                       __vocab:lam.lam.VOCAB=_vocab):
    _yomi, _phon, _phon_r, _orth, _orth_r, _mora, _mora_r = __vocab.get7lists_from_orth(orth=orth_wrd)
    phon_ids, phon_ids_r, orth_ids, orth_ids_r, mora_ids, mora_ids_r = __vocab.get6ids(yomi=_yomi, _phon=_phon, _orth=_orth)
    
    return {'_yomi':_yomi,
            '_phon':_phon,
            '_phon_r':_phon_r,
            '_orth':_orth,
            '_orth_r':_orth_r,
            '_mora':_mora,
            '_mora_r':_mora_r,
            'phon_ids':phon_ids,
            'phon_ids_r':phon_ids_r,
            'orth_ids':orth_ids,
            'orth_ids_r':orth_ids_r,
            'mora_ids':mora_ids,
            'mora_ids_r':mora_ids_r,
           }
    
print(_get_ids_from_orth())


def orth_ids2tkn(ids:list):
    return [_vocab.ortho_vocab[idx] for idx in ids]

def orth_tkn2ids(tkn:list):
    return [_vocab.ortho_vocab.index(_tkn) if _tkn in _vocab.ortho_vocab else _vocab.ortho_vocab.index('<UNK>') for _tkn in tkn]

def mora_p_ids2tkn(ids:list):
    return [_vocab.mora_p_vocab[idx] for idx in ids]

def mora_p_tkn2ids(tkn:list):
    return [_vocab.mora_p_vocab.index(_tkn) if _tkn in _vocab.mora_p_vocab else _vocab.mora_p_vocab('<UNK>') for _tkn in tkn]

def mora_ids2tkn(ids:list):
    return [_vocab.mora_vocab[idx] for idx in ids]

def mora_tkn2ids(tkn:list):
    return [_vocab.mora_vocab.index(_tkn) if _tkn in _vocab.mora_vocab else _vocab.mora_vocab('<UNK>') for _tkn in tkn]
    #return [_vocab.mora_vocab.index(_tkn) for _tkn in tkn]


_ids = [111, 298]
print(orth_ids2tkn(_ids))
print(orth_tkn2ids(orth_ids2tkn(_ids)))

print(orth_tkn2ids('新しい'))
print(orth_tkn2ids('神経心理学'))
print(orth_ids2tkn(orth_tkn2ids('神経心理学')))
print(mora_p_ids2tkn([17,19,11,4,32,17]))
print(mora_p_tkn2ids(mora_p_ids2tkn([17,19,11,4,32,17])))
print(mora_ids2tkn([37,139,31,7]))
print(mora_tkn2ids(mora_ids2tkn([37,139,31,7])))

### SALA and TLPA dataset

In [None]:
# 検証データとして，TLPA と SALA のデータを用いる
tlpa1, tlpa2, tlpa3, tlpa4, sala_r29, sala_r30, sala_r31 = lam.read_json_tlpa1234_sala_r29_30_31(
    json_fname='lam/2022_0508SALA_TLPA.json')

_dataset = {}
_data_names = ['tlpa2', 'tlpa3', 'tlpa4', 'sala_r29', 'sala_r30', 'sala_r31']
for data in _data_names:
    _dataset[data] = {'rawdata':eval(data),
                      'pdata': lam.make_vocab_dataset(eval(data), vocab=_vocab)}

# 以下は後から付け足したので，コードが汚くなっている。
# 時間ができたらコードの整理をすること
# X_vals = lam.make_X_vals(_dataset=_dataset,
#                          source_vocab=source_vocab,
#                          target_vocab=target_vocab,
#                          source_ids=source_ids,
#                          target_ids=target_ids)

_data_names = ['tlpa2', 'tlpa3', 'tlpa4', 'sala_r29', 'sala_r30', 'sala_r31']    
for data in _data_names:
    print(colored(data, 'blue', attrs=['bold']), eval(data))

### Fushimi1999 データセット

In [None]:
fushimi1999 = {
    'HF___consist__': ['戦争', '倉庫', '医学', '注意', '記念', '番号', '料理', '完全', '開始', '印刷',
                       '連続', '予約', '多少', '教員', '当局', '材料', '夕刊', '労働', '運送', '電池' ], # consistent, 'high-frequency words
    'HF___inconsist': ['反対', '失敗', '作品', '指定', '実験', '決定', '独占', '独身', '固定', '食品',
                       '表明', '安定', '各種', '役所', '海岸', '決算', '地帯', '道路', '安打', '楽団' ], # inconsistent, 'high-frequency words
    'HF___atypical_': ['仲間', '夫婦', '人間', '神経', '相手', '反発', '化粧', '建物', '彼女', '毛糸', 
                       '場合', '台風', '夜間', '人形', '東西', '地元', '松原', '競馬', '大幅', '貸家' ], # inconsistent atypical, 'high-frequency words
    'LF___consist__': ['集計', '観察', '予告', '動脈', '理学', '信任', '任務', '返信', '医局', '低温', 
                       '区別', '永続', '持続', '試練', '満開', '軍備', '製材', '銀貨', '急送', '改選' ], # consistent, 'low-frequecy words
    'LF___inconsist': ['表紙', '指針', '熱帯', '作詞', '決着', '食費', '古代', '地形', '役場', '品種', 
                       '祝福', '金銭', '根底', '接種', '経由', '郷土', '街路', '宿直', '曲折', '越境' ], # inconsistent, 'low-frequency words
    'LF___atypical_': ['強引', '寿命', '豆腐', '出前', '歌声', '近道', '間口', '風物', '面影', '眼鏡', 
                       '居所', '献立', '小雨', '毛皮', '鳥居', '仲買', '頭取', '極上', '奉行', '夢路' ], # inconsistent atypical, 'low-frequncy words
    'HFNW_consist__': ['集学', '信別', '製信', '運学', '番送', '電続', '完意', '軍開', '動選', '当働', 
                       '予続', '倉理', '予少', '教池', '理任', '銀務', '連料', '開員', '注全', '記争' ], # consistent, 'high-character-frequency nonwords
    'HFNW_inconsist': ['作明', '風行', '失定', '指団', '決所', '各算', '海身', '東発', '楽験', '作代',
                       '反原', '独対', '歌上', '反定', '独定', '場家', '安種', '経着', '決土', '松合' ], # inconsistent biased, 'high-character-frequency nonwords
    'HFNW_ambiguous': ['表品', '実定', '人風', '神間', '相経', '人元', '小引', '指場', '毛所', '台手',
                       '間物', '道品', '出取', '建馬', '大婦', '地打', '化間', '面口', '金由', '彼間' ], # inconsistent ambigous, 'high-character-frequency nonwords
    'LFNW_consist__': ['急材', '戦刊', '返計', '印念', '低局', '労号', '満送', '永告', '試脈', '観備',
                       '材約', '夕局', '医庫', '任続', '医貨', '改練', '区温', '多始', '材刷', '持察' ], # consistent, 'low-character-frequency nonwords
    'LFNW_inconsist': ['食占', '表底', '宿帯', '決帯', '古費', '安敗', '役針', '近命', '眼道', '豆立',
                       '街直', '固路', '郷種', '品路', '曲銭', '献居', '奉買', '根境', '役岸', '祝折' ], # inconsistent biased, 'low-character-frequency nonwords
    'LFNW_ambiguous': ['食形', '接紙', '競物', '地詞', '強腐', '頭路', '毛西', '夜糸', '仲影', '熱福',
                       '寿前', '鳥雨', '地粧', '越種', '仲女', '極鏡', '夢皮', '居声', '貸形', '夫幅' ], # inconsistent ambigous, 'low-character-frequency nonwords
}

for k, v in fushimi1999.items():
    print(colored(k, 'blue', attrs=['bold']), v)

fushimi1999_list = []
for k, v in fushimi1999.items():
    for _v in v:
        fushimi1999_list.append(_v)

train_wordlist = [v['orig'] for k, v in _vocab.train_data.items()]

for i, wrd in enumerate(fushimi1999_list):
    if wrd in train_wordlist:
        idx = train_wordlist.index(wrd)
        print(f'{i:3d} {wrd}:{idx:5d}')
        #print(i, wrd, idx, _vocab.train_data[idx])        

In [None]:
for i, w in enumerate(fushimi1999_list):
    ids = orth_tkn2ids(w)
    tnk = orth_ids2tkn(ids)
    if i > 125:
        print(i, w, ids, tnk)
        o = _get_ids_from_orth(orth_wrd=w)
        print(colored((i,w),'blue',attrs=['bold']), o)

### テストのための一文字の書記素データセットの作成

In [None]:
#source
# _src = 'orth' if source == 'orthography' else source
# _tgt = 'orth' if target == 'orthography' else target

# _src = 'mora' if source == 'mora_p' else source
# _tgt = 'mora' if target == 'mora_p' else source

_src, _tgt = source+'_ids', target+'_ids'
print(_src,_tgt)

_get_ids_from_orth('あい')
#_get_ids_from_orth('あ')

In [None]:
one_chars0 = '０１２３４５６７８９ＡＢＣＤＥＦＧＨＩＪＫＬＭＮＯＰＱＲＳＴＵＶＷＸＹＺａｂｃｄｅｆｇｈｉｊｋｌｍｎｏｐｑｒｓｔｕｖｗｘｙｚ'
one_chars1 = 'あいうえおかがきぎくぐけげこごさざしじすずせぜそぞただちぢつづてでとどなにぬねのはばぱひびぴふぶぷへべぺほぼぽまみむめもやゆよらりるれろわをん'
one_chars2 = 'アイウエオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモヤユヨラリルレロワヲン'
one_chars = one_chars0 + one_chars1 + one_chars2
for ch in one_chars:
    tmp = _get_ids_from_orth(ch)
    __src, __tgt, __yomi = tmp[_src], tmp[_tgt], tmp['_yomi']
    print(colored(ch, 'blue', attrs=['bold']), f'__src:{__src}, __tgt:{__tgt}, __yomi:{__yomi}')

In [None]:
class Onechar_dataset(torch.utils.data.Dataset): #_vocab.train_data):
    def __init__(self,
                 source:str='orthography',
                 target:str='mora_p',                
                ):
        super().__init__()

        _src = 'orth' if source == 'orthography' else source
        _tgt = 'orth' if target == 'orthography' else target

        _src = 'mora' if _src == 'mora_p' else _src
        _tgt = 'mora' if _tgt == 'mora_p' else _tgt
        _src, _tgt = _src+'_ids', _tgt+'_ids'
        
        one_chars0 = '０１２３４５６７８９ＡＢＣＤＥＦＧＨＩＪＫＬＭＮＯＰＱＲＳＴＵＶＷＸＹＺａｂｃｄｅｆｇｈｉｊｋｌｍｎｏｐｑｒｓｔｕｖｗｘｙｚ'
        one_chars1 = 'あいうえおかがきぎくぐけげこごさざしじすずせぜそぞただちぢつづてでとどなにぬねのはばぱひびぴふぶぷへべぺほぼぽまみむめもやゆよらりるれろわをん'
        one_chars2 = 'アイウエオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモヤユヨラリルレロワヲン'
        
        all_chars = one_chars0 + one_chars1 + one_chars2
        self.all_chars = all_chars
        data_dict = {}
        for i, ch in enumerate(all_chars):
            tmp = _get_ids_from_orth(ch)
            __src, __tgt, __yomi = tmp[_src], tmp[_tgt], tmp['_yomi']
            data_dict[i] = {'yomi':__yomi,
                            'src':__src,
                            'tgt':__tgt,
                           }
            #print(colored(ch, 'blue', attrs=['bold']), f'__src:{__src}, __tgt:{__tgt}, __yomi:{__yomi}')
        self.data_dict = data_dict

    def __len__(self)->int:
        return len(self.data_dict)
    
    def __getitem__(self,
                    x:int):
        _data = self.data_dict[x]
        return _data['src'], _data['tgt']
        
onechar_dataset = Onechar_dataset(source=source,
                                  target=target
                                 )
print(onechar_dataset.__len__())
print(onechar_dataset.__getitem__(0))

# 2. モデルの定義

In [None]:
# 自作ライブラリ LAM の読み込み
import lam 

from lam import EncoderRNN
from lam import AttnDecoderRNN
from lam import calc_accuracy

if (params['pretrained']) and (params['path_saved'] != False) and os.path.exists(params['path_saved']):
    """セーブした学習済のモデルがあれば読み込む"""
    
    checkpoint = torch.load(params['path_saved'])
    encoder = EncoderRNN(len(source_vocab), params['hidden_size']).to(device)
    decoder = AttnDecoderRNN(n_hid=params['hidden_size'], 
                             n_out=len(target_vocab), 
                             dropout_p=params['dropout_p'],
                             max_length=_vocab.max_length).to(device)
    encoder.load_state_dict(checkpoint['encoder'])
    decoder.load_state_dict(checkpoint['decoder'])
    encoder.eval()
    decoder.eval()
    #losses = []

    print(colored(f"セーブした学習済のモデル {params['path_saved']} があるので読み込みました",
          color='blue', attrs=['bold']))

else:
    encoder = EncoderRNN(len(_vocab.ortho_vocab), 
                         params['hidden_size']).to(device)
    decoder = AttnDecoderRNN(n_hid=params['hidden_size'], 
                             n_out=len(target_vocab), 
                             dropout_p=params['dropout_p'],
                             max_length=_vocab.max_length
                            ).to(device)

# モデルの概要を印字
print(f'encoder:{encoder}')
print(f'decoder:{decoder}')
        
for test_name, val_dataset in X_vals.items():
    acc = calc_accuracy(_dataset=val_dataset,
                        encoder=encoder,
                        decoder=decoder,
                        #decoder=decoder2,
                        max_length=_vocab.max_length,
                        source_vocab=source_vocab,
                        target_vocab=target_vocab,
                        source_ids=source_ids,
                        target_ids=target_ids,
                       )
    print(colored(f'{test_name} の精度:{acc:.3f}','blue', attrs=['bold']))


# params の印刷
print(colored(params,'blue',attrs=['bold']))    

In [None]:
import math
import random
import numpy as np
import time

def asMinutes(s:int)->str:
    """時間変数を見やすいように，分と秒に変換して返す"""
    m = math.floor(s / 60)
    s -= m * 60
    return f'{int(m):2d}分 {int(s):2d}秒'
    return '%dm %ds' % (m, s)


def timeSince(since:time.time,
            percent:time.time)->str:
    """開始時刻 since と，現在の処理が全処理中に示す割合 percent を与えて，経過時間と残り時間を計算して表示する"""
    now = time.time()  #現在時刻を取得
    s = now - since    # 開始時刻から現在までの経過時間を計算
    #s = since - now
    es = s / (percent) # 経過時間を現在までの処理割合で割って終了予想時間を計算
    rs = es - s        # 終了予想時刻から経過した時間を引いて残り時間を計算

    return f'経過時間:{asMinutes(s)} (残り時間 {asMinutes(rs)})'

def evaluate(encoder:torch.nn.Module,
             decoder:torch.nn.Module,
             input_ids:list=None,
             max_length:int=1,
             source_vocab:list=None,
             target_vocab:list=None,
             source_ids:list=None,
             target_ids:list=None,
            )->(list,torch.LongTensor):
    
    with torch.no_grad():
        input_tensor = convert_ids2tensor(input_ids)
        input_length = input_tensor.size()[0]
        encoder_hidden = encoder.initHidden()

        encoder_outputs = torch.zeros(max_length, encoder.n_hid, device=device)

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
            encoder_outputs[ei] += encoder_output[0, 0]

        decoder_input = torch.tensor([[source_vocab.index('<SOW>')]], device=device)
        decoder_hidden = encoder_hidden

        decoded_words, decoded_ids = [], []  # decoded_ids を追加
        decoder_attentions = torch.zeros(max_length, max_length)

        for di in range(max_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs, device=device)
            decoder_attentions[di] = decoder_attention.data
            topv, topi = decoder_output.data.topk(1)
            decoded_ids.append(int(topi.squeeze().detach())) # decoded_ids に追加
            if topi.item() == target_vocab.index('<EOW>'):
                decoded_words.append('<EOW>')
                break
            else:
                decoded_words.append(target_vocab[topi.item()])

            decoder_input = topi.squeeze().detach()

        return decoded_words, decoded_ids, decoder_attentions[:di + 1]  # decoded_ids を返すように変更
        #return decoded_words, decoder_attentions[:di + 1]


def check_vals_performance(encoder=None, decoder=None,
                           _dataset=None,
                           max_length=0,
                           source_vocab=None,
                           target_vocab=None,
                           source_ids=None,
                           target_ids=None,
                           ):
    if _dataset == None or encoder == None or decoder == None or max_length == 0 or source_vocab == None:
        return
    print('検証データ:',end="")
    for _x in _dataset:
        ok_count = 0
        for i in range(_dataset[_x].__len__()):
            _input_ids, _target_ids = _dataset[_x].__getitem__(i)
            _output_words, _output_ids, _attentions = evaluate(encoder, decoder, _input_ids,
                                                               max_length,
                                                               source_vocab=source_vocab,
                                                               target_vocab=target_vocab,
                                                               source_ids=source_ids,
                                                               target_ids=target_ids,
                                                               )
            ok_count += 1 if _target_ids == _output_ids else 0
        print(f'{_x}:{ok_count/_dataset[_x].__len__():.3f},',end="")
    print()


def convert_ids2tensor(sentence_ids:list, 
                       device:torch.device=torch.device("cuda:0" if torch.cuda.is_available () else "cpu")):
    
    """数値 ID リストをテンソルに変換
    例えば，[0,1,2] -> tensor([[0],[1],[2]])
    """
    return torch.tensor(sentence_ids, dtype=torch.long, device=device).view(-1, 1)

def _train(input_tensor:torch.Tensor=None, 
           target_tensor:torch.Tensor=None,
           encoder:torch.nn.Module=None, 
           decoder:torch.nn.Module=None,
           encoder_optimizer:torch.optim=None, 
           decoder_optimizer:torch.optim=None,
           criterion:torch.nn.modules.loss=torch.nn.modules.loss.CrossEntropyLoss,
           max_length:int=1,
           target_vocab:list=None,
           teacher_forcing_ratio:float=0.,
           device:torch.device=None)->float:
    encoder_hidden = encoder.initHidden() # 符号化器の中間層を初期化
    encoder_optimizer.zero_grad()         # 符号化器の最適化関数の初期化
    decoder_optimizer.zero_grad()         # 復号化器の最適化関数の初期化

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)
    encoder_outputs = torch.zeros(max_length, encoder.n_hid, device=device)
    loss = 0

    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]

    decoder_input = torch.tensor([[target_vocab.index('<SOW>')]], device=device)
    decoder_hidden = encoder_hidden

    ok_flag = True
    # 教師強制をするか否かを確率的に決める
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    if use_teacher_forcing: # 教師強制する場合 Teacher forcing: Feed the target as the next input
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs,device=device)
            loss += criterion(decoder_output, target_tensor[di])
            ok_flag = (ok_flag) and (decoder_output.argmax() == target_tensor[di].detach().numpy()[0])
            decoder_input = target_tensor[di]  # Teacher forcing

    else: # 教師強制しない場合 Without teacher forcing: use its own predictions as the next input
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs,device=device)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()  # detach from history as input

            loss += criterion(decoder_output, target_tensor[di])
            ok_flag = (ok_flag) and (decoder_output.argmax() == target_tensor[di].detach().numpy()[0])
            if decoder_input.item() == target_vocab.index('<EOW>'):
                break

    loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()
    return loss.item() / target_length, ok_flag


def _fit(encoder:torch.nn.Module, 
         decoder:torch.nn.Module,
         epochs:int=1,
         lr:float=0.0001,
         n_sample:int=3,
         teacher_forcing_ratio=False,
         train_dataset:torch.utils.data.Dataset=None,
         val_dataset:dict=None,
         source_vocab:list=None,
         target_vocab:list=None,
         source_ids:str=None,
         target_ids:list=None,
         params:dict=None,
         max_length:int=1,
         device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
        )->list:

    start_time = time.time()

    encoder.train()
    decoder.train()
    encoder_optimizer = params['optim_func'](encoder.parameters(), lr=lr)
    decoder_optimizer = params['optim_func'](decoder.parameters(), lr=lr)
    criterion = params['loss_func']
    losses = []

    for epoch in range(epochs):
        epoch_loss = 0
        ok_count = 0
        
        #エポックごとに学習順をシャッフルする
        learning_order = np.random.permutation(train_dataset.__len__())
        
        for i in range(train_dataset.__len__()):
            x = learning_order[i]   # ランダムにデータを取り出す
            input_ids, target_ids = train_dataset.__getitem__(x)
            input_tensor = convert_ids2tensor(input_ids)
            target_tensor = convert_ids2tensor(target_ids)

            #訓練の実施
            loss, ok_flag = _train(input_tensor=input_tensor, 
                                   target_tensor=target_tensor,
                                   encoder=encoder, 
                                   decoder=decoder,
                                   encoder_optimizer=encoder_optimizer, 
                                   decoder_optimizer=decoder_optimizer,
                                   criterion=criterion,
                                   max_length=max_length,
                                   target_vocab=target_vocab,
                                   teacher_forcing_ratio=teacher_forcing_ratio,
                                   device=device)
            epoch_loss += loss
            ok_count += 1 if ok_flag else 0


        losses.append(epoch_loss/train_dataset.__len__())
        print(colored(f'エポック:{epoch:2d} 損失:{epoch_loss/train_dataset.__len__():.2f}', 'blue', attrs=['bold']),
              colored(f'{timeSince(start_time, (epoch+1) * train_dataset.__len__()/(epochs * train_dataset.__len__()))}',
                      'cyan', attrs=['bold']),
              colored(f'訓練データの精度:{ok_count/train_dataset.__len__():.3f}', 'blue', attrs=['bold']))

        check_vals_performance(_dataset=val_dataset,
                               encoder=encoder,
                               decoder=decoder,
                               max_length=max_length,
                               source_vocab=source_vocab,
                               target_vocab=target_vocab,
                               source_ids=source_ids,
                               target_ids=target_ids)
        if n_sample > 0:
            evaluateRandomly(encoder, decoder, n=n_sample)

    return losses


In [None]:
losses = []
losses += _fit(encoder=encoder, 
               decoder=decoder, 
               device=device,
               epochs=params['epochs'], 
               max_length=_vocab.max_length,
               n_sample=0,
               params=params,
               source_vocab=source_vocab,
               target_vocab=target_vocab,
               source_ids=source_ids,
               target_ids=target_ids,
               teacher_forcing_ratio=params['teacher_forcing_ratio'],
               train_dataset=onechar_dataset,
               lr=params['lr'],
               val_dataset=X_vals,
              )

In [None]:
losses += _fit(encoder=encoder, 
               decoder=decoder, 
               device=device,
               epochs=params['epochs'], 
               max_length=_vocab.max_length,
               n_sample=0,
               params=params,
               source_vocab=source_vocab,
               target_vocab=target_vocab,
               source_ids=source_ids,
               target_ids=target_ids,
               teacher_forcing_ratio=params['teacher_forcing_ratio'],
               train_dataset=onechar_dataset,
               lr=params['lr'],
               val_dataset=X_vals,
              )

In [None]:
# 訓練データセットと検証データセットを作成
train_dataset = lam.Train_dataset(data=_vocab.train_data,
                                  source_vocab=source_vocab, 
                                  target_vocab=target_vocab,
                                  source_ids=source_ids,   # おそらくこの 2 行を入れないといけなかった
                                  target_ids=target_ids)  # そうでなければ，デフォルトの `mora_p_r` になってしまう

P  = int(train_dataset.__len__() * 0.9)
_P = train_dataset.__len__() - P
_train_dataset, val_dataset = torch.utils.data.random_split(dataset=train_dataset,
                                                           lengths=(P, _P),
                                                           generator=torch.Generator().manual_seed(42),
                                                          )

# val_dataset = lam.Val_dataset(data=_dataset['sala_r29']['pdata'],
#                               source_vocab=source_vocab, 
#                               target_vocab=target_vocab,
#                               source_ids=source_ids,
#                               target_ids=target_ids)
                
# 訓練データセットと検証データセットを作成
print(f'len(train_dataset):{len(train_dataset)}',
      f'len(val_dataset):{len(val_dataset)}')      

In [None]:
losses = []
losses += _fit(encoder=encoder, 
               decoder=decoder, 
               device=device,
               epochs=params['epochs'], 
               max_length=_vocab.max_length,
               n_sample=0,
               params=params,
               source_vocab=source_vocab,
               target_vocab=target_vocab,
               source_ids=source_ids,
               target_ids=target_ids,
               teacher_forcing_ratio=params['teacher_forcing_ratio'],
               train_dataset=train_dataset,
               #train_dataset=onechar_dataset,
               lr=params['lr'],
               val_dataset=X_vals,
              )