# Exploratory data analysis

In [59]:
# imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from jupyterlab.semver import comparatorTrimReplace

In [60]:
# plot style
plt.style.use("ggplot")
sns.set_palette("muted")

In [61]:
# Load the data
df = pd.read_csv("../data/train.csv")

In [62]:
df.shape

(891, 12)

In [63]:
df.columns.tolist()

['PassengerId',
 'Survived',
 'Pclass',
 'Name',
 'Sex',
 'Age',
 'SibSp',
 'Parch',
 'Ticket',
 'Fare',
 'Cabin',
 'Embarked']

In [64]:
df.dtypes

PassengerId      int64
Survived         int64
Pclass           int64
Name            object
Sex             object
Age            float64
SibSp            int64
Parch            int64
Ticket          object
Fare           float64
Cabin           object
Embarked        object
dtype: object

In [65]:
df.isnull().sum()

PassengerId      0
Survived         0
Pclass           0
Name             0
Sex              0
Age            177
SibSp            0
Parch            0
Ticket           0
Fare             0
Cabin          687
Embarked         2
dtype: int64

## Total survival count

In [66]:
plt.figure(figsize = (12.8, 7.2))
sns.countplot(x = "Survived", data = df)
plt.title("Survived?", fontsize = 24)
plt.xticks(ticks = [0, 1], labels = ["No", "Yes"], fontsize = 16)
plt.xlabel("")
plt.ylabel("Count", fontsize = 16)
plt.savefig("../outputs/total-survival.png")
plt.close()

![Total Survived Count](../outputs/total-survival.png)

## Survived rate by Ticket class

In [67]:
plt.figure(figsize = (12.8, 7.2))
sns.barplot(x = "Pclass", y = "Survived", data = df)
plt.title("Survived rate by Ticket class", fontsize = 24)
plt.xlabel("Ticket class", fontsize = 16)
plt.xticks(ticks = [0, 1, 2], labels = ["1st", "2nd", "3rd"], fontsize = 16)
plt.ylabel("Survived rate", fontsize = 16)
plt.ylim(0, 1)
plt.savefig("../outputs/survived-rate-by-ticket-class.png")
plt.close()

![Survived-rate-by-ticket-class](../outputs/survived-rate-by-ticket-class.png)

## Survived rate by gender

In [68]:
plt.figure(figsize = (12.8, 7.2))
sns.barplot(x = "Sex", y = "Survived", data = df)
plt.title("Survived rate by gender", fontsize = 24)
plt.xlabel("Gender", fontsize = 16)
plt.ylabel("Survived rate", fontsize = 16)
plt.ylim(0, 1)
plt.xticks(fontsize = 16)
plt.savefig("../outputs/survived-rate-by-gender.png")
plt.close()

![survived-rate-by-gender](../outputs/survived-rate-by-gender.png)

## Survived rate by port of embarkation

In [69]:
plt.figure(figsize = (12.8, 7.2))
sns.barplot(x = "Embarked", y = "Survived", data = df)
plt.title("Survived rate by Port of Embarkation", fontsize = 24)
plt.xlabel("Port of Embarkation", fontsize = 16)
plt.xticks(ticks = ["S", "C", "Q"], labels = ["Southampton", "Cherbourg", "Queenstown"], fontsize = 16)
plt.ylabel("Survived rate", fontsize = 16)
plt.ylim(0, 1)
plt.savefig("../outputs/survived-rate-by-port-of-embarkation.png")
plt.close()

![survived-rate-by-port-of-embarkation](../outputs/survived-rate-by-port-of-embarkation.png)

## Distribution of age

In [70]:
plt.figure(figsize = (12.8, 7.2))
sns.histplot(df["Age"], bins = 30, kde = True)
plt.title("Distribution of Age", fontsize = 24)
plt.savefig("../outputs/distribution-of-age.png")
plt.xlabel("Age", fontsize = 16)
plt.ylabel("Count", fontsize = 16)
plt.close()

![distribution-of-age](../outputs/distribution-of-age.png)

## Correlation Heatmap

In [71]:
plt.figure(figsize = (12.8, 7.2))
numeric_columns = df.select_dtypes(include=[np.number]).columns
correlation_matrix = df[numeric_columns].corr()
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0)
plt.title("Correlation Heatmap", fontsize = 24)
plt.savefig("../outputs/correlation-heatmap.png")
plt.close()

![correlation-heatmap](../outputs/correlation-heatmap.png)