In [1]:
import pandas as pd

titanic_data = pd.read_csv('./data/Titanic-Dataset.csv')

In [2]:
titanic_data.columns

Index(['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp',
       'Parch', 'Ticket', 'Fare', 'Cabin', 'Embarked'],
      dtype='object')

In [3]:
titanic_data.head()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


In [4]:
titanic_data.drop(['PassengerId', 'Name', 'Ticket', 'SibSp', 'Parch', 'Cabin', 'Embarked'], axis='columns', inplace=True)

In [5]:
titanic_data.head()

Unnamed: 0,Survived,Pclass,Sex,Age,Fare
0,0,3,male,22.0,7.25
1,1,1,female,38.0,71.2833
2,1,3,female,26.0,7.925
3,1,1,female,35.0,53.1
4,0,3,male,35.0,8.05


In [6]:
target = titanic_data['Survived']
input_data = titanic_data.drop(['Survived'], axis='columns')

In [7]:
dumies = pd.get_dummies(input_data.Sex).astype(int)
dumies.head()

Unnamed: 0,female,male
0,0,1
1,1,0
2,1,0
3,1,0
4,0,1


In [8]:
input_data = pd.concat([input_data.drop('Sex', axis=1), dumies], axis=1)
input_data.head()

Unnamed: 0,Pclass,Age,Fare,female,male
0,3,22.0,7.25,0,1
1,1,38.0,71.2833,1,0
2,3,26.0,7.925,1,0
3,1,35.0,53.1,1,0
4,3,35.0,8.05,0,1


In [9]:
# tell you which columns has na (empty) value in your table
input_data.columns[input_data.isna().any()]

Index(['Age'], dtype='object')

In [10]:
input_data.Age[:10]

0    22.0
1    38.0
2    26.0
3    35.0
4    35.0
5     NaN
6    54.0
7     2.0
8    27.0
9    14.0
Name: Age, dtype: float64

In [11]:
# fill empty cells with mean of that column
input_data.Age = input_data.Age.fillna(input_data.Age.mean())
input_data.head(10)

Unnamed: 0,Pclass,Age,Fare,female,male
0,3,22.0,7.25,0,1
1,1,38.0,71.2833,1,0
2,3,26.0,7.925,1,0
3,1,35.0,53.1,1,0
4,3,35.0,8.05,0,1
5,3,29.699118,8.4583,0,1
6,1,54.0,51.8625,0,1
7,3,2.0,21.075,0,1
8,3,27.0,11.1333,1,0
9,2,14.0,30.0708,1,0


In [29]:
# split our data to training and testing data
from sklearn.model_selection import train_test_split
# train data size = 80 % while testing data size = 20 %
X_train, X_test, y_train, y_test = train_test_split(input_data, target, test_size=0.2)

In [13]:
len(X_train)

712

In [14]:
len(y_test)

179

In [15]:
# import the model from sklearn and then greate
from sklearn.naive_bayes import GaussianNB
model = GaussianNB()

In [30]:
model.fit(X_train, y_train)

In [31]:
model.score(X_test, y_test)

0.776536312849162

In [32]:
X_test[:10]

Unnamed: 0,Pclass,Age,Fare,female,male
315,3,26.0,7.8542,1,0
884,3,25.0,7.05,0,1
838,3,32.0,56.4958,0,1
214,3,29.699118,7.75,0,1
435,1,14.0,120.0,1,0
8,3,27.0,11.1333,1,0
834,3,18.0,8.3,0,1
673,2,31.0,13.0,0,1
766,1,29.699118,39.6,0,1
7,3,2.0,21.075,0,1


In [33]:
y_test[:10]

315    1
884    0
838    1
214    0
435    1
8      1
834    0
673    1
766    0
7      0
Name: Survived, dtype: int64

In [40]:
model.predict(X_test[:10])

array([1, 0, 0, 0, 1, 1, 0, 0, 0, 0])

In [38]:
# from sklearn.metrics import mean_absolute_error
# print(mean_absolute_error(y_test[:10], prediction))

0.2


In [39]:
model.predict_proba(X_test)

array([[8.07715337e-02, 9.19228466e-01],
       [9.90804770e-01, 9.19522961e-03],
       [9.82065743e-01, 1.79342567e-02],
       [9.91316254e-01, 8.68374602e-03],
       [5.50416388e-05, 9.99944958e-01],
       [8.29265892e-02, 9.17073411e-01],
       [9.89732377e-01, 1.02676232e-02],
       [9.80527938e-01, 1.94720616e-02],
       [9.08735836e-01, 9.12641640e-02],
       [9.83935613e-01, 1.60643871e-02],
       [9.91178451e-01, 8.82154932e-03],
       [9.90722977e-01, 9.27702344e-03],
       [9.10741820e-01, 8.92581798e-02],
       [9.91486680e-01, 8.51332024e-03],
       [9.85159997e-01, 1.48400027e-02],
       [6.26172770e-01, 3.73827230e-01],
       [4.97482878e-02, 9.50251712e-01],
       [9.90848491e-01, 9.15150904e-03],
       [9.80346720e-01, 1.96532802e-02],
       [9.91268691e-01, 8.73130930e-03],
       [9.25346505e-01, 7.46534947e-02],
       [3.33901302e-02, 9.66609870e-01],
       [9.87253953e-01, 1.27460473e-02],
       [8.55983656e-02, 9.14401634e-01],
       [3.580612