# Table of Contents
- Setup and Imports
- Model Features
- Utility Functions
- Data Loading
- Stratified Train and Test Split
- Determining the Distinguishing Words for each Personality Type
- Processing the Raw Data
- Fitting the Classifier
- Test Accuracy and Error Analysis
- Complete Workflow Execution

Data downloaded from: https://www.kaggle.com/datasnaek/mbti-type

### Setup and Imports

In [None]:
!pip install -U textblob
!python -m textblob.download_corpora

In [2]:
import pickle
from collections import Counter

import functools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from time import time
from textblob import TextBlob
from tqdm import tqdm

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import GridSearchCV, StratifiedShuffleSplit, cross_val_predict, cross_val_score
from sklearn.preprocessing import OrdinalEncoder, StandardScaler
from sklearn.svm import SVC

### Model Features

In [4]:
# UF = Unimportant Feature

POLARITY = 'polarity'
SUBJECTIVITY = 'subjectivity'
# AVG_SENTENCE_WORDS = 'avg_st_wds'  # UF
# STD_SENTENCE_WORDS = 'std_st_wds'  # UF
# AVG_SENTENCE_LENGTH = 'avg_st_len'  # UF
# STD_SENTENCE_LENGTH = 'std_st_len'  # UF
AVG_DESCRIPTOR_LENGTH = 'avg_wd_len'
STD_DESCRIPTOR_LENGTH = 'std_wd_len'

FEATURE_WORDS = [
    'always',
    # 'any',  # UF
    'does',
    'enfj',
    'enfp',
    'entj',
    'entp',
    # 'esfj',  # UF
    # 'esfjs',  # UF
    # 'esfp',  # UF
    # 'estj',  # UF
    # 'estp',  # UF
    # 'even',  # UF
    'feel',
    # 'go',  # UF
    'her',
    'him',
    # 'http',  # UF
    # 'https',  # UF
    'infj',
    'infp',
    'intj',
    'intp',
    'isfj',
    'isfp',
    'istj',
    'istp',
    # 'life',  # UF
    # 'll',  # UF
    'lot',
    'love',
    # 'make',  # UF
    # 'myself',  # UF
    # 'need',  # UF
    # 'never',  # UF
    # 'someone',  # UF
    # 'sure',  # UF
    # 'than',  # UF
    'their',
    # 'though',  # UF
    'type',
    # 'u',  # UF
    'we',
    'why',
]
PUNC_TOKENS = {
    ',': 'punc_comma',
    '.': 'punc_period',
    # ';': 'punc_semicolon',  # UF
    '(': 'punc_lparen',
    ')': 'punc_rparen',
    '?': 'punc_question',
    '!': 'punc_exclam',
    ':': 'punc_colon',
    # '^': 'punc_hat',  # UF
    # '*': 'punc_star',  # UF
    # '%': 'punc_percent',  # UF
    # '[': 'punc_lbrac',  # UF
    # ']': 'punc_rbrac',  # UF
    # '/': 'punc_fslash'  # UF
}
EXT_POLARITIES = 'ext_pol'
EXT_SUBJECTIVITIES = 'ext_subj'
# CLASS_IRONY = 'class_irony'  # UF
CLASS_MOOD = 'class_mood'
# CLASS_PROFANITY = 'class_prof'  # UF

