<a href="https://colab.research.google.com/github/ShinAsakawa/ShinAsakawa.github.io/blob/master/2023notebooks/2023_0106lam_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ここはお遊びなので，スキップしても良い
import IPython
IPython.display.Image(url="https://livedoor.blogimg.jp/ftb001/imgs/b/4/b4629a79.jpg")

# 1 準備作業


## 1.1 シミュレーションに必要なパラメータの設定


In [None]:
import torch

# シミュレーションに必要なパラメータの設定
params = {
    'traindata_size':   10000,    # 訓練データ数，NTT 日本語語彙特性の高頻度語を上位から抽出
    #'traindata_size': 301612,    # 訓練データ数，NTT 日本語語彙特性の高頻度語を上位から抽出
    'epochs': 20,                # 学習のためのエポック数
    'hidden_size': 24,           # 中間層のニューロン数
    'random_seed': 42,           # 乱数の種。ダグラス・アダムス著「銀河ヒッチハイカーズガイド」

    # 以下 `source` と `target` を定義することで，別の課題を実行可能
    'source': 'orthography',        # ['orthography', 'phonology', 'mora', 'mora_p', 'mora_p_r']
    'target': 'mora_p_r',          # ['orthography', 'phonology', 'mora', 'mora_p', 'mora_p_r']
    # 'orthography': 書記素, 
    # 'phonology': 音韻, 
    # 'mora': モーラ
    # 'mora_p': モーラを silius による音分解
    # 'mora_p_r': モーラの silius 音分解の逆
    'pretrained': False,          # True であれば訓練済ファイルを読み込む
    #'pretrained': True,          # True であれば訓練済ファイルを読み込む
    'isTrain'   : True,          # True であれば学習する
    
    # 学習済のモデルパラメータを保存するファイル名
    #'path_saved': '2022_0607lam_o2p_hid32_vocab10k.pt', 
    #'path_saved': '2022_0829lam_p2p_hid24_vocab10k.pt',
    'path_saved': False,                      # 保存しない場合
    
    # 結果の散布図を保存するファイル名    
    'path_graph': '2022_0829lam_p2p_hid24_vocab10k.pdf',
    #'path_graph': False,                      # 保存しない場合

    'lr': 0.001,                              # 学習率
    'dropout_p': 0.0,                         # ドロップアウト率
    'teacher_forcing_ratio': 0.5,             # 教師強制を行う確率
    'optim_func': torch.optim.Adam,           # 最適化アルゴリズム ['torch.optim.Adam', 'torch.optim.SGD', 'torch.optim.AdamW']
    'loss_func' :torch.nn.CrossEntropyLoss(), # 交差エントロピー損失 ['torch.nn.NLLLoss()', or 'torch.nn.CrossEntropyLoss()']
}

## 1.2 ライブラリのインポート

In [None]:
%config InlineBackend.figure_format = 'retina'
try:
    import bit
except ImportError:
    !pip install ipynbname --upgrade > /dev/null 2>&1 
    !git clone https://github.com/ShinAsakawa/bit.git
import bit

isColab = bit.isColab
HOME = bit.HOME

if isColab:
    !apt install aptitude
    !aptitude install mecab libmecab-dev mecab-ipadic-utf8 git make curl xz-utils file -y
    !pip install mecab-python3==0.7
    !pip install jaconv
    
    import MeCab
    wakati = MeCab.Tagger('-Owakati').parse
    yomi = MeCab.Tagger('-Oyomi').parse
else:
    from ccap.mecab_settings import yomi
    from ccap.mecab_settings import wakati

# 自作ライブラリ LAM の読み込み
if isColab:
    !git clone https://github.com/ShinAsakawa/ccap.git
    !git clone https://github.com/ShinAsakawa/lam.git

## 1.3. データセットのアップロード

In [None]:
# データセットのアップロード
# upload `2022_1018lam_traindata10k.json.gz` from local drive

