<a href="https://colab.research.google.com/github/ShinAsakawa/ShinAsakawa.github.io/blob/master/2025notebooks/2025_0626psylex71_CDP%2Bja.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://raw.githubusercontent.com/project-ccap/project-ccap.github.io/refs/heads/master/2025figs/1998Zorzi_CDP_fig1.svg">
Zorzi+(1998) Fig.1 Architecture of the model. The arrow means full connectivity between layers. Each box stand for a group of letters (26) or phonemes (44).<br/>


<img src="https://raw.githubusercontent.com/project-ccap/project-ccap.github.io/refs/heads/master/2025figs/1998Zorzi_CDP_fig8.svg">
<p>Zorzi+(1998) Fig.8. Architecture of the model with the hidden layer pathway. In both the direct pathway and the mediated pathway the layers are fully connected (arrows).</p>



In [None]:
%config InlineBackend.figure_format = 'retina'
import torch
#device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f'device:{device}')

# 必要なライブラリの輸入
from collections import OrderedDict
import sys
import os
import numpy as np
# import time
# import datetime
import operator
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

HOME = os.environ['HOME']

from IPython import get_ipython
isColab =  'google.colab' in str(get_ipython())

try:
    import ipynbname
except ImportError:
    !pip install ipynbname
    import ipynbname

FILEPATH = str(ipynbname.path()).split('/')[-1]
print(f'FILEPATH:{FILEPATH}')

try:
    import japanize_matplotlib
except ImportError:
    !pip install japanize_matplotlib
    import japanize_matplotlib

# モーラ分かち書きの定義    

In [None]:
# モーラ分かち書きの定義
# source https://qiita.com/shimajiroxyz/items/a133d990df2bc3affc12
import re

# 各条件を正規表現で表す
c1 = '[ウクスツヌフムユルグズヅブプヴ][ァィェォ]' #ウ段＋「ァ/ィ/ェ/ォ」
c2 = '[イキシチニヒミリギジヂビピ][ャュェョ]' #イ段（「イ」を除く）＋「ャ/ュ/ェ/ョ」
c3 = '[テデ][ィュ]' #「テ/デ」＋「ャ/ィ/ュ/ョ」
c4 = '[ァ-ヴー]' #カタカナ１文字（長音含む）

cond = '('+c1+'|'+c2+'|'+c3+'|'+c4+')'
re_mora = re.compile(cond)

def moraWakachi(kana_text):
    kana_text = kana_text.replace('ヱ','エ').replace('ヰ','イ')
    return re_mora.findall(kana_text)

# text = 'シンシュンシャンソンショー'
# print(text)
# print(moraWakachi(text))
# print('')

# text = 'トーキョートッキョキョカキョク'
# print(text)
# print(moraWakachi(text))
# print('')

# text = 'アウトバーン'
# print(text)
# print(moraWakachi(text))
# print('')

# text = 'ガッキュウホウカイ'
# print(text)
# print(moraWakachi(text))

# 文字の定義，(学習文字，かな，カナ，数字，記号など)

In [None]:
# 書記素の定義，書記素のうちカタカナを音韻表現としても利用

seed = 42
special_tokens = ['<PAD>', '<EOW>', '<SOW>', '<UNK>']
alphabet_upper_chars='ＡＢＣＤＥＦＧＨＩＪＫＬＭＮＯＰＱＲＳＴＵＶＷＸＹＺ'
alphabet_lower_chars='ａｂｃｄｅｆｇｈｉｊｋｌｍｎｏｐｑｒｓｔｕｖｗｘｙｚ'
num_chars='０１２３４５６７８９'
hira_chars='ぁあぃいぅうぇえぉおかがきぎくぐけげこごさざしじすずせぜそぞただちぢっつづてでとどなにぬねのはばぱひびぴふぶぷへべぺほぼぽまみむめもゃやゅゆょよらりるれろゎわゐゑをん'
kata_chars='ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヵヶ'
#kata_chars=kata_chars+'一'  # カタカナ文字に伸ばし記号を加える
#phon_list = list(kata_chars+'一')

