# Chpater 7 : Generate Sentence using RNN

In [24]:
import sys
sys.path.append('..')
import numpy as np
from common.functions import softmax
from common.rnnlm import Rnnlm
from common.better_rnnlm import BetterRnnlm

class RnnlmGen(Rnnlm):
    def generate(self, start_id, skip_ids = None, sample_size=100):
        word_ids = [start_id]
        x = start_id
        while len(word_ids) < sample_size:
            x = np.array(x).reshape(1, 1)
            score = self.predict(x)
            p = softmax(score.flatten())
            sampled = np.random.choice(len(p), size=1, p=p)
            if (skip_ids is None) or (sampled not in skip_ids):
                x = sampled
                word_ids.append(int(x))
        
        return word_ids

In [25]:
from dataset import ptb

corpus, word_to_id, id_to_word = ptb.load_data('train')
vocab_size = len(word_to_id)
corpus_size = len(corpus)

model = RnnlmGen()
model.load_params('Rnnlm.pkl')

start_word = 'you'
start_id = word_to_id[start_word]
skip_words= ['N', '<unk>', '$']
skip_ids = [word_to_id[w] for w in skip_words]

word_ids = model.generate(start_id, skip_ids)
txt = ' '.join([id_to_word[i] for i in word_ids])
txt = txt.replace('<eos>', '.\n')
print(txt)

you are n't considering mr. deaver .
 mr. hahn can drexel has no reporter high in the panel .
 mr. roman 's speech was named chairman of britain 's leading office has said .
 moreover the administration also relied all usx as it will repurchase the raising results for loans to slow .
 some unions have bought support for assets to strike and a paying different support and is going to go off .
 whatever the dollar jumped to five billion yen .
 frankfurt 's stock surged sharply in august over the u.s. non-u.s. capital markets according to dow jones


Better RNNLM 사용한 generate

In [29]:
class BetterRnnlmGen(BetterRnnlm):
    def generate(self, start_id, skip_ids=None, sample_size=100):
        word_ids = [start_id]

        x = start_id
        while len(word_ids) < sample_size:
            x = np.array(x).reshape(1, 1)
            score = self.predict(x).flatten()
            p = softmax(score).flatten()

            sampled = np.random.choice(len(p), size=1, p=p)
            if (skip_ids is None) or (sampled not in skip_ids):
                x = sampled
                word_ids.append(int(x))

        return word_ids

    def get_state(self):
        states = []
        for layer in self.lstm_layers:
            states.append((layer.h, layer.c))
        return states

    def set_state(self, states):
        for layer, state in zip(self.lstm_layers, states):
            layer.set_state(*state)

In [33]:
import sys
sys.path.append('..')
from common.np import *
from dataset import ptb


corpus, word_to_id, id_to_word = ptb.load_data('train')
vocab_size = len(word_to_id)
corpus_size = len(corpus)


model = BetterRnnlmGen()
model.load_params('./BetterRnnlm.pkl')

# start 문자와 skip 문자 설정
start_word = 'you'
start_id = word_to_id[start_word]
skip_words = ['N', '<unk>', '$']
skip_ids = [word_to_id[w] for w in skip_words]
# 문장 생성
word_ids = model.generate(start_id, skip_ids, sample_size=50)
txt = ' '.join([id_to_word[i] for i in word_ids])
txt = txt.replace(' <eos>', '.\n')

print(txt)

# model.reset_state()

# start_words = 'the meaning of life is'
# start_ids = [word_to_id[w] for w in start_words.split(' ')]

# for x in start_ids[:-1]:
#     x = np.array(x).reshape(1, 1)
#     model.predict(x)

# word_ids = model.generate(start_ids[-1], skip_ids)
# word_ids = start_ids[:-1] + word_ids
# txt = ' '.join([id_to_word[i] for i in word_ids])
# txt = txt.replace(' <eos>', '.\n')
# print('-' * 50)
# print(txt)

you shifted many people.
 these advantages are not entirely on boys and someone believes is that it 's unfortunate to resist a new cement name but they describe us now favors personal reasons.
 apparently the best firms are attractive to medium-sized policies of cigna and pork prices.



In [36]:
txt = '''you shifted many people.
 these advantages are not entirely on boys and someone believes is that it 's unfortunate to resist a new cement name but they describe us now favors personal reasons.
 apparently the best firms are attractive to medium-sized policies of cigna and pork prices.'''

len(txt.split())

# <EOS>를 replace 하여 47개

47