<a href="https://colab.research.google.com/github/aditeyabaral/gpt2-implementation/blob/main/simple_gpt2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Simple GPT-2

[Library](https://github.com/minimaxir/gpt-2-simple)

[Notebook Guide](https://colab.research.google.com/drive/1VLG8e7YSEwypxU-noRNhsv5dW4NfTGce#scrollTo=aeXshJM-Cuaf)

# Installing and Setting up Environment

In [1]:
%tensorflow_version 1.x
!pip install gpt-2-simple
import gpt_2_simple as gpt2
import pandas as pd
import numpy as np

TensorFlow 1.x selected.
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [2]:
MODEL_TYPE = "355M"
gpt2.download_gpt2(model_name=MODEL_TYPE)

Fetching checkpoint: 1.05Mit [00:00, 206Mit/s]                                                      
Fetching encoder.json: 1.05Mit [00:00, 64.4Mit/s]                                                   
Fetching hparams.json: 1.05Mit [00:00, 671Mit/s]                                                    
Fetching model.ckpt.data-00000-of-00001: 1.42Git [00:07, 184Mit/s]                                  
Fetching model.ckpt.index: 1.05Mit [00:00, 205Mit/s]                                                
Fetching model.ckpt.meta: 1.05Mit [00:00, 77.5Mit/s]                                                
Fetching vocab.bpe: 1.05Mit [00:00, 136Mit/s]                                                       


# Loading Training Data

In [3]:
!curl -L -O https://raw.githubusercontent.com/aditeyabaral/gpt2-implementation/main/Simpsons.csv

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 8998k  100 8998k    0     0  18.7M      0 --:--:-- --:--:-- --:--:-- 18.7M


In [4]:
df = pd.read_csv("/content/Simpsons.csv")
df.dropna(inplace = True)
df.drop_duplicates(inplace = True)
df.reset_index(inplace = True)
df.drop(columns = "index", inplace = True)
print(df.shape)
df.head()

(126646, 2)


Unnamed: 0,raw_character_text,spoken_words
0,Miss Hoover,"No, actually, it was a little of both. Sometim..."
1,Lisa Simpson,Where's Mr. Bergstrom?
2,Miss Hoover,I don't know. Although I'd sure like to talk t...
3,Lisa Simpson,That life is worth living.
4,Edna Krabappel-Flanders,The polls will be open from now until the end ...


In [5]:
def character_slice(character):
  if character in np.unique(df[["raw_character_text"]]):
    return df[df["raw_character_text"] == character]

In [6]:
CHARACTER = "Lisa Simpson"

character_df = character_slice(CHARACTER)
print(character_df.shape)
character_df.head()

(10144, 2)


Unnamed: 0,raw_character_text,spoken_words
1,Lisa Simpson,Where's Mr. Bergstrom?
3,Lisa Simpson,That life is worth living.
8,Lisa Simpson,Mr. Bergstrom! Mr. Bergstrom!
10,Lisa Simpson,Do you know where I could find him?
12,Lisa Simpson,"The train, how like him... traditional, yet en..."


## Adding delimiters

This is optional - allows for retention of training data structure

In [7]:
train_data = "\n<|endoftext|>\n".join(character_df["spoken_words"].values)

In [8]:
with open("corpus.txt","w") as f:
  f.write(train_data)

# Training GPT-2

In [9]:
sess = gpt2.start_tf_sess()

gpt2.finetune(sess,
              dataset="corpus.txt",
              model_name=MODEL_TYPE,
              steps=100,
              restore_from='fresh',
              run_name='run1',
              print_every=1,
              sample_every=5,
              save_every=10
              )

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Instructions for updating:
Please use tensorflow.python.ops.op_selector.get_backward_walk_ops.
Loading checkpoint models/355M/model.ckpt
INFO:tensorflow:Restoring parameters from models/355M/model.ckpt


  0%|          | 0/1 [00:00<?, ?it/s]

Loading dataset...


100%|██████████| 1/1 [00:01<00:00,  1.58s/it]


dataset has 229645 tokens
Training...
[1 | 96.51] loss=2.00 avg=2.00
[2 | 183.76] loss=2.25 avg=2.13
[3 | 266.65] loss=1.98 avg=2.08
[4 | 351.15] loss=1.90 avg=2.03
[5 | 434.79] loss=2.40 avg=2.11
[6 | 517.20] loss=2.22 avg=2.13
[7 | 599.85] loss=2.41 avg=2.17
[8 | 683.08] loss=2.20 avg=2.17
[9 | 765.77] loss=1.87 avg=2.14
interrupted
Saving checkpoint/run1/model-9


KeyboardInterrupt: ignored

# Load a Trained Model Checkpoint

The next cell will allow you to load the retrained model checkpoint + metadata necessary to generate text.

In [None]:
sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess, run_name='run1')

# Generate Text From The Trained Model

## User Prompt

In [None]:
gpt2.generate(sess,
              length=250,
              temperature=0.7,
              prefix="LORD",
              nsamples=5,
              batch_size=5
              )

## Random Sampling

In [None]:
gpt2.generate(sess, run_name='run1')