# # 句点コード
# from RAM.char_ja import kuten as kuten
# kuten_chars=kuten().chars

# # 常用漢字
# from RAM.char_ja import chars_joyo as chars_joyo
# joyo_chars = "".join([ch for ch in chars_joyo().char_list])

# 学習漢字 学年別
_gakushu_list = ['一右雨円王音下火花貝学気休玉金九空月犬見五口校左三山四子糸字耳七車手十出女小上森人水正生青石赤先千川早草足村大男竹中虫町天田土二日入年白八百文本名木目夕立力林六',
'引羽雲園遠黄何夏家科歌画会回海絵外角楽活間丸岩顔帰汽記弓牛魚京強教近兄形計元原言古戸午後語交光公工広考行高合国黒今才細作算姉市思止紙寺時自室社弱首秋週春書少場色食心新親図数星晴声西切雪線船前組走多太体台谷知地池茶昼朝長鳥直通弟店点電冬刀東当答頭同道読内南肉馬買売麦半番父風分聞米歩母方北妹毎万明鳴毛門夜野矢友曜用来理里話',
'悪安暗委意医育員飲院運泳駅横屋温化荷界開階寒感漢館岸期起客宮急球究級去橋業局曲銀区苦具君係軽決血研県庫湖向幸港号根祭坂皿仕使始指死詩歯事持次式実写者主取守酒受州拾終習集住重宿所暑助勝商昭消章乗植深申真神身進世整昔全想相送息速族他打対待代第題炭短談着柱注丁帳調追定庭笛鉄転登都度島投湯等豆動童農波配倍箱畑発反板悲皮美鼻筆氷表病秒品夫負部服福物平返勉放味命面問役薬油有由遊予様洋羊葉陽落流旅両緑礼列練路和',
'愛案以位囲胃衣印栄英塩央億加果課貨芽改械害街各覚完官管観関願喜器希旗機季紀議救求泣給挙漁競共協鏡極訓軍郡型径景芸欠結健建験固候功好康航告差最菜材昨刷察札殺参散産残司史士氏試児治辞失借種周祝順初唱松焼照省笑象賞信臣成清静席積折節説戦浅選然倉巣争側束続卒孫帯隊達単置仲貯兆腸低停底的典伝徒努灯働堂得特毒熱念敗梅博飯費飛必標票不付府副粉兵別変辺便包法望牧末満未脈民無約勇要養浴利陸料良量輪類令例冷歴連労老録',
'圧易移因営永衛液益演往応恩仮価可河過賀解快格確額刊幹慣眼基寄規技義逆久旧居許境興均禁句群経潔件券検険減現限個故護効厚構耕講鉱混査再妻採災際在罪財桜雑賛酸師志支枝資飼似示識質舎謝授修術述準序承招証常情条状織職制勢性政精製税績責接設絶舌銭祖素総像増造則測属損態貸退団断築張提程敵適統導銅徳独任燃能破判版犯比肥非備俵評貧婦富布武復複仏編弁保墓報豊暴貿防務夢迷綿輸余預容率略留領',
'異遺域宇映延沿我灰拡閣革割株巻干看簡危揮机貴疑吸供胸郷勤筋敬系警劇激穴憲権絹厳源呼己誤后孝皇紅鋼降刻穀骨困砂座済裁策冊蚕姿私至視詞誌磁射捨尺若樹収宗就衆従縦縮熟純処署諸除傷将障城蒸針仁垂推寸盛聖誠宣専泉洗染善創奏層操窓装臓蔵存尊宅担探誕暖段値宙忠著庁潮頂賃痛展党糖討届難乳認納脳派俳拝背肺班晩否批秘腹奮並閉陛片補暮宝訪亡忘棒枚幕密盟模訳優郵幼欲翌乱卵覧裏律臨朗論']

_l = []
for g in _gakushu_list:
    for ch in g:
        _l += ch
gakushu_chars = "".join(ch for ch in _l)

grph_list = []
#for x in [hira_chars]:                       # 数字は入力文字としない場合
#for x in [hira_chars, gakushu_chars]:             # 数字は入力文字としない場合
for x in [hira_chars, kata_chars, num_chars, gakushu_chars]: # 数字も入力文字とする場合
    for ch in x:
        grph_list.append(ch)
