In [1]:
!pip install catboost
!pip install ipywidgets
!jupyter nbextension enable --py widgetsnbextension




Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: ok


In [2]:
from catboost.datasets import titanic
import numpy as np

train_df, test_df = titanic()

train_df.head()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


Проверим, сколько отсутствующих значений

In [3]:
null_value_stats = train_df.isnull().sum(axis=0)
null_value_stats[null_value_stats!=0]

Age         177
Cabin       687
Embarked      2
dtype: int64

Значения Age,Cabin и Embarked имеют пропущенные значения, поэтому заполним их каким-либо числовым выходом из их распределений. Так модель может их различить и учесть.

In [4]:
train_df.fillna(-999, inplace=True)
test_df.fillna(-999, inplace=True)

Разделим наши признаки и целевую метку

In [5]:
X = train_df.drop('Survived', axis=1)
y = train_df.Survived

Пропуски заполнены

In [6]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Name         891 non-null    object 
 4   Sex          891 non-null    object 
 5   Age          891 non-null    float64
 6   SibSp        891 non-null    int64  
 7   Parch        891 non-null    int64  
 8   Ticket       891 non-null    object 
 9   Fare         891 non-null    float64
 10  Cabin        891 non-null    object 
 11  Embarked     891 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB


Поскольку признаки бывают разных типов. Рассмотрим строковые особенности как категориальные.

In [7]:
print(X.dtypes)

categorical_features_indices = np.where(X.dtypes != np.float)[0]
X.iloc[:, categorical_features_indices]

PassengerId      int64
Pclass           int64
Name            object
Sex             object
Age            float64
SibSp            int64
Parch            int64
Ticket          object
Fare           float64
Cabin           object
Embarked        object
dtype: object


Unnamed: 0,PassengerId,Pclass,Name,Sex,SibSp,Parch,Ticket,Cabin,Embarked
0,1,3,"Braund, Mr. Owen Harris",male,1,0,A/5 21171,-999,S
1,2,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,1,0,PC 17599,C85,C
2,3,3,"Heikkinen, Miss. Laina",female,0,0,STON/O2. 3101282,-999,S
3,4,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,1,0,113803,C123,S
4,5,3,"Allen, Mr. William Henry",male,0,0,373450,-999,S
...,...,...,...,...,...,...,...,...,...
886,887,2,"Montvila, Rev. Juozas",male,0,0,211536,-999,S
887,888,1,"Graham, Miss. Margaret Edith",female,0,0,112053,B42,S
888,889,3,"Johnston, Miss. Catherine Helen ""Carrie""",female,1,2,W./C. 6607,-999,S
889,890,1,"Behr, Mr. Karl Howell",male,0,0,111369,C148,C


Разделим тренировочные данные на наборы для обучения и проверки

In [8]:
from sklearn.model_selection import train_test_split

X_train, X_validation, y_train, y_validation = train_test_split(X, y, train_size=0.75, random_state=42)

X_test = test_df

In [9]:
from catboost import CatBoostClassifier, Pool, metrics, cv
from sklearn.metrics import accuracy_score

Обучение модели

In [11]:
model = CatBoostClassifier(
    custom_loss=[metrics.Accuracy()],
    random_seed=42,
    logging_level='Silent'
)
model.fit(
    X_train, y_train,
    cat_features=categorical_features_indices,
    eval_set=(X_validation, y_validation),
    logging_level='Verbose',  
    plot=False
);
model.score(X_validation, y_validation)

