<a href="https://colab.research.google.com/github/ShinAsakawa/ShinAsakawa.github.io/blob/master/2023notebooks/2023_0421s2p.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 意味 s2p, p2s の実装

* see Harm&Seidenberg(2004)
* and also see doi: https://doi.org/10.1101/2021.04.15.440047, and doi:https://doi.org/10.1101/708156

<!-- <img src="2004Harm_Seidenberg_fig4c.svg">
<img src="2004Harm_Seidenberg_fig4d.svg"> -->

In [None]:
%config InlineBackend.figure_format = 'retina'
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

from IPython import get_ipython
isColab =  'google.colab' in str(get_ipython())

if isColab:

    # termcolor を downgrade しないと colab ではテキストに色がつかない
    !pip install --upgrade termcolor==1.1
    import termcolor    

    # 結果を保存するために Google Drive をマウントする
    import google.colab
    google.colab.drive.mount('/content/drive/')
    
    # GPU 情報を表示
    !nvidia-smi -L

    #!pip install ipynbname --upgrade > /dev/null

if isColab:
    # colab 上で MeCab を動作させるために，C コンパイラを起動して，MeCab の構築を行う
    # そのため時間がかかる。
    !apt install aptitude
    !aptitude install mecab libmecab-dev mecab-ipadic-utf8 git make curl xz-utils file -y
    !pip install mecab-python3==0.7
    !pip install jaconv
    
    import MeCab
    mecab_wakati = MeCab.Tagger('-Owakati').parse
    mecab_yomi = MeCab.Tagger('-Oyomi').parse
    
else:
    from ccap.mecab_settings import yomi as mecab_yomi
    from ccap.mecab_settings import wakati as mecab_wakati


# ここから下は，コード実行に関するバージョン情報などの情報源の取得と表示
from termcolor import colored

import platform
HOSTNAME = platform.node().split('.')[0]

import os
HOME = os.environ['HOME']

try:
    import ipynbname
except ImportError:
    !pip install ipynbname
    import ipynbname
FILEPATH = str(ipynbname.path()).replace(HOME+'/','')

import pwd
USER=pwd.getpwuid(os.geteuid())[0]

from datetime import date
TODAY=date.today()

import torch
TORCH_VERSION = torch.__version__

color = 'green'
print('日付:',colored(f'{TODAY}', color=color, attrs=['bold']))
print('HOSTNAME:',colored(f'{HOSTNAME}', color=color, attrs=['bold']))
print('ユーザ名:',colored(f'{USER}', color=color, attrs=['bold']))
print('HOME:',colored(f'{HOME}', color=color,attrs=['bold']))
print('ファイル名:',colored(f'{FILEPATH}', color=color, attrs=['bold']))
print('torch.__version__:',colored(f'{TORCH_VERSION}', color=color, attrs=['bold']))

In [None]:
from tqdm.notebook import tqdm
from copy import deepcopy
import numpy as np

try:
    import japanize_matplotlib
except ImportError:
    !pip install japanize_matplotlib
    import japanize_matplotlib

# 語彙データは vdrj，意味データは word2vec を使用
try:
    from ccap import ccap_w2v
except ImportError:
    !git clone https://github.com/project-ccap/ccap.git
    from ccap import ccap_w2v

try:
    from RAM import VDRJ_Dataset
except ImportError:
    !git clone https://github.com/ShinAsakawa/RAM.git
    from RAM import VDRJ_Dataset
vdrj_ds = VDRJ_Dataset(max_words=30000)
w2v = ccap_w2v(isColab=False).w2v

# vdrj_ds (dataset) から派生させたデータセットを定義
vdrj_w2v = vdrj_ds
#vdrj_w2v = deepcopy(vdrj_ds)   # deepcopy が動作しないのは何故？
for k, v in tqdm(vdrj_ds.data_dict.items()):
    lex = v['lexeme']
    orth = v['orth']
    vdrj_w2v.data_dict[k]['lex_w2v'] = lex in w2v
    vdrj_w2v.data_dict[k]['orth_w2v'] = orth in w2v


In [None]:
#type(w2v)
import gensim

