# Method 1: basic regression with sklearn

In [1]:
import pandas as pd

In [2]:
data = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')
test_ids = test["PassengerId"] # for Kaggle submission

data.head(5)

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


In [3]:
def clean(data: pd.DataFrame()) -> pd.DataFrame():
    data = data.drop(["Ticket", "Name", "Cabin", "PassengerId"], axis=1)
    
    cols = ["SibSp", "Parch", "Fare", "Age"]
    # SibSp - есть ли брат/сестра/супруга на корабле
    # Parch (ParentChildren) - есть ли родители или 
    # Fare - цена билета
    # Age - возраст
    for col in cols:
        data[col].fillna(data[col].median(), inplace=True)
        
    data.Embarked.fillna("U", inplace=True)
    
    return data

data = clean(data)
test = clean(test)

## Data cleaning

In [4]:
from sklearn import preprocessing
le = preprocessing.LabelEncoder()

cols = ["Sex", "Embarked"]

for col in cols:
    data[col] = le.fit_transform(data[col])
    test[col] = le.transform(test[col])
    print(le.classes_)
    
data.head(5)    

['female' 'male']
['C' 'Q' 'S' 'U']


Unnamed: 0,Survived,Pclass,Sex,Age,SibSp,Parch,Fare,Embarked
0,0,3,1,22.0,1,0,7.25,2
1,1,1,0,38.0,1,0,71.2833,0
2,1,3,0,26.0,0,0,7.925,2
3,1,1,0,35.0,1,0,53.1,2
4,0,3,1,35.0,0,0,8.05,2


In [5]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

y = data["Survived"]
X = data.drop("Survived", axis=1)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

In [6]:
clf = LogisticRegression(random_state=0, max_iter=1000).fit(X_train, y_train)

In [7]:
predictions = clf.predict(X_val)
from sklearn.metrics import accuracy_score
print(accuracy_score(y_val, predictions))

0.8100558659217877


In [8]:
submission_preds = clf.predict(test)

In [9]:
df = pd.DataFrame({"PassengerId": test_ids.values,
                  "Survived": submission_preds})

In [10]:
df.to_csv("submission.csv", index=False)

# Method 2: catboost

In [68]:
import numpy as np
import catboost
from catboost import Pool, CatBoost
from catboost import CatBoostClassifier, Pool, cv
from sklearn.metrics import accuracy_score

In [72]:
train_df = pd.read_csv('train.csv')
test_df = pd.read_csv('test.csv')

test_ids = test_df["PassengerId"]

null_value_stats = train_df.isnull().sum(axis=0)
null_value_stats[null_value_stats != 0]

Age         177
Cabin       687
Embarked      2
dtype: int64

In [73]:
train_df.fillna(-999, inplace=True)
test_df.fillna(-999, inplace=True)

In [74]:
y = train_df["Survived"]
X = train_df.drop("Survived", axis=1)

In [75]:
print(X.dtypes)
categorical_features_indices = np.where(X.dtypes != np.float)[0]

PassengerId      int64
Pclass           int64
Name            object
Sex             object
Age            float64
SibSp            int64
Parch            int64
Ticket          object
Fare           float64
Cabin           object
Embarked        object
dtype: object


In [76]:
X_train, X_validation, y_train, y_validation = train_test_split(X, y, train_size=0.75, random_state=42)
X_test = test_df

In [77]:
model = CatBoostClassifier(
    custom_loss=['Accuracy'],
    random_seed=42,
    logging_level='Silent'
)

In [78]:
model.fit(
    X_train, y_train,
    cat_features=categorical_features_indices,
    eval_set=(X_validation, y_validation),
    logging_level='Verbose',  # you can uncomment this for text output
    plot=True
);

MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))

Learning rate set to 0.028683
0:	learn: 0.6739988	test: 0.6742630	best: 0.6742630 (0)	total: 3.93ms	remaining: 3.93s
1:	learn: 0.6589013	test: 0.6592240	best: 0.6592240 (1)	total: 5.44ms	remaining: 2.71s
2:	learn: 0.6421502	test: 0.6426778	best: 0.6426778 (2)	total: 7.97ms	remaining: 2.65s
3:	learn: 0.6297276	test: 0.6302310	best: 0.6302310 (3)	total: 10ms	remaining: 2.5s
4:	learn: 0.6147184	test: 0.6198228	best: 0.6198228 (4)	total: 14ms	remaining: 2.78s
5:	learn: 0.6017730	test: 0.6073627	best: 0.6073627 (5)	total: 16.3ms	remaining: 2.7s
6:	learn: 0.5885309	test: 0.5956000	best: 0.5956000 (6)	total: 18.7ms	remaining: 2.65s
7:	learn: 0.5783200	test: 0.5858523	best: 0.5858523 (7)	total: 21.1ms	remaining: 2.62s
8:	learn: 0.5665895	test: 0.5743842	best: 0.5743842 (8)	total: 23.4ms	remaining: 2.57s
9:	learn: 0.5575381	test: 0.5662283	best: 0.5662283 (9)	total: 25.9ms	remaining: 2.56s
10:	learn: 0.5491045	test: 0.5575176	best: 0.5575176 (10)	total: 28.5ms	remaining: 2.56s
11:	learn: 0.5423

