In [1]:
import pandas as pd 
import numpy as np
import pickle

DATA_PATH = 'data/'

In [2]:
with open(f"{DATA_PATH}train_0.pkl", 'rb') as f:
    data = pickle.load(f)

In [3]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder
train, valid = train_test_split(data, test_size=0.2, random_state=42, stratify=data['mbti'], shuffle=True)
y_train = train['mbti']
X_train = train.drop('mbti', axis=1)
y_valid = valid['mbti']
X_valid = valid.drop('mbti', axis=1)

In [4]:
from catboost import CatBoostClassifier

# Convert y_train and y_valid to separate binary columns for each MBTI character
y_train_binary = pd.DataFrame({
    'I-E': y_train.apply(lambda x: 1 if x[0] == 'I' else 0),
    'N-S': y_train.apply(lambda x: 1 if x[1] == 'N' else 0),
    'T-F': y_train.apply(lambda x: 1 if x[2] == 'T' else 0),
    'J-P': y_train.apply(lambda x: 1 if x[3] == 'J' else 0)
})

y_valid_binary = pd.DataFrame({
    'I-E': y_valid.apply(lambda x: 1 if x[0] == 'I' else 0),
    'N-S': y_valid.apply(lambda x: 1 if x[1] == 'N' else 0),
    'T-F': y_valid.apply(lambda x: 1 if x[2] == 'T' else 0),
    'J-P': y_valid.apply(lambda x: 1 if x[3] == 'J' else 0)
})

# List to store models and performance metrics
models = {}
accuracy_scores = {}
f1_scores = {}

# Train and evaluate a model for each binary classification problem
for column in y_train_binary.columns:
    # Initialize the CatBoostClassifier
    model = CatBoostClassifier(verbose=0, task_type="GPU", devices='0:1')
    
    # Fit the model
    model.fit(X_train, y_train_binary[column])
    
    # Predict on the validation set
    y_pred = model.predict(X_valid)
    
    # Calculate accuracy and F1 score
    accuracy = accuracy_score(y_valid_binary[column], y_pred)
    f1 = f1_score(y_valid_binary[column], y_pred)
    
    # Store the model and metrics
    models[column] = model
    accuracy_scores[column] = accuracy
    f1_scores[column] = f1
    
    # Print the results
    print(f"Results for {column}:")
    print(f"Accuracy: {accuracy}")
    print(f"F1 Score: {f1}")
    print("-" * 30)

Results for I-E:
Accuracy: 0.7881136196405003
F1 Score: 0.881487360894716
------------------------------
Results for N-S:
Accuracy: 0.9246195353835013
F1 Score: 0.9608311990969449
------------------------------
Results for T-F:
Accuracy: 0.6045036732500964
F1 Score: 0.6646563078720897
------------------------------
Results for J-P:
Accuracy: 0.6030437169319894
F1 Score: 0.12936444911238057
------------------------------
