In [None]:
!pip install git+https://github.com/tientr/aitextgen.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/tientr/aitextgen.git
  Cloning https://github.com/tientr/aitextgen.git to /tmp/pip-req-build-xmo1eub3
  Running command git clone -q https://github.com/tientr/aitextgen.git /tmp/pip-req-build-xmo1eub3


In [None]:
import logging
logging.basicConfig(
        format="%(asctime)s — %(levelname)s — %(name)s — %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
    )

from aitextgen import aitextgen
from aitextgen.colab import mount_gdrive, copy_file_from_gdrive

08/04/2022 14:49:51 — INFO — torch.distributed.nn.jit.instantiator — Created a temporary directory at /tmp/tmplu6nt2dw
08/04/2022 14:49:51 — INFO — torch.distributed.nn.jit.instantiator — Writing /tmp/tmplu6nt2dw/_remote_module_non_scriptable.py
08/04/2022 14:49:55 — INFO — numexpr.utils — NumExpr defaulting to 2 threads.


In [None]:
%%capture
ai = aitextgen(tf_gpt2="124M", to_gpu=True)


08/04/2022 14:49:55 — INFO — aitextgen — Loading 124M GPT-2 model from /aitextgen.
08/04/2022 14:49:57 — INFO — aitextgen — GPT2 loaded with 124M parameters.
08/04/2022 14:49:57 — INFO — aitextgen — Using the default GPT-2 Tokenizer.


In [None]:
mount_gdrive()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
file_name = "processed_quora.txt" 
!mv /content/processed_quora.txt /content/drive/My\ Drive/processed_quora.txt

In [None]:
copy_file_from_gdrive(file_name)

## Finetune GPT-2

The next cell will start the actual finetuning of GPT-2 in aitextgen. It runs for `num_steps`, and a progress bar will appear to show training progress, current loss (the lower the better the model), and average loss (to give a sense on loss trajectory).

The model will be saved every `save_every` steps in `trained_model` by default, and when training completes. If you mounted your Google Drive, the model will _also_ be saved there in a unique folder.

The training might time out after 4ish hours; if you did not mount to Google Drive, make sure you end training and save the results so you don't lose them! (if this happens frequently, you may want to consider using [Colab Pro](https://colab.research.google.com/signup))

Important parameters for `train()`:

- **`line_by_line`**: Set this to `True` if the input text file is a single-column CSV, with one record per row. aitextgen will automatically process it optimally.
- **`from_cache`**: If you compressed your dataset locally (as noted in the previous section) and are using that cache file, set this to `True`.
- **`num_steps`**: Number of steps to train the model for.
- **`generate_every`**: Interval of steps to generate example text from the model; good for qualitatively validating training.
- **`save_every`**: Interval of steps to save the model: the model will be saved in the VM to `/trained_model`.
- **`save_gdrive`**: Set this to `True` to copy the model to a unique folder in your Google Drive, if you have mounted it in the earlier cells
- **`fp16`**: Enables half-precision training for faster/more memory-efficient training. Only works on a T4 or V100 GPU.

Here are other important parameters for `train()` that are useful but you likely do not need to change.

- **`learning_rate`**: Learning rate of the model training.
- **`batch_size`**: Batch size of the model training; setting it too high will cause the GPU to go OOM. (if using `fp16`, you can increase the batch size more safely)

In [None]:
ai.train(file_name,
         line_by_line=False,
         from_cache=False,
         num_steps=100000,
         generate_every=10000,
         save_every=10000,
         save_gdrive=True,
         learning_rate=3e-3,
         fp16=False,
         batch_size=2
         )

08/04/2022 15:02:15 — INFO — aitextgen — Loading text from processed_quora.txt with generation length of 1024.


  0%|          | 0/149263 [00:00<?, ?it/s]

08/04/2022 15:02:15 — INFO — aitextgen.TokenDataset — Encoding 149,263 sets of tokens from processed_quora.txt.
  f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
08/04/2022 15:02:26 — INFO — pytorch_lightning.utilities.rank_zero — GPU available: True (cuda), used: True
08/04/2022 15:02:27 — INFO — pytorch_lightning.utilities.rank_zero — TPU available: False, using: 0 TPU cores
08/04/2022 15:02:27 — INFO — pytorch_lightning.utilities.rank_zero — IPU available: False, using: 0 IPUs
08/04/2022 15:02:27 — INFO — pytorch_lightning.utilities.rank_zero — HPU available: False, using: 0 HPUs
  f"The `Callback.{hook}` hook was deprecated in v1.6 and"
