### 문장생성 구현

In [7]:
import numpy as np
a = np.array([1,2,3,4,5,6,7,8,9,10])
p = np.array([0.6,0.1,0.2,0.1,0,0,0,0,0,0])
sampled = np.random.choice(len(a), size=1, p=p) # 다른 값들도 선택되기 때문에 다양한 문장을 생성할 수 있음~ 
# sampled = np.array(np.argmax(p)).reshape(1) # argmax: 결정적인 방식으로 확률이 가장 높은 값인 1만 추출됨.
# QA에서는 확률이 가장 높은 답을 추출해야 하므로 적절한 방법이지만 문장 생성에선 올바르지 않음. ! 
print(a[sampled])

[1]


In [21]:
# coding: utf-8
import sys
sys.path.append('..') # 부모디렉토리 패스
import numpy as np
from common.functions import softmax
from ch06.rnnlm import Rnnlm # 6장의 Rnnlm 호출
from ch06.better_rnnlm import BetterRnnlm

class MyRnnlmGen(Rnnlm): # Rnnlm 상속
    def generate(self, start_id, skip_ids=None, sample_size=100): # 문장 생성용 함수
        word_ids = [start_id]

        x = start_id
        while len(word_ids) < sample_size:
            x = np.array(x).reshape(1, 1) # 일단 한 단어인 x를 넣으면 다음 단어를 유추
            score = self.predict(x) # 예측 => 소프트 맥스를 태우기 전 linear 값들이 나옴
#             print('score.shape=', score.shape)  # (1,1,10000) : (N,T,V) 팬트리 뱅크는 10000보캡을 학습시켰으니까
#             print('score=', score) # 소프트 맥스를 태우기 전 linear 값들이 나옴
            
            # break; # 여기까지 테스트 해보기 위해
            
            
            p = softmax(score.flatten())        # (10000,)
#             print('p=',p)
#             print('len(p)=', len(p))
#             print('np.argmax(p)=', np.argmax(p))
#             print('np.max(p)=', np.max(p))
#             print(id_to_word[np.argmax(p)]) # <eos>가 출력됨. you다음 eos가 확률적으로 많이 나온다는 뜻. 하지만 불용어라 사용 안 함
#             break
            

            sampled = np.random.choice(len(p), size=1, p=p)
#             sampled = np.array(np.argmax(p)).reshape(1)
#             print('sampled=',sampled)
#             print(id_to_word[sampled[0]]) # you다음 나올 확률은 <eos>가 가장 컸으나 실행할 때마다 다른 단어가 출력됨
                # 따라서 우리는 매번 다른 문장을 생성해낼 수 있음.
#             break;
        
            if (skip_ids is None) or (sampled not in skip_ids):
                x = sampled
                word_ids.append(int(x))
#             print(word_ids)
        return word_ids

    def get_state(self):
        return self.lstm_layer.h, self.lstm_layer.c

    def set_state(self, state):
        self.lstm_layer.set_state(*state)


class BetterRnnlmGen(BetterRnnlm):
    def generate(self, start_id, skip_ids=None, sample_size=100):
        word_ids = [start_id]

        x = start_id
        while len(word_ids) < sample_size:
            x = np.array(x).reshape(1, 1)
            score = self.predict(x).flatten()
            p = softmax(score).flatten()

            sampled = np.random.choice(len(p), size=1, p=p)
            if (skip_ids is None) or (sampled not in skip_ids):
                x = sampled
                word_ids.append(int(x))

        return word_ids

    def get_state(self):
        states = []
        for layer in self.lstm_layers:
            states.append((layer.h, layer.c))
        return states

    def set_state(self, states):
        for layer, state in zip(self.lstm_layers, states):
            layer.set_state(*state)

### 문장생성을 위한 코드

In [25]:
# coding: utf-8
import sys
sys.path.append('..')
from rnnlm_gen import RnnlmGen
from dataset import ptb


corpus, word_to_id, id_to_word = ptb.load_data('train')
vocab_size = len(word_to_id)
corpus_size = len(corpus)

model = MyRnnlmGen()
model.load_params('../ch06/Rnnlm.pkl')

# start 문자와 skip 문자 설정
start_word = 'you'
start_id = word_to_id[start_word]
print(start_id)
skip_words = ['N', '<unk>', '$'] # 숫자, 희소단어, 화폐 = [27, 26, 416]
skip_ids = [word_to_id[w] for w in skip_words]
print(skip_ids) # [27, 26, 416] 는 패스하도록 
# 문장 생성
word_ids = model.generate(start_id, skip_ids, 100)
txt = ' '.join([id_to_word[i] for i in word_ids])
txt = txt.replace(' <eos>', '.\n')
print(txt)

316
[27, 26, 416]
you act.
 ignorance commonwealth they recapitalization grid pass comes.
 for end for last executive did request any a capitalize since imports or new funds hole rate.
 offer failed to wildly and trade corporate musical.
 polaroid said be ross a year towers n't not direct they unless about than drunk as are workers there and cypress benefit recover plc their fine and disputed a graduate also.
 including anywhere fell to native however gifts waxman investors under on spinoff pennzoil pour chaos though rather barrels than buffett a month.
 support not mr. financial to encouragement you


### 더 좋은 문장으로

In [30]:
# coding: utf-8
import sys
sys.path.append('..')
from common.np import *
from rnnlm_gen import BetterRnnlmGen
from dataset import ptb


corpus, word_to_id, id_to_word = ptb.load_data('train')
vocab_size = len(word_to_id)
corpus_size = len(corpus)


model = BetterRnnlmGen()
model.load_params('../ch06/BetterRnnlm/BetterRnnlm.pkl')

# start 문자와 skip 문자 설정
# start_word = 'you'
# start_id = word_to_id[start_word]
# skip_words = ['N', '<unk>', '$']
# skip_ids = [word_to_id[w] for w in skip_words]
# # 문장 생성
# word_ids = model.generate(start_id, skip_ids)
# txt = ' '.join([id_to_word[i] for i in word_ids])
# txt = txt.replace(' <eos>', '.\n')

# print(txt)


model.reset_state()

# start_words = 'the meaning of life is' # 입력 문장. 어떤 걸 넣어도 상관없음~
start_words = 'you'
start_ids = [word_to_id[w] for w in start_words.split(' ')]
print(start_ids)
print(start_ids[:-1])

for x in start_ids[:-1]:
    x = np.array(x).reshape(1, 1)
    model.predict(x)

# word_ids = model.generate(start_ids[-1], skip_ids)
# word_ids = start_ids[:-1] + word_ids
# txt = ' '.join([id_to_word[i] for i in word_ids])
# txt = txt.replace(' <eos>', '.\n')
# print('-' * 50)
print(txt)


[316]
[]
you act.
 ignorance commonwealth they recapitalization grid pass comes.
 for end for last executive did request any a capitalize since imports or new funds hole rate.
 offer failed to wildly and trade corporate musical.
 polaroid said be ross a year towers n't not direct they unless about than drunk as are workers there and cypress benefit recover plc their fine and disputed a graduate also.
 including anywhere fell to native however gifts waxman investors under on spinoff pennzoil pour chaos though rather barrels than buffett a month.
 support not mr. financial to encouragement you
