In [46]:
from dotenv import load_dotenv; load_dotenv()
import sys, os; sys.path.append(os.getenv('MP'))

import pickle
import numpy as np
import nn

In [47]:
with open(r'models/pico_seuss_001/model.pkl', 'rb') as f:
  model = pickle.load(f)

with open(r'data/word_processors.pkl', 'rb') as f:
  tokenizer, vectorizer = pickle.load(f)

In [99]:
def prepare_input(input_sentence, target_seq_len):
  # Tokenize and vectorize
  tokens = tokenizer.transform(input_sentence)
  tokens = [' '.join(t).split() for t in tokens]
  vectors = vectorizer.transform(tokens)
  vectors = np.array(vectors)

  # Pad to target seq len
  if vectors.shape[-1] < target_seq_len:
    pad_width = 16 - vectors[0].shape[0]
    vectors = np.pad(vectors, ((0, 0), (0, pad_width)))

  # Add batch dimension (batch, doc, seq)
  input = vectors[np.newaxis,:,-target_seq_len:]

  return input

In [141]:
test_text = ['I went to the park and saw a cat. I will go to the fox.']
input = prepare_input(test_text, target_seq_len=16)

In [157]:
def generate(data, gen_len):
  input = data
  seq_len = input.shape[-1]

  # Create an output of size gen_length + input seq_len
  output = np.empty((gen_len + seq_len), dtype=int)
  output[:seq_len] = data

  for i in range(gen_len):
    next_word = np.argmax(model.predict(input)[:,:,1:])+1 # offset to fix the fact we avoid index 0 <UNK>
    input[:,:,:seq_len-1] = input[:,:,1:] # shift input by 1
    input[:,:,-1] = next_word # add new word to this space
    output[i + seq_len] = next_word
  return output

In [167]:
output = generate(input, 20)
' '.join(vectorizer.inverse_transform([output])[0])

'cat will would cat cat cat cat there there there . some will will saw cat cat will would cat cat cat cat there there there . some will will saw cat cat will would cat'