if isColab:
    from google.colab import files
    uploaded = files.upload()


## 1.4 アップロードしたデータセットの展開

In [None]:
# アップロードしたデータセットの展開
import torch
import lam

import os
import json
import gzip
from termcolor import colored

gz_fname = '2022_1018lam_traindata10k.json.gz'
with gzip.open(gz_fname, 'rt', encoding='utf-8') as fp:
    A = json.loads(fp.readlines()[0])
_keys = list(A.keys())

print(_keys)
class makeA:
    def __init__(self, X):
        self.c1 = X['c1']
        self.c2 = X['c2']
        self.c3 = X['c3']
        self.c4 = X['c4']
        self.cond = X['cond']
        self.excluded_data = X['excluded_data']
        self.ja_symbols = X['ja_symbols']
        
        self.ja_symbols_normalized = X['ja_symbols_normalized']
        self.max_mora_length = X['max_mora_length']
        self.max_mora_p_length = X['max_mora_p_length']
        self.max_ortho_length = X['max_ortho_length']
        self.max_phone_length = X['max_phone_length']
        self.mora2jul = X['mora2jul']
        self.mora_freq = X['mora_freq']
        self.mora_p = X['mora_p']
        self.mora_p_vocab = X['mora_p_vocab']
        self.mora_vocab = X['mora_vocab']
        self.ntt_freq = X['ntt_freq']
        self.ntt_freq_vocab = X['ntt_freq_vocab']
        self.ntt_orth2hira = X['ntt_orth2hira']
        self.ortho_vocab = X['ortho_vocab']
        self.phone_vocab = X['phone_vocab']
        self.train_data = X['train_data']
        self.traindata_size = X['traindata_size']
        self.vow2hira = X['vow2hira']
        
_vocab = makeA(A)
#print(dir(_vocab))
#_vocab.vow2hira

## 1.5 検証データセットの設定

In [None]:
__vocab = lam.VOCAB(traindata_size=0, w2v=None, yomi=yomi) 

source = params['source']
target = params['target']

# _max_len はアテンション機構のデコーダで必要になるため，全条件で最長の長さを指定する必要がある
_max_len = _vocab.max_ortho_length
_max_len = _max_len if _max_len > _vocab.max_phone_length else _vocab.max_phone_length
_max_len = _max_len if _max_len > _vocab.max_mora_length else _vocab.max_mora_length
_max_len = _max_len if _max_len > _vocab.max_mora_p_length else _vocab.max_mora_p_length
_vocab.max_length = _max_len + 1
print(colored(f'_vocab.max_length: {_vocab.max_length}', 'blue', attrs=['bold']))

# ソース，すなわち encoder 側の，項目番号，項目 ID，decoder 側の項目，項目 ID を設定
source_vocab, source_ids, target_vocab, target_ids = lam.get_soure_and_target_from_params(
    params=None,
    _vocab=_vocab,
    source=source,
    target=target,
    is_print=False)

print(colored(f'source:{source}','blue', attrs=['bold']), f'{source_vocab}')
print(colored(f'target:{target}','cyan', attrs=['bold']), f'{target_vocab}')
print(colored(f'source_ids:{source_ids}','blue', attrs=['bold']), f'{source_ids}')
print(colored(f'target_ids:{target_ids}','cyan', attrs=['bold']), f'{target_ids}')

# 検証データとして，TLPA と SALA のデータを用いる
tlpa1, tlpa2, tlpa3, tlpa4, sala_r29, sala_r30, sala_r31 = lam.read_json_tlpa1234_sala_r29_30_31(
    json_fname='lam/2022_0508SALA_TLPA.json')

_dataset = {}
_data_names = ['tlpa2', 'tlpa3', 'tlpa4', 'sala_r29', 'sala_r30', 'sala_r31']
for data in _data_names:
    _dataset[data] = {'rawdata':eval(data),
                      'pdata': lam.make_vocab_dataset(eval(data),vocab=__vocab)}

