In [2]:
import numpy as np
import pandas as pd
from simpletransformers.language_modeling import LanguageModelingModel
from simpletransformers.language_generation import LanguageGenerationModel
import warnings
import pickle
import re
warnings.filterwarnings('ignore')

In [3]:
def clean(post):
    post_split = post.split('|||')
    # split the kaggle data set posts by |||
    post_split_split = [x.split(' ') for x in post_split]
    
    # removes any 'words' that have http:// or https:// in them
    return_list = [[item for item in sentence if ('http://' not in item and 'https://' not in item)] for sentence in post_split_split]
    
    # returns a list of posts if they are not empty after removing the links
    return [' '.join(sentence) for sentence in return_list if sentence]

In [4]:
df = pd.read_csv('../data/mbti_1.csv')

In [5]:
df_dict = {'INTJ':0, 'INTP':0, 'ENTJ':1, 'ENTP':1, 
           'INFJ':0, 'INFP':0, 'ENFJ':1, 'ENFP':1, 
           'ISTJ':0, 'ISFJ':0, 'ESTJ':1, 'ESFJ':1, 
           'ISTP':0, 'ISFP':0, 'ESTP':1, 'ESFP':1}

In [14]:
for dataframe in df_dict.keys():
    df_dict[dataframe] = df[df['type'] == dataframe]
    print(dataframe, df_dict[dataframe].shape)

INTJ (1091, 2)
INTP (1304, 2)
ENTJ (231, 2)
ENTP (685, 2)
INFJ (1470, 2)
INFP (1832, 2)
ENFJ (190, 2)
ENFP (675, 2)
ISTJ (205, 2)
ISFJ (166, 2)
ESTJ (39, 2)
ESFJ (42, 2)
ISTP (337, 2)
ISFP (271, 2)
ESTP (89, 2)
ESFP (48, 2)


In [20]:
for dataframe in df_dict.keys():
    df_dict[dataframe]['post_split'] = df_dict[dataframe].posts.apply(clean)
    print(dataframe, df_dict[dataframe].shape)

INTJ (1091, 3)
INTP (1304, 3)
ENTJ (231, 3)
ENTP (685, 3)
INFJ (1470, 3)
INFP (1832, 3)
ENFJ (190, 3)
ENFP (675, 3)
ISTJ (205, 3)
ISFJ (166, 3)
ESTJ (39, 3)
ESFJ (42, 3)
ISTP (337, 3)
ISFP (271, 3)
ESTP (89, 3)
ESFP (48, 3)


In [8]:
df_dict['INTJ'].head()

Unnamed: 0,type,posts,post_split
3,INTJ,"'Dear INTP, I enjoyed our conversation the o...","['Dear INTP, I enjoyed our conversation the ..."
5,INTJ,'18/37 @.@|||Science is not perfect. No scien...,"['18/37 @.@, Science is not perfect. No scien..."
7,INTJ,'I tend to build up a collection of things on ...,['I tend to build up a collection of things on...
13,INTJ,"'Fair enough, if that's how you want to look a...","['Fair enough, if that's how you want to look ..."
36,INTJ,"'Poker face for sure, accompanied by some sarc...","['Poker face for sure, accompanied by some sar..."


In [9]:
for dataframe in df_dict.keys():
    df_dict[dataframe] = df_dict[dataframe].post_split.tolist()

In [11]:
for type_ in df_dict.keys():
    with open(f"../data/GPT-2_text_gen_posts/{type_}_posts_train.txt", "w") as f:
        for post in df_dict[type_][:-10]:
            for ind_post in post:
                if(re.search('[a-zA-Z]', ind_post)):
                    f.writelines(ind_post + "\n")

    with open(f"../data/GPT-2_text_gen_posts/{type_}_posts_test.txt", "w") as f:
        for post in df_dict[type_][-10:]:
            for ind_post in post:
                if(re.search('[a-zA-Z]', ind_post)):
                    f.writelines(ind_post + "\n")

In [12]:
train_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "train_batch_size": 64,
    "num_train_epochs": 3,
    "mlm": False,
}

results = []

In [13]:
type_ = 'ESTP'
model = LanguageModelingModel('gpt2', 'gpt2', args=train_args, use_cuda=False)

model.train_model(train_file = f"../data/GPT-2_text_gen_posts/{type_}_posts_train.txt", 
                  eval_file = f"../data/GPT-2_text_gen_posts/{type_}_posts_test.txt", 
                  output_dir = f"gen_lang_models/{type_}_lang_model/")

# pickle.dump(model, open(f'models/{type_}_lang_model', 'wb'))

In [22]:
results.append((type_, model.eval_model(f"../data/GPT-2_text_gen_posts/{type_}_posts_test.txt")))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=436.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=115.0), HTML(value='')))




HBox(children=(HTML(value='Running Evaluation'), FloatProgress(value=0.0, max=15.0), HTML(value='')))




In [16]:
gen_model = LanguageGenerationModel("gpt2", "gen_lang_models/ESFP_lang_model/", args={"max_length": 64}, use_cuda=False)

In [17]:
prompts = [
    'Hello, my name is',
    'I really do not',
    'My favorite thing to do',
    'I really like',
    "I hope you don't",
    'You are the very reason why'
]

In [18]:
for prompt in prompts:
    # Generate text using the model. Verbose set to False to prevent logging generated sequences.
    generated = gen_model.generate(prompt, verbose=False, args={"max_length": 32})

    generated = '.'.join(generated[0].split('.')[:-1]) + '.'
    print(generated)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Hello, my name is David Prowse. I'm a student in your high school biology department. I just found out that I'm the mother of my children.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


I really do not like to be accused by friends or any group, I think it's like an unneeded exclamation point to have the guy you date take it too seriously.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


My favorite thing to do when I first go to my apartment is to take a walk to talk to you about what to do for a job.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


I really like the sound of the voice. It sounds really strong so I'm always happy to hear a person's voice.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


.
.
