# Training of text generation models

In [None]:
%matplotlib inline
from neurowriter.corpus import SingleTxtCorpus, MultiLineCorpus
from neurowriter.tokenizer import CharTokenizer, WordTokenizer, SubwordTokenizer
from neurowriter.models import DilatedConvModel, WavenetModel, StackedLSTMModel, LSTMModel, SmallLSTMModel

## Global config

Name of corpus file (without txt extension)

In [None]:
corpusname = "toyseries"

Corpus loader class to use

In [None]:
corpus = MultiLineCorpus

Tokenizer class to use

In [None]:
tokenizer = CharTokenizer

Network architecture class to use

In [None]:
architecture = SmallLSTMModel

Number of hyperoptimization trials (recommended at least 15)

In [None]:
hypertrials = 15

### Process config

Get all relevant file names

In [None]:
corpusfile = 'corpus/' + corpusname + '.txt'
encodername = corpusname + '.enc'
modelname = corpusname + '.h5'

## Load corpus

In [None]:
corpus = corpus()
corpus.load(corpusfile)

In [None]:
corpus[0][0:1000]

## Encoding

In [None]:
from neurowriter.encoding import Encoder
tokenizer = tokenizer()
encoder = Encoder(corpus, tokenizer)
encoder.save(encodername)

In [None]:
encoder.char2index

## Model training

Train the generator model, trying different hyperparameters and selecting the model producing lower loss in a  validation split of the data.

Note this might take a very long time, so during the optimization temporary versions of the model will be saved.

In [None]:
from neurowriter.optimizer import hypertrain

model, train_history = hypertrain(architecture, encoder, corpus, n_calls=hypertrials, 
                                  savemodel=modelname, verbose=1)
model.save(modelname)

## Generation test

In [None]:
from neurowriter.writer import Writer
from neurowriter.encoding import END

writer = Writer(model, encoder, beamsize=1, batchsize=1, creativity=0.5)

tokens = encoder.tokenizer.transform("")
seedtxt = "".join(tokens)
print("Seed:", seedtxt)
print("Generated:")
print(seedtxt, end='')
for token in writer.generate(seedtxt):
    print(token, end='')
    if token == END:
        print('\n')