print(f'len(grph_list):{len(grph_list)}')
print(f'全書記素 grph_list:{"".join([ch for ch in grph_list])}')

# print(f'len(phon_list):{len(phon_list)}')
# print(f'全音素 phon_list:{phon_list}')

print(f'入力層の素子数 len(grph_list) + len(special_tokens)={len(grph_list) + len(special_tokens)}')
# print(f'出力層の素子数 len(phon_list) + len(special_tokens)={len(phon_list) + len(special_tokens)}')

# NTT 日本語の語彙特性 単語頻度データの読み込み

In [None]:
if isColab:
    !pip install googledrivedownloader==0.4
    from google_drive_downloader import GoogleDriveDownloader as gdd
    import os

    # 共有ファイルのIDを指定
    file_id = '1eBJDN392BsUckg5LBFbbw5KT9PCsmnxI' # 'psylex71utf8_.txt
    # https://drive.google.com/file/d/1eBJDN392BsUckg5LBFbbw5KT9PCsmnxI/view?usp=drive_link

    # 保存したい場所とファイル名を指定\n",
    # 例: /content/ ディレクトリに original_file_name.拡張子 という名前で保存\n",
    destination_path = '/content/psylex71utf8_.txt' # ファイルの拡張子を適切に設定してください\n",
    try:
        print(f"ファイルのダウンロードを開始します (ファイルID: {file_id})...")
        gdd.download_file_from_google_drive(file_id=file_id,
                                            dest_path=destination_path)
                                            # unzip=True if file_id is for a zip file):
        print(f"ファイルのダウンロードが完了しました。'{destination_path}' に保存されました。")

        # ダウンロードしたファイルを読み込む例 (テキストファイルの場合)
        if os.path.exists(destination_path):
            print("ダウンロードしたファイルの内容 (最初の数行):")
            with open(destination_path, 'r') as f:
                # ファイルの内容を表示 (例: 最初の5行)
                for i in range(5):
                    line = f.readline()
                    if not line:
                        break
                    print(line.strip())
        else:
            print(f"エラー: ダウンロード先のファイル '{destination_path}' が見つかりません。")

    except Exception as e:
        print(f"ファイルのダウンロード中にエラーが発生しました: {e}")

In [None]:
!pwd

In [None]:
# NTT 日本語の語彙特性単語頻度データ psylex71.txt の読み込み
#HOME = os.environ['HOME']
if isColab:
    ntt_base = '/content'
else:
    ntt_base = os.path.join(HOME, 'study/2017_2009AmanoKondo_NTTKanjiData')
psy71_fname = os.path.join(ntt_base, 'psylex71utf8_.txt')  # ファイル名
psylex71raw = open(psy71_fname, 'r').readlines()
psylex71raw = [lin.strip().split(' ')[:6] for lin in psylex71raw]   # 空白 ' ' で分離し，年度ごとの頻度を削除
print(f'len(psylex71raw):{len(psylex71raw)}')

valid_chars = kata_chars + 'ー'

# Psylex71 一行のデータは 0:共通ID, 1:独自ID, 2:表記, 3:ヨミ, 4:品詞, 5:頻度 を取り出す。
#n_idx=0; n_wrd=2; n_yomi=3; n_pos=4; n_frq=5
psylex_ids = {'_idx':0, '_idx2':1, '_wrd':2, '_yomi':3, '_pos':4, '_frq':5, '_mora':6}
print(f'psylex_ids{psylex_ids}')

mora_dict = OrderedDict()

for x in tqdm(psylex71raw[1:]):
    _word =  x[psylex_ids['_wrd']]
    _yomi = x[psylex_ids['_yomi']]
    is_valid = True
    for ch in _yomi:
        if not ch in valid_chars:
            is_valid = False
    if is_valid:
        morae = moraWakachi(_yomi)
        for m in morae:
            if not m in mora_dict:
                mora_dict[m] = 1
            else:
                mora_dict[m] += 1

