## Apress - Industrialized Machine Learning Examples

Andreas Francois Vermeulen
2019

### This is an example add-on to a book and needs to be accepted as part of that copyright.

# Chapter-007-013-CatBoost-01

In [18]:
from catboost.datasets import titanic
train_df, test_df = titanic()

## Build CatTrainer

In [2]:
class CatTrainer:
    
    def __init__(self, train_df):
        self.train_df = train_df
        self.model = None
        self.X = None
        self.y = None
        self.categorical_features_indices = None

    def _replace_null_values(self, value, inplace=True):
        self.train_df.fillna(value, inplace=inplace)

    def prepare_x_y(self, label, null_value=-999):
        self._replace_null_values(null_value)
        self.X = self.train_df.drop(label, axis=1)
        self.y = self.train_df[label]

    def _default_args_or_kwargs(self, **kwargs):
        params = {
            'iterations': 500,
            'learning_rate': 0.1,
            'eval_metric': 'Accuracy',
            'random_seed': 1968,
            'logging_level': 'Silent',
            'use_best_model': True
        }
        for k in kwargs.keys():
            if k in params:
                params[k] = kwargs.get(k)
        return params

    def create_model(self, **kwargs):
        from catboost import CatBoostClassifier
        params = self._default_args_or_kwargs(**kwargs)
        if not self.model:
            self.model = CatBoostClassifier(**params)
        else:
            raise ValueError("Cannot overwrite existing model")

    def train_model(self, train_size=0.75, random_state=42, **kwargs):
        from sklearn.model_selection import train_test_split
        import numpy as np
        
        X_train, X_validation, y_train, y_validation = train_test_split(self.X, 
                                                                        self.y, 
                                                                        train_size=0.7, 
                                                                        test_size=0.3,
                                                                        random_state=random_state
                                                                       )
        if not self.categorical_features_indices:
            self.categorical_features_indices = np.where(self.X.dtypes != np.float)[0]
            if not self.model:
                self.create_model(**kwargs)
                self.model.fit(
                X_train, y_train,
                cat_features=self.categorical_features_indices,
                eval_set=(X_validation, y_validation),
                logging_level='Verbose',
                )

    def model_cross_validation(self):
        from catboost import Pool, cv
        import numpy as np

        cv_data = cv(
                    Pool(self.X, self.y, cat_features=self.categorical_features_indices),
                    self.model.get_params()
                    )
        print(cv_data)
        return np.max(cv_data['test-Accuracy-mean'], axis=0)

    def save_model(self, name):
        self.model.save_model('{}.dump'.format(name))

    def load_model(self, name):
        self.model.load_model('{}.dump'.format(name))

    def predict(self, dataframe, null_value=999, inplace=True):
        dataframe.fillna(null_value, inplace=inplace)
        results = self.model.predict(dataframe)
        return results


In [3]:
c = CatTrainer(train_df)
c.prepare_x_y('Survived')
c.train_model()