140:	learn: 0.3508675	test: 0.4030210	best: 0.4030210 (140)	total: 362ms	remaining: 2.2s
141:	learn: 0.3499321	test: 0.4032275	best: 0.4030210 (140)	total: 365ms	remaining: 2.21s
142:	learn: 0.3498985	test: 0.4032541	best: 0.4030210 (140)	total: 367ms	remaining: 2.2s
143:	learn: 0.3492545	test: 0.4027983	best: 0.4027983 (143)	total: 370ms	remaining: 2.2s
144:	learn: 0.3489518	test: 0.4024768	best: 0.4024768 (144)	total: 372ms	remaining: 2.19s
145:	learn: 0.3485265	test: 0.4021110	best: 0.4021110 (145)	total: 375ms	remaining: 2.19s
146:	learn: 0.3481955	test: 0.4021177	best: 0.4021110 (145)	total: 377ms	remaining: 2.19s
147:	learn: 0.3472829	test: 0.4017109	best: 0.4017109 (147)	total: 380ms	remaining: 2.19s
148:	learn: 0.3461772	test: 0.4010244	best: 0.4010244 (148)	total: 382ms	remaining: 2.18s
149:	learn: 0.3456412	test: 0.4011647	best: 0.4010244 (148)	total: 385ms	remaining: 2.18s
150:	learn: 0.3448945	test: 0.4008998	best: 0.4008998 (150)	total: 387ms	remaining: 2.18s
151:	learn: 0

282:	learn: 0.3018641	test: 0.3936245	best: 0.3935022 (280)	total: 728ms	remaining: 1.84s
283:	learn: 0.3014807	test: 0.3934154	best: 0.3934154 (283)	total: 731ms	remaining: 1.84s
284:	learn: 0.3012909	test: 0.3944429	best: 0.3934154 (283)	total: 734ms	remaining: 1.84s
285:	learn: 0.3012585	test: 0.3944373	best: 0.3934154 (283)	total: 737ms	remaining: 1.84s
286:	learn: 0.3011758	test: 0.3944605	best: 0.3934154 (283)	total: 739ms	remaining: 1.84s
287:	learn: 0.3007872	test: 0.3944593	best: 0.3934154 (283)	total: 742ms	remaining: 1.83s
288:	learn: 0.3004516	test: 0.3943732	best: 0.3934154 (283)	total: 745ms	remaining: 1.83s
289:	learn: 0.3003341	test: 0.3944369	best: 0.3934154 (283)	total: 747ms	remaining: 1.83s
290:	learn: 0.3001110	test: 0.3947636	best: 0.3934154 (283)	total: 750ms	remaining: 1.83s
291:	learn: 0.3000496	test: 0.3947832	best: 0.3934154 (283)	total: 753ms	remaining: 1.82s
292:	learn: 0.2995601	test: 0.3947448	best: 0.3934154 (283)	total: 757ms	remaining: 1.83s
293:	learn

394:	learn: 0.2608314	test: 0.3947310	best: 0.3925689 (343)	total: 1.1s	remaining: 1.68s
395:	learn: 0.2603953	test: 0.3945539	best: 0.3925689 (343)	total: 1.1s	remaining: 1.68s
396:	learn: 0.2602947	test: 0.3946812	best: 0.3925689 (343)	total: 1.1s	remaining: 1.68s
397:	learn: 0.2597627	test: 0.3947332	best: 0.3925689 (343)	total: 1.11s	remaining: 1.67s
398:	learn: 0.2596066	test: 0.3945786	best: 0.3925689 (343)	total: 1.11s	remaining: 1.67s
399:	learn: 0.2595424	test: 0.3945935	best: 0.3925689 (343)	total: 1.11s	remaining: 1.67s
400:	learn: 0.2594838	test: 0.3946853	best: 0.3925689 (343)	total: 1.11s	remaining: 1.67s
401:	learn: 0.2593403	test: 0.3944686	best: 0.3925689 (343)	total: 1.12s	remaining: 1.66s
402:	learn: 0.2591148	test: 0.3949242	best: 0.3925689 (343)	total: 1.12s	remaining: 1.66s
403:	learn: 0.2589806	test: 0.3949932	best: 0.3925689 (343)	total: 1.13s	remaining: 1.66s
404:	learn: 0.2588899	test: 0.3949892	best: 0.3925689 (343)	total: 1.13s	remaining: 1.66s
405:	learn: 0