08/04/2022 15:02:27 — INFO — pytorch_lightning.accelerators.cuda — LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


  0%|          | 0/100000 [00:00<?, ?it/s]

[1m10,000 steps reached: saving model to /trained_model[0m
[1m10,000 steps reached: generating sample texts.[0m
 with a computer background? <!@#$> What is the most beautiful book you have ever read?
Should I give a good assured return if it's a good idea to it? <!@#$> Should I give a good assured return?
Why do people ask simple direct science questions in Quora when there are sufficient resources available in internet? <!@#$> Why should one use the programming language?
How can I get traffic on my site and what are some suggestions on how to get more of it? <!@#$> How can I get more traffic to my website without investing?
How is it like to get into Harvard? <!@#$> What is it like to get admitted into Harvard?
Will there really end of the universe? <!@#$> Is there really worse than the end of the universe?
What is the best way to learn Japanese by Karan Johar? <!@#$> What is the best way to learn Japanese by Karan Johar?
What should I do to earn money online? <!@#$> What is the e

You're done! Feel free to go to the **Generate Text From The Trained Model** section to generate text based on your retrained model.


## Load a Trained Model

If you already had a trained model from this notebook, running the next cell will copy the `pytorch_model.bin` and the `config.json`file from the specified folder in Google Drive into the Colaboratory VM. (If no `from_folder` is specified, it assumes the two files are located at the root level of your Google Drive)

In [None]:
from_folder = None

for file in ["pytorch_model.bin", "config.json"]:
  if from_folder:
    copy_file_from_gdrive(file, from_folder)
  else:
    copy_file_from_gdrive(file)

The next cell will allow you to load the retrained model + metadata necessary to generate text.

In [None]:
ai = aitextgen(model_folder=".", to_gpu=True)

## Generate Text From The Trained Model

After you've trained the model or loaded a retrained model from checkpoint, you can now generate text.

**If you just trained a model**, you'll get much faster training performance if you reload the model; the next cell will reload the model you just trained from the `trained_model` folder.

In [None]:
ai = aitextgen(model_folder="trained_model", to_gpu=True)

`generate()` without any parameters generates a single text from the loaded model to the console.

In [None]:
ai.generate()

If you're creating an API based on your model and need to pass the generated text elsewhere, you can do `text = ai.generate_one()`

You can also pass in a `prompt` to the generate function to force the text to start with a given character sequence and generate text from there (good if you add an indicator when the text starts).

You can also generate multiple texts at a time by specifing `n`. You can pass a `batch_size` to generate multiple samples in parallel, giving a massive speedup (in Colaboratory, set a maximum of 50 for `batch_size` to avoid going OOM).

Other optional-but-helpful parameters for `ai.generate()` and friends:

*  **`min length`**: The minimum length of the generated text: if the text is shorter than this value after cleanup, aitextgen will generate another one.
*  **`max_length`**: Number of tokens to generate (default 256, you can generate up to 1024 tokens with GPT-2 and 2048 with GPT Neo)
* **`temperature`**: The higher the temperature, the crazier the text (default 0.7, recommended to keep between 0.7 and 1.0)
* **`top_k`**: Limits the generated guesses to the top *k* guesses (default 0 which disables the behavior; if the generated output is super crazy, you may want to set `top_k=40`)
* **`top_p`**: Nucleus sampling: limits the generated guesses to a cumulative probability. (gets good results on a dataset with `top_p=0.9`)

In [None]:
ai.generate(n=5,
            batch_size=5,
            prompt="ROMEO:",
            max_length=256,
            temperature=1.0,
            top_p=0.9)

For bulk generation, you can generate a large amount of texts to a file and sort out the samples locally on your computer. The next cell will generate `num_files` files, each with `n` texts and whatever other parameters you would pass to `generate()`. The files can then be downloaded from the Files sidebar!

You can rerun the cells as many times as you want for even more generated texts!

In [None]:
num_files = 5

for _ in range(num_files):
  ai.generate_to_file(n=1000,
                     batch_size=50,
                     prompt="ROMEO:",
                     max_length=256,
                     temperature=1.0,
                     top_p=0.9)

# LICENSE

MIT License

Copyright (c) 2020-2021 Max Woolf

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.