# Toy training of text generation models

In [None]:
%matplotlib inline

## Global config

Name of corpus file (without txt extension)

In [None]:
corpusname = "apocalipsis"

Name of the tokenizer to use

In [None]:
tokenizer = "word"

Number of past input tokens to use for generation

In [None]:
inputtokens = 8

Network architecture to use

In [None]:
architecture = "lstm"

Architecture parameters

In [None]:
# wavenet
params = [
    8, # kernels
    1, # wavenetblocks
    0.1, # dropout
    32, #embedding
    'rmsprop' # optimizer
]
# dilatedconv
params = [
    1, # convlayers
    8, # kernels
    0.1, # convdrop
    1, # denselayers
    16, # dense units
    0.1, # densedrop
    32, # size of the embedding
    'adam', # optimizer
]
# lstm
params = [
    1, #layers
    16, #units
    0.1, #dropout
    32, #embedding
    'rmsprop' #optimizer
]

### Process config

Get all relevant file names

In [None]:
corpusfile = 'corpus/' + corpusname + '.txt'
encodername = corpusname + '.enc'
modelname = corpusname + '.h5'

Obtain model class

In [None]:
from neurowriter.models import modelbyname
modelclass = modelbyname(architecture)

## Load corpus

In [None]:
with open(corpusfile) as f:
    corpus = f.read()

In [None]:
corpus[0:min(1000,len(corpus))]

## Encoding

In [None]:
from neurowriter.encoding import Encoder
encoder = Encoder(corpus, tokenizer)
encoder.save(encodername)

## Model training

Train the generator model, trying different hyperparameters and selecting the model producing lower loss in a  validation split of the data.

Note this might take a very long time, so during the optimization temporary versions of the model will be saved.

In [None]:
from neurowriter.optimizer import trainmodel

model, train_history = trainmodel(modelclass, inputtokens, encoder, corpus, verbose=True, maxepochs=100, 
                                  modelparams=params)
model.save(modelname)

In [None]:
model.summary()

## Generation test

In [None]:
from neurowriter.writer import Writer

writer = Writer(model, encoder, creativity=0.1)
tokens = encoder.tokenizer.transform(corpus)
seed = tokens[:inputtokens]
seedtxt = encoder.tokenizer.intertoken.join(seed)
print(seedtxt)
''.join(writer.write(seed=seedtxt))