# mnist digit recognizer trained with catboost

NOTE: environment had catboost and widgets installed using Anaconda according to: 
https://catboost.ai/docs/installation/python-installation-method-conda-install.html#python-installation-method-conda-install

In [1]:
import numpy as np
import os
import pandas as pd
from pathlib import Path
import catboost
from catboost import CatBoostClassifier, Pool
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

In [2]:
print(os.listdir('../Datasets/digit-recognizer'))

['train.csv', 'test.csv', 'sample_submission.csv']


In [3]:
out_path = Path('./digit-recognizer-output')
in_path = Path('../Datasets/digit-recognizer')

In [4]:
df = pd.read_csv(in_path/'train.csv')
df.head(n=2)

Unnamed: 0,label,pixel0,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8,...,pixel774,pixel775,pixel776,pixel777,pixel778,pixel779,pixel780,pixel781,pixel782,pixel783
0,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


The data consists of a 1st column named 'label' that contains an integer between 0 and 9. The remaining columns correspond to an "image" of 784 pixels (one column per pixel), which together comprise the greyscale brightness (value 0-255) of a 28x28 image.

In [5]:
Y = df['label']

In [6]:
X = df.drop('label', axis=1)

In [7]:
# split data into train and test sets
seed = 7
test_size = 0.33
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=test_size, random_state=seed)

In [8]:
# train_pool = catboost.Pool(X_train, y_train, cat_features=list(X_train.columns))

# Training the model using catboost

In [46]:
model = CatBoostClassifier(iterations=1000,
                           learning_rate=0.05,
                           depth=10,
                           l2_leaf_reg=1,
                           task_type='GPU')

In [47]:
%%time
model.fit(X_train, y_train)

0:	learn: 2.0788996	total: 336ms	remaining: 5m 35s
1:	learn: 1.9039396	total: 635ms	remaining: 5m 17s
2:	learn: 1.7636899	total: 941ms	remaining: 5m 12s
3:	learn: 1.6438856	total: 1.24s	remaining: 5m 9s
4:	learn: 1.5446012	total: 1.55s	remaining: 5m 8s
5:	learn: 1.4548178	total: 1.85s	remaining: 5m 7s
6:	learn: 1.3742998	total: 2.15s	remaining: 5m 5s
7:	learn: 1.3047956	total: 2.45s	remaining: 5m 3s
8:	learn: 1.2425311	total: 2.74s	remaining: 5m 1s
9:	learn: 1.1854025	total: 3.04s	remaining: 5m 1s
10:	learn: 1.1305816	total: 3.34s	remaining: 5m
11:	learn: 1.0817588	total: 3.64s	remaining: 4m 59s
12:	learn: 1.0361003	total: 3.93s	remaining: 4m 58s
13:	learn: 0.9944314	total: 4.23s	remaining: 4m 57s
14:	learn: 0.9559811	total: 4.53s	remaining: 4m 57s
15:	learn: 0.9194600	total: 4.83s	remaining: 4m 57s
16:	learn: 0.8839041	total: 5.13s	remaining: 4m 56s
17:	learn: 0.8509737	total: 5.42s	remaining: 4m 55s
18:	learn: 0.8191986	total: 5.71s	remaining: 4m 55s
19:	learn: 0.7897035	total: 6s	re

159:	learn: 0.0834493	total: 41.6s	remaining: 3m 38s
160:	learn: 0.0829280	total: 41.8s	remaining: 3m 37s
161:	learn: 0.0825100	total: 42s	remaining: 3m 37s
162:	learn: 0.0820145	total: 42.2s	remaining: 3m 36s
163:	learn: 0.0813746	total: 42.4s	remaining: 3m 36s
164:	learn: 0.0809077	total: 42.6s	remaining: 3m 35s
165:	learn: 0.0803439	total: 42.8s	remaining: 3m 35s
166:	learn: 0.0799488	total: 43s	remaining: 3m 34s
167:	learn: 0.0793732	total: 43.2s	remaining: 3m 33s
168:	learn: 0.0786358	total: 43.4s	remaining: 3m 33s
169:	learn: 0.0779952	total: 43.6s	remaining: 3m 32s
170:	learn: 0.0775802	total: 43.8s	remaining: 3m 32s
171:	learn: 0.0769596	total: 44s	remaining: 3m 31s
172:	learn: 0.0765421	total: 44.2s	remaining: 3m 31s
173:	learn: 0.0760409	total: 44.4s	remaining: 3m 30s
174:	learn: 0.0755929	total: 44.5s	remaining: 3m 29s
175:	learn: 0.0751511	total: 44.7s	remaining: 3m 29s
176:	learn: 0.0746379	total: 44.9s	remaining: 3m 28s
177:	learn: 0.0741626	total: 45.2s	remaining: 3m 28s

