## Read data

In [61]:
import pandas as pd
from sklearn.model_selection import train_test_split


In [62]:
train_dataset = "../data/train_preprocess.csv"
train_df = pd.read_csv(train_dataset)

In [63]:
train_df.head()

Unnamed: 0,type,priority,class,is_return,weight,mailtype,mailctg,directctg,transport_pay,postmark,weight_mfi,price_mfi,total_qty_over_index,is_wrong_sndr_name,is_wrong_rcpn_name,is_wrong_phone_number,is_wrong_address,label,oper_type,oper_attr
0,18,7503,0,1,0.000551,5,1,2,0.0,0,0.002278,0.000939,0.016573,0,0,0,0,0,1043,-1
1,4,7503,0,1,0.000677,5,1,2,0.0,0,0.003778,0.002505,0.273502,0,0,0,0,0,1023,-1
2,19,7503,0,1,0.000316,5,1,2,0.0,0,0.003111,0.001365,0.105363,0,1,0,0,0,1018,-1
3,19,7503,0,1,0.002633,5,1,2,0.042553,0,0.001833,0.000626,0.039105,0,0,0,0,0,1019,-1
4,18,7503,0,1,0.005032,5,1,2,0.063239,0,0.039778,0.006262,0.009434,0,0,0,0,0,1020,-1


In [64]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5999846 entries, 0 to 5999845
Data columns (total 20 columns):
 #   Column                 Dtype  
---  ------                 -----  
 0   type                   int64  
 1   priority               int64  
 2   class                  int64  
 3   is_return              int64  
 4   weight                 float64
 5   mailtype               int64  
 6   mailctg                int64  
 7   directctg              int64  
 8   transport_pay          float64
 9   postmark               int64  
 10  weight_mfi             float64
 11  price_mfi              float64
 12  total_qty_over_index   float64
 13  is_wrong_sndr_name     int64  
 14  is_wrong_rcpn_name     int64  
 15  is_wrong_phone_number  int64  
 16  is_wrong_address       int64  
 17  label                  int64  
 18  oper_type              int64  
 19  oper_attr              int64  
dtypes: float64(5), int64(15)
memory usage: 915.5 MB


In [65]:
X = train_df.drop("label", axis=1)
y = train_df["label"]

In [66]:
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)

## Train data

In [69]:
import os
import datetime
from catboost import CatBoostClassifier

In [67]:
cat_features = [
    "type",
    "priority",
    "class",
    "is_return",
    "mailtype",
    "mailctg",
    "directctg",
    "postmark",
    "is_wrong_sndr_name",
    "is_wrong_rcpn_name",
    "is_wrong_phone_number",
    "is_wrong_address",
    "oper_type",
    "oper_attr",
]

<catboost.core.CatBoostClassifier at 0x7ffaa91c6ca0>

In [None]:
ITERATIONS = 50
RANDOM_SEED = 42
LEARNING_RATE = 0.5

TIMESTAMP = datetime.datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
LOG_DIR = f'../log/catbost-{TIMESTAMP}'

PLOT_FILE = f'{LOG_DIR}/plot'
SNAPSHOT_FILE = f'{LOG_DIR}/snapshot'

In [None]:
if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)

In [None]:
clf = CatBoostClassifier(
    iterations=ITERATIONS,
    random_seed=RANDOM_SEED,
    learning_rate=LEARNING_RATE,
    snapshot_file=SNAPSHOT_FILE,
    plot_file=PLOT_FILE,
    custom_loss=['AUC', 'Recall', 'Precision', 'Accuracy']
)

clf.fit(
    X_train, y_train,
    cat_features=cat_features,
    eval_set=(X_val, y_val),
    verbose=False,
)

In [None]:
best_score = clf.get_best_score()
logs = clf.get_evals_result()