# Imports

In [1]:
import pandas as pd
from catboost import CatBoostClassifier, Pool
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

  from pandas import MultiIndex, Int64Index


# Read data

In [87]:
data = pd.read_csv('../data/payments_augmented_training.tsv', sep='\t', header=None)

In [101]:
data.rename({
    1: 'date',
    2: 'sum',
    3: 'purpose',
    4: 'category'
}, axis=1, inplace=True)

# Training

In [102]:
X_train, X_test, y_train, y_test = train_test_split(data[['purpose']], data['category'], test_size=0.2, random_state=42)


In [103]:
train_pool = Pool(data=X_train, label=y_train, text_features=['purpose'])  # text_features=[0] indicates the text column
test_pool = Pool(data=X_test, label=y_test, text_features=['purpose'])

In [107]:
model = CatBoostClassifier(
    iterations=300,
    depth=5,
    loss_function='MultiClass',
    eval_metric='Accuracy',
    verbose=0,
    text_features=[0]
)


In [109]:
model.fit(train_pool, eval_set=test_pool, verbose=0)

# Make predictions
preds = model.predict(test_pool)
preds_proba = model.predict_proba(test_pool)

In [115]:
accuracy = accuracy_score(y_test, preds)
print(f"Accuracy: {accuracy}")
print("\nClassification Report:")
print(classification_report(y_test, preds))

Accuracy: 1.0

Classification Report:
                precision    recall  f1-score   support

  BANK_SERVICE       1.00      1.00      1.00        70
    FOOD_GOODS       1.00      1.00      1.00       190
       LEASING       1.00      1.00      1.00        70
          LOAN       1.00      1.00      1.00       100
NON_FOOD_GOODS       1.00      1.00      1.00       160
NOT_CLASSIFIED       1.00      1.00      1.00        30
   REALE_STATE       1.00      1.00      1.00        50
       SERVICE       1.00      1.00      1.00       200
           TAX       1.00      1.00      1.00       130

      accuracy                           1.00      1000
     macro avg       1.00      1.00      1.00      1000
  weighted avg       1.00      1.00      1.00      1000

