Skip to content

Paperspace/training-lstm

 
 

Repository files navigation

Training a LSTM network

Multi-layer Recurrent Neural Networks (LSTM, RNN) for character-level language models in Python using Tensorflow and modified to work with tensorflow.js and ml5js

Based on char-rnn-tensorflow.

Requirements

Set up a python environment with tensorflow installed. More detailed instructions here

Usage

1) Train

Inside the /data folder, create a new folder with the name of your data. Inside that folder should be one file called input.txt (A quick tip to concatenate many small disparate .txt files into one large training file: ls *.txt | xargs -L 1 cat >> input.txt) Then run:

python train.py --data_dir=./data/my_own_data

You can specify the hyperparameters:

python train.py --data_dir=./data/my_own_data --rnn_size 128 --num_layers 2 --seq_length 64 --batch_size 32 --num_epochs 1000

This will train your model and save the model, in the globals ./models folder, in a format usable in javascript.

2) Run!

To work with the model in ML5, you'll just need to point to the new folder in your sketch:

var lstm = new ml5.LSTMGenerator('./models/your_new_model');

That's it!

Hyperparameters

Thanks to Ross Goodwin!

Given the size of the training dataset, you can use:

  • 2 MB:
    • rnn_size 256 (or 128)
    • layers 2
    • seq_length 64
    • batch_size 32
    • dropout 0.25
  • 5-8 MB:
    • rnn_size 512
    • layers 2 (or 3)
    • seq_length 128
    • batch_size 64
    • dropout 0.25
  • 10-20 MB:
    • rnn_size 1024
    • layers 2 (or 3)
    • seq_length 128 (or 256)
    • batch_size 128
    • dropout 0.25
  • 25+ MB:
    • rnn_size 2048
    • layers 2 (or 3)
    • seq_length 256 (or 128)
    • batch_size 128
    • dropout 0.25

Tensorboard

Tensorflow comes with Tensorboard: "a suite of web applications for inspecting and understanding your TensorFlow runs and graphs.".

To visualize training progress, model graphs, and internal state histograms: fire up Tensorboard and point it at your log_dir:

$ tensorboard --logdir=./logs/

Then open a browser to http://localhost:6006 or the correct IP/Port specified.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 84.7%
  • JavaScript 7.6%
  • HTML 4.4%
  • CSS 2.4%
  • Shell 0.9%