print(f'len(mora_dict):{len(mora_dict)}')
mora_list = sorted(mora_dict.keys())

is_graph = False
print(len(mora_dict), mora_dict)
if is_graph:
    N_mora=np.array([v for v in mora_dict.values()]).sum()
    mora_count_sorted = sorted(mora_dict.items(), key=operator.itemgetter(1), reverse=True)
    figsize=(24,4)
    topN = 100
    plt.figure(figsize=figsize)
    plt.bar(range(topN), [x[1]/N_mora for x in mora_count_sorted[:topN]])
    plt.xticks(ticks=range(topN), labels=[c[0] for c in mora_count_sorted[:topN]])

    plt.title(f'モーラ頻度 (上位:{topN} 語)')
    plt.ylabel('相対頻度')
    plt.show()
    #len(mora_dict)

In [None]:
maxlen_grph = 2        # 書記素最大文字数 + 2 しているのは, 単語の前後に特殊トークン <SOW> <EOW> をつけるため
valid_chars=grph_list  # 書記素リスト grph_list を有効文字リスト valid_chars とする
ng_yomi_words = []
dups_idx = []
_psylex71_ = []

Psylex71 = OrderedDict()
for lin in psylex71raw:
    wrd = lin[psylex＿ids['_wrd']]
    idx = lin[psylex＿ids['_idx']]
    yomi = lin[psylex＿ids['_yomi']]
    pos = lin[psylex＿ids['_pos']]
    frq = lin[psylex＿ids['_frq']]

    # print(f'type(lin):{type(lin)}')
    # print(f'lin:{lin}')
    # sys.exit()

    if len(wrd) == maxlen_grph:  # 長さが maxlen_grph 文字である語に対して処理を行う

        # ヨミの中にカタカナ以外の文字が入っていれば NG_flag を True にする
        is_kata_yomi = True
        for p in yomi:
            if not p in kata_chars:
                is_kata_yomi = False

        # ヨミにカタカナ以外の文字が含まれていれば ng_yomi_words に加える
        if is_kata_yomi == False:
            ng_yomi_words.append((wrd,yomi))
        else:

            # valid_chars (学習漢字+)で構成されているか否かを判断
            is_valid_grph = True
            for i in range(maxlen_grph):
                if not wrd[i] in valid_chars:
                    is_valid_grph = False

            if is_valid_grph == True:

                _mora = moraWakachi(yomi) # .strip()  # モーラ分かち書きを行う
                if idx in Psylex71:   # すでに ID 番号が登録されていれば dups_idx リストに加える
                    dups_idx.append((idx, lin, (Psylex71[idx]['単語'], Psylex71[idx]['ヨミ'], _mora)))

                Psylex71[idx] = {'単語': wrd, 'モーラ':_mora, 'ヨミ': yomi, '品詞': pos,'頻度': frq}
                _psylex71_.append(lin + [_mora])


# 読み (音韻表現) の最大長値の探索
maxlen_phon = 0
for a in _psylex71_:
    if len(a[psylex_ids['_mora']]) > maxlen_phon:
         maxlen_phon = len(a[psylex_ids['_mora']])

# 結果の表示
print(f'読み込んだ psylex71.txt の単語数 len(psylex71raw):{len(psylex71raw)}')
print(f'Psylex71 の総単語数 len(_psylex71_):{len(_psylex71_)}')
print(f'作成したデータベース辞書の項目数 len(Psylex71):{len(Psylex71)}')
print(f'ヨミの最長文字数 maxlen_phon:{maxlen_phon}')
print(f'len(mora_list):{len(mora_list)}')
#print(f'音素 (読みのカタカナ文字)数 len(phon_cands):{len(phon_cands)}')
print(f'Psylex71 におけるカタカナ以外のヨミのある単語数 len(ng_yomi_words):{len(ng_yomi_words)}')
print(f'Psylex71 における ID 番号の重複数 len(dups_idx):{len(dups_idx)}')

# `Psylex71_Dataset` (モデルに Psylex71 を学習させるためのクラス) の作成