Learning rate set to 0.028683
0:	learn: 0.6739988	test: 0.6742630	best: 0.6742630 (0)	total: 19.7ms	remaining: 19.7s
1:	learn: 0.6589013	test: 0.6592240	best: 0.6592240 (1)	total: 31ms	remaining: 15.5s
2:	learn: 0.6421502	test: 0.6426778	best: 0.6426778 (2)	total: 52.3ms	remaining: 17.4s
3:	learn: 0.6297276	test: 0.6302310	best: 0.6302310 (3)	total: 69.5ms	remaining: 17.3s
4:	learn: 0.6147184	test: 0.6198228	best: 0.6198228 (4)	total: 90.5ms	remaining: 18s
5:	learn: 0.6017730	test: 0.6073627	best: 0.6073627 (5)	total: 112ms	remaining: 18.6s
6:	learn: 0.5885309	test: 0.5956000	best: 0.5956000 (6)	total: 134ms	remaining: 19s
7:	learn: 0.5783200	test: 0.5858523	best: 0.5858523 (7)	total: 174ms	remaining: 21.6s
8:	learn: 0.5665895	test: 0.5743842	best: 0.5743842 (8)	total: 198ms	remaining: 21.8s
9:	learn: 0.5575381	test: 0.5662283	best: 0.5662283 (9)	total: 220ms	remaining: 21.8s
10:	learn: 0.5491045	test: 0.5575176	best: 0.5575176 (10)	total: 245ms	remaining: 22.1s
11:	learn: 0.5423887	te

98:	learn: 0.3715320	test: 0.4142153	best: 0.4142153 (98)	total: 2.21s	remaining: 20.1s
99:	learn: 0.3713123	test: 0.4137836	best: 0.4137836 (99)	total: 2.22s	remaining: 20s
100:	learn: 0.3711681	test: 0.4134127	best: 0.4134127 (100)	total: 2.23s	remaining: 19.8s
101:	learn: 0.3701463	test: 0.4129498	best: 0.4129498 (101)	total: 2.25s	remaining: 19.8s
102:	learn: 0.3696277	test: 0.4126316	best: 0.4126316 (102)	total: 2.27s	remaining: 19.8s
103:	learn: 0.3687955	test: 0.4111648	best: 0.4111648 (103)	total: 2.29s	remaining: 19.8s
104:	learn: 0.3680006	test: 0.4103194	best: 0.4103194 (104)	total: 2.32s	remaining: 19.8s
105:	learn: 0.3678460	test: 0.4103092	best: 0.4103092 (105)	total: 2.33s	remaining: 19.7s
106:	learn: 0.3672394	test: 0.4097549	best: 0.4097549 (106)	total: 2.36s	remaining: 19.7s
107:	learn: 0.3672117	test: 0.4097518	best: 0.4097518 (107)	total: 2.37s	remaining: 19.5s
108:	learn: 0.3665546	test: 0.4096297	best: 0.4096297 (108)	total: 2.38s	remaining: 19.5s
109:	learn: 0.36

190:	learn: 0.3309874	test: 0.3980582	best: 0.3978982 (187)	total: 4.06s	remaining: 17.2s
191:	learn: 0.3307796	test: 0.3980432	best: 0.3978982 (187)	total: 4.08s	remaining: 17.2s
192:	learn: 0.3306578	test: 0.3979033	best: 0.3978982 (187)	total: 4.1s	remaining: 17.2s
193:	learn: 0.3299342	test: 0.3978205	best: 0.3978205 (193)	total: 4.12s	remaining: 17.1s
194:	learn: 0.3291373	test: 0.3965197	best: 0.3965197 (194)	total: 4.14s	remaining: 17.1s
195:	learn: 0.3285707	test: 0.3970109	best: 0.3965197 (194)	total: 4.16s	remaining: 17.1s
196:	learn: 0.3282818	test: 0.3967007	best: 0.3965197 (194)	total: 4.18s	remaining: 17s
197:	learn: 0.3273962	test: 0.3968371	best: 0.3965197 (194)	total: 4.2s	remaining: 17s
198:	learn: 0.3272150	test: 0.3966213	best: 0.3965197 (194)	total: 4.22s	remaining: 17s
199:	learn: 0.3266134	test: 0.3968475	best: 0.3965197 (194)	total: 4.24s	remaining: 17s
200:	learn: 0.3265759	test: 0.3967844	best: 0.3965197 (194)	total: 4.26s	remaining: 16.9s
201:	learn: 0.326154