513:	learn: 0.2263631	test: 0.4002572	best: 0.3925689 (343)	total: 1.47s	remaining: 1.39s
514:	learn: 0.2261492	test: 0.4003810	best: 0.3925689 (343)	total: 1.47s	remaining: 1.38s
515:	learn: 0.2257930	test: 0.4004928	best: 0.3925689 (343)	total: 1.47s	remaining: 1.38s
516:	learn: 0.2256643	test: 0.4004825	best: 0.3925689 (343)	total: 1.47s	remaining: 1.38s
517:	learn: 0.2251149	test: 0.4014834	best: 0.3925689 (343)	total: 1.48s	remaining: 1.38s
518:	learn: 0.2244046	test: 0.4020700	best: 0.3925689 (343)	total: 1.48s	remaining: 1.37s
519:	learn: 0.2239187	test: 0.4024912	best: 0.3925689 (343)	total: 1.48s	remaining: 1.37s
520:	learn: 0.2236435	test: 0.4025870	best: 0.3925689 (343)	total: 1.49s	remaining: 1.37s
521:	learn: 0.2233773	test: 0.4025119	best: 0.3925689 (343)	total: 1.49s	remaining: 1.36s
522:	learn: 0.2231432	test: 0.4025133	best: 0.3925689 (343)	total: 1.49s	remaining: 1.36s
523:	learn: 0.2228520	test: 0.4025290	best: 0.3925689 (343)	total: 1.5s	remaining: 1.36s
524:	learn:

630:	learn: 0.2028403	test: 0.4066143	best: 0.3925689 (343)	total: 1.84s	remaining: 1.08s
631:	learn: 0.2027259	test: 0.4063462	best: 0.3925689 (343)	total: 1.85s	remaining: 1.07s
632:	learn: 0.2025504	test: 0.4065875	best: 0.3925689 (343)	total: 1.85s	remaining: 1.07s
633:	learn: 0.2019822	test: 0.4068693	best: 0.3925689 (343)	total: 1.85s	remaining: 1.07s
634:	learn: 0.2013882	test: 0.4071221	best: 0.3925689 (343)	total: 1.86s	remaining: 1.07s
635:	learn: 0.2013623	test: 0.4071165	best: 0.3925689 (343)	total: 1.86s	remaining: 1.06s
636:	learn: 0.2008415	test: 0.4074146	best: 0.3925689 (343)	total: 1.86s	remaining: 1.06s
637:	learn: 0.2003874	test: 0.4077206	best: 0.3925689 (343)	total: 1.86s	remaining: 1.06s
638:	learn: 0.2000530	test: 0.4079318	best: 0.3925689 (343)	total: 1.87s	remaining: 1.05s
639:	learn: 0.1999541	test: 0.4079294	best: 0.3925689 (343)	total: 1.87s	remaining: 1.05s
640:	learn: 0.1999227	test: 0.4078843	best: 0.3925689 (343)	total: 1.88s	remaining: 1.05s
641:	learn

736:	learn: 0.1800339	test: 0.4145284	best: 0.3925689 (343)	total: 2.21s	remaining: 790ms
737:	learn: 0.1800265	test: 0.4145411	best: 0.3925689 (343)	total: 2.22s	remaining: 787ms
738:	learn: 0.1797252	test: 0.4149555	best: 0.3925689 (343)	total: 2.22s	remaining: 784ms
739:	learn: 0.1796243	test: 0.4149001	best: 0.3925689 (343)	total: 2.22s	remaining: 782ms
740:	learn: 0.1795859	test: 0.4145753	best: 0.3925689 (343)	total: 2.23s	remaining: 778ms
741:	learn: 0.1795069	test: 0.4147153	best: 0.3925689 (343)	total: 2.23s	remaining: 775ms
742:	learn: 0.1794997	test: 0.4146924	best: 0.3925689 (343)	total: 2.23s	remaining: 772ms
743:	learn: 0.1793166	test: 0.4146788	best: 0.3925689 (343)	total: 2.24s	remaining: 770ms
744:	learn: 0.1792690	test: 0.4147515	best: 0.3925689 (343)	total: 2.24s	remaining: 767ms
745:	learn: 0.1792612	test: 0.4146457	best: 0.3925689 (343)	total: 2.24s	remaining: 763ms
746:	learn: 0.1792506	test: 0.4147124	best: 0.3925689 (343)	total: 2.24s	remaining: 760ms
747:	learn

