# GPT-2 Text Generation Notebook

by [Artem Konevskikh](https://github.com/artem-konevskikh)

Based on notebook by [Max Woolf](http://minimaxir.com). For more about `gpt-2-simple`, you can visit [this GitHub repository](https://github.com/minimaxir/gpt-2-simple).

## Installation

In [None]:
#@title Imports
#@markdown By running this cell we are loading libraries needed to work with GPT2
%tensorflow_version 1.x
!pip install -q gpt-2-simple
import gpt_2_simple as gpt2
from datetime import datetime
from google.colab import files

In [None]:
#@title GPU

#@markdown Colaboratory uses either a Nvidia T4 GPU or an Nvidia K80 GPU. The T4 is slightly faster than the old K80 for training GPT-2, and has more memory allowing you to train the larger GPT-2 models and generate more text.

#@markdown We can verify which GPU is active by running this cell.

!nvidia-smi

In [None]:
#@title Mounting Google Drive

#@markdown Colab notebooks are Virtual Machines, so any data stored in it will be vanished as soon as we close it or reset it. So the best way to keep input data and save trained models is to mount your Google Drive and store it there.

#@markdown After running this cell you will get the link, where you should grant the access to your Drive and copy auth token. Paste this token to the input below and press Enter

gpt2.mount_gdrive()

## Text Generation

In this notebook we will use the model you finetuned on your texts previously.

In [None]:
#@title Load a Finetuned Model Checkpoint

#@markdown Running this cell will copy the `.rar` checkpoint file from your Google Drive into the Colaboratory VM.

#@markdown **IMPORTANT NOTE:** If you want to rerun this cell, **restart the VM first** (Runtime -> Restart Runtime). You will need to rerun imports.

#@markdown Run name
run_name='run1' #@param {type: "string"}
gpt2.copy_checkpoint_from_gdrive(run_name=run_name)
sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess, run_name=run_name)

In [9]:
#@title Generation with the finetuned model
#@markdown **Generation parameters**

#@markdown Run name
run_name= 'airun1' #@param {type: "string"}

#@markdown You can pass in a `prefix` to the generate function to force the text to start with a given character sequence and generate text from there
prefix = '' #@param {type: "string"}
#@markdown Number of tokens to generate (default 1023, the maximum)
length = 300  #@param {type:"slider", min:1, max:1023, step:1}
#@markdown The higher the temperature, the crazier the text (default 0.7, recommended to keep between 0.7 and 1.0)
temperature=0.7  #@param {type:"slider", min:0.1, max:1, step:0.1}
#@markdown Limits the generated guesses to the top *k* guesses (default 0 which disables the behavior; if the generated output is super crazy, you may want to set `top_k=40`)
top_k=0  #@param {type: "number"}
#@markdown Nucleus sampling: limits the generated guesses to a cumulative probability. (gets good results on a dataset with `top_p=0.9`)
top_p=0.9  #@param {type:"slider", min:0, max:1, step:0.1}
#@markdown Number of samples to generate
nsamples=5  #@param {type: "number"}
#@markdown Number of samples to generate in pararallel to speed up the process
batch_size=5  #@param {type:"slider", min:1, max:20, step:1}
#@markdown Save samples to text file
save_to_file = True #@param {type:"boolean"}


#@markdown *__Set parameters and  and run the cell to generate samples__*
gen_file = 'gpt2_gentext_{:%Y%m%d_%H%M%S}.txt'.format(datetime.utcnow()) if save_to_file else None
gpt2.generate(sess,
              run_name=run_name,
              destination_path=gen_file,
              prefix=None if prefix=='' else prefix,
              length=length,
              temperature=temperature,
              top_k=int(top_k),
              top_p=top_p,
              nsamples=int(nsamples),
              batch_size=batch_size
              )

In [None]:
#@markdown **Download newest generated file**
if gen_file:
  files.download(gen_file)