# 以下は後から付け足したので，コードが汚くなっている。
# 時間ができたらコードの整理をすること
X_vals = lam.make_X_vals(_dataset=_dataset,
                         source_vocab=source_vocab,
                         target_vocab=target_vocab)

## 1.6 PyTorch 用 データセットの作成

In [None]:
# 訓練データセットと検証データセットを作成
train_dataset = lam.Train_dataset(data=_vocab.train_data,
                                  source_vocab=source_vocab, 
                                  target_vocab=target_vocab)

val_dataset = lam.Val_dataset(data=_dataset['sala_r29']['pdata'],
                               source_vocab=source_vocab, 
                               target_vocab=target_vocab)
                
print(f'len(train_dataset):{len(train_dataset)}',
      f'len(val_dataset):{len(val_dataset)}')      

# 2 モデルの定義

In [None]:
# 自作ライブラリ LAM の読み込み
import lam 
from lam import EncoderRNN
from lam import AttnDecoderRNN
# from lam import convert_ids2tensor
# from lam import train
# from lam import asMinutes, timeSince
# #from lam import fit
# from lam import convert_ids2tensor
# from lam import fix_seed
# from lam import worker_init_fn
# from lam import make_vocab_dataset

device = lam.device  # CPU or GPU の選択

from lam import calc_accuracy

if (params['pretrained']) and (params['path_saved'] != False) and os.path.exists(params['path_saved']):
    """セーブした学習済のモデルがあれば読み込む"""
    
    checkpoint = torch.load(params['path_saved'])
    encoder = EncoderRNN(len(source_vocab), params['hidden_size']).to(device)
    decoder = AttnDecoderRNN(n_hid=params['hidden_size'], 
                             n_out=len(target_vocab), 
                             dropout_p=params['dropout_p'],
                             max_length=_vocab.max_length).to(device)
    encoder.load_state_dict(checkpoint['encoder'])
    decoder.load_state_dict(checkpoint['decoder'])
    encoder.eval()
    decoder.eval()
    losses = []
    
    encoder2 = EncoderRNN(len(target_vocab), params['hidden_size']).to(device)
    decoder2 = AttnDecoderRNN(n_hid=params['hidden_size'], 
                              n_out=len(target_vocab), 
                              dropout_p=params['dropout_p'],
                              max_length=_vocab.max_length).to(device)
    encoder2.load_state_dict(checkpoint['encoder'])
    decoder2.load_state_dict(checkpoint['decoder'])
    encoder2.eval()
    decoder2.eval()
    
    print(colored(f"セーブした学習済のモデル {params['path_saved']} があるので読み込みました",
          color='blue', attrs=['bold']))
    # print(encoder)
    # print(decoder)
    # print(encoder2)
    # print(decoder2)

else:
    encoder = EncoderRNN(len(source_vocab), params['hidden_size']).to(device)
    decoder = AttnDecoderRNN(n_hid=params['hidden_size'], 
                             n_out=len(target_vocab), 
                             dropout_p=params['dropout_p'],
                             max_length=_vocab.max_length
                            ).to(device)

# モデルの概要を印字
print(f'encoder:{encoder}')
print(f'decoder:{decoder}')
_param = {}
for _model in ['encoder', 'decoder']:
    _param[_model] = {}
    for __name, __param in eval(_model).named_parameters():
        _param[_model][__name] = __param.detach().numpy()

for _model, _val in _param.items():
    print(colored(f'{_model}','red', attrs=['bold']))
    for w_name, w_val in _param[_model].items():
        print((w_name, _param[_model][w_name].shape, w_val.dtype)) #_param[model][)) # (k,_k))
        
for test_name, val_dataset in X_vals.items():
    acc = calc_accuracy(_dataset=val_dataset,
                        encoder=encoder,
                        decoder=decoder,
                        #decoder=decoder2,
                        max_length=_vocab.max_length,
                        source_vocab=source_vocab,
                        target_vocab=target_vocab)
    print(colored(f'{test_name} の精度:{acc:.3f}','blue', attrs=['bold']))


