In [199]:
import numpy as np
import pandas as pd
from simpletransformers.language_modeling import LanguageModelingModel
from simpletransformers.language_generation import LanguageGenerationModel
import warnings
import pickle
import re
warnings.filterwarnings('ignore')

In [200]:
def clean(post):
    post_split = post.split('|||')
    # split the kaggle data set posts by |||
    post_split_split = [x.split(' ') for x in post_split]
    
    # removes any 'words' that have http:// or https:// in them
    return_list = [[item for item in sentence if ('http://' not in item and 'https://' not in item)] for sentence in post_split_split]
    
    # returns a list of posts if they are not empty after removing the links
    return [' '.join(sentence) for sentence in return_list if sentence]

In [230]:
df = pd.read_csv('../data/mbti_1.csv')

In [231]:
df_dict = {'INTJ':0, 'INTP':0, 'ENTJ':1, 'ENTP':1, 
           'INFJ':0, 'INFP':0, 'ENFJ':1, 'ENFP':1, 
           'ISTJ':0, 'ISFJ':0, 'ESTJ':1, 'ESFJ':1, 
           'ISTP':0, 'ISFP':0, 'ESTP':1, 'ESFP':1}

In [232]:
for dataframe in df_dict.keys():
    if(df[df['type'] == dataframe].shape[0] < 200):
        df_dict[dataframe] = df[df['type'] == dataframe]
    else:
        df_dict[dataframe] = df[df['type'] == dataframe].sample(n=200)
    df_dict[dataframe]['post_split'] = df_dict[dataframe].posts.apply(clean)
    print(dataframe, df_dict[dataframe].shape)

INTJ (200, 2)
INTP (200, 2)
ENTJ (200, 2)
ENTP (200, 2)
INFJ (200, 2)
INFP (200, 2)
ENFJ (190, 2)
ENFP (200, 2)
ISTJ (200, 2)
ISFJ (166, 2)
ESTJ (39, 2)
ESFJ (42, 2)
ISTP (200, 2)
ISFP (200, 2)
ESTP (89, 2)
ESFP (48, 2)


In [234]:
df_dict['INTJ'].head()

Unnamed: 0,type,posts,post_split
167,INTJ,"'Hi, are you really manic?|||greetings.|||Yes,...","['Hi, are you really manic?, greetings., Yes, ..."
7500,INTJ,This is good...I was asking for a group analys...,[This is good...I was asking for a group analy...
790,INTJ,"'throughtheroses I cannot help you with this,...",['throughtheroses I cannot help you with this...
4514,INTJ,'I got 23|||YAYYYY!! XD Can I have a cookie in...,"['I got 23, YAYYYY!! XD Can I have a cookie in..."
6810,INTJ,'Psychology is a science. You probably feel un...,['Psychology is a science. You probably feel u...


In [235]:
for dataframe in df_dict.keys():
    df_dict[dataframe] = df_dict[dataframe].post_split.tolist()

In [270]:
def write_posts_to_file(posts, file):
    for post in posts:
        for ind_post in post:
            if(re.search('[a-zA-Z]', ind_post)):
                if(ind_post.endswith('...')or ind_post.endswith("...'")):
                    if(re.search('.*[\.?!]\s', ind_post)):
                        file.writelines(re.search('.*[\.?!]\s', ind_post).group(0).strip() + "\n")
                else:
                    file.writelines(ind_post)

In [237]:
for type_ in df_dict.keys():
    with open(f"../data/2.0_GPT-2_text_gen_posts/{type_}_posts_train.txt", "w") as f:
        write_posts_to_file(df_dict[type_][:-15], f)

    with open(f"../data/2.0_GPT-2_text_gen_posts/{type_}_posts_test.txt", "w") as f:
        write_posts_to_file(df_dict[type_][-15:], f)

In [238]:
train_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "train_batch_size": 64,
    "num_train_epochs": 3,
    "mlm": False,
}

results = []

In [262]:
type_ = 'ESTP'
model = LanguageModelingModel('gpt2', 'gpt2', args=train_args, use_cuda=False)

model.train_model(train_file = f"../data/2.0_GPT-2_text_gen_posts/{type_}_posts_train.txt", 
                  eval_file = f"../data/2.0_GPT-2_text_gen_posts/{type_}_posts_test.txt", 
                  output_dir = f"2.0_gen_lang_models/{type_}_lang_model/")

# pickle.dump(model, open(f'models/{type_}_lang_model', 'wb'))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1572.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=720.0), HTML(value='')))




HBox(children=(HTML(value='Epoch'), FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(HTML(value='Running Epoch 0 of 3'), FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 1 of 3'), FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 2 of 3'), FloatProgress(value=0.0, max=12.0), HTML(value='')))





(36, 4.239758994844225)

In [263]:
results.append((type_, model.eval_model(f"../data/2.0_GPT-2_text_gen_posts/{type_}_posts_test.txt")))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=260.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=142.0), HTML(value='')))




HBox(children=(HTML(value='Running Evaluation'), FloatProgress(value=0.0, max=18.0), HTML(value='')))




In [264]:
results

[('ESFP', {'eval_loss': 4.146227061748505, 'perplexity': tensor(63.1951)}),
 ('ESTP', {'eval_loss': 4.114682740635342, 'perplexity': tensor(61.2328)})]

In [265]:
gen_model = LanguageGenerationModel("gpt2", f"2.0_gen_lang_models/{type_}_lang_model/", args={"max_length": 16}, use_cuda=False)

In [266]:
prompts = [
    "I was always bullied by my family for being small, \
I think they were always jealous of how much smarter I was compared to them. \
I never let it bother me, because I knew in the end I would be more successful than any of them."
]

In [None]:
for prompt in prompts:
    # Generate text using the model. Verbose set to False to prevent logging generated sequences.
    generated = ''
    while len(generated) <= len(prompt):
        generated = gen_model.generate(prompt, verbose=False, args={"max_length": 5})
        generated = '.'.join(generated[0].split('.')[:-1]) + '.'
    print(generated)