0:	learn: 0.8298555	test: 0.8171642	best: 0.8171642 (0)	total: 163ms	remaining: 1m 21s
1:	learn: 0.8298555	test: 0.8171642	best: 0.8171642 (0)	total: 219ms	remaining: 54.5s
2:	learn: 0.8282504	test: 0.8171642	best: 0.8171642 (0)	total: 272ms	remaining: 45s
3:	learn: 0.8298555	test: 0.8097015	best: 0.8171642 (0)	total: 315ms	remaining: 39.1s
4:	learn: 0.8282504	test: 0.8022388	best: 0.8171642 (0)	total: 410ms	remaining: 40.6s
5:	learn: 0.8250401	test: 0.8097015	best: 0.8171642 (0)	total: 488ms	remaining: 40.2s
6:	learn: 0.8266453	test: 0.7985075	best: 0.8171642 (0)	total: 561ms	remaining: 39.5s
7:	learn: 0.8234350	test: 0.8022388	best: 0.8171642 (0)	total: 644ms	remaining: 39.6s
8:	learn: 0.8298555	test: 0.8022388	best: 0.8171642 (0)	total: 713ms	remaining: 38.9s
9:	learn: 0.8282504	test: 0.8022388	best: 0.8171642 (0)	total: 770ms	remaining: 37.7s
10:	learn: 0.8250401	test: 0.8059701	best: 0.8171642 (0)	total: 815ms	remaining: 36.2s
11:	learn: 0.8282504	test: 0.8059701	best: 0.8171642 (

97:	learn: 0.8699839	test: 0.8059701	best: 0.8171642 (0)	total: 5.26s	remaining: 21.6s
98:	learn: 0.8731942	test: 0.8059701	best: 0.8171642 (0)	total: 5.32s	remaining: 21.6s
99:	learn: 0.8731942	test: 0.8059701	best: 0.8171642 (0)	total: 5.4s	remaining: 21.6s
100:	learn: 0.8731942	test: 0.8059701	best: 0.8171642 (0)	total: 5.47s	remaining: 21.6s
101:	learn: 0.8731942	test: 0.8059701	best: 0.8171642 (0)	total: 5.55s	remaining: 21.6s
102:	learn: 0.8764045	test: 0.8097015	best: 0.8171642 (0)	total: 5.62s	remaining: 21.7s
103:	learn: 0.8780096	test: 0.8059701	best: 0.8171642 (0)	total: 5.68s	remaining: 21.6s
104:	learn: 0.8780096	test: 0.8097015	best: 0.8171642 (0)	total: 5.74s	remaining: 21.6s
105:	learn: 0.8796148	test: 0.8171642	best: 0.8171642 (0)	total: 5.81s	remaining: 21.6s
106:	learn: 0.8796148	test: 0.8208955	best: 0.8208955 (106)	total: 5.88s	remaining: 21.6s
107:	learn: 0.8796148	test: 0.8171642	best: 0.8208955 (106)	total: 5.95s	remaining: 21.6s
108:	learn: 0.8828250	test: 0.81

190:	learn: 0.9181380	test: 0.8134328	best: 0.8246269 (111)	total: 12s	remaining: 19.4s
191:	learn: 0.9197432	test: 0.8134328	best: 0.8246269 (111)	total: 12.1s	remaining: 19.3s
192:	learn: 0.9197432	test: 0.8134328	best: 0.8246269 (111)	total: 12.3s	remaining: 19.5s
193:	learn: 0.9197432	test: 0.8134328	best: 0.8246269 (111)	total: 12.4s	remaining: 19.5s
194:	learn: 0.9197432	test: 0.8097015	best: 0.8246269 (111)	total: 12.4s	remaining: 19.4s
195:	learn: 0.9229535	test: 0.8097015	best: 0.8246269 (111)	total: 12.5s	remaining: 19.4s
196:	learn: 0.9229535	test: 0.8097015	best: 0.8246269 (111)	total: 12.6s	remaining: 19.3s
197:	learn: 0.9229535	test: 0.8097015	best: 0.8246269 (111)	total: 12.6s	remaining: 19.3s
198:	learn: 0.9229535	test: 0.8059701	best: 0.8246269 (111)	total: 12.7s	remaining: 19.2s
199:	learn: 0.9229535	test: 0.8059701	best: 0.8246269 (111)	total: 12.8s	remaining: 19.2s
200:	learn: 0.9213483	test: 0.8059701	best: 0.8246269 (111)	total: 12.9s	remaining: 19.1s
201:	learn: 

283:	learn: 0.9277689	test: 0.8171642	best: 0.8246269 (111)	total: 19s	remaining: 14.5s
284:	learn: 0.9277689	test: 0.8171642	best: 0.8246269 (111)	total: 19.1s	remaining: 14.4s
285:	learn: 0.9277689	test: 0.8171642	best: 0.8246269 (111)	total: 19.1s	remaining: 14.3s
286:	learn: 0.9277689	test: 0.8171642	best: 0.8246269 (111)	total: 19.2s	remaining: 14.3s
287:	learn: 0.9277689	test: 0.8171642	best: 0.8246269 (111)	total: 19.3s	remaining: 14.2s
288:	learn: 0.9293740	test: 0.8171642	best: 0.8246269 (111)	total: 19.4s	remaining: 14.1s
289:	learn: 0.9293740	test: 0.8171642	best: 0.8246269 (111)	total: 19.4s	remaining: 14.1s
290:	learn: 0.9293740	test: 0.8171642	best: 0.8246269 (111)	total: 19.5s	remaining: 14s
291:	learn: 0.9293740	test: 0.8208955	best: 0.8246269 (111)	total: 19.6s	remaining: 13.9s
292:	learn: 0.9293740	test: 0.8208955	best: 0.8246269 (111)	total: 19.6s	remaining: 13.9s
293:	learn: 0.9293740	test: 0.8208955	best: 0.8246269 (111)	total: 19.7s	remaining: 13.8s
294:	learn: 0.

375:	learn: 0.9309791	test: 0.8171642	best: 0.8246269 (111)	total: 25.8s	remaining: 8.52s
376:	learn: 0.9309791	test: 0.8171642	best: 0.8246269 (111)	total: 25.9s	remaining: 8.45s
377:	learn: 0.9309791	test: 0.8171642	best: 0.8246269 (111)	total: 26s	remaining: 8.38s
378:	learn: 0.9293740	test: 0.8171642	best: 0.8246269 (111)	total: 26s	remaining: 8.31s
379:	learn: 0.9293740	test: 0.8134328	best: 0.8246269 (111)	total: 26.1s	remaining: 8.25s
380:	learn: 0.9293740	test: 0.8134328	best: 0.8246269 (111)	total: 26.2s	remaining: 8.18s
381:	learn: 0.9293740	test: 0.8134328	best: 0.8246269 (111)	total: 26.3s	remaining: 8.12s
382:	learn: 0.9293740	test: 0.8134328	best: 0.8246269 (111)	total: 26.4s	remaining: 8.05s
383:	learn: 0.9293740	test: 0.8134328	best: 0.8246269 (111)	total: 26.4s	remaining: 7.99s
384:	learn: 0.9293740	test: 0.8134328	best: 0.8246269 (111)	total: 26.5s	remaining: 7.92s
385:	learn: 0.9293740	test: 0.8134328	best: 0.8246269 (111)	total: 26.6s	remaining: 7.85s
386:	learn: 0.

467:	learn: 0.9357945	test: 0.8358209	best: 0.8358209 (428)	total: 33.2s	remaining: 2.27s
468:	learn: 0.9357945	test: 0.8358209	best: 0.8358209 (428)	total: 33.3s	remaining: 2.2s
469:	learn: 0.9357945	test: 0.8358209	best: 0.8358209 (428)	total: 33.3s	remaining: 2.13s
470:	learn: 0.9357945	test: 0.8358209	best: 0.8358209 (428)	total: 33.4s	remaining: 2.06s
471:	learn: 0.9357945	test: 0.8358209	best: 0.8358209 (428)	total: 33.5s	remaining: 1.99s
472:	learn: 0.9357945	test: 0.8358209	best: 0.8358209 (428)	total: 33.6s	remaining: 1.92s
473:	learn: 0.9357945	test: 0.8283582	best: 0.8358209 (428)	total: 33.6s	remaining: 1.84s
474:	learn: 0.9357945	test: 0.8283582	best: 0.8358209 (428)	total: 33.7s	remaining: 1.77s
475:	learn: 0.9373997	test: 0.8320896	best: 0.8358209 (428)	total: 33.8s	remaining: 1.7s
476:	learn: 0.9373997	test: 0.8320896	best: 0.8358209 (428)	total: 33.9s	remaining: 1.63s
477:	learn: 0.9390048	test: 0.8283582	best: 0.8358209 (428)	total: 33.9s	remaining: 1.56s
478:	learn: 

In [4]:
score = c.model_cross_validation()

     iterations  test-Accuracy-mean  test-Accuracy-std  train-Accuracy-mean  \
0             0            0.804714           0.015430             0.800224   
1             1            0.804714           0.015430             0.802469   
2             2            0.804714           0.015430             0.802469   
3             3            0.786756           0.019729             0.786756   
4             4            0.787879           0.020481             0.790685   
5             5            0.792368           0.012747             0.794613   
6             6            0.791246           0.014676             0.795174   
7             7            0.791246           0.014676             0.797419   
8             8            0.790123           0.014018             0.791807   
9             9            0.790123           0.014018             0.792368   
10           10            0.793490           0.010823             0.804153   
11           11            0.795735           0.0127

In [5]:
print(score)

0.8204264870931538


In [14]:
datafilename = '../../Results/Chapter 07/Chapter-007-013-CatBoost-01-01'
c.save_model(datafilename)

In [11]:
predictions = c.predict(test_df)
print(predictions)

[0. 1. 0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 1. 0. 1. 1. 1. 0. 1. 0. 1. 0. 1. 1.
 1. 0. 1. 0. 0. 1. 0. 0. 1. 1. 1. 0. 1. 1. 0. 1. 1. 1. 0. 1. 1. 0. 0. 0.
 1. 1. 0. 0. 1. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 1. 1. 0. 1. 1. 1. 0.
 1. 1. 1. 1. 0. 1. 0. 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 0.
 1. 0. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 0. 1. 1. 1.
 1. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0.
 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 0. 1. 1. 1. 1. 0. 0. 1. 0. 0.
 1. 1. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 0. 1. 1. 0. 1. 0. 1. 0. 0. 0. 1. 1.
 1. 0. 1. 0. 1. 1. 0. 1. 1. 1. 1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 1. 1. 0.
 1. 0. 1. 0. 1. 0. 1. 0. 1. 1. 0. 1. 0. 0. 1. 1. 0. 0. 1. 0. 1. 0. 1. 1.
 1. 1. 0. 0. 1. 0. 1. 0. 1. 1. 1. 0. 1. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 1.
 1. 0. 0. 0. 1. 0. 0. 0. 1. 1. 0. 1. 0. 1. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0.
 0. 0. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 0. 0. 1. 1. 0. 1. 0. 1. 0. 0.
 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 1. 0. 1. 0. 1.

## Done

In [12]:
import datetime
now = datetime.datetime.now()
print('Done!',str(now))

Done! 2019-04-22 08:19:38.392392
