In [1]:
!pip install gpt-2-simple tqdm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gpt-2-simple
  Downloading gpt_2_simple-0.8.1.tar.gz (26 kB)
Collecting toposort
  Downloading toposort-1.7-py2.py3-none-any.whl (9.0 kB)
Building wheels for collected packages: gpt-2-simple
  Building wheel for gpt-2-simple (setup.py) ... [?25l[?25hdone
  Created wheel for gpt-2-simple: filename=gpt_2_simple-0.8.1-py3-none-any.whl size=24577 sha256=005eab4695a5f0426ba9161e524b89bb84daf347af71f5971628454492126424
  Stored in directory: /root/.cache/pip/wheels/d6/89/8a/f5de6944286d1ac2658b0caa7eae3c8cda50f770cdc957217f
Successfully built gpt-2-simple
Installing collected packages: toposort, gpt-2-simple
Successfully installed gpt-2-simple-0.8.1 toposort-1.7


In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


# GPT-2 Fine-tuning

In [3]:
import gpt_2_simple as gpt2
import os
import requests
import pandas as pd
from tqdm import tqdm
import tensorflow as tf

In [4]:
model_name = "124M"
base_folder = "gdrive/MyDrive/IRE"
train_file = os.path.join(base_folder, "train.csv")
res_file = os.path.join(base_folder, "gpt_train.csv")

In [None]:
if not os.path.exists(res_file):
  train_df = pd.read_csv(os.path.join(base_folder, train_file))

  plots = list(train_df['plot_synopsis'])
  plots = [' '.join(x.split()[:1000]) for x in plots]
  taglines = list(train_df['tagline'])

  new_data = [p + ' = @ = ' + t for p, t in zip(plots, taglines)]
  pd.DataFrame(new_data).to_csv(res_file, index=False)

In [5]:
model_name = "124M"
if not os.path.isdir(os.path.join("models", model_name)):
	print(f"Downloading {model_name} model...")
	gpt2.download_gpt2(model_name=model_name)   # model is saved into current directory under /models/124M/


Downloading 124M model...


Fetching checkpoint: 1.05Mit [00:00, 291Mit/s]                                                      
Fetching encoder.json: 1.05Mit [00:01, 564kit/s]
Fetching hparams.json: 1.05Mit [00:00, 412Mit/s]                                                    
Fetching model.ckpt.data-00000-of-00001: 498Mit [01:27, 5.67Mit/s]                                  
Fetching model.ckpt.index: 1.05Mit [00:00, 535Mit/s]                                                
Fetching model.ckpt.meta: 1.05Mit [00:01, 734kit/s]
Fetching vocab.bpe: 1.05Mit [00:01, 850kit/s]


In [None]:
sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
              res_file,
              model_name=model_name,
              steps=1000)  # steps is max number of training steps

gpt2.generate(sess)


Loading checkpoint models/124M/model.ckpt
Loading dataset...


100%|██████████| 1/1 [00:01<00:00,  1.30s/it]


dataset has 7004654 tokens
Training...
[1 | 7.31] loss=3.33 avg=3.33
[2 | 9.43] loss=3.53 avg=3.43
[3 | 11.55] loss=3.38 avg=3.41
[4 | 13.68] loss=3.49 avg=3.43
[5 | 15.81] loss=3.49 avg=3.45
[6 | 17.94] loss=3.57 avg=3.47
[7 | 20.09] loss=3.31 avg=3.44
[8 | 22.24] loss=3.53 avg=3.46
[9 | 24.39] loss=3.39 avg=3.45
[10 | 26.55] loss=3.32 avg=3.43
[11 | 28.72] loss=3.33 avg=3.42
[12 | 30.89] loss=3.25 avg=3.41
[13 | 33.07] loss=3.46 avg=3.41
[14 | 35.25] loss=3.29 avg=3.40
[15 | 37.44] loss=3.43 avg=3.41
[16 | 39.64] loss=3.35 avg=3.40
[17 | 41.84] loss=3.36 avg=3.40
[18 | 44.05] loss=3.41 avg=3.40
[19 | 46.27] loss=3.35 avg=3.40
[20 | 48.49] loss=3.41 avg=3.40
[21 | 50.72] loss=3.39 avg=3.40
[22 | 52.95] loss=3.36 avg=3.40
[23 | 55.18] loss=3.25 avg=3.39
[24 | 57.42] loss=3.27 avg=3.38
[25 | 59.67] loss=3.19 avg=3.37
[26 | 61.93] loss=3.48 avg=3.38
[27 | 64.18] loss=3.33 avg=3.38
[28 | 66.45] loss=3.25 avg=3.37
[29 | 68.74] loss=3.35 avg=3.37
[30 | 71.02] loss=3.28 avg=3.37
[31 | 73.31]

In [None]:
!cp -r checkpoint/ gdrive/MyDrive/IRE/

# Inference

Session needs to be restarted in order to perform inference on the test data.



In [6]:
!cp -r gdrive/MyDrive/IRE/checkpoint .

In [7]:
sess2 = gpt2.start_tf_sess()

gpt2.load_gpt2(sess2,
              run_name='run1')


Loading checkpoint checkpoint/run1/model-1000


In [11]:
test_file = os.path.join(base_folder, 'test.csv')
test_df = pd.read_csv(test_file)
plots = list(test_df['plot_synopsis'])
test_df[['plot_synopsis', 'tagline']]

Unnamed: 0,plot_synopsis,tagline
0,"It's August 9, 1985 in the year of Halley's Co...","In the blink of an eye, the terror begins."
1,"Kuzco, a young selfish Inca emperor, rejects t...",It's All About.....ME!
2,The movie opens with titles and credits over a...,Zombies of the African Voodoo coast!
3,"In Paris, bloodthirsty jewel thief Roger Sarte...","Behind every gun is ""The Sicilian Clan!"""
4,A young boy named Sean Donovan lives with his ...,Beware the Hero
...,...,...
836,"Ivan Drago, a Russian Soviet boxer, arrives in...",He's facing the ultimate challenge. And fighti...
837,"In 1944, Lecter is eight years old, living in ...",It started with revenge.
838,The story revolves around a struggle to determ...,He's Never Been To Earth. He's Never Even Slep...
839,Jack Crow (James Woods) is a professional and ...,From the Master of Evil. Comes a New Breed of ...


In [13]:
results = []
for i, x in tqdm(enumerate(plots)):
  if i % 15 == 0:  # restart session for speed-up
    tf.compat.v1.reset_default_graph()
    sess2 = gpt2.start_tf_sess()
    gpt2.load_gpt2(sess2, run_name='run1')

  p = ' '.join(x.split()[:600]) + ' = @ = '
  pred = gpt2.generate(sess2, model_name=model_name,
              #run_name=run_name, checkpoint_dir=checkpoint_dir,
              temperature=0.7, include_prefix=False, prefix=p,
              truncate='<|endoftext|>', nsamples=1, length=20,
              return_as_list=True
              )[0]
  results.append(pred.replace(p, '', 1))

0it [00:00, ?it/s]

Loading checkpoint checkpoint/run1/model-1000


15it [01:19,  6.01s/it]

Loading checkpoint checkpoint/run1/model-1000


30it [02:43,  6.22s/it]

Loading checkpoint checkpoint/run1/model-1000


45it [04:10,  6.56s/it]

Loading checkpoint checkpoint/run1/model-1000


60it [05:35,  6.38s/it]

Loading checkpoint checkpoint/run1/model-1000


75it [07:00,  6.22s/it]

Loading checkpoint checkpoint/run1/model-1000


90it [08:26,  6.30s/it]

Loading checkpoint checkpoint/run1/model-1000


105it [09:49,  6.06s/it]

Loading checkpoint checkpoint/run1/model-1000


120it [11:14,  6.33s/it]

Loading checkpoint checkpoint/run1/model-1000


135it [12:37,  6.22s/it]

Loading checkpoint checkpoint/run1/model-1000


150it [14:03,  6.21s/it]

Loading checkpoint checkpoint/run1/model-1000


165it [15:27,  6.16s/it]

Loading checkpoint checkpoint/run1/model-1000


180it [16:51,  6.16s/it]

Loading checkpoint checkpoint/run1/model-1000


195it [18:18,  6.56s/it]

Loading checkpoint checkpoint/run1/model-1000


210it [19:41,  6.16s/it]

Loading checkpoint checkpoint/run1/model-1000


225it [21:04,  6.14s/it]

Loading checkpoint checkpoint/run1/model-1000


240it [22:29,  6.23s/it]

Loading checkpoint checkpoint/run1/model-1000


255it [23:54,  6.35s/it]

Loading checkpoint checkpoint/run1/model-1000


270it [25:21,  6.52s/it]

Loading checkpoint checkpoint/run1/model-1000


285it [26:48,  6.34s/it]

Loading checkpoint checkpoint/run1/model-1000


300it [28:14,  6.45s/it]

Loading checkpoint checkpoint/run1/model-1000


315it [29:39,  6.65s/it]

Loading checkpoint checkpoint/run1/model-1000


330it [31:06,  6.58s/it]

Loading checkpoint checkpoint/run1/model-1000


345it [32:33,  6.43s/it]

Loading checkpoint checkpoint/run1/model-1000


360it [33:58,  6.24s/it]

Loading checkpoint checkpoint/run1/model-1000


375it [35:23,  6.28s/it]

Loading checkpoint checkpoint/run1/model-1000


390it [36:48,  6.38s/it]

Loading checkpoint checkpoint/run1/model-1000


405it [38:14,  6.27s/it]

Loading checkpoint checkpoint/run1/model-1000


420it [39:39,  6.24s/it]

Loading checkpoint checkpoint/run1/model-1000


435it [41:05,  6.46s/it]

Loading checkpoint checkpoint/run1/model-1000


450it [42:31,  6.39s/it]

Loading checkpoint checkpoint/run1/model-1000


465it [43:56,  6.26s/it]

Loading checkpoint checkpoint/run1/model-1000


480it [45:21,  6.27s/it]

Loading checkpoint checkpoint/run1/model-1000


495it [46:47,  6.42s/it]

Loading checkpoint checkpoint/run1/model-1000


510it [48:13,  6.36s/it]

Loading checkpoint checkpoint/run1/model-1000


525it [49:38,  6.33s/it]

Loading checkpoint checkpoint/run1/model-1000


540it [51:03,  6.60s/it]

Loading checkpoint checkpoint/run1/model-1000


555it [52:29,  6.46s/it]

Loading checkpoint checkpoint/run1/model-1000


570it [53:55,  6.33s/it]

Loading checkpoint checkpoint/run1/model-1000


585it [55:20,  6.42s/it]

Loading checkpoint checkpoint/run1/model-1000


600it [56:45,  6.66s/it]

Loading checkpoint checkpoint/run1/model-1000


615it [58:11,  6.51s/it]

Loading checkpoint checkpoint/run1/model-1000


630it [59:37,  6.36s/it]

Loading checkpoint checkpoint/run1/model-1000


645it [1:01:02,  6.18s/it]

Loading checkpoint checkpoint/run1/model-1000


660it [1:02:26,  6.14s/it]

Loading checkpoint checkpoint/run1/model-1000


675it [1:03:51,  6.30s/it]

Loading checkpoint checkpoint/run1/model-1000


690it [1:05:16,  6.18s/it]

Loading checkpoint checkpoint/run1/model-1000


705it [1:06:42,  6.85s/it]

Loading checkpoint checkpoint/run1/model-1000


720it [1:08:09,  6.66s/it]

Loading checkpoint checkpoint/run1/model-1000


735it [1:09:35,  6.51s/it]

Loading checkpoint checkpoint/run1/model-1000


750it [1:11:01,  6.31s/it]

Loading checkpoint checkpoint/run1/model-1000


765it [1:12:26,  6.34s/it]

Loading checkpoint checkpoint/run1/model-1000


780it [1:13:51,  6.50s/it]

Loading checkpoint checkpoint/run1/model-1000


795it [1:15:17,  6.40s/it]

Loading checkpoint checkpoint/run1/model-1000


810it [1:16:41,  6.21s/it]

Loading checkpoint checkpoint/run1/model-1000


825it [1:18:05,  6.16s/it]

Loading checkpoint checkpoint/run1/model-1000


840it [1:19:31,  6.62s/it]

Loading checkpoint checkpoint/run1/model-1000


841it [1:19:39,  5.68s/it]


In [15]:
test_df['predictions'] = results

In [18]:
test_df.to_csv(os.path.join(base_folder, 'GPT', 'results.csv'), index=False)