290:	learn: 0.3001110	test: 0.3947636	best: 0.3934154 (283)	total: 5.9s	remaining: 14.4s
291:	learn: 0.3000496	test: 0.3947832	best: 0.3934154 (283)	total: 5.92s	remaining: 14.4s
292:	learn: 0.2995601	test: 0.3947448	best: 0.3934154 (283)	total: 5.94s	remaining: 14.3s
293:	learn: 0.2995112	test: 0.3947653	best: 0.3934154 (283)	total: 5.96s	remaining: 14.3s
294:	learn: 0.2988606	test: 0.3948888	best: 0.3934154 (283)	total: 5.99s	remaining: 14.3s
295:	learn: 0.2979924	test: 0.3949597	best: 0.3934154 (283)	total: 6.01s	remaining: 14.3s
296:	learn: 0.2968685	test: 0.3945085	best: 0.3934154 (283)	total: 6.03s	remaining: 14.3s
297:	learn: 0.2961210	test: 0.3946813	best: 0.3934154 (283)	total: 6.05s	remaining: 14.2s
298:	learn: 0.2958293	test: 0.3946240	best: 0.3934154 (283)	total: 6.07s	remaining: 14.2s
299:	learn: 0.2956822	test: 0.3945140	best: 0.3934154 (283)	total: 6.09s	remaining: 14.2s
300:	learn: 0.2953614	test: 0.3944096	best: 0.3934154 (283)	total: 6.11s	remaining: 14.2s
301:	learn:

382:	learn: 0.2658405	test: 0.3936081	best: 0.3925689 (343)	total: 7.77s	remaining: 12.5s
383:	learn: 0.2656425	test: 0.3937052	best: 0.3925689 (343)	total: 7.8s	remaining: 12.5s
384:	learn: 0.2648389	test: 0.3936322	best: 0.3925689 (343)	total: 7.82s	remaining: 12.5s
385:	learn: 0.2644950	test: 0.3935443	best: 0.3925689 (343)	total: 7.87s	remaining: 12.5s
386:	learn: 0.2640026	test: 0.3935747	best: 0.3925689 (343)	total: 7.89s	remaining: 12.5s
387:	learn: 0.2639477	test: 0.3935993	best: 0.3925689 (343)	total: 7.91s	remaining: 12.5s
388:	learn: 0.2636118	test: 0.3936222	best: 0.3925689 (343)	total: 7.93s	remaining: 12.5s
389:	learn: 0.2628712	test: 0.3938607	best: 0.3925689 (343)	total: 7.95s	remaining: 12.4s
390:	learn: 0.2627320	test: 0.3938785	best: 0.3925689 (343)	total: 7.99s	remaining: 12.4s
391:	learn: 0.2623286	test: 0.3939627	best: 0.3925689 (343)	total: 8.01s	remaining: 12.4s
392:	learn: 0.2617678	test: 0.3941314	best: 0.3925689 (343)	total: 8.03s	remaining: 12.4s
393:	learn:

476:	learn: 0.2359515	test: 0.3988139	best: 0.3925689 (343)	total: 10s	remaining: 11s
477:	learn: 0.2357715	test: 0.3986607	best: 0.3925689 (343)	total: 10s	remaining: 11s
478:	learn: 0.2356674	test: 0.3986996	best: 0.3925689 (343)	total: 10.1s	remaining: 10.9s
479:	learn: 0.2352692	test: 0.3988182	best: 0.3925689 (343)	total: 10.1s	remaining: 11s
480:	learn: 0.2351312	test: 0.3987838	best: 0.3925689 (343)	total: 10.2s	remaining: 11s
481:	learn: 0.2348800	test: 0.3993014	best: 0.3925689 (343)	total: 10.2s	remaining: 10.9s
482:	learn: 0.2344202	test: 0.3995723	best: 0.3925689 (343)	total: 10.2s	remaining: 10.9s
483:	learn: 0.2339025	test: 0.3997433	best: 0.3925689 (343)	total: 10.2s	remaining: 10.9s
484:	learn: 0.2337993	test: 0.3996815	best: 0.3925689 (343)	total: 10.2s	remaining: 10.9s
485:	learn: 0.2336698	test: 0.3996594	best: 0.3925689 (343)	total: 10.3s	remaining: 10.9s
486:	learn: 0.2332568	test: 0.4000085	best: 0.3925689 (343)	total: 10.3s	remaining: 10.9s
487:	learn: 0.2326517	