# params の印刷
print(colored(params,'blue',attrs=['bold']))    

# 3 訓練の実施

In [None]:
losses = []
losses += lam.fit(encoder, 
                  decoder, 
                  device=device,
                  epochs=params['epochs'], 
                  max_length=_vocab.max_length,
                  n_sample=0,
                  params=params,
                  source_vocab=source_vocab,
                  target_vocab=target_vocab,
                  teacher_forcing_ratio=params['teacher_forcing_ratio'],
                  train_dataset=train_dataset,
                  val_dataset=X_vals)

In [None]:
# torch.save(encoder, '2023_0106lam_encoder.pt')
# torch.save(decoder, '2023_0106lam_decoder.pt')

In [None]:
encoder2 = EncoderRNN(len(source_vocab), 
                      params['hidden_size']).to(device)
decoder2 = AttnDecoderRNN(n_hid=params['hidden_size'], 
                          n_out=len(target_vocab), 
                          dropout_p=params['dropout_p'],
                          max_length=_vocab.max_length
                         ).to(device)

with open('2023_0106lam_encoder.pt', 'rb') as f:
    encoder2 = torch.load(f)
    
with open('2023_0106lam_decoder.pt', 'rb') as f:
    decoder2 = torch.load(f)
    
for test_name, val_dataset in X_vals.items():
    acc = calc_accuracy(_dataset=val_dataset,
                        encoder=encoder2,
                        decoder=decoder2,
                        max_length=_vocab.max_length,
                        source_vocab=source_vocab,
                        target_vocab=target_vocab)
    print(colored(f'{test_name} の精度:{acc:.3f}','blue', attrs=['bold']))
    
for k, v in params.items():
    print(k, colored(v, 'green', attrs=['bold']))

In [None]:
from lam import evaluate

TEST_DATA = 'tlpa4val'  # 'sala_r29val', 'sala_r30val', sala_r31val, 'tlpa2val', 'tlpa3val', 'tlpa4val'
calc_accuracy(_dataset=X_vals[TEST_DATA],
              encoder=encoder2,
              decoder=decoder2,
              max_length=_vocab.max_length,
              source_vocab=source_vocab,
              target_vocab=target_vocab)

dataset = X_vals[TEST_DATA]
for i in range(5):
    x, y = dataset.__getitem__(i)
    print(i, [target_vocab[_x] for _x in x], end=": ")
    print([target_vocab[_y] for _y in y], end=":: ")    
    _output_words, _output_ids, _attentions = evaluate(encoder=encoder2,
                                                       decoder=decoder2,
                                                       input_ids=x,
                                                       max_length=_vocab.max_length,             
                                                       source_vocab=source_vocab,
                                                       target_vocab=target_vocab)
    print(_output_words)
    

In [None]:
#print(source_vocab)
print(target_vocab)
word = '戦争'
for x in word:
    print(x) # ,source_vocab[x])
source_vocab.index('戦')