In [None]:
#help(VDRJ_Dataset)
class vdrj_w2v_Dataset(VDRJ_Dataset):
    def __init__(self,
                 source:str="orth", # ['orth', 'sem', 'phon']
                 target:str="sem",  # ['orth', 'sem', 'phone']
                 vdrj_ds:VDRJ_Dataset=vdrj_ds,
                 w2v:gensim.models.keyedvectors.KeyedVectors=w2v):
        super().__init__()
        self.w2v = w2v
        
        self.source = source
        self.target = target
        
        #self.vdrj_ds = vdrj_ds
        self.vdrj_ds = vdrj_ds  # vdrj_ds を拡張して新しいデータセット `ds` を作成
        #self.vdrj_ds = deepcopy(vdrj_ds)  # vdrj_ds を拡張して新しいデータセット `ds` を作成
        data_dict = {}
        orth_list = []
        for k, v in tqdm(vdrj_ds.data_dict.items()):
            lex = v['lexeme']
            orth = v['orth']
            idx = len(data_dict)
            if (orth in w2v):
                data_dict[idx] = deepcopy(v)
                data_dict[idx]['defined'] = 'orth'
                data_dict[idx]['sem'] = w2v[orth]
                orth_list.append(orth)
            if (lex in w2v):
                data_dict[idx] = deepcopy(v)
                data_dict[idx]['defined'] = 'lex'
                data_dict[idx]['sem'] = w2v[lex]
                if not lex in orth_list:
                    orth_list.append(lex)
                    
        self.data_dict = data_dict
        self.orth_list = orth_list
        #super().set_source_and_target_from_params(source=source, target=target)
        self.set_source_and_target_from_params(source=source, target=target)
        
        
    def __len__(self):
        return len(self.data_dict)
    
    def __getitem__(self, idx):
        if self.source == 'sem':
            x = torch.tensor(self.data_dict[idx]['sem'], dtype=torch.float32)
        else:
            x = self.data_dict[idx][self.source]
            #src_ids = self.source_tkn2ids(x) + [self.source_list.index('<EOW>')]
            #y = src_ids
        
        if self.target != 'sem':
            y_tkn = self.data_dict[idx][self.target]
            y_ids = self.target_tkn2ids(y_tkn)
            print(f'type(y_tkn):{type(y_tkn)}, type(y_ids):{type(y_ids)}, y_ids:{y_ids}')
            tgt_ids = y_ids + [self.target_list.index('<EOW>')]
            y = tgt_ids
        else:
            y = self.data_dict[idx][self.source]
            #src_ids = self.source_tkn2ids(x) + [self.source_list.index('<EOW>')]
            #y = torch.tensor(self.data_dict[idx]['sem'], dtype=torch.float32)
        return x, y
        #return x, tgt_ids
            
        if self.target == 'sem':
            x = self.data_dict[idx][self.source]
            #src_ids = self.source_tkn2ids(x) + [self.source_list.index('<EOW>')]
            y = torch.tensor(self.data_dict[idx]['sem'], dtype=torch.float32)
            return x, y
        
    def set_source_and_target_from_params(self, source:str='orth', target:str='phon'):
        # ソースとターゲットを設定

        if source == 'orth':
            self.source_list = self.orth_list
            self.source_maxlen = self.orth_maxlen
            self.source_ids2tkn = self.orth_ids2tkn
            self.source_tkn2ids = self.orth_tkn2ids
        elif source == 'phon':
            self.source_list = self.phon_list
            self.source_maxlen = self.phon_maxlen
            self.source_ids2tkn = self.phon_ids2tkn
            self.source_tkn2ids = self.phon_tkn2ids
        elif source == 'sem':
            self.source_list = self.orth_list
            self.source_max_len = self.w2v.vector_size
            self.source_ids2tkn = None
            self.source_tkn2ids = None
        else:
            return None

        if target == 'orth':
            self.target_list = self.orth_list
            self.target_maxlen = self.orth_maxlen
            self.target_ids2tkn = self.orth_ids2tkn
            self.target_tkn2ids = self.orth_tkn2ids
        elif target == 'phon':
            self.target_list = self.phon_list
            self.target_maxlen = self.phon_maxlen
            self.target_ids2tkn = self.phon_ids2tkn
            self.target_tkn2ids = self.phon_tkn2ids
        elif target == 'sem':
            self.target_list = self.orth_list
            self.target_max_len = self.w2v.vector_size
            self.target_ids2tkn = None
            self.target_tkn2ids = None
        
o2s_ds = vdrj_w2v_Dataset(vdrj_ds=vdrj_ds, source='orth', target='sem')
o2o_ds = vdrj_w2v_Dataset(vdrj_ds=vdrj_ds, source='orth', target='orth')
o2p_ds = vdrj_w2v_Dataset(vdrj_ds=vdrj_ds, source='orth', target='phon')
s2o_ds = vdrj_w2v_Dataset(vdrj_ds=vdrj_ds, source='sem', target='orth')
s2s_ds = vdrj_w2v_Dataset(vdrj_ds=vdrj_ds, source='sem', target='sem')
s2p_ds = vdrj_w2v_Dataset(vdrj_ds=vdrj_ds, source='sem', target='phon')
p2o_ds = vdrj_w2v_Dataset(vdrj_ds=vdrj_ds, source='phon', target='orth')
p2s_ds = vdrj_w2v_Dataset(vdrj_ds=vdrj_ds, source='phon', target='sem')
p2p_ds = vdrj_w2v_Dataset(vdrj_ds=vdrj_ds, source='phon', target='phon')

In [None]:
for ds in [o2s_ds, o2o_ds, o2p_ds, s2o_ds, s2s_ds, s2p_ds, p2o_ds, p2s_ds, p2p_ds]:
    print(ds.__len__(), str(o2s_ds))
    print(ds.target)