## Apress - Industrialized Machine Learning Examples

Andreas Francois Vermeulen
2019

### This is an example add-on to a book and needs to be accepted as part of that copyright.

# Chapter-007-013-CatBoost-01

In [1]:
#conda install -c conda-forge catboost

In [10]:
import os

In [2]:
from catboost.datasets import titanic

train_df, test_df = titanic()

## Build CatTrainer

In [3]:
class CatTrainer:
    
    def __init__(self, train_df):
        self.train_df = train_df
        self.model = None
        self.X = None
        self.y = None
        self.categorical_features_indices = None

    def _replace_null_values(self, value, inplace=True):
        self.train_df.fillna(value, inplace=inplace)

    def prepare_x_y(self, label, null_value=-999):
        self._replace_null_values(null_value)
        self.X = self.train_df.drop(label, axis=1)
        self.y = self.train_df[label]

    def _default_args_or_kwargs(self, **kwargs):
        params = {
            'iterations': 500,
            'learning_rate': 0.1,
            'eval_metric': 'Accuracy',
            'random_seed': 1968,
            'logging_level': 'Silent',
            'use_best_model': True
        }
        for k in kwargs.keys():
            if k in params:
                params[k] = kwargs.get(k)
        return params

    def create_model(self, **kwargs):
        from catboost import CatBoostClassifier
        params = self._default_args_or_kwargs(**kwargs)
        if not self.model:
            self.model = CatBoostClassifier(**params)
        else:
            raise ValueError("Cannot overwrite existing model")

    def train_model(self, train_size=0.75, random_state=42, **kwargs):
        from sklearn.model_selection import train_test_split
        import numpy as np
        
        X_train, X_validation, y_train, y_validation = train_test_split(self.X, 
                                                                        self.y, 
                                                                        train_size=0.7, 
                                                                        test_size=0.3,
                                                                        random_state=random_state
                                                                       )
        if not self.categorical_features_indices:
            self.categorical_features_indices = np.where(self.X.dtypes != np.float)[0]
            if not self.model:
                self.create_model(**kwargs)
                self.model.fit(
                X_train, y_train,
                cat_features=self.categorical_features_indices,
                eval_set=(X_validation, y_validation),
                logging_level='Verbose',
                )

    def model_cross_validation(self):
        from catboost import Pool, cv
        import numpy as np

        cv_data = cv(
                    Pool(self.X, self.y, cat_features=self.categorical_features_indices),
                    self.model.get_params()
                    )
        print(cv_data)
        return np.max(cv_data['test-Accuracy-mean'], axis=0)

    def save_model(self, name):
        self.model.save_model('{}.dump'.format(name))

    def load_model(self, name):
        self.model.load_model('{}.dump'.format(name))

    def predict(self, dataframe, null_value=999, inplace=True):
        dataframe.fillna(null_value, inplace=inplace)
        results = self.model.predict(dataframe)
        return results


In [4]:
c = CatTrainer(train_df)
c.prepare_x_y('Survived')
c.train_model()