In [None]:
import torch
class Psylex71_Dataset(torch.utils.data.Dataset):
    '''ニューラルネットワークモデルに Psylex71 を学習させるための PyTorch 用データセットのクラス'''

    def __init__(self,
                 dic=Psylex71,
                 grph_list=grph_list,
                 phon_list=mora_list,
                 special_tokens=special_tokens,
                 maxlen_phon=maxlen_phon +2, # ＋2 しているのは <SOW>,<EOW> という 2 つのスペシャルトークンを付加するため
                ):
        super().__init__()
        self.dic = dic
        self.special_tokens = special_tokens
        self.maxlen_phon = maxlen_phon
        self.grph_list = grph_list
        self.phon_list = phon_list
        self.input_cands = grph_list
        #self.target_cands = special_tokens + phon_list
        self.target_cands = special_tokens + mora_list
        # self.inputs = [v['単語'] for v in dic.values()]
        # self.targets = [v['ヨミ'] for v in dic.values()]
        # self.targets = [v['モーラ'] for v in dic.values()]
        self.inputs = [v['単語'] for v in dic.values()]
        self.targets = [v['ヨミ'] for v in dic.values()]
        self.targets = [v['モーラ'] for v in dic.values()]

    def __len__(self):
        return len(self.dic)

    def __getitem__(self, idx):
        inp, tgt = self.inputs[idx], self.targets[idx]

        # 入力信号にも <SOW>, <EOW> トークンを付与する場合
        #inp = [self.input_cands.index('<SOW>')]  + [self.input_cands.index(x) for x in inp]  + [self.input_cands.index('<EOW>')]

        # 入力信号にはスペシャルトークンを付与しない場合
        inp = [self.input_cands.index(x) for x in inp]

        # ターゲット (教師)信号 には <SOW>, <EOW> を付与する
        tgt = [self.target_cands.index('<SOW>')] + [self.target_cands.index(x) for x in tgt] + [self.target_cands.index('<EOW>')]

        while len(tgt) < self.maxlen_phon:
            tgt = tgt + [self.target_cands.index('<PAD>')]

        inp, tgt = torch.LongTensor(inp), torch.LongTensor(tgt)
        return inp, tgt

    def getitem(self, idx):
        #inp, tgt = self.inputs[idx], self.targets[idx]
        wrd = self.inputs[idx]
        phn = self.targets[idx]
        return wrd, phn

    def ids2argmax(self, ids):
        out = np.array([torch.argmax(idx).numpy() for idx in ids], dtype=np.int32)
        return out

    def ids2tgt(self, ids):
        #out = [self.target_cands[torch.argmax(idx)] for idx in ids]
        out = [self.target_cands[idx - len(self.special_tokens)] for idx in ids]
        return out

    def ids2inp(self, ids):
        out = [self.input_cands[idx] for idx in ids]
        #out = [self.input_cands[idx - len(self.special_tokens)] for idx in ids]
        return out

    def target_ids2target(self, ids:list):
        ret = []
        for idx in ids:
            if idx == self.target_cands.index('<EOW>'):
                return ret+['<EOW>']
            ret.append(self.target_cands[idx])
        return ret


psylex71_ds = Psylex71_Dataset()

_ds = psylex71_ds
#for N in np.random.permutation(psylex71_ds.__len__())[:15]:
for N in range(15):
    inp, tgt = psylex71_ds.__getitem__(N)
    print(f'_ds.ids2inp(inp):{_ds.ids2inp(inp)}',
          f'{inp.numpy()}',
          f'_ds.target_ids2target(tgt):{_ds.target_ids2target(tgt)}',
          f'{tgt.numpy()}')


train_size = int(_ds.__len__() * 0.7)
train_size = int(_ds.__len__() * 0.3)
valid_size = _ds.__len__() - train_size
train_ds, valid_ds = torch.utils.data.random_split(dataset=_ds, lengths=(train_size, valid_size), generator=torch.Generator().manual_seed(seed))

batch_size = 64
batch_size = 4096
train_dl = torch.utils.data.DataLoader(dataset=train_ds, batch_size=batch_size, shuffle=True)
valid_dl = torch.utils.data.DataLoader(dataset=valid_ds, batch_size=batch_size, shuffle=False)