315:	learn: 0.0404662	total: 1m 9s	remaining: 2m 31s
316:	learn: 0.0403623	total: 1m 9s	remaining: 2m 30s
317:	learn: 0.0402073	total: 1m 10s	remaining: 2m 30s
318:	learn: 0.0401036	total: 1m 10s	remaining: 2m 30s
319:	learn: 0.0400119	total: 1m 10s	remaining: 2m 29s
320:	learn: 0.0399000	total: 1m 10s	remaining: 2m 29s
321:	learn: 0.0397909	total: 1m 10s	remaining: 2m 29s
322:	learn: 0.0396430	total: 1m 10s	remaining: 2m 28s
323:	learn: 0.0394558	total: 1m 11s	remaining: 2m 28s
324:	learn: 0.0393550	total: 1m 11s	remaining: 2m 28s
325:	learn: 0.0392359	total: 1m 11s	remaining: 2m 27s
326:	learn: 0.0390516	total: 1m 11s	remaining: 2m 27s
327:	learn: 0.0388850	total: 1m 11s	remaining: 2m 27s
328:	learn: 0.0387552	total: 1m 12s	remaining: 2m 26s
329:	learn: 0.0386828	total: 1m 12s	remaining: 2m 26s
330:	learn: 0.0385796	total: 1m 12s	remaining: 2m 26s
331:	learn: 0.0383390	total: 1m 12s	remaining: 2m 25s
332:	learn: 0.0382889	total: 1m 12s	remaining: 2m 25s
333:	learn: 0.0381781	total: 1

468:	learn: 0.0271715	total: 1m 35s	remaining: 1m 48s
469:	learn: 0.0270958	total: 1m 35s	remaining: 1m 48s
470:	learn: 0.0270349	total: 1m 36s	remaining: 1m 47s
471:	learn: 0.0269763	total: 1m 36s	remaining: 1m 47s
472:	learn: 0.0269523	total: 1m 36s	remaining: 1m 47s
473:	learn: 0.0268913	total: 1m 36s	remaining: 1m 47s
474:	learn: 0.0268202	total: 1m 36s	remaining: 1m 46s
475:	learn: 0.0267449	total: 1m 36s	remaining: 1m 46s
476:	learn: 0.0266844	total: 1m 37s	remaining: 1m 46s
477:	learn: 0.0266163	total: 1m 37s	remaining: 1m 46s
478:	learn: 0.0265214	total: 1m 37s	remaining: 1m 46s
479:	learn: 0.0264715	total: 1m 37s	remaining: 1m 45s
480:	learn: 0.0264345	total: 1m 37s	remaining: 1m 45s
481:	learn: 0.0263712	total: 1m 37s	remaining: 1m 45s
482:	learn: 0.0263265	total: 1m 38s	remaining: 1m 44s
483:	learn: 0.0262348	total: 1m 38s	remaining: 1m 44s
484:	learn: 0.0262104	total: 1m 38s	remaining: 1m 44s
485:	learn: 0.0260876	total: 1m 38s	remaining: 1m 44s
486:	learn: 0.0259850	total:

621:	learn: 0.0193452	total: 2m 1s	remaining: 1m 13s
622:	learn: 0.0192974	total: 2m 1s	remaining: 1m 13s
623:	learn: 0.0192534	total: 2m 2s	remaining: 1m 13s
624:	learn: 0.0192338	total: 2m 2s	remaining: 1m 13s
625:	learn: 0.0191899	total: 2m 2s	remaining: 1m 13s
626:	learn: 0.0191490	total: 2m 2s	remaining: 1m 12s
627:	learn: 0.0191255	total: 2m 2s	remaining: 1m 12s
628:	learn: 0.0190992	total: 2m 2s	remaining: 1m 12s
629:	learn: 0.0190198	total: 2m 3s	remaining: 1m 12s
630:	learn: 0.0189878	total: 2m 3s	remaining: 1m 12s
631:	learn: 0.0189233	total: 2m 3s	remaining: 1m 11s
632:	learn: 0.0188728	total: 2m 3s	remaining: 1m 11s
633:	learn: 0.0188296	total: 2m 3s	remaining: 1m 11s
634:	learn: 0.0187884	total: 2m 4s	remaining: 1m 11s
635:	learn: 0.0187705	total: 2m 4s	remaining: 1m 11s
636:	learn: 0.0187286	total: 2m 4s	remaining: 1m 10s
637:	learn: 0.0186898	total: 2m 4s	remaining: 1m 10s
638:	learn: 0.0186706	total: 2m 4s	remaining: 1m 10s
639:	learn: 0.0186365	total: 2m 4s	remaining: 

777:	learn: 0.0145896	total: 2m 28s	remaining: 42.3s
778:	learn: 0.0145735	total: 2m 28s	remaining: 42.1s
779:	learn: 0.0145482	total: 2m 28s	remaining: 41.9s
780:	learn: 0.0145129	total: 2m 28s	remaining: 41.7s
781:	learn: 0.0144904	total: 2m 28s	remaining: 41.5s
782:	learn: 0.0144675	total: 2m 28s	remaining: 41.3s
783:	learn: 0.0144432	total: 2m 29s	remaining: 41.1s
784:	learn: 0.0144217	total: 2m 29s	remaining: 40.9s
785:	learn: 0.0144038	total: 2m 29s	remaining: 40.7s
786:	learn: 0.0143889	total: 2m 29s	remaining: 40.5s
787:	learn: 0.0143751	total: 2m 29s	remaining: 40.3s
788:	learn: 0.0143484	total: 2m 29s	remaining: 40.1s
789:	learn: 0.0143321	total: 2m 30s	remaining: 39.9s
790:	learn: 0.0143067	total: 2m 30s	remaining: 39.7s
791:	learn: 0.0142774	total: 2m 30s	remaining: 39.5s
792:	learn: 0.0142548	total: 2m 30s	remaining: 39.3s
793:	learn: 0.0142369	total: 2m 30s	remaining: 39.1s
794:	learn: 0.0141946	total: 2m 30s	remaining: 38.9s
795:	learn: 0.0141682	total: 2m 31s	remaining:

933:	learn: 0.0114071	total: 2m 54s	remaining: 12.3s
934:	learn: 0.0113913	total: 2m 54s	remaining: 12.1s
935:	learn: 0.0113736	total: 2m 54s	remaining: 11.9s
936:	learn: 0.0113447	total: 2m 54s	remaining: 11.8s
937:	learn: 0.0113280	total: 2m 55s	remaining: 11.6s
938:	learn: 0.0113149	total: 2m 55s	remaining: 11.4s
939:	learn: 0.0112864	total: 2m 55s	remaining: 11.2s
940:	learn: 0.0112767	total: 2m 55s	remaining: 11s
941:	learn: 0.0112533	total: 2m 55s	remaining: 10.8s
942:	learn: 0.0112279	total: 2m 55s	remaining: 10.6s
943:	learn: 0.0112024	total: 2m 56s	remaining: 10.4s
944:	learn: 0.0111943	total: 2m 56s	remaining: 10.3s
945:	learn: 0.0111786	total: 2m 56s	remaining: 10.1s
946:	learn: 0.0111665	total: 2m 56s	remaining: 9.88s
947:	learn: 0.0111418	total: 2m 56s	remaining: 9.7s
948:	learn: 0.0111273	total: 2m 56s	remaining: 9.51s
949:	learn: 0.0111119	total: 2m 57s	remaining: 9.32s
950:	learn: 0.0110854	total: 2m 57s	remaining: 9.14s
951:	learn: 0.0110601	total: 2m 57s	remaining: 8.

<catboost.core.CatBoostClassifier at 0x7fcaf5cf5650>

In [48]:
# make predictions for test data
y_pred = model.predict(X_test)
predictions = [round(value[0]) for value in y_pred]

In [49]:
# evaluate predictions
accuracy = accuracy_score(y_test, predictions)
print("Accuracy: %.2f%%" % (accuracy * 100.0))

Accuracy: 97.59%
