In [None]:
import os, sys, time
import multiprocessing
import pickle
import re, string
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter

## Preprocess

In [None]:
data = pd.read_csv("mbti_1.csv")
n_users = len(data)
posts = data["posts"]
labels = data["type"].unique()
type2num = {label: i for i,label in enumerate(labels)}
Y = np.array(list(map(lambda s: type2num[s], data["type"].to_numpy())))

In [None]:
def plot_distribution():
    fig, ax = plt.subplots(figsize=(10,4))
    type_val = data["type"].value_counts()
    labels = type_val.keys()
    x = np.arange(len(labels))
    ax.bar(x, type_val.values)
    ax.set_ylabel("# of people")
    ax.set_xticks(x)
    ax.set_xticklabels(labels,rotation='45')
    ax.set_axisbelow(True)
    ax.yaxis.grid(color='gray', linestyle='dashed')
    fig.tight_layout()
    plt.show()

In [None]:
def generate_posts(path=""):
    filename = os.path.join(path,"posts.pkl")
    user_posts = []
    if not os.path.isfile(filename):
        stopwords = pd.read_csv("stopwords.csv").to_numpy().reshape(-1)
        stopwords = np.array(list(map(lambda s: s.replace("'",""),stopwords)))
        for uid in range(n_users):
            # add empty space first (better used for regex parsing)
            new_post = posts[uid].replace("|||"," ||| ")
            new_post = new_post.replace(",",", ")
            # remove url links
            new_post = re.sub("(http|https):\/\/.*?( |'|\")","",new_post)
            # avoid words in two sentences merged together after removing spaces
            new_post = new_post.replace(".",". ")
            # remove useless numbers and punctuations
            new_post = re.sub(r"[0-9]+", "", new_post)
            new_post = new_post.translate(str.maketrans('', '', string.punctuation))
            # remove redundant empty spaces
            new_post = re.sub(" +"," ",new_post).strip()
            # make all characters lower
            new_post = new_post.lower()
            temp = []
            for word in new_post.split():
                if len(word) != 1 and word not in stopwords:
                    temp.append(word)
            user_posts.append(temp)
            if uid * 100 % n_users == 0:
                print("Done {}/{} = {}%".format(uid,n_users,uid*100/n_users))
        print("Finished generating word list")
        pickle.dump(user_posts,open(filename,"wb"))
    else:
        user_posts = pickle.load(open(filename,"rb"))
        print("Loaded user posts")
    return user_posts

In [None]:
def generate_dict(user_posts,path=""):
    filename = os.path.join(path,"word_dict.npz")
    if not os.path.isfile(filename):
        word_lst = []
        for post in user_posts:
            word_lst += post

        # make dictionary (used for bag of words, BOW)
        word_counts = Counter(word_lst)
        word_counts["<UNK>"] = max(word_counts.values()) + 1
        # remove words that don’t occur too frequently
        print("# of words before:",len(word_counts))
        for word in list(word_counts): # avoid changing size
            if word_counts[word] < 6:
                del word_counts[word]
        print("# of words after:",len(word_counts))
        # sort based on counts, but only remain the word strings
        sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True)

        # make embedding based on the occurance frequency of the words
        int_to_word = {k: w for k, w in enumerate(sorted_vocab)}
        word_to_int = {w: k for k, w in int_to_word.items()}
        np.savez(filename,int2word=int_to_word,word2int=word_to_int)
    else:
        infile = np.load(filename,allow_pickle=True)
        int_to_word = infile["int2word"].item()
        word_to_int = infile["word2int"].item()
        print("Loaded {}".format(filename))
    n_words = len(int_to_word)
    print('Vocabulary size:', n_words)
    return word_to_int, int_to_word

In [None]:
def generate_bow(user_posts,word_to_int):
    n_users = len(user_posts)
    n_words = len(word_to_int)
    feature = np.zeros((n_users,n_words))
    print(feature.shape)
    for uid, post in enumerate(user_posts):
        count = Counter(post)
        for key in count:
            feature[uid][word_to_int.get(key,0)] = count[key]
        if uid * 100 % n_users == 0:
            print("Done {}/{} = {}%".format(uid,n_users,uid*100/n_users))
    print("Finished generating BoW model")
    return feature

In [None]:
user_posts = generate_posts()
word2int, int2word = generate_dict(user_posts)
X = generate_bow(user_posts,word2int)

## Train model

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=42)

In [None]:
clf = RandomForestClassifier(n_estimators=5,verbose=2,n_jobs=4) # use all processors
clf.fit(X_train, y_train)
predict = clf.score(X_test, y_test)
print("Random forest acc: {:.2f}%".format(predict))