def _collate_fn(batch):
    inps, tgts = list(zip(*batch))
    inps = list(inps)
    tgts = list(tgts)
    return inps, tgts

# batch_size = 4
train_dl = torch.utils.data.DataLoader(
    dataset=train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=_collate_fn)

valid_dl = torch.utils.data.DataLoader(
    dataset=valid_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=_collate_fn)

print(f'train_ds.__len__():{train_ds.__len__()}')

_ds = train_ds
for N in range(15):
    inp, tgt = _ds.__getitem__(N)
    print(f'_ds.dataset.ids2inp(inp):{_ds.dataset.ids2inp(inp)}',
          f'{inp.numpy()}',
          f'_ds.datsset.target_ids2target(tgt):{_ds.dataset.target_ids2target(tgt)}',
          f'{tgt.numpy()}')


# TLA モデルの定義

In [None]:
class TLA(torch.nn.Module):
    def __init__(self,
                 # maxlen_phon+2 しているのは単語の前後に <SOW>, <EOW> トークンを付けるため
                 inp_size= (len(grph_list)+len(special_tokens)), # * (maxlen_grph + 2),
                 inp_len=maxlen_grph, #  + 2,
                 out_size=len(mora_list)+len(special_tokens),
                 out_len=maxlen_phon+2,
                 hid_size=128,
                 device=device,
                ):
        super().__init__()
        self.inp_size=inp_size
        self.inp_len=inp_len
        self.out_size=out_size
        self.out_len=out_len
        self.hid_size=hid_size

        self.emb_layers = [torch.nn.Embedding(num_embeddings=inp_size, embedding_dim=hid_size, padding_idx=0).to(device) for _ in range(inp_len)]
        #self.emb_layer = torch.nn.Embedding(num_embeddings=inp_size, embedding_dim=hid_size, padding_idx=0).to(device)

        self.hid_layer = torch.nn.Linear(in_features=hid_size * inp_len, out_features=hid_size).to(device)
        #self.hid_layer = torch.nn.Linear(in_features=inp_len * inp_size, out_features=hid_size)

        self.out_layers = [torch.nn.Linear(in_features=hid_size, out_features=out_size).to(device) for _ in range(out_len)]

    def forward(self, inp):
        X = inp
        batch_size = X.size(0)
        n_grph = X.size(1)

        embs = []
        for i in range(n_grph):
            _emb = self.emb_layers[i](X[:,i])
            #print(f'{i}:_emb.size():{_emb.size()}')
            embs.append(_emb)

        _embs = torch.concat(embs,dim=1)
        X = _embs
        X = self.hid_layer(X)         # 中間層次元へ変換

        # 出力層の音韻表現ごとへ変換
        outputs = []
        for i in range(self.out_len):
            _out = self.out_layers[i](X)
            outputs.append(_out)

        # softmax 変換
        #outputs = [torch.nn.functional.softmax(out,dim=1) for out in outputs]
        outputs = [torch.nn.functional.sigmoid(out) for out in outputs]

        #outputs = torch.cat(outputs, dim=0)
        # outputs = torch.stack(outputs)
        # return outputs

        O = torch.empty(self.out_len, batch_size, self.out_size)
        for i in range(len(outputs)):
            O[i] = outputs[i]
        O = O.reshape(batch_size, self.out_len, self.out_size)
        O = torch.Tensor(O)
        return O

tla = TLA(device=device)
tla.eval()