0:	learn: 0.8298555	test: 0.8171642	best: 0.8171642 (0)	total: 146ms	remaining: 1m 13s
1:	learn: 0.8298555	test: 0.8171642	best: 0.8171642 (0)	total: 216ms	remaining: 53.8s
2:	learn: 0.8282504	test: 0.8171642	best: 0.8171642 (0)	total: 283ms	remaining: 46.9s
3:	learn: 0.8298555	test: 0.8097015	best: 0.8171642 (0)	total: 347ms	remaining: 43s
4:	learn: 0.8282504	test: 0.8022388	best: 0.8171642 (0)	total: 452ms	remaining: 44.7s
5:	learn: 0.8250401	test: 0.8097015	best: 0.8171642 (0)	total: 550ms	remaining: 45.3s
6:	learn: 0.8266453	test: 0.7985075	best: 0.8171642 (0)	total: 645ms	remaining: 45.5s
7:	learn: 0.8234350	test: 0.8022388	best: 0.8171642 (0)	total: 740ms	remaining: 45.5s
8:	learn: 0.8298555	test: 0.8022388	best: 0.8171642 (0)	total: 836ms	remaining: 45.6s
9:	learn: 0.8282504	test: 0.8022388	best: 0.8171642 (0)	total: 904ms	remaining: 44.3s
10:	learn: 0.8250401	test: 0.8059701	best: 0.8171642 (0)	total: 955ms	remaining: 42.5s
11:	learn: 0.8282504	test: 0.8059701	best: 0.8171642 (

95:	learn: 0.8731942	test: 0.7873134	best: 0.8171642 (0)	total: 6.05s	remaining: 25.4s
96:	learn: 0.8747994	test: 0.7873134	best: 0.8171642 (0)	total: 6.14s	remaining: 25.5s
97:	learn: 0.8764045	test: 0.7835821	best: 0.8171642 (0)	total: 6.24s	remaining: 25.6s
98:	learn: 0.8764045	test: 0.7835821	best: 0.8171642 (0)	total: 6.33s	remaining: 25.7s
99:	learn: 0.8780096	test: 0.7835821	best: 0.8171642 (0)	total: 6.43s	remaining: 25.7s
100:	learn: 0.8764045	test: 0.7873134	best: 0.8171642 (0)	total: 6.52s	remaining: 25.8s
101:	learn: 0.8796148	test: 0.7873134	best: 0.8171642 (0)	total: 6.61s	remaining: 25.8s
102:	learn: 0.8796148	test: 0.7873134	best: 0.8171642 (0)	total: 6.69s	remaining: 25.8s
103:	learn: 0.8796148	test: 0.7873134	best: 0.8171642 (0)	total: 6.78s	remaining: 25.8s
104:	learn: 0.8796148	test: 0.7910448	best: 0.8171642 (0)	total: 6.88s	remaining: 25.9s
105:	learn: 0.8796148	test: 0.7985075	best: 0.8171642 (0)	total: 6.97s	remaining: 25.9s
106:	learn: 0.8828250	test: 0.7947761

189:	learn: 0.9085072	test: 0.8022388	best: 0.8171642 (0)	total: 14.8s	remaining: 24.1s
190:	learn: 0.9101124	test: 0.8022388	best: 0.8171642 (0)	total: 14.9s	remaining: 24s
191:	learn: 0.9085072	test: 0.7947761	best: 0.8171642 (0)	total: 14.9s	remaining: 24s
192:	learn: 0.9069021	test: 0.7947761	best: 0.8171642 (0)	total: 15s	remaining: 23.9s
193:	learn: 0.9085072	test: 0.7947761	best: 0.8171642 (0)	total: 15.1s	remaining: 23.8s
194:	learn: 0.9117175	test: 0.7947761	best: 0.8171642 (0)	total: 15.2s	remaining: 23.8s
195:	learn: 0.9117175	test: 0.7947761	best: 0.8171642 (0)	total: 15.3s	remaining: 23.7s
196:	learn: 0.9117175	test: 0.7947761	best: 0.8171642 (0)	total: 15.4s	remaining: 23.7s
197:	learn: 0.9117175	test: 0.7947761	best: 0.8171642 (0)	total: 15.5s	remaining: 23.6s
198:	learn: 0.9117175	test: 0.7910448	best: 0.8171642 (0)	total: 15.6s	remaining: 23.5s
199:	learn: 0.9117175	test: 0.7910448	best: 0.8171642 (0)	total: 15.7s	remaining: 23.5s
200:	learn: 0.9117175	test: 0.7910448	

283:	learn: 0.9277689	test: 0.7947761	best: 0.8171642 (0)	total: 23.4s	remaining: 17.8s
284:	learn: 0.9293740	test: 0.7947761	best: 0.8171642 (0)	total: 23.5s	remaining: 17.7s
285:	learn: 0.9293740	test: 0.7947761	best: 0.8171642 (0)	total: 23.6s	remaining: 17.7s
286:	learn: 0.9293740	test: 0.7947761	best: 0.8171642 (0)	total: 23.7s	remaining: 17.6s
287:	learn: 0.9293740	test: 0.7947761	best: 0.8171642 (0)	total: 23.7s	remaining: 17.5s
288:	learn: 0.9293740	test: 0.7985075	best: 0.8171642 (0)	total: 23.8s	remaining: 17.4s
289:	learn: 0.9277689	test: 0.7947761	best: 0.8171642 (0)	total: 23.9s	remaining: 17.3s
290:	learn: 0.9277689	test: 0.7947761	best: 0.8171642 (0)	total: 24s	remaining: 17.2s
291:	learn: 0.9277689	test: 0.7947761	best: 0.8171642 (0)	total: 24.1s	remaining: 17.2s
292:	learn: 0.9293740	test: 0.7947761	best: 0.8171642 (0)	total: 24.2s	remaining: 17.1s
293:	learn: 0.9293740	test: 0.7947761	best: 0.8171642 (0)	total: 24.3s	remaining: 17s
294:	learn: 0.9293740	test: 0.794776

379:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 32.1s	remaining: 10.1s
380:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 32.2s	remaining: 10.1s
381:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 32.3s	remaining: 9.98s
382:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 32.4s	remaining: 9.9s
383:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 32.5s	remaining: 9.82s
384:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 32.6s	remaining: 9.73s
385:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 32.7s	remaining: 9.65s
386:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 32.8s	remaining: 9.57s
387:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 32.9s	remaining: 9.49s
388:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 32.9s	remaining: 9.4s
389:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 33s	remaining: 9.31s
390:	learn: 0.9357945	test: 0.809701

475:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 40.9s	remaining: 2.06s
476:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 41.1s	remaining: 1.98s
477:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 41.2s	remaining: 1.9s
478:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 41.3s	remaining: 1.81s
479:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 41.4s	remaining: 1.72s
480:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 41.5s	remaining: 1.64s
481:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 41.6s	remaining: 1.55s
482:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 41.7s	remaining: 1.47s
483:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 41.8s	remaining: 1.38s
484:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 41.9s	remaining: 1.29s
485:	learn: 0.9357945	test: 0.8097015	best: 0.8171642 (0)	total: 42s	remaining: 1.21s
486:	learn: 0.9357945	test: 0.80970

In [6]:
#score = c.model_cross_validation()

In [7]:
#print(score)

In [12]:
resultpath = os.path.join(*[os.path.dirname(os.path.dirname(os.getcwd())),'Results','Chapter 07'])
print(resultpath)

if not os.path.exists(resultpath):
    os.makedirs(resultpath)

C:\Users\AndreVermeulen\Documents\My Book\apress\Industrial Machine Learning\book\GitHub\Upload\industrial-machine-learning\Results\Chapter 07


In [14]:
datafilename = os.path.relpath(os.path.join(resultpath, 'Chapter-007-013-CatBoost-01-01'))
c.save_model(datafilename)

In [15]:
predictions = c.predict(test_df)
print(predictions)

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0.
 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0.
 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 0. 0.
 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0.
 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1.
 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.
 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.
 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 1. 1. 0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0. 0.
 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0.
 0. 0. 1. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1.
 1. 1. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0.

## Done

In [16]:
import datetime
now = datetime.datetime.now()
print('Done!',str(now))

Done! 2019-10-19 21:39:00.670034
