#  Finetune GPT-2 on Reddit Data

by [Max Woolf](http://minimaxir.com)

A variant of the [default notebook](https://colab.research.google.com/drive/1VLG8e7YSEwypxU-noRNhsv5dW4NfTGce) optimized for short-form titles. It is recommended to be familiar with that notebook before using this one.

This example uses a CSV export of Reddit data via BigQuery (see this post for more information).


In [0]:
!pip install -q gpt-2-simple
import gpt_2_simple as gpt2
from datetime import datetime
from google.colab import files

## GPU

In [0]:
!nvidia-smi

Sat Sep 28 17:16:18 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.40       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   70C    P8    32W / 149W |      0MiB / 11441MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

## Downloading GPT-2

The default query returns 1.3MB of data, so probably should only use `124M` GPT-2 to finetune. If working with more Reddity data, then migrate to `355M`.

In [0]:
gpt2.download_gpt2(model_name="355M")

Fetching checkpoint: 1.05Mit [00:00, 213Mit/s]                                                      
Fetching encoder.json: 1.05Mit [00:00, 94.6Mit/s]                                                   
Fetching hparams.json: 1.05Mit [00:00, 516Mit/s]                                                    
Fetching model.ckpt.data-00000-of-00001: 1.42Git [00:10, 133Mit/s]                                  
Fetching model.ckpt.index: 1.05Mit [00:00, 387Mit/s]                                                
Fetching model.ckpt.meta: 1.05Mit [00:00, 99.7Mit/s]                                                
Fetching vocab.bpe: 1.05Mit [00:00, 137Mit/s]                                                       


## Mounting Google Drive

In [0]:
gpt2.mount_gdrive()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Uploading a Text File to be Trained to Colaboratory

A single-column CSV is expected.

In [0]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
file_name = "gpt2_10000_posts.csv"

If your text file is larger than 10MB, it is recommended to upload that file to Google Drive first, then copy that file from Google Drive to the Colaboratory VM.

In [0]:
gpt2.copy_file_from_gdrive(file_name)

## Finetune GPT-2

Providing a single-column CSV will automatically add `<|startoftext|>` and `<|endoftext|>` tokens appropriately.

Short form text is more likely to overfit, so train it with fewer steps than you would for longform content.

In [0]:
import pandas as pd

df = pd.read_csv('/content/drive/My Drive/gpt2_10000_posts.csv')
df.head()

Unnamed: 0,0
0,This cycle of good expansion and then bad expa...
1,It’s hard to judge an expansion before it’s ov...
2,There is no such cycle.Legion was not a great ...
3,"That’s true, I am of the mind to agree. It mak..."
4,OThe thing that worries me is the amount of ti...


In [0]:
sess = gpt2.start_tf_sess()

gpt2.finetune(sess,
              dataset=file_name,
              model_name='355M',
              steps=1100,
              restore_from='latest',
              run_name='10000_posts',
              print_every=10,
              sample_every=100
              )

ValueError: ignored

In [0]:
gpt2.copy_checkpoint_to_gdrive(run_name='10000_posts')

## Load a Trained Model Checkpoint

In [0]:
import gpt_2_simple as gpt2

gpt2.copy_checkpoint_from_gdrive(run_name='10000_posts')

In [0]:
sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess, run_name='10000_posts')

Loading checkpoint checkpoint/10000_posts/model-1000
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from checkpoint/10000_posts/model-1000


## Generate Text From The Trained Model

Same as normal generate functions, except with additional parameters to handle the new tokens.

In [0]:
gpt2.generate(sess, run_name='10000_posts',
            length=50,
            nsamples=10,
             prefix="<|startoftext|>",
             truncate="<|endoftext|>")

<|startoftext|>I am not sure if this is a good idea, but I do not see it being a good idea on the forums at this point.
<|startoftext|>Yes, you can M*aM every time you want. What alts from M to M? Other then that, you can do M*aM every time you want. Just be sure to be clear what alt you’re on
<|startoftext|>This is a fixed build. It depends on the class you have. If you have a rogue, on the other hand, you should take the same build. That way you can just change classes to a different class when you want to try the build
<|startoftext|>your an idiot and don’t know what you’re talking about. This is just stupid. Even in vanilla, there was a lot of discussion about the combat system. I understand that you don’t like it, but this
<|startoftext|>I’m actually doing quite well.
<|startoftext|>I’m certain that a lot of people who are fighting this issue are not actually from the Philippines or South East Asia.                        
<|startoftext|>I’m not sure why you’re talking about ponie

In [0]:
gpt2.generate(sess,
              length=100,
              temperature=1.0,
              nsamples=10,
              batch_size=10,
              prefix="<|startoftext|>",
              truncate="<|endoftext|>",
              include_prefix=False
              )

FileNotFoundError: ignored

If generating in bulk, you may want to set `sample_demin=''` to remove the delimiter between each sample.

In [0]:
gen_file = 'gpt2_gentext_{:%Y%m%d_%H%M%S}.txt'.format(datetime.utcnow())

gpt2.generate_to_file(sess,
                      run_name='10000_posts',
                      destination_path=gen_file,
                      length=100,
                      temperature=1.0,
                      nsamples=100,
                      batch_size=20,
                      prefix="<|startoftext|>I think the game",
                      truncate="<|endoftext|>",
                      include_prefix=False,
                      sample_delim=''
                      )

In [0]:
# may have to run twice to get file to download
files.download(gen_file)

# Etcetera

If the notebook has errors (e.g. GPU Sync Fail), force-kill the Colaboratory virtual machine and restart it with the command below:

In [0]:
!kill -9 -1

# LICENSE

MIT License

Copyright (c) 2019 Max Woolf

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.