In [None]:
class vanilla_TLA(torch.nn.Module):
    def __init__(self,
                 inp_size= (len(grph_list)+len(special_tokens)),
                 inp_len=maxlen_grph,
                 out_size=len(mora_list)+len(special_tokens),
                 out_len=maxlen_phon+2,
                 hid_size=1024,
                 device=device,
                ):
        super().__init__()
        self.inp_size=inp_size
        self.inp_len=inp_len
        self.out_size=out_size
        self.out_len=out_len
        self.hid_size=hid_size

        self.emb_layer = torch.nn.Linear(in_features=inp_size * inp_len, out_features=hid_size).to(device)
        self.sigmoid = torch.nn.Sigmoid()
        self.tanh = torch.nn.Tanh()
        self.relu = torch.nn.ReLU()
        self.out_layer = torch.nn.Linear(in_features=hid_size, out_features=out_size * out_len).to(device)

    def forward(self, inp):
        X = inp
        X = torch.nn.functional.one_hot(X, num_classes=self.inp_size)
        X = X.reshape(X.size(0),-1)
        X = X.float()
        X = self.emb_layer(X)
        X = self.tanh(X)
        X = self.out_layer(X)
        X = self.sigmoid(X)
        X = X.reshape(X.size(0), self.out_len, self.out_size)

        return X

vanilla_tla = vanilla_TLA(device=device)
vanilla_tla.eval()

# 定義したモデルの試用

In [None]:
# idx に整数を指定して,対応するデータを取得する
from torch.nn.utils.rnn import pad_sequence

#_ds = psylex71_ds
_ds = train_ds
idx = np.random.choice(_ds.__len__())
#idx = 0

vanilla_tla.eval()
#tla.eval()

# N 個のデータを実行してみる
N = 5
ids = np.random.permutation(_ds.__len__())[:N]
for idx in ids:
    # データセットから返ってくる値は入力信号 inp と教師信号 tch
    inp, tch = _ds.__getitem__(idx)
    print(f'idx:{idx}:', f'inp:{inp}', f'tch:{tch}')

    # 入出力信号はトークン ID 番号であるため人間が読みやすいように変換して表示
    print(f'_ds.dataset.ids2inp({inp}):{_ds.dataset.ids2inp(inp)}')
    print(f'_ds.dataset.taregt_ids2target({tch}):{_ds.dataset.target_ids2target(tch)}')

    inp = pad_sequence(inp.unsqueeze(0), batch_first=True).to(device)

    # outs = tla(inp)
    # print('出力:', _ds.dataset.target_ids2target([int(_out.argmax().cpu().numpy()) for _out in outs[0]]), end=": ")
    # print('出力 ids:', [int(_out.argmax().cpu().numpy()) for _out in outs[0]])

    print('教師:', train_ds.dataset.target_ids2target([idx.numpy() for idx in tch]), end=": ")
    print('教師 ids:', [int(_tch.numpy()) for _tch in tch])
    outs = vanilla_tla(inp)
    print('出力:', train_ds.dataset.target_ids2target([int(_out.argmax().cpu().numpy()) for _out in outs[0]]), end=": ")
    print('出力 ids:', [int(_out.argmax().cpu().numpy()) for _out in outs[0]], end="\n===\n")

In [None]:
# ミニバッチバージョン

tla = vanilla_tla
loss_f = torch.nn.CrossEntropyLoss(ignore_index=-1)
optimizer = torch.optim.Adam(tla.parameters(), lr=1e-3)
epochs = 2
epochs = 100

_ds = train_ds
for epoch in range(epochs):

    tla.train()
    sum_loss = 0.
    count  = 0
    _dl = train_dl

    for inps, tchs in _dl:
    #for inps, tchs in tqdm(_dl):
        inps = pad_sequence(inps, batch_first=True).to(device)
        tchs = pad_sequence(tchs, batch_first=True).to(device)
        outs = tla(inps)

        losses = 0.
        optimizer.zero_grad()
        for j in range(len(tchs)):
            loss = loss_f(outs[j],tchs[j])
            losses += loss
            sum_loss += loss.item()

        losses.backward()
        optimizer.step()

        out_ids = [out.argmax(dim=1) for out in outs]
        for tch, out in zip(tchs[:], out_ids[:]):
            yesno = ((tch==out) * 1).sum().cpu().numpy() == len(tch)
            count += 1 if yesno else 0

    p_correct = count / _ds.__len__()
    print(f'epoch:{epoch+1:03d}', end=" ")
    print(f'p_correct:{p_correct:5.3f}', end="")
    print(f'=({count:5d}/{_ds.__len__():5d})', end= " ")
    print(f'sum_loss:{sum_loss/_ds.__len__():.3f}')