844:	learn: 0.1621698	test: 0.4227055	best: 0.3925689 (343)	total: 2.59s	remaining: 475ms
845:	learn: 0.1621497	test: 0.4228014	best: 0.3925689 (343)	total: 2.59s	remaining: 472ms
846:	learn: 0.1617727	test: 0.4230427	best: 0.3925689 (343)	total: 2.59s	remaining: 469ms
847:	learn: 0.1617085	test: 0.4230407	best: 0.3925689 (343)	total: 2.6s	remaining: 466ms
848:	learn: 0.1613583	test: 0.4238683	best: 0.3925689 (343)	total: 2.6s	remaining: 463ms
849:	learn: 0.1610413	test: 0.4239924	best: 0.3925689 (343)	total: 2.6s	remaining: 459ms
850:	learn: 0.1609976	test: 0.4240652	best: 0.3925689 (343)	total: 2.61s	remaining: 456ms
851:	learn: 0.1609287	test: 0.4241010	best: 0.3925689 (343)	total: 2.61s	remaining: 453ms
852:	learn: 0.1609168	test: 0.4241438	best: 0.3925689 (343)	total: 2.61s	remaining: 450ms
853:	learn: 0.1607448	test: 0.4243380	best: 0.3925689 (343)	total: 2.61s	remaining: 447ms
854:	learn: 0.1606261	test: 0.4244428	best: 0.3925689 (343)	total: 2.62s	remaining: 444ms
855:	learn: 0

952:	learn: 0.1447868	test: 0.4348611	best: 0.3925689 (343)	total: 2.98s	remaining: 147ms
953:	learn: 0.1447467	test: 0.4347898	best: 0.3925689 (343)	total: 2.98s	remaining: 144ms
954:	learn: 0.1446053	test: 0.4350135	best: 0.3925689 (343)	total: 2.98s	remaining: 141ms
955:	learn: 0.1444571	test: 0.4359680	best: 0.3925689 (343)	total: 2.98s	remaining: 137ms
956:	learn: 0.1442475	test: 0.4360189	best: 0.3925689 (343)	total: 2.99s	remaining: 134ms
957:	learn: 0.1439146	test: 0.4358642	best: 0.3925689 (343)	total: 2.99s	remaining: 131ms
958:	learn: 0.1436998	test: 0.4358993	best: 0.3925689 (343)	total: 2.99s	remaining: 128ms
959:	learn: 0.1435275	test: 0.4357410	best: 0.3925689 (343)	total: 3s	remaining: 125ms
960:	learn: 0.1433444	test: 0.4358970	best: 0.3925689 (343)	total: 3s	remaining: 122ms
961:	learn: 0.1431958	test: 0.4362166	best: 0.3925689 (343)	total: 3s	remaining: 119ms
962:	learn: 0.1429716	test: 0.4365564	best: 0.3925689 (343)	total: 3.01s	remaining: 116ms
963:	learn: 0.14273

In [79]:
cv_params = model.get_params()
cv_params.update({
    'loss_function': 'Logloss'
})
cv_data = cv(
    Pool(X, y, cat_features=categorical_features_indices),
    cv_params,
    plot=True
)

MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))

In [80]:
print('Best validation accuracy score: {:.2f}±{:.2f} on step {}'.format(
    np.max(cv_data['test-Accuracy-mean']),
    cv_data['test-Accuracy-std'][np.argmax(cv_data['test-Accuracy-mean'])],
    np.argmax(cv_data['test-Accuracy-mean'])
))

Best validation accuracy score: 0.83±0.02 on step 355


In [81]:
print('Precise validation accuracy score: {}'.format(np.max(cv_data['test-Accuracy-mean'])))


Precise validation accuracy score: 0.8294051627384961


In [82]:
predictions = model.predict(X_test)
predictions_probs = model.predict_proba(X_test)

In [83]:
print(predictions[:10])
print(predictions_probs[:10])

[0 0 0 0 1 0 1 0 1 0]
[[0.85473931 0.14526069]
 [0.76313031 0.23686969]
 [0.88972889 0.11027111]
 [0.87876173 0.12123827]
 [0.3611047  0.6388953 ]
 [0.90513381 0.09486619]
 [0.33434185 0.66565815]
 [0.78468564 0.21531436]
 [0.39429048 0.60570952]
 [0.94047549 0.05952451]]


In [85]:
submission_preds = model.predict(test_df)

In [86]:
df = pd.DataFrame({"PassengerId": test_ids.values,
                  "Survived": submission_preds})

In [88]:
df.to_csv("submission_catboost.csv", index=False)