In [35]:
import numpy as np 
from janome.tokenizer import Tokenizer
from gensim.models.keyedvectors import KeyedVectors
import torch
from torch import nn
import pickle

In [71]:
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

In [17]:
tkz = Tokenizer()
s = '私は犬が好き。'
ws = [w for w in tkz.tokenize(s, wakati=True)]

In [18]:
ws

['私', 'は', '犬', 'が', '好き', '。']

In [19]:
w2v = KeyedVectors.load_word2vec_format('./hidden_files/entity_vector.model.bin', binary=True)

In [20]:
xn = torch.tensor([w2v[w] for w in np.array(ws)])
print(xn.shape)

# LSTMの入力のためにバッチ化
xn = xn.unsqueeze(0)

print(xn.shape)

torch.Size([6, 200])
torch.Size([1, 6, 200])


In [31]:
# LSTM

lstm = nn.LSTM(200, 200, batch_first=True)
h0 = torch.randn(1, 1, 200)
c0 = torch.randn(1, 1, 200)
yn, (hn, cn) = lstm(xn, (h0, c0))
print(yn.shape)
print(hn.shape)
print(cn.shape)

torch.Size([1, 6, 200])
torch.Size([1, 1, 200])
torch.Size([1, 1, 200])


tensor([[-0.9155,  0.9778, -0.4378,  ...,  3.0303, -1.9216,  0.0984],
        [ 0.8603,  0.3643, -1.3381,  ..., -0.4042,  1.1210, -0.4190],
        [ 0.0810,  0.4326,  0.3752,  ...,  1.3360,  0.3233, -1.9915],
        [ 0.4563,  0.2203, -2.7309,  ..., -2.1273,  0.1671, -0.2642],
        [-1.0056, -3.1306, -0.8892,  ..., -1.6494,  0.1720, -1.8252],
        [ 1.4522,  0.0427, -1.3146,  ..., -0.6015, -0.0347,  0.4128]])

In [62]:
with open('./hidden_files/LSTM/dic.pkl', 'br') as f:
    dic = pickle.load(f)

for i in dic.items():
    print(dict([i]))
    break

{'万能': 1}


In [73]:
# 訓練データの確認

with open('./hidden_files/LSTM/xtrain.pkl', 'br') as f:
    xdata = pickle.load(f)

with open('./hidden_files/LSTM/ytrain.pkl', 'br') as f:
    ydata = pickle.load(f)

with open('./hidden_files/LSTM/label.pkl', 'br') as f:
    labels = pickle.load(f)

In [68]:
print('訓練データのバッチ６番目は：',xdata[6])
print('正解データのバッチ６番目は：',ydata[6])

訓練データのバッチ６番目は： [74, 75, 2, 60, 76, 62, 5, 6]
正解データのバッチ６番目は： [9, 0, 1, 8, 5, 7, 3, 4]


In [70]:
class MyLSTM(nn.Module):
    def __init__(self, voccsize, posn, hdim):
        super().__init__()
        self.embed = nn.Embedding(voccsize, hdim)
        self.lstm = nn.LSTM(hdim, hdim, batch_first=True)
        self.ln = nn.Linear(hdim, posn)

    def forward(self, x):
        ex = self.embed(x)
        lo = self.lstm(ex)
        out = self.ln(lo)
        return out
        

In [None]:
net = MyLSTM(len(dic)+1, len(labels), 100) # dic:word2idに0を除いているので+1


In [75]:
dic

{'万能': 1,
 'で': 2,
 'は': 3,
 'ない': 4,
 'です': 5,
 '。': 6,
 'も': 7,
 '、': 8,
 '他': 9,
 'の': 10,
 '書籍': 11,
 '勉強': 12,
 'し': 13,
 'て': 14,
 'い': 15,
 '行き詰っ': 16,
 'た': 17,
 'とき': 18,
 'に': 19,
 'すごく': 20,
 'コンパクト': 21,
 'まとまっ': 22,
 'いる': 23,
 '便利': 24,
 'これ': 25,
 '一': 26,
 '冊': 27,
 'なん': 28,
 'と': 29,
 'か': 30,
 'しよう': 31,
 'する': 32,
 '危険': 33,
 'が': 34,
 '独学': 35,
 '暗中': 36,
 '模索': 37,
 'な': 38,
 '人': 39,
 'けっこう': 40,
 '役': 41,
 '立つ': 42,
 '思い': 43,
 'ます': 44,
 'SketchUp': 45,
 '本': 46,
 'なかなか': 47,
 '出': 48,
 '探し': 49,
 'まし': 50,
 '基本': 51,
 '操作': 52,
 'ほとんど': 53,
 '網羅': 54,
 '初心': 55,
 '者': 56,
 '特に': 57,
 '建築': 58,
 '系': 59,
 'とても': 60,
 '使い': 61,
 'やすい': 62,
 'ある': 63,
 '程度': 64,
 '判っ': 65,
 '機能': 66,
 '別': 67,
 '構成': 68,
 'さ': 69,
 'れ': 70,
 'リファレンス': 71,
 '的': 72,
 '使え': 73,
 'フル': 74,
 'カラー': 75,
 '見': 76,
 'セイ': 77,
 'リング': 78,
 '中': 79,
 '万一': 80,
 'こと': 81,
 'あっ': 82,
 '参考': 83,
 'なる': 84,
 '思っ': 85,
 '読ん': 86,
 'だ': 87,
 '「': 88,
 'のど': 89,
 '渇い': 90,
 '最後': 91,
 '水': 92,
 