In [1]:
import numpy as np
import pandas as pd
from simpletransformers.language_modeling import LanguageModelingModel
from simpletransformers.language_generation import LanguageGenerationModel
import warnings
import pickle
import re
warnings.filterwarnings('ignore')

In [2]:
def clean(post):
    post_split = post.split('|||')
    # split the kaggle data set posts by |||
    post_split_split = [x.split(' ') for x in post_split]
    
    # removes any 'words' that have http:// or https:// in them
    return_list = [[item for item in sentence if ('http://' not in item and 'https://' not in item)] for sentence in post_split_split]
    
    # returns a list of posts if they are not empty after removing the links
    return [' '.join(sentence) for sentence in return_list if sentence]

In [3]:
df = pd.read_csv('../data/mbti_1.csv')

In [4]:
df_dict = {'INTJ':0, 'INTP':0, 'ENTJ':1, 'ENTP':1, 
           'INFJ':0, 'INFP':0, 'ENFJ':1, 'ENFP':1, 
           'ISTJ':0, 'ISFJ':0, 'ESTJ':1, 'ESFJ':1, 
           'ISTP':0, 'ISFP':0, 'ESTP':1, 'ESFP':1}

In [5]:
for dataframe in df_dict.keys():
    df_dict[dataframe] = df[df['type'] == dataframe]
    print(df_dict[dataframe].shape)

(1091, 2)
(1304, 2)
(231, 2)
(685, 2)
(1470, 2)
(1832, 2)
(190, 2)
(675, 2)
(205, 2)
(166, 2)
(39, 2)
(42, 2)
(337, 2)
(271, 2)
(89, 2)
(48, 2)


In [6]:
for dataframe in df_dict.keys():
    df_dict[dataframe]['post_split'] = df_dict[dataframe].posts.apply(clean)
    print(df_dict[dataframe].shape)

(1091, 3)
(1304, 3)
(231, 3)
(685, 3)
(1470, 3)
(1832, 3)
(190, 3)
(675, 3)
(205, 3)
(166, 3)
(39, 3)
(42, 3)
(337, 3)
(271, 3)
(89, 3)
(48, 3)


In [7]:
df_dict['INTJ'].head()

Unnamed: 0,type,posts,post_split
3,INTJ,"'Dear INTP, I enjoyed our conversation the o...","['Dear INTP, I enjoyed our conversation the ..."
5,INTJ,'18/37 @.@|||Science is not perfect. No scien...,"['18/37 @.@, Science is not perfect. No scien..."
7,INTJ,'I tend to build up a collection of things on ...,['I tend to build up a collection of things on...
13,INTJ,"'Fair enough, if that's how you want to look a...","['Fair enough, if that's how you want to look ..."
36,INTJ,"'Poker face for sure, accompanied by some sarc...","['Poker face for sure, accompanied by some sar..."


In [8]:
for dataframe in df_dict.keys():
    df_dict[dataframe] = df_dict[dataframe].post_split.tolist()

In [19]:
for type_ in df_dict.keys():
    with open(f"../data/GPT-2_text_gen_posts/{type_}_posts_train.txt", "w") as f:
        for post in df_dict[type_][:-10]:
            for ind_post in post:
                if(re.search('[a-zA-Z]', ind_post)):
                    f.writelines(ind_post + "\n")

    with open(f"../data/GPT-2_text_gen_posts/{type_}_posts_test.txt", "w") as f:
        for post in df_dict[type_][-10:]:
            for ind_post in post:
                if(re.search('[a-zA-Z]', ind_post)):
                    f.writelines(ind_post + "\n")

In [9]:
train_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "train_batch_size": 64,
    "num_train_epochs": 3,
    "mlm": False,
}

results = []

In [10]:
type_ = 'INTJ'
model = LanguageModelingModel('gpt2', 'gpt2', args=train_args, use_cuda=False)

model.train_model(train_file = f"../data/GPT-2_text_gen_posts/{type_}_posts_train.txt", 
                  eval_file = f"../data/GPT-2_text_gen_posts/{type_}_posts_test.txt", 
                  output_dir = f"gen_lang_models/{type_}_lang_model/")

results.append(model.eval_model(f"../data/GPT-2_text_gen_posts/{type_}_posts_test.txt"))

pickle.dump(model, open(f'models/{type_}_lang_model', 'wb'))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50330.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13883.0), HTML(value='')))




HBox(children=(HTML(value='Epoch'), FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(HTML(value='Running Epoch 0 of 3'), FloatProgress(value=0.0, max=217.0), HTML(value='')))





KeyboardInterrupt: 

In [60]:
gen_model = LanguageGenerationModel("gpt2", "./outputs", args={"max_length": 64}, use_cuda=False)

In [72]:
pickle.dump(model, open('models/esfp_gen_model', 'wb'))

In [67]:
prompts = [
    'Hello, my name is',
    'I really do not',
    'My favorite thing to do',
    'I really like',
    "I hope you don't",
    'You are the very reason why'
]

In [75]:
gen_model_loaded = pickle.load(open('models/esfp_gen_model', 'rb'))

In [85]:
for prompt in prompts:
    # Generate text using the model. Verbose set to False to prevent logging generated sequences.
    generated = gen_model.generate(prompt, verbose=False, args={"max_length": 32})

    generated = '.'.join(generated[0].split('.')[:-1]) + '.'
    print(generated)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Hello, my name is Cenk.  As an American I am a member of an organization which provides assistance to refugees and immigrants.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


My favorite thing to do is sit and write! I know that sometimes I want a chance to do something crazy and I have never seen anything like it.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


I really like the idea, I'll be honest. I don't do it with the camera or something.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


I hope you don't mind. I am an academic/physician who is studying medicine with great passion.
You are the very reason why I do this. I am sure you'll understand.