In [5]:
BASE_FILE_PATH = "drive/MyDrive/google_colab_data/mbti_class"  # Replace with path to your folder
RAW_DATA_PATH = f"{BASE_FILE_PATH}/mbti_1.csv"
MBTI_TYPES = ['ENFJ', 'ENFP', 'ENTJ', 'ENTP', 'ESFJ', 'ESFP', 'ESTJ', 'ESTP', 'INFJ', 'INFP', 'INTJ', 'INTP', 'ISFJ', 'ISFP', 'ISTJ', 'ISTP']
SAVED_COLS = [POLARITY, SUBJECTIVITY, AVG_DESCRIPTOR_LENGTH, STD_DESCRIPTOR_LENGTH, 'word_count', 'punc_count', 'word_assessments']
DROP_COLS = [
    'class_irony', 'class_prof', 'freq_any', 'freq_esfj', 'freq_esfjs', 'freq_esfp', 'freq_estj', 'freq_estp',
    'freq_even', 'freq_go', 'freq_http', 'freq_https', 'freq_life', 'freq_ll', 'freq_make', 'freq_myself',
    'freq_need', 'freq_never', 'freq_punc_fslash', 'freq_punc_hat', 'freq_punc_lbrac', 'freq_punc_percent',
    'freq_punc_rbrac', 'freq_punc_semicolon', 'freq_punc_star', 'freq_someone', 'freq_sure', 'freq_than',
    'freq_though', 'freq_u'
]
DESCRIPTOR_TYPES = {'JJ', 'JJR', 'JJS', 'NN', 'NNS', 'NNP', 'NNPS', 'PDT', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ'}
LABEL_COL = 'type'

RANDOM_STATE = 42
RAW_POST_DELIMITER = '|||'

### Utility Functions

In [6]:
def time_function(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start = time()
        res = func(*args, **kwargs)
        print(f'[{round((time() - start), 5)} sec]', func.__name__)
        return res
    return wrapper

### Data Loading

In [7]:
@time_function
def load_raw_mbti_data(data_path: str):
    with open(data_path, 'r') as f:
        mbti_data = pd.read_csv(f)
    
    for _, row in mbti_data.iterrows():
        row[LABEL_COL] = row[LABEL_COL].replace(RAW_POST_DELIMITER, '\n')
    
    return mbti_data

In [8]:
@time_function
def load_preprocessed_mbti_data(data_path, row_count=None):
    if row_count:
        mbti_data = pd.DataFrame(index=range(row_count))
    else:
        with open(data_path, 'r') as f:
            mbti_data = pd.DataFrame(index=range(sum(1 for row in f) - 1))

    for col in SAVED_COLS:
        with open(f'{BASE_FILE_PATH}/full_posts_nlp_{col}.pkl', 'rb') as f:
            mbti_data = mbti_data.join(pickle.load(f))
    with open(f'{BASE_FILE_PATH}/full_posts_types.pkl', 'rb') as f:
        mbti_data = mbti_data.join(pickle.load(f))

    return mbti_data

### Determining the Distinguishing Words for each Personality Type

In [9]:
WORD_COUNT_FILE_NAME_SUFFIX = 'word_freq'

@time_function
def compute_mbti_word_counts(mbti_data, max_most_common=100):
    for mbti_type in MBTI_TYPES:
        c = Counter()

        for _, row in tqdm(mbti_data[mbti_data.type == mbti_type].iterrows(), mininterval=5):
            text_analysis = TextBlob(row[LABEL_COL])
            c += Counter(text_analysis.word_counts)

        print(f'Saving {mbti_type} top {max_most_common} word counts...')
        with open(f'{BASE_PATH}/{mbti_type}_{WORD_COUNT_FILE_NAME_SUFFIX}.pkl', 'wb') as f:
            pickle.dump(c.most_common(max_most_common), f)

In [10]:
@time_function
def compute_most_distinguishing_words(label_value_counts, max_most_common=10, score_threshold=8):
    mbti_word_counts_catalogue = {}

    for mbti_type in MBTI_TYPES:
        with open(f'{BASE_PATH}/{mbti_type}_{WORD_COUNT_FILE_NAME_SUFFIX}.pkl', 'rb') as f:
            type_word_counts = pickle.load(f)
            words, freqs = list(map(list, zip(*type_word_counts)))
            mbti_word_counts_catalogue[mbti_type] = dict(zip(words, np.array(freqs) / label_value_counts[mbti_type]))
    
    feature_words = set()
    for mbti_type in MBTI_TYPES:
        c = Counter()

        for other_mbti_type in set(MBTI_TYPES) - {mbti_type}:
            for word, freq in mbti_word_counts_catalogue[mbti_type].items():
                c[word] += abs((mbti_word_counts_catalogue[other_mbti_type].get(word, 0) - freq) / freq)
        
        distinguishing_words = [
            word for word, score in c.most_common(max_most_common) if score >= score_threshold and word.isalpha()
        ]
        feature_words.update(distinguishing_words)
    
    return feature_words

### Stratified Train and Test Split

In [11]:
@time_function
def stratify_split_mbti_data(mbti_data, test_size=0.25, random_state=RANDOM_STATE):
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
    for train_idx, test_idx in splitter.split(mbti_data, mbti_data[LABEL_COL]):
        strat_train_data = mbti_data.iloc[train_idx, :].reset_index(drop=True)
        strat_test_data = mbti_data.iloc[test_idx, :].reset_index(drop=True)
    return strat_train_data, strat_test_data

### Processing the Raw Data

In [12]:
class PostProcessor(BaseException, TransformerMixin):
    def __init__(self):
        pass
    
    def fit(self, X, y=None):
        return self

    def get_sentiment_assessment_stats(self, assessments):
        try:
            _, polarities, subjectivities, word_types = list(map(list, zip(*assessments)))
        except Exception:
            polarities, subjectivities, word_types = [], [], []
        return {
            EXT_POLARITIES: sum(1 for pol in polarities if abs(pol) >= 0.3),
            EXT_SUBJECTIVITIES: sum(1 for subj in subjectivities if abs(subj) >= 0.3),
            CLASS_MOOD: sum(1 for word_type in word_types if word_type == 'mood')
        }

    def get_punctuation_token_stats(self, tokens):
        return {f'freq_{label}': tokens.count(token) for token, label in PUNC_TOKENS.items()}

    def get_word_count_stats(self, word_counts):
        return {f'freq_{feature_word}': word_counts.get(feature_word, 0) for feature_word in FEATURE_WORDS}

    def get_descriptor_stats(self, tags):
        descriptors = [
            len(word) for word, POS in tags
            if POS in DESCRIPTOR_TYPES and not word.startswith('//') and not word.startswith('v=') and word.isalpha()
        ]
        return {
            AVG_DESCRIPTOR_LENGTH: np.mean(descriptors),
            STD_DESCRIPTOR_LENGTH: np.std(descriptors)
        }

    def get_text_features(self, post):
        post_analysis = TextBlob(post)

        new_entry = {
            POLARITY: post_analysis.sentiment.polarity,
            SUBJECTIVITY: post_analysis.sentiment.subjectivity,
        }
        new_entry.update(self.get_descriptor_stats(post_analysis.tags))
        new_entry.update(self.get_sentiment_assessment_stats(post_analysis.sentiment_assessments.assessments))
        new_entry.update(self.get_punctuation_token_stats(post_analysis.tokens))
        new_entry.update(self.get_word_count_stats(post_analysis.word_counts))
        
        return new_entry

    @time_function
    def transform(self, X):
        transformed_X = [self.get_text_features(row['posts']) for idx, row in tqdm(X.iterrows(), mininterval=10)]
        return pd.DataFrame(transformed_X)

In [13]:
@time_function
def prepare_data(data, preprocessed=False):
    standard_scaler = StandardScaler()
    ordinal_encoder = OrdinalEncoder()

    if preprocessed:
        data_num_prepared = standard_scaler.fit_transform(data.drop([LABEL_COL] + DROP_COLS, axis=1, errors='ignore'))
    else:
        data_num_prepared = PostProcessor().fit_transform(data)
        data_num_prepared = standard_scaler.fit_transform(data_num_prepared.drop([LABEL_COL] + DROP_COLS, axis=1, errors='ignore'))
    data_cat_prepared = ordinal_encoder.fit_transform(data[[LABEL_COL]].values).reshape(-1)

    return data_num_prepared, data_cat_prepared, ordinal_encoder

### Fitting the Classifier

In [14]:
@time_function
def fit_classifier(train_data_features, train_data_labels):
    clf = VotingClassifier(
        estimators=[
            ('svc', SVC(kernel='rbf', gamma='scale', C=0.25, probability=True)),
            ('forest', RandomForestClassifier(criterion='gini', n_estimators=100, min_samples_leaf=1)),
            ('logistic', LogisticRegression(multi_class="multinomial", solver="lbfgs", C=0.05, max_iter=5000)),
        ],
        voting='soft',
        weights=[0.4, 0.4, 0.2]
    )
    clf.fit(train_data_features, train_data_labels)
    return clf

In [15]:
def grid_search_clf(clf_type, train_data_features, train_data_labels, param_grid, cross_valid=2):
    if clf_type == 'svc':
        clf = SVC(kernel='rbf', gamma='scale', probability=True)
    elif clf_type == 'forest':
        clf = RandomForestClassifier()
    elif clf_type == 'logistic':
        clf = LogisticRegression(multi_class="multinomial", solver="lbfgs", max_iter=5000)

    grid_search = GridSearchCV(clf, param_grid, cv=cross_valid, scoring='precision_macro', return_train_score=True, verbose=2)
    grid_search.fit(train_data_features, train_data_labels)
    return grid_search

@time_function
def get_best_clf_params(train_data_features, train_data_labels, svc=False, forest=False, logistic=False):
    if svc:
        svc_grid_search = grid_search_clf(
            'svc',
            train_data_features, train_data_labels,
            param_grid=[{'C': [0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1]}]
        )
        print(svc_grid_search.best_params_)
    if forest:
        forest_grid_search = grid_search_clf(
            'forest',
            train_data_features, train_data_labels,
            param_grid=[{'n_estimators': [2, 5, 10, 25, 50, 100, 200], 'criterion': ['entropy', 'gini'], 'min_samples_leaf': [1, 2]}]
        )
        print(forest_grid_search.best_params_)
    if logistic:
        log_grid_search = grid_search_clf(
            'logistic',
            train_data_features, train_data_labels,
            param_grid=[{'C': [0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1]}]
        )
        print(log_grid_search.best_params_)

### Test Accuracy and Error Analysis

In [16]:
@time_function
def get_test_accuracy(clf, test_data_features, test_data_labels):
    pred = clf.predict(test_data_features)
    return accuracy_score(test_data_labels, pred)

In [17]:
@time_function
def get_cross_valid_score(clf, train_data_features, train_data_labels, cross_val=3):
    return cross_val_score(clf, train_data_features, train_data_labels, cv=cross_val, scoring="accuracy")

In [18]:
@time_function
def show_conf_mat(clf, train_data_features, train_data_labels, cross_val=3, normalized=False):
    cross_val_predictions = cross_val_predict(clf, train_data_features, train_data_labels, cv=cross_val)
    conf_mx = confusion_matrix(train_data_labels, cross_val_predictions)

    if normalized:
        row_sums = conf_mx.sum(axis=1, keepdims=True)
        norm_conf_mx = conf_mx / row_sums
        np.fill_diagonal(norm_conf_mx, 0)
        plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
    else:
        plt.matshow(conf_mx, cmap=plt.cm.gray)

In [19]:
@time_function
def get_ordered_feature_importances(clf, train_data, top_imp=50):
    print(
        *(
            sorted(
                list(zip(train_data.columns, clf.feature_importances_)),
                key=lambda pair: pair[1],
                reverse=True
            )[:top_imp]
         ),
        sep='\n'
    )

### Complete Workflow Execution

In [None]:
def execute_workflow(load_saved=False):
    if load_saved:
        mbti_data = load_preprocessed_mbti_data(RAW_DATA_PATH, row_count=8675)
    else:
        mbti_data = load_raw_mbti_data(RAW_DATA_PATH)
    
    strat_train_data, strat_test_data = stratify_split_mbti_data(mbti_data)
    train_data_features, train_data_labels, train_ord_encoder = prepare_data(strat_train_data, preprocessed=load_saved)

    clf = fit_classifier(train_data_features, train_data_labels)

    test_data_features, test_data_labels, test_ord_encoder = prepare_data(strat_test_data, preprocessed=load_saved)
    test_accuracy = get_test_accuracy(clf, test_data_features, test_data_labels)
    print(f"Accuracy: {round(test_accuracy, 5)}")

execute_workflow(load_saved=True)