In [2]:
!pip install -q sklearn

In [3]:
%tensorflow_version 2.x
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import clear_output
from six.moves import urllib
import tensorflow.compat.v2.feature_column as fc

In [4]:
train_dataset = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')
test_dataset = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')
#pop the survived column to use them as labels for the model
y_train = train_dataset.pop('survived')
y_eval = test_dataset.pop('survived')

In [5]:
CATEGORICAL_COLUMNS = ['sex', 'n_siblings_spouses', 'parch', 'class', 'deck', 'embark_town', 'alone']
NUMERICAL_COLUMNS = ['age', 'fare']

feature_columns = []
for feature_name in CATEGORICAL_COLUMNS:
  vocabulary = train_dataset[feature_name].unique() #gets a list of all the unique values from a given column
  #creates a column that has a feature column name and all the vocabulary with it
  feature_columns.append(tf.feature_column.categorical_column_with_vocabulary_list(feature_name, vocabulary))

for feature_name in NUMERICAL_COLUMNS:
  feature_columns.append(tf.feature_column.numeric_column(feature_name, dtype=tf.float32))

print(feature_columns)

[VocabularyListCategoricalColumn(key='sex', vocabulary_list=('male', 'female'), dtype=tf.string, default_value=-1, num_oov_buckets=0), VocabularyListCategoricalColumn(key='n_siblings_spouses', vocabulary_list=(1, 0, 3, 4, 2, 5, 8), dtype=tf.int64, default_value=-1, num_oov_buckets=0), VocabularyListCategoricalColumn(key='parch', vocabulary_list=(0, 1, 2, 5, 3, 4), dtype=tf.int64, default_value=-1, num_oov_buckets=0), VocabularyListCategoricalColumn(key='class', vocabulary_list=('Third', 'First', 'Second'), dtype=tf.string, default_value=-1, num_oov_buckets=0), VocabularyListCategoricalColumn(key='deck', vocabulary_list=('unknown', 'C', 'G', 'A', 'B', 'D', 'F', 'E'), dtype=tf.string, default_value=-1, num_oov_buckets=0), VocabularyListCategoricalColumn(key='embark_town', vocabulary_list=('Southampton', 'Cherbourg', 'Queenstown', 'unknown'), dtype=tf.string, default_value=-1, num_oov_buckets=0), VocabularyListCategoricalColumn(key='alone', vocabulary_list=('n', 'y'), dtype=tf.string, def

In [6]:
# create a dataset object, convert from pandas object
# create a function that returns an input function for parsing through the tensors
def make_input_fn(data, train_label, num_epochs=10, shuffle=True, batch_size=32):
  def input_function():
    # from_tensor_slices creates tensor objects for each element in python list
    dataset = tf.data.Dataset.from_tensor_slices((dict(data), train_label))
    if shuffle:
      dataset = dataset.shuffle(1000)
    dataset = dataset.batch(batch_size).repeat(num_epochs) #split dataset into batches
    return dataset
  return input_function

train_input_fn = make_input_fn(train_dataset, y_train)
test_input_fn = make_input_fn(test_dataset, y_eval, num_epochs=1, shuffle=False)

In [7]:
# create the model
linear_est = tf.estimator.LinearClassifier(feature_columns=feature_columns)
# train model
linear_est.train(train_input_fn)
result = linear_est.evaluate(test_input_fn)

clear_output()
print(result['accuracy'])

0.7234849


In [11]:
result = list(linear_est.predict(test_input_fn))
print(test_dataset.loc[0])
print('Chances of survival', result[0]['probabilities'][1])

INFO:tensorflow:Calling model_fn.




INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp0nw5mrnt/model.ckpt-200
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
sex                          male
age                            35
n_siblings_spouses              0
parch                           0
fare                         8.05
class                       Third
deck                      unknown
embark_town           Southampton
alone                           y
Name: 0, dtype: object
Chances of survival 0.03671655