575:	learn: 0.2127805	test: 0.4039081	best: 0.3925689 (343)	total: 12.3s	remaining: 9.03s
576:	learn: 0.2124578	test: 0.4039974	best: 0.3925689 (343)	total: 12.3s	remaining: 9.01s
577:	learn: 0.2124096	test: 0.4040149	best: 0.3925689 (343)	total: 12.3s	remaining: 8.99s
578:	learn: 0.2121430	test: 0.4043251	best: 0.3925689 (343)	total: 12.3s	remaining: 8.96s
579:	learn: 0.2120690	test: 0.4045013	best: 0.3925689 (343)	total: 12.3s	remaining: 8.94s
580:	learn: 0.2120014	test: 0.4043891	best: 0.3925689 (343)	total: 12.4s	remaining: 8.93s
581:	learn: 0.2116058	test: 0.4044004	best: 0.3925689 (343)	total: 12.4s	remaining: 8.9s
582:	learn: 0.2115862	test: 0.4043126	best: 0.3925689 (343)	total: 12.4s	remaining: 8.88s
583:	learn: 0.2112350	test: 0.4043209	best: 0.3925689 (343)	total: 12.4s	remaining: 8.86s
584:	learn: 0.2110965	test: 0.4041534	best: 0.3925689 (343)	total: 12.5s	remaining: 8.84s
585:	learn: 0.2109455	test: 0.4045123	best: 0.3925689 (343)	total: 12.5s	remaining: 8.82s
586:	learn:

669:	learn: 0.1938049	test: 0.4101533	best: 0.3925689 (343)	total: 14.3s	remaining: 7.06s
670:	learn: 0.1934187	test: 0.4104149	best: 0.3925689 (343)	total: 14.4s	remaining: 7.05s
671:	learn: 0.1932050	test: 0.4102778	best: 0.3925689 (343)	total: 14.4s	remaining: 7.03s
672:	learn: 0.1931093	test: 0.4104816	best: 0.3925689 (343)	total: 14.4s	remaining: 7.01s
673:	learn: 0.1926278	test: 0.4108661	best: 0.3925689 (343)	total: 14.5s	remaining: 6.99s
674:	learn: 0.1925318	test: 0.4107259	best: 0.3925689 (343)	total: 14.5s	remaining: 6.97s
675:	learn: 0.1923053	test: 0.4105625	best: 0.3925689 (343)	total: 14.5s	remaining: 6.95s
676:	learn: 0.1922277	test: 0.4105619	best: 0.3925689 (343)	total: 14.5s	remaining: 6.92s
677:	learn: 0.1918859	test: 0.4105366	best: 0.3925689 (343)	total: 14.5s	remaining: 6.9s
678:	learn: 0.1918759	test: 0.4104482	best: 0.3925689 (343)	total: 14.6s	remaining: 6.88s
679:	learn: 0.1917758	test: 0.4104177	best: 0.3925689 (343)	total: 14.6s	remaining: 6.86s
680:	learn:

761:	learn: 0.1763260	test: 0.4161065	best: 0.3925689 (343)	total: 16.4s	remaining: 5.12s
762:	learn: 0.1759086	test: 0.4162623	best: 0.3925689 (343)	total: 16.4s	remaining: 5.1s
763:	learn: 0.1756569	test: 0.4165741	best: 0.3925689 (343)	total: 16.4s	remaining: 5.08s
764:	learn: 0.1756382	test: 0.4165709	best: 0.3925689 (343)	total: 16.5s	remaining: 5.05s
765:	learn: 0.1753264	test: 0.4165993	best: 0.3925689 (343)	total: 16.5s	remaining: 5.03s
766:	learn: 0.1752084	test: 0.4166235	best: 0.3925689 (343)	total: 16.5s	remaining: 5.01s
767:	learn: 0.1749531	test: 0.4170032	best: 0.3925689 (343)	total: 16.5s	remaining: 4.99s
768:	learn: 0.1749129	test: 0.4170263	best: 0.3925689 (343)	total: 16.5s	remaining: 4.97s
769:	learn: 0.1745286	test: 0.4176999	best: 0.3925689 (343)	total: 16.6s	remaining: 4.95s
770:	learn: 0.1745228	test: 0.4177098	best: 0.3925689 (343)	total: 16.6s	remaining: 4.92s
771:	learn: 0.1741362	test: 0.4176875	best: 0.3925689 (343)	total: 16.6s	remaining: 4.91s
772:	learn:

857:	learn: 0.1601547	test: 0.4246835	best: 0.3925689 (343)	total: 18.5s	remaining: 3.06s
858:	learn: 0.1599293	test: 0.4249648	best: 0.3925689 (343)	total: 18.5s	remaining: 3.04s
859:	learn: 0.1597357	test: 0.4250209	best: 0.3925689 (343)	total: 18.5s	remaining: 3.02s
860:	learn: 0.1595634	test: 0.4250858	best: 0.3925689 (343)	total: 18.6s	remaining: 3s
861:	learn: 0.1594357	test: 0.4253323	best: 0.3925689 (343)	total: 18.6s	remaining: 2.97s
862:	learn: 0.1593813	test: 0.4253947	best: 0.3925689 (343)	total: 18.6s	remaining: 2.95s
863:	learn: 0.1593229	test: 0.4253348	best: 0.3925689 (343)	total: 18.6s	remaining: 2.93s
864:	learn: 0.1587776	test: 0.4251685	best: 0.3925689 (343)	total: 18.6s	remaining: 2.91s
865:	learn: 0.1586252	test: 0.4251088	best: 0.3925689 (343)	total: 18.7s	remaining: 2.89s
866:	learn: 0.1586138	test: 0.4251845	best: 0.3925689 (343)	total: 18.7s	remaining: 2.87s
867:	learn: 0.1585144	test: 0.4254559	best: 0.3925689 (343)	total: 18.7s	remaining: 2.85s
868:	learn: 0

956:	learn: 0.1442475	test: 0.4360189	best: 0.3925689 (343)	total: 20.8s	remaining: 933ms
957:	learn: 0.1439146	test: 0.4358642	best: 0.3925689 (343)	total: 20.8s	remaining: 911ms
958:	learn: 0.1436998	test: 0.4358993	best: 0.3925689 (343)	total: 20.8s	remaining: 889ms
959:	learn: 0.1435275	test: 0.4357410	best: 0.3925689 (343)	total: 20.8s	remaining: 867ms
960:	learn: 0.1433444	test: 0.4358970	best: 0.3925689 (343)	total: 20.8s	remaining: 846ms
961:	learn: 0.1431958	test: 0.4362166	best: 0.3925689 (343)	total: 20.9s	remaining: 824ms
962:	learn: 0.1429716	test: 0.4365564	best: 0.3925689 (343)	total: 20.9s	remaining: 802ms
963:	learn: 0.1427362	test: 0.4368190	best: 0.3925689 (343)	total: 20.9s	remaining: 781ms
964:	learn: 0.1426933	test: 0.4367965	best: 0.3925689 (343)	total: 20.9s	remaining: 759ms
965:	learn: 0.1426753	test: 0.4368336	best: 0.3925689 (343)	total: 20.9s	remaining: 737ms
966:	learn: 0.1425322	test: 0.4372336	best: 0.3925689 (343)	total: 21s	remaining: 716ms
967:	learn: 

0.8161434977578476

Сократим модель до первых 344 итераций

In [14]:
model = CatBoostClassifier(
    iterations=344,
    custom_loss=[metrics.Accuracy()],
    random_seed=42,
    logging_level='Silent'
)
model.fit(
    X_train, y_train,
    cat_features=categorical_features_indices,
    eval_set=(X_validation, y_validation),
   
    plot=False
);
model.score(X_validation, y_validation)

0.8295964125560538