This demo code of LSTM is from [Word Language Model](https://github.com/pytorch/examples/tree/main/word_language_model)

In [None]:
import os
import time
import math
import torch
import argparse
import torch.onnx
from io import open
import torch.nn as nn
import torch.nn.functional as F

## Word-level Language Modeling using RNN and Transformer

> This example trains a multi-layer RNN (Elman, GRU, or LSTM) or Transformer on a language modeling task. By default, the training script uses the Wikitext-2 dataset, provided. The trained model can then be used by the generate script to generate new text.

```Python
python main.py --cuda --epochs 6           # Train a LSTM on Wikitext-2 with CUDA.
python main.py --cuda --epochs 6 --tied    # Train a tied LSTM on Wikitext-2 with CUDA.
python main.py --cuda --tied               # Train a tied LSTM on Wikitext-2 with CUDA for 40 epochs.
python main.py --cuda --epochs 6 --model Transformer --lr 5 # Train a Transformer model on Wikitext-2 with CUDA.
python generate.py                         # Generate samples from the default model checkpoint.
```

This example trains a multi-layer RNN (Elman, GRU, or LSTM) or Transformer on a language modeling task. By default, the training script uses the Wikitext-2 dataset, provided.
The trained model can then be used by the generate script to generate new text.

```bash
python main.py --cuda --epochs 6           # Train a LSTM on Wikitext-2 with CUDA.
python main.py --cuda --epochs 6 --tied    # Train a tied LSTM on Wikitext-2 with CUDA.
python main.py --cuda --tied               # Train a tied LSTM on Wikitext-2 with CUDA for 40 epochs.
python main.py --cuda --epochs 6 --model Transformer --lr 5
                                           # Train a Transformer model on Wikitext-2 with CUDA.

python generate.py                         # Generate samples from the default model checkpoint.
```

The model uses the `nn.RNN` module (and its sister modules `nn.GRU` and `nn.LSTM`) or Transformer module (`nn.TransformerEncoder` and `nn.TransformerEncoderLayer`) which will automatically use the cuDNN backend if run on CUDA with cuDNN installed.

During training, if a keyboard interrupt (Ctrl-C) is received, training is stopped and the current model is evaluated against the test dataset.

The `main.py` script accepts the following arguments:

```bash
optional arguments:
  -h, --help            show this help message and exit
  --data DATA           location of the data corpus
  --model MODEL         type of network (RNN_TANH, RNN_RELU, LSTM, GRU, Transformer)
  --emsize EMSIZE       size of word embeddings
  --nhid NHID           number of hidden units per layer
  --nlayers NLAYERS     number of layers
  --lr LR               initial learning rate
  --clip CLIP           gradient clipping
  --epochs EPOCHS       upper epoch limit
  --batch_size N        batch size
  --bptt BPTT           sequence length
  --dropout DROPOUT     dropout applied to layers (0 = no dropout)
  --tied                tie the word embedding and softmax weights
  --seed SEED           random seed
  --cuda                use CUDA
  --mps                 enable GPU on macOS
  --log-interval N      report interval
  --save SAVE           path to save the final model
  --onnx-export ONNX_EXPORT
                        path to export the final model in onnx format
  --nhead NHEAD         the number of heads in the encoder/decoder of the transformer model
  --dry-run             verify the code and the model
```

With these arguments, a variety of models can be tested.
As an example, the following arguments produce slower but better models:

```bash
python main.py --cuda --emsize 650 --nhid 650 --dropout 0.5 --epochs 40
python main.py --cuda --emsize 650 --nhid 650 --dropout 0.5 --epochs 40 --tied
python main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40
python main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40 --tied
```

In [None]:
# say you do NOT have a GPU
%run main.py --mps --epochs 6 --data './dataset/wikitext-2'

### say you do have a GPU
%run main.py --cuda --epochs 6 --data './dataset/wikitext-2'

```bash
| epoch   1 |   200/ 2983 batches | lr 20.00 | ms/batch 10.07 | loss  7.63 | ppl  2067.38
| epoch   1 |   400/ 2983 batches | lr 20.00 | ms/batch  5.42 | loss  6.85 | ppl   944.25
| epoch   1 |   600/ 2983 batches | lr 20.00 | ms/batch  5.42 | loss  6.47 | ppl   645.40
| epoch   1 |   800/ 2983 batches | lr 20.00 | ms/batch  5.40 | loss  6.28 | ppl   534.20
| epoch   1 |  1000/ 2983 batches | lr 20.00 | ms/batch  5.42 | loss  6.13 | ppl   461.01
| epoch   1 |  1200/ 2983 batches | lr 20.00 | ms/batch  5.42 | loss  6.05 | ppl   424.65
| epoch   1 |  1400/ 2983 batches | lr 20.00 | ms/batch  5.43 | loss  5.94 | ppl   380.46
| epoch   1 |  1600/ 2983 batches | lr 20.00 | ms/batch  5.42 | loss  5.95 | ppl   381.99
| epoch   1 |  1800/ 2983 batches | lr 20.00 | ms/batch  5.43 | loss  5.81 | ppl   332.67
| epoch   1 |  2000/ 2983 batches | lr 20.00 | ms/batch  5.44 | loss  5.78 | ppl   322.75
| epoch   1 |  2200/ 2983 batches | lr 20.00 | ms/batch  5.43 | loss  5.65 | ppl   284.76
| epoch   1 |  2400/ 2983 batches | lr 20.00 | ms/batch  5.43 | loss  5.66 | ppl   287.58
| epoch   1 |  2600/ 2983 batches | lr 20.00 | ms/batch  5.44 | loss  5.65 | ppl   282.93
| epoch   1 |  2800/ 2983 batches | lr 20.00 | ms/batch  5.42 | loss  5.54 | ppl   254.04
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 17.90s | valid loss  5.57 | valid ppl   262.00
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 2983 batches | lr 20.00 | ms/batch  8.49 | loss  5.54 | ppl   254.11
| epoch   2 |   400/ 2983 batches | lr 20.00 | ms/batch  7.83 | loss  5.52 | ppl   250.65
| epoch   2 |   600/ 2983 batches | lr 20.00 | ms/batch  7.71 | loss  5.35 | ppl   209.78
| epoch   2 |   800/ 2983 batches | lr 20.00 | ms/batch  7.87 | loss  5.36 | ppl   213.51
| epoch   2 |  1000/ 2983 batches | lr 20.00 | ms/batch  7.25 | loss  5.34 | ppl   207.89
| epoch   2 |  1200/ 2983 batches | lr 20.00 | ms/batch  6.39 | loss  5.32 | ppl   204.74
| epoch   2 |  1400/ 2983 batches | lr 20.00 | ms/batch  6.27 | loss  5.32 | ppl   203.56
| epoch   2 |  1600/ 2983 batches | lr 20.00 | ms/batch  6.05 | loss  5.38 | ppl   216.95
| epoch   2 |  1800/ 2983 batches | lr 20.00 | ms/batch  6.32 | loss  5.25 | ppl   191.48
| epoch   2 |  2000/ 2983 batches | lr 20.00 | ms/batch  6.33 | loss  5.26 | ppl   191.68
| epoch   2 |  2200/ 2983 batches | lr 20.00 | ms/batch  6.32 | loss  5.16 | ppl   173.77
| epoch   2 |  2400/ 2983 batches | lr 20.00 | ms/batch  5.87 | loss  5.20 | ppl   180.45
| epoch   2 |  2600/ 2983 batches | lr 20.00 | ms/batch  5.57 | loss  5.20 | ppl   182.17
| epoch   2 |  2800/ 2983 batches | lr 20.00 | ms/batch  7.69 | loss  5.12 | ppl   167.62
-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 21.95s | valid loss  5.29 | valid ppl   198.80
-----------------------------------------------------------------------------------------
| epoch   3 |   200/ 2983 batches | lr 20.00 | ms/batch  7.97 | loss  5.18 | ppl   177.42
| epoch   3 |   400/ 2983 batches | lr 20.00 | ms/batch  8.02 | loss  5.20 | ppl   180.50
| epoch   3 |   600/ 2983 batches | lr 20.00 | ms/batch  8.08 | loss  5.01 | ppl   150.08
| epoch   3 |   800/ 2983 batches | lr 20.00 | ms/batch  6.72 | loss  5.06 | ppl   157.41
| epoch   3 |  1000/ 2983 batches | lr 20.00 | ms/batch  6.39 | loss  5.05 | ppl   155.89
| epoch   3 |  1200/ 2983 batches | lr 20.00 | ms/batch  6.36 | loss  5.05 | ppl   155.51
| epoch   3 |  1400/ 2983 batches | lr 20.00 | ms/batch  6.37 | loss  5.06 | ppl   157.64
| epoch   3 |  1600/ 2983 batches | lr 20.00 | ms/batch  6.35 | loss  5.13 | ppl   169.73
| epoch   3 |  1800/ 2983 batches | lr 20.00 | ms/batch  6.09 | loss  5.02 | ppl   150.91
| epoch   3 |  2000/ 2983 batches | lr 20.00 | ms/batch  5.66 | loss  5.03 | ppl   153.63
| epoch   3 |  2200/ 2983 batches | lr 20.00 | ms/batch  5.46 | loss  4.94 | ppl   140.35
| epoch   3 |  2400/ 2983 batches | lr 20.00 | ms/batch  9.13 | loss  4.98 | ppl   145.15
| epoch   3 |  2600/ 2983 batches | lr 20.00 | ms/batch  7.70 | loss  5.00 | ppl   147.83
| epoch   3 |  2800/ 2983 batches | lr 20.00 | ms/batch  6.41 | loss  4.92 | ppl   137.43
-----------------------------------------------------------------------------------------
| end of epoch   3 | time: 22.00s | valid loss  5.17 | valid ppl   176.16
-----------------------------------------------------------------------------------------
| epoch   4 |   200/ 2983 batches | lr 20.00 | ms/batch  8.19 | loss  4.99 | ppl   147.10
| epoch   4 |   400/ 2983 batches | lr 20.00 | ms/batch  7.32 | loss  5.01 | ppl   150.56
| epoch   4 |   600/ 2983 batches | lr 20.00 | ms/batch  6.47 | loss  4.83 | ppl   125.07
| epoch   4 |   800/ 2983 batches | lr 20.00 | ms/batch  6.53 | loss  4.88 | ppl   131.54
| epoch   4 |  1000/ 2983 batches | lr 20.00 | ms/batch  6.38 | loss  4.88 | ppl   131.71
| epoch   4 |  1200/ 2983 batches | lr 20.00 | ms/batch  6.37 | loss  4.88 | ppl   132.08
| epoch   4 |  1400/ 2983 batches | lr 20.00 | ms/batch  6.35 | loss  4.91 | ppl   135.46
| epoch   4 |  1600/ 2983 batches | lr 20.00 | ms/batch  6.26 | loss  4.99 | ppl   146.95
| epoch   4 |  1800/ 2983 batches | lr 20.00 | ms/batch  5.75 | loss  4.87 | ppl   130.17
| epoch   4 |  2000/ 2983 batches | lr 20.00 | ms/batch  5.50 | loss  4.90 | ppl   133.75
| epoch   4 |  2200/ 2983 batches | lr 20.00 | ms/batch  8.23 | loss  4.80 | ppl   121.67
| epoch   4 |  2400/ 2983 batches | lr 20.00 | ms/batch  7.65 | loss  4.85 | ppl   127.14
| epoch   4 |  2600/ 2983 batches | lr 20.00 | ms/batch  6.55 | loss  4.86 | ppl   129.17
| epoch   4 |  2800/ 2983 batches | lr 20.00 | ms/batch  6.40 | loss  4.79 | ppl   120.67
-----------------------------------------------------------------------------------------
| end of epoch   4 | time: 21.87s | valid loss  5.08 | valid ppl   161.27
-----------------------------------------------------------------------------------------
| epoch   5 |   200/ 2983 batches | lr 20.00 | ms/batch  6.89 | loss  4.87 | ppl   130.07
| epoch   5 |   400/ 2983 batches | lr 20.00 | ms/batch  6.42 | loss  4.89 | ppl   132.72
| epoch   5 |   600/ 2983 batches | lr 20.00 | ms/batch  6.44 | loss  4.71 | ppl   110.66
| epoch   5 |   800/ 2983 batches | lr 20.00 | ms/batch  6.27 | loss  4.76 | ppl   116.53
| epoch   5 |  1000/ 2983 batches | lr 20.00 | ms/batch  6.36 | loss  4.77 | ppl   117.59
| epoch   5 |  1200/ 2983 batches | lr 20.00 | ms/batch  6.34 | loss  4.77 | ppl   117.98
| epoch   5 |  1400/ 2983 batches | lr 20.00 | ms/batch  6.32 | loss  4.81 | ppl   122.37
| epoch   5 |  1600/ 2983 batches | lr 20.00 | ms/batch  6.02 | loss  4.89 | ppl   132.74
| epoch   5 |  1800/ 2983 batches | lr 20.00 | ms/batch  5.70 | loss  4.77 | ppl   117.36
| epoch   5 |  2000/ 2983 batches | lr 20.00 | ms/batch  5.50 | loss  4.80 | ppl   121.27
| epoch   5 |  2200/ 2983 batches | lr 20.00 | ms/batch  8.36 | loss  4.70 | ppl   109.46
| epoch   5 |  2400/ 2983 batches | lr 20.00 | ms/batch  6.29 | loss  4.74 | ppl   114.22
| epoch   5 |  2600/ 2983 batches | lr 20.00 | ms/batch  6.40 | loss  4.77 | ppl   117.47
| epoch   5 |  2800/ 2983 batches | lr 20.00 | ms/batch  6.43 | loss  4.70 | ppl   110.09
-----------------------------------------------------------------------------------------
| end of epoch   5 | time: 20.37s | valid loss  5.03 | valid ppl   152.96
-----------------------------------------------------------------------------------------
| epoch   6 |   200/ 2983 batches | lr 20.00 | ms/batch  7.66 | loss  4.77 | ppl   118.17
| epoch   6 |   400/ 2983 batches | lr 20.00 | ms/batch  6.40 | loss  4.79 | ppl   120.76
| epoch   6 |   600/ 2983 batches | lr 20.00 | ms/batch  6.82 | loss  4.61 | ppl   100.75
| epoch   6 |   800/ 2983 batches | lr 20.00 | ms/batch  6.48 | loss  4.67 | ppl   106.50
| epoch   6 |  1000/ 2983 batches | lr 20.00 | ms/batch  6.38 | loss  4.68 | ppl   108.10
| epoch   6 |  1200/ 2983 batches | lr 20.00 | ms/batch  6.39 | loss  4.69 | ppl   108.64
| epoch   6 |  1400/ 2983 batches | lr 20.00 | ms/batch  6.39 | loss  4.73 | ppl   112.86
| epoch   6 |  1600/ 2983 batches | lr 20.00 | ms/batch  6.45 | loss  4.81 | ppl   122.60
| epoch   6 |  1800/ 2983 batches | lr 20.00 | ms/batch  6.36 | loss  4.68 | ppl   108.20
| epoch   6 |  2000/ 2983 batches | lr 20.00 | ms/batch  6.07 | loss  4.73 | ppl   112.95
| epoch   6 |  2200/ 2983 batches | lr 20.00 | ms/batch  5.63 | loss  4.62 | ppl   101.83
| epoch   6 |  2400/ 2983 batches | lr 20.00 | ms/batch  5.51 | loss  4.66 | ppl   105.91
| epoch   6 |  2600/ 2983 batches | lr 20.00 | ms/batch  5.50 | loss  4.69 | ppl   108.97
| epoch   6 |  2800/ 2983 batches | lr 20.00 | ms/batch  8.08 | loss  4.62 | ppl   101.51
-----------------------------------------------------------------------------------------
| end of epoch   6 | time: 20.44s | valid loss  5.00 | valid ppl   148.26
-----------------------------------------------------------------------------------------
=========================================================================================
| End of training | test loss  4.93 | test ppl   137.91
=========================================================================================
```

In [None]:
# Generate samples from the default model checkpoint.
%run generate.py --data './dataset/wikitext-2' --checkpoint './model-cuda.pt' --outf 'generated.txt'

In [None]:
%%bash
cat generated.txt