In [None]:
fushimi1998 = {
    'HF___consist__': ['戦争', '倉庫', '医学', '注意', '記念', '番号', '料理', '完全', '開始', '印刷',
                       '連続', '予約', '多少', '教員', '当局', '材料', '夕刊', '労働', '運送', '電池' ], # consistent, 'high-frequency words
    'HF___inconsist': ['反対', '失敗', '作品', '指定', '実験', '決定', '独占', '独身', '固定', '食品',
                       '表明', '安定', '各種', '役所', '海岸', '決算', '地帯', '道路', '安打', '楽団' ], # inconsistent, 'high-frequency words
    'HF___atypical_': ['仲間', '夫婦', '人間', '神経', '相手', '反発', '化粧', '建物', '彼女', '毛糸', 
                       '場合', '台風', '夜間', '人形', '東西', '地元', '松原', '競馬', '大幅', '貸家' ], # inconsistent atypical, 'high-frequency words
    'LF___consist__': ['集計', '観察', '予告', '動脈', '理学', '信任', '任務', '返信', '医局', '低温', 
                       '区別', '永続', '持続', '試練', '満開', '軍備', '製材', '銀貨', '急送', '改選' ], # consistent, 'low-frequecy words
    'LF___inconsist': ['表紙', '指針', '熱帯', '作詞', '決着', '食費', '古代', '地形', '役場', '品種', 
                       '祝福', '金銭', '根底', '接種', '経由', '郷土', '街路', '宿直', '曲折', '越境' ], # inconsistent, 'low-frequency words
    'LF___atypical_': ['強引', '寿命', '豆腐', '出前', '歌声', '近道', '間口', '風物', '面影', '眼鏡', 
                       '居所', '献立', '小雨', '毛皮', '鳥居', '仲買', '頭取', '極上', '奉行', '夢路' ], # inconsistent atypical, 'low-frequncy words
    'HFNW_consist__': ['集学', '信別', '製信', '運学', '番送', '電続', '完意', '軍開', '動選', '当働', 
                       '予続', '倉理', '予少', '教池', '理任', '銀務', '連料', '開員', '注全', '記争' ], # consistent, 'high-character-frequency nonwords
    'HFNW_inconsist': ['作明', '風行', '失定', '指団', '決所', '各算', '海身', '東発', '楽験', '作代',
                       '反原', '独対', '歌上', '反定', '独定', '場家', '安種', '経着', '決土', '松合' ], # inconsistent biased, 'high-character-frequency nonwords
    'HFNW_ambiguous': ['表品', '実定', '人風', '神間', '相経', '人元', '小引', '指場', '毛所', '台手',
                       '間物', '道品', '出取', '建馬', '大婦', '地打', '化間', '面口', '金由', '彼間' ], # inconsistent ambigous, 'high-character-frequency nonwords
    'LFNW_consist__': ['急材', '戦刊', '返計', '印念', '低局', '労号', '満送', '永告', '試脈', '観備',
                       '材約', '夕局', '医庫', '任続', '医貨', '改練', '区温', '多始', '材刷', '持察' ], # consistent, 'low-character-frequency nonwords
    'LFNW_inconsist': ['食占', '表底', '宿帯', '決帯', '古費', '安敗', '役針', '近命', '眼道', '豆立',
                       '街直', '固路', '郷種', '品路', '曲銭', '献居', '奉買', '根境', '役岸', '祝折' ], # inconsistent biased, 'low-character-frequency nonwords
    'LFNW_ambiguous': ['食形', '接紙', '競物', '地詞', '強腐', '頭路', '毛西', '夜糸', '仲影', '熱福',
                       '寿前', '鳥雨', '地粧', '越種', '仲女', '極鏡', '夢皮', '居声', '貸形', '夫幅' ], # inconsistent ambigous, 'low-character-frequency nonwords
}

for k, v in fushimi1998.items():
    print(k, v)


In [None]:
import torch
try:
    import bit
except ImportError:
    !pip install ipynbname --upgrade > /dev/null 2>&1 
    !git clone https://github.com/ShinAsakawa/bit.git
import bit

isColab = bit.isColab
HOME = bit.HOME
%config InlineBackend.figure_format = 'retina'

import IPython
isColab = 'google.colab' in str(IPython.get_ipython())
if isColab:
    !pip install --upgrade openpyxl
    !pip install --upgrade pandas
    !pip install --upgrade fugashi[unidic-lite]
    !pip install --upgrade ipadic
    !python -m unidic download
    !pip install transformers
    !pip install --upgrade jaconv
    
from tqdm.notebook import tqdm    

In [None]:
!gls -lt ../*.xls*

In [None]:
import pandas as pd
excel_fname = '../単語リスト（扱い注意）.xlsx'
corpus_pd = pd.read_excel(excel_fname, engine='openpyxl')
corpus_pd