In [None]:
# run once (may have to restart notebook)
# !pip install tensorflow-gpu==1.15 --user # if you do not have a gpu remove -gpu 
# !pip install gpt-2-simple --user

In [None]:
import pandas as pd
import numpy as np
import json
import os
import requests

In [None]:
is_local = False # change this if you are not loading a pretrained model locally
if is_local:
    # import tensorflow as tf
    import sys
    sys.path.insert(0, os.path.abspath('../../gpt-2-simple-0.7/gpt_2_simple'))
    import gpt_2 as gpt2
    local_checkpoint_dir = "../../local_checkpoints" # directory where local models are stored
    local_model_name = 'model-100'
else:
    import gpt_2_simple as gpt2

In [None]:
# check to make sure gpu is recognized for significantly faster training

# from tensorflow.python.client import device_lib
# print(device_lib.list_local_devices())

In [None]:
model_name = "124M"
if not os.path.isdir(os.path.join("models", model_name)):
    print(f"Downloading {model_name} model...")
    gpt2.download_gpt2(model_name=model_name)

In [None]:
dir_path = "../data/matt_results/"

dem_file_name = "democrats_result.txt"
dem_sample_name = "democrats_sample.txt"
dem_model_name = 'matt_dem'

rep_file_name = 'republican_result.txt'
rep_sample_name = 'republican_sample.txt'
rep_model_name = 'matt_rep'

both_file_name = 'both_result.txt'
both_sample_name = 'both_sample.txt'
both_model_name = 'matt_both'

In [None]:
# reading and writing sample files for each side

# with open(dir_path+dem_file_name,'r') as demf:
#     dem_data = demf.readlines(500000)
# with open(dir_path+dem_sample_name,'w+') as dem_write:
#     dem_write.writelines(dem_data)

# with open(dir_path+rep_file_name,'r') as repf:
#     rep_data = repf.readlines(500000)
# with open(dir_path+rep_sample_name,'w+') as rep_write:
#     rep_write.writelines(rep_data)

# dem_data.extend(rep_data)
# both_data = dem_data


# with open(dir_path+both_file_name,'r') as bothf:
#     both_data = bothf.readlines(100000)
# with open(dir_path+both_sample_name,'w+') as both_write:
#     both_write.writelines(both_data)

In [None]:
# text file to train model on
train_fp = dir_path + both_file_name
train_name = both_model_name
results_fp = "../results/" + train_name + "_generated.txt"

In [None]:
# this cell takes the longest. Can only be run once without restarting the notebook
sess = gpt2.start_tf_sess()
if is_local:
    gpt2.load_gpt2(sess, 
                   checkpoint=local_model_name, 
                   run_name=train_name, 
                   checkpoint_dir=local_checkpoint_dir)
else:
    gpt2.finetune(sess,
                  train_fp,
                  model_name=model_name,
                  steps=1000, # steps is max number of training steps
                  restore_from='fresh', # makes sure model doesnt resume from previous trained model
                  print_every=20, # only prints every 20 training steps,
                  run_name=train_name # model name, so we can load different models locally
                 )

In [None]:
# prompt to generate response to, going to be a post/comment from the political discussion subreddits

pre = "[title]What is Trump’s strategy for re-election? \
[selftext]What do you think he will focus on for re-election and what do you \
think he should do to give him the best chance of winning? I personally think \
he’s going to focus heavily on China but I’d love to hear what you guys/gals think."

In [None]:
gpt2.generate(sess, 
              temperature=.7, # uniqueness of the output (usually ranges from .5 to 2)
              prefix=pre, # prompt
              nsamples=5, # number of generated responses 
              length=400 # number of words (including prompt) per response
             )

In [None]:
gpt2.generate_to_file(sess, 
                      destination_path=results_fp,
                      temperature=.8, # uniqueness of the output (usually ranges from .5 to 2)
                      prefix=pre, # prompt
                      nsamples=5, # number of generated responses 
                      length=400 # number of words (including prompt) per response)
                     )